Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions CorpusCallosum/segmentation/segmentation_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,29 +265,29 @@ def connect_nearby_components(seg_arr: ArrayType, max_connection_distance: float

# Plot components for debugging
if plot:
import matplotlib
import matplotlib.pyplot as plt
curr_backend = matplotlib.get_backend()
plt.switch_backend("qtagg")
n_components = len(component_sizes)
fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5))
if n_components == 1:
axes = [axes]
# Plot each component in a different color
for i, (comp_id, comp_size) in enumerate(component_sizes):
component_mask = labels_cc == comp_id
axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray')
axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}')
axes[i].axis('off')

# Plot the connected segmentation
axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray')
axes[-1].set_title('Connected Segmentation')
axes[-1].axis('off')
plt.tight_layout()
plt.show()
plt.switch_backend(curr_backend)


from FastSurferCNN.utils.plotting import backend

with backend("qtagg"):
n_components = len(component_sizes)
fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5))
if n_components == 1:
axes = [axes]
# Plot each component in a different color
for i, (comp_id, comp_size) in enumerate(component_sizes):
component_mask = labels_cc == comp_id
axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray')
axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}')
axes[i].axis('off')

# Plot the connected segmentation
axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray')
axes[-1].set_title('Connected Segmentation')
axes[-1].axis('off')
plt.tight_layout()
plt.show()

return connected_seg


