import os
from datetime import datetime
import copy
import h5py
import numpy as np
import pandas as pd
import nibabel as nib
import pyvista as pv
from pathlib import Path
from scipy import stats
from scipy.linalg import pinv
from typing import Union, List, Optional, Tuple, Dict
from rich.progress import (
Progress,
BarColumn,
TimeRemainingColumn,
TextColumn,
MofNCompleteColumn,
SpinnerColumn,
)
# Importing local modules
from . import misctools as cltmisc
from . import imagetools as cltimg
from . import segmentationtools as cltseg
from . import freesurfertools as cltfree
from . import surfacetools as cltsurf
from . import bidstools as cltbids
from . import colorstools as cltcol
from . import connectivitytools as cltcon
####################################################################################################
####################################################################################################
############ ############
############ ############
############ Section 1: Class dedicated to work with parcellation images ############
############ ############
############ ############
####################################################################################################
####################################################################################################
[docs]
class Parcellation:
"""
Comprehensive class for working with brain parcellation data.
Provides tools for loading, manipulating, and analyzing brain parcellation
files with associated lookup tables. Supports filtering, masking, grouping,
volume calculations, and various export formats for neuroimaging workflows.
"""
####################################################################################################
[docs]
def __init__(
self,
parc_file: Union[str, Path, np.ndarray] = None,
color_table: Optional[Union[str, Path, dict]] = None,
affine: Optional[np.ndarray] = None,
parc_id: Optional[str] = None,
space_id: Optional[str] = "unknown",
):
"""
Initialize Parcellation object from file or array.
Parameters
----------
parc_file : str, Path, or np.ndarray, optional
Path to parcellation file (NIfTI format) or numpy array containing
parcellation data. If string/Path, loads from file and attempts to
find associated TSV/LUT files. Default is None.
color_table : str, Path, or dict, optional
Color lookup table for parcellation regions. Can be:
- Path to TSV/LUT file with columns: index, name, R, G, B, A (and optionally opacity)
- Dictionary with required keys 'index', 'name', 'color' and optional keys
'opacity', 'headerlines'
If None, color table is auto-generated or loaded from sidecar files.
Default is None.
affine : np.ndarray, optional
4x4 affine transformation matrix. If None and parc_file is array,
creates identity matrix centered on data. Default is None.
parc_id : str, optional
Unique identifier for the parcellation. If None, generated from
file name or set to 'numpy_array' for array input. Default is None.
space_id : str, optional
Identifier for the space in which the parcellation is defined
(e.g., 'MNI152NLin6Asym', 'native'). Default is "unknown".
Attributes
----------
data : np.ndarray
3D parcellation data array (integer labels).
affine : np.ndarray
4x4 affine transformation matrix.
index : list of int
List of region codes present in parcellation (excluding 0).
name : list of str
List of region names corresponding to codes.
color : list of str
List of colors (hex format) for each region.
opacity : list of float
List of opacity values (0-1) for each region. Default is 1.0 for all.
headerlines : list
List of header lines for the color table. Default is [].
parc_file : str
Path to parcellation file or 'numpy_array'.
id : str
Unique identifier for the parcellation.
space : str
Space identifier.
dim : tuple
Dimensions of the parcellation data.
voxel_size : float
Volume of a single voxel in mm³.
dtype : np.dtype
Data type of the parcellation.
Raises
------
ValueError
If parcellation file does not exist or parc_file is None.
FileNotFoundError
If specified color_table file does not exist.
Examples
--------
>>> # Load from file with automatic color table detection
>>> parc = Parcellation('parcellation.nii.gz')
>>>
>>> # Load with explicit color table
>>> parc = Parcellation('parcellation.nii.gz', color_table='colors.tsv')
>>>
>>> # Create from array with custom affine
>>> parc = Parcellation(label_array, affine=img.affine, parc_id='custom')
>>>
>>> # Create from array with full color dictionary
>>> color_dict = {
... 'index': [1, 2, 3],
... 'name': ['region1', 'region2', 'region3'],
... 'color': ['#FF0000', '#00FF00', '#0000FF'],
... 'opacity': [1.0, 0.8, 0.6], # optional
... 'headerlines': ['# My custom parcellation'] # optional
... }
>>> parc = Parcellation(label_array, color_table=color_dict)
>>>
>>> # Create with minimal color dictionary (opacity defaults to 1.0)
>>> color_dict = {
... 'index': [1, 2, 3],
... 'name': ['region1', 'region2', 'region3'],
... 'color': ['#FF0000', '#00FF00', '#0000FF']
... }
>>> parc = Parcellation(label_array, color_table=color_dict)
"""
if parc_file is None:
raise ValueError(
"parc_file cannot be None. Provide a file path or numpy array."
)
# Handle file path input
if isinstance(parc_file, (str, Path)):
parc_file = str(parc_file)
if not os.path.exists(parc_file):
raise ValueError(f"The parcellation file does not exist: {parc_file}")
self.parc_file = parc_file
# Set parcellation ID
if parc_id is not None:
self.id = parc_id
else:
self.get_parcellation_id()
# Set space ID
self.set_space_id()
# Load the parcellation data
temp_iparc = nib.load(parc_file)
self.affine = temp_iparc.affine
self.data = temp_iparc.get_fdata().astype(np.int32)
self.dtype = temp_iparc.get_data_dtype()
# Determine color table to load
lut_2_load = self._determine_color_table_file(parc_file, color_table)
# Load or create color table
if lut_2_load is not None:
self.load_colortable(lut_file=lut_2_load)
elif isinstance(color_table, dict):
self._load_colortable_from_dict(color_table)
else:
# Auto-generate color table
self._create_default_colortable()
# Handle numpy array input
elif isinstance(parc_file, np.ndarray):
self.parc_file = "numpy_array"
self.id = parc_id if parc_id is not None else "numpy_array"
self.space = space_id
self.data = parc_file.astype(np.int32)
self.dtype = self.data.dtype
# Create affine matrix if not provided
if affine is None:
affine = np.eye(4)
center = np.array(self.data.shape) // 2
affine[:3, 3] = -center
self.affine = affine
# Handle color table
if isinstance(color_table, dict):
self._load_colortable_from_dict(color_table)
elif isinstance(color_table, (str, Path)) and os.path.exists(
str(color_table)
):
self.load_colortable(lut_file=str(color_table))
else:
# Auto-generate color table
self._create_default_colortable()
else:
raise TypeError(
f"parc_file must be str, Path, or np.ndarray, got {type(parc_file)}"
)
# Ensure required attributes exist and adjust to data
self._ensure_attributes()
self.adjust_values()
# Set dimensional properties
self.dim = self.data.shape
self.voxel_size = cltimg.get_voxel_size(self.affine)
# Set Voxel volume
self.voxel_volume = cltimg.get_voxel_volume(self.affine)
# Detect label range
self.parc_range()
def _determine_color_table_file(
self, parc_file: str, color_table: Optional[Union[str, Path]]
) -> Optional[str]:
"""
Determine which color table file to load.
Priority: explicit color_table > .tsv sidecar > .lut sidecar
Parameters
----------
parc_file : str
Path to parcellation file.
color_table : str, Path, or None
Explicitly provided color table path.
Returns
-------
str or None
Path to color table file, or None if none found.
"""
# Check for explicit color_table
if color_table is not None and isinstance(color_table, (str, Path)):
color_table = str(color_table)
if os.path.isfile(color_table):
return color_table
else:
raise FileNotFoundError(
f"Specified color_table does not exist: {color_table}"
)
# Determine base name for sidecar files
if parc_file.endswith(".nii.gz"):
base = parc_file[:-7] # Remove .nii.gz
elif parc_file.endswith(".nii"):
base = parc_file[:-4] # Remove .nii
else:
return None
# Check for sidecar files
tsv_file = base + ".tsv"
lut_file = base + ".lut"
if os.path.isfile(tsv_file):
return tsv_file
elif os.path.isfile(lut_file):
return lut_file
return None
def _load_colortable_from_dict(self, color_dict: dict) -> None:
"""
Load color table from dictionary.
Parameters
----------
color_dict : dict
Dictionary with required keys:
- 'index': list of int (region codes)
- 'name': list of str (region names)
And optional keys:
- 'color': list of str (hex colors)
- 'opacity': list of float (0-1 range, defaults to 1.0 for all)
- 'headerlines': list of str (defaults to [])
Raises
------
ValueError
If required keys are missing or lists have mismatched lengths.
"""
required_keys = {"index", "name"}
if not required_keys.issubset(color_dict.keys()):
raise ValueError(
f"color_table dict must contain keys: {required_keys}. "
f"Got: {set(color_dict.keys())}"
)
index = color_dict["index"]
name = color_dict["name"]
color = color_dict["color"]
if not (len(index) == len(name)):
raise ValueError(
f"All required lists in color_table dict must have same length. "
f"Got: index={len(index)}, name={len(name)}, color={len(color)}"
)
# Convert to appropriate types
self.index = [int(x) for x in index]
self.name = list(name)
if "color" in color_dict:
color = color_dict["color"]
if len(color) != len(index):
raise ValueError(
f"If provided, color list must have same length as index. "
f"Got: color={len(color)}, index={len(index)}"
)
else:
color = cltcol.create_distinguishable_colors(
len(self.index), output_format="hex"
)
# Force the colors to be in hex format
color = cltcol.harmonize_colors(color, output_format="hex")
self.color = list(color)
# Handle optional opacity
if "opacity" in color_dict:
opacity = color_dict["opacity"]
if len(opacity) != len(index):
raise ValueError(
f"If provided, opacity list must have same length as index. "
f"Got: opacity={len(opacity)}, index={len(index)}"
)
self.opacity = [float(x) for x in opacity]
else:
# Default opacity: 1.0 for all regions
self.opacity = [1.0] * len(self.index)
# Handle optional headerlines
if "headerlines" in color_dict:
self.headerlines = list(color_dict["headerlines"])
else:
# Default: empty list
self.headerlines = []
def _create_default_colortable(self) -> None:
"""
Create default color table from unique values in data.
Generates region indices, automatic names, distinguishable colors,
default opacity (1.0), and empty headerlines.
"""
# Get unique non-zero values
unique_vals = np.unique(self.data)
unique_vals = unique_vals[unique_vals != 0]
self.index = [int(x) for x in unique_vals]
self.name = cltmisc.create_names_from_indices(self.index)
if len(self.index) > 0:
self.color = cltcol.create_distinguishable_colors(
len(self.index), output_format="hex"
)
# Default opacity: 1.0 for all regions
self.opacity = [1.0] * len(self.index)
else:
self.color = []
self.opacity = []
# Default: empty headerlines
self.headerlines = []
def _ensure_attributes(self) -> None:
"""
Ensure all required attributes exist with valid values.
Creates default values for index, name, color, opacity, and headerlines
if not present.
"""
# Ensure index exists
if not hasattr(self, "index") or self.index is None:
unique_vals = np.unique(self.data)
unique_vals = unique_vals[unique_vals != 0]
self.index = [int(x) for x in unique_vals]
# Force index to int
self.index = [int(x) for x in self.index]
# Ensure name exists
if not hasattr(self, "name") or self.name is None:
self.name = cltmisc.create_names_from_indices(self.index)
# Ensure color exists
if not hasattr(self, "color") or self.color is None:
if len(self.index) > 0:
self.color = cltcol.create_distinguishable_colors(
len(self.index), output_format="hex"
)
else:
self.color = []
# Ensure opacity exists (default to 1.0 for all regions)
if not hasattr(self, "opacity") or self.opacity is None:
self.opacity = [1.0] * len(self.index)
# Ensure opacity is iterable (guard against scalar float)
if not hasattr(self.opacity, "__len__"):
self.opacity = [float(self.opacity)] * len(self.index)
# Ensure opacity is a flat 1-D list of floats with the correct length
opacity_arr = np.array(self.opacity)
if opacity_arr.ndim != 1 or len(opacity_arr) != len(self.index):
self.opacity = [1.0] * len(self.index)
else:
self.opacity = [float(x) for x in opacity_arr.tolist()]
# Ensure headerlines exists (default to empty list)
if not hasattr(self, "headerlines") or self.headerlines is None:
self.headerlines = []
# Detect minimum and maximum labels
self.parc_range()
#####################################################################################################
[docs]
def get_space_id(self) -> str:
"""
Infer the space identifier from the parcellation filename if space is not yet set.
Returns
-------
space_id : str
The space identifier extracted from the BIDS filename, or "unknown" if not found.
Notes
-----
This method only attempts to extract the space entity from BIDS-compliant
filenames. It does not modify the `space` attribute.
Examples
--------
>>> # Extract from BIDS filename
>>> parc = Parcellation('sub-01_space-t1_atlas-xxx.nii.gz')
>>> space_id = parc.get_space_id()
>>> print(space_id)
't1'
>>> # Non-BIDS filename
>>> parc = Parcellation('custom_parcellation.nii.gz')
>>> space_id = parc.get_space_id()
>>> print(space_id)
'unknown'
>>> # No file set
>>> parc = Parcellation()
>>> space_id = parc.get_space_id()
>>> print(space_id)
'unknown'
"""
# If space is already set, return it
if hasattr(self, "space"):
return self.space
# Otherwise, infer from filename
if not hasattr(self, "parc_file") or not self.parc_file:
return "unknown"
parc_file_name = os.path.basename(self.parc_file)
# Check if the filename follows BIDS naming conventions
if cltbids.is_bids_filename(parc_file_name):
# Extract entities from the filename
name_ent_dict = cltbids.str2entity(parc_file_name)
# Return space entity if present
if "space" in name_ent_dict:
return name_ent_dict["space"]
return "unknown"
#####################################################################################################
[docs]
def set_space_id(self, space_id: Optional[str] = None) -> str:
"""
Set the space identifier for the parcellation.
Parameters
----------
space_id : str, optional
Identifier for the space. If None, attempts to infer from filename
using get_space_id().
Returns
-------
space_id : str
The space identifier that was set.
Notes
-----
Priority order for setting space:
1. Provided space_id parameter (if not None)
2. Inferred from filename via get_space_id() (returns "unknown" if not found)
This method sets the `space` attribute of the Parcellation object.
Examples
--------
>>> # Explicitly set space
>>> parc = Parcellation('sub-01_atlas-xxx.nii.gz')
>>> parc.set_space_id('mni152')
'mni152'
>>> # Infer from filename
>>> parc = Parcellation('sub-01_space-t1_atlas-xxx.nii.gz')
>>> parc.set_space_id()
't1'
>>> # Fallback to unknown
>>> parc = Parcellation('custom_parcellation.nii.gz')
>>> parc.set_space_id()
'unknown'
"""
# Priority 1: Use explicitly provided space_id
if space_id is not None:
final_space_id = space_id
else:
# Priority 2: Infer from filename (returns "unknown" if not found)
final_space_id = self.get_space_id()
# Set the space attribute
self.space = final_space_id
return final_space_id
####################################################################################################
[docs]
def get_parcellation_id(self) -> str:
"""
Generate a unique identifier for the parcellation based on its filename. If the filename
follows BIDS naming conventions, it extracts relevant entities to form the ID.
If the filename does not follow BIDS conventions, it uses the filename without extension.
Returns
-------
str
Unique identifier for the parcellation, formatted as 'atlas-<atlas_name>_seg-<seg_name>_scale-<scale_value>_desc-<description>'.
If no entities are found, it returns the filename without extension.
Raises
------
ValueError
If the parcellation file is not set.
Notes
This method is useful for identifying and categorizing parcellation files based on their naming conventions.
It can be used to easily retrieve or reference specific parcellations in analyses or reports.
Examples
--------
>>> parc = Parcellation('sub-01_ses-01_acq-mprage_space-t1_atlas-xxx_seg-yyy_scale-1_desc-test.nii.gz')
>>> parc_id = parc.get_parcellation_id()
>>> print(parc_id)
'atlas-xxx_seg-yyy_scale-1_desc-test'
>>> parc = Parcellation('custom_parcellation.nii.gz')
>>> parc_id = parc.get_parcellation_id()
>>> print(parc_id)
'custom_parcellation'
"""
# Check if the parcellation file is set
if not hasattr(self, "parc_file"):
raise ValueError(
"The parcellation file is not set. Please load a parcellation file first."
)
# Initialize parc_fullid as an empty string
parc_fullid = ""
# Get the base name of the parcellation file
parc_file_name = os.path.basename(self.parc_file)
# Check if the parcellation file name follows BIDS naming conventions
if cltbids.is_bids_filename(parc_file_name):
# Extract entities from the parcellation file name
name_ent_dict = cltbids.str2entity(parc_file_name)
ent_names_list = list(name_ent_dict.keys())
# Create parc_fullid based on the entities present in the parcellation file name
parc_fullid = ""
if "atlas" in ent_names_list:
parc_fullid = "atlas-" + name_ent_dict["atlas"]
if "seg" in ent_names_list:
parc_fullid += "_seg-" + name_ent_dict["seg"]
if "scale" in ent_names_list:
parc_fullid += "_scale-" + name_ent_dict["scale"]
if "desc" in ent_names_list:
parc_fullid += "_desc-" + name_ent_dict["desc"]
# Remove the _ if the parc_fullid starts with it
if parc_fullid.startswith("_"):
parc_fullid = parc_fullid[1:]
else:
# Remove the file extension if it exists
if parc_file_name.endswith(".nii.gz"):
parc_fullid = parc_file_name[:-7]
else:
parc_fullid = parc_file_name[:-4]
self.id = parc_fullid
return parc_fullid
######################################################################################################
[docs]
def get_info(self, verbose: bool = True) -> dict:
"""
Display and return comprehensive information about the Parcellation object.
Provides a formatted overview of the parcellation including identification,
image properties, color-table statistics, and label consistency checks.
Useful for quick inspection and validation of volumetric parcellation data.
The method displays:
- Basic identification (ID, space, file path)
- Image properties (dimensions, voxel size, data type, affine matrix)
- Color-table properties (number of regions, label range, opacity range)
- Label consistency (codes in data missing from table, and table entries
absent from data)
Parameters
----------
verbose : bool, optional
If True, prints the information in a formatted table to stdout.
If False, only returns the info dictionary without printing.
Default is True.
Returns
-------
info : dict
Dictionary containing the following keys:
- 'id' : str
Parcellation identifier string.
- 'space' : str
Space identifier (e.g. 'MNI152NLin6Asym', 'native').
- 'parc_file' : str
Full path to the parcellation file, or 'numpy_array'.
- 'dim' : tuple
Spatial dimensions of the parcellation volume (x, y, z).
- 'voxel_size' : float
Volume of a single voxel in mm³.
- 'dtype' : str
NumPy data type of the parcellation array.
- 'affine' : np.ndarray or None
4x4 affine transformation matrix, or None if not set.
- 'n_regions' : int
Number of regions defined in the color table (len of index).
- 'min_label' : int or None
Minimum non-zero label value present in the data.
- 'max_label' : int or None
Maximum label value present in the data.
- 'opacity_range' : tuple or None
(min_opacity, max_opacity) across all defined regions,
or None if opacity is not set.
- 'labels_not_in_table' : list of int
Label values present in the data but absent from ``self.index``.
Ideally empty.
- 'regions_not_in_data' : list of str
Region names whose index codes do not appear in the data.
These are defined but unused regions.
Examples
--------
>>> parc = Parcellation('sub-01_space-MNI_atlas-Schaefer_desc-400Parcels.nii.gz')
>>> info = parc.get_info()
╔════════════════════════════════════════════════════════════════╗
║ PARCELLATION INFO ║
╠════════════════════════════════════════════════════════════════╣
║ ID : atlas-Schaefer_desc-400Parcels ║
║ Space : MNI ║
║ File : sub-01_space-MNI_atlas-Schaefer_desc-400Parcels... ║
╠════════════════════════════════════════════════════════════════╣
║ IMAGE PROPERTIES ║
║ Dimensions : 182 x 218 x 182 ║
║ Voxel size : 1.000 mm³ ║
║ Data type : int32 ║
╠════════════════════════════════════════════════════════════════╣
║ COLOR TABLE ║
║ Regions : 400 ║
║ Label range : 1 → 400 ║
║ Opacity : 1.00 → 1.00 ║
╠════════════════════════════════════════════════════════════════╣
║ LABEL CONSISTENCY ║
║ Labels not in table : 0 ║
║ Regions not in data : 0 ║
╚════════════════════════════════════════════════════════════════╝
"""
WIDTH = 64 # inner content width (between the ║ borders)
def _row(content: str) -> None:
"""Print a single bordered row, left-justified."""
print(f"║{content.ljust(WIDTH)}║")
# ── Gather information ────────────────────────────────────────────────────
info: dict = {
"id": getattr(self, "id", None),
"space": getattr(self, "space", None),
"parc_file": getattr(self, "parc_file", None),
"dim": getattr(self, "dim", None),
"voxel_size": getattr(self, "voxel_size", None),
"voxel_volume": getattr(self, "voxel_volume", None),
"dtype": str(getattr(self, "dtype", None)),
"affine": getattr(self, "affine", None),
"n_regions": None,
"min_label": getattr(self, "minlab", None),
"max_label": getattr(self, "maxlab", None),
"opacity_range": None,
"labels_not_in_table": [],
"regions_not_in_data": [],
}
# Color-table stats
if hasattr(self, "index") and self.index is not None:
info["n_regions"] = len(self.index)
if hasattr(self, "opacity") and self.opacity is not None:
try:
op_arr = np.array(self.opacity, dtype=float)
info["opacity_range"] = (float(op_arr.min()), float(op_arr.max()))
except Exception:
pass
# Label consistency checks (only when both data and index exist)
if (
hasattr(self, "data")
and self.data is not None
and hasattr(self, "index")
and self.index is not None
):
data_codes = set(int(v) for v in np.unique(self.data) if v != 0)
table_codes = set(int(v) for v in self.index)
info["labels_not_in_table"] = sorted(data_codes - table_codes)
missing_in_data = sorted(table_codes - data_codes)
if hasattr(self, "name") and self.name is not None:
idx_to_name = {int(c): n for c, n in zip(self.index, self.name)}
info["regions_not_in_data"] = [
idx_to_name.get(c, str(c)) for c in missing_in_data
]
else:
info["regions_not_in_data"] = missing_in_data
# ── Print ─────────────────────────────────────────────────────────────────
if verbose:
# Helper: truncate a string to fit within a column width
def _trunc(s: str, max_w: int) -> str:
return s if len(s) <= max_w else "..." + s[-(max_w - 3) :]
print("╔" + "═" * WIDTH + "╗")
_row(" PARCELLATION INFO".center(WIDTH))
print("╠" + "═" * WIDTH + "╣")
# — Identification ————————————————————————————————————————————————
id_val = info["id"] or "N/A"
space_val = info["space"] or "N/A"
file_val = _trunc(info["parc_file"] or "N/A", WIDTH - 12)
_row(f" ID : {id_val}")
_row(f" Space : {space_val}")
_row(f" File : {file_val}")
# — Image properties ——————————————————————————————————————————————
print("╠" + "═" * WIDTH + "╣")
_row(" IMAGE PROPERTIES")
if info["dim"] is not None:
dim_str = " x ".join(str(d) for d in info["dim"])
_row(f" Dimensions : {dim_str:>20}")
else:
_row(" Dimensions : N/A")
if info["voxel_size"] is not None:
dim_str = " x ".join(str(d) for d in info["voxel_size"])
_row(f" Voxel size : {dim_str:>20} mm")
else:
_row(" Voxel size : N/A")
if info["voxel_volume"] is not None:
_row(f" Voxel volume: {info['voxel_volume']:>16.3f} mm³")
else:
_row(" Voxel volume : N/A")
_row(f" Data type : {info['dtype']:>20}")
# — Color table ———————————————————————————————————————————————————
print("╠" + "═" * WIDTH + "╣")
_row(" COLOR TABLE")
if info["n_regions"] is not None:
_row(f" Regions : {info['n_regions']:>20,}")
else:
_row(" Regions : N/A")
if info["min_label"] is not None and info["max_label"] is not None:
range_str = f"{info['min_label']} → {info['max_label']}"
_row(f" Label range : {range_str:>20}")
else:
_row(" Label range : N/A")
if info["opacity_range"] is not None:
op_str = (
f"{info['opacity_range'][0]:.2f} → {info['opacity_range'][1]:.2f}"
)
_row(f" Opacity : {op_str:>20}")
else:
_row(" Opacity : N/A")
# — Label consistency ——————————————————————————————————————————————
print("╠" + "═" * WIDTH + "╣")
_row(" LABEL CONSISTENCY")
n_missing_table = len(info["labels_not_in_table"])
n_missing_data = len(info["regions_not_in_data"])
_row(f" Labels not in table : {n_missing_table:>4}")
if n_missing_table > 0:
codes_str = ", ".join(str(v) for v in info["labels_not_in_table"])
_row(f" ↳ codes : {_trunc(codes_str, WIDTH - 14)}")
_row(f" Regions not in data : {n_missing_data:>4}")
if n_missing_data > 0:
names_str = ", ".join(str(v) for v in info["regions_not_in_data"])
_row(f" ↳ names : {_trunc(names_str, WIDTH - 14)}")
print("╚" + "═" * WIDTH + "╝")
return info
####################################################################################################
[docs]
def export_summary_to_hdf5(self, out_file: str, overwrite: bool = False):
"""
Export parcellation summary to HDF5 file.
Parameters
----------
out_file : str
Path to output HDF5 file.
Raises
------
ValueError
If the parcellation data is not set.
Notes
-----
This method saves the parcellation data, index, name, and color attributes to an HDF5 file.
It is useful for archiving and sharing parcellation information in a structured format.
Examples
--------
>>> parc.export_summary_to_hdf5('parcellation_summary.h5')
"""
out_path = Path(out_file)
# Check if the output directory exists, if not create raise an error
if not out_path.parent.exists():
raise ValueError(
f"The output directory {out_path.parent} does not exist. Please create it first."
)
# Check if the output file already exists
if out_path.exists() and not overwrite:
raise ValueError(
f"The output file {out_path} already exists. Use overwrite=True to overwrite it."
)
# Check if the parcellation data is set
if not hasattr(self, "data"):
raise ValueError(
"The parcellation data is not set. Please load a parcellation file first."
)
# Check if the attributes parcellation_id and space_id are set
if not hasattr(self, "id"):
self.get_parcellation_id()
if not hasattr(self, "space"):
self.get_space_id(space_id="unknown")
parc_id = self.id
space_id = self.space
base_cad = f"parcellation_{parc_id}/space-{space_id}"
# Create the hf file
hf = h5py.File(out_file, "w")
# Save the filename
if hasattr(self, "parc_file"):
hf.create_dataset(f"{base_cad}/header/file_path", data=self.parc_file)
# Save the LUT file pathname if it exists
if hasattr(self, "lut_file"):
hf.create_dataset(f"{base_cad}/header/lut_file", data=self.lut_file)
# Save the parcellation id
if hasattr(self, "id"):
hf.create_dataset(f"{base_cad}/header/id", data=self.id)
# Save the space id
if hasattr(self, "space"):
hf.create_dataset(f"{base_cad}/header/space", data=self.space)
# Save the parcellation dimension
if hasattr(self, "dim"):
hf.create_dataset(f"{base_cad}/header/dim", data=self.dim)
# Save the parcellation voxel size
if hasattr(self, "voxel_size"):
hf.create_dataset(f"{base_cad}/header/voxel_size", data=self.voxel_size)
# Save the parcellation affine
if hasattr(self, "affine"):
hf.create_dataset(f"{base_cad}/header/affine", data=self.affine)
# Save the number of regions
if hasattr(self, "index"):
hf.create_dataset(f"{base_cad}/header/num_regions", data=len(self.index))
else:
# If index is not set, calculate the number of regions from the data
regions = np.unique(self.data)
n_regions = len(regions[regions != 0])
hf.create_dataset(f"{base_cad}/header/num_regions", data=n_regions)
# Save the minimum label
if hasattr(self, "min_label"):
hf.create_dataset(f"{base_cad}/header/min_label", data=self.min_label)
else:
# If min_label and max_label are not set, calculate them from the data
regions = np.unique(self.data)
regions = regions[regions != 0]
hf.create_dataset(f"{base_cad}/header/min_label", data=np.min(regions))
# Save the maximum label
if hasattr(self, "max_label"):
hf.create_dataset(f"{base_cad}/header/max_label", data=self.max_label)
else:
# If min_label and max_label are not set, calculate them from the data
regions = np.unique(self.data)
regions = regions[regions != 0]
hf.create_dataset(f"{base_cad}/header/max_label", data=np.max(regions))
# Save the index of the regions
if hasattr(self, "index"):
hf.create_dataset(f"{base_cad}/regions_indices", data=self.index)
# Save the region names
if hasattr(self, "name"):
hf.create_dataset(f"{base_cad}/regions_names", data=self.name)
# Save the region colors
if hasattr(self, "color"):
hf.create_dataset(f"{base_cad}/regions_colors", data=self.color)
# Save the parcellation centroids
if hasattr(self, "centroids"):
hf.create_dataset(f"{base_cad}/regions_centroids", data=self.centroids)
# Save the timeseries if they exist
if hasattr(self, "timeseries"):
hf.create_dataset(f"{base_cad}/time_series", data=self.timeseries)
# Close the file
hf.close()
# Save the morphometry DataFrame if it exists
if hasattr(self, "morphometry"):
cltmisc.save_morphometry_hdf5(
out_file, "{base_cad}/morphometry", self.morphometry, mode="w"
)
####################################################################################################
[docs]
def prepare_for_tracking(self):
"""
Prepare parcellation for fiber tracking by merging cortical white matter labels
to their corresponding cortical gray matter values.
Converts white matter labels (>=3000) to corresponding gray matter labels
by subtracting 3000, and removes other structures labels (>=5000).
This method is useful for tractography applications where the parcellation was generated
using Chimera (https://github.com/connectomicslab/chimera). It can also be applied to other
parcellations following the same labeling scheme.
- Gray matter regions: 1-2999
- White matter regions: 3000-4999 (to be merged with gray matter)
- 3000 for white matter label
- For FreeSurfer cortical white matter labels, the value is 3000 + corresponding gray matter label
- Other structures: 5000-5008 (to be removed)
- Corporus Callosum: 5009-5013 (to be merged with white matter)
Examples
--------
>>> parc.prepare_for_tracking()
>>> print(f"Max label after prep: {parc.data.max()}")
"""
# Get the Corpus Callosum and add them to the white matter label
ind = np.argwhere(self.data >= 5009)
self.data[ind[:, 0], ind[:, 1], ind[:, 2]] = 3000
# Unique of non-zero values
sts_vals = np.unique(self.data)
# sts_vals as integers
sts_vals = sts_vals.astype(int)
# get the values of sts_vals that are bigger or equaal to 5000 and create a list with them
indexes = [x for x in sts_vals if x >= 5000]
self.remove_by_code(codes2remove=indexes)
# Get the labeled wm values
ind = np.argwhere(self.data >= 3000)
# Add the wm voxels to the gm label
self.data[ind[:, 0], ind[:, 1], ind[:, 2]] = (
self.data[ind[:, 0], ind[:, 1], ind[:, 2]] - 3000
)
# Adjust the values
self.adjust_values()
####################################################################################################
[docs]
def keep_by_name(self, names2keep: Union[list, str], rearrange: bool = False):
"""
Filter parcellation to keep only regions with specified names.
Parameters
----------
names2keep : str or list
Name substring(s) to search for in region names.
rearrange : bool, optional
Whether to rearrange labels starting from 1. Default is False.
Examples
--------
>>> # Keep only hippocampal regions
>>> parc.keep_by_name('hippocampus')
>>>
>>> # Keep multiple regions and rearrange
>>> parc.keep_by_name(['frontal', 'parietal'], rearrange=True)
"""
if isinstance(names2keep, str):
names2keep = [names2keep]
if hasattr(self, "index") and hasattr(self, "name") and hasattr(self, "color"):
# Find the indexes of the names that contain the substring
indexes = cltmisc.get_indexes_by_substring(
input_list=self.name,
or_filter=names2keep,
invert=False,
bool_case=False,
)
if len(indexes) > 0:
sel_st_codes = [self.index[i] for i in indexes]
self.keep_by_code(codes2keep=sel_st_codes, rearrange=rearrange)
else:
print("The names were not found in the parcellation")
#####################################################################################################
[docs]
def keep_by_code(
self, codes2keep: Union[str, list, np.ndarray], rearrange: bool = False
):
"""
Filter parcellation to keep only specified region codes.
Parameters
----------
codes2keep : list or np.ndarray
Region codes to retain in parcellation.
rearrange : bool, optional
Whether to rearrange labels consecutively from 1. Default is False.
Raises
------
ValueError
If codes2keep is empty or contains invalid codes.
Examples
--------
>>> # Keep specific regions
>>> parc.keep_by_code([1, 2, 5, 10])
>>>
>>> # Keep and rearrange
>>> parc.keep_by_code([100, 200, 300], rearrange=True)
"""
# Validate codes2keep
if isinstance(codes2keep, str):
codes2keep = [codes2keep]
# Convert codes2keep to numpy array
if isinstance(codes2keep, list):
codes2keep = cltmisc.build_indices(codes2keep)
codes2keep = np.array(codes2keep)
# Create a boolean mask for voxels to keep
mask = np.isin(self.data, codes2keep)
# Set elements to zero if they are not in the retain list
self.data[~mask] = 0
# Get the actual codes present in the filtered data (excluding 0)
remaining_codes = np.unique(self.data)
remaining_codes = remaining_codes[remaining_codes != 0]
# Filter metadata arrays to match remaining codes
if hasattr(self, "index"):
temp_index = np.array(self.index)
metadata_mask = np.isin(temp_index, remaining_codes)
self.index = temp_index[metadata_mask].tolist()
# Apply the same mask to other metadata arrays
if hasattr(self, "name"):
self.name = np.array(self.name)[metadata_mask].tolist()
if hasattr(self, "color"):
self.color = np.array(self.color)[metadata_mask].tolist()
if hasattr(self, "opacity"):
opacity_arr = np.array(self.opacity)
# Guard: opacity must be a 1-D array with length matching the index
if opacity_arr.ndim == 1 and len(opacity_arr) == len(metadata_mask):
self.opacity = opacity_arr[metadata_mask].tolist()
else:
# Fallback: default opacity for the remaining regions
self.opacity = [1.0] * int(metadata_mask.sum())
# If rearrange is True, the parcellation will be rearranged starting from 1
if rearrange:
self.rearrange()
# Detect minimum and maximum labels
self.parc_range()
#####################################################################################################
[docs]
def remove_by_code(
self, codes2remove: Union[str, list, np.ndarray], rearrange: bool = False
):
"""
Remove regions with specified codes from parcellation.
Parameters
----------
codes2remove : list or np.ndarray
Region codes to remove from parcellation.
rearrange : bool, optional
Whether to rearrange remaining labels from 1. Default is False.
Examples
--------
>>> # Remove specific regions
>>> parc.remove_by_code([1, 5, 10])
>>>
>>> # Remove and rearrange
>>> parc.remove_by_code([100, 200], rearrange=True)
"""
# Validate codes2remove
if isinstance(codes2remove, str):
codes2remove = [codes2remove]
# Convert codes2remove to numpy array
if isinstance(codes2remove, list):
codes2remove = cltmisc.build_indices(codes2remove)
codes2remove = np.array(codes2remove)
# Set voxels with codes to remove to 0
self.data[np.isin(self.data, codes2remove)] = 0
# Get remaining codes (excluding 0)
remaining_codes = np.unique(self.data)
remaining_codes = remaining_codes[remaining_codes != 0]
# Use keep_by_code to clean up metadata
# (parc_range is called by keep_by_code, so no need to call it again)
self.keep_by_code(codes2keep=remaining_codes, rearrange=rearrange)
#####################################################################################################
[docs]
def remove_by_name(self, names2remove: Union[list, str], rearrange: bool = False):
"""
Remove regions with specified names from parcellation.
Parameters
----------
names2remove : str or list
Name substring(s) to search for removal.
rearrange : bool, optional
Whether to rearrange remaining labels from 1. Default is False.
Examples
--------
>>> # Remove ventricles
>>> parc.remove_by_name('ventricle')
>>>
>>> # Remove multiple structures
>>> parc.remove_by_name(['csf', 'unknown'], rearrange=True)
"""
# Convert single string to list
if isinstance(names2remove, str):
names2remove = [names2remove]
# Check required attributes
if not (hasattr(self, "name") and hasattr(self, "index")):
raise AttributeError(
"Parcellation must have 'name' and 'index' attributes to remove by name"
)
# Get indexes of regions whose names contain the substrings to remove
indexes_to_remove = cltmisc.get_indexes_by_substring(
input_list=self.name, or_filter=names2remove, invert=False, bool_case=False
)
if len(indexes_to_remove) == 0:
print(f"No regions found matching: {names2remove}")
return
# Get the codes corresponding to the regions to remove
codes_to_remove = [self.index[i] for i in indexes_to_remove]
# Remove the regions using remove_by_code
# (parc_range is called by remove_by_code, so no need to call it again)
self.remove_by_code(codes2remove=codes_to_remove, rearrange=rearrange)
#####################################################################################################
[docs]
def apply_mask(
self,
image_mask: Union[str, Path, np.ndarray],
mask_codes: Union[str, list, np.ndarray] = None,
invert: bool = False,
fill: bool = False,
):
"""
Apply spatial mask to restrict parcellation to specific regions.
Parameters
----------
image_mask : str, Path, np.ndarray, or Parcellation
3D mask array, parcellation object, or path to mask file.
Can be binary mask (0/1) or labeled image with region codes.
mask_codes : list or np.ndarray, optional
Specific codes in the mask image to use for masking.
If None, uses all non-zero values in mask. Default is None.
invert : bool, optional
If False, keep only voxels where mask has specified codes.
If True, remove voxels where mask has specified codes.
Default is False.
fill : bool, optional
Whether to grow regions to fill mask using region growing.
Default is False.
Raises
------
ValueError
If mask file doesn't exist or shapes don't match.
Examples
--------
>>> # Apply binary cortical mask
>>> parc.apply_mask(cortex_mask)
>>>
>>> # Mask using specific regions from another parcellation
>>> parc.apply_mask(roi_parc, mask_codes=[1, 2, 3])
>>>
>>> # Inverse masking with region growing
>>> parc.apply_mask(exclusion_mask, invert=True, fill=True)
"""
# Load mask data
if isinstance(image_mask, (str, Path)):
image_mask = str(image_mask)
if not os.path.exists(image_mask):
raise ValueError(f"Mask file does not exist: {image_mask}")
temp_mask = nib.load(image_mask)
mask_data = temp_mask.get_fdata()
elif isinstance(image_mask, np.ndarray):
mask_data = image_mask
elif isinstance(image_mask, Parcellation):
mask_data = image_mask.data
else:
raise ValueError(
"image_mask must be a file path, numpy array, or Parcellation object"
)
# Validate shape compatibility
if mask_data.shape != self.data.shape:
raise ValueError(
f"Mask shape {mask_data.shape} doesn't match parcellation shape {self.data.shape}"
)
# Determine which codes in the mask to use
if mask_codes is None:
# Use all non-zero values in the mask
mask_codes = np.unique(mask_data)
mask_codes = mask_codes[mask_codes != 0]
else:
# Convert to standardized format
if isinstance(mask_codes, str):
mask_codes = [mask_codes]
mask_codes = cltmisc.build_indices(mask_codes)
mask_codes = np.array(mask_codes)
# Create boolean mask for regions to keep
bool_mask = np.isin(mask_data, mask_codes)
# Apply masking
if invert:
# Remove voxels where mask contains specified codes
self.data[bool_mask] = 0
bool_mask = ~bool_mask # Invert for region growing
else:
# Keep only voxels where mask contains specified codes
self.data[~bool_mask] = 0
# Optional region growing to fill the mask
if fill:
self.data = cltimg.region_growing(self.data, bool_mask)
# Adjust parcellation values
self.adjust_values()
# Update parcellation range
self.parc_range()
####################################################################################################
[docs]
def mask_image(
self,
image_2mask: Union[str, Path, list, np.ndarray],
masked_image: Union[str, Path, list, None] = None,
roi_codes: Union[str, list, np.ndarray] = None,
roi_names: Union[str, list] = None,
invert: bool = False,
) -> Union[np.ndarray, list]:
"""
Mask external images using parcellation as binary mask.
Parameters
----------
image_2mask : str, Path, list, or np.ndarray
Image(s) to mask using parcellation. Can be file path(s) or array.
masked_image : str, Path, list, optional
Output path(s) for masked images. Required when image_2mask is path(s).
Ignored when image_2mask is numpy array. Default is None.
roi_codes : str, list or np.ndarray, optional
Region codes to use for masking. Default is None (all non-zero regions).
roi_names : str or list, optional
Region names to use for masking. Default is None.
invert : bool, optional
If False, keep only voxels within specified regions.
If True, remove voxels within specified regions.
Default is False.
Returns
-------
np.ndarray or list
If image_2mask is numpy array, returns masked array.
If image_2mask is path(s), returns list of output paths.
Raises
------
ValueError
If both roi_codes and roi_names are specified, if output paths don't
match input paths in length, or files don't exist, or shapes don't match.
Examples
--------
>>> # Mask T1 image with all parcellation regions
>>> parc.mask_image('T1w.nii.gz', 'T1w_masked.nii.gz')
['T1w_masked.nii.gz']
>>> # Mask with specific region codes
>>> parc.mask_image('fmri.nii.gz', 'fmri_masked.nii.gz', roi_codes=[1, 2, 3])
['fmri_masked.nii.gz']
>>> # Mask with specific region names
>>> parc.mask_image('dwi.nii.gz', 'dwi_masked.nii.gz', roi_names=['cortex', 'hippocampus'])
['dwi_masked.nii.gz']
>>> # Inverted masking (remove specific regions)
>>> parc.mask_image('dwi.nii.gz', 'dwi_masked.nii.gz', roi_codes=[5, 6], invert=True)
['dwi_masked.nii.gz']
>>> # Mask numpy array
>>> masked_data = parc.mask_image(img_array, roi_codes=[10, 20])
"""
# Normalize image_2mask to list
if isinstance(image_2mask, (str, Path)):
image_2mask = [image_2mask]
is_file_input = isinstance(image_2mask, list) and isinstance(
image_2mask[0], (str, Path)
)
# Handle masked_image paths
if is_file_input:
if masked_image is None:
raise ValueError(
"masked_image output path(s) required when image_2mask is file path(s)"
)
if isinstance(masked_image, (str, Path)):
masked_image = [masked_image]
# Convert all paths to strings
image_2mask = [str(p) for p in image_2mask]
masked_image = [str(p) for p in masked_image]
if len(masked_image) != len(image_2mask):
raise ValueError(
f"Number of output paths ({len(masked_image)}) must match "
f"number of input images ({len(image_2mask)})"
)
# Check if both inclusion criteria are specified
if roi_codes is not None and roi_names is not None:
raise ValueError(
"Cannot specify both roi_codes and roi_names. Please choose one."
)
# Determine which codes to use for masking
if roi_codes is not None:
# Use specified codes
if isinstance(roi_codes, str):
roi_codes = [roi_codes]
codes_to_use = cltmisc.build_indices(roi_codes)
codes_to_use = np.array(codes_to_use)
elif roi_names is not None:
# Get codes from names
if isinstance(roi_names, str):
roi_names = [roi_names]
if not hasattr(self, "name") or not hasattr(self, "index"):
raise ValueError(
"Parcellation must have 'name' and 'index' attributes to use roi_names"
)
# Find indexes of matching names
indexes = cltmisc.get_indexes_by_substring(
input_list=self.name, or_filter=roi_names, invert=False, bool_case=False
)
if len(indexes) == 0:
raise ValueError(f"No regions found matching names: {roi_names}")
codes_to_use = np.array([self.index[i] for i in indexes])
else:
# Use all non-zero codes
codes_to_use = np.unique(self.data)
codes_to_use = codes_to_use[codes_to_use != 0]
# Create boolean mask of voxels to zero out
if invert:
# Remove voxels with specified codes
voxels_to_zero = np.isin(self.data, codes_to_use)
else:
# Keep only voxels with specified codes
voxels_to_zero = ~np.isin(self.data, codes_to_use)
# Process file inputs
if is_file_input:
output_paths = []
for img_path, out_path in zip(image_2mask, masked_image):
if not os.path.exists(img_path):
raise ValueError(f"Image file does not exist: {img_path}")
# Load image
temp_img = nib.load(img_path)
img_data = temp_img.get_fdata()
# Validate shape
if img_data.shape[:3] != self.data.shape:
raise ValueError(
f"Image shape {img_data.shape[:3]} doesn't match "
f"parcellation shape {self.data.shape}"
)
# Apply mask
img_data[voxels_to_zero] = 0
# Save masked image
out_img = nib.Nifti1Image(img_data, temp_img.affine, temp_img.header)
nib.save(out_img, out_path)
output_paths.append(out_path)
return output_paths
# Process numpy array input
elif isinstance(image_2mask, np.ndarray):
# Validate shape
if image_2mask.shape[:3] != self.data.shape:
raise ValueError(
f"Image shape {image_2mask.shape[:3]} doesn't match "
f"parcellation shape {self.data.shape}"
)
# Create copy to avoid modifying input
img_data = image_2mask.copy()
img_data[voxels_to_zero] = 0
return img_data
else:
raise ValueError(
"image_2mask must be a file path, Path object, list of paths, or numpy array"
)
#####################################################################################################
[docs]
def compute_region_adjacency(
self,
roi_codes: Union[List[int], np.ndarray] = None,
roi_names: Union[List[str], str] = None,
rearrange: bool = False,
) -> Tuple[np.ndarray, dict, dict]:
"""
Computes the region adjacency (neighbor) matrix for the parcellation.
Parameters
----------
roi_codes : list or np.ndarray, optional
Specific region codes to include. Default is None (all regions).
roi_names : list or str, optional
Specific region names to include. Default is None.
rearrange : bool, optional
Whether to rearrange the parcellation labels before computing connectivity.
Default is False.
Returns
-------
neighb_matrix : np.ndarray
Binary adjacency matrix (n_regions x n_regions) indicating neighboring regions.
source : dict
Dictionary containing source region indices, codes, and names.
target : dict
Dictionary containing target region indices, codes, and names.
Raises
------
ValueError
If both roi_codes and roi_names are specified.
"""
from .imagetools import MorphologicalOperations
# Check if both inclusion criteria are specified
if roi_codes is not None and roi_names is not None:
raise ValueError(
"Cannot specify both roi_codes and roi_names. Please choose one."
)
# Work on a copy to avoid modifying original
temp_parc = copy.deepcopy(self)
# Apply filtering if specified
if roi_codes is not None:
temp_parc.keep_by_code(codes2keep=roi_codes, rearrange=rearrange)
if roi_names is not None:
temp_parc.keep_by_name(names2keep=roi_names, rearrange=rearrange)
data = temp_parc.data
all_neigh_pairs = np.zeros((0, 2), dtype=int)
# Find all neighboring pairs
for i in range(len(temp_parc.index)):
region_code = temp_parc.index[i]
# Create binary mask for the region of interest
region_mask = data == region_code
# Dilate the region by 1 voxel
morph = MorphologicalOperations()
dilated_mask = morph.dilate(region_mask.astype(int), iterations=1)
# Find the boundary (dilated area minus original region)
boundary = (dilated_mask == 1) & (region_mask == 0)
# Get values of neighboring regions (excluding background and self)
neighbor_codes = np.unique(data[boundary])
neighbor_codes = neighbor_codes[neighbor_codes != 0] # Remove background
neighbor_codes = neighbor_codes[
neighbor_codes != region_code
] # Remove self
# Create pairs for this region
reg_pairs = np.ones((len(neighbor_codes), 2), dtype=int) * region_code
reg_pairs[:, 1] = neighbor_codes
# Concatenate to overall neighbor pairs
all_neigh_pairs = np.vstack((all_neigh_pairs, reg_pairs))
# Sort each row so the smaller value is always first
sorted_pairs = np.sort(all_neigh_pairs, axis=1)
# Find unique pairs
unique_pairs = np.unique(sorted_pairs, axis=0)
# Get ROI information
roi_codes = np.array(temp_parc.index)
roi_names = temp_parc.name
n_rois = len(roi_codes)
# Initialize the neighborhood matrix
neighb_matrix = np.zeros((n_rois, n_rois), dtype=int)
# Create source and target dictionaries
source = {"idxs": [], "codes": [], "names": []}
target = {"idxs": [], "codes": [], "names": []}
for pair in unique_pairs:
code1, code2 = pair[0], pair[1]
# Find the row indices in the ROI list
roi_idx1 = np.where(roi_codes == code1)[0][0]
roi_idx2 = np.where(roi_codes == code2)[0][0]
# Update symmetric adjacency matrix
neighb_matrix[roi_idx1, roi_idx2] = 1
neighb_matrix[roi_idx2, roi_idx1] = 1
# Store the pair (only once, not symmetric)
source["idxs"].append(roi_idx1)
source["codes"].append(code1)
source["names"].append(roi_names[roi_idx1])
target["idxs"].append(roi_idx2)
target["codes"].append(code2)
target["names"].append(roi_names[roi_idx2])
return neighb_matrix, source, target
######################################################################################################
[docs]
def compute_centroids(
self,
roi_codes: Union[List[int], np.ndarray] = None,
roi_names: Union[List[str], str] = None,
gaussian_smooth: bool = True,
sigma: float = 1.0,
closing_iterations: int = 2,
centroid_table: Union[str, Path, None] = None,
) -> pd.DataFrame:
"""
Compute region centroids, voxel counts, and volumes.
Parameters
----------
roi_codes : list or np.ndarray, optional
Specific region codes to include. Default is None (all regions).
roi_names : list or str, optional
Specific region names to include. Default is None.
gaussian_smooth : bool, optional
Whether to apply Gaussian smoothing before centroid calculation. Default is True.
sigma : float, optional
Standard deviation for Gaussian smoothing. Default is 1.0.
closing_iterations : int, optional
Number of morphological closing iterations. Default is 2.
centroid_table : str or Path, optional
Path to save results as TSV file. Default is None.
Returns
-------
pd.DataFrame
DataFrame with columns: index, name, color, x_vox, y_vox, z_vox,
x_mm, y_mm, z_mm, nvoxels, volume.
Raises
------
ValueError
If both roi_codes and roi_names are specified.
Notes
-----
This method sets the `centroids` attribute (Nx3 array in mm or voxel coordinates).
Examples
--------
>>> # Compute all centroids
>>> centroids_df = parc.compute_centroids()
>>>
>>> # Specific regions with file output
>>> df = parc.compute_centroids(
... roi_codes=[1, 2, 3],
... centroid_table='centroids.tsv'
... )
>>> # Specific regions by name
>>> df = parc.compute_centroids(
... roi_names=['hippocampus', 'amygdala'],
... centroid_table='centroids.tsv'
... )
"""
# Check if both inclusion criteria are specified
if roi_codes is not None and roi_names is not None:
raise ValueError(
"Cannot specify both roi_codes and roi_names. Please choose one."
)
# Work on a copy to avoid modifying original
temp_parc = copy.deepcopy(self)
# Apply filtering if specified
if roi_codes is not None:
temp_parc.keep_by_code(codes2keep=roi_codes)
if roi_names is not None:
temp_parc.keep_by_name(names2keep=roi_names)
# Get region information
region_codes = np.array(temp_parc.index)
n_regions = len(region_codes)
# Initialize result lists
codes = []
names = []
colors = []
x_coords_vox = []
y_coords_vox = []
z_coords_vox = []
num_voxels = []
volumes = []
# Get voxel size
voxel_volume = cltimg.get_voxel_volume(temp_parc.affine)
# Iterate over regions - indices align with temp_parc.index
for i, region_code in enumerate(region_codes):
# Extract centroid and voxel count for this region
centroid_vox, voxel_count = cltimg.extract_centroid_from_volume(
temp_parc.data == region_code,
gaussian_smooth=gaussian_smooth,
sigma=sigma,
closing_iterations=closing_iterations,
)
# Calculate total volume
total_volume = voxel_count * voxel_volume
# Store results
codes.append(int(region_code))
names.append(temp_parc.name[i])
colors.append(temp_parc.color[i])
x_coords_vox.append(centroid_vox[0])
y_coords_vox.append(centroid_vox[1])
z_coords_vox.append(centroid_vox[2])
num_voxels.append(voxel_count)
volumes.append(total_volume)
# Convert voxel coordinates to mm
coords_vox = np.stack(
(np.array(x_coords_vox), np.array(y_coords_vox), np.array(z_coords_vox)),
axis=-1,
)
coords_mm = cltimg.vox2mm(coords_vox, self.affine)
# Store centroids as attribute (in mm)
self.centroids = coords_mm.astype(float)
# Extract mm coordinates
x_coords_mm = coords_mm[:, 0].tolist()
y_coords_mm = coords_mm[:, 1].tolist()
z_coords_mm = coords_mm[:, 2].tolist()
# Create DataFrame
df = pd.DataFrame(
{
"index": codes,
"name": names,
"color": colors,
"x_vox": x_coords_vox,
"y_vox": y_coords_vox,
"z_vox": z_coords_vox,
"x_mm": x_coords_mm,
"y_mm": y_coords_mm,
"z_mm": z_coords_mm,
"nvoxels": num_voxels,
"volume": volumes,
}
)
# Save to TSV file if path is provided
if centroid_table is not None:
centroid_table = str(centroid_table) # Convert Path to str
# Check if directory exists
directory = os.path.dirname(centroid_table)
if directory and not os.path.exists(directory):
print(f"Directory does not exist: {directory}.")
return df
try:
df.to_csv(centroid_table, sep="\t", index=False)
print(f"Centroid table saved to: {centroid_table}")
except Exception as e:
import warnings
warnings.warn(f"Failed to save centroid table: {e}", UserWarning)
return df
######################################################################################################
[docs]
def get_regionwise_timeseries(
self,
time_series_data: Union[str, np.ndarray],
vols_to_delete: Union[List[int], np.ndarray] = None,
method: str = "nilearn",
metric: str = "mean",
roi_codes: Union[List[int], np.ndarray] = None,
roi_names: Union[List[str], str] = None,
) -> np.ndarray:
"""
Compute region-wise time series.
Parameters
----------
time_series_data : str or np.ndarray
Path to time series file or numpy array with shape (dimx X dimy X dimZ x Timepoints).
roi_codes : list or np.ndarray, optional
Specific region codes to include. Default is None (all regions).
roi_names : list or str, optional
Specific region names to include. Default is None.
ouput_h5file : str, optional
Path to save results as HDF5 file. Default is None.
Returns
-------
np.ndarray
Region-wise time series array with shape (n_regions x timepoints).
Raises
------
ValueError
If both roi_codes and roi_names are specified.
Examples
--------
>>> # Compute region-wise time series from file
>>> region_ts = parc.get_regionwise_timeseries('timeseries.nii.gz')
>>>
# Compute from numpy array
>>> region_ts = parc.get_regionwise_timeseries(time_series_data=np.random.rand(64, 64, 64, 100))
>>>
# Compute with specific regions using codes
>>> region_ts = parc.get_regionwise_timeseries(
... time_series_data='timeseries.nii.gz',
... roi_codes=[1, 2, 3])
"""
# Check if include_by_code and include_by_name are different from None at the same time
if roi_codes is not None and roi_names is not None:
roi_codes = None
print(
"Both roi_codes and roi_names were specified. Ignoring roi_codes and using roi_names for region selection."
)
temp_parc = copy.deepcopy(self)
# Apply inclusion if specified
if roi_codes is not None:
temp_parc.keep_by_code(codes2keep=roi_codes)
if roi_names is not None:
temp_parc.keep_by_name(names2keep=roi_names)
# Delete volumes if specified
if vols_to_delete is not None:
# Check if the time_series_data is a string
if isinstance(time_series_data, str):
# Generating a temporary file to save the 4D data
tmp_image = cltmisc.create_temporary_filename(
prefix="temp_timeseries",
extension=".nii.gz",
tmp_dir="/tmp",
)
if method == "nilearn":
# Deleting the volumes from the 4D image
del_img = cltimg.delete_volumes_from_4D_images(
in_image=time_series_data,
out_image=tmp_image,
vols_to_delete=vols_to_delete,
)
time_series_data_tmp = tmp_image
elif method == "clabtoolkit":
# Deleting the volumes from the 4D image and loading it as a numpy array
# Load the 4D image
img = nib.load(time_series_data)
# Get the dimensions of the image
dim = img.shape
time_series_data_tmp, _ = cltimg.delete_volumes_from_4D_array(
in_array=img.get_fdata(), vols_to_delete=vols_to_delete
)
elif isinstance(time_series_data, np.ndarray):
time_series_data_tmp, _ = cltimg.delete_volumes_from_4D_array(
in_array=time_series_data, vols_to_delete=vols_to_delete
)
else:
time_series_data_tmp = time_series_data
if method == "nilearn":
if isinstance(time_series_data_tmp, str):
# Check if the file exists
try:
from nilearn.maskers import NiftiLabelsMasker
except:
raise ImportError(
"nilearn is not installed. Please install it to use this method."
)
# Generating a temporary parcellation file
tmp_parc_image = cltmisc.create_temporary_filename(
prefix="temp_parcellation", extension=".nii.gz", tmp_dir="/tmp"
)
tmp_basename = cltmisc.get_real_basename(tmp_parc_image)
tmp_parc_image_nilearnlut = os.path.join(
"/tmp", f"{tmp_basename}_nilearnlut.txt"
)
temp_parc.save_parcellation(
out_file=tmp_parc_image,
lut_file=tmp_parc_image_nilearnlut,
lut_type="nilearn",
force=True,
)
# Generating the masker
masker = NiftiLabelsMasker(
labels_img=tmp_parc_image,
lut=tmp_parc_image_nilearnlut,
standardize="zscore_sample",
standardize_confounds=True,
memory="nilearn_cache",
verbose=1,
)
# Check if the parcellation is a numpy array
region_time_series = masker.fit_transform(time_series_data_tmp).T
# Delete the temporary files
os.remove(tmp_parc_image)
os.remove(tmp_parc_image_nilearnlut)
elif isinstance(time_series_data_tmp, np.ndarray):
print(
"Using nilearn method requires a file path. Please provide a valid file path."
)
print(
"Computing region-wise timeseries without using nilearn. This may take longer."
)
method = "clabtoolkit"
if method.lower() != "nilearn":
# Get unique region values
unique_regions = np.array(temp_parc.index)
# Load time series data
if isinstance(time_series_data_tmp, str):
if os.path.exists(time_series_data_tmp):
time_series = nib.load(time_series_data_tmp).get_fdata()
else:
raise ValueError("The time series file does not exist")
elif isinstance(time_series_data_tmp, np.ndarray):
time_series = time_series_data_tmp
else:
raise ValueError(
"time_series_data must be a string (file path) or a numpy array"
)
# Check if time series has 4 dimensions
if time_series.ndim != 4:
raise ValueError(
"Time series data must have 4 dimensions (dimx, dimy, dimz, timepoints)"
)
# Check if time series dimensions match parcellation dimensions
if time_series.shape[:3] != temp_parc.data.shape:
raise ValueError(
"Time series dimensions do not match parcellation dimensions"
)
# Detect the number of time points
num_timepoints = time_series.shape[-1]
# Initialize array to hold region-wise time series
region_time_series = np.zeros((len(unique_regions), num_timepoints))
# Fixed loop - iterate over regions and find their index
for i, region_label in enumerate(unique_regions):
# Find the index of this region in parc.index
region_idx = np.where(np.array(temp_parc.index) == region_label)[0]
if len(region_idx) == 0:
continue
region_idx = region_idx[0] # Get the first (should be only) match
# Computing the mean time series at non-zero voxels for this region
ts_values = cltimg.compute_statistics_at_nonzero_voxels(
temp_parc.data == region_label, time_series, metric=metric
)
region_time_series[i, :] = ts_values
# Create an attribute to hold the time series
region_time_series = RegionTimeSeries(
region_time_series,
method=method,
region_names=temp_parc.name,
region_colors=temp_parc.color,
)
return region_time_series
######################################################################################################
######################################################################################################
[docs]
def adjust_values(self):
"""
Synchronize index, name, and color attributes with data contents.
Removes entries for codes not present in data and updates
min/max label range.
Examples
--------
>>> parc.adjust_values()
>>> print(f"Regions in data: {len(parc.index)}")
"""
# I want to check if the len of index, name, color and opacity are the same, if not I want to raise an error
if (
hasattr(self, "index")
and hasattr(self, "name")
and hasattr(self, "color")
and hasattr(self, "opacity")
):
if not (
len(self.index)
== len(self.name)
== len(self.color)
== len(self.opacity)
):
raise ValueError(
"The length of index, name, color and opacity attributes must be the same."
)
st_codes = np.unique(self.data)
unique_codes = st_codes[st_codes != 0]
mask = np.isin(self.index, unique_codes)
indexes = np.where(mask)[0]
temp_index = np.array(self.index)
index_new = temp_index[mask]
if hasattr(self, "index"):
self.index = [int(x) for x in index_new.tolist()]
# If name is an attribute of self
if hasattr(self, "name"):
self.name = [self.name[i] for i in indexes]
# If color is an attribute of self
if hasattr(self, "color"):
self.color = [self.color[i] for i in indexes]
# If opacity is an attribute of self
if hasattr(self, "opacity"):
self.opacity = [self.opacity[i] for i in indexes]
self.parc_range()
######################################################################################################
[docs]
def group_by_codes(
self, group_dict: dict, keep_ungrouped: bool = False
) -> Tuple[np.ndarray, dict]:
"""
Group array values and create color table for new groups.
Structures not included in any group will remain unchanged with their original
properties in both the array and color table.
Parameters:
-----------
group_dict : dict
{new_id: {'index': [old_ids], 'name': str, 'color': str, 'opacity': float}}
Index values can be integers, strings with ranges ("11:12", "50-52"), or mixed.
Name, color, and opacity are optional.
keep_ungrouped : bool, optional
Whether to keep structures not included in any group. Default is False.
Returns:
--------
tuple : (modified_array, color_table)
modified_array : numpy.ndarray
Array with grouped values replaced by new IDs. Ungrouped structures remain unchanged.
color_table : dict
Color table with 'index', 'name', 'color', 'opacity', 'headerlines' keys.
Includes both grouped structures and ungrouped structures with original properties.
Examples:
---------
>>> import numpy as np
>>> array = np.random.randint(0, 60, (100, 100, 100))
>>>
>>> # Define groups with mixed specifications
>>> group_dict = {
... 3: {'index': ["11:12", "50-52", 13]}, # Auto-generated name and color
... 4: {'index': [10, 49], 'name': 'Thalamus', 'color': '#33FF57', 'opacity': 0.8},
... 5: {'index': [17, 53, 18, 54], 'name': 'LimbicSystem', 'color': '#3357FF'},
... 6: {'index': [8, 47], 'name': 'Cerebellum', 'color': '#F1C40F', 'opacity': 0.8}
... }
>>>
>>> grouped_array, color_table = parc.group_by_codes(group_dict)
>>> print(color_table['index']) # [3, 4, 5, 6, ...ungrouped codes...]
>>> print(color_table['name']) # ['group_1', 'Thalamus', 'LimbicSystem', 'Cerebellum', ...original names...]
"""
if keep_ungrouped == False:
# Create a mask of all old IDs to be grouped
all_old_ids = []
for params in group_dict.values():
old_ids = params["index"]
old_ids = cltmisc.build_indices(old_ids)
all_old_ids.extend(old_ids)
all_old_ids = set(all_old_ids)
self.keep_by_code(codes2keep=list(all_old_ids))
array = self.data
color_table = {
"index": [],
"name": [],
"color": [],
"opacity": [],
"headerlines": [],
}
ngroups = len(group_dict)
def_color = cltcol.create_distinguishable_colors(ngroups, output_format="hex")
def_names = cltmisc.create_names_from_indices(
np.arange(1, ngroups + 1), prefix="group"
)
for i, (new_id, params) in enumerate(group_dict.items()):
old_ids = params["index"]
old_ids = cltmisc.build_indices(old_ids)
# Replace old IDs with new ID in array
mask = np.isin(array, old_ids)
array[mask] = new_id
# Create new entry
color_table["index"].append(new_id)
color_table["name"].append(params.get("name", def_names[i]))
color_table["color"].append(params.get("color", def_color[i]))
color_table["opacity"].append(params.get("opacity", 1.0))
# Updating the parcellation data
self.data = array
# Looking for codes that were not grouped
unique_codes = np.unique(array)
unique_codes = unique_codes[unique_codes != 0]
# Adding ungrouped codes to the color table
for i, code in enumerate(unique_codes):
if code not in color_table["index"]:
color_table["index"].append(code)
try:
pos = self.index.index(code)
color_table["name"].append(self.name[pos])
color_table["color"].append(self.color[pos])
color_table["opacity"].append(self.opacity[pos])
except ValueError:
# Code not in original color table, use defaults
color_table["name"].append(f"region_{code}")
color_table["color"].append("#ffffff")
color_table["opacity"].append(1.0)
self.index = color_table["index"]
self.name = color_table["name"]
self.color = cltcol.harmonize_colors(color_table["color"], output_format="hex")
self.opacity = color_table["opacity"]
self.adjust_values()
# Detect minimum and maximum labels
self.parc_range()
return array, color_table
######################################################################################################
[docs]
def group_by_names(
self, group_dict: dict, keep_ungrouped: bool = True
) -> Tuple[np.ndarray, dict]:
"""
Group array values and create color table for new groups using name-based dictionary.
Structures not included in any group will remain unchanged with their original
properties in both the array and color table.
Parameters:
-----------
group_dict : dict
{name: {'codes': [old_ids], 'color': (R,G,B)/hex or str, 'opacity': float, 'index': new_id}}
Codes can be integers, strings with ranges ("11:12", "50-52"), or mixed.
Color can be RGB tuple (0-255) or hex string.
Opacity is optional (default: 1.0).
keep_ungrouped : bool, optional
Whether to keep structures not included in any group. Default is True.
Returns:
--------
tuple : (modified_array, color_table)
modified_array : numpy.ndarray
Array with grouped values replaced by new IDs. Ungrouped structures remain unchanged.
color_table : dict
Color table with 'index', 'name', 'color', 'opacity', 'headerlines' keys.
Includes both grouped structures and ungrouped structures with original properties.
Examples:
---------
>>> group_dict = {
... 'BasalGanglia': {"codes": [11, 12, 50, 51, 13, 52], "color": (255, 0, 0), "opacity": 0.8, "index": 1},
... 'Thalamus': {"codes": [10, 49], "color": (0, 255, 0), "opacity": 0.8, "index": 2},
... 'Limbic': {"codes": [17, 53, 18, 54, 26, 58], "color": (0, 0, 255), "opacity": 0.8, "index": 3},
... 'Cerebellum': {"codes": [7, 46], "color": (255, 255, 0), "opacity": 0.8, "index": 4}
... }
>>>
>>> grouped_array, color_table = parc.group_by_names(group_dict)
>>> print(color_table['index']) # [1, 2, 3, 4, ...ungrouped codes...]
>>> print(color_table['name']) # ['BasalGanglia', 'Thalamus', 'Limbic', 'Cerebellum', ...original names...]
"""
if keep_ungrouped == False:
# Create a mask of all old IDs to be grouped
all_old_names = []
for params in group_dict.values():
new_id = params["index"]
all_old_names.extend(params["names"])
all_old_names = set(all_old_names)
self.keep_by_name(names2keep=list(all_old_names))
array = self.data
color_table = {
"index": [],
"name": [],
"color": [],
"opacity": [],
"headerlines": [],
}
ngroups = len(group_dict)
def_color = cltcol.create_distinguishable_colors(ngroups, output_format="hex")
for i, (name, params) in enumerate(group_dict.items()):
new_id = params["index"]
indexes = cltmisc.get_indexes_by_substring(self.name, params["names"])
old_ids = [self.index[i] for i in indexes]
# Replace old IDs with new ID in array
mask = np.isin(array, old_ids)
array[mask] = new_id
# Convert RGB tuple to hex if needed
color = params.get("color", def_color[i])
color = cltcol.harmonize_colors([color], output_format="hex")[0]
# Create new entry
color_table["index"].append(new_id)
color_table["name"].append(name)
color_table["color"].append(color)
color_table["opacity"].append(params.get("opacity", 1.0))
# Updating the parcellation data
self.data = array
# Looking for codes that were not grouped
unique_codes = np.unique(array)
unique_codes = unique_codes[unique_codes != 0]
# Adding ungrouped codes to the color table
for code in unique_codes:
if code not in color_table["index"]:
try:
pos = self.index.index(code)
color_table["index"].append(code)
color_table["name"].append(self.name[pos])
color_table["color"].append(self.color[pos])
color_table["opacity"].append(self.opacity[pos])
except ValueError:
# Code not in original color table, use defaults
color_table["index"].append(code)
color_table["name"].append(f"region_{code}")
color_table["color"].append("#ffffff")
color_table["opacity"].append(1.0)
self.index = color_table["index"]
self.name = color_table["name"]
self.color = cltcol.harmonize_colors(color_table["color"], output_format="hex")
self.opacity = color_table["opacity"]
self.adjust_values()
# Detect minimum and maximum labels
self.parc_range()
return array, color_table
######################################################################################################
[docs]
def rearrange(self, offset: int = 0):
"""
Rearrange parcellation labels to consecutive integers.
Parameters
----------
offset : int, optional
Starting value for rearranged labels. Default is 0 (starts from 1).
Examples
--------
>>> # Rearrange to 1, 2, 3, ...
>>> parc.rearrange()
>>>
>>> # Start from 100
>>> parc.rearrange(offset=99)
"""
# First, adjust values to ensure index, name, color align with data
self.adjust_values()
# Get unique structure codes in data (excluding background/0)
st_codes = np.unique(self.data)
st_codes = st_codes[st_codes != 0]
# Leave the index if it is present in stcodes
new_parc = np.zeros_like(self.data)
new_index = []
new_name = []
new_color = []
new_opacity = []
for i, index in enumerate(self.index):
if index in st_codes:
new_index.append(i + 1 + offset)
new_name.append(self.name[i])
new_color.append(self.color[i])
if hasattr(self, "opacity"):
new_opacity.append(self.opacity[i])
new_parc[self.data == index] = i + 1 + offset
# Update data with rearranged labels
self.data = new_parc
self.index = new_index
self.name = new_name
self.color = new_color
if hasattr(self, "opacity"):
self.opacity = new_opacity
# Update parcellation range
self.parc_range()
######################################################################################################
[docs]
def harmonize(self):
"""
Harmonize parcellation attributes with data contents.
Ensures index, name, color and opacity attributes align in type and length
with actual data values, removing unused entries and updating min/max label range.
Examples
--------
>>> parc.harmonize()
>>> print(f"Regions in data: {len(parc.index)}")
"""
# Get unique codes present in data (excluding background/0)
st_codes = np.unique(self.data)
unique_codes = st_codes[st_codes != 0]
# Find which indices are actually present in the data
if hasattr(self, "index"):
mask = np.isin(self.index, unique_codes)
indexes = np.where(mask)[0]
# Filter index to only present labels
temp_index = np.array(self.index)
index_new = temp_index[mask]
self.index = [int(x) for x in index_new.tolist()]
# Filter name if present
if hasattr(self, "name"):
self.name = [self.name[i] for i in indexes]
# Filter color if present
if hasattr(self, "color"):
self.color = [self.color[i] for i in indexes]
# Filter opacity if present
if hasattr(self, "opacity"):
self.opacity = [self.opacity[i] for i in indexes]
# Harmonize colors to consistent format (after filtering)
if hasattr(self, "color"):
self.color = cltcol.harmonize_colors(self.color)
# Harmonize opacity to list of floats (after filtering)
if hasattr(self, "opacity"):
if isinstance(self.opacity, np.ndarray):
self.opacity = [float(x) for x in self.opacity.tolist()]
elif not hasattr(self.opacity, "__len__"):
# scalar: broadcast to all regions
self.opacity = [float(self.opacity)] * len(self.index)
elif not isinstance(self.opacity, list):
self.opacity = [float(x) for x in self.opacity]
else:
self.opacity = [float(x) for x in self.opacity]
# Update parcellation range
self.parc_range()
######################################################################################################
[docs]
def add_parcellation(self, parc2add, append: bool = False):
"""
Combine another parcellation into current object.
Parameters
----------
parc2add : Parcellation or list
Parcellation object(s) to add.
append : bool, optional
If True, adds new labels by offsetting. If False, overlays directly. Default is False.
Examples
--------
>>> # Overlay parcellations
>>> parc1.add_parcellation(parc2, append=False)
>>>
>>> # Append with new labels
>>> parc1.add_parcellation(parc2, append=True)
"""
# Harmonize current parcellation
self.harmonize()
# Convert single parcellation to list
if isinstance(parc2add, Parcellation):
parc2add = [parc2add]
if not isinstance(parc2add, list):
raise TypeError(
"parc2add must be a Parcellation object or list of Parcellation objects"
)
if len(parc2add) == 0:
raise ValueError("The parcellation list is empty")
# Process each parcellation to add
for parc in parc2add:
if not isinstance(parc, Parcellation):
raise TypeError("All elements must be Parcellation objects")
# Deep copy and harmonize
tmp_parc = copy.deepcopy(parc)
tmp_parc.harmonize()
# Get non-zero indices
ind = np.where(tmp_parc.data != 0)
# Adjust labels if appending
if append:
tmp_parc.data[ind] = tmp_parc.data[ind] + self.maxlab
if hasattr(tmp_parc, "index"):
tmp_parc.index = [int(x + self.maxlab) for x in tmp_parc.index]
# Check if both parcellations have lookup tables
has_lut = all(
hasattr(tmp_parc, attr) for attr in ["index", "name", "color"]
)
self_has_lut = all(
hasattr(self, attr) for attr in ["index", "name", "color"]
)
if has_lut and self_has_lut:
# After harmonize(), all attributes are lists, so simple concatenation
self.index = self.index + tmp_parc.index
self.name = self.name + tmp_parc.name
self.color = self.color + tmp_parc.color
# Handle opacity if present in either parcellation
if hasattr(tmp_parc, "opacity"):
if hasattr(self, "opacity"):
self.opacity = self.opacity + tmp_parc.opacity
else:
# Create default opacity for existing labels
self.opacity = [1.0] * len(self.index) + tmp_parc.opacity
elif hasattr(self, "opacity"):
# Extend opacity with defaults for new labels
self.opacity = self.opacity + [1.0] * len(tmp_parc.index)
elif has_lut and np.sum(self.data) == 0:
# Self is empty, copy all attributes from tmp_parc
self.index = tmp_parc.index
self.name = tmp_parc.name
self.color = tmp_parc.color
if hasattr(tmp_parc, "opacity"):
self.opacity = tmp_parc.opacity
# Update parcellation data
self.data[ind] = tmp_parc.data[ind]
# Final harmonization
self.harmonize()
# Detect minimum and maximum labels
self.parc_range()
######################################################################################################
[docs]
def save_parcellation(
self,
out_file: Union[str, Path],
affine: np.float64 = None,
headerlines: Union[list, str] = [],
lut_file: Union[str, Path, List[str], List[Path]] = None,
lut_type: Union[str, List[str]] = "lut",
force: bool = True,
):
"""
Save parcellation to NIfTI file with optional lookup tables.
Parameters
----------
out_file : str
Output file path.
affine : np.ndarray, optional
Affine transformation matrix. If None, uses object's affine.
headerlines : list, str, or None, optional
Header lines for LUT format. If None, uses object's headerlines.
lut_file : str, Path, list of str/Path, or None, optional
Path(s) for lookup table file(s). If None, paths are auto-generated
from out_file using the appropriate extension for each lut_type.
If a list, must match the length of lut_type.
lut_type : str or list of str, optional
Lookup table format(s): 'lut', 'tsv', 'fsl', or 'nilearn'.
Can be a list to export multiple formats simultaneously,
e.g. ['lut', 'tsv']. Default is 'lut'.
force : bool, optional
Whether to overwrite existing files. Default is True.
Raises
------
ValueError
If lut_file is a list whose length does not match lut_type,
or if an unrecognised lut_type is given.
Examples
--------
>>> # Save with a single LUT format
>>> parc.save_parcellation('output.nii.gz', lut_type='tsv')
>>> # Save with multiple LUT formats (auto-generated paths)
>>> parc.save_parcellation('output.nii.gz', lut_type=['lut', 'tsv'])
>>> # Save with multiple LUT formats and explicit paths
>>> parc.save_parcellation(
... 'output.nii.gz',
... lut_type=['lut', 'tsv'],
... lut_file=['custom.lut', 'custom.tsv']
... )
"""
# Mapping from lut_type to file extension
_EXT_MAP = {
"lut": ".lut",
"tsv": ".tsv",
"fsl": ".fsllut",
"nilearn": ".nilearnlut",
}
# Handle affine
if affine is None:
affine = self.affine
# Handle headerlines
if headerlines is None:
headerlines = self.headerlines
elif isinstance(headerlines, str):
headerlines = [headerlines]
# Normalise out_file to str
if isinstance(out_file, Path):
out_file = str(out_file)
if not force and os.path.exists(out_file):
raise FileExistsError(
f"File {out_file} already exists. Set force=True to overwrite."
)
# Save NIfTI file with proper data type
data_to_save = self.data.astype(np.int32)
out_atlas = nib.Nifti1Image(data_to_save, affine)
nib.save(out_atlas, out_file)
if lut_type is None:
return
# --- Normalise lut_type to a list ---
if isinstance(lut_type, str):
lut_type = [lut_type]
# Validate all requested formats
for lt in lut_type:
if lt.lower() not in _EXT_MAP:
raise ValueError(
f"Unrecognised lut_type '{lt}'. Must be one of: {list(_EXT_MAP.keys())}"
)
# --- Normalise lut_file to a list of matching length ---
if lut_file is None:
# Auto-generate one path per format
base_name = cltmisc.get_real_basename(os.path.basename(out_file))
out_dir = os.path.dirname(out_file)
lut_file = [
os.path.join(out_dir, base_name + _EXT_MAP[lt.lower()])
for lt in lut_type
]
elif isinstance(lut_file, (str, Path)):
# Single explicit path — only valid when a single format is requested
if len(lut_type) > 1:
raise ValueError(
f"A single lut_file was provided but lut_type contains "
f"{len(lut_type)} formats. Provide a list of paths or set "
f"lut_file=None to auto-generate them."
)
lut_file = [str(lut_file)]
elif isinstance(lut_file, list):
lut_file = [str(f) for f in lut_file]
if len(lut_file) != len(lut_type):
raise ValueError(
f"lut_file list length ({len(lut_file)}) must match "
f"lut_type list length ({len(lut_type)})."
)
else:
raise TypeError(
f"lut_file must be a str, Path, list, or None; got {type(lut_file)}."
)
# --- Export one colortable per (lut_file, lut_type) pair ---
for lf, lt in zip(lut_file, lut_type):
self.export_colortable(
out_file=lf, lut_type=lt.lower(), force=force, headerlines=headerlines
)
######################################################################################################
[docs]
def load_colortable(self, lut_file: Union[str, Path, dict] = None):
"""
Load lookup table to associate codes with names and colors.
Parameters
----------
lut_file : str or dict, optional
Path to LUT file or dictionary with index/name/color keys. Default is None.
Examples
--------
>>> # Load FreeSurfer LUT
>>> parc.load_colortable('FreeSurferColorLUT.txt')
>>>
>>> # Load TSV table
>>> parc.load_colortable('regions.tsv')
"""
if lut_file is None:
# Get the enviroment variable of $FREESURFER_HOME
freesurfer_home = os.getenv("FREESURFER_HOME")
lut_file = os.path.join(freesurfer_home, "FreeSurferColorLUT.txt")
if isinstance(lut_file, (str, Path)):
if os.path.exists(lut_file):
self.lut_file = lut_file
col_dict = cltcol.ColorTableLoader.load_colortable(lut_file)
else:
raise ValueError("The lut file does not exist")
elif isinstance(lut_file, dict):
self.lut_file = None
col_dict = copy.deepcopy(lut_file)
if "index" in col_dict.keys() and "name" in col_dict.keys():
self.index = col_dict["index"]
self.name = col_dict["name"]
else:
raise ValueError("The dictionary must contain the keys 'index' and 'name'")
if "color" in col_dict.keys():
self.color = cltcol.harmonize_colors(col_dict["color"], output_format="hex")
else:
self.color = cltcol.create_distinguishable_colors(
len(self.index), output_format="hex"
)
if "opacity" in col_dict.keys():
self.opacity = col_dict["opacity"]
else:
self.opacity = [1.0] * len(self.index)
if "headerlines" in col_dict.keys():
self.headerlines = col_dict["headerlines"]
else:
self.headerlines = []
self.adjust_values()
self.parc_range()
######################################################################################################
[docs]
def sort_index(self):
"""
Sort index, name, and color attributes by index values.
Examples
--------
>>> parc.sort_index()
>>> print(f"First region: {parc.name[0]} (code: {parc.index[0]})")
"""
# Sort the all_index and apply the order to all_name and all_color
sort_index = np.argsort(self.index)
self.index = [self.index[i] for i in sort_index]
self.name = [self.name[i] for i in sort_index]
self.color = [self.color[i] for i in sort_index]
self.opacity = [self.opacity[i] for i in sort_index]
######################################################################################################
[docs]
def export_colortable(
self,
out_file: str,
lut_type: str = "lut",
headerlines: Union[list, str] = [],
force: bool = True,
):
"""
Export lookup table to file.
Parameters
----------
out_file : str
Output file path.
lut_type : str, optional
Output format: 'lut' or 'tsv'. Default is 'lut'.
headerlines : list or str, optional
Header lines for LUT format. Default is None.
force : bool, optional
Whether to overwrite existing files. Default is True.
Examples
--------
>>> # Export FreeSurfer LUT
>>> parc.export_colortable('regions.lut', lut_type='lut')
>>>
>>> # Export TSV
>>> parc.export_colortable('regions.tsv', lut_type='tsv')
"""
if isinstance(headerlines, str):
headerlines = [headerlines]
if len(headerlines) == 0:
headerlines = self.headerlines
if (
not hasattr(self, "index")
or not hasattr(self, "name")
or not hasattr(self, "color")
):
raise ValueError(
"The parcellation does not contain a color table. The index, name and color attributes must be present"
)
# Adjusting the colortable to the values in the parcellation
array_3d = self.data
unique_codes = np.unique(array_3d)
unique_codes = unique_codes[unique_codes != 0]
mask = np.isin(self.index, unique_codes)
indexes = np.where(mask)[0]
temp_index = np.array(self.index)
index_new = temp_index[mask]
if hasattr(self, "index"):
self.index = index_new
# If name is an attribute of self
if hasattr(self, "name"):
self.name = [self.name[i] for i in indexes]
# If color is an attribute of self
if hasattr(self, "color"):
self.color = [self.color[i] for i in indexes]
if hasattr(self, "opacity"):
self.opacity = [self.opacity[i] for i in indexes]
# Create color dictionary
now = datetime.now()
date_time = now.strftime("%m/%d/%Y, %H:%M:%S")
if len(headerlines) == 0:
headerlines = ["# $Id: {} {} \n".format(out_file, date_time)]
if os.path.isfile(self.parc_file):
headerlines.append(
"# Corresponding parcellation: {} \n".format(self.parc_file)
)
if lut_type == "lut":
now = datetime.now()
date_time = now.strftime("%m/%d/%Y, %H:%M:%S")
if len(headerlines) == 0:
headerlines = ["# $Id: {} {} \n".format(out_file, date_time)]
if os.path.isfile(self.parc_file):
headerlines.append(
"# Corresponding parcellation: {} \n".format(self.parc_file)
)
elif lut_type == "tsv":
if self.index is None or self.name is None:
raise ValueError(
"The parcellation does not contain a color table. The index and name attributes must be present"
)
tsv_df = pd.DataFrame({"index": np.asarray(self.index), "name": self.name})
# Add color if it is present
if self.color is not None:
if isinstance(self.color, list):
if isinstance(self.color[0], str):
if self.color[0][0] != "#":
raise ValueError("The colors must be in hexadecimal format")
else:
tsv_df["color"] = self.color
else:
tsv_df["color"] = cltcol.multi_rgb2hex(self.color)
elif isinstance(self.color, np.ndarray):
tsv_df["color"] = cltcol.multi_rgb2hex(self.color)
col_dict = {
"index": self.index,
"name": self.name,
"color": self.color,
"opacity": self.opacity,
"headerlines": headerlines,
}
col_obj = cltcol.ColorTableLoader(col_dict)
col_obj.export(out_file, out_format=lut_type, overwrite=force)
######################################################################################################
[docs]
def replace_values(
self,
codes2rep: Union[List[Union[int, List[int]]], np.ndarray, Dict],
new_codes: Union[int, List[int], np.ndarray] = None,
) -> None:
"""
Replace region codes with new values, supporting group replacements.
"""
# Input validation
if not hasattr(self, "data"):
raise AttributeError("Object must have 'data' attribute")
# Handle Dictionary input
if isinstance(codes2rep, Dict):
old_codes_list = list(codes2rep.keys())
old_codes_list = cltmisc.build_indices(old_codes_list)
new_codes_list = list(codes2rep.values())
new_codes_list = cltmisc.build_indices(new_codes_list)
codes2rep = copy.deepcopy(old_codes_list)
new_codes = copy.deepcopy(new_codes_list) # Don't process again later
# Process codes2rep to determine structure and number of groups
if isinstance(codes2rep, list):
if len(codes2rep) == 0:
raise ValueError("codes2rep cannot be empty")
# Detect whether it's a flat list of ints or a list of lists
if all(isinstance(x, (int, np.integer)) for x in codes2rep):
codes2rep = [[x] for x in codes2rep]
elif all(isinstance(x, list) for x in codes2rep):
pass # Already in group form
else:
raise TypeError(
"codes2rep must be a list of ints or a list of lists of ints"
)
n_groups = len(codes2rep)
elif isinstance(codes2rep, np.ndarray):
if codes2rep.ndim == 1:
codes2rep = [[int(x)] for x in codes2rep.tolist()]
else:
raise TypeError("Unsupported numpy array shape for codes2rep")
n_groups = len(codes2rep)
else:
raise TypeError(
f"codes2rep must be list or numpy array, got {type(codes2rep)}"
)
# Apply build_indices to each group in codes2rep
for i, group in enumerate(codes2rep):
codes2rep[i] = cltmisc.build_indices(group, nonzeros=False)
# Process new_codes (handle single int, list, or array)
if isinstance(new_codes, (int, np.integer)):
new_codes = np.array([new_codes], dtype=np.int32)
elif isinstance(new_codes, list):
new_codes = cltmisc.build_indices(new_codes, nonzeros=False)
new_codes = np.array(new_codes, dtype=np.int32)
else:
new_codes = np.array(new_codes, dtype=np.int32)
# Validate matching lengths
if len(new_codes) != n_groups:
raise ValueError(
f"Number of new codes ({len(new_codes)}) must equal "
f"number of groups ({n_groups}) to be replaced"
)
# Perform replacements
for group_idx in range(n_groups):
codes_to_replace = np.array(codes2rep[group_idx])
mask = np.isin(self.data, codes_to_replace)
self.data[mask] = new_codes[group_idx]
# Update color table if present
if hasattr(self, "index"):
for code in codes_to_replace:
if code in self.index:
pos = self.index.index(code)
self.index[pos] = new_codes[group_idx]
# Optionally, update name/color here if desired
# Optional post-processing
if hasattr(self, "index") and hasattr(self, "name") and hasattr(self, "color"):
if hasattr(self, "adjust_values"):
self.adjust_values()
if hasattr(self, "parc_range"):
self.parc_range()
######################################################################################################
[docs]
def parc_range(self) -> None:
"""
Update minimum and maximum label values in parcellation.
Sets minlab and maxlab attributes based on non-zero values in data.
Returns
-------
tuple : (minlab, maxlab)
minlab : int
Minimum label value (excluding zero).
maxlab : int
Maximum label value.
Examples
--------
>>> parc.parc_range()
>>> print(f"Label range: {parc.minlab} - {parc.maxlab}")
"""
# Get unique non-zero elements
unique_codes = np.unique(self.data)
nonzero_codes = unique_codes[unique_codes != 0]
if nonzero_codes.size > 0:
self.minlab = np.min(nonzero_codes)
self.maxlab = np.max(nonzero_codes)
else:
self.minlab = 0
self.maxlab = 0
return self.minlab, self.maxlab
#######################################################################################################
[docs]
def compute_morphometry_table(
self,
output_table: Union[str, Path] = None,
add_bids_entities: bool = False,
map_files: Union[str, Path, list] = None,
map_ids: Union[str, list] = None,
units: Union[str, list] = "unknown",
exclude_by_code: Union[list, np.ndarray] = None,
exclude_by_name: Union[list, str] = None,
include_by_code: Union[list, np.ndarray] = None,
include_by_name: Union[list, str] = None,
include_global: bool = True,
):
"""
Compute morphometry table for all regions in parcellation.
Sets morphometry containing region volumes and statistics.
Parameters
----------
output_table : str or Path, optional
Path to save the output table. If None, does not save. Default is None.
add_bids_entities : bool, optional
Whether to add BIDS entities to the output. Default is False.
map_files : str, Path, or list of str/Path, optional
Paths to additional map files for morphometry. If None, only base morphometry is computed.
Default is None. This method will compute morphometry for each map file provided.
map_ids : str or list, optional
IDs for the additional maps. If None, uses filenames as IDs. Default is None.
units : str or list, optional
Units for the additional maps. If None, uses "unknown". Default is "unknown".
include_by_code : list or np.ndarray, optional
List of region codes to include. If None, includes all. Default is None.
include_by_name : list or str, optional
List of region names to include. If None, includes all. Default is None.
exclude_by_code : list or np.ndarray, optional
List of region codes to exclude. If None, excludes none. Default is None.
exclude_by_name : list or str, optional
List of region names to exclude. If None, excludes none. Default is None.
include_global : bool, optional
Whether to include global morphometry metrics. Default is True.
Raises
------
TypeError
If output_table is not a string or Path, or if map_files/map_ids/units have incorrect types.
FileNotFoundError
If the output directory does not exist.
ValueError
If lengths of map_files, map_ids, and units don't match after normalization.
Examples
--------
>>> from pathlib import Path
>>> # Compute morphometry table and save to CSV
>>> parc.compute_morphometry_table(
... output_table='morphometry.csv',
... add_bids_entities=True,
... map_files=['map1.nii.gz', 'map2.nii.gz'],
... map_ids=['map1', 'map2'],
... units=['mm^3', 'unknown']
... )
>>> # Using Path objects
>>> parc.compute_morphometry_table(
... output_table=Path('morphometry.csv'),
... add_bids_entities=True,
... map_files=[Path('map1.nii.gz'), Path('map2.nii.gz')],
... map_ids=['map1', 'map2'],
... units=['mm^3', 'unknown']
... )
>>> # Compute morphometry without additional maps
>>> parc.compute_morphometry_table(
... output_table='morphometry_base.csv',
... add_bids_entities=False
... )
>>> # Compute morphometry with a single map file as Path
>>> parc.compute_morphometry_table(
... output_table=Path('morphometry_single.csv'),
... add_bids_entities=True,
... map_files=Path('single_map.nii.gz'),
... map_ids='single_map',
... units='mm^3'
... )
"""
from . import morphometrytools as cltmorpho
# Normalize map_files to list of strings
if map_files is not None:
if isinstance(map_files, (str, Path)):
map_files = [str(map_files)]
elif isinstance(map_files, list):
# Validate all items in list are strings or Path objects
if not all(isinstance(f, (str, Path)) for f in map_files):
raise TypeError(
"All items in map_files must be strings or Path objects"
)
# Convert all to strings for consistent processing
map_files = [str(f) for f in map_files]
else:
raise TypeError(
f"map_files must be str, Path, or list, got {type(map_files)}"
)
n_maps = len(map_files)
# Normalize map_ids to list
if map_ids is None:
map_ids = [cltmisc.get_real_basename(f) for f in map_files]
elif isinstance(map_ids, str):
if n_maps == 1:
map_ids = [map_ids]
else:
# Single string provided for multiple maps - auto-generate instead
print(
f"Warning: Single map_id provided for {n_maps} maps. Auto-generating IDs."
)
map_ids = [cltmisc.get_real_basename(f) for f in map_files]
elif isinstance(map_ids, list):
# Validate all items are strings
if not all(isinstance(mid, str) for mid in map_ids):
raise TypeError("All items in map_ids must be strings")
if len(map_ids) != n_maps:
print(
f"Warning: Number of map_ids ({len(map_ids)}) doesn't match number of map_files ({n_maps}). Auto-generating IDs."
)
map_ids = [cltmisc.get_real_basename(f) for f in map_files]
else:
raise TypeError(f"map_ids must be str or list, got {type(map_ids)}")
# Normalize units to list
if units is None:
units = ["unknown"] * n_maps
elif isinstance(units, str):
if n_maps == 1:
units = [units]
else:
# Single unit for multiple maps - replicate it
units = [units] * n_maps
elif isinstance(units, list):
# Validate all items are strings
if not all(isinstance(u, str) for u in units):
raise TypeError("All items in units must be strings")
if len(units) != n_maps:
print(
f"Warning: Number of units ({len(units)}) doesn't match number of map_files ({n_maps}). Using 'unknown' for all."
)
units = ["unknown"] * n_maps
else:
raise TypeError(f"units must be str or list, got {type(units)}")
# Filter out non-existent files
fin_maps = []
fin_map_ids = []
fin_units = []
for map_file, map_id, unit in zip(map_files, map_ids, units):
if os.path.exists(map_file):
fin_maps.append(map_file)
fin_map_ids.append(map_id)
fin_units.append(unit)
else:
print(f"Warning: Map file not found: {map_file}")
n_valid_maps = len(fin_maps)
if n_valid_maps == 0:
print(
"Warning: No valid map files found. Computing only base morphometry."
)
map_files = None # Reset to process only base morphometry
else:
n_valid_maps = 0
# Add Rich progress bar around the main loop
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]{task.description}", justify="right"),
BarColumn(bar_width=None),
MofNCompleteColumn(),
TextColumn("•"),
TimeRemainingColumn(),
expand=True,
) as progress:
# Total steps: 1 base + N maps
total_steps = 1 + n_valid_maps
task = progress.add_task(
"[bold green]Computing base morphometry: [bold green]volume[/bold green] ([yellow]cm³[/yellow])",
total=total_steps,
)
# Computing the volume table
morphometry_table, *_ = cltmorpho.compute_reg_volume_fromparcellation(
self,
add_bids_entities=add_bids_entities,
include_by_code=include_by_code,
include_by_name=include_by_name,
exclude_by_code=exclude_by_code,
exclude_by_name=exclude_by_name,
include_global=include_global,
)
# Update after completing base morphometry
progress.update(task, completed=1)
# If there are additional maps, compute morphometry for each map file
# --- Step 2: Additional maps ---
for i, (map_file, map_id, unit) in enumerate(
zip(fin_maps, fin_map_ids, fin_units), start=1
):
progress.update(
task,
description=f"[bold green]Processing map {i}/{n_valid_maps}[/bold green] • {map_id} ({unit})",
advance=1,
)
try:
df, _, _ = cltmorpho.compute_reg_val_fromparcellation(
map_file,
self,
add_bids_entities=add_bids_entities,
metric=map_id,
units=unit,
include_by_code=include_by_code,
include_by_name=include_by_name,
exclude_by_code=exclude_by_code,
exclude_by_name=exclude_by_name,
include_global=include_global,
)
morphometry_table = pd.concat([morphometry_table, df], axis=0)
except Exception as e:
print(f"[WARNING] Map {map_id} failed with: {e}", flush=True)
# progress.update(task, advance=1)
# Final update with clear summary
final_msg = (
f"[bold green]✓[/bold green] Completed: 1 base + {n_valid_maps} map(s)"
if n_valid_maps > 0
else "[bold green]✓[/bold.green] Completed base morphometry"
)
progress.update(task, description=final_msg, completed=total_steps)
self.morphometry = morphometry_table
# Saving the morphometry table if output_table is provided
if output_table is not None:
if not isinstance(output_table, (str, Path)):
raise TypeError(
f"output_table must be a string or Path, got {type(output_table)}"
)
# Convert to Path for consistent handling
output_table = Path(output_table)
# If the directory does not exist, raise an error
if not output_table.parent.exists():
raise FileNotFoundError(
f"Output directory does not exist: {output_table.parent}"
)
# Save the DataFrame to CSV
morphometry_table.to_csv(output_table, index=False)
print(f"Saved morphometry table to {output_table}")
return morphometry_table
######################################################################################################
[docs]
def compute_volume_table(
self,
exclude_by_code: Union[list, np.ndarray] = None,
exclude_by_name: Union[list, str] = None,
include_by_code: Union[list, np.ndarray] = None,
include_by_name: Union[list, str] = None,
include_global: bool = True,
output_table: Union[str, Path] = None,
):
"""
Compute volume table for all regions in parcellation.
Sets volumetable attribute containing region volumes and statistics.
exclude_by_code : list or np.ndarray, optional
Region codes to exclude from the analysis. If None, no regions are excluded by code.
Useful for excluding regions like ventricles or non-brain tissue.
exclude_by_name : list or str, optional
Region names to exclude from the analysis. If None, no regions are excluded by name.
Example: ["Ventricles", "White-Matter"] to focus only on gray matter regions.
include_by_code : list or np.ndarray, optional
Region codes to include in the analysis. If None, all regions are included.
Useful for focusing on specific regions of interest.
include_by_name : list or str, optional
Region names to include in the analysis. If None, all regions are included.
Example: ["Cortex", "Hippocampus"] to focus on specific structures.
add_bids_entities : bool, default=True
Whether to include BIDS entities as columns in the resulting DataFrame.
This extracts subject, session, and other metadata from the filename.
region_prefix : str, default="supra-side"
Prefix to use for region names when they cannot be determined from the parcellation object.
The prefix will be combined with the region index number.
include_global : bool, default=True
Whether to include a the total volume in the output table.
If True, adds a row for the total volume calculated from the parcellation.
Examples
--------
>>> parc.compute_volume_table()
>>> volume_df, _ = parc.volumetable
>>> print(volume_df.head())
"""
from . import morphometrytools as cltmorpho
volume_table = cltmorpho.compute_reg_volume_fromparcellation(
self,
exclude_by_code=exclude_by_code,
exclude_by_name=exclude_by_name,
include_by_code=include_by_code,
include_by_name=include_by_name,
include_global=include_global,
output_table=output_table,
)
return volume_table
######################################################################################################
[docs]
def print_properties(self):
"""
Print all attributes and methods of the parcellation object.
Displays non-private attributes and methods for object inspection.
Examples
--------
>>> parc.print_properties()
Attributes:
data
affine
index
...
Methods:
keep_by_code
save_parcellation
...
"""
# Get and print attributes and methods
attributes_and_methods = [
attr for attr in dir(self) if not callable(getattr(self, attr))
]
methods = [method for method in dir(self) if callable(getattr(self, method))]
print("Attributes:")
for attribute in attributes_and_methods:
if not attribute.startswith("__"):
print(attribute)
print("\nMethods:")
for method in methods:
if not method.startswith("__"):
print(method)
#######################################################################################################
[docs]
def compute_fc_matrix(
self,
data: Union[str, Path, np.ndarray],
method: str = "pearson",
*,
z_transform: bool = False,
absolute: bool = False,
threshold: Optional[float] = None,
normalize_rows: bool = False,
vols_to_delete: Union[str, list, np.ndarray] = None,
ts_method: str = "nilearn",
roi_codes: Union[List[int], np.ndarray] = None,
roi_names: Union[List[str], str] = None,
) -> cltcon.Connectome:
"""Compute a functional connectivity (FC) matrix from a ROI × time series or 4-D NIfTI file.
Each entry ``FC[i, j]`` reflects the pairwise association between the
time series of ROI *i* and ROI *j*. The result is always a symmetric
square matrix of shape ``(n_rois, n_rois)`` with ones on the diagonal
(except for ``"partial"`` and ``"mutual_info"``).
Parameters
----------
data : str, Path, or np.ndarray
Either a path to a NIfTI file, a pre-loaded 2-D ROI × time matrix,
or a 4-D NIfTI array (handled via ``get_regionwise_timeseries``).
method : str, default ``"pearson"``
Correlation / association method. Supported values:
``"pearson"``
Standard Pearson *r*. Fast; assumes linearity.
``"spearman"``
Rank-based Spearman *ρ*. Robust to monotone non-linearities
and mild outliers.
``"kendall"``
Kendall *τ-b*. More robust than Spearman but *O(n²)* in time
points — avoid for very long series.
``"partial"``
Partial correlation via the precision matrix (inverse of the
covariance). Controls for the linear influence of all other
ROIs. Requires ``n_timepoints > n_rois``.
``"mutual_info"``
Normalised mutual information (scikit-learn). Captures
non-linear dependencies; values in ``[0, 1]``.
z_transform : bool, default False
Apply Fisher's *r*-to-*z* transform: ``z = arctanh(r)``.
Useful before group-level statistics. Not applied for
``"mutual_info"``.
absolute : bool, default False
Return ``|FC|`` instead of signed values. Useful when only
connection *strength* matters, not sign.
threshold : float, optional
Zero out all entries whose absolute value is below *threshold*
after all other transforms.
normalize_rows : bool, default False
Z-score each row of *data* before computing the FC matrix.
Equivalent to ``create_carpet_plot``'s ``normalize_rows``.
vols_to_delete : str, list, or np.ndarray, optional
Volume indices to discard before computing the FC matrix.
Passed directly to ``get_regionwise_timeseries``.
ts_method : str, default ``"nilearn"``
Time-series extraction backend passed to ``get_regionwise_timeseries``.
Returns
-------
fc_connectome : cltcon.Connectome
A Connectome object containing the FC matrix and metadata.
Raises
------
ValueError
On unsupported method, wrong data dimensionality, or insufficient
time points for partial correlation.
Examples
--------
>>> import numpy as np
>>> data = np.random.randn(90, 200) # 90 ROIs, 200 time points
>>> fc_r = compute_fc_matrix(data, method="pearson")
>>> fc_rho = compute_fc_matrix(data, method="spearman", z_transform=True)
>>> fc_par = compute_fc_matrix(data, method="partial")
>>> fc_mi = compute_fc_matrix(data, method="mutual_info")
"""
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
SUPPORTED = {"pearson", "spearman", "kendall", "partial", "mutual_info"}
if isinstance(data, RegionTimeSeries):
fc_connectome = data.compute_fc_matrix(
method=method,
z_transform=z_transform,
absolute=absolute,
threshold=threshold,
roi_codes=roi_codes,
roi_names=roi_names,
normalize_rows=normalize_rows,
)
return fc_connectome
# Check if include_by_code and include_by_name are different from None at the same time
if roi_codes is not None and roi_names is not None:
roi_codes = None
print(
"Both roi_codes and roi_names were specified. Ignoring roi_codes and using roi_names for region selection."
)
temp_parc = copy.deepcopy(self)
# Apply inclusion if specified
if roi_codes is not None:
temp_parc.keep_by_code(codes2keep=roi_codes)
if roi_names is not None:
temp_parc.keep_by_name(names2keep=roi_names)
if vols_to_delete is not None:
# Ensure vols_to_delete is a list
if not isinstance(vols_to_delete, list):
vols_to_delete = [vols_to_delete]
# Convert vols_to_delete to a flat list of integers
vols_to_delete = cltmisc.build_indices(vols_to_delete, nonzeros=False)
# Check if vols_to_delete is not empty
if len(vols_to_delete) == 0:
vols_to_delete = None # Reset to None for get_regionwise_timeseries
if isinstance(data, (str, Path)):
if os.path.exists(data):
if isinstance(data, str):
data = temp_parc.get_regionwise_timeseries(
data, vols_to_delete=vols_to_delete, method=ts_method
)
roi_names = data.region_names
data = data.data
else:
raise ValueError(f"Data file does not exist: {data}")
elif isinstance(data, np.ndarray):
data = np.asarray(data, dtype=float)
if data.ndim == 4:
data = temp_parc.get_regionwise_timeseries(
data, vols_to_delete=vols_to_delete, method=ts_method
)
data = data.data
elif data.ndim == 2:
if vols_to_delete is not None:
if max(vols_to_delete) >= data.shape[1]:
raise ValueError(
f"vols_to_delete contains indices that exceed the number of time points ({data.shape[1]})."
)
data = np.delete(data, vols_to_delete, axis=1)
else:
raise ValueError(
f"Data array must be 2-D (n_rois × n_timepoints) or 4-D (n_x × n_y × n_z × n_timepoints), got shape {data.shape}."
)
method = method.lower().strip()
if method not in SUPPORTED:
raise ValueError(
f"Unknown method '{method}'. Choose from: {sorted(SUPPORTED)}."
)
n_rois, n_timepoints = data.shape
if n_rois != len(temp_parc.index):
raise ValueError(
f"Number of ROIs in data ({n_rois}) does not match number of regions in parcellation ({len(temp_parc.index)})."
)
# ------------------------------------------------------------------
# Optional row-wise z-scoring
# ------------------------------------------------------------------
if normalize_rows:
mu = data.mean(axis=1, keepdims=True)
sigma = data.std(axis=1, keepdims=True)
sigma[sigma == 0] = 1.0
data = (data - mu) / sigma
# ------------------------------------------------------------------
# Compute FC
# ------------------------------------------------------------------
if method == "pearson":
fc = np.corrcoef(data) # (n_rois, n_rois), fast vectorised
elif method == "spearman":
# Rank each row, then run Pearson on the ranks — identical to
# scipy.stats.spearmanr but avoids the slow Python loop.
ranked = np.apply_along_axis(stats.rankdata, axis=1, arr=data)
fc = np.corrcoef(ranked)
elif method == "kendall":
fc = np.eye(n_rois)
for i in range(n_rois):
for j in range(i + 1, n_rois):
tau, _ = stats.kendalltau(data[i], data[j])
fc[i, j] = fc[j, i] = tau
elif method == "partial":
if n_timepoints <= n_rois:
raise ValueError(
f"Partial correlation requires n_timepoints ({n_timepoints}) "
f"> n_rois ({n_rois})."
)
cov = np.cov(data) # (n_rois, n_rois)
prec = pinv(cov) # precision matrix
# Normalise to correlation scale: pcor[i,j] = -prec[i,j] / sqrt(prec[i,i]*prec[j,j])
d = np.sqrt(np.diag(prec))
fc = -prec / np.outer(d, d)
np.fill_diagonal(fc, 1.0)
elif method == "mutual_info":
from sklearn.metrics import mutual_info_score
from sklearn.preprocessing import KBinsDiscretizer
# Discretise each time series into bins for MI estimation
n_bins = max(10, int(np.sqrt(n_timepoints)))
est = KBinsDiscretizer(n_bins=n_bins, encode="ordinal", strategy="quantile")
data_disc = est.fit_transform(data.T).T.astype(
int
) # (n_rois, n_timepoints)
mi_raw = np.zeros((n_rois, n_rois))
for i in range(n_rois):
for j in range(i, n_rois):
mi = mutual_info_score(data_disc[i], data_disc[j])
mi_raw[i, j] = mi_raw[j, i] = mi
# Normalise: NMI(i,j) = MI(i,j) / sqrt(H(i)*H(j)) ∈ [0, 1]
entropies = np.diag(mi_raw) # MI(x, x) == H(x)
denom = np.sqrt(np.outer(entropies, entropies))
denom[denom == 0] = 1.0
fc = mi_raw / denom
# ------------------------------------------------------------------
# Post-processing
# ------------------------------------------------------------------
# Clip numerical noise outside [-1, 1] for correlation-based methods
if method not in {"mutual_info"}:
fc = np.clip(fc, -1.0, 1.0)
if z_transform and method != "mutual_info":
# arctanh is undefined at ±1; clip diagonal / extreme values
fc_clip = np.clip(fc, -0.9999, 0.9999)
fc = np.arctanh(fc_clip)
if absolute:
fc = np.abs(fc)
if threshold is not None:
fc[np.abs(fc) < threshold] = 0.0
fc_connectome = cltcon.Connectome(
fc,
region_names=temp_parc.name,
region_index=temp_parc.index,
region_colors=temp_parc.color,
connectivity_type="functional",
)
return fc_connectome
########################################################################################################
[docs]
class RegionTimeSeries:
"""
Class for handling region-wise time series extracted from parcellations.
Attributes
----------
data : np.ndarray
2-D array of shape (n_regions, n_timepoints) containing the time series for each region.
region_names : list of str
List of region names corresponding to the rows of `data`.
region_colors : list of tuples or list of str
List of region colors corresponding to the rows of `data`. Colors can be in RGB tuples or hex string format.
"""
[docs]
def __init__(
self,
data: np.ndarray,
region_names: List[str] = None,
region_colors: List[Union[Tuple[float, float, float], str]] = None,
method: str = "clabtoolkit",
):
"""
Initialize the RegionTimeSeries object.
Parameters
----------
data : np.ndarray
2-D array of shape (n_regions, n_timepoints) containing the time series for each region.
region_names : list of str, optional
List of region names corresponding to the rows of `data`. If None, regions will be named "Region 1", "Region 2", etc.
region_colors : list of tuples or list of str, optional
List of region colors corresponding to the rows of `data`. Colors can be in RGB tuples (e.g., (255, 0, 0)) or hex string format (e.g., "#FF0000"). If None, default colors will be assigned.
method : str, optional
Method to create time series. Default is "clabtoolkit". Supported values:
- "clabtoolkit": Use mean time series for each region.
- "nilearn": Use Nilearn's NiftiLabelsMasker to extract time series. Requires Nilearn to be installed.
"""
self.data = data
n_regions = data.shape[0]
n_timepoints = data.shape[1]
self._n_regions = n_regions
self._n_timepoints = n_timepoints
# Validate and assign region names and colors
if region_names is None:
indexes = np.arange(1, n_regions + 1)
region_names = cltmisc.create_names_from_indices(indexes)
if region_colors is None:
region_colors = cltcol.create_distinguishable_colors(
n_regions, output_format="hex"
)
if region_names is not None:
if len(region_names) != n_regions:
raise ValueError(
f"Length of region_names ({len(region_names)}) must match number of regions in data ({n_regions})."
)
self.region_names = region_names
if region_colors is not None:
if len(region_colors) != n_regions:
raise ValueError(
f"Length of region_colors ({len(region_colors)}) must match number of regions in data ({n_regions})."
)
self.region_colors = region_colors
##############################################################################################
[docs]
def compute_fc_matrix(
self,
method: str = "pearson",
*,
z_transform: bool = False,
absolute: bool = False,
threshold: Optional[float] = None,
normalize_rows: bool = False,
vols_to_delete: Union[str, list, np.ndarray] = None,
roi_names: Union[List[str], str] = None,
) -> cltcon.Connectome:
"""Compute a functional connectivity (FC) matrix.
Each entry ``FC[i, j]`` reflects the pairwise association between the
time series of ROI *i* and ROI *j*. The result is always a symmetric
square matrix of shape ``(n_rois, n_rois)`` with ones on the diagonal
(except for ``"partial"`` and ``"mutual_info"``).
Parameters
----------
method : str, default ``"pearson"``
Correlation / association method. Supported values:
``"pearson"``
Standard Pearson *r*. Fast; assumes linearity.
``"spearman"``
Rank-based Spearman *ρ*. Robust to monotone non-linearities
and mild outliers.
``"kendall"``
Kendall *τ-b*. More robust than Spearman but *O(n²)* in time
points — avoid for very long series.
``"partial"``
Partial correlation via the precision matrix (inverse of the
covariance). Controls for the linear influence of all other
ROIs. Requires ``n_timepoints > n_rois``.
``"mutual_info"``
Normalised mutual information (scikit-learn). Captures
non-linear dependencies; values in ``[0, 1]``.
z_transform : bool, default False
Apply Fisher's *r*-to-*z* transform: ``z = arctanh(r)``.
Useful before group-level statistics. Not applied for
``"mutual_info"``.
absolute : bool, default False
Return ``|FC|`` instead of signed values. Useful when only
connection *strength* matters, not sign.
threshold : float, optional
Zero out all entries whose absolute value is below *threshold*
after all other transforms.
normalize_rows : bool, default False
Z-score each row of *data* before computing the FC matrix.
Equivalent to ``create_carpet_plot``'s ``normalize_rows``.
vols_to_delete : str, list, or np.ndarray, optional
Volume indices to discard before computing the FC matrix.
Passed directly to ``get_regionwise_timeseries``.
Returns
-------
fc_connectome : cltcon.Connectome
A Connectome object containing the FC matrix and metadata.
Raises
------
ValueError
On unsupported method, wrong data dimensionality, or insufficient
time points for partial correlation.
Examples
--------
>>> import numpy as np
>>> data = np.random.randn(90, 200) # 90 ROIs, 200 time points
>>> fc_r = compute_fc_matrix(data, method="pearson")
>>> fc_rho = compute_fc_matrix(data, method="spearman", z_transform=True)
>>> fc_par = compute_fc_matrix(data, method="partial")
>>> fc_mi = compute_fc_matrix(data, method="mutual_info")
"""
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
SUPPORTED = {"pearson", "spearman", "kendall", "partial", "mutual_info"}
if roi_names is not None:
if isinstance(roi_names, str):
roi_names = [roi_names]
elif isinstance(roi_names, list):
if not all(isinstance(name, str) for name in roi_names):
raise TypeError("All items in roi_names must be strings")
else:
raise TypeError(f"roi_names must be str or list, got {type(roi_names)}")
indices = cltmisc.get_indexes_by_substring(self.region_names, roi_names)
data = self.data[indices, :]
roi_names = [self.region_names[i] for i in indices]
roi_colors = [self.region_colors[i] for i in indices]
else:
data = self.data
roi_names = self.region_names
roi_colors = self.region_colors
if vols_to_delete is not None:
# Ensure vols_to_delete is a list
if not isinstance(vols_to_delete, list):
vols_to_delete = [vols_to_delete]
# Convert vols_to_delete to a flat list of integers
vols_to_delete = cltmisc.build_indices(vols_to_delete, nonzeros=False)
# Check if vols_to_delete is not empty
if len(vols_to_delete) == 0:
vols_to_delete = None # Reset to None for get_regionwise_timeseries
if max(vols_to_delete) >= data.shape[1]:
raise ValueError(
f"vols_to_delete contains indices that exceed the number of time points ({data.shape[1]})."
)
data = np.delete(data, vols_to_delete, axis=1)
method = method.lower().strip()
if method not in SUPPORTED:
raise ValueError(
f"Unknown method '{method}'. Choose from: {sorted(SUPPORTED)}."
)
n_rois, n_timepoints = data.shape
# ------------------------------------------------------------------
# Optional row-wise z-scoring
# ------------------------------------------------------------------
if normalize_rows:
mu = data.mean(axis=1, keepdims=True)
sigma = data.std(axis=1, keepdims=True)
sigma[sigma == 0] = 1.0
data = (data - mu) / sigma
# ------------------------------------------------------------------
# Compute FC
# ------------------------------------------------------------------
if method == "pearson":
fc = np.corrcoef(data) # (n_rois, n_rois), fast vectorised
elif method == "spearman":
# Rank each row, then run Pearson on the ranks — identical to
# scipy.stats.spearmanr but avoids the slow Python loop.
ranked = np.apply_along_axis(stats.rankdata, axis=1, arr=data)
fc = np.corrcoef(ranked)
elif method == "kendall":
fc = np.eye(n_rois)
for i in range(n_rois):
for j in range(i + 1, n_rois):
tau, _ = stats.kendalltau(data[i], data[j])
fc[i, j] = fc[j, i] = tau
elif method == "partial":
if n_timepoints <= n_rois:
raise ValueError(
f"Partial correlation requires n_timepoints ({n_timepoints}) "
f"> n_rois ({n_rois})."
)
cov = np.cov(data) # (n_rois, n_rois)
prec = pinv(cov) # precision matrix
# Normalise to correlation scale: pcor[i,j] = -prec[i,j] / sqrt(prec[i,i]*prec[j,j])
d = np.sqrt(np.diag(prec))
fc = -prec / np.outer(d, d)
np.fill_diagonal(fc, 1.0)
elif method == "mutual_info":
from sklearn.metrics import mutual_info_score
from sklearn.preprocessing import KBinsDiscretizer
# Discretise each time series into bins for MI estimation
n_bins = max(10, int(np.sqrt(n_timepoints)))
est = KBinsDiscretizer(n_bins=n_bins, encode="ordinal", strategy="quantile")
data_disc = est.fit_transform(data.T).T.astype(
int
) # (n_rois, n_timepoints)
mi_raw = np.zeros((n_rois, n_rois))
for i in range(n_rois):
for j in range(i, n_rois):
mi = mutual_info_score(data_disc[i], data_disc[j])
mi_raw[i, j] = mi_raw[j, i] = mi
# Normalise: NMI(i,j) = MI(i,j) / sqrt(H(i)*H(j)) ∈ [0, 1]
entropies = np.diag(mi_raw) # MI(x, x) == H(x)
denom = np.sqrt(np.outer(entropies, entropies))
denom[denom == 0] = 1.0
fc = mi_raw / denom
# ------------------------------------------------------------------
# Post-processing
# ------------------------------------------------------------------
# Clip numerical noise outside [-1, 1] for correlation-based methods
if method not in {"mutual_info"}:
fc = np.clip(fc, -1.0, 1.0)
if z_transform and method != "mutual_info":
# arctanh is undefined at ±1; clip diagonal / extreme values
fc_clip = np.clip(fc, -0.9999, 0.9999)
fc = np.arctanh(fc_clip)
if absolute:
fc = np.abs(fc)
if threshold is not None:
fc[np.abs(fc) < threshold] = 0.0
fc_connectome = cltcon.Connectome(
fc,
region_names=roi_names,
region_index=list(range(1, n_rois + 1)),
region_colors=roi_colors,
)
return fc_connectome
#########################################################################################################
[docs]
def get_info(self) -> None:
"""
Display a formatted summary of the RegionTimeSeries object.
Shows data shape, dtype, basic statistics, and a preview of the
region names / colours (up to ``_MAX_SHOWN`` rows; the rest are
summarised).
Examples
--------
>>> rts.get_info()
"""
_MAX_SHOWN = 10
_NAME_W = 38 # max chars for a region name column
_COLOR_W = 18 # max chars for a colour column
# Inner content width: index(6) + space(1) + name + space(1) + color
_INNER_W = 4 + 1 + _NAME_W + 1 + _COLOR_W # = 64
_WIDTH = _INNER_W + 2 # +2 for the leading " " indent
def _border(left="╠", mid="═", right="╣") -> None:
print(f"{left}{mid * _WIDTH}{right}")
def _row(content: str) -> None:
# Pad / truncate so the closing '║' always lands in the same column.
print(f"║{content[:_WIDTH]:<{_WIDTH}}║")
def _trunc(s: str, max_w: int) -> str:
return s if len(s) <= max_w else s[: max_w - 1] + "…"
# ------------------------------------------------------------------ #
# Gather fields (graceful fallbacks if somehow unset)
# ------------------------------------------------------------------ #
n_regions = getattr(self, "_n_regions", self.data.shape[0])
n_timepoints = getattr(self, "_n_timepoints", self.data.shape[1])
region_names = getattr(self, "region_names", [])
region_colors = getattr(self, "region_colors", [])
# ------------------------------------------------------------------ #
# Header
# ------------------------------------------------------------------ #
_border("╔", "═", "╗")
_row(" REGION TIME SERIES".center(_WIDTH))
_border()
# ------------------------------------------------------------------ #
# Data block
# ------------------------------------------------------------------ #
_row(" DATA")
_row(f" Shape : {n_regions} regions × {n_timepoints} timepoints")
_row(f" Dtype : {self.data.dtype}")
try:
_row(f" Min / Max : {self.data.min():.4g} / {self.data.max():.4g}")
_row(f" Mean / Std : {self.data.mean():.4g} / {self.data.std():.4g}")
except (ValueError, TypeError):
_row(" Statistics : unavailable")
# ------------------------------------------------------------------ #
# Regions block
# ------------------------------------------------------------------ #
_border()
n_names = len(region_names)
_row(f" REGIONS ({n_names})")
if n_names == 0:
_row(" (no region names stored)")
else:
# Column header
idx_hdr = f"{'#':>4}"
name_hdr = f"{'Name':<{_NAME_W}}"
color_hdr = f"{'Color':<{_COLOR_W}}"
_row(f" {idx_hdr} {name_hdr} {color_hdr}")
_border("╟", "─", "╢")
n_show = min(n_names, _MAX_SHOWN)
for i in range(n_show):
idx_str = f"{i + 1:>4}"
name_str = _trunc(str(region_names[i]), _NAME_W)
color_str = _trunc(
str(region_colors[i]) if i < len(region_colors) else "—", _COLOR_W
)
_row(f" {idx_str} {name_str:<{_NAME_W}} {color_str:<{_COLOR_W}}")
if n_names > _MAX_SHOWN:
_border("╟", "─", "╢")
_row(f" … {n_names - _MAX_SHOWN} more region(s) not shown")
# ------------------------------------------------------------------ #
# Footer
# ------------------------------------------------------------------ #
_border("╚", "═", "╝")
[docs]
def show_content(
self,
show_private=False,
show_dunder=False,
show_methods=True,
show_properties=True,
show_attributes=True,
):
"""
Alias for get_info() to display the content of the RegionTimeSeries object.
Examples
--------
>>> rts.show_content()
"""
cltmisc.show_object_content(
self,
show_private=show_private,
show_dunder=show_dunder,
show_methods=show_methods,
show_properties=show_properties,
show_attributes=show_attributes,
)