diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 5f12ddbb..4f5a9fc1 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -265,29 +265,29 @@ def connect_nearby_components(seg_arr: ArrayType, max_connection_distance: float # Plot components for debugging if plot: - import matplotlib import matplotlib.pyplot as plt - curr_backend = matplotlib.get_backend() - plt.switch_backend("qtagg") - n_components = len(component_sizes) - fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5)) - if n_components == 1: - axes = [axes] - # Plot each component in a different color - for i, (comp_id, comp_size) in enumerate(component_sizes): - component_mask = labels_cc == comp_id - axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray') - axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}') - axes[i].axis('off') - - # Plot the connected segmentation - axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray') - axes[-1].set_title('Connected Segmentation') - axes[-1].axis('off') - plt.tight_layout() - plt.show() - plt.switch_backend(curr_backend) - + + from FastSurferCNN.utils.plotting import backend + + with backend("qtagg"): + n_components = len(component_sizes) + fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5)) + if n_components == 1: + axes = [axes] + # Plot each component in a different color + for i, (comp_id, comp_size) in enumerate(component_sizes): + component_mask = labels_cc == comp_id + axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray') + axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}') + axes[i].axis('off') + + # Plot the connected segmentation + axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray') + axes[-1].set_title('Connected Segmentation') + axes[-1].axis('off') + plt.tight_layout() + plt.show() + return connected_seg diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index 09e8e128..c5e29c36 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -205,25 +205,24 @@ def find_cc_endpoints( pc_startpoint_idx = np.argmin(np.linalg.norm(contour - posterior_anchor_2d[:, None], axis=0)) if plot: # interactive debug plot of contour, ac, pc and endpoints - import matplotlib import matplotlib.pyplot as plt - curr_backend = matplotlib.get_backend() - plt.switch_backend("qtagg") - plt.figure(figsize=(10, 8)) - plt.plot(contour[0, :], contour[1, :], 'b-', label='CC Contour', linewidth=2) - plt.plot(ac_2d[0], ac_2d[1], 'go', markersize=10, label='AC') - plt.plot(pc_2d[0], pc_2d[1], 'ro', markersize=10, label='PC') - plt.plot(anterior_anchor_2d[0], anterior_anchor_2d[1], 'g^', markersize=10, label='Anterior Anchor') - plt.plot(posterior_anchor_2d[0], posterior_anchor_2d[1], 'r^', markersize=10, label='Posterior Anchor') - plt.plot(contour[0, ac_startpoint_idx], contour[1, ac_startpoint_idx], 'g*', markersize=15, label='AC Endpoint') - plt.plot(contour[0, pc_startpoint_idx], contour[1, pc_startpoint_idx], 'r*', markersize=15, label='PC Endpoint') - plt.xlabel('A-S (mm)') - plt.ylabel('I-S (mm)') - plt.title('CC Contour with Endpoints') - plt.legend() - plt.axis('equal') - plt.grid(True, alpha=0.3) - plt.show() - plt.switch_backend(curr_backend) + + from FastSurferCNN.utils.plotting import backend + with backend("qtagg"): + plt.figure(figsize=(10, 8)) + plt.plot(contour[0, :], contour[1, :], 'b-', label='CC Contour', linewidth=2) + plt.plot(*ac_2d[0:2], 'go', markersize=10, label='AC') + plt.plot(*pc_2d[0:2], 'ro', markersize=10, label='PC') + plt.plot(*anterior_anchor_2d[0:2], 'g^', markersize=10, label='Anterior Anchor') + plt.plot(*posterior_anchor_2d[0:2], 'r^', markersize=10, label='Posterior Anchor') + plt.plot(*contour[0:2, ac_startpoint_idx], 'g*', markersize=15, label='AC Endpoint') + plt.plot(*contour[0:2, pc_startpoint_idx], 'r*', markersize=15, label='PC Endpoint') + plt.xlabel('A-S (mm)') + plt.ylabel('I-S (mm)') + plt.title('CC Contour with Endpoints') + plt.legend() + plt.axis('equal') + plt.grid(True, alpha=0.3) + plt.show() return ac_startpoint_idx, pc_startpoint_idx diff --git a/CorpusCallosum/shape/metrics.py b/CorpusCallosum/shape/metrics.py index 4003b7b7..6442d26c 100644 --- a/CorpusCallosum/shape/metrics.py +++ b/CorpusCallosum/shape/metrics.py @@ -214,24 +214,22 @@ def calculate_cc_index(cc_contour: np.ndarray, plot: bool = False) -> float: cc_index = (anterior_thickness + posterior_thickness + middle_thickness) / ap_distance if plot: - import matplotlib import matplotlib.pyplot as plt - curr_backend = matplotlib.get_backend() - plt.switch_backend("qtagg") - - fig, ax = plt.subplots(figsize=(8, 6)) - plot_cc_index_calculation( - ax, - cc_contour, - anterior_idx, - posterior_idx, - ap_intersections, - middle_intersections, - midpoint, - ) - ax.legend() - plt.show() - plt.switch_backend(curr_backend) + + from FastSurferCNN.utils.plotting import backend + with backend("qtagg"): + fig, ax = plt.subplots(figsize=(8, 6)) + plot_cc_index_calculation( + ax, + cc_contour, + anterior_idx, + posterior_idx, + ap_intersections, + middle_intersections, + midpoint, + ) + ax.legend() + plt.show() return cc_index diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 6ad7b421..e1a3df6c 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -629,19 +629,19 @@ def make_subdivision_mask( subdivision_mask[points_left_of_line] = label if plot: # interactive debug plot - import matplotlib import matplotlib.pyplot as plt - curr_backend = matplotlib.get_backend() - plt.switch_backend("qtagg") - plt.figure(figsize=(10, 8)) - plt.imshow(subdivision_mask, cmap='tab10') - plt.colorbar(label='Subdivision') - plt.title('CC Subdivision Mask') - plt.xlabel('X') - plt.ylabel('Y') - plt.tight_layout() - plt.show() - plt.switch_backend(curr_backend) + + from FastSurferCNN.utils.plotting import backend + with backend("qtagg"): + plt.figure(figsize=(10, 8)) + plt.imshow(subdivision_mask, cmap='tab10') + plt.colorbar(label='Subdivision') + plt.title('CC Subdivision Mask') + plt.xlabel('X') + plt.ylabel('Y') + plt.tight_layout() + plt.show() + return subdivision_mask diff --git a/CorpusCallosum/utils/visualization.py b/CorpusCallosum/utils/visualization.py index 74b17c62..da4267fd 100644 --- a/CorpusCallosum/utils/visualization.py +++ b/CorpusCallosum/utils/visualization.py @@ -14,17 +14,16 @@ from pathlib import Path -import matplotlib -import matplotlib.pyplot as plt +import matplotlib.axes import nibabel as nib import numpy as np from CorpusCallosum.utils.types import ContourList, Polygon2dType -from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Vector2d +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Vector2d, noop_context def plot_standardized_space( - ax_row: list[plt.Axes], + ax_row: list[matplotlib.axes.Axes], vol: np.ndarray, ac_coords: np.ndarray, pc_coords: np.ndarray @@ -33,7 +32,7 @@ def plot_standardized_space( Parameters ---------- - ax_row : list[plt.Axes] + ax_row : list[matplotlib.axes.Axes] Row of axes to plot on (should be length 3). vol : np.ndarray Volume data to visualize. @@ -125,6 +124,7 @@ def visualize_coordinate_spaces( 3. standardized image space as a single image named 'ac_pc_spaces.png' in `output_dir`. """ + from matplotlib import pyplot as plt fig, ax = plt.subplots(3, 3, figsize=(12, 12)) # Original space (Column 0) @@ -193,11 +193,12 @@ def plot_contours( from nibabel.affines import apply_affine + from FastSurferCNN.utils.plotting import backend + if vox2ras is None and None in (split_contours, midline_equidistant, levelpaths): raise ValueError("vox_size must be provided if split_contours, midline_equidistant, or levelpaths are given.") - if output_path is not None: - matplotlib.use('Agg') # Use non-GUI backend + _backend_context = noop_context if output_path is None else partial(backend, 'agg') # Use non-GUI backend # convert vox_size from LIA to AS ras2vox = partial(apply_affine, np.linalg.inv(vox2ras)[1:, 1:]) @@ -210,50 +211,54 @@ def plot_contours( has_first_plot = not (len(_split_contours) == 0 and ac_coords_vox is None and pc_coords_vox is None) num_plots = 1 + int(has_first_plot) - fig, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) + with _backend_context(): + # import here to have the correct backend set for non-GUI environments + from matplotlib import pyplot as plt + + fig, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) - # NOTE: For all plots imshow shows y inverted - current_plot = 0 + # NOTE: For all plots imshow shows y inverted + current_plot = 0 + + if _split_contours: + reference_contour = _split_contours[-1] - if _split_contours: - reference_contour = _split_contours[-1] + # This visualization uses voxel coordinates in fsaverage space... + if has_first_plot: + ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray") + ax[current_plot].set_title(title) + if _split_contours: + for this_contour in _split_contours: + ax[current_plot].fill(this_contour[1, :], this_contour[0, :], color="steelblue", alpha=0.25) + kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid"} + ax[current_plot].plot(this_contour[1, :], this_contour[0, :], **kwargs) + if ac_coords_vox is not None: + ax[current_plot].scatter(ac_coords_vox[1], ac_coords_vox[0], color="red", marker="x") + if pc_coords_vox is not None: + ax[current_plot].scatter(pc_coords_vox[1], pc_coords_vox[0], color="blue", marker="x") + current_plot += int(has_first_plot) - # This visualization uses voxel coordinates in fsaverage space... - if has_first_plot: ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray") - ax[current_plot].set_title(title) - if _split_contours: - for this_contour in _split_contours: - ax[current_plot].fill(this_contour[1, :], this_contour[0, :], color="steelblue", alpha=0.25) - kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid"} - ax[current_plot].plot(this_contour[1, :], this_contour[0, :], **kwargs) - if ac_coords_vox is not None: - ax[current_plot].scatter(ac_coords_vox[1], ac_coords_vox[0], color="red", marker="x") - if pc_coords_vox is not None: - ax[current_plot].scatter(pc_coords_vox[1], pc_coords_vox[0], color="blue", marker="x") - current_plot += int(has_first_plot) - - ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray") - for this_path in _levelpaths: - ax[current_plot].plot(this_path[:, 1], this_path[:, 0], color="brown", linewidth=0.8) - ax[current_plot].set_title("Midline & Levelpaths") - if _midline_equi.shape[0] > 0: - ax[current_plot].plot(_midline_equi[:, 1], _midline_equi[:, 0], color="red") - if _split_contours: - ax[current_plot].plot(reference_contour[1, :], reference_contour[0, :], color="red", linewidth=0.5) - - padding = 30 - for a in ax.flatten(): - a.set_aspect("equal", adjustable="box") - a.axis("off") + for this_path in _levelpaths: + ax[current_plot].plot(this_path[:, 1], this_path[:, 0], color="brown", linewidth=0.8) + ax[current_plot].set_title("Midline & Levelpaths") + if _midline_equi.shape[0] > 0: + ax[current_plot].plot(_midline_equi[:, 1], _midline_equi[:, 0], color="red") if _split_contours: - # get bounding box of contours - a.set_xlim(reference_contour[1, :].min() - padding, reference_contour[1, :].max() + padding) - a.set_ylim((reference_contour[0, :]).max() + padding, (reference_contour[0, :]).min() - padding) - - if output_path is None: - return plt.show() - for _output_path in (output_path if isinstance(output_path, (list, tuple)) else [output_path]): - Path(_output_path).parent.mkdir(parents=True, exist_ok=True) - fig.savefig(_output_path, dpi=300, bbox_inches="tight") + ax[current_plot].plot(reference_contour[1, :], reference_contour[0, :], color="red", linewidth=0.5) + + padding = 30 + for a in ax.flatten(): + a.set_aspect("equal", adjustable="box") + a.axis("off") + if _split_contours: + # get bounding box of contours + a.set_xlim(reference_contour[1, :].min() - padding, reference_contour[1, :].max() + padding) + a.set_ylim((reference_contour[0, :]).max() + padding, (reference_contour[0, :]).min() - padding) + + if output_path is None: + return plt.show() + for _output_path in (output_path if isinstance(output_path, (list, tuple)) else [output_path]): + Path(_output_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(_output_path, dpi=300, bbox_inches="tight") return None diff --git a/FastSurferCNN/segstats.py b/FastSurferCNN/segstats.py index 1955c8d8..7a662b1a 100644 --- a/FastSurferCNN/segstats.py +++ b/FastSurferCNN/segstats.py @@ -764,8 +764,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str: read_lut = manager.make_read_hook(read_classes_from_lut) if lut_file := getattr(args, "lut", None): read_lut(lut_file, blocking=False) - # load these files in different threads to avoid waiting on IO - # (not parallel due to GIL though) + # load these files in different threads to avoid waiting on IO (not parallel due to GIL though) load_image = manager.make_read_hook(read_volume_file) preload_image = partial(load_image, blocking=False) preload_image(segfile) diff --git a/FastSurferCNN/utils/__init__.py b/FastSurferCNN/utils/__init__.py index 223b208a..0aacf105 100644 --- a/FastSurferCNN/utils/__init__.py +++ b/FastSurferCNN/utils/__init__.py @@ -31,6 +31,7 @@ "misc", "nibabelImage", "nibabelHeader", + "noop_context", "parser_defaults", "parallel", "Plane", @@ -48,6 +49,7 @@ "Vector3d", ] +from contextlib import contextmanager from typing import Literal, TypeVar # there are very few cases, when we do not need nibabel in any "full script" so always @@ -92,3 +94,8 @@ Mask3d = ndarray[Shape3d, dtype[bool_]] Mask4d = ndarray[Shape4d, dtype[bool_]] RotationMatrix3x3 = ndarray[tuple[Literal[3], Literal[3]], dtype[float64]] + +@contextmanager +def noop_context(): + """A no-op context manager that does nothing.""" + yield \ No newline at end of file diff --git a/FastSurferCNN/utils/plotting.py b/FastSurferCNN/utils/plotting.py new file mode 100644 index 00000000..3f6700ce --- /dev/null +++ b/FastSurferCNN/utils/plotting.py @@ -0,0 +1,51 @@ +#!/bin/python + +# Copyright 2026 Image Analysis Lab, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from contextlib import contextmanager + + +@contextmanager +def backend(backend: str): + """ + Context manager to temporarily set the matplotlib backend. + + Parameters + ---------- + backend : str + The name of the matplotlib backend to use within the context. + + Yields + ------ + None + This function does not yield any value, it only sets the backend temporarily. + """ + import matplotlib + + original_backend = matplotlib.get_backend() + try: + matplotlib.use(backend, force=True) + yield + finally: + try: + matplotlib.use(original_backend, force=True) + except Exception as e: + logger = logging.getLogger(__name__) + logger.warning(f"Failed switching back to the original matplotlib backend {original_backend}: " + f"{', '.join(map(str, e.args))}") + # Fall back to a safe non-interactive backend if the original one cannot be restored in the current + # environment. + matplotlib.use("Agg", force=True) diff --git a/HypVINN/inference.py b/HypVINN/inference.py index 017f880e..8c19ac6d 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -52,7 +52,6 @@ class Inference: def __init__( self, cfg, - threads: int = -1, async_io: bool = False, device: str = "auto", viewagg_device: str = "auto", @@ -68,8 +67,6 @@ def __init__( ---------- cfg : yacs.config.CfgNode The configuration node containing the parameters for the model. - threads : int, optional - The number of threads to use. Default is -1, which uses all available threads. async_io : bool, optional Whether to use asynchronous IO. Default is False. device : str, optional @@ -77,8 +74,9 @@ def __init__( viewagg_device : str, optional The device to use for view aggregation. Can be 'auto', 'cpu', or 'cuda'. Default is 'auto'. """ - self._threads = threads - torch.set_num_threads(self._threads) + from FastSurferCNN.utils.parallel import get_num_threads + + torch.set_num_threads(get_num_threads()) self._async_io = async_io # Set random seed from configs. diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index f789a6ff..14486c28 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -23,16 +23,15 @@ if TYPE_CHECKING: import yacs.config - from nibabel.filebasedimages import FileBasedHeader -from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults +from FastSurferCNN.utils import PLANES, Plane, logging, nibabelHeader, parser_defaults from FastSurferCNN.utils.checkpoint import ( get_checkpoints, get_config_file, load_checkpoint_config_defaults, ) from FastSurferCNN.utils.common import update_docstring -from FastSurferCNN.utils.parallel import SerialExecutor +from FastSurferCNN.utils.parallel import get_num_threads, thread_executor from HypVINN.config.hypvinn_files import HYPVINN_MASK_NAME, HYPVINN_SEG_NAME from HypVINN.data_loader.data_utils import hypo_map_label2subseg, rescale_image from HypVINN.inference import Inference @@ -157,7 +156,6 @@ def main( hypo_maskfile: str = HYPVINN_MASK_NAME, qc_snapshots: bool = False, reg_mode: Literal["coreg", "robust", "none"] = "coreg", - threads: int = -1, batch_size: int = 1, async_io: bool = False, device: str = "auto", @@ -189,32 +187,29 @@ def main( cfg_sag : Path The path to the sagittal configuration file. hypo_segfile : str, default="{HYPVINN_SEG_NAME}" - The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}. + The name of the hypothalamus segmentation file. hypo_maskfile : str, default="{HYPVINN_MASK_NAME}" - The name of the hypothalamus mask file. Default is {HYPVINN_MASK_NAME}. - qc_snapshots : bool, optional - Whether to create QC snapshots. Default is False. + The name of the hypothalamus mask file. + qc_snapshots : bool, default=False + Whether to create QC snapshots. reg_mode : "coreg", "robust", "none", default="coreg" - The registration mode to use. Default is "coreg". - threads : int, default=-1 - The number of threads to use. Default is -1, which uses all available threads. + The registration mode to use. batch_size : int, default=1 - The batch size to use. Default is 1. + The batch size to use. async_io : bool, default=False - Whether to use asynchronous I/O. Default is False. + Whether to use asynchronous I/O. device : str, default="auto" - The device to use. Default is "auto", which automatically selects the device. + The device to use. "auto" automatically selects the device. viewagg_device : str, default="auto" - The view aggregation device to use. Default is "auto", which automatically selects the device. + The view aggregation device to use. "auto" automatically selects the device. Returns ------- int, str 0, if successful, an error message describing the cause for the failure otherwise. """ - from concurrent.futures import Future, ProcessPoolExecutor + from concurrent.futures import Future - pool = ProcessPoolExecutor(threads) if threads != 1 else SerialExecutor() prep_tasks: dict[str, Future] = {} # mapped freesurfer orig input name to the hypvinn t1 name @@ -225,7 +220,7 @@ def main( start = time() try: # Set up logging - prep_tasks["cp"] = pool.submit(prepare_checkpoints, ckpt_ax, ckpt_cor, ckpt_sag) + prep_tasks["cp"] = thread_executor().submit(prepare_checkpoints, ckpt_ax, ckpt_cor, ckpt_sag) kwargs = {} if t1_path is not None: @@ -242,7 +237,7 @@ def main( f"available." ) - # Create output directory if it does not already exist. + # Create the output directory if it does not already exist. create_expand_output_directory(subject_dir, qc_snapshots) logger.info(f"Running HypVINN segmentation pipeline on subject {sid}") logger.info(f"Output will be stored in: {subject_dir}") @@ -254,12 +249,12 @@ def main( # Note, that t1_path and t2_path are guaranteed to be not None via # get_hypvinn_mode, which only returns t1t2, if t1 and t2 exist. # hypvinn_preproc returns the path to the t2 that is registered to the t1 - prep_tasks["reg"] = pool.submit( + prep_tasks["reg"] = thread_executor().submit( hypvinn_preproc, mode, reg_mode, subject_dir=Path(subject_dir), - threads=threads, + threads=get_num_threads(), **kwargs, ) @@ -288,13 +283,12 @@ def main( if "reg" in prep_tasks: t2_path = prep_tasks["reg"].result() kwargs["t2_path"] = t2_path - prep_tasks["load"] = pool.submit(load_volumes, mode=mode, **kwargs) + prep_tasks["load"] = thread_executor().submit(load_volumes, mode=mode, **kwargs) # Set up model model = Inference( cfg=cfg_fin, async_io=async_io, - threads=threads, viewagg_device=viewagg_device, device=device, ) @@ -336,20 +330,17 @@ def main( ras_affine=affine, ras_header=header, subject_dir=subject_dir, - seg_file=hypo_segfile, + seg_file=Path(hypo_segfile), mask_file=hypo_maskfile, save_mask=True, ) logger.info(f"Prediction successfully saved in {time_needed} seconds.") if qc_snapshots: - qc_future: Future | None = pool.submit( + qc_future: Future | None = thread_executor().submit( plot_qc_images, subject_qc_dir=subject_dir / "qc_snapshots", orig_path=orig_path, - prediction_path=Path(subject_dir / "mri" /hypo_segfile), - ) - qc_future.add_done_callback( - lambda x: logger.info(f"QC snapshots saved in {x.result()} seconds."), + prediction_path=subject_dir / "mri" / hypo_segfile, ) else: qc_future = None @@ -357,30 +348,36 @@ def main( logger.info("Computing stats") return_value = compute_stats( orig_path=orig_path, - prediction_path=Path(subject_dir / "mri" /hypo_segfile), + prediction_path=subject_dir / "mri" / hypo_segfile, stats_dir=subject_dir / "stats", - threads=threads, ) if return_value != 0: + # if not 0, return_value is a string describing the error logger.error(return_value) - logger.info( - f"Processing segmentation finished in {time() - seg:0.4f} seconds." - ) + logger.info(f"Processing segmentation finished in {time() - seg:0.4f} seconds.") except (FileNotFoundError, RuntimeError) as e: logger.info(f"Failed Evaluation on {subject_name}:") logger.exception(e) + + return f"HypVINN segmentation pipeline failed with {type(e).__name__}: {'; '.join(map(str, e.args))}." else: if qc_future: # finish qc - qc_future.result() + if e := qc_future.exception(): + logger.warning(f"Failed to create qc snapshots for {subject_name}:") + logger.exception(e) - logger.info( - f"Processing whole pipeline finished in {time() - start:.4f} seconds." - ) + # Note that a failure of qc image generation is only a warning for the whole hypothalamus segmentation. + else: + logger.info(f"QC snapshots saved in {qc_future.result()} seconds.") + + logger.info(f"Processing whole pipeline finished in {time() - start:.4f} seconds.") + + return return_value -def prepare_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag): +def prepare_checkpoints(ckpt_ax: str | Path, ckpt_cor: str | Path, ckpt_sag: str | Path) -> None: """ Prepare the checkpoints for the Hypothalamus Segmentation model. @@ -389,11 +386,11 @@ def prepare_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag): Parameters ---------- - ckpt_ax : str + ckpt_ax : str, Path The path to the axial checkpoint file. - ckpt_cor : str + ckpt_cor : str. Path The path to the coronal checkpoint file. - ckpt_sag : str + ckpt_sag : str, Path The path to the sagittal checkpoint file. """ logger.info("Checking or downloading default checkpoints ...") @@ -409,7 +406,7 @@ def load_volumes( ) -> tuple[ ModalityDict, npt.NDArray[float], - "FileBasedHeader", + "nibabelHeader", tuple[float, float, float], tuple[int, int, int], ]: @@ -455,7 +452,7 @@ def load_volumes( t1_zoom = () t2_zoom = () affine: npt.NDArray[float] = np.ndarray([0]) - header: FileBasedHeader | None = None + header: nibabelHeader | None = None zoom: tuple[float, float, float] = (0.0, 0.0, 0.0) size: tuple[int, ...] = (0, 0, 0) diff --git a/HypVINN/utils/stats_utils.py b/HypVINN/utils/stats_utils.py index b9e8edfe..cdb79b1a 100644 --- a/HypVINN/utils/stats_utils.py +++ b/HypVINN/utils/stats_utils.py @@ -19,7 +19,6 @@ def compute_stats( orig_path: Path, prediction_path: Path, stats_dir: Path, - threads: int, ) -> int | str: """ Compute statistics for the segmentation results. @@ -32,8 +31,6 @@ def compute_stats( The path to the predicted segmentation. stats_dir : Path The directory for storing the statistics. - threads : int - The number of threads to be used. Returns ------- @@ -45,6 +42,11 @@ def compute_stats( ------ RuntimeError If the main function from FastSurferCNN.segstats fails to run. + + Notes + ----- + The underlying segstats will read the number of threads from the global variable set via + `FastSurfer.utils.parallel.set_num_threads`. """ from collections import namedtuple @@ -67,7 +69,7 @@ def compute_stats( args.ids = labels args.merged_labels = [] args.robust = None - args.threads = threads + # the threads argument no longer works, this is handled globally via set_num_threads and get_num_threads args.patch_size = 32 args.device = "auto" args.lut = FASTSURFER_ROOT / "FastSurferCNN/config/FreeSurferColorLUT.txt" diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index f3bf6ddb..8fb2145b 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -13,17 +13,13 @@ # limitations under the License. from pathlib import Path -import matplotlib.pyplot as plt import nibabel as nib import numpy as np from FastSurferCNN.utils.common import update_docstring - -#from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT +from FastSurferCNN.utils.plotting import backend from HypVINN.config.hypvinn_files import HYPVINN_LUT -#_doc_HYPVINN_LUT = os.path.relpath(HYPVINN_LUT, FASTSURFER_ROOT) - def remove_values_from_list(the_list, val): """ @@ -250,7 +246,8 @@ def plot_qc_images( prediction_path: Path, padd: int = 45, lut_file: Path = HYPVINN_LUT, - slice_step: int = 2): + slice_step: int = 2, +) -> None: """ Plot the quality control images for the subject. @@ -290,39 +287,27 @@ def plot_qc_images( hypo_seg, cmap = map_hyposeg2label(hyposeg=mod_pred, lut_file=lut_file) if len(idx) > 0: - crop_image = mod_image[idx, :, :] - crop_seg = hypo_seg[idx, :, :] - - cm = ndimage.center_of_mass(crop_seg > 0) - - cm = np.asarray(cm).astype(int) - - crop_image = crop_image[:, cm[1] - padd:cm[1] + padd, cm[2] - padd:cm[2] + padd] - crop_seg = crop_seg[:, cm[1] - padd:cm[1] + padd, cm[2] - padd:cm[2] + padd] - + center_of_mass = np.asarray(ndimage.center_of_mass(crop_seg > 0), dtype=int) else: depth = hypo_seg.shape[0] // 2 crop_image = mod_image[depth - 8:depth + 8, :, :] crop_seg = hypo_seg[depth - 8:depth + 8, :, :] - cm = [crop_image.shape[0] // 2, crop_image.shape[1] // 2, crop_image.shape[2] // 2] - cm = np.array(cm).astype(int) - - crop_image = crop_image[:, cm[1] - padd:cm[1] + padd, cm[2] - padd:cm[2] + padd] - crop_seg = crop_seg[:, cm[1] - padd:cm[1] + padd, cm[2] - padd:cm[2] + padd] + center_of_mass = np.asarray([0, crop_image.shape[1] // 2, crop_image.shape[2] // 2], dtype=int) - crop_image = np.rot90(np.flip(crop_image, axis=0), k=-1, axes=(1, 2)) - crop_seg = np.rot90(np.flip(crop_seg, axis=0), k=-1, axes=(1, 2)) + com_indexer = (slice(None),) + tuple(slice(com - padd, com + padd) for com in center_of_mass[1:3]) - fig = plot_coronal_predictions( - cmap=cmap, - images_batch=crop_image, - pred_batch=crop_seg, - img_per_row=crop_image.shape[0], - ) + def _data_around_com(data): + return np.rot90(np.flip(data[com_indexer], axis=0), k=-1, axes=(1, 2)) - fig.savefig(subject_qc_dir / HYPVINN_QC_IMAGE_NAME, transparent=False) + with backend('agg'): + fig = plot_coronal_predictions( + cmap=cmap, + images_batch=_data_around_com(crop_image), + pred_batch=_data_around_com(crop_seg), + img_per_row=crop_image.shape[0], + ) - plt.close(fig) + fig.savefig(subject_qc_dir / HYPVINN_QC_IMAGE_NAME, transparent=False)