"""
Functions for localizing and plotting intracranial electrodes on the freesurfer average brain.
"""
import os
import warnings
from os.path import join as pjoin
import numpy as np
from nibabel.freesurfer.io import read_geometry, read_label, read_morph_data, read_annot
from scipy.spatial.distance import cdist
from skspatial.objects import Line, Plane
from hdf5storage import loadmat
from naplib.utils import dist_calc, load_freesurfer_label
from naplib import logger
warnings.filterwarnings("ignore", message="nopython", append=True)
HEMIS = ("lh", "rh")
SURF_TYPES = ("pial", "inflated")
num2region_D_custom = {
# My custom labels
76: "O_pmHG",
77: "O_alHG",
78: "O_Te10",
79: "O_Te11",
80: "O_Te12",
81: "O_mSTG",
82: "O_pSTG",
83: "O_IFG",
}
num2region_DK_custom = {
# My custom labels
42: "O_IFG",
}
class Hemisphere:
def __init__(
self,
hemi: str,
surf_type: str = "pial",
subject: str = "fsaverage",
coordinate_space: str = 'FSAverage',
atlas=None,
subject_dir=None,
):
"""
Hemisphere object
Parameters
----------
hemi : str
Either 'lh' or 'rh'.
surf_type : str, default='pial'
Cortical surface type, either 'pial' or 'inflated' or another if the corresponding
files can be found.
subject : str, default='fsaverage'
Subject to use, must be a directory within ``subject_dir``
coordinate_space : str, default='FSAverage'
Coordinate space of brain vertices. Must be 'FSAverage' or 'MNI152'
atlas : str, default=None
Atlas for brain parcellation. Defaults to 'Destrieux' for coordinate_space='FSAverage'
and 'Desikan-Killiany' for 'MNI152'. Can also be an annotation file name given by
``{subject_dir}/{subject}/label/?h.{atlas}.annot``
subject_dir : str/path-like, defaults to SUBJECT_DIR environment variable, or the current directory
if that does not exist.
Path containing the subject's folder.
"""
if hemi not in HEMIS:
raise ValueError(f"Argument `hemi` should be in {HEMIS}.")
if surf_type not in SURF_TYPES:
raise ValueError(f"Argument `surf_type` should be in {SURF_TYPES}.")
if not atlas:
if coordinate_space == 'FSAverage':
atlas = 'Destrieux'
# Use DK for MNI152 or any other
else:
atlas = 'Desikan-Killiany'
self.hemi = hemi
self.surf_type = surf_type
self.subject = subject
self.coordinate_space = coordinate_space
if subject_dir is None:
subject_dir = os.environ.get("SUBJECTS_DIR", "./")
self.subject_dir = subject_dir
if atlas not in ['Desikan-Killiany', 'Destrieux'] and not os.path.exists(self.label_file(f'{self.hemi}.{atlas}.annot')):
raise ValueError('Bad atlas. Try "Desikan-Killiany" or "Destrieux"')
self.atlas = atlas
# Check if fsaverage geometry exists
if self.coordinate_space == 'FSAverage':
if os.path.exists(self.surf_file(f"{hemi}.{surf_type}")):
self.surf = read_geometry(self.surf_file(f"{hemi}.{surf_type}"))
self.surf_pial = read_geometry(self.surf_file(f"{hemi}.pial"))
else:
self.coordinate_space = 'MNI152'
print('Trying MNI152 coordinate space')
# Use MN152 coordinate space if not
if self.coordinate_space == 'MNI152':
# try to find .mat file
surf_ = loadmat(self.surf_file(f"{hemi}_pial.mat"))
coords, faces = surf_['coords'], surf_['faces']
faces -= 1 # make faces zero-indexed
self.surf = (coords, faces)
self.surf_pial = (coords, faces)
if self.coordinate_space not in ['FSAverage','MNI152']:
raise ValueError(f"Argument `coordinate_space`={self.coordinate_space} not implemented.")
try:
self.cort = np.sort(read_label(self.label_file(f"{hemi}.cortex.label")))
except Exception as e:
logger.warning(f'No {hemi}.cortex.label file found. Assuming the entire surface is cortex.')
self.cort = np.arange(self.surf[0].shape[0])
try:
self.sulc = read_morph_data(self.surf_file(f"{hemi}.sulc"))
except Exception as e:
logger.warning(f'No {hemi}.sulc file found. No sulcus information will be used.')
self.sulc = None
self.sulc_alpha = 1.0
self.load_labels()
self.reset_overlay()
@property
def coords(self):
return self.surf[0]
@property
def n_verts(self):
return self.surf[0].shape[0]
@property
def trigs(self):
return self.surf[1]
@property
def n_trigs(self):
return self.surf[1].shape[0]
@property
def label_names(self):
return [self.num2label[n] for n in list(set(self.labels))]
def surf_file(self, file: str):
return pjoin(self.subject_dir, self.subject, "surf", file)
def label_file(self, file: str):
return pjoin(self.subject_dir, self.subject, "label", file)
def other_file(self, file: str):
return pjoin(self.subject_dir, self.subject, "other", file)
def load_labels(self):
"""
Load Destrieux labels of each vertex from annotation files.
Returns
-------
self : instance of self
"""
self.ignore = np.zeros(self.n_verts, dtype=bool)
annot_file = self.label_file(f"{self.hemi}.aparc.a2005s.annot")
if os.path.exists(annot_file):
for reg in ("Unknown", "Medial_wall"):
self.ignore[load_freesurfer_label(annot_file, reg)] = True
self.labels = np.zeros(self.n_verts, dtype=int)
if self.coordinate_space == 'MNI152':
if self.atlas == 'Desikan-Killiany':
annot_file = self.label_file(f"FSL_MNI152.{self.hemi}.aparc.split_STG_MTG.annot")
else:
annot_file = self.label_file(f'{self.hemi}.{self.atlas}.annot')
elif self.coordinate_space == "FSAverage":
if self.atlas == 'Desikan-Killiany':
annot_file = self.label_file(f"{self.hemi}.aparc.annot")
elif self.atlas == 'Destrieux':
annot_file = self.label_file(f"{self.hemi}.aparc.a2009s.annot")
else:
annot_file = self.label_file(f'{self.hemi}.{self.atlas}.annot')
else:
raise ValueError('Bad coordinate space')
if os.path.exists(annot_file):
_,_,regions = read_annot(annot_file)
regions = [i.decode("utf-8") for i in regions]
num2region = {k:v for k,v in enumerate(regions)}
for ind, reg in num2region.items():
self.labels[load_freesurfer_label(annot_file, reg)] = ind
else:
raise ValueError('Bad atlas. Try "Desikan-Killiany" or "Destrieux".')
if self.atlas == 'Destrieux':
num2region.update(num2region_D_custom)
elif self.atlas == 'Desikan-Killiany' and self.coordinate_space == 'MNI152':
num2region.update(num2region_DK_custom)
self.labels[self.ignore] = 0
self.num2label = num2region
self.label2num = {v: k for k, v in self.num2label.items()}
self.simplified = False
self.is_mangled_hg = False
self.is_mangled_tts = False
self.is_mangled_stg = False
self.is_mangled_ifg = False
return self
def simplify_labels(self):
"""
Simplify Destrieux and Desikan-Killiany labels into shortforms.
Returns
-------
self : instance of self
"""
if self.atlas == 'Destrieux':
conversions = {
"Other": [], # Autofill all uncovered vertecies
"HG": ["G_temp_sup-G_T_transv"],
"pmHG": ["O_pmHG"],
"alHG": ["O_alHG"],
"Te1.0": ["O_Te10"],
"Te1.1": ["O_Te11"],
"Te1.2": ["O_Te12"],
"TTS": ["S_temporal_transverse"],
"PT": ["G_temp_sup-Plan_tempo"],
"PP": ["G_temp_sup-Plan_polar"],
"MTG": ["G_temporal_middle"],
"ITG": ["G_temporal_inf"],
"STG": ["G_temp_sup-Lateral"],
"mSTG": ["O_mSTG"],
"pSTG": ["O_pSTG"],
"STS": ["S_temporal_sup"],
"IFG": ["O_IFG"],
"IFG.opr": ["G_front_inf-Opercular"],
"IFG.tri": ["G_front_inf-Triangul"],
"IFG.orb": ["G_front_inf-Orbital"],
"Subcnt": ["G_and_S_subcentral"],
"Insula": ["G_Ins_lg_and_S_cent_ins", "G_insular_short"],
"T.Pole": ["Pole_temporal"],
}
elif self.atlas == 'Desikan-Killiany':
d1 = {k: [k] for k in self.label2num.keys() if k not in ['O_IFG','parsopercularis','parstriangularis','parsorbitalis']}
d2_override = {
"Other": [],
"IFG": ["O_IFG"],
"IFG.opr": ["parsopercularis"],
"IFG.tri": ["parstriangularis"],
"IFG.orb": ["parsorbitalis"],
}
conversions = {**d1, **d2_override}
else:
raise ValueError('Bad atlas')
conversions = {
key: [self.label2num[g] for g in groups]
for key, groups in conversions.items()
}
simple_num2label = {i: k for i, k in enumerate(conversions)}
simple_label2num = {v: k for k, v in simple_num2label.items()}
simple_labels = np.zeros(self.n_verts, dtype=int)
for key, groups in conversions.items():
for g in groups:
simple_labels[self.labels == g] = simple_label2num[key]
self.labels = simple_labels
self.num2label = simple_num2label
self.label2num = simple_label2num
self.simplified = True
return self
def filter_labels(self, labels):
"""
Returns mask of vertices that are within the union of `labels`.
Parameters
----------
labels : str | list[str]
Label(s) of zone(s) to include in the binary mask.
Returns
-------
mask : boolean array of shape (n_verts,).
"""
if isinstance(labels, str):
labels = (labels,)
mask = np.zeros(self.n_verts, dtype=bool)
for label in labels:
if label in self.label2num:
mask[self.labels == self.label2num[label]] = True
return mask
def zones(self, labels, min_alpha=0):
"""
Build zone map of brain where vertices in region-of-interest (union of `labels`), have
alpha=1, and everywhere else has alpha=`min_alpha`.
Parameters
----------
labels : str | list[str]
Label or labels to include in the zone.
min_alpha : float
Value to assign to regions not included.
Returns
-------
verts : array of shape (n_verts,) with values of 0 and 1.
trigs : array of shape (n_triangles,) with values between 0 and 1.
zones : array of shape (n_verts,) with integer values corresponding to index of label.
"""
if isinstance(labels, str):
labels = (labels,)
labels = [l for l in labels if l in self.label2num]
verts = np.zeros(self.n_verts, dtype=bool)
zones = np.zeros(self.n_verts, dtype=int)
for i, label in enumerate(labels):
nodes = self.labels == self.label2num[label]
verts[nodes] = 1
zones[nodes] = i + 1
trigs = np.zeros(self.n_trigs, dtype=float)
for i in range(self.n_trigs):
trigs[i] = np.mean([verts[self.trigs[i, j]] != 0 for j in range(3)])
if trigs[i] < min_alpha:
trigs[i] = min_alpha
# if self.ignore[self.trigs[i]].any():
# trigs[i] = 0
return verts, trigs, zones
def fit_ml_line(self, points):
"""
Fit a mediolateral line to `points` that goes from medial to lateral end.
Parameters
----------
points : np.ndarray
Array of 3d point coordinates in the shape of (n_points, 3).
Returns
-------
line : Line, the fit line as a Line object.
"""
line = Line.best_fit(points)
line.direction = (
-line.direction if line.point[0] * line.direction[0] < 0 else line.direction
)
return line
def fit_ml_plane_from_line(self, points):
"""
Fit an anterior facing plane to mediolateral `points`.
Parameters
----------
points : np.ndarray
Array of 3d point coordinates in the shape of (n_points, 3).
Returns
-------
plane : Plane, the fit plane as a Plane object.
"""
line = Line.best_fit(points)
plane = Plane.from_vectors(points.mean(0), line.direction, [0, 0, 1])
plane.normal = -plane.normal if plane.normal[1] < 0 else plane.normal
return plane
def split_hg(self, method="midpoint"):
"""
Split HG vertices into subregions, such as posteromedial (pmHG) and anterolateral (alHG) halves.
Parameters
----------
method : {'midpoint', 'endpoint', 'median', 'te1x', 'six_four', or 'seven_three'}, default='midepoint'
How to split the region.
Returns
-------
self : instance of self
"""
if self.is_mangled_hg:
raise RuntimeError(
"HG cannot be split as it is already mangled. Try changing order of operations?"
)
self.is_mangled_hg = True
if self.atlas != 'Destrieux':
raise ValueError(f'split_hg() only supported for Destrieux atlas.')
hg = self.filter_labels(["G_temp_sup-G_T_transv", "HG"])
if method == "midpoint":
# Fit line to HG and project vertices to line
position = self.fit_ml_line(self.coords[hg]).transform_points(
self.coords[hg]
)
midpoint = np.mean((min(position), max(position)))
# Split HG using midpoint of line
medial = position <= midpoint
self.labels[np.where(hg)[0][medial]] = self.label2num[
"pmHG" if self.simplified else "O_pmHG"
]
self.labels[np.where(hg)[0][~medial]] = self.label2num[
"alHG" if self.simplified else "O_alHG"
]
elif method == "six_four" or method == "seven_three":
# Fit line to HG and project vertices to line
position = self.fit_ml_line(self.coords[hg]).transform_points(
self.coords[hg]
)
if method == "six_four":
midpoint = 0.4 * min(position) + 0.6 * max(position)
else:
midpoint = 0.3 * min(position) + 0.7 * max(position)
# Split HG using midpoint of line
medial = position <= midpoint
self.labels[np.where(hg)[0][medial]] = self.label2num[
"pmHG" if self.simplified else "O_pmHG"
]
self.labels[np.where(hg)[0][~medial]] = self.label2num[
"alHG" if self.simplified else "O_alHG"
]
elif method == "endpoint":
# Fit line to HG and find two furthest endpoints w.r.t. the line
position = self.fit_ml_line(self.coords[hg]).transform_points(
self.coords[hg]
)
medpoint, latpoint = np.argmin(position), np.argmax(position)
# Distance of HG from endpoint 1
dist_mp = dist_calc(self.surf, self.cort, np.where(hg)[0][medpoint])
dist_mp[self.ignore] = np.inf
# Distance of HG from endpoint 2
dist_lp = dist_calc(self.surf, self.cort, np.where(hg)[0][latpoint])
dist_lp[self.ignore] = np.inf
# Join each point to closer endpoint
closer_to_medpoint = (
np.argmin(np.stack((dist_mp[hg], dist_lp[hg])), axis=0) == 0
)
self.labels[np.where(hg)[0][closer_to_medpoint]] = self.label2num[
"pmHG" if self.simplified else "O_pmHG"
]
self.labels[np.where(hg)[0][~closer_to_medpoint]] = self.label2num[
"alHG" if self.simplified else "O_alHG"
]
elif method == "median":
# Fit line to HG and project vertices to line
position = self.fit_ml_line(self.coords[hg]).transform_points(
self.coords[hg]
)
midpoint = np.median(position)
# Split HG using midpoint of line
medial = position <= midpoint
self.labels[np.where(hg)[0][medial]] = self.label2num[
"pmHG" if self.simplified else "O_pmHG"
]
self.labels[np.where(hg)[0][~medial]] = self.label2num[
"alHG" if self.simplified else "O_alHG"
]
elif method == "te1x":
# Read Te1.x labels
te10 = read_label(self.other_file(f"{self.hemi}.te10.label"))
te11 = read_label(self.other_file(f"{self.hemi}.te11.label"))
te12 = read_label(self.other_file(f"{self.hemi}.te12.label"))
# Distance of HG from Te1.0
dist_te10 = dist_calc(self.surf, self.cort, te10)
dist_te10[self.ignore] = np.inf
# Distance of HG from Te1.1
dist_te11 = dist_calc(self.surf, self.cort, te11)
dist_te11[self.ignore] = np.inf
# Distance of HG from Te1.2
dist_te12 = dist_calc(self.surf, self.cort, te12)
dist_te12[self.ignore] = np.inf
# Join each point to closest endpoint
closest = np.argmin(
np.stack((dist_te10[hg], dist_te11[hg], dist_te12[hg])), axis=0
)
self.labels[np.where(hg)[0][closest == 0]] = self.label2num[
"Te1.0" if self.simplified else "O_Te10"
]
self.labels[np.where(hg)[0][closest == 1]] = self.label2num[
"Te1.1" if self.simplified else "O_Te11"
]
self.labels[np.where(hg)[0][closest == 2]] = self.label2num[
"Te1.2" if self.simplified else "O_Te12"
]
else:
raise ValueError(f"Invalid method argument {method}")
return self
def remove_tts(self, method="split"):
"""
Convert TTS labels into either HG or PT ones.
Parameters
----------
method : {'split', 'join_hg', 'join_pt'}, default='split'
Method for removing. 'split' will convert labels to either PT or HG depending on
which is closer, while 'join_hg' or 'join_pt' will convert the entire region to
HG or PT, respectively.
Returns
-------
self : instance of self
"""
if self.atlas != 'Destrieux':
raise ValueError(f'remove_tts() only supported for Destrieux atlas.')
if self.is_mangled_tts:
raise RuntimeError(
"TTS cannot be removed as it is already mangled. Try changing order of operations?"
)
self.is_mangled_tts = True
tts = self.filter_labels(["S_temporal_transverse", "TTS"])
if method == "join_hg":
self.labels[tts] = self.label2num[
"HG" if self.simplified else "G_temp_sup-G_T_transv"
]
elif method == "join_pt":
self.labels[tts] = self.label2num[
"PT" if self.simplified else "G_temp_sup-Plan_tempo"
]
elif method == "split":
# Distance of TTS points from HG
hg = self.filter_labels(
["G_temp_sup-G_T_transv", "HG", "O_pmHG", "O_alHG", "pmHG", "alHG"]
)
dist_hg = dist_calc(self.surf, self.cort, np.where(hg)[0])
dist_hg[self.ignore] = np.inf
# Distance of TTS points from PT
pt = self.filter_labels(["G_temp_sup-Plan_tempo", "PT"])
dist_pt = dist_calc(self.surf, self.cort, np.where(pt)[0])
dist_pt[self.ignore] = np.inf
# Join each point to closer region
closer_to_hg = (
np.argmin(np.stack((dist_hg[tts], dist_pt[tts])), axis=0) == 0
)
self.labels[np.where(tts)[0][closer_to_hg]] = self.label2num[
"HG" if self.simplified else "G_temp_sup-G_T_transv"
]
self.labels[np.where(tts)[0][~closer_to_hg]] = self.label2num[
"PT" if self.simplified else "G_temp_sup-Plan_tempo"
]
else:
raise ValueError("")
return self
def split_stg(self, method="tts_plane"):
"""
Split STG into middle (mSTG) and posterior (pSTG) halves.
Parameters
----------
method : {'tts_plane'}, default='tts_plane'
Method for splitting, currently only support tts_plane.
Returns
-------
self : instance of self
"""
if self.atlas != 'Destrieux':
raise ValueError(f'split_stg() only supported for Destrieux atlas.')
if self.is_mangled_stg:
raise RuntimeError(
"STG cannot be split as it is already mangled. Try changing order of operations?"
)
self.is_mangled_stg = True
stg = self.filter_labels(["G_temp_sup-Lateral", "STG"])
if method == "tts_plane":
# Compute TTS plane
tts = self.filter_labels(["S_temporal_transverse", "TTS"])
plane = self.fit_ml_plane_from_line(self.coords[tts])
# Split STG using the plane
posterior = (
np.array([plane.distance_point_signed(p) for p in self.coords[stg]])
<= 0
)
self.labels[np.where(stg)[0][posterior]] = self.label2num[
"pSTG" if self.simplified else "O_pSTG"
]
self.labels[np.where(stg)[0][~posterior]] = self.label2num[
"mSTG" if self.simplified else "O_mSTG"
]
else:
raise ValueError("")
return self
def join_ifg(self):
"""
Join all three subregion of IFG into one.
Returns:
self
"""
if self.is_mangled_ifg:
raise RuntimeError(
"IFG cannot be joined as it is already mangled. Try changing order of operations?"
)
self.is_mangled_ifg = True
if self.atlas == 'Destrieux':
ifg = self.filter_labels(
[
"G_front_inf-Opercular",
"G_front_inf-Triangul",
"G_front_inf-Orbital",
"IFG.opr",
"IFG.tri",
"IFG.orb",
]
)
elif self.atlas == 'Desikan-Killiany':
ifg = self.filter_labels(
[
"parsopercularis",
"parstriangularis",
"parsorbitalis",
"IFG.opr",
"IFG.tri",
"IFG.orb",
]
)
else:
print('No change for coordinate space', self.coordinate_space)
return self
self.labels[ifg] = self.label2num["IFG" if self.simplified else "O_IFG"]
return self
def reset_overlay(self):
self.overlay = np.zeros(self.surf[0].shape[0])
self.alpha = np.ones(self.surf[1].shape[0])
self.keep_visible = np.ones_like(self.overlay).astype("bool")
self.keep_visible_cells = np.ones_like(self.alpha).astype("bool")
return self
def paint_overlay(self, labels, value=1):
"""
Paint brain region(s) specified by label(s).
Returns:
self
"""
if isinstance(labels, str):
labels = [labels]
for label in labels:
if label in self.label2num:
self.overlay[self.labels==self.label2num[label]] = value
return self
def interpolate_electrodes_onto_brain(self, coords, values, k, max_dist, roi='all'):
"""
Use electrode coordinates to interpolate 1-dimensional values corresponding
to each electrode onto the brain's surface.
Parameters
----------
coords : np.ndarray (elecs, 3)
3D coordinates of electrodes
values : np.ndarray (elecs,)
Value for each electrode
k : int
Number of nearest neighbors to consider
max_dist : float
Maximum distance outside of which nearest neighbors will be ignored
roi : list of strings, or string in {'all', 'temporal'}, default='all'
Regions to allow interpolation over. By default, the entire brain surface
is allowed. Can also be specified as a list of string labels (drawing from self.label_names)
Notes
-----
After running this function, you can use the visualization function ``plot_brain_overlay``
for a quick matplotlib plot, or you can extract the surface values from the ``self.overlay``
attribute for plotting with another tool like pysurfer.
"""
if isinstance(roi, str) and roi == 'all':
roi_list = self.label_names
elif isinstance(roi, str) and roi == 'temporal':
if self.atlas != 'Destrieux':
raise ValueError("roi='temporal' only supported for Destrieux atlas. Must specify list of specific region names")
if self.simplified:
roi_list = ['alHG','pmHG','HG','TTS','PT','PP','MTG','ITG','mSTG','pSTG','STG','STS','T.Pole']
else:
temporal_regions_nums = [33, 34, 35, 36, 74, 41, 43, 72, 73, 38, 37, 76, 77, 78, 79, 80, 81, 82]
roi_list = [self.num2label[num] for num in temporal_regions_nums]
else:
roi_list = roi
assert isinstance(roi, list)
roi_list_subset = [x for x in roi_list if x in self.label_names]
zones_to_include, _, _ = self.zones(roi_list_subset)
# Euclidean distances from each surface vertex to each coordinate
dists = cdist(self.surf[0], coords)
sorted_dists = np.sort(dists, axis=-1)[:, :k]
indices = np.argsort(dists, axis=-1)[:, :k] # get closest k electrodes to each vertex
# Mask out distances greater than max_dist
valid_mask = sorted_dists <= max_dist
# Retrieve the corresponding values using indices
neighbor_values = values[indices]
# Mask invalid values
masked_values = np.where(valid_mask, neighbor_values, np.nan)
masked_distances = np.where(valid_mask, sorted_dists, np.nan)
# Compute weights: inverse distance weighting (avoiding division by zero)
weights = np.where(valid_mask, 1 / (masked_distances + 1e-10), 0)
# # Compute weighted sum and normalize by total weight per vertex
weighted_sum = np.nansum(masked_values * weights, axis=1)
total_weight = np.nansum(weights, axis=1)
# # Normalize to get final smoothed values
updated_vertices = np.logical_and(total_weight > 0, zones_to_include)
total_weight[~updated_vertices] += 1e-10 # this just gets ride of the division by zero warning, but doesn't affect result since these values are turned to nan anyway
smoothed_values = np.where(updated_vertices, weighted_sum / total_weight, np.nan)
# update the surface vertices and triangle attributes with the values
verts = updated_vertices.astype('float')
trigs = np.zeros(self.n_trigs, dtype=float)
for i in range(self.n_trigs):
trigs[i] = np.mean([verts[self.trigs[i, j]] != 0 for j in range(3)])
self.overlay[updated_vertices] = smoothed_values[updated_vertices]
return self
def mark_overlay(self, verts, value=1, inner_radius=0.8, taper=True):
"""
Fill circle(s) around target(s).
Returns
-------
self
"""
if np.isscalar(verts):
verts = [verts]
dist = dist_calc(self.surf, self.cort, verts)
dist[self.ignore] = np.inf
r1 = dist <= 1 * inner_radius
r2 = dist <= 2 * inner_radius
r3 = dist <= 3 * inner_radius
v1 = value
v2 = value / 2 if taper else value
v3 = value / 8 if taper else value
self.overlay[r3] = np.maximum(self.overlay[r3], v3)
self.overlay[r2] = np.maximum(self.overlay[r2], v2)
self.overlay[r1] = np.maximum(self.overlay[r1], v1)
return self
def parcellate_overlay(self, merge_func=np.mean):
"""
Merges overlay values within each parcel for a single hemisphere.
Parameters
----------
merge_func : callable, default=numpy.mean
Function to merge values within each parcel. Should accept a 1D
NumPy array and return a scalar.
"""
# Vectorize label to number conversion for efficiency
label_nums = np.array([self.label2num[label] for label in self.label_names], dtype=self.labels.dtype)
# Vectorize the core logic.
parcellated_overlay = np.zeros_like(self.overlay) # Create an empty array like self.overlay
for i, label_num in enumerate(label_nums):
inds = self.labels == label_num
if inds.any(): # important check in case a label has no vertices
parcellated_overlay[inds] = merge_func(self.overlay[inds])
self.overlay = parcellated_overlay
return self
def set_visible(self, labels, min_alpha=0):
keep_visible, self.alpha, _ = self.zones(labels, min_alpha=min_alpha)
self.keep_visible = keep_visible > min_alpha
self.keep_visible_cells = self.alpha > min_alpha
self.alpha = np.maximum(self.alpha, min_alpha)
return self
def reset_overlay_except(self, labels):
keep_visible, self.alpha, _ = self.zones(labels, min_alpha=0)
self.overlay[~keep_visible] = 0
return self
[docs]
class Brain:
def __init__(
self,
surf_type: str = "pial",
subject: str = "fsaverage",
coordinate_space: str = 'FSAverage',
atlas=None,
subject_dir=None
):
"""
Brain representation containing a left and right hemisphere. Can be used for plotting,
distance calculations, etc.
Parameters
----------
surf_type : str, default='pial'
Cortical surface type, either 'pial' or 'inflated' or another if the corresponding
files can be found.
subject : str, default='fsaverage'
Subject to use, must be a directory within ``subject_dir``
coordinate_space : str, default='FSAverage'
Coordinate space of brain vertices. Must be 'FSAverage' or 'MNI152'
atlas : str, default=None
Atlas for brain parcellation. Defaults to 'Destrieux' for coordinate_space='FSAverage'
and 'Desikan-Killiany' for 'MNI152'. Can also be an annotation file name given by
``{subject_dir}/{subject}/label/?h.{atlas}.annot``
subject_dir : str/path-like, defaults to SUBJECT_DIR environment variable, or the current directory
if that does not exist.
Path containing the subject's folder.
Examples
--------
>>> from naplib.localization import Brain
>>> from naplib.visualization import plot_brain_elecs
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> brain = Brain('pial', subject_dir='path/to/freesurfer/subjects/').split_stg().join_ifg()
>>> coords = np.array([[-47.281147 , 17.026093 , -21.833099 ],
[-48.273964 , 16.155487 , -20.162935 ]])
>>> isleft = np.array([True, True])
>>> annotations = brain.annotate_coords(coords, isleft)
array(['mSTG','mSTG'])
>>> dist_from_HG = brain.distance_from_region(coords, isleft, region='pmHG', metric='surf')
array([52.67211969 50.86446306])
>>> # plot electrodes on brain with matplotlib
>>> fig, axes = plot_brain_elecs(brain, coords, isleft, values=dist_from_HG, hemi='lh', view='lateral')
>>> plt.show()
>>> # plot electrodes on brain in interactive 3D figure
>>> fig, _ = plot_brain_elecs(brain, coords, isleft, values=dist_from_HG, backend='plotly')
>>> fig.write_html("interactive_brain_plot.html") # save as an interactive html figure
>>> fig.show()
"""
if surf_type not in SURF_TYPES:
raise ValueError(f"Argument `surf_type` should be in {SURF_TYPES}.")
self.surf_type = surf_type
self.subject = subject
self.lh = Hemisphere("lh", surf_type, subject, coordinate_space, atlas, subject_dir=subject_dir)
self.rh = Hemisphere("rh", surf_type, subject, coordinate_space, atlas, subject_dir=subject_dir)
@property
def num2label(self):
return self.lh.num2label
@property
def label2num(self):
return self.lh.label2num
@property
def label_names(self):
return list(set(self.lh.label_names + self.rh.label_names))
[docs]
def load_labels(self):
"""
Load Destrieux labels of each vertex from annotation files.
"""
self.lh.load_labels()
self.rh.load_labels()
return self
[docs]
def simplify_labels(self):
"""
Simplify Destrieux labels into shortforms.
"""
self.lh.simplify_labels()
self.rh.simplify_labels()
return self
[docs]
def split_hg(self, method="midpoint"):
"""
Split HG vertices into posteromedial (pmHG) and anterolateral (alHG) halves.
Arguments:
method, str: How to split the halves. One of 'midpoint', 'endpoint' or 'median'.
"""
self.lh.split_hg(method)
self.rh.split_hg(method)
return self
[docs]
def split_stg(self, method="tts_plane"):
"""
Split STG into middle (mSTG) and posterior (pSTG) halves.
"""
self.lh.split_stg(method)
self.rh.split_stg(method)
return self
[docs]
def remove_tts(self, method="split"):
"""
Convert TTS labels into either HG or PT ones.
"""
self.lh.remove_tts(method)
self.rh.remove_tts(method)
return self
[docs]
def join_ifg(self):
"""
Join all three subregion of IFG into one.
"""
self.lh.join_ifg()
self.rh.join_ifg()
return self
[docs]
def annotate(self, verts, is_left, is_surf=None, text=True):
"""
Get labels for vertices of the surface.
Parameters
----------
verts : np.ndarray
Array of vertices.
isleft : np.ndarray
Boolean array whether each vertex belongs to the left hemisphere.
distance_cutoff : float, default=10
Electrodes further than this distance (in mm) from the cortical surface will be labeled as "Other"
is_surf : boolean np.ndarray
Array of the same shape as the number of vertices in the surface (e.g. len(self.lh.surf[0])) indicating
whether those points should be included as surface options. If an electrode is closest to a point
with a False indicator in this array, then it will get None as its label.
text : bool, default=True
Whether to return labels as string names, or integer labels.
Returns
-------
labels : np.ndarray
Array of labels, either as strings or ints.
"""
labels = np.zeros(len(verts), dtype=int)
labels[is_left] = self.lh.labels[verts[is_left]]
labels[~is_left] = self.rh.labels[verts[~is_left]]
labels[verts <= 1] = 0
if is_surf is not None:
labels[~is_surf] = 0
if text:
labels = np.array([self.lh.num2label[label] if is_left[i] else self.rh.num2label[label]
for i,label in enumerate(labels)])
return labels
[docs]
def annotate_coords(
self, coords, isleft=None, distance_cutoff=10, is_surf=None, text=True, get_dists=False,
):
"""
Get labels (like pmHG, IFG, etc) for coordinates. Note, the coordinates should match the
`surf_type` of this brain, otherwise finding nearest surface points to each coordinate in order
to label it may be inaccurate.
Parameters
----------
coords : np.ndarray
Array of coordinates, shape (num_elecs, 3).
isleft : np.ndarray (elecs,), optional
If provided, specifies a boolean which is True for each electrode that is in the left hemisphere.
If not given, this will be inferred from the first dimension of the coords (negative is left).
distance_cutoff : float, default=10
Electrodes further than this distance (in mm) from the cortical surface will be labeled as None
is_surf : boolean np.ndarray
Array of the same shape as the number of vertices in the surface (e.g. len(self.lh.surf[0])) indicating
whether those points should be included as surface options. If an electrode is closest to a point
with a False indicator in this array, then it will get None as its label.
text : bool, default=True
Whether to return labels as string names, or integer labels.
get_dists : bool, default=False
Whether to return distances for each electrode to the nearest vertex.
Returns
-------
labels : np.ndarray
Array of labels, either as strings or ints.
dists : np.ndarray, optional
Array of minimum distances as floats
"""
if isleft is None:
isleft = coords[:,0] < 0
verts, dists = get_nearest_vert_index(
coords, isleft, self.lh.surf, self.rh.surf, verbose=False
)
labels = self.annotate(verts, isleft, is_surf=is_surf, text=text)
labels = np.asarray(
[
lab if dist < distance_cutoff else None
for lab, dist in zip(labels, dists)
]
)
if get_dists:
return labels, dists
else:
return labels
[docs]
def distance_from_region(self, coords, isleft=None, region="pmHG", metric="surf"):
"""
Get distance from a certain region for each electrode's coordinates. Can compute
distance along the cortical surface or as euclidean distance. For proper results, assuming
coordinates are in pial space, the brain must also be in pial space.
Parameters
----------
coords : np.ndarray
Array of coordinates in pial space for this brain's subject_id, shape (num_elecs, 3).
isleft : np.ndarray (elecs,), optional
If provided, specifies a boolean which is True for each electrode that is in the left hemisphere.
If not given, this will be inferred from the first dimension of the coords (negative is left).
region : str, default='pmHG'
Anatomical label. Must exist in the labels for the brain.
metric : {'surf','euclidean'}, default='surf'
Either surf, for distance along cortical surface, or euclidean, for euclidean distance.
Returns
-------
distances : np.ndarray
Array of distances, in mm.
"""
if isleft is None:
isleft = coords[:,0] < 0
region_label_num = self.label2num[region]
if region_label_num not in self.lh.labels:
raise ValueError(
"Region not found in existing labels. One possible issue is that you have not yet called"
" brain.split_hg(), or a similar method. For example, Te1.1 is only"
" available after calling brain.split_hg(method='te1x')"
)
surf_lh = self.lh.surf
surf_rh = self.rh.surf
# see which vertices correspond to this region in each hemi
which_verts_this_region_lh = self.lh.labels == region_label_num
which_verts_this_region_rh = self.rh.labels == region_label_num
# find the center of this region
region_center_lh = surf_lh[0][which_verts_this_region_lh].mean(0, keepdims=True)
region_center_rh = surf_rh[0][which_verts_this_region_rh].mean(0, keepdims=True)
if metric == "surf":
# get the closest valid vertex to this center point
closest_surface_vert_to_region_center_lh = np.argmin(
np.square(surf_lh[0] - region_center_lh.squeeze()).sum(1)
)
closest_surface_vert_to_region_center_rh = np.argmin(
np.square(surf_rh[0] - region_center_rh.squeeze()).sum(1)
)
# get distance from every vertex on the surface to this vertex
dist_lh = dist_calc(
surf_lh, self.lh.cort, closest_surface_vert_to_region_center_lh
)
dist_rh = dist_calc(
surf_rh, self.rh.cort, closest_surface_vert_to_region_center_rh
)
# get approximate vertex for every coordinate
nearest_verts, _ = get_nearest_vert_index(
coords, isleft, surf_lh, surf_rh, verbose=False
)
# get distance for each of these vertices from the region center, which was already calculated
distances_by_elec = []
for i in range(len(coords)):
if isleft[i]:
distances_by_elec.append(dist_lh[nearest_verts[i]])
else:
distances_by_elec.append(dist_rh[nearest_verts[i]])
elif metric == "euclidean":
dist_lh = cdist(region_center_lh, coords).squeeze()
dist_rh = cdist(region_center_rh, coords).squeeze()
distances_by_elec = []
for i in range(len(coords)):
if isleft[i]:
distances_by_elec.append(dist_lh[i])
else:
distances_by_elec.append(dist_rh[i])
else:
raise ValueError(f"metric must be surf or euclidean but got {metric}")
return np.asarray(distances_by_elec)
def reset_overlay(self):
self.lh.reset_overlay()
self.rh.reset_overlay()
return self
[docs]
def paint_overlay(self, labels, value=1):
"""
Paint brain region(s) specified by label(s).
Parameters
----------
labels : str | list[str]
Region or regions to paint an overlay.
value : float, default=1
Value to paint the region overlay with.
Returns
-------
self : an instance of self
"""
self.lh.paint_overlay(labels, value)
self.rh.paint_overlay(labels, value)
return self
[docs]
def mark_overlay(self, verts, isleft, value=1, inner_radius=0.8, taper=True):
"""
Fill circle(s) around target(s).
Parameters
----------
verts : np.ndarray
Vertices to mark.
isleft : np.ndarray of booleans
Indicator of same shape as verts for whether they are in the left hemisphere.
value : float, default=1
Value to mark with.
inner_radius : float, default=0.8
Radius of circle to mark around each vertex.
taper : bool, default=True
Whether to taper the circular mark.
Returns
-------
self : instance of self
"""
self.lh.mark_overlay(verts[isleft], value, inner_radius, taper)
self.rh.mark_overlay(verts[~isleft], value, inner_radius, taper)
return self
[docs]
def set_visible(self, labels, min_alpha=0):
"""
Set certain regions as visible with a float label, and the rest will be invisible.
Parameters
----------
labels : str | list[str]
Label(s) to set as visible.
min_alpha : float, default=0
Returns
-------
self : instance of self
"""
self.lh.set_visible(labels, min_alpha)
self.rh.set_visible(labels, min_alpha)
return self
[docs]
def reset_overlay_except(self, labels):
"""
Keep certain regions and the rest as colorless.
Parameters
----------
labels : str | list[str]
Label(s) to set as visible.
Returns
-------
self : instance of self
"""
self.lh.reset_overlay_except(labels)
self.rh.reset_overlay_except(labels)
return self
[docs]
def interpolate_electrodes_onto_brain(self, coords, values, isleft=None, k=10, max_dist=10, roi='all', reset_overlay_first=True):
"""
Use electrode coordinates to interpolate 1-dimensional values corresponding
to each electrode onto the brain's surface.
Parameters
----------
coords : np.ndarray (elecs, 3)
3D coordinates of electrodes
values : np.ndarray (elecs,)
Value for each electrode
isleft : np.ndarray (elecs,), optional
If provided, specifies a boolean which is True for each electrode that is in the left hemisphere.
If not given, this will be inferred from the first dimension of the coords (negative is left).
k : int, default=10
Number of nearest neighbors to consider
max_dist : float, default=10
Maximum distance (in mm) outside of which nearest neighbors will be ignored
roi : list of strings, or string in {'all', 'temporal'}, default='all'
Regions to allow interpolation over. By default, the entire brain surface
is allowed. Can also be specified as a list of string labels (drawing from self.lh.label_names)
reset_overlay_first : bool, default=True
If True (default), reset the overlay before creating a new overlay
Notes
-----
After running this function, you can use the visualization function ``plot_brain_overlay``
for a quick matplotlib plot, or you can extract the surface values from the ``self.lh.overlay``
and ``self.rh.overlay`` attributes, etc, for plotting with another tool like pysurfer or plotly.
"""
if reset_overlay_first:
self.reset_overlay()
if isleft is None:
isleft = coords[:,0] < 0
self.lh.interpolate_electrodes_onto_brain(coords[isleft], values[isleft], k=k, max_dist=max_dist, roi=roi)
self.rh.interpolate_electrodes_onto_brain(coords[~isleft], values[~isleft], k=k, max_dist=max_dist, roi=roi)
return self
[docs]
def parcellate_overlay(self, merge_func=np.mean):
"""Merges brain overlay values within each atlas parcel.
This method applies a merging function to the overlay values within each
anatomical parcel defined by an atlas. It is typically used after
interpolating electrode data onto the brain surface
(e.g., via `brain.interpolate_electrodes_onto_brain()`) to summarize
the data within each parcel.
Parameters
----------
merge_func : callable, default=numpy.mean
The function used to combine the overlay values within each parcel.
The function should accept an array-like object of values and return a
single value. Common examples include `numpy.mean` (default),
`numpy.median`, and `numpy.max`.
Returns
-------
self : instance of self
Returns the instance itself, with the overlay data parcellated.
"""
self.lh.parcellate_overlay(merge_func)
self.rh.parcellate_overlay(merge_func)
return self
def get_nearest_vert_index(coords, isleft, surf_lh, surf_rh, verbose=False):
vert_indices = []
min_dists = []
# loop through coordinates and update stat for each node that this electrode coordinate is close enough to
for i, coord in enumerate(coords):
if isleft[i]:
dists = np.sqrt(np.square(surf_lh[0] - coord).sum(1))
else:
dists = np.sqrt(np.square(surf_rh[0] - coord).sum(1))
min_dists.append(dists.min())
if verbose:
print(min_dists[-1])
vert_indices.append(dists.argmin())
return np.asarray(vert_indices), np.asarray(min_dists)
def find_closest_vertices(surface_coords, point_coords):
"""Return the vertices on a surface mesh closest to some given coordinates.
The distance metric used is Euclidian distance.
Parameters
----------
surface_coords : numpy array
Array of coordinates on a surface mesh
point_coords : numpy array
Array of coordinates to map to vertices
Returns
-------
closest_vertices : numpy array
Array of mesh vertex ids
"""
point_coords = np.atleast_2d(point_coords)
dists = cdist(surface_coords, point_coords)
return np.argmin(dists, axis=0), np.min(dists, axis=0)