Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions src/cedalion/dot/head_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def from_surfaces(
scalp_face_count: int | None = 60000,
fill_holes: bool = False,
parcel_file: Path | str | None = None,
parcel_volume_file: Path | str | None = None,

) -> "TwoSurfaceHeadModel":
"""Constructor from seg.masks, brain and head surfaces as gained from MRI scans.
Expand All @@ -246,6 +247,7 @@ def from_surfaces(
scalp_face_count: Number of faces for the scalp surface.
fill_holes: Whether to fill holes in the segmentation masks.
parcel_file: path to the json file mapping vertices to parcels.
parcel_volume_file: Path to parcel nifiti file (annotated voxels).

Returns:
TwoSurfaceHeadModel: An instance of the TwoSurfaceHeadModel class.
Expand Down Expand Up @@ -327,18 +329,41 @@ def from_surfaces(
"segmentation_type"
)

# load parcellations
if parcel_file is not None:
parcels = cedalion.io.read_parcellations(parcel_file)
assert len(parcels) == brain_ijk.nvertices
brain_ijk.vertex_coords["parcel"] = np.asarray(parcels.Label.tolist())
else:
parcels = None

if parcel_volume_file is not None:
import nibabel as nib
voxel_parcels = nib.load(parcel_volume_file)
affine = voxel_parcels.affine
voxel_parcels = voxel_parcels.get_fdata()
labels = np.unique(voxel_parcels.astype(int))
if os.path.exists(parcel_volume_file.replace('.nii.gz', '_labels.csv')):
with open(parcel_volume_file.replace('.nii.gz', '_labels.csv'), 'r') as f:
lines = [l.split() for l in f.readlines()]
csv_cbig = {int(l[0]): l[1] for l in lines}
parcels_dict = parcels.Label.to_dict()
for i, l in csv_cbig.items():
assert csv_cbig[i] == l
assert brain_mask.shape == voxel_parcels.shape
assert (t_ijk2ras.values == affine).all()
voxel_parcels = voxel_parcels.astype(int)
else:
voxel_parcels = None

voxel_to_vertex_brain = map_segmentation_mask_to_surface(
brain_mask, t_ijk2ras, brain_ijk.apply_transform(t_ijk2ras)
brain_mask, t_ijk2ras, brain_ijk.apply_transform(t_ijk2ras),
parcels_vox=voxel_parcels, parcels_verts=parcels
)
voxel_to_vertex_scalp = map_segmentation_mask_to_surface(
scalp_mask, t_ijk2ras, scalp_ijk.apply_transform(t_ijk2ras)
)

if parcel_file is not None:
parcels = read_parcellations(parcel_file)
assert len(parcels) == brain_ijk.nvertices
brain_ijk.vertex_coords["parcel"] = np.asarray(parcels.Label.tolist())

return cls(
segmentation_masks=segmentation_masks,
brain=brain_ijk,
Expand Down
41 changes: 41 additions & 0 deletions src/cedalion/dot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def map_segmentation_mask_to_surface(
segmentation_mask: xr.DataArray,
transform_vox2ras: cdt.AffineTransform, # FIXME
surface: cdc.Surface,
parcels_vox: np.ndarray = None,
parcels_verts: xr.DataArray = None,
):
"""Find for each voxel the closest vertex on the surface.

Expand All @@ -25,6 +27,10 @@ def map_segmentation_mask_to_surface(
transform_vox2ras (xr.DataArray): The affine transformation from voxel to RAS
space.
surface (cedalion.dataclasses.Surface): The surface to map the voxels to.
parcels_vox (np.ndarray, optional): An array of shape (nx, ny, nz) containing
the parcel label indices for each voxel.
parcels_verts (xr.DataArray, optional): An array of shape (nvertices,) containing
the parcellation information for each brain surface vertex.

Returns:
coo_array: A sparse matrix of shape (ncells, nvertices) that maps voxels to
Expand All @@ -49,6 +55,41 @@ def map_segmentation_mask_to_surface(
cell_coords.values[cell_indices, :], workers=-1
)

if parcels_vox is not None and parcels_verts is not None:
# overwrite voxel labels if in segmentation mask not in brain tissue
fs_num_labeled_vox = np.sum(np.flatnonzero(parcels_vox))
parcels_vox *= segmentation_mask.values
print("Num of labeled voxels before seg-masking: %d\nNum of labeled voxels after seg-masking: %d" % (fs_num_labeled_vox, np.sum(np.flatnonzero(parcels_vox))))

# if parcellation is provided, overwrite vertex_indices with mapping
# constraint to vertices-mapping within the same parcel
for parcel_id in np.unique(parcels_vox):
if parcel_id == 0:
continue
# get cell indices within this parcel
parcels_vox_flat = parcels_vox.flatten()
parcel_cell_indices = np.argwhere(parcels_vox_flat[cell_indices] == parcel_id)[:, 0]
if len(parcel_cell_indices) == 0:
continue
# get vertices within this parcel
parcel_vertex_indices = np.where(
parcels_verts.index == parcel_id
)[0]
if len(parcel_vertex_indices) == 0:
continue
# build a KDTree for the parcel vertices
from scipy.spatial import KDTree
parcel_tree = KDTree(surface.vertices[parcel_vertex_indices, :])
# query the parcel_tree for the parcel_cell_indices
dists_parcel, vertex_indices_parcel = parcel_tree.query(
cell_coords.values[parcel_cell_indices, :], workers=-1
)
# map back to global vertex indices
global_vertex_indices_parcel = parcel_vertex_indices[vertex_indices_parcel]
# update vertex_indices for these cell indices

vertex_indices[parcel_cell_indices] = global_vertex_indices_parcel

# construct a sparse matrix of shape (ncells, nvertices)
# that maps voxels to cells
map_voxel_to_vertex = coo_array(
Expand Down
3 changes: 2 additions & 1 deletion src/cedalion/io/anatomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def read_parcellations(parcel_file: str | Path) -> pd.DataFrame:
parcels["Vertices"] = parcels["Vertices"].astype(int)
parcels = parcels.sort_values("Vertices")

parcels["Label"] = parcels["Label"].apply(lambda x: "_".join(x.split(" ")) + "H")
if not parcels["Label"].values[1].endswith('H'):
parcels["Label"] = parcels["Label"].apply(lambda x: "_".join(x.split(" ")) + "H")

return parcels
Loading