"""
Visualization tools for neuroimaging data using PyVista.
Provides classes and functions for plotting brain surfaces, tractograms, and point clouds
with customizable views and color mappings.
Classes:
- BrainPlotter: A class for visualizing brain surfaces with various view configurations,
colormaps, and optional colorbars.
Functions:
- (Additional functions can be added here as needed)
"""
from __future__ import annotations
# Standard library imports
import copy
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
# Third-party imports
import numpy as np
import pyvista as pv
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
from matplotlib.colors import TwoSlopeNorm, Normalize
from matplotlib.transforms import blended_transform_factory
# Use TYPE_CHECKING to avoid circular imports
# Importing local modules
# Note: The following imports are placed here to avoid circular import issues. If you need to use these modules
from . import build_visualization_layout as vislayout
from . import misctools as cltmisc
from . import plottools as cltplot
from . import pointstools as cltpts
from . import surfacetools as cltsurf
from . import tracttools as clttract
from . import visualization_utils as visutils
from . import colorstools as cltcol
from . import parcellationtools as cltparc
####################################################################################################
####################################################################################################
############ ############
############ ############
############ Section 1: Class dedicated to plot Surface objects ############
############ ############
############ ############
####################################################################################################
####################################################################################################
[docs]
class BrainPlotter:
"""
A comprehensive brain surface visualization tool using PyVista.
This class provides flexible brain plotting capabilities with multiple view configurations,
customizable colormaps, and optional colorbar support for neuroimaging data visualization.
Attributes
----------
config_file : str
Path to the JSON configuration file containing layout definitions.
figure_conf : dict
Loaded figure configuration with styling settings.
views_conf : dict
Loaded views configuration with layout definitions.
Examples
--------
>>> plotter = BrainPlotter("brain_plot_configs.json")
>>> plotter.plot_hemispheres(surf_lh, surf_rh, map_name="thickness",
... views="8_views", colorbar=True)
>>>
>>> # Dynamic view selection
>>> plotter.plot_hemispheres(surf_lh, surf_rh, views=["lateral", "medial", "dorsal"])
"""
###############################################################################################
[docs]
def __init__(self, config_file: Union[str, Path, Dict] = None):
"""
Initialize the BrainPlotter with configuration file.
Parameters
----------
config_file : str, optional
Path to JSON file containing figure and view configurations.
If None, uses default "viz_views.json" from config directory.
Raises
------
FileNotFoundError
If the configuration file doesn't exist.
json.JSONDecodeError
If the configuration file contains invalid JSON.
KeyError
If required keys 'figure_conf' or 'views_conf' are missing.
Examples
--------
>>> plotter = BrainPlotter() # Use default config
>>>
>>> plotter = BrainPlotter("custom_views.json") # Use custom config
"""
# Loading the default configuration file
cwd = os.path.dirname(os.path.abspath(__file__))
# Default to the standard configuration file
def_config_file = os.path.join(cwd, "config", "viz_views.json")
configs = visutils.load_configs(def_config_file)
self.config_file = def_config_file
if config_file is not None:
# Use the provided configuration file path
try:
if isinstance(config_file, Dict):
user_configs = copy.deepcopy(config_file)
elif isinstance(config_file, str) or isinstance(config_file, Path):
user_configs = cltmisc.load_json(config_file)
self.config_file = config_file
else:
raise TypeError("config_file must be a str, Path, or Dict")
# Update default configs with user configs
configs = cltmisc.update_dict(configs, user_configs)
except Exception as e:
print(f"Error loading configuration file: {e}")
# Create attributes
config_keys = list(configs.keys())
for key in config_keys:
setattr(self, key, configs[key])
# Define mapping from simple view names to configuration titles
self._view_name_mapping = {
"lateral": ["LH: Lateral view", "RH: Lateral view"],
"medial": ["LH: Medial view", "RH: Medial view"],
"dorsal": ["Dorsal view"],
"ventral": ["Ventral view"],
"rostral": ["Rostral view"],
"caudal": ["Caudal view"],
}
################################################################################################
def _update_configs(self, config_file: Union[str, Path, Dict]):
"""
Update the plotting configurations from a new configuration file.
Parameters
----------
config_file : str, Path, Dict
Path to the new JSON configuration file or a dictionary with configurations.
"""
# Load new configurations
if isinstance(config_file, Dict):
configs = copy.deepcopy(config_file)
else:
configs = cltmisc.load_json(config_file)
# Create attributes
config_keys = list(configs.keys())
for key in config_keys:
if key in self.__dict__:
tmp = getattr(self, key)
upd_dict = cltmisc.update_dict(
tmp, configs[key], merge_lists=True, allow_new_keys=True
)
setattr(
self,
key,
upd_dict,
)
else:
print(
f"Warning: Key '{key}' not found in existing attributes. Skipping update."
)
###############################################################################################
def _build_plotting_config(
self,
views: list,
hemi_id: str = ["lh", "rh"],
orientation: str = "horizontal",
objs2plot: Union[Any, List[Any]] = None,
maps_dict: Dict = {},
colorbar: bool = True,
colorbar_style: str = "individual",
colorbar_position: str = "right",
):
"""
Build the plotting configuration based on user inputs.
Returns
-------
Tuple[List[int], List[float], List[float], List[Tuple], Dict, List[Dict]]
(shape, row_weights, col_weights, groups, brain_positions, colorbar_positions)
"""
# Normalize inputs
objs2plot = cltmisc.to_list(objs2plot) if objs2plot else []
maps_names = list(maps_dict.keys())
n_maps = len(maps_names)
n_objs = len(objs2plot)
# Force single view when both maps and objs2plot > 1
if n_maps > 1 and n_objs > 1 and len(views) > 1:
print(
"🔧 FORCING single view (dorsal) because n_maps > 1, n_objects > 1 and n_views > 1"
)
views = ["dorsal"]
# Get view configuration
view_ids = visutils.get_views_to_plot(self, views, hemi_id=hemi_id)
n_views = len(view_ids)
if n_maps > 1 and n_objs > 1:
view_ids = ["merg-dorsal"]
n_views = 1
print(
f"Number of views: {n_views}, Number of maps: {n_maps}, Number of objects: {n_objs}"
)
# Check if colorbar is needed
# colorbar = colorbar and visutils.colorbar_needed(maps_names, surfaces)
# Build configuration based on dimensions
config, colorbar_list = vislayout.build_layout_config(
self,
view_ids,
objs2plot,
maps_dict,
colorbar,
orientation,
colorbar_style,
colorbar_position,
)
return (
view_ids,
config,
colorbar_list,
)
###############################################################################
[docs]
def plot(
self,
objs2plot: Union[cltsurf.Surface, clttract.Tractogram, cltpts.PointCloud, List],
hemi_id: Union[str, List[str]] = "lh",
views: Union[str, List[str]] = "dorsal",
views_orientation: str = "horizontal",
notebook: bool = False,
map_names: Union[str, List[str]] = ["default"],
v_limits: Optional[Union[Tuple[float, float], List[Tuple[float, float]]]] = (
None,
None,
),
v_range: Optional[Union[Tuple[float, float], List[Tuple[float, float]]]] = (
None,
None,
),
range_color: Tuple = (128, 128, 128, 255),
use_opacity: bool = True,
colormaps: Union[str, List[str]] = "BrBG",
save_path: Optional[str] = None,
non_blocking: bool = True,
colorbar: bool = True,
colorbar_style: str = "individual",
colorbar_titles: Union[str, List[str]] = None,
colorbar_position: str = "right",
config_file: Union[str, Path, Dict] = None,
) -> None:
"""
Plot brain surfaces with optional threading and screenshot support.
Parameters
----------
objs2plot : Union[cltsurf.Surface, clttract.Tractogram, cltpts.PointCloud, List]
Object(s) to plot. Can be a single object or a list of objects.
hemi_id : List[str], default ["lh"]
Hemisphere identifiers.
views : Union[str, List[str]], default "dorsal"
View angles for the surfaces.
views_orientation : str, default "horizontal"
Orientation of the views layout.
notebook : bool, default False
Whether running in Jupyter notebook environment.
map_names : Union[str, List[str]], default ["default"]
Names of the surface maps to plot.
v_limits : Optional[Union[Tuple[float, float], List[Tuple[float, float]]]], default (None, None)
Value limits for colormapping.
v_range : Optional[Union[Tuple[float, float], List[Tuple[float, float]]]], default (None, None)
Value range for colormap application. Values outside this range will be displayed in range_color.
range_color : Tuple, default (128, 128, 128, 255)
RGBA color to use for values outside the specified v_range.
colormaps : Union[str, List[str]], default "BrBG"
Colormaps to use for each map.
use_opacity : bool, default True
Whether to use opacity in the surface rendering. This is important when saving to HTML format to
ensure proper visualization. If False, surfaces will be fully opaque.
save_path : Optional[str], default None
File path for saving the figure. If None, plot is displayed.
non_blocking : bool, default False
If True, display the plot in a separate thread, allowing the terminal
or notebook to remain interactive. Only applies when save_path is None.
colorbar : bool, default True
Whether to show colorbars.
colorbar_style : str, default "individual"
Style of colormap application.
colorbar_titles : Union[str, List[str]], optional
Titles for the colorbars.
colorbar_position : str, default "right"
Position of the colorbars.
"""
# Validate and process hemi_id parameter
if isinstance(hemi_id, str):
hemi_id = [hemi_id]
# the hemi_id must be one of the following
valid_hemi_ids = ["lh", "rh", "both"]
# Leave in hemi_id only valid values
hemi_id = [h for h in hemi_id if h in valid_hemi_ids]
if "both" in hemi_id and len(hemi_id) > 1:
hemi_id = ["lh", "rh"]
# Loading custom configuration file if provided
if config_file is not None:
try:
self._update_configs(config_file)
except Exception as e:
print(
f"Error loading configuration file: {e}. Using existing configurations."
)
# Preparing the surfaces to be plotted
if not isinstance(objs2plot, List):
obj2plot = [copy.deepcopy(objs2plot)]
else:
obj2plot = copy.deepcopy(objs2plot)
# Number of objects to plot
n_objects = len(obj2plot)
# Filter to only available maps
if isinstance(map_names, str):
map_names = [map_names]
n_maps = len(map_names)
# Available overlays
no_ctab_maps, map_names = visutils.find_common_map_names(obj2plot, map_names)
# Preparing the maps dictionary
maps_dict = {}
no_ctab_maps_dict = {}
if len(no_ctab_maps) != 0:
no_ctab_maps_dict = visutils.prepare_map_plotting_params(
no_ctab_maps,
colormaps=colormaps,
v_limits=v_limits,
v_range=v_range,
range_color=range_color,
colorbar_titles=colorbar_titles,
)
# Difference list between map_names and no_ctab_maps
diff_maps = list(set(map_names) - set(no_ctab_maps))
ctab_maps_dict = {}
if len(diff_maps) > 0:
for map_name in diff_maps:
ctab_maps_dict[map_name] = {
"colormap": "colortable",
"vmin": None,
"vmax": None,
"range_min": None,
"range_max": None,
"range_color": None,
"colorbar": False,
"colorbar_title": None,
}
# Merge both dictionaries
maps_dict.update(ctab_maps_dict)
maps_dict.update(no_ctab_maps_dict)
# Assing the name of the map for the colorbar titles if not provided
for map_name in maps_dict.keys():
if maps_dict[map_name]["colorbar_title"] is None:
maps_dict[map_name]["colorbar_title"] = map_name
(
view_ids,
config_dict,
colorbar_dict_list,
) = self._build_plotting_config(
views=views,
objs2plot=obj2plot,
maps_dict=maps_dict,
colorbar=colorbar,
orientation=views_orientation,
hemi_id=hemi_id,
colorbar_style=colorbar_style,
colorbar_position=colorbar_position,
)
# Determine rendering mode based on save_path, environment, and threading preference
save_mode, use_off_screen, use_notebook, use_threading = (
visutils.determine_render_mode(save_path, notebook, non_blocking)
)
# Detecting the screen size for the plotter
screen_size = cltplot.get_current_monitor_size()
# Create PyVista plotter with appropriate rendering mode
plotter_kwargs = {
"notebook": use_notebook,
"window_size": [screen_size[0], screen_size[1]],
"off_screen": use_off_screen,
"shape": config_dict["shape"],
"row_weights": config_dict["row_weights"],
"col_weights": config_dict["col_weights"],
"border": self.figure_conf.get("subplot_border", True),
}
groups = config_dict["groups"]
if groups:
plotter_kwargs["groups"] = groups
pv_plotter = pv.Plotter(**plotter_kwargs)
# Now you can place brain surfaces at specific positions
pv_plotter.set_background(self.figure_conf["background_color"])
brain_positions = config_dict["brain_positions"]
# Computing the plot indexes
subplot_indices = []
n_subplots = len(pv_plotter.renderers)
n_rows = config_dict["shape"][0]
n_cols = config_dict["shape"][1]
subplot_indices = []
for (map_idx, obj_idx, view_idx), position in brain_positions.items():
# Handle case where position might be a list/tuple of coordinates
if isinstance(position, (list, tuple)) and len(position) >= 2:
row, col = position[0], position[1]
else:
row, col = position
# Ensure row and col are integers
if isinstance(row, (list, tuple)):
row = row[0] if row else 0
if isinstance(col, (list, tuple)):
col = col[0] if col else 0
subplot_indices.append(int(row) * n_cols + int(col))
# If there is any element of subplot_indices that is bigger than n_subplots do something else
if any(sp_index > n_subplots for sp_index in subplot_indices):
# Remove all the elements that are bigger than n_subplots
# Take a vector from 0 to 6*4 and reshape it to a matrix of 6 rows and 4 columns and print it
tmp = np.arange(0, n_rows * n_cols).reshape(n_rows, n_cols)
# Now remove the last column and print the matrix
tmp = tmp[:, :-1]
# Now, if the matrix has n_rows bigger than 3, remove , from rows 3 to n_rows -1
if tmp.shape[0] > 3:
for cont, r in enumerate(range(1, tmp.shape[0])):
tmp[r, :] = tmp[r, :] - cont
subplot_indices = tmp.T.flatten().tolist()
map_limits = config_dict["colormap_limits"]
for (map_idx, obj_idx, view_idx), (row, col) in brain_positions.items():
pv_plotter.subplot(row, col)
# Set background color from figure configuration
pv_plotter.set_background(self.figure_conf["background_color"])
tmp_view_name = view_ids[view_idx]
# Split the view name if it contains '_'
if "-" in tmp_view_name:
tmp_view_name = tmp_view_name.split("-")[1]
# Capitalize the first letter
tmp_view_name = tmp_view_name.capitalize()
pv_plotter.add_text(
f"{map_names[map_idx]}, Object: {obj_idx}, View: {tmp_view_name}",
font_size=self.figure_conf["title_font_size"],
position="upper_edge",
color=self.figure_conf["title_font_color"],
shadow=self.figure_conf["title_shadow"],
font=self.figure_conf["title_font_type"],
)
# Geting the vmin and vmax for the current map
vmin, vmax, map_name = map_limits[map_idx, obj_idx, view_idx][0]
# Select the colormap for the current map
idx = [i for i, name in enumerate(map_names) if name == map_name]
# colormap = colormaps[idx[0]] if idx else colormaps[0]
colormap = maps_dict[map_names[idx[0]]]["colormap"]
range_min = maps_dict[map_names[idx[0]]]["range_min"]
range_max = maps_dict[map_names[idx[0]]]["range_max"]
range_color = maps_dict[map_names[idx[0]]]["range_color"]
# Add the brain surface mesh
prep_obj = visutils.prepare_list_obj_for_plotting(
obj2plot[obj_idx],
map_names[map_idx],
colormap,
vmin=vmin,
vmax=vmax,
range_min=range_min,
range_max=range_max,
range_color=range_color,
)
for tmp_obj in prep_obj:
if isinstance(tmp_obj, clttract.Tractogram):
tracts = tmp_obj.tracts
rgba_data = tmp_obj.data_per_point["rgba"]
# 1. Concatenate all points and colors
all_points = np.vstack(tracts)
all_rgba = np.vstack(rgba_data)
if use_opacity is False:
all_rgba = all_rgba[:, :3]
# 2. Build the lines connectivity array
# Format: [n1, idx0, idx1, ..., n2, idx0, idx1, ...]
lines = []
offset = 0
for tract in tracts:
n = len(tract)
lines.append(n)
lines.extend(range(offset, offset + n))
offset += n
lines = np.array(lines, dtype=np.int_)
# 3. Create single PolyData with all curves
if self.objs_conf["tracts"]["tubes"]:
# Create a PolyData object for tube representation
# Create a PolyData object for tube representation
poly = pv.PolyData()
poly.points = all_points
poly.lines = lines
# Attach your RGBA scalars
poly.point_data["rgba"] = all_rgba # <-- important
# Add tube filter (tube cannot take scalars directly)
tube_radius = self.objs_conf["tracts"]["tube_radius"]
tube_sides = self.objs_conf["tracts"]["tube_sides"]
tube_poly = poly.tube(
radius=tube_radius,
n_sides=tube_sides,
)
# Add the mesh with tube representation
pv_plotter.add_mesh(
tube_poly,
scalars="rgba", # use the same name
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
else:
poly = pv.PolyData()
poly.points = all_points
poly.lines = lines
poly.point_data["rgba"] = all_rgba
# 4. Single add_mesh call
pv_plotter.add_mesh(
poly,
scalars="rgba",
line_width=2,
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
elif isinstance(tmp_obj, cltpts.PointCloud):
rgba_data = tmp_obj.point_data["rgba"]
if use_opacity is False:
rgba_data = rgba_data[:, :3]
pv_plotter.add_points(
tmp_obj.coords,
render_points_as_spheres=self.objs_conf["points"]["spheres"],
point_size=self.objs_conf["points"]["spheres_radius"],
scalars=rgba_data,
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
elif isinstance(tmp_obj, cltsurf.Surface):
if not use_opacity:
# delete the alpha channel if exists
if "rgba" in tmp_obj.mesh.point_data:
tmp_obj.mesh.point_data["rgba"] = tmp_obj.mesh.point_data[
"rgba"
][:, :3]
pv_plotter.add_mesh(
copy.deepcopy(tmp_obj.mesh),
scalars="rgba",
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
# Set the camera view
tmp_view = view_ids[view_idx]
# Replace merg from the view id if needed
if "merg" in tmp_view:
tmp_view = tmp_view.replace("merg", "lh")
camera_params = self.views_conf[tmp_view]
pv_plotter.camera_position = camera_params["view"]
pv_plotter.camera.azimuth = camera_params["azimuth"]
pv_plotter.camera.elevation = camera_params["elevation"]
pv_plotter.camera.zoom(camera_params["zoom"])
# And place colorbars at their positions
if len(colorbar_dict_list):
for colorbar_dict in colorbar_dict_list:
if colorbar_dict is not False:
row, col = colorbar_dict["position"]
orientation = colorbar_dict["orientation"]
colorbar_id = colorbar_dict["map_name"]
colormap = colorbar_dict["colormap"]
colorbar_title = colorbar_dict["title"]
vmin = colorbar_dict["vmin"]
vmax = colorbar_dict["vmax"]
pv_plotter.subplot(row, col)
if colormap == "colortable":
pass # Currently, no colorbar for categorical maps is implemented
else:
visutils.add_colorbar(
self,
plotter=pv_plotter,
colorbar_subplot=(row, col),
vmin=vmin,
vmax=vmax,
map_name=colorbar_id,
colormap=colormap,
colorbar_title=colorbar_title,
colorbar_position=orientation,
)
# Linking the cameras from the subplots with the same view
unique_v_indices = set(key[2] for key in brain_positions.keys())
grouped_by_v_idx = {}
for v_idx in unique_v_indices:
grouped_by_v_idx[v_idx] = []
for i, ((m_idx, s_idx, v_idx), (row, col)) in enumerate(
brain_positions.items()
):
if v_idx in grouped_by_v_idx: # Safety check
grouped_by_v_idx[v_idx].append(subplot_indices[i])
# After all subplots are created and populated, link the views
for v_idx, positions in grouped_by_v_idx.items():
if len(positions) > 1:
# Link all views in this group
try:
pv_plotter.link_views(positions)
except:
try:
# Try substracting 1 from positions bigger than n_horz_plots
# This is to handle the case when there are colorbars in the last column
# and there are more than 2 rows
# Get number of horizontal plots
n_horz_plots = pv_plotter.shape[1] - 1
# Substract 1 to all the elements in positions that are bigger than n_horz_plots
new_positions = np.arange(len(pv_plotter.renderers)).tolist()
# Remove the element equal to n_horz_plots-1 from new_positions
if n_horz_plots in new_positions:
new_positions.remove(n_horz_plots)
pv_plotter.link_views(new_positions)
except:
print(
f"Could not link views for view index {v_idx} at positions {positions}"
)
# Handle final rendering - either save, display blocking, or display non-blocking
visutils.finalize_plot(pv_plotter, save_mode, save_path, use_threading)
###########################################################################
[docs]
def plot_hemispheres(
self,
obj_rh: Union[cltsurf.Surface, clttract.Tractogram, cltpts.PointCloud, List],
obj_lh: Union[cltsurf.Surface, clttract.Tractogram, cltpts.PointCloud, List],
map_name: str = "default",
views: Union[str, List[str]] = "dorsal",
vmin: Optional[float] = None,
vmax: Optional[float] = None,
range_min: Optional[float] = None,
range_max: Optional[float] = None,
range_color: Tuple = (128, 128, 128, 255),
use_opacity: bool = True,
colormap: str = "viridis",
colorbar: bool = True,
colorbar_title: str = None,
colorbar_position: str = "right",
notebook: bool = False,
non_blocking: bool = False,
save_path: Optional[str] = None,
config_file: Union[str, Path, Dict] = None,
):
"""
Plot brain hemispheres with multiple views.
Parameters
----------
obj_rh : Union[cltsurf.Surface, clttract.Tractogram, cltpts.PointCloud]
Right hemisphere object to plot. Can be a Surface, Tractogram, or PointCloud.
obj_lh : Union[cltsurf.Surface, clttract.Tractogram, cltpts.PointCloud]
Left hemisphere object to plot. Can be a Surface, Tractogram, or PointCloud.
map_name : str "default"
Name of the data maps to visualize. Must be present in all the objects.
views : str or list of str, default "dorsal"
Views to display. Options include 'dorsal', 'ventral', 'lateral', 'medial', 'anterior', 'posterior'.
Can be a single view or a list of views. It can also include different multiple views specified as layouts:
>>> plotter = BrainPlotter("configs.json")
>>> layouts = plotter.list_available_layouts()
vmin : float, optional
Minimum value for colormap scaling. If None, uses the minimum value from the data.
vmax : float, optional
Maximum value for colormap scaling. If None, uses the maximum value from the data.
range_min : float, optional
Minimum value for the value range. Values below this will be colored with `range_color`.
range_max : float, optional
Maximum value for the value range. Values above this will be colored with `range_color`.
range_color : Tuple, default (128, 128, 128, 255)
RGBA color to use for values outside the specified v_range.
colormap : str or list of str, default "BrBG"
Colormap to use for visualization.
colorbar : bool, default True
Whether to display colorbars for the maps.
colorbar_title : str, optional
Title for the colorbar. If None, map name are used.
colorbar_position : str, default "right"
Position of the colorbars. Options are 'right' or 'bottom'.
notebook : bool, default False
Whether to render the plot in a Jupyter notebook environment.
If True, uses notebook-compatible rendering.
non_blocking : bool, default False
If True, displays the plot in a non-blocking manner using threading.
Only applicable when `notebook` is False and `save_path` is None.
save_path : str, optional
File path to save the rendered figure. If provided, the figure is saved to this path
instead of being displayed.
config_file : Union[str, Path, Dict], optional
Path to a custom configuration file (JSON) or a dictionary containing configuration settings.
If provided, this configuration will override the default settings for plotting.
Returns
-------
None
The function does not return any value. It either displays the plot or saves it to a
file, depending on the parameters provided.
"""
# Loading custom configuration file if provided
if config_file is not None:
try:
self._update_configs(config_file)
except Exception as e:
print(
f"Error loading configuration file: {e}. Using existing configurations."
)
# Preparing the surfaces to be plotted
if not isinstance(obj_lh, List):
obj_lh = [copy.deepcopy(obj_lh)]
if not isinstance(obj_rh, List):
obj_rh = [copy.deepcopy(obj_rh)]
# Filter to only available maps
if isinstance(map_name, str):
map_name = [map_name]
all_objects = visutils.flatten_objects(obj_lh + obj_rh)
# Available overlays
no_ctab_map, map_name = visutils.find_common_map_names(all_objects, map_name)
# Preparing the maps dictionary
maps_dict = {}
no_ctab_maps_dict = {}
if len(no_ctab_map) != 0:
no_ctab_maps_dict = visutils.prepare_map_plotting_params(
no_ctab_map,
colormaps=colormap,
v_limits=(vmin, vmax),
v_range=(range_min, range_max),
range_color=range_color,
colorbar_titles=colorbar_title,
)
# Difference list between map_names and no_ctab_maps
diff_maps = list(set(map_name) - set(no_ctab_map))
ctab_maps_dict = {}
if len(diff_maps) > 0:
for map_name in diff_maps:
ctab_maps_dict[map_name] = {
"colormap": "colortable",
"vmin": None,
"vmax": None,
"range_min": None,
"range_max": None,
"range_color": None,
"colorbar": False,
"colorbar_title": None,
}
# Merge both dictionaries
maps_dict.update(ctab_maps_dict)
maps_dict.update(no_ctab_maps_dict)
# Assing the name of the map for the colorbar titles if not provided
for map_name in maps_dict.keys():
if maps_dict[map_name]["colorbar_title"] is None:
maps_dict[map_name]["colorbar_title"] = map_name
valid_views = visutils.get_views_to_plot(self, views, ["lh", "rh"])
n_views = len(valid_views)
colorbar_size = self.figure_conf["colorbar_size"]
limits_dict, charac_dict = visutils.get_map_characteristics(
all_objects, maps_dict
)
##### Determine colormap limits based on colorbar style #####
colormap_limits = {}
for view_idx in range(n_views):
colormap_limits[(0, 0, view_idx)] = [limits_dict["shared"]]
config_dict, colorbar_list = vislayout.grid_multi_views_layout(
maps_dict,
colormap_limits,
charac_dict,
valid_views,
colorbar,
colorbar_position,
colorbar_size,
)
# Determine rendering mode based on save_path, environment, and threading preference
save_mode, use_off_screen, use_notebook, use_threading = (
visutils.determine_render_mode(save_path, notebook, non_blocking)
)
# Detecting the screen size for the plotter
screen_size = cltplot.get_current_monitor_size()
# Create PyVista plotter with appropriate rendering mode
plotter_kwargs = {
"notebook": use_notebook,
"window_size": [screen_size[0], screen_size[1]],
"off_screen": use_off_screen,
"shape": config_dict["shape"],
"row_weights": config_dict["row_weights"],
"col_weights": config_dict["col_weights"],
"border": self.figure_conf.get("subplot_border", True),
}
groups = config_dict["groups"]
if groups:
plotter_kwargs["groups"] = groups
pv_plotter = pv.Plotter(**plotter_kwargs)
# Now you can place brain objects at specific positions
pv_plotter.set_background(self.figure_conf["background_color"])
brain_positions = config_dict["brain_positions"]
# Computing the plot indexes
map_limits = config_dict["colormap_limits"]
# Geting the vmin and vmax for the current map
vmin, vmax, map_name = map_limits[0, 0, 0][0]
colormap = maps_dict[map_name]["colormap"]
range_min = maps_dict[map_name]["range_min"]
range_max = maps_dict[map_name]["range_max"]
range_color = maps_dict[map_name]["range_color"]
for (map_idx, obj_idx, view_idx), (row, col) in brain_positions.items():
pv_plotter.subplot(row, col)
# Set background color from figure configuration
pv_plotter.set_background(self.figure_conf["background_color"])
tmp_view_name = valid_views[view_idx]
# Split the view name if it contains '_'
if "-" in tmp_view_name:
tmp_view_name = tmp_view_name.split("-")[1]
# Capitalize the first letter
tmp_view_name = tmp_view_name.capitalize()
# Detecting if the view is left or right
if "lh" in valid_views[view_idx]:
subplot_title = "Left hemisphere: " + tmp_view_name + " view"
elif "rh" in valid_views[view_idx]:
subplot_title = "Right hemisphere: " + tmp_view_name + " view"
elif "merg" in valid_views[view_idx]:
subplot_title = tmp_view_name + " view"
pv_plotter.add_text(
subplot_title,
font_size=self.figure_conf["title_font_size"],
position="upper_edge",
color=self.figure_conf["title_font_color"],
shadow=self.figure_conf["title_shadow"],
font=self.figure_conf["title_font_type"],
)
# Add the brain surface mesh
if "lh" in valid_views[view_idx]:
prep_obj = visutils.prepare_list_obj_for_plotting(
obj_lh,
map_name,
colormap,
vmin=vmin,
vmax=vmax,
range_min=range_min,
range_max=range_max,
range_color=range_color,
)
elif "rh" in valid_views[view_idx]:
prep_obj = visutils.prepare_list_obj_for_plotting(
obj_rh,
map_name,
colormap,
vmin=vmin,
vmax=vmax,
range_min=range_min,
range_max=range_max,
range_color=range_color,
)
elif "merg" in valid_views[view_idx]:
prep_obj = visutils.prepare_list_obj_for_plotting(
all_objects,
map_name,
colormap,
vmin=vmin,
vmax=vmax,
range_min=range_min,
range_max=range_max,
range_color=range_color,
)
for tmp_obj in prep_obj:
if isinstance(tmp_obj, clttract.Tractogram):
tracts = tmp_obj.tracts
rgba_data = tmp_obj.data_per_point["rgba"]
# 1. Concatenate all points and colors
all_points = np.vstack(tracts)
all_rgba = np.vstack(rgba_data)
if use_opacity is False:
all_rgba = all_rgba[:, :3]
# 2. Build the lines connectivity array
# Format: [n1, idx0, idx1, ..., n2, idx0, idx1, ...]
lines = []
offset = 0
for tract in tracts:
n = len(tract)
lines.append(n)
lines.extend(range(offset, offset + n))
offset += n
lines = np.array(lines, dtype=np.int_)
# 3. Create single PolyData with all curves
if self.objs_conf["tracts"]["tubes"]:
# Create a PolyData object for tube representation
# Create a PolyData object for tube representation
poly = pv.PolyData()
poly.points = all_points
poly.lines = lines
# Attach your RGBA scalars
poly.point_data["rgba"] = all_rgba # <-- important
# Add tube filter (tube cannot take scalars directly)
tube_radius = self.objs_conf["tracts"]["tube_radius"]
tube_sides = self.objs_conf["tracts"]["tube_sides"]
tube_poly = poly.tube(
radius=tube_radius,
n_sides=tube_sides,
)
# Add the mesh with tube representation
pv_plotter.add_mesh(
tube_poly,
scalars="rgba", # use the same name
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
else:
poly = pv.PolyData()
poly.points = all_points
poly.lines = lines
poly.point_data["rgba"] = all_rgba
# 4. Single add_mesh call
pv_plotter.add_mesh(
poly,
scalars="rgba",
line_width=2,
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
elif isinstance(tmp_obj, cltpts.PointCloud):
rgba_data = tmp_obj.point_data["rgba"]
if use_opacity is False:
rgba_data = rgba_data[:, :3]
pv_plotter.add_points(
tmp_obj.coords,
render_points_as_spheres=self.objs_conf["points"]["spheres"],
point_size=self.objs_conf["points"]["spheres_radius"],
scalars=rgba_data,
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
elif isinstance(tmp_obj, cltsurf.Surface):
if not use_opacity:
# delete the alpha channel if exists
if "rgba" in tmp_obj.mesh.point_data:
tmp_obj.mesh.point_data["rgba"] = tmp_obj.mesh.point_data[
"rgba"
][:, :3]
pv_plotter.add_mesh(
copy.deepcopy(tmp_obj.mesh),
scalars="rgba",
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
# Set the camera view
tmp_view = valid_views[view_idx]
# Replace merg from the view id if needed
if "merg" in tmp_view:
tmp_view = tmp_view.replace("merg", "lh")
camera_params = self.views_conf[tmp_view]
pv_plotter.camera_position = camera_params["view"]
pv_plotter.camera.azimuth = camera_params["azimuth"]
pv_plotter.camera.elevation = camera_params["elevation"]
pv_plotter.camera.zoom(camera_params["zoom"])
# And place colorbars at their positions
if len(colorbar_list):
for colorbar_dict in colorbar_list:
if colorbar_dict is not False:
row, col = colorbar_dict["position"]
orientation = colorbar_dict["orientation"]
colorbar_id = colorbar_dict["map_name"]
colormap = colorbar_dict["colormap"]
colorbar_title = colorbar_dict["title"]
vmin = colorbar_dict["vmin"]
vmax = colorbar_dict["vmax"]
pv_plotter.subplot(row, col)
if colormap == "colortable":
pass # Currently, no colorbar for categorical maps is implemented
else:
visutils.add_colorbar(
self,
plotter=pv_plotter,
colorbar_subplot=(row, col),
vmin=vmin,
vmax=vmax,
map_name=colorbar_id,
colormap=colormap,
colorbar_title=colorbar_title,
colorbar_position=orientation,
)
# Handle final rendering - either save, display blocking, or display non-blocking
visutils.finalize_plot(pv_plotter, save_mode, save_path, use_threading)
###########################################################################
[docs]
def plot_scene(
self,
scene_objects: Union[
cltsurf.Surface, clttract.Tractogram, cltpts.PointCloud, List
],
scene_config: Dict = None,
views: Union[str, List[str]] = "dorsal",
notebook: bool = False,
colorbar: bool = True,
colorbar_position: str = "right",
use_opacity: bool = True,
non_blocking: bool = False,
save_path: Optional[str] = None,
config_file: Union[str, Path, Dict] = None,
):
# Loading custom configuration file if provided
if config_file is not None:
try:
self._update_configs(config_file)
except Exception as e:
print(
f"Error loading configuration file: {e}. Using existing configurations."
)
fin_obj_config = visutils.create_final_object_config(
scene_objects, maps_config=scene_config
)
valid_views = visutils.get_views_to_plot(self, views, ["lh", "rh"])
n_views = len(valid_views)
colorbar_size = self.figure_conf["colorbar_size"]
config_dict = vislayout.scene_layout(
valid_views, colorbar, colorbar_position, colorbar_size
)
# Determine rendering mode based on save_path, environment, and threading preference
save_mode, use_off_screen, use_notebook, use_threading = (
visutils.determine_render_mode(save_path, notebook, non_blocking)
)
# Detecting the screen size for the plotter
screen_size = cltplot.get_current_monitor_size()
# Create PyVista plotter with appropriate rendering mode
plotter_kwargs = {
"notebook": use_notebook,
"window_size": [screen_size[0], screen_size[1]],
"off_screen": use_off_screen,
"shape": config_dict["shape"],
"row_weights": config_dict["row_weights"],
"col_weights": config_dict["col_weights"],
"border": self.figure_conf.get("subplot_border", True),
}
groups = config_dict["groups"]
if groups:
plotter_kwargs["groups"] = groups
pv_plotter = pv.Plotter(**plotter_kwargs)
# Now you can place brain objects at specific positions
pv_plotter.set_background(self.figure_conf["background_color"])
brain_positions = config_dict["brain_positions"]
for (map_idx, obj_idx, view_idx), (row, col) in brain_positions.items():
pv_plotter.subplot(row, col)
# Set background color from figure configuration
pv_plotter.set_background(self.figure_conf["background_color"])
tmp_view_name = valid_views[view_idx]
# Split the view name if it contains '_'
if "-" in tmp_view_name:
tmp_view_name = tmp_view_name.split("-")[1]
# Capitalize the first letter
tmp_view_name = tmp_view_name.capitalize()
# Detecting if the view is left or right
if "lh" in valid_views[view_idx]:
subplot_title = "Left hemisphere: " + tmp_view_name + " view"
elif "rh" in valid_views[view_idx]:
subplot_title = "Right hemisphere: " + tmp_view_name + " view"
elif "merg" in valid_views[view_idx]:
subplot_title = tmp_view_name + " view"
pv_plotter.add_text(
subplot_title,
font_size=self.figure_conf["title_font_size"],
position="upper_edge",
color=self.figure_conf["title_font_color"],
shadow=self.figure_conf["title_shadow"],
font=self.figure_conf["title_font_type"],
)
prep_obj = []
for idx, obj in enumerate(scene_objects):
map_name = fin_obj_config[idx]["map_name"]
colormap = fin_obj_config[idx]["colormap"]
vmin = fin_obj_config[idx]["v_limits"][0]
vmax = fin_obj_config[idx]["v_limits"][1]
range_min = fin_obj_config[idx]["v_range"][0]
range_max = fin_obj_config[idx]["v_range"][1]
range_color = fin_obj_config[idx]["range_color"]
opacity = fin_obj_config[idx]["opacity"]
prep_obj.extend(
visutils.prepare_list_obj_for_plotting(
obj,
map_name,
colormap,
vmin=vmin,
vmax=vmax,
range_min=range_min,
range_max=range_max,
range_color=range_color,
)
)
for idx, tmp_obj in enumerate(prep_obj):
opacity = fin_obj_config[idx]["opacity"]
if isinstance(tmp_obj, clttract.Tractogram):
tracts = tmp_obj.tracts
rgba_data = tmp_obj.data_per_point["rgba"]
# 1. Concatenate all points and colors
all_points = np.vstack(tracts)
all_rgba = np.vstack(rgba_data)
all_rgba = all_rgba[:, :3]
if use_opacity is True:
# Check if data is in 0-1 range or 0-255 range
if all_rgba.max() <= 1.0:
# Data is in 0-1 range
alpha_column = np.ones(all_rgba.shape[0]) * opacity
else:
# Data is in 0-255 range
alpha_column = np.ones(all_rgba.shape[0]) * opacity * 255
# Add alpha channel
rgba_with_alpha = np.column_stack([all_rgba, alpha_column])
# Maintain the same dtype as original
rgba_with_alpha = rgba_with_alpha.astype(all_rgba.dtype)
# Assign back
all_rgba = rgba_with_alpha
# 2. Build the lines connectivity array
# Format: [n1, idx0, idx1, ..., n2, idx0, idx1, ...]
lines = []
offset = 0
for tract in tracts:
n = len(tract)
lines.append(n)
lines.extend(range(offset, offset + n))
offset += n
lines = np.array(lines, dtype=np.int_)
# 3. Create single PolyData with all curves
if self.objs_conf["tracts"]["tubes"]:
# Create a PolyData object for tube representation
# Create a PolyData object for tube representation
poly = pv.PolyData()
poly.points = all_points
poly.lines = lines
# Attach your RGBA scalars
poly.point_data["rgba"] = all_rgba # <-- important
# Add tube filter (tube cannot take scalars directly)
tube_radius = self.objs_conf["tracts"]["tube_radius"]
tube_sides = self.objs_conf["tracts"]["tube_sides"]
tube_poly = poly.tube(
radius=tube_radius,
n_sides=tube_sides,
)
# Add the mesh with tube representation
pv_plotter.add_mesh(
tube_poly,
scalars="rgba", # use the same name
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
else:
poly = pv.PolyData()
poly.points = all_points
poly.lines = lines
poly.point_data["rgba"] = all_rgba
# 4. Single add_mesh call
pv_plotter.add_mesh(
poly,
scalars="rgba",
line_width=2,
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
elif isinstance(tmp_obj, cltpts.PointCloud):
rgba_data = tmp_obj.point_data["rgba"]
# Check if data is in 0-1 range or 0-255 range
if rgba_data.max() <= 1.0:
# Data is in 0-1 range
alpha_column = np.ones(rgba_data.shape[0]) * opacity
else:
# Data is in 0-255 range
alpha_column = np.ones(rgba_data.shape[0]) * opacity * 255
# Add alpha channel
rgba_with_alpha = np.column_stack([rgba_data, alpha_column])
# Maintain the same dtype as original
rgba_with_alpha = rgba_with_alpha.astype(rgba_data.dtype)
# Assign back
tmp_obj.point_data["rgba"] = rgba_with_alpha
if use_opacity is False:
rgba_data = rgba_data[:, :3]
pv_plotter.add_points(
tmp_obj.coords,
render_points_as_spheres=self.objs_conf["points"]["spheres"],
point_size=self.objs_conf["points"]["spheres_radius"],
scalars=rgba_data,
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
elif isinstance(tmp_obj, cltsurf.Surface):
if not use_opacity:
# delete the alpha channel if exists
if "rgba" in tmp_obj.mesh.point_data:
rgba_data = tmp_obj.mesh.point_data["rgba"][:, :3]
else:
rgba_data = tmp_obj.mesh.point_data["rgba"][:, :3]
# Check if data is in 0-1 range or 0-255 range
if rgba_data.max() <= 1.0:
# Data is in 0-1 range
alpha_column = np.ones(rgba_data.shape[0]) * opacity
else:
# Data is in 0-255 range
alpha_column = np.ones(rgba_data.shape[0]) * opacity * 255
# Add alpha channel
rgba_with_alpha = np.column_stack([rgba_data, alpha_column])
# Maintain the same dtype as original
rgba_with_alpha = rgba_with_alpha.astype(rgba_data.dtype)
# Assign back
tmp_obj.mesh.point_data["rgba"] = rgba_with_alpha
pv_plotter.add_mesh(
copy.deepcopy(tmp_obj.mesh),
scalars="rgba",
rgb=True,
ambient=self.figure_conf["mesh_ambient"],
diffuse=self.figure_conf["mesh_diffuse"],
specular=self.figure_conf["mesh_specular"],
specular_power=self.figure_conf["mesh_specular_power"],
smooth_shading=self.figure_conf["mesh_smooth_shading"],
show_scalar_bar=False,
)
# Set the camera view
tmp_view = valid_views[view_idx]
# Replace merg from the view id if needed
if "merg" in tmp_view:
tmp_view = tmp_view.replace("merg", "lh")
camera_params = self.views_conf[tmp_view]
pv_plotter.camera_position = camera_params["view"]
pv_plotter.camera.azimuth = camera_params["azimuth"]
pv_plotter.camera.elevation = camera_params["elevation"]
pv_plotter.camera.zoom(camera_params["zoom"])
# # And place colorbars at their positions
# if len(colorbar_list):
# for colorbar_dict in colorbar_list:
# if colorbar_dict is not False:
# row, col = colorbar_dict["position"]
# orientation = colorbar_dict["orientation"]
# colorbar_id = colorbar_dict["map_name"]
# colormap = colorbar_dict["colormap"]
# colorbar_title = colorbar_dict["title"]
# vmin = colorbar_dict["vmin"]
# vmax = colorbar_dict["vmax"]
# pv_plotter.subplot(row, col)
# if colormap == "colortable":
# pass # Currently, no colorbar for categorical maps is implemented
# else:
# visutils.add_colorbar(
# self,
# plotter=pv_plotter,
# colorbar_subplot=(row, col),
# vmin=vmin,
# vmax=vmax,
# map_name=colorbar_id,
# colormap=colormap,
# colorbar_title=colorbar_title,
# colorbar_position=orientation,
# )
# Handle final rendering - either save, display blocking, or display non-blocking
visutils.finalize_plot(pv_plotter, save_mode, save_path, use_threading)
###############################################################################################
[docs]
def list_available_view_names(self) -> List[str]:
"""
List available view names for dynamic view selection.
Returns
-------
List[str]
Available view names that can be used in views parameter:
['Lateral', 'Medial', 'Dorsal', 'Ventral', 'Rostral', 'Caudal'].
Examples
--------
>>> plotter = BrainPlotter()
>>> view_names = plotter.list_available_view_names()
>>> print(f"Available views: {view_names}")
"""
return visutils.list_available_view_names(self)
###############################################################################################
[docs]
def list_available_layouts(self) -> Dict[str, Dict[str, Any]]:
"""
Display available visualization layouts and their configurations.
Returns
-------
Dict[str, Dict[str, Any]]
Dictionary containing detailed layout information for each configuration.
Keys are configuration names, values contain shape, window_size,
num_views, and views information.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>> layouts = plotter.list_available_layouts()
>>> print(f"Available layouts: {list(layouts.keys())}")
>>>
>>> # Access specific layout info
>>> layout_info = layouts['8_views']
>>> print(f"Shape: {layout_info['shape']}")
>>> print(f"Views: {layout_info['num_views']}")
"""
return visutils.list_available_layouts(self)
###############################################################################################
[docs]
def get_layout_details(self, views: str) -> Optional[Dict[str, Any]]:
"""
Get detailed information about a specific layout configuration.
Parameters
----------
views : str
Name of the configuration to examine.
Returns
-------
Dict[str, Any] or None
Detailed configuration information if found, None if configuration
doesn't exist. Contains shape, window_size, and views information.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>> details = plotter.get_layout_details("8_views")
>>> if details:
... print(f"Grid shape: {details['shape']}")
... print(f"Views: {len(details['views'])}")
>>>
>>> # Handle non-existent configuration
>>> details = plotter.get_layout_details("invalid_config")
"""
return visutils.get_layout_details(self, views)
###############################################################################################
###############################################################################################
def _list_all_views_and_layouts(self) -> List[str]:
"""
List available layout configurations from the loaded JSON file.
Returns
-------
List[str]
List of configuration names available for plotting.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>> layouts = plotter._list_all_views_and_layouts()
>>> print(layouts)
['8_views', '8_views_8x1', '8_views_1x8', '6_views', '6_views_6x1', '6_views_1x6', '4_views', '4_views_4x1', '4_views_1x4', '2_views', 'lateral', 'medial', 'dorsal', 'ventral', 'rostral', 'caudal']
"""
all_views_and_layouts = visutils.list_multiviews_layouts(
self
) + visutils.list_single_views(self)
return all_views_and_layouts
###############################################################################################
def _list_multiviews_layouts(self) -> List[str]:
"""
List available multi-view configurations from the loaded JSON file.
Returns
-------
List[str]
List of multi-view configuration names available for plotting.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>> multiviews = plotter._list_multiviews_layouts()
>>> print(multiviews)
['8_views', '6_views', '4_views', '8_views_8x1', '6_views_6x1', '4_views_4x1', '8_views_1x8', '6_views_1x6', '4_views_1x4', '2_views']
"""
return visutils.list_multiviews_layouts(self)
###############################################################################################
def _list_single_views(self) -> List[str]:
"""
List available single view names.
"""
return visutils.list_single_views(self)
###############################################################################################
def _create_threaded_plot(self, plotter: pv.Plotter) -> None:
"""
Create and show plot in a separate thread for non-blocking visualization.
Parameters
----------
plotter : pv.Plotter
PyVista plotter instance ready for display.
"""
visutils.create_threaded_plot(plotter)
print("Plot opened in separate window. Terminal remains interactive.")
print("Note: Plot window may take a moment to appear.")
###############################################################################################
[docs]
def list_available_themes(self) -> None:
"""
Display all available themes with descriptions and previews.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>> plotter.list_available_themes()
"""
themes = {
"dark": "Dark background with white text (default)",
"light": "Light background with dark text",
"high_contrast": "Maximum contrast for presentations",
"minimal": "Clean, minimal styling",
"publication": "Optimized for academic publications",
"colorful": "Vibrant colors for engaging visuals",
}
print("🎨 Available Themes:")
print("=" * 50)
for i, (theme_name, description) in enumerate(themes.items(), 1):
print(f"{i:2d}. {theme_name:12s} - {description}")
print("\n💡 Usage:")
print(" plotter.apply_theme('light') # Apply light theme")
print(" plotter.apply_theme('publication', auto_save=False) # Don't save")
print("=" * 50)
###############################################################################################
def _get_valid_views(self, views: Union[str, List[str]]) -> List[str]:
"""
Get valid view names from the provided views parameter.
Parameters
----------
views : str or List[str]
Either a single view name or a list of view names.
Returns
-------
List[str]
List of valid view names.
Raises
------
ValueError
If no valid views are found.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>> valid_views = plotter._get_valid_views("8_views")
>>> print(valid_views)
['lateral', 'medial', 'dorsal', 'ventral', 'rostral', 'caudal']
"""
return visutils.get_valid_views(self, views)
###############################################################################################
###############################################################################################
[docs]
def apply_theme(self, theme_name: str, auto_save: bool = False) -> None:
"""
Apply predefined visual themes to quickly customize plot appearance.
Parameters
----------
theme_name : str
Name of the theme to apply. Available themes:
- "dark" : Dark background with white text
- "light" : Light background with dark text
- "high_contrast" : Maximum contrast for presentations
- "minimal" : Clean, minimal styling
- "publication" : Optimized for academic publications
- "colorful" : Vibrant colors for engaging visuals
auto_save : bool, default True
Whether to automatically save theme to configuration file.
Raises
------
ValueError
If theme_name is not recognized.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>>
>>> # Apply light theme for presentations
>>> plotter.apply_theme("light")
>>>
>>> # Use high contrast for better visibility
>>> plotter.apply_theme("high_contrast")
>>>
>>> # Publication-ready styling
>>> plotter.apply_theme("publication")
"""
visutils.apply_theme(self, theme_name, auto_save)
###############################################################################################
[docs]
def save_config(self) -> None:
"""
Save current configuration (both figure_conf and views_conf) to JSON file.
Raises
------
IOError
If unable to write to configuration file.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>> plotter.update_figure_config(background_color="white", auto_save=False)
>>> plotter.save_config() # Manually save changes
"""
visutils.save_config(self)
[docs]
def preview_theme(self, theme_name: str) -> None:
"""
Preview a theme's parameters without applying them.
Parameters
----------
theme_name : str
Name of the theme to preview.
Examples
--------
>>> plotter = BrainPlotter("configs.json")
>>> plotter.preview_theme("light") # See what light theme would change
"""
visutils.preview_theme(self, theme_name)
####################################################################################################
####################################################################################################
############ ############
############ ############
############ Useful functions ############
############ ############
############ ############
####################################################################################################
####################################################################################################
[docs]
def create_carpet_plot(
data: Union[np.ndarray, cltparc.RegionTimeSeries],
structure_names: list[str] = None,
*,
time_points: Optional[np.ndarray] = None,
tr: Optional[float] = None,
groups: Optional[dict[str, list[int]]] = None,
group_colors: Optional[list[str]] = None,
groups_title: str = None,
fd_trace: Optional[np.ndarray] = None,
global_signal: Optional[np.ndarray] = None,
normalize_rows: bool = True,
figsize: tuple[float, float] = (15, 10),
cmap: str = "RdBu_r",
unknown_color: Union[str, np.ndarray, tuple] = "#888888",
center_colormap: bool = True,
vmax: Optional[float] = None,
fd_threshold: float = 0.5,
show_structure_names: bool = True,
x_label: str = "Volume index",
y_label: str = "Brain structures",
title: str = "Carpet Plot",
save_path: Optional[Union[str, Path]] = None,
dpi: int = 150,
) -> dict:
"""Create a carpet plot for brain structure time series data.
Parameters
----------
data : np.ndarray, shape (n_structures, n_timepoints)
Time series data for each brain structure.
structure_names : list of str
Names of the brain structures. Must have length ``data.shape[0]``. Defaults to generic names if not provided.
time_points : np.ndarray, optional
Explicit time-point values for the x-axis. When *None*, volume
indices are used (scaled by *tr* if provided).
tr : float, optional
Repetition time in seconds. When provided and *time_points* is
*None*, the x-axis is expressed in seconds.
groups : dict[str, list[int | str]], optional
Mapping from group name to a list of row indices (int) **or**
substrings matched against *structure_names* (str), or a mix of
both. Rows not covered by any group are collected into an
"Unknown" group coloured by *unknown_color*. Example::
groups = {
"Cortex": list(range(60)),
"Subcortex": list(range(60, 75)),
"Cerebellum": ["Cereb", "Vermis"],
}
group_colors : list of str, optional
Hex colour strings for each group in the order they appear in
*groups*. Falls back to a built-in ten-colour palette.
groups_title : str, optional
Title for the groups colorbar. If *None*, defaults to "Groups".
unknown_color : str, default ``"#888888"``
Colour used for the automatic "Unknown" group that collects any
rows not assigned to an explicit group.
fd_trace : np.ndarray, optional
Framewise displacement trace (length n_timepoints). When provided
a motion panel is drawn above the carpet.
global_signal : np.ndarray, optional
Global signal trace (length n_timepoints). Overlaid on the top
panel as a blue line (on a twin y-axis when *fd_trace* is also
supplied).
normalize_rows : bool, default True
If *True*, each row is z-scored independently before plotting.
figsize : tuple of float, default (15, 10)
Figure size (width, height) in inches.
cmap : str, default ``"RdBu_r"``
Matplotlib colormap name.
center_colormap : bool, default True
If *True*, ``TwoSlopeNorm`` is used so zero maps to the midpoint.
vmax : float, optional
Colour scale maximum (absolute value). Derived from data when not
provided.
fd_threshold : float, default 0.5
FD threshold in mm shown as a dashed reference line.
show_structure_names : bool, default True
Whether to display y-axis structure labels.
x_label : str, default ``"Volume index"``
Label for the x-axis. If *time_points* is provided, this should be set to something like ``"Time (s)"``.
y_label : str, default ``"Brain structures"``
title : str, default ``"Carpet Plot"``
Plot title.
save_path : str or Path, optional
Destination path for the figure.
dpi : int, default 150
Resolution used when *save_path* is given.
Returns
-------
dict
``"fig"`` – ``matplotlib.figure.Figure``
``"ax_carpet"`` – main carpet ``Axes``
``"ax_top"`` – top panel ``Axes`` (*None* if not requested)
``"ax_strip"`` – *None* (kept for API stability)
``"im"`` – ``AxesImage`` returned by ``imshow``
``"cbar"`` – ``Colorbar`` object
Examples
--------
>>> import numpy as np
>>> from clabtoolkit.visualizationtools import create_carpet_plot
>>> >>> # Generate synthetic data
>>> n_structures = 100
>>> n_timepoints = 200
>>> data = np.random.randn(n_structures, n_timepoints)
>>> structure_names = [f"Struct_{i}" for i in range(n_structures)]
>>> >>> # Define groups
>>> groups = {
... "Group A": list(range(0, 50)),
... "Group B": list(range(50, 80)),
... "Group C": list(range(80, 100)),
... }
>>> group_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
>>> >>> # Create carpet plot
>>> result = create_carpet_plot(
... data=data,
... structure_names=structure_names,
... groups=groups,
... group_colors=group_colors,
... normalize_rows=True,
... figsize=(12, 8),
... cmap="RdBu_r")
"""
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
if isinstance(data, cltparc.RegionTimeSeries):
structure_names = data.region_names
data = data.data
if structure_names is None:
structure_names = cltmisc.create_names_from_indices(np.arange(data.shape[0]))
data = np.asarray(data, dtype=float)
if data.ndim != 2:
raise ValueError(
f"data must be 2-D (n_structures × n_timepoints), got shape {data.shape}."
)
n_structures, n_timepoints = data.shape
structure_names = list(structure_names)
if len(structure_names) != n_structures:
raise ValueError(
f"len(structure_names) ({len(structure_names)}) must equal "
f"data.shape[0] ({n_structures})."
)
for param_name, arr in [("fd_trace", fd_trace), ("global_signal", global_signal)]:
if arr is not None and len(arr) != n_timepoints:
raise ValueError(
f"{param_name} length ({len(arr)}) must equal n_timepoints ({n_timepoints})."
)
# ------------------------------------------------------------------
# Time axis
# ------------------------------------------------------------------
if time_points is None:
if tr is not None:
time_points = np.arange(1, n_timepoints + 1) * tr # vol 0 → 1×TR
else:
time_points = np.arange(n_timepoints, dtype=float)
time_points = np.asarray(time_points, dtype=float)
if x_label is None:
x_label = "Time (s)" if tr is not None else "Volume index"
# ------------------------------------------------------------------
# Groups: resolve indices, reorder rows, build sequential positions
# ------------------------------------------------------------------
# `_resolved` is an *ordered* list of dicts:
# { "name": str, "color": str, "rows": list[int] }
# where "rows" gives the NEW (post-reorder) consecutive row positions.
_resolved: list[dict] = []
if groups is not None:
_gcolors_base = group_colors or cltcol.create_distinguishable_colors(
len(groups), output_format="hex", exclude_colors=unknown_color
)
# ---- Step 1: resolve each group's original indices ----
covered: list[int] = []
raw_groups: list[dict] = [] # name, color, original_ids
# Track which original row indices have already been claimed so that
# overlapping assignments are resolved in favour of the first group.
already_claimed: set[int] = set()
for gi, (gname, id_list) in enumerate(groups.items()):
# ---- Normalise the value to a plain Python list ----
if isinstance(id_list, str):
# A bare string is treated as a single substring matcher
id_list = [id_list]
elif isinstance(id_list, np.ndarray):
id_list = id_list.tolist()
else:
id_list = list(id_list)
numbers = [
int(x) for x in id_list if isinstance(x, (int, float, np.integer))
]
strings = [x for x in id_list if isinstance(x, str)]
# Match substrings against structure_names
str_ids: list[int] = []
for substr in strings:
str_ids += [
i
for i, s in enumerate(structure_names)
if substr.lower() in s.lower()
]
# Deduplicate within this group, then remove any rows already
# claimed by an earlier group (first-seen wins).
candidate_ids = sorted(set(numbers + str_ids))
orig_ids = [i for i in candidate_ids if i not in already_claimed]
already_claimed.update(orig_ids)
covered.extend(orig_ids)
color = _gcolors_base[gi % len(_gcolors_base)]
raw_groups.append({"name": gname, "color": color, "orig_ids": orig_ids})
# ---- Step 2: collect ungrouped rows ----
ungrouped = sorted(set(range(n_structures)) - set(covered))
if ungrouped:
raw_groups.append(
{"name": "Unknown", "color": unknown_color, "orig_ids": ungrouped}
)
# ---- Step 3: build the final row order & remap to new positions ----
final_order: list[int] = []
for grp in raw_groups:
final_order.extend(grp["orig_ids"])
# Reorder data and names NOW — before z-scoring and imshow
data = data[final_order, :]
structure_names = [structure_names[i] for i in final_order]
n_structures = len(structure_names)
# Map each group's original ids → new consecutive row positions.
# Groups whose ids were entirely claimed by earlier groups are skipped
# so they never appear in rectangles or the legend.
orig_to_new = {orig: new for new, orig in enumerate(final_order)}
for grp in raw_groups:
if not grp["orig_ids"]: # fully emptied by overlap deduplication
continue
new_rows = sorted(orig_to_new[i] for i in grp["orig_ids"])
_resolved.append(
{"name": grp["name"], "color": grp["color"], "rows": new_rows}
)
# ------------------------------------------------------------------
# Per-row z-score normalisation (after reordering!)
# ------------------------------------------------------------------
if normalize_rows:
row_mean = data.mean(axis=1, keepdims=True)
row_std = data.std(axis=1, keepdims=True)
row_std[row_std == 0] = 1.0
plot_data = (data - row_mean) / row_std
cbar_label = "Z-score"
else:
plot_data = data.copy()
cbar_label = "Signal intensity"
# ------------------------------------------------------------------
# Colour normalisation
# ------------------------------------------------------------------
if center_colormap:
_vmax = float(vmax) if vmax is not None else float(np.abs(plot_data).max())
norm: Normalize = TwoSlopeNorm(vmin=-_vmax, vcenter=0.0, vmax=_vmax)
else:
norm = Normalize(
vmin=float(plot_data.min()),
vmax=float(vmax) if vmax is not None else float(plot_data.max()),
)
# ------------------------------------------------------------------
# Figure layout
# ------------------------------------------------------------------
has_top = fd_trace is not None or global_signal is not None
carpet_row = 1 if has_top else 0
fig = plt.figure(figsize=figsize)
gs = gridspec.GridSpec(
2 if has_top else 1,
1,
figure=fig,
height_ratios=[1, 5] if has_top else [1],
hspace=0.04,
)
ax_carpet: plt.Axes = fig.add_subplot(gs[carpet_row, 0])
ax_top: Optional[plt.Axes] = (
fig.add_subplot(gs[0, 0], sharex=ax_carpet) if has_top else None
)
# ------------------------------------------------------------------
# Carpet panel
# ------------------------------------------------------------------
im = ax_carpet.imshow(
plot_data,
aspect="auto",
cmap=cmap,
norm=norm,
interpolation="nearest",
)
# X-axis
n_xticks = min(10, n_timepoints)
xtick_idx = np.linspace(0, n_timepoints - 1, n_xticks, dtype=int).tolist()
ax_carpet.set_xticks(xtick_idx)
ax_carpet.set_xticklabels(
[
f"{time_points[i]:.1f}" if tr is not None else f"{int(time_points[i])}"
for i in xtick_idx
],
fontsize=9,
)
ax_carpet.set_xlabel(x_label, fontsize=11)
# Y-axis — plain black labels
if show_structure_names:
label_size = max(4, 8 - n_structures // 20)
ax_carpet.set_yticks(range(n_structures))
ax_carpet.set_yticklabels(structure_names, fontsize=label_size)
else:
ax_carpet.set_yticks([])
ax_carpet.set_ylabel(y_label, fontsize=11)
# Horizontal white dividers at group boundaries
if _resolved:
seen: set[float] = set()
for grp in _resolved:
if grp["rows"]:
boundary = max(grp["rows"]) + 0.5
if boundary not in seen and 0.0 < boundary < n_structures:
ax_carpet.axhline(boundary, color="white", lw=1.5, alpha=0.75)
seen.add(boundary)
# Colorbar
cbar = fig.colorbar(im, ax=ax_carpet, fraction=0.02, pad=0.01, aspect=40)
cbar.set_label(cbar_label, rotation=270, labelpad=14, fontsize=10)
cbar.ax.tick_params(labelsize=8)
# Title on the topmost axes
(ax_top if ax_top is not None else ax_carpet).set_title(
title, fontsize=14, fontweight="bold", pad=6
)
# ------------------------------------------------------------------
# Top panel (FD + global signal)
# ------------------------------------------------------------------
if ax_top is not None:
t_axis = np.arange(n_timepoints)
ax_gs_twin: Optional[plt.Axes] = None
if fd_trace is not None:
fd_arr = np.asarray(fd_trace, dtype=float)
ax_top.fill_between(
t_axis, 0, fd_arr, color="#e15759", alpha=0.25, label="_nolegend_"
)
ax_top.plot(t_axis, fd_arr, color="#e15759", lw=0.9, label="FD")
ax_top.axhline(
fd_threshold,
color="#e15759",
lw=0.8,
ls="--",
alpha=0.9,
label=f"FD = {fd_threshold} mm",
)
above = fd_arr > fd_threshold
if above.any():
ax_top.fill_between(
t_axis,
fd_arr,
fd_threshold,
where=above,
color="#e15759",
alpha=0.55,
label="_nolegend_",
)
ax_top.set_ylabel("FD (mm)", fontsize=9, color="#e15759")
ax_top.tick_params(axis="y", labelcolor="#e15759", labelsize=8)
if global_signal is not None:
gs_arr = np.asarray(global_signal, dtype=float)
ax_gs = ax_top.twinx() if fd_trace is not None else ax_top
if fd_trace is not None:
ax_gs_twin = ax_gs
ax_gs.plot(t_axis, gs_arr, color="#4e79a7", lw=0.9, label="GS")
ax_gs.set_ylabel("Global signal", fontsize=9, color="#4e79a7")
ax_gs.tick_params(axis="y", labelcolor="#4e79a7", labelsize=8)
handles, labels = ax_top.get_legend_handles_labels()
if ax_gs_twin is not None:
h2, l2 = ax_gs_twin.get_legend_handles_labels()
handles += h2
labels += l2
if handles:
ax_top.legend(
handles,
labels,
loc="upper right",
fontsize=8,
framealpha=0.6,
ncol=len(handles),
)
plt.setp(ax_top.get_xticklabels(), visible=False)
ax_top.tick_params(axis="x", bottom=False)
ax_top.spines["bottom"].set_visible(False)
ax_top.spines["top"].set_visible(False)
# ------------------------------------------------------------------
# Group colour rectangles — blended transform, no text, no overlap.
# x is in axes coordinates (0 = left spine, width ≈ 1 % of plot).
# y is in data coordinates (row indices).
# ------------------------------------------------------------------
if _resolved:
strip_width = 0.010
trans = blended_transform_factory(ax_carpet.transAxes, ax_carpet.transData)
for grp in _resolved:
rows = grp["rows"]
if not rows:
continue
ymin = min(rows) - 0.5
ymax = max(rows) + 0.5
rect = mpatches.Rectangle(
(0.0, ymin),
strip_width,
ymax - ymin,
transform=trans,
clip_on=False,
facecolor=grp["color"],
edgecolor="none",
zorder=5,
)
ax_carpet.add_patch(rect)
# ------------------------------------------------------------------
# Bottom figure legend — no frame, anchored well below x-axis label
# ------------------------------------------------------------------
if _resolved:
legend_patches = [
mpatches.Patch(facecolor=grp["color"], label=grp["name"])
for grp in _resolved
]
n_cols = min(len(_resolved), 6)
if groups_title is None:
fig.legend(
handles=legend_patches,
loc="lower center",
ncol=n_cols,
fontsize=10,
frameon=False,
bbox_to_anchor=(0.5, 0.0),
title_fontsize=10,
)
else:
fig.legend(
handles=legend_patches,
loc="lower center",
ncol=n_cols,
fontsize=10,
frameon=False,
bbox_to_anchor=(0.5, 0.0),
title=groups_title,
title_fontsize=10,
)
# ------------------------------------------------------------------
# Final layout — reserve bottom margin for the legend
# ------------------------------------------------------------------
fig.tight_layout()
if _resolved:
n_legend_rows = max(1, len(_resolved) // 6 + (1 if len(_resolved) % 6 else 0))
fig.subplots_adjust(bottom=0.08 + 0.04 * n_legend_rows)
if save_path is not None:
fig.savefig(Path(save_path), dpi=dpi, bbox_inches="tight")
return {
"fig": fig,
"ax_carpet": ax_carpet,
"ax_top": ax_top,
"ax_strip": None,
"im": im,
"cbar": cbar,
}