import numpy as np
import pyvista as pv
import h5py
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Optional, Union, List, Tuple, Literal
import warnings
from . import colorstools as cltcol
from . import misctools as cltmisc
[docs]
class Connectome:
"""
A class to represent and visualize brain connectivity data.
Attributes:
-----------
name : str
Name identifier for the connectome
matrix : np.ndarray
Connectivity matrix (n_regions x n_regions)
coordinates : np.ndarray
3D coordinates for each region (n_regions x 3)
colors : np.ndarray
RGB color values for each region (n_regions x 3)
region_names : List[str]
Names/labels for each brain region
region_index : np.ndarray
Index codes for each region
connectivity_type : str
Type of connectivity ('structural', 'functional', 'effective', etc.)
affine : np.ndarray
4x4 affine transformation matrix
n_regions : int
Number of brain regions
"""
[docs]
def __init__(
self,
data: Optional[Union[np.ndarray, str, Path]] = None,
name: Optional[str] = None,
coordinates: Optional[np.ndarray] = None,
colors: Union[np.ndarray, List] = None,
region_names: Optional[List[str]] = None,
region_index: Optional[np.ndarray] = None,
connectivity_type: str = "structural",
affine: Optional[np.ndarray] = None,
):
"""
Initialize a Connectome object.
Parameters:
-----------
data : np.ndarray, str, Path, or None
Can be:
- np.ndarray: Connectivity matrix (n_regions x n_regions)
- str or Path: Path to HDF5 file to load
- None: Create empty Connectome
name : str, optional
Name for the connectome. If loading from file and None, uses filename stem.
coordinates : np.ndarray, optional
3D coordinates for each region (n_regions x 3)
colors : np.ndarray or List, optional
RGB color values or hex strings for each region
region_names : List[str], optional
Names/labels for each brain region
region_index : np.ndarray, optional
Index codes for each region
connectivity_type : str, optional
Type of connectivity (default: 'structural')
affine : np.ndarray, optional
4x4 affine transformation matrix
Examples:
---------
>>> # From matrix
>>> matrix = np.random.rand(10, 10)
>>> conn = Connectome(matrix)
>>> # From file
>>> conn = Connectome('/path/to/connectome.h5')
>>> # Empty connectome
>>> conn = Connectome(name='my_network')
"""
self.connectivity_type = connectivity_type
# Handle different input types for data
matrix = None
load_from_file = False
if data is not None:
if isinstance(data, (str, Path)):
# Load from file
load_from_file = True
filepath = Path(data)
# Set default name from filename if not provided
if name is None:
name = filepath.stem
elif isinstance(data, np.ndarray):
# Use as connectivity matrix
matrix = data
else:
raise TypeError(
f"data must be np.ndarray, str, Path, or None. Got {type(data)}"
)
self.name = name
# Initialize matrix and derived attributes
if matrix is not None:
if matrix.ndim != 2 or matrix.shape[0] != matrix.shape[1]:
raise ValueError("Matrix must be square (NxN)")
self.matrix = matrix.astype(np.float64)
self._n_regions = matrix.shape[0]
else:
self.matrix = None
self._n_regions = 0
# Set coordinates
if coordinates is not None:
self.set_coordinates(coordinates)
else:
self.coordinates = None
# Set colors
if colors is not None:
colors = cltcol.harmonize_colors(colors, "hex")
self.set_colors(colors)
elif self._n_regions > 0:
# Only generate default colors if we have regions
self.colors = cltcol.create_distinguishable_colors(self._n_regions)
else:
self.colors = None
# Set region names
if region_names is not None:
self.set_region_names(region_names)
elif self._n_regions > 0:
# Only generate default names if we have regions
self.region_names = cltmisc.create_names_from_indices(
np.arange(self._n_regions) + 1
)
else:
self.region_names = None
# Set region index
if region_index is not None:
if self.matrix is not None and len(region_index) != self._n_regions:
raise ValueError(
f"Region index length ({len(region_index)}) must match matrix size ({self._n_regions})"
)
self.region_index = np.array(region_index)
else:
self.region_index = (
np.arange(self._n_regions) if self.matrix is not None else None
)
# Set affine
if affine is not None:
if affine.shape != (4, 4):
raise ValueError(f"Affine must be 4x4 array, got {affine.shape}")
self.affine = affine.astype(np.float64)
else:
self.affine = np.eye(4)
# Load from file if specified
if load_from_file:
self.load_h5(filepath)
@property
def n_regions(self) -> int:
"""Get the number of brain regions."""
return self._n_regions
[docs]
@classmethod
def from_h5(
cls, filename: Union[str, Path], name: Optional[str] = None
) -> "Connectome":
"""
Create a Connectome object from an HDF5 file.
Parameters:
-----------
filename : str or Path
Path to the HDF5 file containing connectivity data
name : str, optional
Name for the connectome. If None, uses filename stem.
Returns:
--------
Connectome : New Connectome object with loaded data
"""
filename = Path(filename)
# Set default name from filename if not provided
if name is None:
name = filename.stem
connectome = cls(name=name)
connectome.load_h5(filename)
return connectome
def _calculate_node_sizes(
self, property_type: str, threshold: float, scale: float, base_size: float
) -> np.ndarray:
"""
Calculate node sizes based on different properties.
Parameters:
-----------
property_type : str
Type of property to use for sizing
threshold : float
Threshold for degree calculation
scale : float
Scale factor
base_size : float
Base size for nodes
Returns:
--------
np.ndarray : Array of node sizes
"""
if property_type == "uniform":
return np.full(self.n_regions, base_size)
elif property_type == "strength":
# Total connectivity strength (sum of absolute connections)
strengths = np.sum(np.abs(self.matrix), axis=1)
normalized = (
strengths / np.max(strengths) if np.max(strengths) > 0 else strengths
)
return normalized * 10 * scale + base_size
elif property_type == "degree":
# Number of connections above threshold
degrees = np.sum(np.abs(self.matrix) > threshold, axis=1)
normalized = degrees / np.max(degrees) if np.max(degrees) > 0 else degrees
return normalized * 10 * scale + base_size
elif property_type == "betweenness":
try:
import networkx as nx
# Create graph from adjacency matrix
G = nx.from_numpy_array(np.abs(self.matrix))
centrality = nx.betweenness_centrality(G)
values = np.array([centrality[i] for i in range(self.n_regions)])
normalized = values / np.max(values) if np.max(values) > 0 else values
return normalized * 10 * scale + base_size
except ImportError:
warnings.warn(
"NetworkX not available. Using strength instead of betweenness centrality."
)
return self._calculate_node_sizes(
"strength", threshold, scale, base_size
)
elif property_type == "eigenvector":
try:
import networkx as nx
# Create graph from adjacency matrix
G = nx.from_numpy_array(np.abs(self.matrix))
try:
centrality = nx.eigenvector_centrality(G, max_iter=1000)
values = np.array([centrality[i] for i in range(self.n_regions)])
normalized = (
values / np.max(values) if np.max(values) > 0 else values
)
return normalized * 10 * scale + base_size
except nx.PowerIterationFailedConvergence:
warnings.warn(
"Eigenvector centrality failed to converge. Using strength instead."
)
return self._calculate_node_sizes(
"strength", threshold, scale, base_size
)
except ImportError:
warnings.warn(
"NetworkX not available. Using strength instead of eigenvector centrality."
)
return self._calculate_node_sizes(
"strength", threshold, scale, base_size
)
else:
raise ValueError(
f"Unknown node size property: {property_type}. "
f"Available options: 'uniform', 'strength', 'degree', 'betweenness', 'eigenvector'"
)
[docs]
def load_h5(self, filename: Union[str, Path]) -> None:
"""
Load connectivity data from HDF5 file.
Parameters:
-----------
filename : str or Path
Path to the HDF5 file
"""
filename = Path(filename)
if not filename.exists():
raise FileNotFoundError(f"File not found: {filename}")
try:
with h5py.File(filename, "r") as f:
# Try to find data in 'connmat' group first, then root
if "connmat" in f:
data_group = f["connmat"]
else:
data_group = f
# Load connectivity matrix (required)
if "matrix" in data_group:
self.matrix = data_group["matrix"][:]
else:
raise KeyError("No 'matrix' dataset found in HDF5 file")
self._n_regions = self.matrix.shape[0]
# Load coordinates (required for visualization)
if "coords" in data_group:
self.coordinates = data_group["coords"][:]
if self.coordinates.shape[0] != self._n_regions:
raise ValueError(
"Number of coordinates doesn't match matrix size"
)
elif "gmcoords" in data_group: # Alternative name
self.coordinates = data_group["gmcoords"][:]
if self.coordinates.shape[0] != self._n_regions:
raise ValueError(
"Number of coordinates doesn't match matrix size"
)
elif "coordinates" in data_group: # Alternative name
self.coordinates = data_group["coordinates"][:]
if self.coordinates.shape[0] != self._n_regions:
raise ValueError(
"Number of coordinates doesn't match matrix size"
)
else:
warnings.warn(
"No coordinates found. 3D visualization will not be available."
)
# Load colors (optional)
if "gmcolors" in data_group:
colors_data = data_group["gmcolors"][:]
# Decode bytes to strings if necessary
if colors_data.dtype.kind in ["S", "O"]:
colors_list = [
c.decode("utf-8") if isinstance(c, bytes) else str(c)
for c in colors_data
]
else:
colors_list = colors_data.tolist()
# Use harmonize_colors to convert to RGB array
self.colors = cltcol.harmonize_colors(colors_list)
elif "colors" in data_group: # Alternative name (legacy format)
self.colors = data_group["colors"][:]
if self.colors.shape[0] != self._n_regions:
warnings.warn("Number of colors doesn't match matrix size")
else:
# Normalize colors to [0,1] range if needed
if np.max(self.colors) > 1:
self.colors = self.colors / 255.0
# Load region names (optional)
if "gmregions" in data_group:
names_data = data_group["gmregions"][:]
# Decode bytes to strings if necessary
if names_data.dtype.kind in ["S", "O"]:
self.region_names = [
n.decode("utf-8") if isinstance(n, bytes) else str(n)
for n in names_data
]
else:
self.region_names = names_data.tolist()
if len(self.region_names) != self._n_regions:
warnings.warn(
"Number of region names doesn't match matrix size"
)
elif "name" in data_group: # Alternative name
names_data = data_group["name"][:]
if names_data.dtype.kind in ["S", "O"]:
self.region_names = [
n.decode("utf-8") if isinstance(n, bytes) else str(n)
for n in names_data
]
else:
self.region_names = names_data.tolist()
if len(self.region_names) != self._n_regions:
warnings.warn(
"Number of region names doesn't match matrix size"
)
# Load region index (optional)
if "gmindex" in data_group:
self.region_index = data_group["gmindex"][:]
elif "index" in data_group:
self.region_index = data_group["index"][:]
else:
self.region_index = np.arange(self._n_regions)
# Load affine (optional)
if "affine" in data_group:
self.affine = data_group["affine"][:]
else:
self.affine = np.eye(4)
# Load connectivity type (optional)
if "type" in data_group.attrs:
self.connectivity_type = data_group.attrs["type"]
if isinstance(self.connectivity_type, bytes):
self.connectivity_type = self.connectivity_type.decode("utf-8")
except Exception as e:
raise RuntimeError(f"Error loading HDF5 file: {e}")
[docs]
def save_h5(self, filename: Union[str, Path], compression: bool = True) -> None:
"""
Save Connectome to HDF5 file.
Parameters:
-----------
filename : str or Path
Output HDF5 filename
compression : bool, optional
Whether to use gzip compression (default: True)
"""
if self.matrix is None:
raise ValueError("No connectivity matrix to save")
filename = Path(filename)
with h5py.File(filename, "w") as f:
# Create main group
grp = f.create_group("connmat")
# Save matrix (required)
if compression:
grp.create_dataset("matrix", data=self.matrix, compression="gzip")
else:
grp.create_dataset("matrix", data=self.matrix)
# Save coordinates (if available)
if self.coordinates is not None:
grp.create_dataset("coords", data=self.coordinates)
# Save colors (if available)
if self.colors is not None:
# Convert to hex strings
colors_hex = cltcol.harmonize_colors(self.colors, "hex")
# Ensure it's a list of strings
if isinstance(colors_hex, np.ndarray):
colors_hex = colors_hex.tolist()
# Save as UTF-8 encoded strings
dt = h5py.string_dtype(encoding="utf-8")
grp.create_dataset("gmcolors", data=colors_hex, dtype=dt)
# Save region names (if available)
if self.region_names is not None:
dt = h5py.string_dtype(encoding="utf-8")
grp.create_dataset("gmregions", data=self.region_names, dtype=dt)
# Save region index
if self.region_index is not None:
grp.create_dataset("gmindex", data=self.region_index)
# Save affine
grp.create_dataset("affine", data=self.affine)
# Save metadata
grp.attrs["type"] = self.connectivity_type
grp.attrs["n_regions"] = self.n_regions
grp.attrs["density"] = self.get_density()
print(f"Connectome saved to: {filename}")
[docs]
def get_roi_names(self) -> List[str]:
"""
Get region of interest (ROI) names. If not available, generate default names.
Returns:
--------
List[str] : List of ROI names
"""
if self.region_names is not None:
return self.region_names
else:
return self.get_default_roi_names()
[docs]
def get_roi_colors(self) -> np.ndarray:
"""
Get region of interest (ROI) colors. If not available, generate default colors.
Returns:
--------
np.ndarray : Array of ROI colors
"""
if self.colors is not None:
return self.colors
else:
return self.get_default_roi_colors()
[docs]
def get_roi_coordinates(self) -> Optional[np.ndarray]:
"""
Get region of interest (ROI) coordinates.
Returns:
--------
Optional[np.ndarray] : Array of ROI coordinates or None
"""
return self.coordinates
[docs]
def set_coordinates(self, coordinates: np.ndarray) -> None:
"""
Set 3D coordinates for brain regions.
Parameters:
-----------
coordinates : np.ndarray
Array of shape (n_regions, 3) with x, y, z coordinates
"""
if self.matrix is not None and coordinates.shape != (self.n_regions, 3):
raise ValueError(
f"Coordinates shape {coordinates.shape} doesn't match expected ({self.n_regions}, 3)"
)
self.coordinates = coordinates.copy()
[docs]
def set_colors(self, colors: Union[List, np.ndarray]) -> None:
"""
Set colors for brain regions.
Parameters:
-----------
colors : np.ndarray or List
Array of shape (n_regions, 3) with RGB values [0-1] or [0-255], or list of hex colors
"""
if self.matrix is not None and len(colors) != self.n_regions:
raise ValueError(
f"Colors length {len(colors)} doesn't match expected ({self.n_regions})"
)
self.colors = cltcol.harmonize_colors(colors)
[docs]
def set_region_names(self, names: List[str]) -> None:
"""
Set names for brain regions.
Parameters:
-----------
names : List[str]
List of region names
"""
if self.matrix is not None and len(names) != self.n_regions:
raise ValueError(
f"Number of names {len(names)} doesn't match number of regions {self.n_regions}"
)
self.region_names = names.copy()
[docs]
def get_default_roi_colors(self) -> np.ndarray:
"""
Generate default colors for regions if not available.
Returns:
--------
np.ndarray : RGB colors for each region
"""
colors = cltcol.create_distinguishable_colors(self.n_regions)
return colors
[docs]
def get_default_roi_names(self) -> List[str]:
"""
Generate default region names if not available.
Returns:
--------
List[str] : Default region names
"""
names = cltmisc.create_names_from_indices(np.arange(self.n_regions) + 1)
return names
[docs]
def get_density(self) -> float:
"""
Calculate the density of the connectivity matrix.
Returns:
-------
float
Proportion of non-zero connections (excluding diagonal)
"""
if self.matrix is None:
return 0.0
n = self.matrix.shape[0]
n_possible = n * (n - 1) # Exclude diagonal
# Count non-zero off-diagonal elements
mask = ~np.eye(n, dtype=bool)
n_connections = np.count_nonzero(self.matrix[mask])
return n_connections / n_possible if n_possible > 0 else 0.0
[docs]
def get_connectivity_stats(self) -> dict:
"""
Calculate basic connectivity statistics.
Returns:
--------
dict : Dictionary with connectivity statistics
"""
if self.matrix is None:
return {}
stats = {
"n_regions": self.n_regions,
"matrix_shape": self.matrix.shape,
"min_strength": np.min(self.matrix),
"max_strength": np.max(self.matrix),
"mean_strength": np.mean(self.matrix),
"std_strength": np.std(self.matrix),
"density": self.get_density(),
"node_strengths": np.sum(np.abs(self.matrix), axis=1),
}
if self.coordinates is not None:
stats["coord_ranges"] = {
"x": (np.min(self.coordinates[:, 0]), np.max(self.coordinates[:, 0])),
"y": (np.min(self.coordinates[:, 1]), np.max(self.coordinates[:, 1])),
"z": (np.min(self.coordinates[:, 2]), np.max(self.coordinates[:, 2])),
}
return stats
[docs]
def threshold(
self,
method: Literal["value", "sparsity"] = "value",
threshold: float = 0.0,
absolute: bool = True,
binarize: bool = False,
copy: bool = True,
) -> "Connectome":
"""
Threshold the connectivity matrix.
Parameters:
----------
method : {'value', 'sparsity'}
Thresholding method:
- 'value': Keep connections above threshold value
- 'sparsity': Keep top connections to achieve target sparsity
threshold : float
- For 'value': minimum connection strength to keep
- For 'sparsity': target sparsity level (0-1), proportion of connections to keep
absolute : bool, optional
Use absolute values for thresholding (default: True)
binarize : bool, optional
Convert to binary matrix after thresholding (default: False)
copy : bool, optional
If True, return new Connectome object; if False, modify in place (default: True)
Returns:
-------
Connectome
Thresholded Connectome object (new if copy=True, self if copy=False)
"""
if self.matrix is None:
raise ValueError("No connectivity matrix available")
matrix_thresh = self.matrix.copy()
if method == "value":
# Threshold by value
if absolute:
mask = np.abs(matrix_thresh) < threshold
else:
mask = matrix_thresh < threshold
matrix_thresh[mask] = 0
elif method == "sparsity":
# Threshold by sparsity (keep top connections)
if not 0 <= threshold <= 1:
raise ValueError("Sparsity threshold must be between 0 and 1")
# Get off-diagonal elements
n = matrix_thresh.shape[0]
mask_diag = ~np.eye(n, dtype=bool)
values = matrix_thresh[mask_diag]
# Use absolute values to determine which connections to keep
if absolute:
values_for_ranking = np.abs(values)
else:
values_for_ranking = values
# Calculate how many connections to keep
n_total = len(values)
n_keep = int(n_total * threshold)
if n_keep > 0:
# Find threshold value (keep connections >= this value)
sorted_values = np.sort(values_for_ranking)[::-1]
value_threshold = sorted_values[min(n_keep - 1, len(sorted_values) - 1)]
# Apply threshold
if absolute:
mask = np.abs(matrix_thresh) < value_threshold
else:
mask = matrix_thresh < value_threshold
matrix_thresh[mask] = 0
else:
matrix_thresh[:] = 0
else:
raise ValueError(f"Unknown method: {method}. Use 'value' or 'sparsity'")
# Binarize if requested
if binarize:
matrix_thresh = (matrix_thresh != 0).astype(np.float64)
# Keep diagonal at zero
np.fill_diagonal(matrix_thresh, 0)
if copy:
# Create new Connectome object
return Connectome(
data=matrix_thresh,
name=self.name,
coordinates=(
self.coordinates.copy() if self.coordinates is not None else None
),
colors=self.colors.copy() if self.colors is not None else None,
region_names=(
self.region_names.copy() if self.region_names is not None else None
),
region_index=(
self.region_index.copy() if self.region_index is not None else None
),
connectivity_type=self.connectivity_type,
affine=self.affine.copy(),
)
else:
# Modify in place
self.matrix = matrix_thresh
return self
[docs]
def get_subnetwork(
self, region_indices: Union[np.ndarray, List[int]], copy: bool = True
) -> "Connectome":
"""
Extract a subnetwork with selected regions.
Parameters:
----------
region_indices : np.ndarray or List[int]
Indices of regions to include
copy : bool, optional
If True, return new Connectome object; if False, modify in place (default: True)
Returns:
-------
Connectome
Subnetwork Connectome object
"""
if self.matrix is None:
raise ValueError("No connectivity matrix available")
idx = np.array(region_indices)
# Extract subnetwork data
sub_matrix = self.matrix[np.ix_(idx, idx)]
sub_coords = self.coordinates[idx] if self.coordinates is not None else None
sub_colors = self.colors[idx] if self.colors is not None else None
sub_names = (
[self.region_names[i] for i in idx]
if self.region_names is not None
else None
)
sub_index = self.region_index[idx] if self.region_index is not None else None
if copy:
return Connectome(
data=sub_matrix,
name=f"{self.name}_subnetwork" if self.name else "subnetwork",
coordinates=sub_coords,
colors=sub_colors,
region_names=sub_names,
region_index=sub_index,
connectivity_type=self.connectivity_type,
affine=self.affine.copy(),
)
else:
# Modify in place
self.matrix = sub_matrix
self.coordinates = sub_coords
self.colors = sub_colors
self.region_names = sub_names
self.region_index = sub_index
self._n_regions = len(idx)
return self
[docs]
def copy(self) -> "Connectome":
"""
Create a deep copy of the Connectome.
Returns:
-------
Connectome
Deep copy of the Connectome
"""
return Connectome(
data=self.matrix.copy() if self.matrix is not None else None,
name=self.name,
coordinates=(
self.coordinates.copy() if self.coordinates is not None else None
),
colors=self.colors.copy() if self.colors is not None else None,
region_names=(
self.region_names.copy() if self.region_names is not None else None
),
region_index=(
self.region_index.copy() if self.region_index is not None else None
),
connectivity_type=self.connectivity_type,
affine=self.affine.copy(),
)
[docs]
def print_info(self) -> None:
"""Print comprehensive information about the connectome."""
print(f"=== Connectome: {self.name} ===")
if self.matrix is None:
print("No connectivity data loaded.")
return
stats = self.get_connectivity_stats()
print(f"Number of regions: {stats['n_regions']}")
print(f"Connectivity type: {self.connectivity_type}")
print(f"Matrix shape: {stats['matrix_shape']}")
print(
f"Connection strength range: [{stats['min_strength']:.3f}, {stats['max_strength']:.3f}]"
)
print(f"Mean ± SD: {stats['mean_strength']:.3f} ± {stats['std_strength']:.3f}")
print(f"Network density: {stats['density']:.3f}")
if "coord_ranges" in stats:
print("Coordinate ranges:")
for axis, (min_val, max_val) in stats["coord_ranges"].items():
print(f" {axis.upper()}: [{min_val:.1f}, {max_val:.1f}]")
else:
print("No coordinate data available")
print(f"Colors available: {'Yes' if self.colors is not None else 'No'}")
print(
f"Region names available: {'Yes' if self.region_names is not None else 'No'}"
)
if self.region_names is not None:
print(f"Sample regions: {self.region_names[:3]}...")
[docs]
def plot_matrix(
self,
figsize: Tuple[int, int] = (12, 10),
log_scale: bool = False,
show_labels: bool = True,
cmap: str = "RdBu_r",
threshold: Optional[float] = None,
threshold_mode: str = "absolute",
) -> None:
"""
Plot the connectivity matrix as a heatmap.
Parameters:
-----------
figsize : tuple
Figure size (width, height)
show_labels : bool
Whether to show region names on axes
cmap : str
Colormap for the heatmap
threshold : float, optional
Threshold value for displaying connections. Values below threshold will be set to 0.
threshold_mode : str
How to apply threshold: 'absolute' (abs(value) > threshold) or 'raw' (value > threshold)
"""
if self.matrix is None:
raise ValueError("No connectivity matrix available")
plt.figure(figsize=figsize)
# Apply threshold if specified
matrix_to_plot = self.matrix.copy()
if threshold is not None:
if threshold_mode == "absolute":
mask = np.abs(matrix_to_plot) < threshold
else: # raw mode
mask = matrix_to_plot < threshold
matrix_to_plot[mask] = 0
# Apply log scale if specified
if log_scale:
# Use symmetric log scale to handle negative values
matrix_to_plot = np.sign(matrix_to_plot) * np.log1p(np.abs(matrix_to_plot))
# Create heatmap
im = plt.imshow(matrix_to_plot, cmap=cmap, aspect="equal")
cbar = plt.colorbar(im, label="Connection Strength")
# Add threshold info to title
title = f"Connectivity Matrix - {self.name}"
if threshold is not None:
title += f" (threshold: {threshold}, mode: {threshold_mode})"
# Add labels if available and requested
if (
show_labels
and self.region_names is not None
and len(self.region_names) < 50
):
plt.xticks(
range(len(self.region_names)),
self.region_names,
rotation=45,
ha="right",
fontsize=8,
)
plt.yticks(range(len(self.region_names)), self.region_names, fontsize=8)
plt.title(title)
plt.xlabel("Brain Regions")
plt.ylabel("Brain Regions")
plt.tight_layout()
plt.show()
[docs]
def plot_circular_graph(
self,
figsize: Tuple[int, int] = (12, 12),
threshold: Optional[float] = None,
node_size_property: str = "strength",
node_size_scale: float = 1000,
edge_width_scale: float = 5,
show_labels: bool = True,
label_distance: float = 1.1,
edge_alpha: float = 0.6,
node_alpha: float = 0.8,
edge_cmap: str = "plasma",
layout_seed: Optional[int] = 42,
) -> None:
"""
Plot the connectivity matrix as a circular graph.
Parameters:
-----------
figsize : tuple
Figure size (width, height)
threshold : float, optional
Minimum connection strength to display edges
node_size_property : str
Property to scale node sizes by: 'strength', 'degree', 'uniform'
node_size_scale : float
Scale factor for node sizes
edge_width_scale : float
Scale factor for edge widths
show_labels : bool
Whether to show region labels
label_distance : float
Distance of labels from nodes (1.0 = at node border)
edge_alpha : float
Transparency of edges (0-1)
node_alpha : float
Transparency of nodes (0-1)
edge_cmap : str
Colormap for edges based on connection strength
layout_seed : int, optional
Random seed for consistent layout
"""
try:
import networkx as nx
except ImportError:
raise ImportError(
"NetworkX is required for circular graph visualization. "
"Install with: pip install networkx"
)
if self.matrix is None:
raise ValueError("No connectivity matrix available")
# Create figure
fig, ax = plt.subplots(figsize=figsize)
# Create adjacency matrix for graph
adj_matrix = self.matrix.copy()
# Apply threshold if specified
if threshold is not None:
adj_matrix[np.abs(adj_matrix) < threshold] = 0
# Create NetworkX graph
G = nx.from_numpy_array(adj_matrix)
# Get circular layout
if layout_seed is not None:
np.random.seed(layout_seed)
pos = nx.circular_layout(G)
# Calculate node sizes
if node_size_property == "uniform":
node_sizes = [
node_size_scale * 0.1
] * self.n_regions # Convert to reasonable size for circular plot
elif node_size_property == "strength":
strengths = np.sum(np.abs(adj_matrix), axis=1)
if np.max(strengths) > 0:
normalized_strengths = strengths / np.max(strengths)
else:
normalized_strengths = np.ones_like(strengths)
node_sizes = normalized_strengths * node_size_scale + node_size_scale * 0.1
elif node_size_property == "degree":
degrees = np.array([G.degree(node) for node in G.nodes()])
if np.max(degrees) > 0:
normalized_degrees = degrees / np.max(degrees)
else:
normalized_degrees = np.ones_like(degrees)
node_sizes = normalized_degrees * node_size_scale + node_size_scale * 0.1
else:
raise ValueError(f"Unknown node size property: {node_size_property}")
# Get node colors
node_colors = self.get_roi_colors()
# Get edge weights and colors
edges = G.edges()
edge_weights = []
edge_colors = []
for edge in edges:
weight = abs(adj_matrix[edge[0], edge[1]])
edge_weights.append(weight * edge_width_scale)
edge_colors.append(weight)
# Normalize edge colors
if edge_colors:
edge_colors = np.array(edge_colors)
if np.max(edge_colors) > 0:
edge_colors = edge_colors / np.max(edge_colors)
# Draw edges
if edges:
nx.draw_networkx_edges(
G,
pos,
width=edge_weights,
edge_color=edge_colors,
edge_cmap=plt.cm.get_cmap(edge_cmap),
alpha=edge_alpha,
ax=ax,
)
# Draw nodes
nx.draw_networkx_nodes(
G,
pos,
node_size=node_sizes,
node_color=node_colors,
alpha=node_alpha,
ax=ax,
)
# Add labels if requested
if show_labels:
# Get region names
region_names = self.get_roi_names()
# Create labels dictionary
labels = {i: region_names[i] for i in range(self.n_regions)}
# Calculate label positions
label_pos = {}
for node, (x, y) in pos.items():
# Move labels slightly outward from nodes
angle = np.arctan2(y, x)
label_x = x * label_distance
label_y = y * label_distance
label_pos[node] = (label_x, label_y)
# Draw labels
nx.draw_networkx_labels(
G, label_pos, labels=labels, font_size=8, font_weight="bold", ax=ax
)
# Set title
title = f"Circular Graph - {self.name}"
if threshold is not None:
title += f" (threshold: {threshold})"
if node_size_property != "uniform":
title += f" (node size: {node_size_property})"
ax.set_title(title, fontsize=16, fontweight="bold", pad=20)
# Remove axes
ax.set_axis_off()
# Make layout tight
plt.tight_layout()
# Add colorbar for edges if there are edges
if edges and len(edge_colors) > 0:
# Create a dummy plot for colorbar
sm = plt.cm.ScalarMappable(
cmap=plt.cm.get_cmap(edge_cmap),
norm=plt.Normalize(
vmin=(
np.min(np.abs(adj_matrix)[adj_matrix != 0])
if threshold is None
else threshold
),
vmax=np.max(np.abs(adj_matrix)),
),
)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, shrink=0.8, pad=0.1)
cbar.set_label("Connection Strength", rotation=270, labelpad=20)
plt.show()
[docs]
def visualize_3d(
self,
connectivity_threshold: float = 0.1,
node_size_scale: float = 1.0,
edge_width_scale: float = 1.0,
show_edges: bool = True,
show_labels: bool = False,
background_color: str = "black",
window_size: Tuple[int, int] = (1200, 800),
node_size_property: str = "strength",
base_node_size: float = 0.5,
) -> pv.Plotter:
"""
Create a 3D visualization of the connectome using PyVista.
Parameters:
-----------
connectivity_threshold : float
Minimum connection strength to display edges
node_size_scale : float
Scale factor for node sizes
edge_width_scale : float
Scale factor for edge widths
show_edges : bool
Whether to show connectivity edges
show_labels : bool
Whether to show region labels
background_color : str
Background color for the plot
window_size : tuple
Window size (width, height)
node_size_property : str
Property to scale node sizes by:
- 'strength': Total connectivity strength (sum of absolute connections)
- 'degree': Number of connections above threshold
- 'uniform': All nodes same size (base_node_size)
- 'betweenness': Betweenness centrality (requires networkx)
- 'eigenvector': Eigenvector centrality (requires networkx)
base_node_size : float
Base size for nodes when using 'uniform' or as minimum size for other properties
Returns:
--------
pv.Plotter : PyVista plotter object
"""
if self.matrix is None:
raise ValueError("No connectivity matrix available")
if self.coordinates is None:
raise ValueError("No coordinates available for 3D visualization")
# Create plotter
plotter = pv.Plotter(window_size=window_size)
plotter.set_background(background_color)
# Center coordinates around origin
coords_centered = self.coordinates - np.mean(self.coordinates, axis=0)
# Calculate node sizes based on selected property
node_sizes = self._calculate_node_sizes(
node_size_property, connectivity_threshold, node_size_scale, base_node_size
)
# Get colors (use provided or generate defaults)
colors = self.get_roi_colors()
# Get region names (use provided or generate defaults)
region_names = self.get_roi_names()
# Add nodes (brain regions)
for i in range(self.n_regions):
# Create sphere for each region
sphere = pv.Sphere(radius=node_sizes[i], center=coords_centered[i])
# Add sphere to plotter with color
plotter.add_mesh(
sphere,
color=colors[i],
opacity=0.8,
smooth_shading=True,
name=f"region_{i}",
)
# Add labels if requested
if show_labels:
plotter.add_point_labels(
coords_centered[i : i + 1],
[region_names[i]],
font_size=8,
text_color="white",
)
# Add connectivity edges
if show_edges:
# Get upper triangle indices (avoid duplicate edges)
i_indices, j_indices = np.triu_indices(self.n_regions, k=1)
for idx in range(len(i_indices)):
i, j = i_indices[idx], j_indices[idx]
connection_strength = abs(self.matrix[i, j])
if connection_strength > connectivity_threshold:
# Create line between regions
points = np.array([coords_centered[i], coords_centered[j]])
line = pv.Line(points[0], points[1])
# Scale line width based on connection strength
line_width = connection_strength * edge_width_scale * 5 + 1
# Color edges based on connection strength
edge_color = plt.cm.plasma(
connection_strength / np.max(np.abs(self.matrix))
)[:3]
plotter.add_mesh(
line,
color=edge_color,
line_width=line_width,
opacity=0.6,
name=f"edge_{i}_{j}",
)
# Set up camera and lighting
plotter.camera_position = "xy"
plotter.add_axes()
# Add title
title = f"Brain Connectivity Network - {self.name}"
if node_size_property != "uniform":
title += f" (node size: {node_size_property})"
plotter.add_title(title, font_size=16, color="white")
return plotter
[docs]
def save_visualization(self, filename: str, **kwargs) -> None:
"""
Save a 3D visualization to file.
Parameters:
-----------
filename : str
Output filename for the visualization
**kwargs : dict
Additional arguments passed to visualize_3d()
"""
plotter = self.visualize_3d(**kwargs)
plotter.screenshot(filename)
plotter.close()
def __repr__(self) -> str:
"""String representation of the Connectome object."""
if self.matrix is None:
return f"Connectome(name='{self.name}', no data loaded)"
return f"Connectome(name='{self.name}', type='{self.connectivity_type}', n_regions={self.n_regions}, density={self.get_density():.3f})"