Expand Down
37 changes: 18 additions & 19 deletions CorpusCallosum/shape/endpoint_heuristic.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,25 +205,24 @@ def find_cc_endpoints(
pc_startpoint_idx = np.argmin(np.linalg.norm(contour - posterior_anchor_2d[:, None], axis=0))

if plot: # interactive debug plot of contour, ac, pc and endpoints
import matplotlib
import matplotlib.pyplot as plt
curr_backend = matplotlib.get_backend()
plt.switch_backend("qtagg")
plt.figure(figsize=(10, 8))
plt.plot(contour[0, :], contour[1, :], 'b-', label='CC Contour', linewidth=2)
plt.plot(ac_2d[0], ac_2d[1], 'go', markersize=10, label='AC')
plt.plot(pc_2d[0], pc_2d[1], 'ro', markersize=10, label='PC')
plt.plot(anterior_anchor_2d[0], anterior_anchor_2d[1], 'g^', markersize=10, label='Anterior Anchor')
plt.plot(posterior_anchor_2d[0], posterior_anchor_2d[1], 'r^', markersize=10, label='Posterior Anchor')
plt.plot(contour[0, ac_startpoint_idx], contour[1, ac_startpoint_idx], 'g*', markersize=15, label='AC Endpoint')
plt.plot(contour[0, pc_startpoint_idx], contour[1, pc_startpoint_idx], 'r*', markersize=15, label='PC Endpoint')
plt.xlabel('A-S (mm)')
plt.ylabel('I-S (mm)')
plt.title('CC Contour with Endpoints')
plt.legend()
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.show()
plt.switch_backend(curr_backend)

from FastSurferCNN.utils.plotting import backend
with backend("qtagg"):
plt.figure(figsize=(10, 8))
plt.plot(contour[0, :], contour[1, :], 'b-', label='CC Contour', linewidth=2)
plt.plot(*ac_2d[0:2], 'go', markersize=10, label='AC')
plt.plot(*pc_2d[0:2], 'ro', markersize=10, label='PC')
plt.plot(*anterior_anchor_2d[0:2], 'g^', markersize=10, label='Anterior Anchor')
plt.plot(*posterior_anchor_2d[0:2], 'r^', markersize=10, label='Posterior Anchor')
plt.plot(*contour[0:2, ac_startpoint_idx], 'g*', markersize=15, label='AC Endpoint')
plt.plot(*contour[0:2, pc_startpoint_idx], 'r*', markersize=15, label='PC Endpoint')
plt.xlabel('A-S (mm)')
plt.ylabel('I-S (mm)')
plt.title('CC Contour with Endpoints')
plt.legend()
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.show()

return ac_startpoint_idx, pc_startpoint_idx
32 changes: 15 additions & 17 deletions CorpusCallosum/shape/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,22 @@ def calculate_cc_index(cc_contour: np.ndarray, plot: bool = False) -> float:
cc_index = (anterior_thickness + posterior_thickness + middle_thickness) / ap_distance

if plot:
import matplotlib
import matplotlib.pyplot as plt
curr_backend = matplotlib.get_backend()
plt.switch_backend("qtagg")

fig, ax = plt.subplots(figsize=(8, 6))
plot_cc_index_calculation(
ax,
cc_contour,
anterior_idx,
posterior_idx,
ap_intersections,
middle_intersections,
midpoint,
)
ax.legend()
plt.show()
plt.switch_backend(curr_backend)

from FastSurferCNN.utils.plotting import backend
with backend("qtagg"):
fig, ax = plt.subplots(figsize=(8, 6))
plot_cc_index_calculation(
ax,
cc_contour,
anterior_idx,
posterior_idx,
ap_intersections,
middle_intersections,
midpoint,
)
ax.legend()
plt.show()

return cc_index

Expand Down
24 changes: 12 additions & 12 deletions CorpusCallosum/shape/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,19 +629,19 @@ def make_subdivision_mask(
subdivision_mask[points_left_of_line] = label

if plot: # interactive debug plot
import matplotlib
import matplotlib.pyplot as plt
curr_backend = matplotlib.get_backend()
plt.switch_backend("qtagg")
plt.figure(figsize=(10, 8))
plt.imshow(subdivision_mask, cmap='tab10')
plt.colorbar(label='Subdivision')
plt.title('CC Subdivision Mask')
plt.xlabel('X')
plt.ylabel('Y')
plt.tight_layout()
plt.show()
plt.switch_backend(curr_backend)

from FastSurferCNN.utils.plotting import backend
with backend("qtagg"):
plt.figure(figsize=(10, 8))
plt.imshow(subdivision_mask, cmap='tab10')
plt.colorbar(label='Subdivision')
plt.title('CC Subdivision Mask')
plt.xlabel('X')
plt.ylabel('Y')
plt.tight_layout()
plt.show()

return subdivision_mask


Expand Down
101 changes: 53 additions & 48 deletions CorpusCallosum/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@

from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.axes
import nibabel as nib
import numpy as np

from CorpusCallosum.utils.types import ContourList, Polygon2dType
from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Vector2d
from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Vector2d, noop_context


def plot_standardized_space(
ax_row: list[plt.Axes],
ax_row: list[matplotlib.axes.Axes],
vol: np.ndarray,
ac_coords: np.ndarray,
pc_coords: np.ndarray
Expand All @@ -33,7 +32,7 @@ def plot_standardized_space(

Parameters
----------
ax_row : list[plt.Axes]
ax_row : list[matplotlib.axes.Axes]
Row of axes to plot on (should be length 3).
vol : np.ndarray
Volume data to visualize.
Expand Down Expand Up @@ -125,6 +124,7 @@ def visualize_coordinate_spaces(
3. standardized image space
as a single image named 'ac_pc_spaces.png' in `output_dir`.
"""
from matplotlib import pyplot as plt
fig, ax = plt.subplots(3, 3, figsize=(12, 12))

# Original space (Column 0)
Expand Down Expand Up @@ -193,11 +193,12 @@ def plot_contours(

from nibabel.affines import apply_affine

from FastSurferCNN.utils.plotting import backend

if vox2ras is None and None in (split_contours, midline_equidistant, levelpaths):
raise ValueError("vox_size must be provided if split_contours, midline_equidistant, or levelpaths are given.")

if output_path is not None:
matplotlib.use('Agg') # Use non-GUI backend
_backend_context = noop_context if output_path is None else partial(backend, 'agg') # Use non-GUI backend

# convert vox_size from LIA to AS
ras2vox = partial(apply_affine, np.linalg.inv(vox2ras)[1:, 1:])
Expand All @@ -210,50 +211,54 @@ def plot_contours(
has_first_plot = not (len(_split_contours) == 0 and ac_coords_vox is None and pc_coords_vox is None)
num_plots = 1 + int(has_first_plot)

fig, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10))
with _backend_context():
# import here to have the correct backend set for non-GUI environments
from matplotlib import pyplot as plt

fig, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10))

# NOTE: For all plots imshow shows y inverted
current_plot = 0
# NOTE: For all plots imshow shows y inverted
current_plot = 0

if _split_contours:
reference_contour = _split_contours[-1]

if _split_contours:
reference_contour = _split_contours[-1]
# This visualization uses voxel coordinates in fsaverage space...
if has_first_plot:
ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray")
ax[current_plot].set_title(title)
if _split_contours:
for this_contour in _split_contours:
ax[current_plot].fill(this_contour[1, :], this_contour[0, :], color="steelblue", alpha=0.25)
kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid"}
ax[current_plot].plot(this_contour[1, :], this_contour[0, :], **kwargs)
if ac_coords_vox is not None:
ax[current_plot].scatter(ac_coords_vox[1], ac_coords_vox[0], color="red", marker="x")
if pc_coords_vox is not None:
ax[current_plot].scatter(pc_coords_vox[1], pc_coords_vox[0], color="blue", marker="x")
current_plot += int(has_first_plot)

# This visualization uses voxel coordinates in fsaverage space...
if has_first_plot:
ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray")
ax[current_plot].set_title(title)
if _split_contours:
for this_contour in _split_contours:
ax[current_plot].fill(this_contour[1, :], this_contour[0, :], color="steelblue", alpha=0.25)
kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid"}
ax[current_plot].plot(this_contour[1, :], this_contour[0, :], **kwargs)
if ac_coords_vox is not None:
ax[current_plot].scatter(ac_coords_vox[1], ac_coords_vox[0], color="red", marker="x")
if pc_coords_vox is not None:
ax[current_plot].scatter(pc_coords_vox[1], pc_coords_vox[0], color="blue", marker="x")
current_plot += int(has_first_plot)

ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray")
for this_path in _levelpaths:
ax[current_plot].plot(this_path[:, 1], this_path[:, 0], color="brown", linewidth=0.8)
ax[current_plot].set_title("Midline & Levelpaths")
if _midline_equi.shape[0] > 0:
ax[current_plot].plot(_midline_equi[:, 1], _midline_equi[:, 0], color="red")
if _split_contours:
ax[current_plot].plot(reference_contour[1, :], reference_contour[0, :], color="red", linewidth=0.5)

padding = 30
for a in ax.flatten():
a.set_aspect("equal", adjustable="box")
a.axis("off")
for this_path in _levelpaths:
ax[current_plot].plot(this_path[:, 1], this_path[:, 0], color="brown", linewidth=0.8)
ax[current_plot].set_title("Midline & Levelpaths")
if _midline_equi.shape[0] > 0:
ax[current_plot].plot(_midline_equi[:, 1], _midline_equi[:, 0], color="red")
if _split_contours:
# get bounding box of contours
a.set_xlim(reference_contour[1, :].min() - padding, reference_contour[1, :].max() + padding)
a.set_ylim((reference_contour[0, :]).max() + padding, (reference_contour[0, :]).min() - padding)

if output_path is None:
return plt.show()
for _output_path in (output_path if isinstance(output_path, (list, tuple)) else [output_path]):
Path(_output_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(_output_path, dpi=300, bbox_inches="tight")
ax[current_plot].plot(reference_contour[1, :], reference_contour[0, :], color="red", linewidth=0.5)

padding = 30
for a in ax.flatten():
a.set_aspect("equal", adjustable="box")
a.axis("off")
if _split_contours:
# get bounding box of contours
a.set_xlim(reference_contour[1, :].min() - padding, reference_contour[1, :].max() + padding)
a.set_ylim((reference_contour[0, :]).max() + padding, (reference_contour[0, :]).min() - padding)

if output_path is None:
return plt.show()
for _output_path in (output_path if isinstance(output_path, (list, tuple)) else [output_path]):
Path(_output_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(_output_path, dpi=300, bbox_inches="tight")
return None
3 changes: 1 addition & 2 deletions FastSurferCNN/segstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str:
read_lut = manager.make_read_hook(read_classes_from_lut)
if lut_file := getattr(args, "lut", None):
read_lut(lut_file, blocking=False)
# load these files in different threads to avoid waiting on IO
# (not parallel due to GIL though)
# load these files in different threads to avoid waiting on IO (not parallel due to GIL though)
load_image = manager.make_read_hook(read_volume_file)
preload_image = partial(load_image, blocking=False)
preload_image(segfile)
Expand Down
7 changes: 7 additions & 0 deletions FastSurferCNN/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"misc",
"nibabelImage",
"nibabelHeader",
"noop_context",
"parser_defaults",
"parallel",
"Plane",
Expand All @@ -48,6 +49,7 @@
"Vector3d",
]

from contextlib import contextmanager
from typing import Literal, TypeVar

# there are very few cases, when we do not need nibabel in any "full script" so always
Expand Down Expand Up @@ -92,3 +94,8 @@
Mask3d = ndarray[Shape3d, dtype[bool_]]
Mask4d = ndarray[Shape4d, dtype[bool_]]
RotationMatrix3x3 = ndarray[tuple[Literal[3], Literal[3]], dtype[float64]]

@contextmanager
def noop_context():
"""A no-op context manager that does nothing."""
yield
Loading