import os
import numpy as np
import nibabel as nib
from typing import Union, List, Dict, Optional, Tuple
from pathlib import Path
import pyvista as pv
import pandas as pd
import copy
import warnings
# Importing local modules
from . import freesurfertools as cltfree
from . import misctools as cltmisc
from . import colorstools as cltcol
####################################################################################################
####################################################################################################
############ ############
############ ############
############ Section 1: Class and methods work with meshes ############
############ ############
############ ############
####################################################################################################
####################################################################################################
[docs]
class Surface:
"""
Comprehensive class for loading and visualizing brain surface data.
Provides interface for working with brain surface geometries including loading
from files or arrays, managing scalar maps and parcellations, and creating
visualizations using PyVista. Supports FreeSurfer and other surface formats.
Attributes
----------
surf : str or None
Path to surface file if loaded from file.
mesh : pv.PolyData
PyVista mesh object containing surface geometry and data.
hemi : str
Hemisphere designation ('lh', 'rh', or 'unknown').
colortables : dict
Dictionary storing color table information for parcellations.
Examples
--------
>>> # Load from FreeSurfer surface file
>>> surface = Surface('lh.pial')
>>>
>>> # Create from vertex/face arrays
>>> surface = Surface(vertices=verts, faces=faces, hemi='lh')
>>>
>>> # Load scalar data and parcellations
>>> surface.load_scalar_map('thickness.mgh', 'thickness')
>>> surface.load_annotation('lh.aparc.annot', 'aparc')
"""
##############################################################################################
[docs]
def __init__(
self,
surface_file: Union[str, Path] = None,
vertices: np.ndarray = None,
faces: np.ndarray = None,
color: Union[str, np.ndarray] = "#f0f0f0",
alpha: float = 1.0,
hemi: str = None,
) -> None:
"""
Initialize Surface object from file, arrays, or create empty instance.
Parameters
----------
surface_file : str or Path, optional
Path to surface file (FreeSurfer .pial, .white, .inflated). Default is None.
vertices : np.ndarray, optional
Vertex coordinates array with shape (n_vertices, 3). Default is None.
faces : np.ndarray, optional
Face connectivity array with shape (n_faces, 3). Default is None.
color : str or np.ndarray, optional
Color for the surface mesh. Can be a hex color string (e.g., '#f0f0f0')
or an RGB array in [0, 1] or [0, 255] range. Default is '#f0f0f0'.
alpha : float, optional
Alpha transparency value in [0, 1] for the surface color. Default is 1.0 (opaque).
hemi : str, optional
Hemisphere designation ('lh' or 'rh'). Auto-detected from filename
if None. Default is None.
Raises
------
ValueError
If both surface_file and vertices/faces are provided, or if only
one of vertices/faces is provided.
FileNotFoundError
If surface file doesn't exist.
Examples
--------
>>> # Load from file with auto-detection
>>> surface = Surface('lh.pial')
>>>
>>> # Create from arrays
>>> vertices = np.random.rand(100, 3)
>>> faces = np.array([[0, 1, 2], [1, 2, 3]])
>>> surface = Surface(vertices=vertices, faces=faces, hemi='lh')
>>>
>>> # Create empty instance
>>> surface = Surface()
"""
# Initialize attributes to None (empty instance)
self.surf = None
self.mesh = None
self.hemi = None
self.active_scalar = "default"
self.colortables: Dict[str, Dict] = {}
# Create the defalt colortable for the surface
# Set the colortable for the surface
# Validate alpha value
if isinstance(alpha, int):
alpha = float(alpha)
# If the alpha is not in the range [0, 1], raise an error
if not (0 <= alpha <= 1):
raise ValueError(f"Alpha value must be in the range [0, 1], got {alpha}")
color = cltcol.harmonize_colors(color, output_format="rgb")
tmp_ctable = cltcol.colors_to_table(colors=color, alpha_values=alpha)
tmp_ctable[:, :3] = tmp_ctable[:, :3] # Ensure colors are between 0 and 1
# Store parcellation information in organized structure
self.colortables["default"] = {
"names": ["default"],
"color_table": tmp_ctable,
"lookup_table": None, # Will be populated by _create_parcellation_colortable if needed
}
# Validate input parameters
if surface_file is not None and (vertices is not None or faces is not None):
raise ValueError("Cannot specify both surface_file and vertices/faces")
if vertices is not None and faces is None:
raise ValueError("If vertices are provided, faces must also be provided")
if faces is not None and vertices is None:
raise ValueError("If faces are provided, vertices must also be provided")
# Load data if provided
if surface_file is not None:
if isinstance(surface_file, Path):
surface_file = str(surface_file)
if isinstance(surface_file, str):
if not os.path.isfile(surface_file):
raise FileNotFoundError(f"Surface file not found: {surface_file}")
self.surf = surface_file
self.load_from_file(surface_file, color, alpha, hemi)
elif isinstance(surface_file, pv.PolyData):
self.load_from_mesh(surface_file, color, alpha, hemi)
elif vertices is not None and faces is not None:
self.load_from_arrays(vertices, faces, color, alpha, hemi=hemi)
################################################################################################
[docs]
def load_from_file(
self,
surface_file: Union[str, Path],
color: Union[str, np.ndarray] = "#f0f0f0",
alpha: np.float32 = 1.0,
hemi: str = None,
) -> None:
"""
Load surface geometry from FreeSurfer or compatible surface file.
Parameters
----------
surface_file : str or Path
Path to surface file (e.g., FreeSurfer .pial, .white, .inflated).
color : str or np.ndarray, optional
Color for the surface mesh. Can be a hex color string (e.g., '#f0f0f0')
or an RGB array in [0, 1] or [0, 255] range. Default is '#f0f0f0'.
alpha : float, optional
Alpha transparency value in [0, 1] for the surface color. Default is 1.0 (opaque).
hemi : str, optional
Hemisphere designation ('lh' or 'rh'). Auto-detected from filename
if None. Default is None.
Raises
------
FileNotFoundError
If surface file cannot be found.
ValueError
If surface file format is unsupported or corrupted.
Notes
-----
- Automatically detects hemisphere from filename if not provided.
- Converts color string to RGB array if provided as hex.
- Adds alpha channel to color for RGBA representation.
- Creates default parcellation data for visualization.
- Uses nibabel to read FreeSurfer geometry files.
- Handles both left ('lh') and right ('rh') hemisphere surfaces.
- If color is not provided, defaults to light gray ('#f0f0f0').
Examples
--------
>>> surface = Surface()
>>> surface.load_from_file('lh.pial')
>>> print(f"Loaded {surface.mesh.n_points} vertices")
>>>
>>> # Explicit hemisphere specification
>>> surface.load_from_file('brain_surface.surf', hemi='rh')
"""
# Check if the surface file exists
if isinstance(surface_file, Path):
surface_file = str(surface_file)
if not os.path.isfile(surface_file):
raise FileNotFoundError(f"Surface file not found: {surface_file}")
# Store the surface file path
self.surf = surface_file
# Validate alpha value
if isinstance(alpha, int):
alpha = float(alpha)
# If the alpha is not in the range [0, 1], raise an error
if not (0 <= alpha <= 1):
raise ValueError(f"Alpha value must be in the range [0, 1], got {alpha}")
# Handle color input
color = cltcol.harmonize_colors(color, output_format="rgb")
# Load the surface geometry
try:
vertices, faces = nib.freesurfer.read_geometry(self.surf)
# Add column with 3's to faces array for PyVista
faces = np.c_[np.full(len(faces), 3), faces]
mesh = pv.PolyData(vertices, faces)
# Add default surface colors if not present
# Adding the mesh
self.mesh = mesh
except Exception as e:
try:
# Try loading with PyVista directly for other formats
self.mesh = pv.read(self.surf)
except Exception as e2:
raise ValueError(
f"Failed to load surface file {self.surf} with error: {e}\n"
f"Also failed with PyVista: {e2}"
) from e2
# Hemisphere detection from filename
if hemi is not None:
self.hemi = hemi
else:
self.hemi = cltfree.detect_hemi(self.surf)
# Fallback hemisphere detection from BIDS organization
surf_name = os.path.basename(self.surf)
detected_hemi = cltfree.detect_hemi(surf_name)
if detected_hemi is None:
self.hemi = "lh" # Default to left hemisphere
# Create default parcellation data
self._create_default_parcellation(
color=color,
alpha=alpha,
)
##############################################################################################
[docs]
def load_from_arrays(
self,
vertices: np.ndarray,
faces: np.ndarray,
color: Union[str, np.ndarray] = "#f0f0f0",
alpha: float = 1.0,
hemi: str = None,
surface_file: str = None,
) -> None:
"""
Load surface geometry from vertex and face arrays.
Parameters
----------
vertices : np.ndarray
Vertex coordinates with shape (n_vertices, 3).
faces : np.ndarray
Face connectivity with shape (n_faces, 3).
color : str or np.ndarray, optional
Color for the surface mesh. Can be a hex color string (e.g., '#f0f0f0')
or an RGB array in [0, 1] or [0, 255] range. Default is '#f0f0f0'.
alpha : float, optional
Alpha transparency value in [0, 1] for the surface color. Default is 1.0 (opaque).
hemi : str, optional
Hemisphere designation ('lh' or 'rh'). Defaults to 'lh'.
surface_file : str, optional
Associated surface file path for metadata. Default is None.
Raises
------
ValueError
If vertices or faces arrays have incorrect shapes.
Examples
--------
>>> # Basic triangle mesh
>>> vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]])
>>> faces = np.array([[0, 1, 2]])
>>> surface = Surface()
>>> surface.load_from_arrays(vertices, faces, hemi='lh')
"""
# Validate alpha value
if isinstance(alpha, int):
alpha = float(alpha)
# If the alpha is not in the range [0, 1], raise an error
if not (0 <= alpha <= 1):
raise ValueError(f"Alpha value must be in the range [0, 1], got {alpha}")
# Handle color input
color = cltcol.harmonize_colors(color, output_format="rgb")
self.surf = surface_file
self.mesh = self.create_mesh_from_arrays(vertices, faces)
self.hemi = hemi if hemi is not None else "lh" # Default to left hemisphere
# Create default parcellation data
self._create_default_parcellation(
color=color,
alpha=alpha,
)
##############################################################################################
[docs]
def load_from_mesh(
self,
mesh: pv.PolyData,
color: Union[str, np.ndarray] = "#f0f0f0",
alpha: float = 1.0,
hemi: str = None,
) -> None:
"""
Load surface geometry from existing PyVista mesh object.
Parameters
----------
mesh : pv.PolyData
PyVista mesh object containing surface geometry.
color : str or np.ndarray, optional
Color for the surface mesh. Can be a hex color string (e.g., '#f0f0f0')
or an RGB array in [0, 1] or [0, 255] range. Default is '#f0f0f0'.
alpha : float, optional
Alpha transparency value in [0, 1] for the surface color. Default is 1.0 (opaque).
hemi : str, optional
Hemisphere designation ('lh' or 'rh'). Defaults to 'lh'.
Notes
-----
Creates a deep copy of the input mesh to avoid modifying the original.
Adds default surface colors if not present in the mesh.
Examples
--------
>>> # From existing PyVista mesh
>>> existing_mesh = pv.PolyData(vertices, faces)
>>> surface = Surface()
>>> surface.load_from_mesh(existing_mesh, hemi='rh')
>>>
>>> # From procedural mesh
>>> sphere = pv.Sphere(radius=50)
>>> surface.load_from_mesh(sphere, hemi='lh')
"""
self.surf = None
self.mesh = copy.deepcopy(mesh) # Make a copy to avoid modifying the original
self.hemi = hemi if hemi is not None else "lh" # Default to left hemisphere
# Validate alpha value
if isinstance(alpha, int):
alpha = float(alpha)
# If the alpha is not in the range [0, 1], raise an error
if not (0 <= alpha <= 1):
raise ValueError(f"Alpha value must be in the range [0, 1], got {alpha}")
# Handle color input
color = cltcol.harmonize_colors(color, output_format="rgb")
# Ensure mesh has default surface colors if not present
self._create_default_parcellation(color=color, alpha=alpha)
##############################################################################################
[docs]
def is_loaded(self) -> bool:
"""
Check whether surface data has been loaded.
Returns
-------
bool
True if surface data is loaded, False otherwise.
Examples
--------
>>> surface = Surface()
>>> print(surface.is_loaded()) # False
>>> surface.load_from_file('lh.pial')
>>> print(surface.is_loaded()) # True
"""
return self.mesh is not None
##############################################################################################
def _create_default_parcellation(
self, color: Union[str, np.ndarray] = "#f0f0f0", alpha: np.ndarray = 1.0
) -> None:
"""
Create default parcellation data for surface visualization.
Internal method that sets up basic parcellation with uniform surface
colors for initial visualization before loading specific annotations.
Parameters
----------
color : str or np.ndarray
Color for the surface mesh. Can be a hex color string (e.g., '#f0f0f0')
or an RGB array in [0, 1] or [0, 255] range.
alpha : float
Alpha transparency value in [0, 1] for the surface color. Default is 1.0 (opaque).
Notes
-----
Creates a single-region parcellation with default gray color values
assigned to all vertices.
"""
tmp_ctable = cltcol.colors_to_table(colors=color, alpha_values=alpha)
tmp_ctable[:, :3] = tmp_ctable[:, :3] / 255 # Ensure colors are between 0 and 1
self._store_parcellation_data(
np.ones((self.mesh.n_points,), dtype=np.uint32) * int(tmp_ctable[0, 4]),
tmp_ctable,
["default"],
"default",
)
##############################################################################################
[docs]
def create_mesh_from_arrays(
self, vertices: np.ndarray, faces: np.ndarray
) -> pv.PolyData:
"""
Create PyVista mesh object from vertex and face arrays.
Parameters
----------
vertices : np.ndarray
Vertex coordinates with shape (n_vertices, 3).
faces : np.ndarray
Face connectivity with shape (n_faces, 3).
Returns
-------
pv.PolyData
PyVista mesh object with vertices, faces, and default surface colors.
Raises
------
ValueError
If arrays have incorrect shapes or face indices are invalid.
Notes
-----
Validates input arrays and creates properly formatted PyVista mesh
with default surface colors. Adds normals to point data if provided.
Examples
--------
>>> surface = Surface()
>>> vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]])
>>> faces = np.array([[0, 1, 2]])
>>> mesh = surface.create_mesh_from_arrays(vertices, faces)
>>> print(f"Created mesh with {mesh.n_points} vertices")
"""
# Validate array shapes
if vertices.ndim != 2 or vertices.shape[1] != 3:
raise ValueError("Vertices array must have shape (n_vertices, 3)")
if faces.ndim != 2 or faces.shape[1] != 3:
raise ValueError("Faces array must have shape (n_faces, 3)")
# Check that face indices are valid
if np.any(faces >= len(vertices)) or np.any(faces < 0):
raise ValueError(
"Face indices must be valid indices into the vertices array"
)
mesh = self._create_pyvista_mesh(vertices, faces)
return mesh
##############################################################################################
def _create_pyvista_mesh(
self, vertices: np.ndarray, faces: np.ndarray
) -> pv.PolyData:
"""
Internal method to create PyVista mesh from vertices and faces.
Parameters
----------
vertices : np.ndarray
Vertex coordinates with shape (n_vertices, 3).
faces : np.ndarray
Face connectivity with shape (n_faces, 3).
Returns
-------
pv.PolyData
PyVista mesh object with default surface colors.
Notes
-----
Handles PyVista-specific formatting requirements including adding
the face size prefix and setting up default point data.
"""
# Add column with 3's to faces array for PyVista
faces_pv = np.c_[np.full(len(faces), 3), faces]
mesh = pv.PolyData(vertices, faces_pv)
vertices_colors = (
np.ones((len(vertices), 3), dtype=np.uint8) * 240
) # Default colors
mesh.point_data["default"] = np.c_[
vertices_colors, np.ones(len(vertices), dtype=np.uint8) * 255
]
return mesh
##############################################################################################
[docs]
def get_vertices(self) -> np.ndarray:
"""
Get vertex coordinates from the surface mesh.
Returns
-------
np.ndarray
Array of vertex coordinates with shape (n_vertices, 3).
Raises
------
RuntimeError
If no surface data has been loaded.
Examples
--------
>>> surface = Surface('lh.pial')
>>> vertices = surface.get_vertices()
>>> print(f"Surface has {len(vertices)} vertices")
>>> print(f"First vertex: {vertices[0]}")
"""
if not self.is_loaded():
raise RuntimeError("No surface data loaded. Load data first.")
return self.mesh.points
##############################################################################################
[docs]
def get_faces(self) -> np.ndarray:
"""
Get face connectivity from the surface mesh.
Returns
-------
np.ndarray
Array of face indices with shape (n_faces, 3). Each row contains
three vertex indices forming a triangular face.
Raises
------
RuntimeError
If no surface data has been loaded.
Notes
-----
Extracts face connectivity from PyVista's internal format which stores
faces as [n_vertices, vertex_id1, vertex_id2, ...]. This method returns
only the vertex indices in standard format.
Examples
--------
>>> surface = Surface('lh.pial')
>>> faces = surface.get_faces()
>>> print(f"Surface has {len(faces)} triangular faces")
>>> print(f"First face connects vertices: {faces[0]}")
"""
if not self.is_loaded():
raise RuntimeError("No surface data loaded. Load data first.")
# PyVista stores faces as [n_vertices, vertex_id1, vertex_id2, ...]
# We need to extract just the vertices indices
faces_raw = self.mesh.faces
n_faces = self.mesh.n_cells
faces = faces_raw.reshape(n_faces, 4)[
:, 1:4
] # Skip the first column (n_vertices)
return faces
###############################################################################################
[docs]
def get_edges(self, return_counts: bool = False) -> np.ndarray:
"""
Extract unique edges from a triangular mesh using vectorized operations.
This function efficiently extracts all unique edges from a triangular mesh
represented as a faces array. Each triangle contributes three edges, and
the function automatically removes duplicates that occur when triangles
share edges.
Parameters
----------
return_counts : bool, optional
If True, also return the count of how many faces each edge belongs to.
This is useful for identifying boundary edges (count=1) vs interior
edges (count=2). Default is False.
Returns
-------
edges : np.ndarray of shape (n_edges, 2)
Array of unique edges where each row contains two vertex indices
[v1, v2] with v1 <= v2. Edges are sorted lexicographically.
counts : np.ndarray of shape (n_edges,), optional
Number of faces that contain each edge. Only returned if
return_counts=True. Boundary edges have count=1, interior edges
have count=2.
Raises
------
ValueError
If faces array does not have exactly 3 columns (not triangular).
If faces array is empty.
If faces array contains negative indices.
Examples
--------
>>> # Simple triangular mesh with 2 triangles sharing an edge
>>> faces = np.array([[0, 1, 2], [1, 3, 2]])
>>> edges = get_edges(faces)
>>> print(edges)
[[0 1]
[0 2]
[1 2]
[1 3]
[2 3]]
>>> # Get edge counts to identify boundary vs interior edges
>>> edges, counts = get_edges(faces, return_counts=True)
>>> boundary_edges = edges[counts == 1]
>>> interior_edges = edges[counts == 2]
>>> print("Boundary edges:", boundary_edges)
>>> print("Interior edges:", interior_edges)
Boundary edges: [[0 1]
[0 2]
[1 3]
[2 3]]
Interior edges: [[1 2]]
>>> # Cube mesh (8 vertices, 12 triangular faces)
>>> cube_faces = np.array([
... [0, 1, 2], [0, 2, 3], # Bottom face
... [4, 5, 6], [4, 6, 7], # Top face
... [0, 1, 5], [0, 5, 4], # Front face
... [2, 3, 7], [2, 7, 6], # Back face
... [0, 3, 7], [0, 7, 4], # Left face
... [1, 2, 6], [1, 6, 5] # Right face
... ])
>>> edges = get_edges(cube_faces)
>>> print(f"Cube has {len(edges)} unique edges")
Cube has 18 unique edges
Notes
-----
This function uses vectorized NumPy operations for high performance on
large meshes. The algorithm:
1. Extracts all three edges from each triangle simultaneously
2. Sorts vertex pairs to canonical form (smaller index first)
3. Uses numpy.unique to efficiently remove duplicates
Time complexity: O(n log n) where n is the number of faces
Space complexity: O(n) for intermediate arrays
For non-triangular meshes, use the general `extract_edges_from_faces`
function instead.
The canonical edge representation ensures that edge (i, j) and edge (j, i)
are treated as the same edge, with the final representation always having
the smaller vertex index first.
See Also
--------
extract_edges_from_faces : General version for arbitrary polygon meshes
numpy.unique : Used internally for deduplication
"""
# Getting the faces array from the mesh
faces = (
self.get_faces()
) # Extract only the vertex indices, skip the first column
# Input validation
if faces.size == 0:
raise ValueError("Faces array cannot be empty")
if faces.ndim != 2 or faces.shape[1] != 3:
raise ValueError(
f"Faces array must have shape (n_faces, 3), got {faces.shape}"
)
if np.any(faces < 0):
raise ValueError("Faces array cannot contain negative vertex indices")
# Extract all edges from all triangles using vectorized operations
# Each triangle contributes 3 edges: (v0,v1), (v1,v2), (v2,v0)
all_edges = np.concatenate(
[
faces[:, [0, 1]], # Edge from vertex 0 to vertex 1
faces[:, [1, 2]], # Edge from vertex 1 to vertex 2
faces[:, [2, 0]], # Edge from vertex 2 to vertex 0
],
axis=0,
)
# Sort each edge to canonical form (smaller vertex index first)
# This ensures (i,j) and (j,i) are treated as the same edge
canonical_edges = np.sort(all_edges, axis=1)
# Remove duplicate edges and optionally count occurrences
if return_counts:
unique_edges, counts = np.unique(
canonical_edges, axis=0, return_counts=True
)
return unique_edges, counts
else:
unique_edges = np.unique(canonical_edges, axis=0)
return unique_edges
###############################################################################################
[docs]
def get_boundary_edges(self) -> np.ndarray:
"""
Extract only the boundary edges from a triangular mesh.
Boundary edges are those that belong to only one triangle, indicating
the mesh boundary or holes in the mesh.
Parameters
----------
faces : np.ndarray of shape (n_faces, 3)
Triangular mesh faces array.
Returns
-------
boundary_edges : np.ndarray of shape (n_boundary_edges, 2)
Array of boundary edges where each edge belongs to only one face.
Examples
--------
>>> # Mesh with a hole (incomplete sphere)
>>> faces = np.array([[0, 1, 2], [1, 3, 2], [3, 4, 2]])
>>> boundary = get_boundary_edges(faces)
>>> print("Boundary edges:", boundary)
"""
edges, counts = self.get_edges(return_counts=True)
return edges[counts == 1]
###############################################################################################
[docs]
def get_manifold_edges(self) -> np.ndarray:
"""
Extract only the manifold (interior) edges from a triangular mesh.
Manifold edges are those shared by exactly two triangles, indicating
proper mesh topology without boundaries or non-manifold geometry.
Parameters
----------
faces : np.ndarray of shape (n_faces, 3)
Triangular mesh faces array.
Returns
-------
manifold_edges : np.ndarray of shape (n_manifold_edges, 2)
Array of manifold edges where each edge belongs to exactly two faces.
Examples
--------
>>> # Closed mesh (tetrahedron)
>>> faces = np.array([[0, 1, 2], [0, 2, 3], [0, 3, 1], [1, 2, 3]])
>>> manifold = get_manifold_edges(faces)
>>> print("Manifold edges:", manifold)
"""
# Getting the faces array from the mesh
faces = self.mesh.faces[:, 1:4] # Extract only the vertex indices
edges, counts = self.get_edges(faces, return_counts=True)
return edges[counts == 2]
##############################################################################################
[docs]
def compute_normals(self) -> None:
"""
Compute and store vertex normals for the surface mesh.
Calculates unit normal vectors for each vertex and stores them in the
mesh point data under the key "Normals". Normals are automatically
normalized to unit length.
Raises
------
RuntimeError
If no surface data has been loaded.
RuntimeError
If computed normals have zero length and cannot be normalized.
Notes
-----
Uses PyVista's built-in normal computation which averages face normals
at each vertex. The resulting normals are forced to be unit vectors.
Overwrites any existing normals in the mesh.
Examples
--------
>>> surface = Surface('lh.pial')
>>> surface.compute_normals()
>>> normals = surface.get_normals()
>>> print(f"Computed {len(normals)} unit normal vectors")
>>>
>>> # Check that normals are unit vectors
>>> norms = np.linalg.norm(normals, axis=1)
>>> print(f"Normal lengths range: {norms.min():.3f} - {norms.max():.3f}")
"""
if not self.is_loaded():
raise RuntimeError("No surface data loaded. Load data first.")
self.mesh.compute_normals(inplace=True) # Compute normals and store in mesh
# Force the normals to be unit vectors
if "Normals" in self.mesh.point_data:
normals = self.mesh.point_data["Normals"]
norms = np.linalg.norm(normals, axis=1)
if np.any(norms > 0):
self.mesh.point_data["Normals"] = normals / norms[:, np.newaxis]
else:
raise RuntimeError(
"Computed normals have zero length. Cannot normalize."
)
##############################################################################################
[docs]
def get_normals(self) -> Optional[np.ndarray]:
"""
Get vertex normals from the surface mesh if available.
Returns
-------
np.ndarray or None
Array of normal vectors with shape (n_vertices, 3) if normals
have been computed, None otherwise.
Notes
-----
Returns None if normals haven't been computed yet. Use compute_normals()
to calculate normals before calling this method.
Examples
--------
>>> surface = Surface('lh.pial')
>>> normals = surface.get_normals()
>>> if normals is not None:
... print(f"Found {len(normals)} normal vectors")
... else:
... print("No normals computed yet")
... surface.compute_normals()
... normals = surface.get_normals()
"""
if not self.is_loaded():
return None
return self.mesh.point_data.get("Normals", None)
##############################################################################################
[docs]
def load_annotation(
self,
annotation: Union[str, Path, cltfree.AnnotParcellation],
parc_name: str = None,
) -> None:
"""
Load parcellation annotation onto surface for visualization.
Loads FreeSurfer annotation files or AnnotParcellation objects,
storing labels and color information for region-based visualization.
Parameters
----------
annotation : str , Path or AnnotParcellation
Path to annotation file (.annot) or AnnotParcellation object.
parc_name : str
Name for parcellation reference in visualizations.
Raises
------
FileNotFoundError
If annotation file cannot be found.
ValueError
If invalid input type or vertex count mismatch.
Examples
--------
>>> # Load Desikan-Killiany parcellation
>>> surface.load_annotation('lh.aparc.annot', 'aparc')
>>>
>>> # Load from object
>>> annot = AnnotParcellation('lh.aparc.a2009s.annot')
>>> surface.load_annotation(annot, 'destrieux')
"""
if isinstance(annotation, Path):
annotation = str(annotation)
# Handle different input types
if isinstance(annotation, str):
# Input is a file path
if not os.path.isfile(annotation):
raise FileNotFoundError(f"Annotation file not found: {annotation}")
# Create AnnotParcellation object to benefit from its processing and cleaning
annot_parc = cltfree.AnnotParcellation()
annot_parc.load_from_file(parc_file=annotation, annot_id=parc_name)
elif (
hasattr(annotation, "codes")
and hasattr(annotation, "regtable")
and hasattr(annotation, "regnames")
):
# Input is an AnnotParcellation object
annot_parc = copy.deepcopy(annotation)
if parc_name is not None:
annot_parc.id = parc_name
else:
raise ValueError(
"annot_input must be either a file path (str) or an AnnotParcellation object"
)
# Extract the processed and cleaned data from AnnotParcellation
labels = annot_parc.codes
reg_ctable = annot_parc.regtable.astype(np.float32) # Ensure colors are float32
reg_names = annot_parc.regnames # Already processed as strings
# Validate that the number of vertices matches
if len(labels) != self.mesh.n_points:
raise ValueError(
f"Number of vertices in annotation ({len(labels)}) does not match surface ({self.mesh.n_points})"
)
# If parc_name is not provided, use the annot_id from AnnotParcellation
if parc_name is None:
parc_name = annot_parc.id
# Store the parcellation data
tmp_colors = reg_ctable[:, :3]
reg_ctable[:, :3] = reg_ctable[:, :3] / 255 # Ensure colors are between 0 and 1
# If all the opacity values are 0 set them to 1
if np.all(reg_ctable[:, 3] == 0):
reg_ctable[:, 3] = 1.0
self._store_parcellation_data(labels, reg_ctable, reg_names, parc_name)
# Store reference to AnnotParcellation object for advanced operations
self.colortables[parc_name]["annot_object"] = annot_parc
##############################################################################################
def _store_parcellation_data(
self,
labels: np.ndarray,
reg_ctable: np.ndarray,
reg_names: List[str],
parc_name: str,
) -> None:
"""
Store parcellation data and create color mappings.
Internal method for organizing parcellation labels, colors, and names
in surface object for visualization and analysis.
Parameters
----------
labels : np.ndarray
Label values for each vertex.
reg_ctable : np.ndarray
Color table with RGBA values for each region.
reg_names : list
Region names corresponding to color table.
parc_name : str
Name of the parcellation.
Notes
-----
Stores labels in mesh point data and creates organized color table
structure. Also calls color table creation for visualization.
"""
# Store labels in mesh
self.mesh.point_data[parc_name] = labels
# Store parcellation information in organized structure
self.colortables[parc_name] = {
"names": reg_names,
"color_table": reg_ctable,
"lookup_table": None, # Will be populated by _create_parcellation_colortable if needed
}
##############################################################################################
def _get_parcellation_data(
self, annotation: Union[str, Path, cltfree.AnnotParcellation]
) -> cltfree.AnnotParcellation:
"""
Load or retrieve parcellation data from annotation file or object.
Handles both file paths and AnnotParcellation objects, ensuring
consistent data retrieval for visualization and analysis.
Parameters
----------
annotation : str, Path or AnnotParcellation
Path to annotation file (.annot) or AnnotParcellation object.
Returns
-------
cltfree.AnnotParcellation
AnnotParcellation object containing parcellation data.
Raises
------
FileNotFoundError
If annotation file cannot be found and no colortable is available.
ValueError
If annotation input type is invalid or does not match expected formats.
Notes
-----
- If annotation is a file path, it checks for existence and loads the data.
- If annotation is an AnnotParcellation object, it returns a deep copy to avoid
modifying the original object.
- If the annotation file is not found, it attempts to load from colortables
if available.
- If no colortable is available, it raises a FileNotFoundError.
Examples
--------
>>> # Load from annotation file
>>> annot_parc = surface._get_parcellation_data('lh.aparc.annot')
>>>
"""
if isinstance(annotation, Path):
annotation = str(annotation)
# If map_name is not provided, use the first column name from the DataFrame
if isinstance(annotation, str):
# Check if the annotation file exists
if not os.path.isfile(annotation) and annotation in self.mesh.point_data:
# If the annotation file is not found, try to load it from the colortables
# Extract the annotation data
maps_array = self.mesh.point_data[annotation]
# If there is a colortable for this map, use it
if annotation in self.colortables:
ctable = self.colortables[annotation]["color_table"]
struct_names = self.colortables[annotation]["names"]
# Create AnnotParcellation object from the data
parc = cltfree.AnnotParcellation()
parc.create_from_data(maps_array, ctable, struct_names)
else:
raise FileNotFoundError(
f"Annotation file not found: {annotation} and no colortable available"
)
elif os.path.isfile(annotation):
# Create AnnotParcellation object to benefit from its processing and cleaning
parc = cltfree.AnnotParcellation()
parc.load_from_file(parc_file=annotation)
else:
raise FileNotFoundError(f"Annotation file not found: {annotation}")
elif isinstance(annotation, cltfree.AnnotParcellation):
parc = copy.deepcopy(
annotation
) # Use a copy to avoid modifying the original object
return parc
##############################################################################################
[docs]
def load_scalar_maps(
self,
scalar_map: Union[str, Path, np.ndarray, pd.DataFrame],
annotation: Union[str, Path, cltfree.AnnotParcellation] = None,
maps_names: Union[str, List[str]] = None,
) -> None:
"""
Load data from a FreeSurfer vertex-wise map, a numpy array, a CSV file or
pandas Dataframe onto surface for visualization.
Handles both vertex-wise data (one value per vertex) and region-wise data
(requires annotation for mapping to vertices). It is important that the CSV file
has a header row with column names because the first row is used to name the maps.
If it contains region-wise data then the Annotation file is mandatory.
Parameters
----------
scalar_map : str, Path, pd.DataFrame
Path to a FreeSurfer vertex-wise map file, a CSV file, a numpy array or a
pandas DataFrame.
annotation : str, Path or AnnotParcellation, optional
Annotation file/object for mapping region data to vertices.
Required if the Dataframe has region-wise data. Default is None.
maps_names : str or list, optional
Names for scalar data. If None, uses column names from CSV.
Default is None.
Raises
------
FileNotFoundError
If map file or annotation file cannot be found.
ValueError
If annot_file required but not provided or invalid type.
ValueError
If maps_names length does not match number of columns in CSV.
Notes
-----
Automatically detects if the array, the CSV or the array contains vertex-wise
or region-wise data based on the number of rows. If the number of rows matches
the number of vertices in the mesh, it is treated as vertex-wise data.
Otherwise, it is treated as region-wise data and requires an annotation file to
map region values to vertices.
The annotation can be provided as a file path, a string or as an AnnotParcellation object.
If the annotation file is not found, it will try to load it from the colortables
associated with the surface.
If maps_names is not provided, it uses the column names from the CSV file.
If the annotation is provided as a string, it will try to load it as an AnnotParcellation
object. If it is provided as an AnnotParcellation object, it will use it directly.
Examples
--------
>>> surf_lh = cltsurf.Surface("/opt/freesurfer/subjects/fsaverage/surf/lh.pial")
>>> # Example 1: Reading a region-wise map from a CSV file and selecting a specific column name
>>> print("Example 1: Reading a region-wise map from a CSV file with a specific column name")
>>> surf_lh.load_maps_scalar_maps("/tmp/values.csv",
annotation="/opt/freesurfer/subjects/fsaverage/label/lh.aparc.annot",
maps_names="region_index")
>>> print("Loaded maps from CSV file for an specific column name:")
>>> print(surf_lh.list_overlays())
>>> print("")
>>> # Example 2: Reading a region-wise map from a dataframe and selecting a column with specified name
>>> import pandas as pd
>>> values_df = pd.read_csv("/tmp/values.csv")
>>> print("Example 2: Reading a region-wise map from a DataFrame with specified names")
>>> surf_lh.load_maps_scalar_maps(values_df,
annotation="/opt/freesurfer/subjects/fsaverage/label/lh.aparc.annot",
maps_names=["value"])
>>> print(" Loaded maps from DataFrame with specified names:")
>>> print(surf_lh.list_overlays())
>>> print("")
>>> # Example 3: Reading a region-wise map from a numpy array without specifiying names
>>> import pandas as pd
>>> print("Example 3: Reading a region-wise map from a numpy array without specifying names")
>>> values_df = pd.read_csv("/tmp/values.csv")
>>> surf_lh.load_maps_scalar_maps(values_df.to_numpy(),
annotation="/opt/freesurfer/subjects/fsaverage/label/lh.aparc.annot")
>>> print("Loaded maps from numpy array without specifying names:")
>>> print(surf_lh.list_overlays())
>>> print("")
>>> ######### Creating a csv with values as the number of vertices
>>> import pandas as pd
>>> n_points = surf_lh.mesh.n_points
>>> values_df = pd.DataFrame({'vertex_index': np.arange(n_points), 'vertex_value': np.random.rand(n_points)})
>>> values_df.to_csv("/tmp/values-vertexwise.csv", index=False)
>>> # Example 4: Loading vertex-wise maps from a CSV file
>>> print("Example 4: Loading vertex-wise maps from a CSV file")
>>> surf_lh.load_maps_scalar_maps("/tmp/values-vertexwise.csv"
)
>>> print("Loaded vertex-wise maps from CSV file:")
>>> print(surf_lh.list_overlays())
>>> print("")
>>> # Example 5: Reading a region-wise map from a numpy array and an specified name
>>> print("Example 5: Reading a region-wise map from a numpy array with specified names")
>>> import numpy as np
>>> values_array = np.random.rand(n_points)
>>> surf_lh.load_maps_scalar_maps(values_array,
maps_names=["ex5_vertex_value_array"])
>>> print("Loaded vertex-wise maps from numpy array:")
>>> print(surf_lh.list_overlays())
>>> print("")
>>> # Example 6: Creating a numpy array without specifying names
>>> values_array = np.random.rand(n_points)
>>> print("Example 6: Creating a numpy array with values as the number of vertices without specifying names")
>>> surf_lh.load_maps_scalar_maps(values_array)
>>> print("Loaded vertex-wise maps from numpy array without specifying names:")
>>> print(surf_lh.list_overlays())
>>> print("")
>>> # Example 7: Reading a FreeSurfer map file
>>> print("Example 7: Reading a FreeSurfer map file")
>>> surf_lh.load_maps_scalar_maps("/opt/freesurfer/subjects/fsaverage/surf/lh.thickness",
maps_names=["cthickness"])
>>> print("Loaded vertex-wise maps from FreeSurfer map file:")
>>> print(surf_lh.list_overlays())
>>> print("")
>>> # Example 8: Reading multiple FreeSurfer map files
>>> print("Example 8: Reading multiple FreeSurfer map files specifying names")
>>> list_of_maps = [
"/opt/freesurfer/subjects/fsaverage/surf/lh.thickness",
"/opt/freesurfer/subjects/fsaverage/surf/lh.curv",
"/opt/freesurfer/subjects/fsaverage/surf/lh.sulc"]
>>> maps_names = ["thickness", "curvature", "sulc"]
>>> for i, map_file in enumerate(list_of_maps):
surf_lh.load_maps_scalar_maps(map_file, maps_names=maps_names[i])
>>> print("Loaded vertex-wise maps from FreeSurfer map files:")
>>> print(surf_lh.list_overlays())
>>> print("")
"""
if maps_names is not None:
if isinstance(maps_names, str):
maps_names = [maps_names]
# Load the scalar map based on its type
try:
if isinstance(scalar_map, pd.DataFrame):
# If map_file is a DataFrame, use it directly
maps_df = copy.deepcopy(scalar_map)
# Filter the columns that that are equal to the maps_names if maps_names is provided
if maps_names is not None:
maps_df = maps_df[maps_names]
elif isinstance(scalar_map, np.ndarray):
if scalar_map.ndim == 1:
# If it is a row vector convert it to a column vector
# If scalar_map is a 1D numpy array, convert it to a 2D array
scalar_map = scalar_map[:, np.newaxis]
# If the maps names are not provided, create default names
if maps_names is None:
# Create default names for the maps
tmp_names = [f"map_{i}" for i in range(scalar_map.shape[1])]
maps_df = pd.DataFrame(scalar_map, columns=tmp_names)
else:
if len(maps_names) != scalar_map.shape[1]:
raise ValueError(
"Length of maps_names must match the number of columns in the numpy array"
)
# If scalar_map is a numpy array, convert it to a DataFrame
maps_df = pd.DataFrame(scalar_map, columns=maps_names)
else:
if isinstance(scalar_map, Path):
scalar_map = str(scalar_map)
if not os.path.isfile(scalar_map):
raise FileNotFoundError(f"Map file not found: {scalar_map}")
# Read the map file into a DataFrame
maps_df = cltmisc.smart_read_table(scalar_map)
if maps_names is not None:
maps_df = maps_df[maps_names]
if annotation is not None:
if isinstance(annotation, Path):
annotation = str(annotation)
if not isinstance(annotation, (str, cltfree.AnnotParcellation)):
raise ValueError(
"annotation must be a string or an AnnotParcellation object"
)
if maps_names is not None:
if isinstance(maps_names, str):
maps_names = [maps_names]
elif not isinstance(maps_names, list):
raise ValueError("maps_names must be a string or a list of strings")
if len(maps_names) != maps_df.shape[1]:
raise ValueError(
"Length of maps_names must match the number of columns in the DataFrame"
)
else:
# If maps_names is not provided, use the column names from the DataFrame
maps_names = maps_df.columns.tolist()
# If the number of rows of the dataframe is equal to the number of vertices, we can use it directly
if maps_df.shape[0] == self.mesh.n_points:
vertex_maps = maps_df.to_numpy()
else:
if annotation is None:
raise ValueError(
"annotation must be provided if map_file does not match the number of vertices"
)
# Extracting the parcellation data
parc = self._get_parcellation_data(annotation)
vertex_maps = parc.map_values(
regional_values=maps_df, is_dataframe=True
).to_numpy()
for i, map_name in enumerate(maps_names):
# Ensure the map data is a 1D array
map_data = vertex_maps[:, i]
# Store the map data in the mesh point data
self.mesh.point_data[map_name] = map_data
except:
if isinstance(scalar_map, (str, Path)):
if not os.path.isfile(scalar_map):
raise FileNotFoundError(f"Map file not found: {scalar_map}")
# Read the map file
tmp_map = nib.freesurfer.read_morph_data(str(scalar_map))
if tmp_map.shape[0] == self.mesh.n_points:
if maps_names is None:
# If maps_names is not provided, use the file name as the map name
map_name = os.path.splitext(os.path.basename(scalar_map))[0]
else:
if len(maps_names) != 1:
raise ValueError(
"maps_names must be a single string or a list with one name"
)
self.mesh.point_data[maps_names[0]] = tmp_map
else:
raise ValueError(
f"Map file {scalar_map} does not match the number of vertices"
)
###############################################################################################
[docs]
def separate_mesh_components(
self,
component_labels: Optional[str] = "components",
labels_to_extract: Optional[List[int]] = None,
clean_mesh: bool = True,
preserve_order: bool = False,
) -> List[pv.PolyData]:
"""
Separate a mesh into independent submeshes based on connected component labels.
This method extracts disconnected components from a mesh, creating separate PolyData
objects for each component. Each submesh contains only the vertices, faces, and
point data belonging to that specific component.
Parameters
----------
mesh : pv.PolyData
Input PyVista mesh containing multiple disconnected components.
Must be a triangulated surface mesh (PolyData).
component_labels : str, default="components"
Name of the point data array containing component labels for each vertex.
This array should contain integer labels identifying which component
each vertex belongs to.
labels_to_extract : List[int], optional
Specific component labels to extract. If None, extracts all unique
labels found in the component_labels array. Use this to extract
only specific components of interest.
clean_mesh : bool, default=True
If True, removes unused vertices and ensures faces use consecutive
vertex indices in each submesh. If False, preserves original vertex
indices which may result in sparse vertex arrays.
preserve_order : bool, default=False
If True, submeshes are returned in the same order as labels_to_extract
or sorted label order. If False, may return in arbitrary order for
better performance.
Returns
-------
List[pv.PolyData]
List of PyVista PolyData objects, one for each connected component.
Each submesh contains:
- Vertices belonging to that component
- Faces using only those vertices (with reindexed vertex references)
- All original point data arrays, filtered for the component's vertices
- Original mesh properties and metadata where applicable
Raises
------
TypeError
If mesh is not a PyVista PolyData object.
ValueError
If component_labels field doesn't exist, contains invalid data,
or if the mesh has no faces.
KeyError
If the specified component_labels field is not found in point data.
Examples
--------
>>> import pyvista as pv
>>> import numpy as np
>>>
>>> # Create a mesh with two disconnected triangles
>>> points1 = np.array([[0, 0, 0], [1, 0, 0], [0.5, 1, 0]])
>>> points2 = np.array([[2, 0, 0], [3, 0, 0], [2.5, 1, 0]])
>>> points = np.vstack([points1, points2])
>>> faces1 = np.array([[3, 0, 1, 2]]) # First triangle
>>> faces2 = np.array([[3, 3, 4, 5]]) # Second triangle
>>> faces = np.vstack([faces1, faces2])
>>>
>>> mesh = pv.PolyData(points, faces)
>>>
>>> # Add component labels (would normally come from connected_components)
>>> mesh.point_data["components"] = np.array([1, 1, 1, 2, 2, 2])
>>> mesh.point_data["temperature"] = np.random.rand(6) # Additional data
>>>
>>> # Separate into independent meshes
>>> submeshes = separate_mesh_components(mesh)
>>> print(f"Original mesh: {mesh.n_points} points, {mesh.n_faces} faces")
>>> for i, submesh in enumerate(submeshes):
... comp_label = submesh.point_data["components"][0]
... print(f"Component {comp_label}: {submesh.n_points} points, {submesh.n_faces} faces")
Original mesh: 6 points, 2 faces
Component 1: 3 points, 1 faces
Component 2: 3 points, 1 faces
>>> # Extract only specific components
>>> component_1_only = separate_mesh_components(mesh, labels_to_extract=[1])
>>> print(f"Component 1 only: {len(component_1_only)} submesh(es)")
Component 1 only: 1 submesh(es)
>>> # Check that point data is preserved
>>> original_temp = mesh.point_data["temperature"]
>>> submesh_temps = [sm.point_data["temperature"] for sm in submeshes]
>>> print("Temperature data preserved in submeshes:",
... len(submesh_temps[0]), len(submesh_temps[1]))
Temperature data preserved in submeshes: 3 3
Notes
-----
- Only triangular faces are currently supported
- Point data arrays are automatically filtered and copied to submeshes
- Cell data is not preserved as faces are restructured
- Vertex indices in faces are automatically remapped to be consecutive
- Empty components (no faces) are excluded from results
"""
from . import networktools as cltnet
if "components" not in self.mesh.point_data:
# Get the edges from the surface.
edges = self.get_edges()
print(f"Number of edges in the surface: {len(edges)}")
csr_graph = cltnet.edges_to_csr(edges, n_vertices=self.mesh.n_points)
components, labels = cltnet.connected_components(
csr_graph, return_labels=True
)
self.mesh.point_data["components"] = labels
n_components = len(np.unique(labels))
colors = cltcol.create_distinguishable_colors(n_components)
ctab = cltcol.colors_to_table(colors, alpha_values=255)
new_labels = np.zeros_like(labels, dtype=np.int32)
# Reassign labels to match color indices
for i in range(n_components):
new_labels[labels == i] = ctab[i, 4] # Start labels from 1
self.mesh.point_data["components"] = new_labels
struct_names = [f"component_{i}" for i in range(n_components)]
self.colortables["components"] = create_surface_colortable(
colors=colors, struct_names=struct_names
)
mesh = self.mesh
# Input validation
if not isinstance(mesh, pv.PolyData):
raise TypeError("Input mesh must be a PyVista PolyData object")
if mesh.n_cells == 0:
raise ValueError("Mesh must contain faces")
if component_labels not in mesh.point_data:
raise KeyError(
f"Component labels '{component_labels}' not found in mesh point data. "
f"Available arrays: {list(mesh.point_data.keys())}"
)
# Get component labels
labels = mesh.point_data[component_labels]
if labels.ndim != 1 or len(labels) != mesh.n_points:
raise ValueError(
f"Component labels must be a 1D array with length {mesh.n_points}"
)
# Determine which labels to extract
if labels_to_extract is None:
unique_labels = np.unique(labels)
if preserve_order:
unique_labels = np.sort(unique_labels)
labels_to_extract = unique_labels.tolist()
# Get faces as numpy array
if mesh.faces.size == 0:
raise ValueError("Mesh contains no faces")
# Convert PyVista faces format to standard format
# PyVista format: [n_vertices, v0, v1, v2, n_vertices, v3, v4, v5, ...]
faces_data = mesh.faces.reshape(-1, 4) # Assuming triangular faces
if not np.all(faces_data[:, 0] == 3):
warnings.warn(
"Non-triangular faces detected. Only triangular faces are supported.",
UserWarning,
)
triangle_faces = faces_data[:, 1:4] # Extract vertex indices only
if component_labels in self.colortables:
colortable = True
else:
colortable = False
submeshes = []
for label in labels_to_extract:
# Find vertices belonging to this component
vertex_mask = labels == label
vertex_indices = np.where(vertex_mask)[0]
if len(vertex_indices) == 0:
warnings.warn(
f"No vertices found for component label {label}", UserWarning
)
continue
# Find faces that use only vertices from this component
face_mask = np.all(np.isin(triangle_faces, vertex_indices), axis=1)
component_faces = triangle_faces[face_mask]
if len(component_faces) == 0:
warnings.warn(
f"No faces found for component label {label}", UserWarning
)
continue
# Extract vertices and create vertex mapping
component_points = mesh.points[vertex_indices]
if clean_mesh:
# Create mapping from old vertex indices to new consecutive indices
old_to_new_idx = {
old_idx: new_idx for new_idx, old_idx in enumerate(vertex_indices)
}
# Remap face vertex indices to new consecutive indices
remapped_faces = np.array(
[
[
old_to_new_idx[face[0]],
old_to_new_idx[face[1]],
old_to_new_idx[face[2]],
]
for face in component_faces
]
)
else:
# Keep original vertex indices, but create a sparse points array
max_vertex_idx = np.max(vertex_indices)
sparse_points = np.zeros((max_vertex_idx + 1, 3))
sparse_points[vertex_indices] = component_points
component_points = sparse_points
remapped_faces = component_faces
# Convert faces back to PyVista format
n_faces = len(remapped_faces)
pyvista_faces = np.column_stack(
[
np.full(n_faces, 3), # Number of vertices per face (triangles)
remapped_faces,
]
).ravel()
# Create new submesh
submesh = pv.PolyData(component_points, pyvista_faces)
# Copy all point data arrays for this component
for array_name, array_data in mesh.point_data.items():
if clean_mesh:
# Extract data for selected vertices only
submesh.point_data[array_name] = array_data[vertex_indices]
else:
# Create sparse array matching the sparse points
if array_data.ndim == 1:
sparse_data = np.zeros(max_vertex_idx + 1)
sparse_data[vertex_indices] = array_data[vertex_indices]
else:
sparse_data = np.zeros(
(max_vertex_idx + 1, array_data.shape[1])
)
sparse_data[vertex_indices] = array_data[vertex_indices]
submesh.point_data[array_name] = sparse_data
# Copy mesh metadata
if hasattr(mesh, "field_data"):
submesh.field_data.update(mesh.field_data)
subsurf_obj = Surface(submesh)
if colortable:
# Copy colortable if it exists for the component_labels
tmp_ctab = copy.deepcopy(self.colortables[component_labels])
index = np.where(tmp_ctab["color_table"][:, 4] == label)[0]
if len(index) > 0:
tmp_ctab["color_table"] = tmp_ctab["color_table"][index, :]
tmp_ctab["names"] = [tmp_ctab["names"][index[0]]]
else:
single_color_ctab = cltcol.colors_to_table(
np.array([[240, 240, 240]]), alpha_values=255
)
tmp_ctab["color_table"] = single_color_ctab
tmp_ctab["names"] = [f"component_{single_color_ctab[4]}"]
subsurf_obj.colortables[component_labels] = tmp_ctab
submeshes.append(subsurf_obj)
if preserve_order:
# Sort submeshes by component label to maintain order
submeshes.sort(key=lambda sm: sm.point_data[component_labels][0])
return submeshes
###############################################################################################
[docs]
def list_overlays(self) -> Dict[str, str]:
"""
List all available surface overlays and their data types.
Categorizes loaded data based on array dimensions and properties to
identify scalar maps, color data, normals, and other overlay types.
Returns
-------
dict
Dictionary mapping overlay names to their types:
- 'scalar': 1D arrays of scalar values per vertex
- 'color': 2D arrays with RGB color values (shape: n_vertices, 3)
- 'normals': 2D arrays with unit normal vectors (shape: n_vertices, 3)
- 'unknown': Arrays with other dimensions or unrecognized format
Notes
-----
Automatically detects data type based on:
- 1D arrays: Classified as scalar data
- 2D arrays with 3 columns: Checked for unit vectors (normals) vs colors
- Other dimensions: Classified as unknown
Normal vectors are identified by having unit length (norm ≈ 1) and
containing negative values.
Examples
--------
>>> # Load various data types
>>> surface.load_scalar_map('thickness.mgh', 'thickness')
>>> surface.load_annotation('aparc.annot', 'aparc')
>>> surface.compute_normals()
>>>
>>> # List all overlays
>>> overlays = surface.list_overlays()
>>> print(overlays)
{'surface': 'color', 'thickness': 'scalar', 'aparc': 'scalar', 'Normals': 'normals'}
>>>
>>> # Filter for scalar maps only
>>> scalar_maps = {k: v for k, v in overlays.items() if v == 'scalar'}
>>> print(f"Available scalar maps: {list(scalar_maps.keys())}")
"""
overlays = {}
for key in self.mesh.point_data.keys():
tmp = self.mesh.point_data[key]
if isinstance(tmp, np.ndarray) and tmp.ndim == 1:
# If it's a 1D array, but has a color table, treat it as scalar with colortable
if key in self.colortables:
overlays[key] = "scalar_with_colortable"
else:
overlays[key] = "scalar"
elif isinstance(tmp, np.ndarray) and tmp.ndim == 2:
if tmp.shape[1] == 3:
# If there are negative values and the norm is equal to 1, it's likely normals
if np.all(np.round(np.linalg.norm(tmp, axis=1)) == 1) and np.any(
tmp < 0
):
overlays[key] = "normals"
else:
overlays[key] = "color"
elif tmp.shape[1] == 4:
overlays[key] = "color_with_alpha"
else:
overlays[key] = "unknown"
return overlays
##############################################################################################
[docs]
def set_active_overlay(self, overlay_name: str) -> None:
"""
Set the active overlay for visualization.
Parameters
----------
overlay_name : str
Name of the overlay to set as active
Returns
-------
None
Raises
------
ValueError
If the specified overlay is not found in mesh point data
Examples
--------
>>> surface.set_active_overlay("thickness")
>>> surface.set_active_overlay("aparc")
"""
if overlay_name not in self.mesh.point_data:
raise ValueError(f"Overlay '{overlay_name}' not found in mesh point data")
self.mesh.set_active_scalars(overlay_name)
self.active_scalar = overlay_name
##############################################################################################
[docs]
def remove_overlay(self, overlay_name: str) -> None:
"""
Set the active overlay for visualization.
Designates which data array should be used as the primary scalar field
for coloring and visualization in PyVista plots. This affects how the
surface is colored when rendered.
Parameters
----------
overlay_name : str
Name of the overlay to set as active. Must exist in mesh point data.
Raises
------
ValueError
If the specified overlay is not found in mesh point data.
Notes
-----
The active overlay determines which data is used for:
- Surface coloring in visualizations
- Colormap application
- Scalar value display in interactive plots
PyVista uses the active scalars for automatic coloring unless
explicitly overridden in visualization methods.
Examples
--------
>>> # Set thickness as active for visualization
>>> surface.set_active_overlay('thickness')
>>>
>>> # Switch to parcellation display
>>> surface.set_active_overlay('aparc')
>>>
>>> # Check available overlays first
>>> overlays = surface.list_overlays()
>>> if 'curvature' in overlays:
... surface.set_active_overlay('curvature')
"""
# Check if overlay exists
if (
overlay_name not in self.mesh.point_data
and overlay_name not in self.colortables
):
raise ValueError(f"Overlay '{overlay_name}' not found")
# Remove from mesh point data
if overlay_name in self.mesh.point_data:
del self.mesh.point_data[overlay_name]
# Remove from colortables storage
if overlay_name in self.colortables:
del self.colortables[overlay_name]
# If this was the active scalar, reset to surface default
try:
active_scalars = self.mesh.active_scalars_name
if active_scalars == overlay_name:
if "default" in self.mesh.point_data:
self.mesh.set_active_scalars("default")
else:
# Find the first available overlay
remaining_overlays = list(self.mesh.point_data.keys())
if remaining_overlays:
self.mesh.set_active_scalars(remaining_overlays[0])
except:
# If there's any issue with active scalars, just continue
pass
##############################################################################################
[docs]
def get_overlay_info(self, overlay_name: str) -> Dict:
"""
Get information about a specific surface overlay.
Parameters
----------
overlay_name : str
Name of the overlay to query.
Returns
-------
Dict
Dictionary containing overlay metadata with keys:
- 'name' : str
Name of the overlay.
- 'data_shape' : tuple
Shape of the overlay data array.
- 'data_type' : str
NumPy data type of the overlay values.
- 'has_colortable' : bool
Whether the overlay has an associated color table.
- 'num_regions' : int, optional
Number of regions (if parcellation overlay).
- 'region_names' : list of str, optional
Names of regions (if parcellation overlay).
- 'has_annot_object' : bool, optional
Whether annotation object is available (if parcellation overlay).
Raises
------
ValueError
If the overlay is not found.
Examples
--------
>>> surface = Surface()
>>> info = surface.get_overlay_info("aparc")
>>> print(f"Overlay has {info['num_regions']} regions")
>>> print(f"Data type: {info['data_type']}")
"""
if overlay_name not in self.mesh.point_data:
print(f"Overlay '{overlay_name}' not found")
info = {
"name": overlay_name,
"data_shape": None,
"data_type": None,
"has_colortable": False,
}
else:
info = {
"name": overlay_name,
"data_shape": self.mesh.point_data[overlay_name].shape,
"data_type": str(self.mesh.point_data[overlay_name].dtype),
"has_colortable": overlay_name in self.colortables,
}
# Add colortable info if available
if overlay_name in self.colortables:
ctable_info = self.colortables[overlay_name]
info["num_regions"] = len(ctable_info["names"])
info["region_names"] = ctable_info["names"]
info["has_annot_object"] = "annot_object" in ctable_info
return info
##############################################################################################
[docs]
def get_region_vertices(self, parc_name: str, region_name: str) -> np.ndarray:
"""
Get vertices indices for a specific region in a parcellation.
Parameters
----------
parc_name : str
Name of the parcellation
region_name : str
Name of the region
Returns
-------
np.ndarray
Array of vertices indices belonging to the region
Raises
------
ValueError
If the parcellation is not found
ValueError
If the region is not found in the parcellation
Examples
--------
>>> # Get vertices in the precentral gyrus
>>> vertices = surface.get_region_vertices("aparc", "precentral")
>>> print(f"Precentral region has {len(vertices)} vertices")
>>>
>>> # Get all vertices in superior frontal region
>>> vertices = surface.get_region_vertices("aparc", "superiorfrontal")
"""
if parc_name not in self.colortables:
raise ValueError(f"Parcellation '{parc_name}' not found")
# Fallback to manual lookup
if region_name not in self.colortables[parc_name]["names"]:
raise ValueError(
f"Region '{region_name}' not found in parcellation '{parc_name}'"
)
# Find the label value for this region
region_idx = self.colortables[parc_name]["names"].index(region_name)
label_value = self.colortables[parc_name]["color_table"][region_idx, 4]
# Get vertices with this label
labels = self.mesh.point_data[parc_name]
return np.where(labels == label_value)[0]
##############################################################################################
[docs]
def get_vertexwise_colors(
self,
overlay_name: str = "default",
colormap: str = "viridis",
vmin: np.float64 = None,
vmax: np.float64 = None,
range_min: np.float64 = None,
range_max: np.float64 = None,
range_color: Tuple = (128, 128, 128, 255),
) -> None:
"""
Compute vertices colors for visualization based on the specified overlay.
This method processes the overlay data and creates appropiate vertices colors
for visualization, handling both scalar data (with colormaps) and
categorical data (with discrete color tables).
Parameters
----------
overlay_name : str, optional
Name of the overlay to visualize. If None, the first available overlay is used.
colormap : str, optional
Colormap to use for scalar overlays. If None, uses parcellation color table
for categorical data or 'viridis' for scalar data.
vmin : np.float64, optional
Minimum value for scaling the colormap. If None, uses the minimum value of the overlay
vmax : np.float64, optional
Maximum value for scaling the colormap. If None, uses the maximum value of the overlay
If both vmin and vmax are None, the colormap will be applied to the full range of the overlay values.
If both are provided, they will be used to scale the colormap.
range_min : np.float64, optional
Minimum value for defining a special color range. Values below this will be colored with range_color
range_max : np.float64, optional
Maximum value for defining a special color range. Values above this will be colored with range_color
range_color : Tuple, optional
RGBA color to use for values outside the defined range (range_min, range_max)
Returns
-------
vertices_colors : np.ndarray
Array of RGBA colors for each vertex in the mesh.
Raises
------
ValueError
If the specified overlay is not found in the mesh point data
ValueError
If no overlays are available
Notes
-----
This method sets the vertices colors based on the specified overlay.
Examples
--------
>>> # Prepare colors for a parcellation (uses discrete colors)
>>> surface.get_vertexwise_colors(overlay_name="aparc")
>>>
>>> # Prepare colors for scalar data with custom colormap
>>> surface.get_vertexwise_colors(overlay_name="thickness", colormap="hot")
>>>
>>> # Prepare colors for the surface overlay
>>> surface.get_vertexwise_colors()
"""
# Get the list of overlays
overlay_dict = self.list_overlays()
# If the dictionary is empty
overlays = list(overlay_dict.keys())
if overlay_name is None:
overlay_name = overlays[0] if overlay_dict else None
if overlay_name not in overlays:
raise ValueError(
f"Overlay '{overlay_name}' not found. Available overlays: {', '.join(overlays)}"
)
# Getting the values of the overlay
vertex_values = self.mesh.point_data[overlay_name]
# if colortables is an attribute of the class, use it
if hasattr(self, "colortables"):
dict_ctables = self.colortables
# Check if the overlay is on the colortables
if overlay_name in dict_ctables.keys():
# Use the colortable associated with the parcellation
vertices_colors = cltcol.get_colors_from_colortable(
vertex_values, self.colortables[overlay_name]["color_table"]
)
else:
# Use the colormap for scalar data
vertices_colors = cltcol.values2colors(
vertex_values,
cmap=colormap,
output_format="rgb",
vmin=vmin,
vmax=vmax,
range_min=range_min,
range_max=range_max,
range_color=range_color,
)
else:
vertices_colors = cltcol.values2colors(
vertex_values,
cmap=colormap,
output_format="rgb",
vmin=vmin,
vmax=vmax,
range_min=range_min,
range_max=range_max,
range_color=range_color,
)
return vertices_colors
##############################################################################################
[docs]
def prepare_colors(
self,
overlay_name: str = None,
cmap: str = "viridis",
vmin: np.float64 = None,
vmax: np.float64 = None,
range_min: np.float64 = None,
range_max: np.float64 = None,
range_color: Tuple = (128, 128, 128, 255),
) -> None:
"""
Prepare vertices colors for visualization based on the specified overlay.
This method processes the overlay data and creates appropiate vertices colors
for visualization, handling both scalar data (with colormaps) and
categorical data (with discrete color tables).
Parameters
----------
overlay_name : str, optional
Name of the overlay to visualize. If None, the first available overlay is used.
cmap : str, optional
Colormap to use for scalar overlays. If None, uses parcellation color table
for categorical data or 'viridis' for scalar data.
vmin : np.float64, optional
Minimum value for scaling the colormap. If None, uses the minimum value of the overlay
vmax : np.float64, optional
Maximum value for scaling the colormap. If None, uses the maximum value of the overlay
If both vmin and vmax are None, the colormap will be applied to the full range of the overlay values.
If both are provided, they will be used to scale the colormap.
range_min : np.float64, optional
Minimum value for defining a special color range. Values below this will be colored with range_color
range_max : np.float64, optional
Maximum value for defining a special color range. Values above this will be colored with range_color
range_color : Tuple, optional
RGBA color to use for values outside the defined range (range_min, range_max)
Returns
-------
None
Raises
------
ValueError
If the specified overlay is not found in the mesh point data
ValueError
If no overlays are available
Notes
-----
This method sets the vertices colors in the mesh based on the specified overlay.
The colors are stored in the mesh's point_data under the key "RGB"
and set as the active scalars for visualization.
Examples
--------
>>> # Prepare colors for a parcellation (uses discrete colors)
>>> surface.prepare_colors(overlay_name="aparc")
>>>
>>> # Prepare colors for scalar data with custom colormap
>>> surface.prepare_colors(overlay_name="thickness", cmap="hot")
>>>
>>> # Prepare colors for first available overlay
>>> surface.prepare_colors()
"""
# Getting the minimum and maximum values of the overlay
if vmin is None:
vmin = np.min(self.mesh.point_data[overlay_name])
if vmax is None:
vmax = np.max(self.mesh.point_data[overlay_name])
try:
# Setting NaNs and infinities to zero
vertex_values = self.mesh.point_data[overlay_name]
vertex_values = np.nan_to_num(
vertex_values,
nan=0.0,
) # Handle NaNs and infinities
self.mesh.point_data[overlay_name] = vertex_values
except KeyError:
raise ValueError(
f"Data array '{overlay_name}' not found in surface point_data"
)
# Apply colors to mesh data
self.mesh.point_data["rgba"] = self.get_vertexwise_colors(
overlay_name, cmap, vmin, vmax, range_min, range_max, range_color
)
##############################################################################################
[docs]
def add_surface(self, surf2add: Union["str", "Path", "Surface"]) -> "Surface":
"""
Merge this surface with others into a single surface.
This method merges multiple Surface objects by combining their geometries
and point data. Only point_data fields that are present in ALL surfaces
are retained in the merged result.
Parameters
----------
surf2add : str, Path or Surface
Surface to add. It can be a file path (str or Path) to a surface file
or another Surface object. If a file path is provided, it will be loaded.
Returns
-------
Surface
New merged Surface object with hemisphere set to "unknown"
Raises
------
TypeError
If surfaces is not a list or contains non-Surface objects
ValueError
If the surfaces list is empty
Examples
--------
>>> # Merge left and right hemisphere surfaces
>>> lh_surf = Surface("lh.pial")
>>> rh_surf = Surface("rh.pial")
>>> merged = lh_surf.add_surface([rh_surf])
>>> print(f"Merged surface has {merged.mesh.n_points} vertices")
>>>
>>> # Merge multiple surfaces from file paths
>>> surf1 = Surface("surface1.pial")
>>> merged = surf1.add_surface("surface2.pial")
"""
if isinstance(surf2add, (str, Path)):
if isinstance(surf2add, str):
if not os.path.isfile(surf2add):
raise FileNotFoundError(f"File '{surf2add}' not found")
elif isinstance(surf2add, Path):
# Check if Path is valid
if not surf2add.exists():
raise FileNotFoundError(f"Path '{str(surf2add)}' does not exist")
if not surf2add.is_file():
raise ValueError(f"Path '{str(surf2add)}' is not a file")
# Load the surface from file
surf2add = [Surface(surf2add)]
elif isinstance(surf2add, Surface):
surf2add = [surf2add]
if len(surf2add) == 0:
raise ValueError("surfaces list cannot be empty")
# Check that all items in the list are Surface objects
for i, surf in enumerate(surf2add):
if not isinstance(surf, Surface) and not isinstance(surf, str):
raise TypeError(f"Item at index {i} is not a Surface object")
# Include this surface in the list
all_surfaces = [self] + surf2add
# Find common point_data fields across all surfaces
common_fields = None
for surf in all_surfaces:
current_fields = set(surf.mesh.point_data.keys())
if common_fields is None:
common_fields = current_fields
else:
common_fields = common_fields.intersection(current_fields)
# Convert to list for consistent ordering
common_fields = list(common_fields)
# Prepare meshes with only common fields
meshes_to_merge = []
for surf in all_surfaces:
# Create a copy of the mesh
mesh_copy = copy.deepcopy(surf.mesh)
# Remove point_data fields that are not common
fields_to_remove = set(mesh_copy.point_data.keys()) - set(common_fields)
for field in fields_to_remove:
del mesh_copy.point_data[field]
meshes_to_merge.append(mesh_copy)
# Merge all meshes using PyVista
if len(meshes_to_merge) == 1:
merged_mesh = meshes_to_merge[0]
else:
merged_mesh = pv.merge(meshes_to_merge)
# Create new Surface object without calling __init__
merged_surface = Surface.__new__(Surface)
merged_surface.mesh = merged_mesh
merged_surface.hemi = "unknown"
merged_surface.surf = "merged_surface"
# Merge colortables - only keep those for common fields
merged_colortables = {}
for surf in all_surfaces:
for key, value in surf.colortables.items():
if key in common_fields:
# If key already exists, keep the first one encountered
if key not in merged_colortables:
# Deep copy the colortable data to avoid reference issues
if isinstance(value, dict):
merged_colortables[key] = {}
for k, v in value.items():
if isinstance(v, (list, np.ndarray)):
merged_colortables[key][k] = v.copy()
else:
merged_colortables[key][k] = v
else:
merged_colortables[key] = value
merged_surface.colortables = merged_colortables
merged_surface.active_scalar = self.active_scalar
return merged_surface
#############################################################################################
[docs]
def map_volume_to_surface(
self,
image: Union[str, np.ndarray, nib.Nifti1Image],
method: str = "nilearn",
interp_method: str = "linear",
overlay_name: str = None,
) -> np.ndarray:
"""
Map volumetric neuroimaging data onto a surface mesh.
This function projects 3D or 4D volumetric data (e.g., structural or functional
MRI) onto vertices of a surface mesh using spatial interpolation.
It supports multiple projection methods and can handle both file paths and
in-memory data arrays.
Parameters
----------
image : str, np.ndarray, or nibabel.Nifti1Image
Input volumetric data to project. Can be:
- String: File path to NIfTI image (.nii, .nii.gz, .mgz)
- Path: Path object to NIfTI image
- numpy.ndarray: 3D or 4D array of volumetric data
- nibabel.Nifti1Image: Loaded NIfTI image object
method : str, default "nilearn"
Projection method to use:
- "nilearn": Uses nilearn.surface.vol_to_surf (recommended for neuroimaging)
- "clabtoolkit": Uses clabtoolkit interpolation functions
interp_method : str, default "linear"
Interpolation method:
- "linear": Trilinear interpolation (smooth, good for continuous data)
- "nearest": Nearest-neighbor interpolation (preserves discrete values)
overlay_name : str, optional
Name to store projected data in surf_obj.mesh.point_data dictionary.
Only used with method="clabtoolkit". If None, data is not stored.
Returns
-------
np.ndarray
Projected surface data:
- For 3D input: 1D array of shape (n_vertices,)
- For 4D input: 2D array of shape (n_vertices, n_timepoints)
Values correspond to interpolated intensity at each surface vertex.
Raises
------
FileNotFoundError
If image file path does not exist.
ValueError
If image format is unsupported or has incompatible dimensions.
ImportError
If required dependencies (nilearn) are not installed.
RuntimeError
If projection fails due to coordinate system mismatch.
Notes
-----
**Coordinate Systems:**
- nilearn method expects vertices in world coordinates (mm) and handles
coordinate transformations internally using the image affine matrix
- clabtoolkit method expects vertices in voxel coordinates and requires
manual coordinate conversion for NIfTI images
**Performance:**
- nilearn is optimized for neuroimaging data and handles 4D efficiently
- clabtoolkit may be slower for 4D data as it processes timepoints sequentially
**4D Data Support:**
- Both methods support 4D fMRI data (x, y, z, time)
- Returns (vertices, timepoints) array for temporal analysis
Examples
--------
Project a structural image onto cortical surface:
>>> surface_data = map_volume_to_surface(
... surf_obj,
... "T1w.nii.gz",
... method="nilearn"
... )
>>> print(f"Projected {len(surface_data)} vertex values")
Project 4D fMRI data for time series analysis:
>>> fmri_surface = map_volume_to_surface(
... surf_obj,
... "task_fmri.nii.gz",
... method="nilearn",
... interp_method="linear"
... )
>>> print(f"Shape: {fmri_surface.shape}") # (vertices, timepoints)
>>> vertex_timeseries = fmri_surface[1000, :] # Time series for vertex 1000
Use numpy array input with clabtoolkit method:
>>> import numpy as np
>>> volume_data = np.random.rand(64, 64, 30) # Synthetic 3D data
>>> surface_data = map_volume_to_surface(
... surf_obj,
... volume_data,
... method="clabtoolkit",
... overlay_name="random_data"
... )
>>> # Data is now stored in surf_obj.mesh.point_data["random_data"]
Compare interpolation methods:
>>> linear_data = map_volume_to_surface(surf_obj, "mask.nii.gz",
... interp_method="linear")
>>> nearest_data = map_volume_to_surface(surf_obj, "mask.nii.gz",
... interp_method="nearest")
See Also
--------
nilearn.surface.vol_to_surf : Underlying nilearn projection function
clabtoolkit.imagetools.interpolate : Underlying clabtoolkit interpolation
"""
from . import imagetools as cltimg
# Get surface vertices in world coordinates (mm)
vertices = self.mesh.points
img = None
# Handle different input types and load data
if isinstance(image, Path):
image = str(image) # Convert Path to str for nibabel
if isinstance(image, str):
if not os.path.isfile(image):
raise FileNotFoundError(f"Image file not found: {image}")
img = nib.load(image)
scalar_data = img.get_fdata()
elif isinstance(image, nib.Nifti1Image):
img = image # Fixed: use 'image' not undefined 'img'
scalar_data = image.get_fdata()
elif isinstance(image, np.ndarray):
if image.ndim not in [3, 4]:
raise ValueError("Input numpy array must be 3D or 4D")
scalar_data = image
# img remains None - will raise error if nilearn method selected
else:
raise ValueError(
"Image must be a file path (str), nibabel.Nifti1Image object, or numpy.ndarray"
)
# Apply the selected projection method
if method == "nilearn":
if img is None:
raise ValueError(
"nilearn method requires a nibabel image object with coordinate "
"information. Cannot use with plain numpy arrays."
)
try:
from nilearn import surface as nlsurf
except ImportError:
raise ImportError(
"nilearn is required for this projection method. "
"Install with: pip install nilearn"
)
# Get mesh faces in correct format for nilearn
faces = self.mesh.faces.reshape(-1, 4)[:, 1:4]
# Project using nilearn (handles coordinate conversion internally)
interpolated_data = nlsurf.vol_to_surf(
img,
(vertices, faces), # Pass vertices in mm coordinates
interpolation=interp_method,
)
elif method == "clabtoolkit":
# Convert vertices to voxel coordinates if we have spatial info
if img is not None:
vertices_vox = cltimg.mm2vox(vertices, img.affine)
else:
# Assume vertices are already in appropriate coordinate space
vertices_vox = vertices
if scalar_data.ndim == 4:
# Handle 4D data by processing each timepoint
n_timepoints = scalar_data.shape[3]
n_vertices = len(vertices_vox)
interpolated_data = np.zeros((n_vertices, n_timepoints))
for t in range(n_timepoints):
interpolated_data[:, t] = cltimg.interpolate(
scalar_data[:, :, :, t],
vertices_vox,
interp_method=interp_method,
)
else:
# Handle 3D data
interpolated_data = cltimg.interpolate(
scalar_data, vertices_vox, interp_method=interp_method
)
# Store in mesh point data if requested
if overlay_name is not None:
if not isinstance(overlay_name, str):
raise ValueError("overlay_name must be a string")
self.mesh.point_data[overlay_name] = interpolated_data
else:
raise ValueError(
f"Unknown projection method: '{method}'. "
f"Supported methods: 'nilearn', 'clabtoolkit'"
)
return interpolated_data
##############################################################################################
[docs]
def save_surface(
self,
filename: str,
format: str = "freesurfer",
save_annotation: str = None,
map_name: str = None,
overwrite: bool = False,
) -> None:
"""
Save the surface mesh to a file in the specified format.
Exports the surface geometry (vertices and faces) and optionally associated
data to various file formats including FreeSurfer, VTK, PLY, STL, and OBJ.
Parameters
----------
filename : str
Output filename with or without extension. Extension will be added
automatically if missing for some formats.
format : str, default "freesurfer"
Output format: 'freesurfer', 'vtk', 'ply', 'stl', or 'obj'.
save_annotation : str, optional
Path to save annotation file (for parcellation data). Only applicable
for FreeSurfer.
map_name : str, optional
Name of overlay/parcellation to include with the surface data.
overwrite : bool, default False
Whether to overwrite existing files.
Raises
------
ValueError
If filename is invalid, format is unsupported, or file exists and
overwrite is False.
FileNotFoundError
If the output directory does not exist.
Examples
--------
>>> surface.save_surface("lh.pial.vtk", format="vtk")
>>> surface.save_surface("cortex.ply", format="ply", overwrite=True)
"""
if not isinstance(filename, str):
raise ValueError("filename must be a string")
if not filename:
raise ValueError("filename cannot be empty")
# Check if the filename exists as a valid path
if os.path.exists(filename) and not overwrite:
raise ValueError(
f"File '{filename}' already exists. Please set overwrite to True or choose a different name."
)
# Ensure the directory exists
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
raise FileNotFoundError(f"Directory '{directory}' does not exist")
# Save the mesh using PyVista's built-in methods
if format.lower() == "freesurfer":
self.export_to_freesurfer(filename, save_annotation, map_name, overwrite)
elif format.lower() == "obj":
self.export_to_obj(filename, save_annotation, map_name, overwrite)
elif format.lower() in ["vtk", "ply", "stl"]:
# Substitute the file extension to match the format
file_ext = os.path.splitext(filename)[1].lower()
if file_ext not in [".vtk", ".ply", ".stl"]:
if format.lower() == "vtk":
filename += ".vtk"
elif format.lower() == "ply":
filename += ".ply"
elif format.lower() == "stl":
filename += ".stl"
else:
raise ValueError(
f"Unsupported file format: {format}. Supported formats are 'vtk', 'ply', 'stl'."
)
self.export_to_pyvista(filename, save_annotation, map_name, overwrite)
# Print a message indicating the file was saved
print(f"Surface saved to {filename}")
else:
raise ValueError(
f"Unsupported file format: {format}. Supported formats are 'vtk', 'ply', 'stl', 'obj', and 'freesurfer'."
)
##############################################################################################
[docs]
def export_to_obj(
self,
filename: str,
save_annotation: str = None,
map_name: str = None,
overwrite: bool = False,
) -> None:
"""
Export the surface mesh to an OBJ file format.
Writes surface geometry as a Wavefront OBJ file, which stores vertices
and triangular faces in a simple text format widely supported by 3D
software and visualization tools.
Parameters
----------
filename : str
Output filename, should end with .obj extension.
save_annotation : str, optional
Path to save associated annotation file in FreeSurfer format.
map_name : str, optional
Name of parcellation/overlay to export alongside the geometry.
overwrite : bool, default False
Whether to overwrite existing files.
Raises
------
ValueError
If filename is invalid or file exists and overwrite is False.
FileNotFoundError
If the output directory does not exist.
Notes
-----
OBJ format uses 1-based indexing for face connectivity. The exported
file includes vertex coordinates and triangular face definitions.
Examples
--------
>>> surface.export_to_obj("brain_surface.obj")
>>> surface.export_to_obj("lh.pial.obj", save_annotation="lh.aparc.annot", map_name="aparc")
"""
# Validate filename
if not isinstance(filename, str):
raise ValueError("filename must be a string")
if not filename:
raise ValueError("filename cannot be empty")
# Check if the filename exists as a valid path
if os.path.exists(filename) and not overwrite:
raise ValueError(
f"File '{filename}' already exists. Please set overwrite to True or choose a different name."
)
# Ensure the directory exists
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
raise FileNotFoundError(f"Directory '{directory}' does not exist")
# If save_annotation is provided, save the annotation data
if save_annotation is not None:
self.export_annotation(
filename=save_annotation, parc_name=map_name, overwrite=overwrite
)
vertices = self.mesh.points
faces = self.mesh.regular_faces
with open(filename, "w") as f:
f.write(f"# OBJ file exported from Surface class\n")
f.write(f"# Vertices: {len(vertices)}\n")
f.write(f"# Faces: {len(faces)}\n\n")
# Write vertices
for vertex in vertices:
f.write(f"v {vertex[0]:.6f} {vertex[1]:.6f} {vertex[2]:.6f}\n")
# Write faces (OBJ uses 1-based indexing)
f.write("\n")
for face in faces:
f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
print(f"Surface exported to {filename}")
[docs]
def export_to_pyvista(
self,
filename: str,
save_annotation: str = None,
map_name: str = None,
overwrite: bool = False,
) -> None:
"""
Export surface to VTK, STL, or PLY format using PyVista.
Saves the surface mesh in formats supported by PyVista, preserving
geometry and optionally scalar data or colors. The file format is
determined by the filename extension.
Parameters
----------
filename : str
Output filename with extension (.vtk, .ply, or .stl).
save_annotation : str, optional
Path to save annotation file in FreeSurfer format.
map_name : str, optional
Name of overlay to include as scalar data or vertex colors.
overwrite : bool, default False
Whether to overwrite existing files.
Raises
------
ValueError
If filename is invalid, map_name not found, or file exists and
overwrite is False.
FileNotFoundError
If the output directory does not exist.
Notes
-----
VTK format can store additional scalar data and colors. PLY and STL
formats primarily store geometry. When map_name is specified, the
overlay data is prepared as vertex colors using associated colortables.
Examples
--------
>>> surface.export_to_pyvista("brain.vtk")
>>> surface.export_to_pyvista("surface.ply", map_name="thickness")
"""
# Validate filename
if not isinstance(filename, str):
raise ValueError("filename must be a string")
if not filename:
raise ValueError("filename cannot be empty")
# Check if the filename exists as a valid path
if os.path.exists(filename) and not overwrite:
raise ValueError(
f"File '{filename}' already exists. Please set overwrite to True or choose a different name."
)
# Ensure the directory exists
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
raise FileNotFoundError(f"Directory '{directory}' does not exist")
# If save_annotation is provided, save the annotation data
if save_annotation is not None:
self.export_annotation(
filename=save_annotation, parc_name=map_name, overwrite=overwrite
)
if map_name is not None:
# Ensure the map_name is valid
if not isinstance(map_name, str):
raise ValueError("map_name must be a string")
if not map_name:
raise ValueError("map_name cannot be empty")
# Ensure the map_name exists in the mesh point data
if map_name not in self.mesh.point_data:
raise ValueError(f"Map '{map_name}' not found in mesh point data")
# Set the active scalars to the specified map_name
self.mesh.set_active_scalars(map_name)
# Prepare colors
self.prepare_colors(
overlay_name=map_name,
cmap=None, # Use the colortable for this map
vmin=None,
vmax=None,
range_min=None,
range_max=None,
range_color=(128, 128, 128, 255),
)
# Save the mesh (no texture parameter needed for vertices colors)
self.mesh.save(filename)
else:
# Use PyVista's built-in save method for VTK format
self.mesh.save(filename)
##############################################################################################
[docs]
def export_to_freesurfer(
self,
filename: str,
save_annotation: str = None,
map_name: str = None,
overwrite: bool = False,
) -> None:
"""
Export surface to FreeSurfer binary format.
Saves the surface mesh in FreeSurfer's native binary geometry format,
which efficiently stores vertex coordinates and triangular face
connectivity for neuroimaging applications.
Parameters
----------
filename : str
Output filename, typically without extension (e.g., 'lh.pial').
save_annotation : str, optional
Path to save annotation file containing parcellation data.
map_name : str, optional
Name of parcellation to export with the annotation file.
overwrite : bool, default False
Whether to overwrite existing files.
Raises
------
ValueError
If filename is invalid or file exists and overwrite is False.
FileNotFoundError
If the output directory does not exist.
Notes
-----
FreeSurfer format is a compact binary representation optimized for
neuroimaging workflows. The format stores only geometry data;
additional data like parcellations are saved separately as .annot files.
Examples
--------
>>> surface.export_to_freesurfer("lh.pial")
>>> surface.export_to_freesurfer("rh.white", save_annotation="rh.aparc.annot", map_name="aparc")
"""
if not isinstance(filename, str):
raise ValueError("filename must be a string")
if not filename:
raise ValueError("filename cannot be empty")
# Check if the filename exist as a valid path
if os.path.exists(filename) and not overwrite:
raise ValueError(
f"File '{filename}' already exists. Please set overwrite to True or choose a different name."
)
# Ensure the directory exists
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
raise FileNotFoundError(f"Directory '{directory}' does not exist")
# If save_annotation is provided, save the annotation data
if save_annotation is not None:
self.export_annotation(
filename=save_annotation, parc_name=map_name, overwrite=overwrite
)
# Separating vertices and faces
vertices = self.mesh.points
faces = self.mesh.regular_faces
# Use nibabel to write FreeSurfer geometry
nib.freesurfer.write_geometry(filename, vertices, faces)
##############################################################################################
[docs]
def export_annotation(
self,
filename: str,
parc_name: str,
overwrite: bool = False,
) -> None:
"""
Export parcellation data to a FreeSurfer annotation file.
Saves vertex-wise parcellation labels, associated color lookup table,
and region names in FreeSurfer's .annot format for use with FreeSurfer
tools and visualization software.
Parameters
----------
filename : str
Output filename for annotation file (should end with .annot).
parc_name : str
Name of parcellation overlay to export from the surface data.
overwrite : bool, default False
Whether to overwrite existing files.
Raises
------
ValueError
If filename or parc_name is invalid, parcellation not found, or
file exists and overwrite is False.
FileNotFoundError
If the output directory does not exist.
Notes
-----
Requires the parcellation to have an associated colortable with region
names and colors. The annotation format preserves the mapping between
vertex labels, region names, and visualization colors.
Examples
--------
>>> surface.export_annotation("lh.aparc.annot", "aparc")
>>> surface.export_annotation("rh.destrieux.annot", "destrieux", overwrite=True)
"""
# If save_annotation is provided, save the annotation data
if not isinstance(filename, str):
raise ValueError("The annotation filename must be a string")
if not filename:
raise ValueError("The annotation filename cannot be empty")
# Check if the annotation file exists
if os.path.isfile(filename) and not overwrite:
raise ValueError(
f"Annotation file '{filename}' already exists. Please set overwrite to True or choose a different name."
)
# Check if the directory exists
annot_directory = os.path.dirname(filename)
if annot_directory and not os.path.exists(annot_directory):
raise FileNotFoundError(f"Directory '{annot_directory}' does not exist")
if not isinstance(parc_name, str):
raise ValueError("parc_name must be a string")
if not parc_name:
raise ValueError("parc_name cannot be empty")
# Ensure the parc_name exists in the mesh point data
if parc_name not in self.mesh.point_data:
raise ValueError(f"Map '{parc_name}' not found in mesh point data")
# Extract the annotation data
maps_array = self.mesh.point_data[parc_name]
# If there is a colortable for this map, use it
if parc_name in self.colortables:
ctable = self.colortables[parc_name]["color_table"]
struct_names = self.colortables[parc_name]["names"]
# Saving the annotation data in FreeSurfer format
annot_obj = cltfree.AnnotParcellation()
annot_obj.create_from_data(
maps_array, ctable, struct_names, annot_id=parc_name
)
annot_obj.save_annotation(filename, force=overwrite)
else:
print(
f"Warning: No colortable found for map '{parc_name}'. Annotation file will not be saved."
)
###############################################################################################
[docs]
def plot(
self,
overlay_name: str = None,
cmap: str = "viridis",
vmin: np.float64 = None,
vmax: np.float64 = None,
range_min: np.float64 = None,
range_max: np.float64 = None,
range_color: Tuple = (128, 128, 128, 255),
use_opacity: bool = False,
views: Union[str, List[str]] = ["lateral"],
views_orientation: str = "grid",
hemi: str = "lh",
notebook: bool = False,
show_colorbar: bool = False,
colorbar_title: str = None,
colorbar_position: str = "bottom",
save_path: str = None,
):
"""
Plot the surface with specified overlay and visualization parameters.
Renders the surface mesh with optional overlays using PyVista, supporting
multiple camera views, custom colormaps, and interactive or static output.
Handles both categorical parcellation data and continuous scalar overlays.
Parameters
----------
overlay_name : str, default "default"
Name of the overlay to visualize from the surface's point data.
cmap : str, optional
Colormap for scalar data. If None, uses parcellation colors for
categorical data or 'viridis' for scalar data.
vmin : float, optional
Minimum value for colormap scaling. If None, uses data minimum.
vmax : float, optional
Maximum value for colormap scaling. If None, uses data maximum.
range_min : float, optional
Minimum data value to include in the visualization. Values below
this threshold will be displayed in range_color.
range_max : float, optional
Maximum data value to include in the visualization. Values above
this threshold will be displayed in range_color.
range_color : Tuple, default (128, 128, 128, 255)
RGBA color for values outside the specified data range.
use_opacity : bool, default False
Whether to apply opacity based on data values.
views : str or List[str], default ["lateral"]
Camera view(s): 'lateral', 'medial', 'dorsal', 'ventral', 'anterior',
'posterior', or multiple views like ['lateral', 'medial']. Also supports
preset layouts: '4_views', '6_views', '8_views' with optional orientation.
hemi : str, default "lh"
Hemisphere to visualize: 'lh' (left) or 'rh' (right).
notebook : bool, default False
Whether to display in Jupyter notebook. If False, opens interactive window.
show_colorbar : bool, default False
Whether to display colorbar. Automatically determined if None.
colorbar_title : str, optional
Title for the colorbar. Uses overlay name if None.
colorbar_position : str, default "bottom"
Colorbar position: 'bottom', 'top', 'left', or 'right'.
save_path : str, optional
Path to save plot as image. If None, displays interactively.
Returns
-------
Plotter
PyVista plotter object for further customization.
Raises
------
ValueError
If overlay not found or invalid view parameter.
Examples
--------
>>> surface.plot(overlay_name="aparc")
>>> surface.plot(overlay_name="thickness", cmap="hot", views="medial", show_colorbar=True)
"""
# self.prepare_colors(overlay_name=overlay_name, cmap=cmap, vmin=vmin, vmax=vmax)
if overlay_name is None:
overlay_name = self.active_scalar
dict_ctables = self.colortables
if cmap is None:
if overlay_name in dict_ctables.keys():
show_colorbar = False
else:
show_colorbar = True
else:
show_colorbar = True
from . import visualizationtools as cltvis
plotter = cltvis.BrainPlotter()
plotter.plot(
self,
hemi_id=hemi,
views=views,
views_orientation=views_orientation,
map_names=overlay_name,
colormaps=cmap,
v_limits=(vmin, vmax),
use_opacity=use_opacity,
range_color=range_color,
v_range=(range_min, range_max),
notebook=notebook,
colorbar=show_colorbar,
colorbar_titles=colorbar_title,
colorbar_position=colorbar_position,
save_path=save_path,
)
#################################################################################################
[docs]
def merge_surfaces(
surfaces: List[Union[str, Path, Surface]],
color_table: dict = None,
map_name: str = "surf_id",
) -> Union[Surface, None]:
"""
Merge multiple surface meshes into a single surface with distinct region IDs.
Combines multiple surface meshes into one, assigning unique IDs to each
surface region for parcellation and visualization.
Parameters
----------
surfaces : List[Union[str, Path, Surface]]
List of surface meshes to merge. Each item can be:
- Surface object
- File path (str or Path) to a surface file
color_table : dict, optional
Colortable dictionary defining names and colors for each surface.
If None, a default colortable with distinguishable colors is created.
The dictionary should have the following structure:
{
"names": List[str], # List of surface names
"color_table": np.ndarray, # Nx5 array of RGBA colors and IDs
"lookup_table": None # Placeholder for future use
}
map_name : str, default "surf_id"
Name of the overlay map to store surface IDs in the merged surface.
Returns
-------
Union[Surface, None]
Merged Surface object with distinct region IDs, or None if merging fails.
Raises
------
TypeError
If surfaces is not a list or contains invalid items.
ValueError
If color_table is invalid or does not match number of surfaces.
Examples
--------
Merge multiple surface files with a custom colortable:
>>> surfaces = ["lh.pial", "rh.pial", "lh.white", "rh.white"]
>>> color_table = {
... "names": ["Left Pial", "Right Pial", "Left White", "Right White"],
... "color_table": np.array([[255, 0, 0, 255, 1],
... [0, 255, 0, 255, 2],
... [0, 0, 255, 255, 3],
... [255, 255, 0, 255, 4]])
... }
>>> merged_surface = merge_surfaces(surfaces, color_table=color_table, map_name="region_id")
>>> print(merged_surface)
Merged Surface with 4 regions and distinct IDs.
"""
if not isinstance(surfaces, list):
raise TypeError("surface_list must be a list")
if any(not isinstance(surf, (str, Path, Surface)) for surf in surfaces):
raise TypeError(
"All items in surface_list must be Surface objects, file paths, or Path objects"
)
if not surfaces:
return None
if len(surfaces) == 1:
if isinstance(surfaces[0], Surface):
return surfaces[0]
else:
return Surface(surfaces[0])
n_surfaces = len(surfaces)
# Validate color_table if provided
if color_table is not None:
if not isinstance(color_table, dict):
raise TypeError("color_table must be a dictionary")
required_keys = ["names", "color_table"]
if not all(key in color_table for key in required_keys):
raise ValueError(f"color_table must contain the keys: {required_keys}")
if len(color_table["names"]) != n_surfaces:
raise ValueError(
"Length of 'names' in color_table must match number of surfaces"
)
if color_table["color_table"].shape[0] != n_surfaces:
raise ValueError(
"Number of rows in 'color_table' must match number of surfaces"
)
color_table_dict = color_table
color_table_array = color_table["color_table"]
else:
# Create default color table
colors = cltcol.create_distinguishable_colors(n_surfaces)
color_table_array = cltcol.colors_to_table(
colors=colors, alpha_values=1, values=range(n_surfaces)
)
color_table_array[:, :3] = color_table_array[:, :3] / 255
color_table_array[:, 4] = np.arange(n_surfaces) + 1
surface_names = [f"mesh_{i}" for i in range(n_surfaces)]
color_table_dict = {
"names": surface_names,
"color_table": color_table_array,
"lookup_table": None,
}
# Track point ranges for each surface
point_ranges = []
# Merge surfaces
for i, surf in enumerate(surfaces):
try:
if i == 0:
if isinstance(surf, Surface):
merged = copy.deepcopy(surf)
else:
merged = Surface(surf)
point_ranges.append((0, merged.mesh.n_points))
else:
n_points_before = merged.mesh.n_points
result = merged.add_surface(surf)
if result is not None:
merged = result
n_points_after = merged.mesh.n_points
point_ranges.append((n_points_before, n_points_after))
except Exception as e:
print(f"Merge failed: {e}")
return None
# Create surf_ids based on actual point ranges
surf_ids = np.zeros((merged.mesh.n_points, 1))
for i, (start, end) in enumerate(point_ranges):
surf_ids[start:end] = color_table_array[i, 4]
merged.mesh.point_data[map_name] = surf_ids
merged.colortables[map_name] = color_table_dict
return merged
################################################################################################
[docs]
def create_surface_colortable(
colors: Union[str, List[str], np.ndarray],
struct_names: List[str] = None,
alpha: float = 1.0,
) -> dict:
"""
Create a colortable dictionary used for surface parcellation and visualization.
Generates a colortable mapping region names and colors for surface parcellations.
Parameters
----------
colors : str, List[str], or np.ndarray
Colors for each region. Can be:
- String: Single color name or hex code (applied to all regions)
- List of strings: Color names or hex codes for each region
- numpy.ndarray: Nx3 or Nx4 array of RGB(A) values (0-255)
struct_names : List[str], optional
List of region names corresponding to colors. If None, generic names
like 'region_1', 'region_2', etc. are assigned.
Length must match number of colors if provided.
alpha : float, default 1.0
Alpha transparency for colors (0.0 to 1.0). Applied uniformly if colors
do not include alpha channel.
Returns
-------
dict
Dictionary with keys:
- 'struct_names': List of region names
- 'color_table': Nx4 numpy array of RGBA colors (0-255)
- 'lookup_table': None (placeholder for future use)
Raises
------
ValueError
If alpha is out of range or struct_names length does not match colors.
Examples
--------
Create a colortable with specified colors and names:
>>> colors = array([[255, 0, 0, 204],
[ 0, 255, 0, 204],
[ 0, 0, 255, 204]])
>>> names = ['Region A', 'Region B', 'Region C']
>>> colortable = create_surface_colortable(colors, struct_names=names, alpha=0.8)
>>> print(colortable)
{'struct_names': ['Region A', 'Region B', 'Region C'],
'color_table': array([[255, 0, 0, 204],
[ 0, 255, 0, 204],
[ 0, 0, 255, 204]]),
'lookup_table': None}
"""
# Validate alpha value
if isinstance(alpha, int):
alpha = float(alpha)
# If the alpha is not in the range [0, 1], raise an error
if not (0 <= alpha <= 1):
raise ValueError(f"Alpha value must be in the range [0, 1], got {alpha}")
# Handle color input
colors = cltcol.harmonize_colors(colors, output_format="rgb")
tmp_ctable = cltcol.colors_to_table(colors=colors, alpha_values=alpha)
tmp_ctable[:, :3] = tmp_ctable[:, :3] / 255 # Ensure colors are between 0 and 1
if struct_names is None:
struct_names = [f"region_{i+1}" for i in range(tmp_ctable.shape[0])]
elif len(struct_names) != tmp_ctable.shape[0]:
raise ValueError(
"Length of struct_names must match number of colors in the colortable"
)
# Store parcellation information in organized structure
colortable = {
"names": struct_names,
"color_table": tmp_ctable,
"lookup_table": None, # Will be populated by _create_parcellation_colortable if needed
}
return colortable