Source code for clabtoolkit.dwitools

import os
import numpy as np
from pathlib import Path
from typing import List, Union

import nibabel as nib

from skimage import measure

# Importing the internal modules
from . import misctools as cltmisc
from . import plottools as cltplot
from . import colorstools as cltcol

import pyvista as pv


####################################################################################################
####################################################################################################
############                                                                            ############
############                                                                            ############
############                Section 1: Methods to work with DWI images                  ############
############                                                                            ############
############                                                                            ############
####################################################################################################
####################################################################################################
[docs] def delete_dwi_volumes( in_image: str, bvec_file: str = None, bval_file: str = None, out_image: str = None, bvals_to_delete: Union[int, List[Union[int, tuple, list, str, np.ndarray]]] = None, vols_to_delete: Union[int, List[Union[int, tuple, list, str, np.ndarray]]] = None, ) -> str: """ Remove specific volumes from DWI image. If no volumes are specified, the function will remove the last B0s of the DWI image. Parameters ---------- in_image : str Path to the diffusion weighted image file. bvec_file : str, optional Path to the bvec file. If None, it will assume the bvec file is in the same directory as the DWI file with the same name but with the .bvec extension. bval_file : str, optional Path to the bval file. If None, it will assume the bval file is in the same directory as the DWI file with the same name but with the .bval extension. out_image : str, optional Path to the output file. If None, it will assume the output file is in the same directory as the DWI file with the same name but with the .nii.gz extension. The original file will be overwritten if the output file is not specified. bvals_to_delete : int, list, optional List of bvals to delete. If None, it will assume the bvals to delete are the last B0s of the DWI image. Some conditions could be used to delete the volumes. For example: 1. If you want to delete all the volumes with bval = 0, you can use: bvals_to_delete = [0] 2. If you want to delete all the volumes with b-values higher than 1000, you can use: bvals_to_delete = [bvals > 1000] or bvals_to_delete = [bvals >= 1000] if you want to include the 1000 bvals. 3. If you want to delete all the volumes with b-values between 1000 and 3000 you can use: bvals_to_delete = [1000 < bvals < 3000] or bvals_to_delete = [1000 <= bvals < 3000] if you want to include the 1000 but not the 3000 bvals. For more complex conditions, you can see the function get_indices_by_condition. Included in the clabtoolkit.misctools module. vols_to_delete : int, list, optional Indices of the volumes to delete. If None, it will assume the volumes to delete are the last B0s of the DWI image. Some conditions could be used to delete the volumes. For example: 1. If you want to delete the first 3 volumes, you can use: vols_to_delete = [0, 1, 2] 2. If you want to delete the volumes from 0 to 10, you can use: vols_to_delete = [0:10] or vols_to_delete = [0-10] 3. If you want to delete the volumes from 0 to 10 and 20 to 30, you can use: vols_to_delete = [0:10, 20:30] or vols_to_delete = [0-10, 20-30] 4. If you want to delete the volumes from 0 to 10 and the volumes 40 and 60, you can use: vols_to_delete = [0:10, 40, 60] or vols_to_delete = [0-10, 40, 60] or vols_to_delete = ['0-10, 40, 60'], etc For more complex conditions, you can see the function build_indices. Included in the clabtoolkit.misctools module. If both bvals_to_delete and vols_to_delete are specified, the function will remove the volumes with the bvals specified and the volumes specified in the vols_to_delete list. The function will unify all the indices in a single list and remove the volumes from the DWI image. Returns ------- out_image : str Path to the diffusion weighted image file. out_bvecs_file : str Path to the bvec file. If None, it will assume the bvec file is in the same directory as the DWI file with the same name but with the .bvec extension. out_bvals_file : str Path to the bval file. If None, it will assume the bval file is in the same directory as the DWI file with the same name but with the .bval extension. vols2rem : list List of volumes removed. Notes ----- IMPORTANT: The function will overwrite the original DWI file if the output file is not specified. IMPORTANT: The function will overwrite the original bvec and bval files if the output file is not specified. IMPORTANT: The function will remove the last B0s of the DWI image if no volumes are specified. Examples ----------- >>> delete_volumes('dwi.nii.gz') # will remove the last B0s. The original file will be overwritten. >>> delete_volumes('dwi.nii.gz', out_image='dwi_clean.nii.gz') # will remove the last B0s and save the output in dwi_clean.nii.gz >>> delete_volumes('dwi.nii.gz', vols_to_delete=[0, 1, 2]) # will remove the first 3 volumes >>> delete_volumes('dwi.nii.gz', bvec_file='dwi.bvec', bval_file='dwi.bval') # will remove the last B0s and it will assume the bvec and bval files are in the same directory as the DWI file. >>> delete_volumes('dwi.nii.gz', bvec_file='dwi.bvec', bval_file='dwi.bval', bvals_to_delete= [3000, "bvals >=5000"], out_image='dwi_clean.nii.gz') # will remove the volumes with bvals equal to 3000 and equal or higher than 5000. The output will be saved in in dwi_clean.nii.gz IMPORTANT: the b-values file dwi.bval should be in the same directory as the DWI file. """ # Creating the name for the json file if os.path.isfile(in_image): pth = os.path.dirname(in_image) fname = os.path.basename(in_image) else: raise FileNotFoundError(f"File {in_image} not found.") if fname.endswith(".nii.gz"): flname = fname[0:-7] elif fname.endswith(".nii"): flname = fname[0:-4] # Checking if the file exists. If it is None assume it is in the same directory with the same name as the DWI file but with the .bvec extensions. if bvec_file is None: bvec_file = os.path.join(pth, flname + ".bvec") # Checking if the file exists. If it is None assume it is in the same directory with the same name as the DWI file but with the .bval extensions. if bval_file is None: bval_file = os.path.join(pth, flname + ".bval") # Checking the ouput basename if out_image is not None: fl_out_name = os.path.basename(out_image) if fl_out_name.endswith(".nii.gz"): fl_out_name = fl_out_name[0:-7] elif fl_out_name.endswith(".nii"): fl_out_name = fl_out_name[0:-4] fl_out_path = os.path.dirname(out_image) if not os.path.isdir(fl_out_path): raise FileNotFoundError(f"Output path {fl_out_path} does not exist.") else: fl_out_name = fname fl_out_path = pth # Checking the volumes to delete if vols_to_delete is not None: if not isinstance(vols_to_delete, list): vols_to_delete = [vols_to_delete] vols_to_delete = cltmisc.build_indices(vols_to_delete, nonzeros=False) # Checking the bvals to delete. This variable will overwrite the vols_to_delete variable if it is not None. if bvals_to_delete is not None: if not isinstance(bvals_to_delete, list): bvals_to_delete = [bvals_to_delete] # Loading bvalues if os.path.exists(bval_file): bvals = np.loadtxt(bval_file, dtype=float, max_rows=5).astype(int) tmp_bvals = cltmisc.build_values_with_conditions( bvals_to_delete, bvals=bvals, nonzeros=False ) tmp_bvals_to_delete = np.where(np.isin(bvals, tmp_bvals))[0] if vols_to_delete is not None: vols_to_delete += tmp_bvals_to_delete.tolist() # Remove duplicates vols_to_delete = list(set(vols_to_delete)) else: vols_to_delete = tmp_bvals_to_delete if vols_to_delete is not None: # check if vols_to_delete is not empty if len(vols_to_delete) == 0: print(f"No volumes to delete. The volumes to delete are empty.") return in_image # Loading the DWI image mapI = nib.load(in_image) # getting the dimensions of the image dim = mapI.shape # Only remove the volumes is the image is 4D if len(dim) == 4: # Getting the number of volumes nvols = dim[3] if vols_to_delete is not None: if len(vols_to_delete) == nvols: # If the number of volumes to delete is equal to the number of volumes, send a warning and return the original file print( f"Number of volumes to delete is equal to the number of volumes. No volumes will be deleted." ) return in_image # Check if the volumes to delete are in the range of the number of volumes if np.max(vols_to_delete) >= nvols: # Detect which values of the list vols_to_delete are out of range # Convert the list to a numpy array vols_to_delete = np.array(vols_to_delete) # Check if the values are out of range out_of_range = np.where(vols_to_delete >= nvols)[0] # Raise an error with the out of range values raise ValueError( f"Volumes out of the range: {vols_to_delete[out_of_range]} . The values should be between 0 and {nvols-1}." ) # Check if the volumes to delete are in the range of the number of volumes if np.min(vols_to_delete) < 0: raise ValueError( f"Volumes to delete {vols_to_delete} are out of range. The values shoudl be between 0 and {nvols-1}." ) vols2rem = np.where(np.isin(np.arange(nvols), vols_to_delete))[0] vols2keep = np.where( np.isin(np.arange(nvols), vols_to_delete, invert=True) )[0] else: # Loading bvalues if os.path.exists(bval_file): bvals = np.loadtxt(bval_file, dtype=float, max_rows=5).astype(int) mask = bvals < 10 lb_bvals = measure.label(mask, 2) if np.max(lb_bvals) > 1 and lb_bvals[-1] != 0: # Removing the last cluster of B0s lab2rem = lb_bvals[-1] vols2rem = np.where(lb_bvals == lab2rem)[0] vols2keep = np.where(lb_bvals != lab2rem)[0] else: # Exit the function if there are no B0s to remove at the end of the volume. Leave a message. print("No B0s to remove at the end of the volume.") return in_image else: raise FileNotFoundError( f"File {bval_file} not found. It is mandatory if the volumes to remove are not specified (vols_to_delete)." ) diffData = mapI.get_fdata() affine = mapI.affine # Removing the volumes array_data = np.delete(diffData, vols2rem, 3) # Temporal image and diffusion scheme array_img = nib.Nifti1Image(array_data, affine) nib.save(array_img, out_image) # Saving new bvecs and new bvals if os.path.isfile(bvec_file): bvecs = np.loadtxt(bvec_file, dtype=float) if bvecs.shape[0] == 3: select_bvecs = bvecs[:, vols2keep] else: select_bvecs = bvecs[vols2keep, :] select_bvecs.transpose() if out_image.endswith("nii.gz"): out_bvecs_file = out_image.replace(".nii.gz", ".bvec") elif out_image.endswith("nii"): out_bvecs_file = out_image.replace(".nii", ".bvec") np.savetxt(out_bvecs_file, select_bvecs, fmt="%f") else: out_bvecs_file = None # Saving new bvals if os.path.isfile(bval_file): bvals = np.loadtxt(bval_file, dtype=float, max_rows=5).astype(int) select_bvals = bvals[vols2keep] select_bvals.transpose() if out_image.endswith("nii.gz"): out_bvals_file = out_image.replace(".nii.gz", ".bval") elif out_image.endswith("nii"): out_bvals_file = out_image.replace(".nii", ".bval") np.savetxt(out_bvals_file, select_bvals, newline=" ", fmt="%d") else: out_bvals_file = None else: vols2rem = None raise Warning(f"Image {in_image} is not a 4D image. No volumes to remove.") return out_image, out_bvecs_file, out_bvals_file, vols2rem
####################################################################################################
[docs] def get_b0s( dwi_img: str, b0s_img: str, bval_file: str = None, bval_thresh: int = 0 ) -> str: """ Extract B0 volumes from a DWI image and save them as a separate NIfTI file. Parameters ---------- dwi_img : str Path to the input DWI image file. b0s_img : str Path to the output B0 image file. bval_file : str, optional Path to the bval file. If None, it will assume the bval file is in the same directory as the DWI file with the same name but with the .bval extension. The bval file is used to identify the B0 volumes in the DWI image. bval_thresh : int, optional Threshold for identifying B0 volumes. Default is 0. Volumes with b-values below this threshold will be considered B0 volumes. Returns ------- b0s_img : str Path to the output B0 image file. b0_vols : List[int] List of indices of the B0 volumes extracted from the DWI image. Raises ------ FileNotFoundError If the input DWI image file or the bval file does not exist. ValueError If the output path for the B0 image file does not exist. Examples ----------- >>> dwi_img = 'path/to/dwi_image.nii.gz' >>> b0s_img = 'path/to/b0_image.nii.gz' >>> bval_file = 'path/to/bvals.bval' >>> b0s_img, b0_vols = get_b0s(dwi_img, b0s_img, bval_file) >>> print(f"B0 image saved at: {b0s_img}") >>> print(f"B0 volumes indices: {b0_vols}") >>> b0s_img, b0_vols = get_b0s(dwi_img, b0s_img, bval_file, bval_thresh=10) >>> print(f"B0 image saved at: {b0s_img}") >>> print(f"B0 volumes indices: {b0_vols}") >>> All the volumes with b-values below 10 will be considered B0 volumes. >>> b0s_img, b0_vols = get_b0s(dwi_img, b0s_img) >>> print(f"B0 image saved at: {b0s_img}") >>> print(f"B0 volumes indices: {b0_vols}") >>> The bval file will be assumed to be in the same directory as the DWI file with the same name but with the .bval extension. """ # Creating the name for the json file if os.path.isfile(dwi_img): pth = os.path.dirname(dwi_img) fname = os.path.basename(dwi_img) else: raise FileNotFoundError(f"File {dwi_img} not found.") if fname.endswith(".nii.gz"): flname = fname[0:-7] elif fname.endswith(".nii"): flname = fname[0:-4] # Checking if the file exists. If it is None assume it is in the same directory with the same name as the DWI file but with the .bval extensions. if bval_file is None: bval_file = os.path.join(pth, flname + ".bval") # Checking the ouput basename if b0s_img is not None: fl_out_name = os.path.basename(b0s_img) if fl_out_name.endswith(".nii.gz"): fl_out_name = fl_out_name[0:-7] elif fl_out_name.endswith(".nii"): fl_out_name = fl_out_name[0:-4] fl_out_path = os.path.dirname(b0s_img) if not os.path.isdir(fl_out_path): raise FileNotFoundError(f"Output path {fl_out_path} does not exist.") else: fl_out_name = fname fl_out_path = pth # Loading bvalues if os.path.exists(bval_file): bvals = np.loadtxt(bval_file, dtype=float, max_rows=5).astype(int) # Generate search cad cad = ["bvals > " + str(bval_thresh)] # Get the indices of the volumes that will be removed vols2rem = cltmisc.build_indices_with_conditions( cad, bvals=bvals, nonzeros=False ) b0_vols = np.setdiff1d(np.arange(bvals.shape[0]), vols2rem) if len(vols2rem) == 0: print(f"No B0s to remove. The volumes to delete are empty.") return dwi_img else: mapI = nib.load(dwi_img) diffData = mapI.get_fdata() affine = mapI.affine # Removing the volumes array_data = np.delete(diffData, vols2rem, 3) # Temporal image and diffusion scheme array_img = nib.Nifti1Image(array_data, affine) nib.save(array_img, b0s_img) return b0s_img, b0_vols
############################################################################################################
[docs] def maps_from_tensor_eigenvalues( eigvals: Union[str, Path, list, tuple], out_basename: Union[str, Path], dtmaps: list = ["all"], overwrite: bool = False, ) -> dict: """ Compute scalar maps derived from diffusion tensor eigenvalues. Eigenvalues can be supplied either as a single 4D NIfTI image (volumes ordered as λ1, λ2, λ3 along the 4th axis) or as a list/tuple of three separate 3D NIfTI files ``[l1_path, l2_path, l3_path]``. Division-by-zero voxels are handled safely: whenever the denominator is zero the result at that voxel is set to 0. Parameters ---------- eigvals : str or list/tuple of str Path to a 4D eigenvalue NIfTI image **or** a list/tuple of three paths to the individual eigenvalue volumes ``[λ1, λ2, λ3]``. out_basename : str Full path prefix for the output files. The map tag and ``.nii.gz`` extension are appended automatically (e.g. ``/path/sub-01_desc-DTI`` → ``/path/sub-01_desc-DTI_FA.nii.gz``). dtmaps : list of str, optional Scalar maps to compute. Use ``['all']`` (default) to compute every supported map. Supported tags (case-insensitive): ======== ============================================== Tag Description ======== ============================================== ``AD`` Axial Diffusivity (λ1) ``RD`` Radial Diffusivity ((λ2 + λ3) / 2) ``MD`` Mean Diffusivity ((λ1 + λ2 + λ3) / 3) ``FA`` Fractional Anisotropy ``CL`` Linear Anisotropy Coefficient ``CP`` Planar Anisotropy Coefficient ``CS`` Spherical Anisotropy Coefficient ``VF`` Volume Fraction ``GA`` Geodesic Anisotropy ``RA`` Relative Anisotropy ======== ============================================== overwrite : bool, optional If ``True``, recompute and overwrite existing output files. Default is ``False``. Returns ------- dict Dictionary mapping each requested tag to the path of the saved NIfTI file, or to an empty string if the file could not be created. Raises ------ ValueError If ``eigvals`` is not a str, list, or tuple; or if a list/tuple does not contain exactly three elements. FileNotFoundError If any of the supplied eigenvalue paths do not exist. Examples -------- >>> # 4D eigenvalue image >>> maps = compute_scalar_maps_from_tensor( ... "sub-01_eigvals.nii.gz", ... "out/sub-01", ... dtmaps=["FA", "MD"], ... ) >>> # Three separate eigenvalue files >>> maps = compute_scalar_maps_from_tensor( ... ["sub-01_l1.nii.gz", "sub-01_l2.nii.gz", "sub-01_l3.nii.gz"], ... "out/sub-01", ... dtmaps=["all"], ... ) """ # ------------------------------------------------------------------ # # Input validation and eigenvalue loading # ------------------------------------------------------------------ # # Normalize Path objects to strings if isinstance(out_basename, Path): out_basename = str(out_basename) if isinstance(eigvals, Path): eigvals = str(eigvals) elif isinstance(eigvals, (list, tuple)): eigvals = [str(p) if isinstance(p, Path) else p for p in eigvals] if isinstance(eigvals, str): if not os.path.isfile(eigvals): raise FileNotFoundError(f"Eigenvalue file not found: {eigvals}") ref_img = nib.load(eigvals) data4d = ref_img.get_fdata() if data4d.ndim != 4 or data4d.shape[3] < 3: raise ValueError( "4D eigenvalue image must have at least 3 volumes along the 4th axis." ) affine = ref_img.affine l1_data = data4d[..., 0] l2_data = data4d[..., 1] l3_data = data4d[..., 2] elif isinstance(eigvals, (list, tuple)): if len(eigvals) != 3: raise ValueError( "When supplying separate eigenvalue files, exactly 3 paths are required " f"(got {len(eigvals)})." ) missing = [p for p in eigvals if not os.path.isfile(p)] if missing: raise FileNotFoundError(f"Eigenvalue file(s) not found: {missing}") ref_img = nib.load(eigvals[0]) affine = ref_img.affine l1_data = ref_img.get_fdata() l2_data = nib.load(eigvals[1]).get_fdata() l3_data = nib.load(eigvals[2]).get_fdata() else: raise ValueError( "'eigvals' must be a path string or a list/tuple of three path strings." ) # ------------------------------------------------------------------ # # Output directory # ------------------------------------------------------------------ # out_path = os.path.dirname(out_basename) if out_path and not os.path.isdir(out_path): # If the output directory does not exist, raise an error raise FileNotFoundError(f"Output directory does not exist: {out_path}") # ------------------------------------------------------------------ # # Helpers # ------------------------------------------------------------------ # dtmaps = [x.lower() for x in dtmaps] compute_all = dtmaps[0] == "all" def _safe_div(num: np.ndarray, den: np.ndarray) -> np.ndarray: """Element-wise division; returns 0 where denominator is zero.""" return np.divide( num, den, out=np.zeros_like(num, dtype=np.float64), where=den != 0 ) def _safe_log(arr: np.ndarray) -> np.ndarray: """Element-wise natural log; returns 0 where arr <= 0.""" return np.where(arr > 0, np.log(np.maximum(arr, np.finfo(float).tiny)), 0.0) def _save_map(data: np.ndarray, tag: str) -> str: """NaN-fill → save → return path (or '' on failure).""" fpath = f"{out_basename}_{tag}.nii.gz" data = np.where(np.isnan(data), 0.0, data) nib.save(nib.Nifti1Image(data, affine), fpath) return fpath if os.path.isfile(fpath) else "" # Pre-compute quantities shared across multiple maps suma = l1_data + l2_data + l3_data # used by CL, CP, CS, GA RD_data = (l2_data + l3_data) / 2 # used by RD, VF, RA MD_data = suma / 3 # used by MD, FA scalar_maps: dict = {} # ------------------------------------------------------------------ # # AD — Axial Diffusivity # ------------------------------------------------------------------ # if "ad" in dtmaps or compute_all: fpath = f"{out_basename}_AD.nii.gz" if not os.path.isfile(fpath) or overwrite: scalar_maps["AD"] = _save_map(l1_data.copy(), "AD") else: scalar_maps["AD"] = fpath # ------------------------------------------------------------------ # # RD — Radial Diffusivity # ------------------------------------------------------------------ # if "rd" in dtmaps or compute_all: fpath = f"{out_basename}_RD.nii.gz" if not os.path.isfile(fpath) or overwrite: scalar_maps["RD"] = _save_map(RD_data, "RD") else: scalar_maps["RD"] = fpath # ------------------------------------------------------------------ # # MD — Mean Diffusivity # ------------------------------------------------------------------ # if "md" in dtmaps or compute_all: fpath = f"{out_basename}_MD.nii.gz" if not os.path.isfile(fpath) or overwrite: scalar_maps["MD"] = _save_map(MD_data, "MD") else: scalar_maps["MD"] = fpath # ------------------------------------------------------------------ # # FA — Fractional Anisotropy # ------------------------------------------------------------------ # if "fa" in dtmaps or compute_all: fpath = f"{out_basename}_FA.nii.gz" if not os.path.isfile(fpath) or overwrite: num = ( (l1_data - MD_data) ** 2 + (l2_data - MD_data) ** 2 + (l3_data - MD_data) ** 2 ) den = l1_data**2 + l2_data**2 + l3_data**2 FA = np.sqrt(0.5 * _safe_div(num, den)) scalar_maps["FA"] = _save_map(FA, "FA") else: scalar_maps["FA"] = fpath # ------------------------------------------------------------------ # # CL — Linear Anisotropy Coefficient # ------------------------------------------------------------------ # if "cl" in dtmaps or compute_all: fpath = f"{out_basename}_CL.nii.gz" if not os.path.isfile(fpath) or overwrite: CL = _safe_div(l1_data - l2_data, suma) scalar_maps["CL"] = _save_map(CL, "CL") else: scalar_maps["CL"] = fpath # ------------------------------------------------------------------ # # CP — Planar Anisotropy Coefficient # ------------------------------------------------------------------ # if "cp" in dtmaps or compute_all: fpath = f"{out_basename}_CP.nii.gz" if not os.path.isfile(fpath) or overwrite: CP = _safe_div(2 * (l2_data - l3_data), suma) scalar_maps["CP"] = _save_map(CP, "CP") else: scalar_maps["CP"] = fpath # ------------------------------------------------------------------ # # CS — Spherical Anisotropy Coefficient # ------------------------------------------------------------------ # if "cs" in dtmaps or compute_all: fpath = f"{out_basename}_CS.nii.gz" if not os.path.isfile(fpath) or overwrite: CS = _safe_div(3 * l3_data, suma) scalar_maps["CS"] = _save_map(CS, "CS") else: scalar_maps["CS"] = fpath # ------------------------------------------------------------------ # # VF — Volume Fraction # ------------------------------------------------------------------ # if "vf" in dtmaps or compute_all: fpath = f"{out_basename}_VF.nii.gz" if not os.path.isfile(fpath) or overwrite: product = l1_data * l2_data * l3_data den = RD_data**3 VF = 1 - _safe_div(product, den) scalar_maps["VF"] = _save_map(VF, "VF") else: scalar_maps["VF"] = fpath # ------------------------------------------------------------------ # # GA — Geodesic Anisotropy # ------------------------------------------------------------------ # if "ga" in dtmaps or compute_all: fpath = f"{out_basename}_GA.nii.gz" if not os.path.isfile(fpath) or overwrite: D = (l1_data * l2_data * l3_data) ** (1 / 3) log_sum_sq = ( _safe_log(l1_data) ** 2 + _safe_log(l2_data) ** 2 + _safe_log(l3_data) ** 2 ) GA = np.sqrt(_safe_div(log_sum_sq, D)) scalar_maps["GA"] = _save_map(GA, "GA") else: scalar_maps["GA"] = fpath # ------------------------------------------------------------------ # # RA — Relative Anisotropy # ------------------------------------------------------------------ # if "ra" in dtmaps or compute_all: fpath = f"{out_basename}_RA.nii.gz" if not os.path.isfile(fpath) or overwrite: num = ( (l1_data - RD_data) ** 2 + (l2_data - RD_data) ** 2 + (l3_data - RD_data) ** 2 ) RA = np.sqrt(_safe_div(num / 3, suma)) scalar_maps["RA"] = _save_map(RA, "RA") else: scalar_maps["RA"] = fpath return scalar_maps
#################################################################################################### #################################################################################################### ############ ############ ############ ############ ############ Section 2: Class to work with Diffusion Schemes ############ ############ ############ ############ ############ #################################################################################################### ####################################################################################################
[docs] class DiffusionScheme:
[docs] def __init__(self): self.gradients = None # (N, 3) self.bvals = None # (N,) self.scheme_type = None
# ------------------------- # Loaders # -------------------------
[docs] @classmethod def from_bvec_bval_files(cls, bvec_file, bval_file): bvecs = np.loadtxt(bvec_file) bvals = np.loadtxt(bval_file) return cls.from_bvec_bval_arrays(bvecs, bvals)
[docs] @classmethod def from_bvec_bval_arrays(cls, bvecs, bvals): obj = cls() bvecs = np.asarray(bvecs) bvals = np.asarray(bvals) if bvecs.shape[0] == 3: bvecs = bvecs.T obj.gradients = bvecs obj.bvals = bvals.flatten() # Ensure 1D array obj._detect_scheme() return obj
[docs] @classmethod def from_bmatrix_file(cls, bmat_file): bmat = np.loadtxt(bmat_file) return cls.from_bmatrix_array(bmat)
[docs] @classmethod def from_bmatrix_array(cls, bmat): """ bmat shape: (N, 6) with: [Bxx, Byy, Bzz, Bxy, Bxz, Byz] """ obj = cls() bmat = np.asarray(bmat) Bxx, Byy, Bzz, Bxy, Bxz, Byz = bmat.T obj.bvals = Bxx + Byy + Bzz gradients = np.vstack( [ np.sqrt(np.maximum(Bxx, 0)), np.sqrt(np.maximum(Byy, 0)), np.sqrt(np.maximum(Bzz, 0)), ] ).T # Normalize norms = np.linalg.norm(gradients, axis=1) norms[norms == 0] = 1 obj.gradients = gradients / norms[:, None] obj._detect_scheme() return obj
# ------------------------- # Scheme detection # ------------------------- def _detect_scheme(self): """ Detect if the acquisition scheme is shelled (HARDI/multi-shell) or cartesian (DSI). Detection criteria: - Shelled: Few discrete b-value shells (typically 2-6) with many samples per shell - Cartesian (DSI): Many b-value shells (typically >8) with samples distributed across wide range """ b0_thresh = 50 # Threshold to identify b0 images # Get non-b0 values non_b0_bvals = self.bvals[self.bvals > b0_thresh] if len(non_b0_bvals) == 0: self.scheme_type = "b0_only" return # Round to nearest 100 to account for small variations rounded_bvals = np.round(non_b0_bvals, -2) unique_shells = np.unique(rounded_bvals) n_shells = len(unique_shells) # Calculate distribution metrics mean_bval = non_b0_bvals.mean() std_bval = non_b0_bvals.std() cv = std_bval / mean_bval # Coefficient of variation # Count samples per shell samples_per_shell = [] for shell in unique_shells: n_samples = np.sum(np.abs(rounded_bvals - shell) < 50) samples_per_shell.append(n_samples) mean_samples_per_shell = np.mean(samples_per_shell) # Decision criteria # Shelled data typically has: # - Few shells (2-6) # - Many samples per shell (>10) # - Lower coefficient of variation (<0.4) # # DSI data typically has: # - Many shells (>8) # - Few samples per shell (<15) # - Higher coefficient of variation (>0.35) if n_shells <= 6 and mean_samples_per_shell > 10: self.scheme_type = "shelled" elif n_shells > 8 and cv > 0.35: self.scheme_type = "cartesian" else: # Borderline case - use number of shells as primary criterion if n_shells <= 6: self.scheme_type = "shelled" else: self.scheme_type = "cartesian" # ------------------------- # Visualization # -------------------------
[docs] def plot( self, show=True, use_notebook: bool = False, radius: float = 10.0, colormap: str = "jet", toroid_radius: float = None, toroid_alpha: float = 0.3, b0_thresh: float = 10.0, show_colorbar: bool = True, show_axes: bool = True, show_opposite_dirs: bool = True, ): g = self.gradients b = self.bvals if self.scheme_type is None: self._detect_scheme() # Apply appropriate coordinate transformation if self.scheme_type == "shelled": # HARDI: scale by b-value coords = g * b[:, None] else: # DSI: Coords = max(bvals)*grads.*sqrt(bvals/max(bvals)) # This simplifies to: grads * sqrt(bvals * max(bvals)) b_max = b.max() coords = g * np.sqrt(b[:, None] * b_max) # FIGURE CONFIGURATION figure_conf = { "background_color": "black", "title_font_color": "white", "colorbar_font_color": "white", "title_font_type": "arial", "title_font_size": 10, "title_shadow": True, "mesh_ambient": 0.2, "mesh_diffuse": 0.5, "mesh_specular": 0.5, "mesh_specular_power": 15, "mesh_smooth_shading": True, } rgba_data = cltcol.values2colors( b, cmap=colormap, vmin=b.min(), vmax=b.max(), output_format="rgb" ) # Optionally add opposite directions (mirror across origin) if show_opposite_dirs: coords = np.vstack([coords, -coords]) rgba_data = np.vstack([rgba_data, rgba_data]) b = np.concatenate([b, b]) # Detecting the screen size for the plotter screen_size = cltplot.get_current_monitor_size() # Create PyVista plotter with appropriate rendering mode plotter_kwargs = { "notebook": use_notebook, "window_size": [screen_size[0], screen_size[1]], } pv_plotter = pv.Plotter(**plotter_kwargs) pv_plotter.set_background(figure_conf["background_color"]) # Add gradient points as spheres pv_plotter.add_points( coords, render_points_as_spheres=True, point_size=radius, scalars=rgba_data, rgb=True, ambient=figure_conf["mesh_ambient"], diffuse=figure_conf["mesh_diffuse"], specular=figure_conf["mesh_specular"], specular_power=figure_conf["mesh_specular_power"], smooth_shading=figure_conf["mesh_smooth_shading"], show_scalar_bar=False, ) # Add center sphere for b0 pv_plotter.add_points( np.array([[0, 0, 0]]), render_points_as_spheres=True, point_size=radius, color="white", ambient=figure_conf["mesh_ambient"], diffuse=figure_conf["mesh_diffuse"], specular=figure_conf["mesh_specular"], specular_power=figure_conf["mesh_specular_power"], smooth_shading=figure_conf["mesh_smooth_shading"], ) # Add toroidal shells at each unique b-value # Use original b-values (not duplicated) for toroids original_bvals = self.bvals unique_bvals = np.unique(original_bvals) unique_bvals = unique_bvals[unique_bvals > b0_thresh] # Exclude b0 # Get colors for each unique b-value unique_colors = cltcol.values2colors( unique_bvals, cmap=colormap, vmin=original_bvals.min(), vmax=original_bvals.max(), output_format="rgb", ) # Auto-calculate toroid tube radius if not provided if toroid_radius is None: b_max = original_bvals.max() toroid_radius = b_max * 0.005 if b_max > 0 else 0.5 for bval, color in zip(unique_bvals, unique_colors): torus = pv.ParametricTorus( ringradius=bval, crosssectionradius=toroid_radius ) pv_plotter.add_mesh( torus, color=color, opacity=toroid_alpha, ambient=figure_conf["mesh_ambient"], diffuse=figure_conf["mesh_diffuse"], specular=figure_conf["mesh_specular"], specular_power=figure_conf["mesh_specular_power"], smooth_shading=figure_conf["mesh_smooth_shading"], ) # Add coordinate axes if show_axes: b_max = original_bvals.max() axis_length = b_max * 1.1 # X-axis pv_plotter.add_lines( np.array([[-axis_length, 0, 0], [axis_length, 0, 0]]), color="white", width=2, ) # Y-axis pv_plotter.add_lines( np.array([[0, -axis_length, 0], [0, axis_length, 0]]), color="white", width=2, ) # Z-axis pv_plotter.add_lines( np.array([[0, 0, -axis_length], [0, 0, axis_length]]), color="white", width=2, ) # Add colorbar - vertical on the right if show_colorbar: # Create a dummy mesh for the colorbar dummy_mesh = pv.PolyData(coords) dummy_mesh["bvalues"] = b pv_plotter.add_mesh( dummy_mesh, scalars="bvalues", cmap=colormap, show_edges=False, opacity=0, # Make it invisible scalar_bar_args={ "title": "b-value (s/mm²)", "title_font_size": 20, "label_font_size": 16, "color": "white", "position_x": 0.90, # Far right "position_y": 0.25, # Vertically centered "width": 0.08, # Narrow width for vertical bar "height": 0.5, # Tall for vertical orientation "vertical": True, # Explicitly vertical "n_labels": 5, # Number of labels "fmt": "%.0f", # Format as integers }, ) # Count b0 and DWI images (use original counts) n_b0s = np.sum(original_bvals <= b0_thresh) n_dwi = len(original_bvals) - n_b0s # Add title pv_plotter.add_text( f"q-Space Plot: {n_b0s} B0 Images and {n_dwi} Diffusion Images\nScheme: {self.scheme_type}", position="upper_edge", font_size=14, color="white", font="arial", ) # Set camera and lighting pv_plotter.add_light(pv.Light(position=(1, 1, 1))) pv_plotter.view_isometric() if show: pv_plotter.show() return pv_plotter