diff --git a/.codespellignore b/.codespellignore index df7814600..2cbe5c5e8 100644 --- a/.codespellignore +++ b/.codespellignore @@ -4,3 +4,4 @@ padd struc TE warmup +ALS diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml new file mode 100644 index 000000000..68ab6290c --- /dev/null +++ b/.github/workflows/unittest.yaml @@ -0,0 +1,88 @@ +name: test + +# File: test.yaml +# Author: David Kügler +# Created on: 2025-09-08 +# Functionality: This workflow runs unit tests defined in tests/image. + +concurrency: + group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} + cancel-in-progress: true + +on: + pull_request: + workflow_dispatch: + +permissions: + actions: read + attestations: read + checks: read + contents: read + deployments: read + issues: read + discussions: read + packages: read + pages: read + pull-requests: write + repository-projects: read + security-events: read + statuses: read + +jobs: + image-test: + name: 'Run FastSurfer unit tests from test/image' + runs-on: ubuntu-latest + timeout-minutes: 180 + strategy: + matrix: + tests: ["image"] + pytest-flags: [""] + + steps: + - uses: actions/checkout@v4 + - name: Setup Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + architecture: 'x64' + cache: 'pip' # caching pip dependencies + - name: Install dependencies + shell: bash + # uses the same python environment as quicktest + run: | + echo "::group::Create python environment" + python -m pip install --progress-bar off --upgrade pip setuptools wheel + python -m pip install --progress-bar off .[quicktest] + echo "::endgroup::" + # run pytest checks for data consistency/quality + - name: Run pytest + env: + PYTHON_PATH: . + FASTSURFER_HOME: ${{ github.workspace }} + shell: bash + continue-on-error: true # pytest only exits with code 0, if all tests pass + run: | + echo "::group::Run tests" + flags=("--junit-xml=/tmp/fastsurfer-unittest-${{ matrix.tests }}.junit.xml") + if [[ "$ACTION_RUNNER_DEBUG" == "true" ]] + then + flags+=(-vv --log_cli_level=DEBUG) + fi + python -m pytest "${flags[@]}" ${{ matrix.pytest-flags }} test/${{ matrix.tests }} + echo "::endgroup::" + - name: Upload the JUnit XML file as an artifact + uses: actions/upload-artifact@v4 + with: + name: fastsurfer-${{ github.sha }}-${{ matrix.tests }}-junit + path: /tmp/fastsurfer-unittest-${{ matrix.tests }}.junit.xml + - name: Annotate the results into the check + uses: mikepenz/action-junit-report@v5 + with: + report_paths: /tmp/fastsurfer-unittest-${{ matrix.tests }}.junit.xml + check_name: Annotate the test results as checks + fail_on_failure: 'true' + fail_on_parse_error: 'true' + # here, we only add the annotations, the fail-state cannot be set here by this action because the github token + # does not have the required permissions in PRs from forks. + # See: https://github.com/mikepenz/action-junit-report?tab=readme-ov-file#pr-run-permissions + annotate_only: 'true' diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index 5e56e9d87..83b2d424d 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -19,18 +19,16 @@ import re import sys from collections.abc import Callable, Iterable, Sequence -from typing import TYPE_CHECKING, Literal, TypeVar, Union, cast +from typing import TYPE_CHECKING, Literal, TypeVar, cast import nibabel import nibabel as nib import numpy as np -import numpy.typing as npt -from nibabel.freesurfer.mghformat import MGHHeader +from nibabel.freesurfer.mghformat import MGHHeader, MGHImage +from numpy import typing as npt -if TYPE_CHECKING: - import torch - -from FastSurferCNN.utils import ScalarType, logging, nibabelImage +from FastSurferCNN.utils import AffineMatrix4x4, ScalarType, Shape1d, logging, nibabelHeader, nibabelImage +from FastSurferCNN.utils.affines import OrntArrayType, aff2axcodes, io_orientation from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, StrictOrientationType, VoxSizeOption from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one_mm from FastSurferCNN.utils.arg_types import img_size as __img_size @@ -56,11 +54,19 @@ Modified by: David Kügler Date: May-12-2025 """ +FIX_MGH_AFFINE_CALCULATION = False +FIX_CENTER_NOT_CENTER = True LOGGER = logging.getLogger(__name__) -_TA = TypeVar("_TA", bound=Union[np.ndarray, "torch.Tensor"]) -_TB = TypeVar("_TB", bound=Union[np.ndarray, "torch.Tensor"]) +if TYPE_CHECKING: + from torch import Tensor + + _TA = TypeVar("_TA", bound=np.ndarray | Tensor) + _TB = TypeVar("_TB", bound=np.ndarray | Tensor) +else: + _TA = TypeVar("_TA", bound=np.ndarray) + _TB = TypeVar("_TB", bound=np.ndarray) def __rescale_type(a: str) -> float | int | None: @@ -254,7 +260,7 @@ def options_parse(): def to_target_orientation( image_data: _TA, - source_affine: npt.NDArray[float], + source_affine: AffineMatrix4x4, target_orientation: StrictOrientationType, ) -> tuple[_TA, Callable[[_TB], _TB]]: """ @@ -264,7 +270,7 @@ def to_target_orientation( ---------- image_data : np.ndarray, torch.Tensor The image data to reorder/flip. - source_affine : npt.NDArray[float] + source_affine : AffineMatrix4x4 The affine to detect the reorientation operations. target_orientation : StrictOrientationType The target orientation to reorient to. @@ -276,51 +282,67 @@ def to_target_orientation( Callable[[np.ndarray], np.ndarray], Callable[[torch.Tensor], torch.Tensor] A function that flips and reorders the data back (returns same type as output). """ - reorient_ornt, unorient_ornt = orientation_to_ornts(source_affine, target_orientation) - - if np.any([reorient_ornt[:, 1] != 1, reorient_ornt[:, 0] != np.arange(reorient_ornt.shape[0])]): # is not lia yet - def back_to_native(data: _TB) -> _TB: - return apply_orientation(data, unorient_ornt) - - return apply_orientation(image_data, reorient_ornt), back_to_native - else: # data is already in lia - def do_nothing(data: _TB) -> _TB: - return data + def do_nothing(data: _TB) -> _TB: + return data + _target_orientation = target_orientation[slice(5 if target_orientation.lower().startswith("soft") else 0)].lower() + if _target_orientation == "native": # should not really happen, but let's be safe return image_data, do_nothing - -def orientation_to_ornts( - source_affine: npt.NDArray[float], - target_orientation: StrictOrientationType, -) -> tuple[npt.NDArray[int], npt.NDArray[int]]: + # vox2vox should always be a "soft" transform + vox2vox = vox2vox_for_target_orientation(source_affine, "soft " + _target_orientation, image_data.shape).round() + if np.allclose(vox2vox, np.eye(4)): # is already target_orientation + return image_data, do_nothing + else: # is not target_affine yet + from functools import partial + out_shape = np.abs(vox2vox[:3, :3] @ np.asarray(image_data.shape)).astype(int) + inverse_apply_vox2vox = partial(apply_vox2vox, vox2vox=np.linalg.inv(vox2vox), out_shape=image_data.shape) + return apply_vox2vox(image_data, vox2vox, out_shape), inverse_apply_vox2vox + + +def vox2vox_for_target_orientation( + source_affine: AffineMatrix4x4, + target_orientation: OrientationType, + shape: npt.ArrayLike, +) -> AffineMatrix4x4: """ - Determine the nibabel `ornt` Array to reorder and flip data from source_affine such that the data is in orientation. + Determine the affine matrix to reorder and flip/interpolate data from source_affine to orientation. + + The resulting transform is a vox2vox from source to target. Parameters ---------- source_affine : npt.NDArray[float] The affine to detect the reorientation operations. - target_orientation : StrictOrientationType + target_orientation : OrientationType The target orientation to reorient to. + shape : array_like + The source shape of the data to reorder. Returns ------- - npt.NDArray[int] - The `ornt` transform from source_affine to target_orientation. - npt.NDArray[int] - The `ornt` transform back from target_orientation to source_affine. + AffineMatrix4x4 + The affine matrix to transform from source_affine to target_orientation. """ - from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform - - source_ornt = io_orientation(source_affine) - target_ornt = axcodes2ornt(target_orientation.upper()) - reorient_ornt = ornt_transform(source_ornt, target_ornt) - unorient_ornt = ornt_transform(target_ornt, source_ornt) - return reorient_ornt.astype(int), unorient_ornt.astype(int) + from nibabel.orientations import axcodes2ornt + from FastSurferCNN.utils.affines import io_orientation + + _target_orientation = target_orientation.lower() + if _target_orientation == "native": + return np.eye(4, dtype=source_affine.dtype) + # use strict affine if soft orientation intended + elif _target_orientation.startswith("soft"): + _target_orientation = _target_orientation[5:] + _source_affine = ornt2affine(io_orientation(source_affine), (0,) * 3) + else: + _source_affine = source_affine + if any(c not in "lrpais" for c in _target_orientation): + raise ValueError(f"Invalid target_orientation: {target_orientation}.") + target_strict_affine = ornt2affine(axcodes2ornt(_target_orientation, ("lr", "pa", "is")), shape) + return np.linalg.inv(_source_affine) @ target_strict_affine -def apply_orientation(arr: _TB | npt.ArrayLike, ornt: npt.NDArray[int]) -> _TB: +def apply_orientation(arr: _TB | npt.ArrayLike, ornt: OrntArrayType) -> _TB: """ Apply transformations implied by `ornt` to the first n axes of the array `arr`. @@ -346,31 +368,34 @@ def apply_orientation(arr: _TB | npt.ArrayLike, ornt: npt.NDArray[int]) -> _TB: """ from nibabel.orientations import OrientationError from nibabel.orientations import apply_orientation as _apply_orientation - from torch import is_tensor as _is_tensor - - if _is_tensor(arr): - ornt = np.asarray(ornt) - n = ornt.shape[0] - if arr.ndim < n: - raise OrientationError("Data array has fewer dimensions than orientation") - # apply ornt transformations - flip_dims = np.nonzero(ornt[:, 1] == -1)[0].tolist() - if len(flip_dims) > 0: - arr = arr.flip(flip_dims) - full_transpose = np.arange(arr.ndim) - # ornt indicates the transpose that has occurred - we reverse it - full_transpose[:n] = np.argsort(ornt[:, 0]) - t_arr = arr.permute(*full_transpose) - return t_arr - else: - return _apply_orientation(arr, ornt) + + # only import torch, if it is likely we are dealing with a tensor + if hasattr(arr, "device"): + from torch import is_tensor as _is_tensor + + if _is_tensor(arr): + ornt = np.asarray(ornt) + n = ornt.shape[0] + if arr.ndim < n: + raise OrientationError("Data array has fewer dimensions than orientation") + # apply ornt transformations + flip_dims = np.nonzero(ornt[:, 1] == -1)[0].tolist() + if len(flip_dims) > 0: + arr = arr.flip(flip_dims) + full_transpose = np.arange(arr.ndim) + # ornt indicates the transpose that has occurred - we reverse it + full_transpose[:n] = np.argsort(ornt[:, 0]) + t_arr = arr.permute(*full_transpose) + return t_arr + + return _apply_orientation(arr, ornt) def map_image( img: nibabelImage, - out_affine: npt.NDArray[float], - out_shape: tuple[int, ...] | npt.NDArray[int] | Iterable[int], - ras2ras: npt.NDArray[np.number] | None = None, + out_affine: AffineMatrix4x4, + out_shape: np.ndarray[Shape1d, np.dtype[np.integer]] | Iterable[int], + ras2ras: AffineMatrix4x4 | None = None, order: int = 1, dtype: np.dtype[ScalarType] | npt.DTypeLike | None = None, vox_eps: float = 1e-4, @@ -381,13 +406,13 @@ def map_image( Parameters ---------- - img : nibabelImage + img : nibabel.spatialimages.SpatialImage The src 3D image with data and affine set. - out_affine : np.ndarray + out_affine : AffineMatrix4x4 Trg image affine. - out_shape : tuple[int, ...], np.ndarray + out_shape : tuple[int, ...], np.ndarray of int The target shape information. - ras2ras : np.ndarray, optional + ras2ras : AffineMatrix4x4, optional An additional mapping that should be applied (default=id to just reslice). order : int, default=1 Order of interpolation (0=nearest,1=linear,2=quadratic,3=cubic). @@ -407,10 +432,43 @@ def map_image( ras2ras = np.eye(4) # compute vox2vox from src to trg - vox2vox = np.linalg.inv(out_affine) @ ras2ras @ img.affine - + vox2vox = np.linalg.inv(out_affine) @ ras2ras @ get_affine_from_any(img) # here we apply the inverse vox2vox (to pull back the src info to the target image) image_data = np.asarray(img.dataobj, dtype=dtype) + return apply_vox2vox(image_data, vox2vox, out_shape=out_shape, order=order, vox_eps=vox_eps, rot_eps=rot_eps) + + +def apply_vox2vox( + image_data: _TA, + vox2vox: AffineMatrix4x4, + out_shape: np.ndarray[tuple[int], np.dtype[np.integer]] | Iterable[int], + order: int = 1, + vox_eps: float = 1e-4, + rot_eps: float = 1e-6, + ) -> _TA: + """ + Map image to new voxel space (RAS orientation). + + Parameters + ---------- + image_data : np.ndarray + The 3D image data. + vox2vox : np.ndarray + To-apply affine. + out_shape : tuple[int, ...], np.ndarray + The target shape information. + order : int, default=1 + Order of interpolation (0=nearest,1=linear,2=quadratic,3=cubic). + vox_eps : float, default=1e-4 + The epsilon for the voxelsize check. + rot_eps : float, default=1e-6 + The epsilon for the affine rotation check. + + Returns + ------- + np.ndarray + Mapped image data array. + """ # convert frames to single image out_shape = tuple(out_shape) @@ -442,19 +500,21 @@ def map_image( inv_vox2vox = np.linalg.inv(vox2vox) if not does_vox2vox_rot_require_interpolation(vox2vox, vox_eps=vox_eps, rot_eps=rot_eps): - LOGGER.debug(f"vox2vox: {vox2vox}") # second condition: translations are integers if np.allclose(vox2vox[:, 3], np.round(vox2vox[:, 3]), atol=1e-4): # reorder axes - ornt = nib.orientations.io_orientation(vox2vox) + from FastSurferCNN.utils.affines import io_orientation + ornt = io_orientation(vox2vox) + reordered = apply_orientation(image_data, ornt) + new_old_index = list(enumerate(map(int, ornt[:, 0]))) # if the direction is flipped (ornt[j, 1] == -1), offset has to start at the other end - offsets = [ornt[j, 1] * vox2vox[i, 3] + (ornt[j, 1] == -1) * img.shape[j] for i, j in new_old_index] + offsets = [-vox2vox[i, 3] + (ornt[j, 1] == -1) * (image_data.shape[j] - 1) for i, j in new_old_index] offsets = list(map(lambda x: int(x.astype(int)), offsets)) - reordered = apply_orientation(image_data, ornt) # pad=0 => pad with zeros return crop_transform(reordered, offsets=offsets, target_shape=out_shape, pad=0) + # TODO: in contrast to the type annotation, the following is not compatible with torch.Tensor from scipy.ndimage import affine_transform return affine_transform(image_data, inv_vox2vox, output_shape=out_shape, order=order) @@ -628,6 +688,37 @@ def rescale( return data_new +def get_affine_from_any(image: nibabelImage | nibabelHeader) -> AffineMatrix4x4: + """ + Retrieve the affine matrix from an MGH header. + + This function also incorporates the `FIX_MGH_AFFINE_CALCULATION` hardcoded flag, which attempts to fix the nibabel + calculation of the affine matrix for MGH images (which incorrectly assumes Pxyz_c to be at the center of the image). + It is not, Pxyz_c is offset by half a voxel, see + https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems?action=AttachFile&do=get&target=fscoordinates.pdf . + + Parameters + ========== + image : nibabel.spatialimages.SpatialImage, nibabel.spatialimages.SpatialHeader + The image object or file image header object. + + Returns + ======= + AffineMatrix4x4 + The 4x4 affine transformation matrix for mapping voxel data to world coordinates. + """ + if FIX_MGH_AFFINE_CALCULATION and isinstance(image, (MGHImage, MGHHeader)): + mgh_header = image.header if isinstance(image, MGHImage) else image + # the function header.get_affine() is actually bugged, because it uses dims and not dims-1 for center :/ + MdcD = np.asarray(mgh_header["Mdc"]).T * mgh_header["delta"] + vol_center = MdcD.dot(np.asarray(mgh_header["dims"][:3]) - 1) / 2 + return nib.affines.from_matvec(MdcD, mgh_header["Pxyz_c"] - vol_center) + elif isinstance(image, nib.analyze.SpatialHeader): + return image.get_best_affine() + else: + return image.affine + + def conform( img: nibabelImage, order: int = 1, @@ -636,7 +727,7 @@ def conform( dtype: type | None = np.uint8, orientation: OrientationType | None = "lia", threshold_1mm: float | None = None, - rescale: int | float | Literal["none"] = 255, + rescale: int | float | None = 255, vox_eps: float = 1e-4, rot_eps: float = 1e-6, **kwargs, @@ -648,7 +739,7 @@ def conform( Parameters ---------- - img : nibabelImage + img : nib.spatialimages.SpatialImage Loaded source image. order : int, default=1 Interpolation order (0=nearest, 1=linear, 2=quadratic, 3=cubic). @@ -698,12 +789,13 @@ def conform( LOGGER.warning("conform_vox_size is deprecated, replaced by vox_size and will be removed.") vox_size = kwargs["conform_vox_size"] - vox_img = conformed_vox_img_size(img, vox_size, img_size, threshold_1mm=threshold_1mm, vox_eps=vox_eps) + vox_size, img_size = conformed_vox_img_size(img, vox_size, img_size, threshold_1mm=threshold_1mm, vox_eps=vox_eps) _orientation: OrientationType = "native" if orientation is None else orientation - h1 = prepare_mgh_header(img, *vox_img, _orientation, vox_eps=vox_eps, rot_eps=rot_eps) + h1 = prepare_mgh_header(img, vox_size, img_size, _orientation, vox_eps=vox_eps, rot_eps=rot_eps) # affine is the computed target affine for the output image - target_affine = h1.get_affine() + target_affine = get_affine_from_any(h1) + if LOGGER.getEffectiveLevel() <= logging.DEBUG: with np.printoptions(precision=2, suppress=True): from re import sub @@ -729,7 +821,7 @@ def conform( target_affine, h1.get_data_shape(), order=order, - dtype=float, + dtype=np.float64, vox_eps=vox_eps, rot_eps=rot_eps, ) @@ -751,7 +843,8 @@ def conform( # mapped data is still float here, clip to integers now if np.issubdtype(target_dtype, np.integer): mapped_data = np.rint(mapped_data) - new_img = nibabel.MGHImage(mapped_data.astype(target_dtype), target_affine, h1) + # using h1.get_affine() here to keep affine and header consistent within nibabel calculations + new_img = nibabel.MGHImage(mapped_data.astype(target_dtype), get_affine_from_any(h1), h1) # make sure we store uchar from nibabel.freesurfer import mghformat @@ -770,29 +863,67 @@ def conform( return new_img +def ornt2affine( + ornt: OrntArrayType, + shape: npt.ArrayLike | None = None, +) -> AffineMatrix4x4: + """ + Calculate the affine of the orientation transform `ornt` (this is not the target orientation, but operation). + + Parameters + ---------- + ornt : array_like + The orientation to transform by. + shape : array_like, optional + The shape of the (input) data. + + Returns + ------- + AffineMatrix4x4 + The transformation affine, a homogeneous affine if shape is passed. + """ + _ornt = np.asarray(ornt, dtype=int) + # read dim from ornt + if _ornt.shape[1] != 2: + raise ValueError("shape of ornt must be (dim, 2)") + dim = _ornt.shape[0] + homogeneous_affine = shape is not None + aff = np.zeros((dim + int(homogeneous_affine),) * 2, dtype=float) + # reorder, then flip + aff[_ornt[:, 0], np.arange(dim)] = _ornt[:, 1] + if homogeneous_affine: + _center = (np.asarray(shape) - 1) / 2 + if _center.size != dim: + raise ValueError(f"The length of shape needs to be equal ornt.shape[0] ({dim})!") + aff[:, dim] += np.concatenate([_center, [1]]) + origin_out_CS = np.expand_dims(_center[_ornt[:, 0]], 1) + aff[:dim, dim] -= (aff[:dim, :dim] @ origin_out_CS)[:, 0] + return aff + + def prepare_mgh_header( img: nibabelImage, - target_vox_size: npt.NDArray[float] | None = None, - target_img_size: npt.NDArray[int] | None = None, + target_vox_size: Sequence[float] | float | None = None, + target_img_size: Sequence[int] | float | None = None, orientation: OrientationType = "native", vox_eps: float = 1e-4, rot_eps: float = 1e-6, ) -> MGHHeader: """ - Prepare the header with affine by target voxel size, target image size and criteria - initialized from img. + Prepare the header with affine by target voxel size, target image size, and criteria - initialized from img. This implicitly prepares the affine, which can be computed by `return_value.get_affine()`. Parameters ---------- - img : nibabel.analyze.SpatialImage + img : nibabel.spatialimages.SpatialImage The image object to base the header on. - target_vox_size : npt.NDArray[float], None, default=None + target_vox_size : sequence of float, float, None, default=None The target voxel size, importantly still in native orientation (reordering after). - target_img_size : npt.NDArray[int], None, default=None + target_img_size : sequence of int, int, None, default=None The target image size, importantly still in native orientation (reordering after). orientation : "native", "soft-", "", default="native" - How the affine should look like. + What the affine should be oriented like. vox_eps : float, default=1e-4 The epsilon for the voxelsize check. rot_eps : float, default=1e-6 @@ -803,45 +934,58 @@ def prepare_mgh_header( nibabel.freesurfer.mghformat.MGHHeader The header object to the "conformed" image based on img and the other parameters. """ + def _validate(param_name, param_type, param_value): + if param_value is None: + return param_value + elif not isinstance(param_value, (Sequence, np.ndarray, param_type)): + raise TypeError(f"{param_name} needs to be sequence, ndarray or float.") + elif isinstance(param_value, param_type): + return (param_value,) * 3 + elif len(param_value) != 3: + raise ValueError(f"{param_name} needs to have length 3.") + else: + return tuple(param_value) + + target_vox_size = _validate("target_vox_size", float, target_vox_size) + target_img_size = _validate("target_img_size", int, target_img_size) + # may copy some parameters if input was MGH format - h1 = MGHHeader.from_header(img.header) + h1: MGHHeader = MGHHeader.from_header(img.header) # nibabel only copies header information, if the file type is the same (here, this would be only of mgh header) source_img_shape = img.header.get_data_shape() source_vox_size = img.header.get_zooms() - source_mdc = img.affine[:3, :3] / np.linalg.norm(img.affine[:3, :3], axis=0, keepdims=True) - # native - if orientation == "native": - re_order_axes = [0, 1, 2] - mdc_affine = np.linalg.inv(source_mdc) - else: - _ornt_transform, _ = orientation_to_ornts(img.affine, orientation[-3:]) - re_order_axes = _ornt_transform[:, 0] - if len(orientation) == 3: # lia, ras, etc - # this is a 3x3 matrix - out_ornt = nib.orientations.axcodes2ornt(orientation[-3:].upper()) - mdc_affine = nib.orientations.inv_ornt_aff(out_ornt, source_img_shape)[:3, :3] - else: # soft lia, ras, .... - aff = _ornt_transform[:, 1][None] * source_mdc - mdc_affine = np.stack([aff[:3, int(ax)] for ax in _ornt_transform[:, 0]], axis=-1) + source_affine = get_affine_from_any(img) + _vox2vox = vox2vox_for_target_orientation(source_affine, orientation, (0,) * 3)[:3, :3] + _target_affine = source_affine[:3, :3] @ _vox2vox + h1["Mdc"] = np.linalg.inv(_target_affine / np.linalg.norm(_target_affine, axis=0, keepdims=True)) + _vox2vox_hom = np.eye(4, dtype=_vox2vox.dtype) + _vox2vox_hom[:3, :3] = _vox2vox + re_order_axes = io_orientation(_vox2vox_hom)[:, 0].astype(int).tolist() shape: list[int] = [(source_img_shape if target_img_size is None else target_img_size)[i] for i in re_order_axes] h1.set_data_shape(shape + [1]) # --> h1['delta'] h1.set_zooms([(target_vox_size if target_vox_size is not None else source_vox_size)[i] for i in re_order_axes]) - - h1["Mdc"] = mdc_affine # fov should only be defined, if the image has same fov in all directions? fov == one number _fov = np.asarray([i * v for i, v in zip(h1.get_data_shape(), h1.get_zooms(), strict=False)]) if _fov.min() == _fov.max(): # fov is not needed for MGHHeader.get_affine() h1["fov"] = _fov[0] - center = np.asarray(img.shape[:3], dtype=float) / 2.0 - h1["Pxyz_c"] = img.affine.dot(np.hstack((center, [1.0])))[:3] + center = (np.asarray(img.shape[:3], dtype=float) - (1 if FIX_MGH_AFFINE_CALCULATION else 0)) / 2.0 + if FIX_CENTER_NOT_CENTER: + # The center is not actually the center, but rather the position of the voxel at Ni/2 (counting at voxel 0) + # Therefore, the center changes, if we apply a vox2vox + # to get to the true center, move back half a voxel in all directions + true_center = center - 0.5 * np.ones((1, 3)) @ source_affine[:3, :3] + # new image center from true center go half a voxel in all direction of the new affine + center = 0.5 * np.ones((1, 3)) @ get_affine_from_any(h1)[:3, :3] + true_center + h1["Pxyz_c"] = nib.affines.apply_affine(source_affine, center) + # h1["Pxyz_c"] = source_affine.dot(np.hstack((center, [1.0])))[:3] # There is a special case here, where an interpolation is triggered, but it is not necessary, if the position of # the center could "fix this" condition: - vox2vox = np.linalg.inv(h1.get_affine()) @ img.affine + vox2vox = np.linalg.inv(get_affine_from_any(h1)) @ source_affine if does_vox2vox_rot_require_interpolation(vox2vox, vox_eps=vox_eps, rot_eps=rot_eps): # 1. has rotation, or vox-size resampling => requires resampling pass @@ -853,10 +997,10 @@ def prepare_mgh_header( # is it fixable? if not np.allclose(vec, np.round(vec), **tols) and np.allclose(vec * 2, np.round(vec * 2), **tols): new_center = (center + (1 - np.isclose(vec, np.round(vec), **tols)) / 2.0, [1.0]) - h1["Pxyz_c"] = img.affine.dot(np.hstack(new_center))[:3] + h1["Pxyz_c"] = source_affine.dot(np.hstack(new_center))[:3] # tr information is not copied when copying from non-mgh formats - if len(img.header.get('pixdim', [])) : + if len(img.header.get('pixdim', [])): h1['tr'] = img.header['pixdim'][4] * 1000 # The affine can be explicitly constructed by MGHHeader.get_affine() / h1.get_affine() @@ -915,7 +1059,7 @@ def is_conform( Parameters ---------- - img : nibabelImage + img : nib.analyze.SpatialImage Loaded source image. vox_size : float, "min", None, default=1.0 Which voxel size to conform to. Can either be a float between 0.0 and 1.0, 'min' (to check, whether the image is @@ -989,15 +1133,17 @@ def is_conform( img_size_criteria = f"Dimensions {img_size}={'x'.join(map(str, _img_size[:3]))}" checks[img_size_criteria] = np.array_equal(np.asarray(img.shape[:3]), _img_size), img_size_text + img_affine = get_affine_from_any(img) + # check orientation LIA - affcode = "".join(nib.orientations.aff2axcodes(img.affine)) + affcode = "".join(aff2axcodes(img_affine)) with np.printoptions(precision=2, suppress=True): - orientation_text = "affine=" + re.sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}" + orientation_text = "affine=" + re.sub("\\s+", " ", str(img_affine[:3, :3])) + f" => {affcode}" if orientation is None or orientation == "native": checks[f"Orientation {orientation}"] = "IGNORED", orientation_text else: is_soft = not orientation.startswith("soft") - is_correct_orientation = is_orientation(img.affine, orientation[-3:], is_soft, eps) + is_correct_orientation = is_orientation(img_affine, orientation[-3:], is_soft, eps) checks[f"Orientation {orientation.upper()}"] = is_correct_orientation, orientation_text # check dtype uchar @@ -1085,13 +1231,13 @@ def is_orientation( bool Whether the affine is LIA-oriented. """ - if "".join(nib.orientations.aff2axcodes(affine, tol=eps)).lower() == target_orientation.lower(): + if "".join(aff2axcodes(affine, tol=eps)).lower() == target_orientation.lower(): if soft: return True else: return False - return does_vox2vox_rot_require_interpolation(affine / np.linalg.norm(affine, axis=0), eps=eps) + return does_vox2vox_rot_require_interpolation(affine / np.linalg.norm(affine, axis=0), rot_eps=eps, vox_eps=eps) def conformed_vox_img_size( @@ -1109,7 +1255,7 @@ def conformed_vox_img_size( Parameters ---------- - img : nibabelImage + img : nib.spatialimages.SpatialImage Loaded source image. vox_size : float, "min", None The voxel size parameter to use: either a voxel size as float, or the string "min" to automatically find a @@ -1232,8 +1378,10 @@ def check_affine_in_nifti( # Exit otherwise vox_size_header = header.get_zooms() + img_affine = get_affine_from_any(img) + # voxel size in xyz direction from the affine - vox_size_affine = np.sqrt((img.affine[:3, :3] * img.affine[:3, :3]).sum(0)) + vox_size_affine = np.sqrt((img_affine[:3, :3] * img_affine[:3, :3]).sum(0)) if not np.allclose(vox_size_affine, vox_size_header, atol=1e-3): message = ( @@ -1241,7 +1389,7 @@ def check_affine_in_nifti( f"ERROR: Invalid Nifti-header! Affine matrix is inconsistent with " f"Voxel sizes. \nVoxel size (from header) vs. Voxel size in affine:\n" f"{tuple(vox_size_header[:3])}, {tuple(vox_size_affine)}\n" - f"Input Affine----------------\n{img.affine}\n" + f"Input Affine----------------\n{img_affine}\n" f"#############################################################" ) check = False diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index cd79317ca..013d94f3d 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -39,7 +39,7 @@ import FastSurferCNN.reduce_to_aseg as rta from FastSurferCNN.data_loader import data_utils as du -from FastSurferCNN.data_loader.conform import conform, is_conform, orientation_to_ornts, to_target_orientation +from FastSurferCNN.data_loader.conform import vox2vox_for_target_orientation, conform, is_conform, to_target_orientation from FastSurferCNN.inference import Inference from FastSurferCNN.quick_qc import check_volume from FastSurferCNN.utils import PLANES, Plane, logging, nibabelImage, parser_defaults @@ -388,7 +388,7 @@ def get_prediction( orig_in_lia, back_to_native = to_target_orientation(orig_data, affine, target_orientation="LIA") shape = orig_in_lia.shape + (self.get_num_classes(),) - _ornt_transform, _ = orientation_to_ornts(affine, target_orientation="LIA") + _ornt_transform, _ = vox2vox_for_target_orientation(affine, target_orientation="LIA") _zoom = _zoom[_ornt_transform[:, 0]] pred_prob = torch.zeros(shape, **kwargs) diff --git a/FastSurferCNN/utils/__init__.py b/FastSurferCNN/utils/__init__.py index 0aacf105f..efb00bb50 100644 --- a/FastSurferCNN/utils/__init__.py +++ b/FastSurferCNN/utils/__init__.py @@ -12,6 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager +from typing import TYPE_CHECKING, Literal, TypeVar + +if TYPE_CHECKING: + # if we are type-checking, we require numpy and nibabel to be installed + from nibabel.analyze import SpatialHeader as nibabelHeader + from nibabel.analyze import SpatialImage as nibabelImage + from numpy import bool_, dtype, float64, ndarray, number +else: + # if we are not type-checking, we try to import numpy and nibabel, but provide fallbacks if they are not installed + + # there are very few cases, when we do not need nibabel in any "full script" so always + # including nibabel does not overly drag down performance + try: + from nibabel.analyze import SpatialHeader as nibabelHeader + from nibabel.analyze import SpatialImage as nibabelImage + # Some scripts like the build script do not require the full FastSurfer environment. This makes sure, this typing + # module is still functional in such cases. + except (ImportError, ModuleNotFoundError): + nibabelImage = None + nibabelHeader = None + try: + from numpy import bool_, dtype, float64, ndarray, number + # Some scripts like the build script do not require the full FastSurfer environment. This makes sure, this typing + # module is still functional in such cases. + except (ImportError, ModuleNotFoundError): + float64 = float + bool_ = bool + # by typing this with tuple, ndarray[...] and dtype [...] will still be valid syntax + ndarray = tuple + dtype = tuple + from numbers import Number as number + __all__ = [ "AffineMatrix4x4", "checkpoint", @@ -41,6 +74,7 @@ "PLANES", "RotationMatrix3x3", "ScalarType", + "Shape1d", "Shape2d", "Shape3d", "Shape4d", @@ -49,31 +83,6 @@ "Vector3d", ] -from contextlib import contextmanager -from typing import Literal, TypeVar - -# there are very few cases, when we do not need nibabel in any "full script" so always -# including nibabel does not overly drag down performance -try: - from nibabel.analyze import SpatialHeader as nibabelHeader - from nibabel.analyze import SpatialImage as nibabelImage -# Some scripts like the build script do not require the full FastSurfer environment. This makes sure, this typing -# module is still functional in such cases. -except (ImportError, ModuleNotFoundError): - nibabelImage = None - nibabelHeader = None -try: - from numpy import bool_, dtype, float64, ndarray, number -# Some scripts like the build script do not require the full FastSurfer environment. This makes sure, this typing -# module is still functional in such cases. -except (ImportError, ModuleNotFoundError): - float64 = float - bool_ = bool - # by typing this with tuple, ndarray[...] and dtype [...] will still be valid syntax - ndarray = tuple - dtype = tuple - from numbers import Number as number - AffineMatrix4x4 = ndarray[tuple[Literal[4], Literal[4]], dtype[float64]] PlaneAxial = Literal["axial"] PlaneCoronal = Literal["coronal"] @@ -83,6 +92,7 @@ ScalarType = TypeVar("ScalarType", bound=number) Vector2d = ndarray[tuple[Literal[2]], dtype[float64]] Vector3d = ndarray[tuple[Literal[3]], dtype[float64]] +Shape1d = tuple[int] Shape2d = tuple[int, int] Shape3d = tuple[int, int, int] Shape4d = tuple[int, int, int, int] @@ -98,4 +108,4 @@ @contextmanager def noop_context(): """A no-op context manager that does nothing.""" - yield \ No newline at end of file + yield diff --git a/FastSurferCNN/utils/affines.py b/FastSurferCNN/utils/affines.py new file mode 100644 index 000000000..ce4078c90 --- /dev/null +++ b/FastSurferCNN/utils/affines.py @@ -0,0 +1,103 @@ +# Copyright 2026 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Literal + +import numpy as np + +from FastSurferCNN.utils import AffineMatrix4x4, ScalarType + +OrntArrayType = np.ndarray[tuple[int, Literal[2]], np.dtype[ScalarType]] + +# nibabel's implementations of io_orientation and aff2axcodes are not fixed in 5.3.3 yet, so at a minimum for tests, we +# need nibabel>5.3.3 + + +def io_orientation(affine, tol=None): + """Orientation of input axes in terms of output axes for `affine` + + Valid for an affine transformation from ``p`` dimensions to ``q`` + dimensions (``affine.shape == (q + 1, p + 1)``). + + The calculated orientations can be used to transform associated + arrays to best match the output orientations. If ``p`` > ``q``, then + some of the output axes should be considered dropped in this + orientation. + + Parameters + ---------- + affine : (q+1, p+1) ndarray-like + Transformation affine from ``p`` inputs to ``q`` outputs. Usually this + will be a shape (4,4) matrix, transforming 3 inputs to 3 outputs, but + the code also handles the more general case + tol : {None, float}, optional + threshold below which SVD values of the affine are considered zero. If + `tol` is None, and ``S`` is an array with singular values for `affine`, + and ``eps`` is the epsilon value for datatype of ``S``, then `tol` set + to ``S.max() * max((q, p)) * eps`` + + Returns + ------- + orientations : (p, 2) ndarray + one row per input axis, where the first value in each row is the closest + corresponding output axis. The second value in each row is 1 if the + input axis is in the same direction as the corresponding output axis and + -1 if it is in the opposite direction. If a row is [np.nan, np.nan], + which can happen when p > q, then this row should be considered dropped. + """ + affine = np.asarray(affine) + q, p = affine.shape[0] - 1, affine.shape[1] - 1 + # extract the underlying rotation, zoom, shear matrix + RZS = affine[:q, :p] + zooms = np.sqrt(np.sum(RZS * RZS, axis=0)) + # Zooms can be zero, in which case all elements in the column are zero, and + # we can leave them as they are + zooms[zooms == 0] = 1 + RS = RZS / zooms + # Transform below is polar decomposition, returning the closest + # shearless matrix R to RS + P, S, Qs = np.linalg.svd(RS, full_matrices=False) + # Threshold the singular values to determine the rank. + if tol is None: + tol = S.max() * max(RS.shape) * np.finfo(S.dtype).eps + keep = S > tol + R = np.dot(P[:, keep], Qs[keep]) + # the matrix R is such that np.dot(R,R.T) is projection onto the + # columns of P[:,keep] and np.dot(R.T,R) is projection onto the rows + # of Qs[keep]. R (== np.dot(R, np.eye(p))) gives rotation of the + # unit input vectors to output coordinates. Therefore, the row + # index of abs max R[:,N], is the output axis changing most as input + # axis N changes. In case there are ties, we choose the axes + # iteratively, removing used axes from consideration as we go + ornt = np.ones((p, 2), dtype=np.int8) * np.nan + for _ in range(p): + row, col = np.unravel_index(np.argmax(np.abs(R)), R.shape) + max_val = R[row, col] + if not np.isclose(max_val, 0): + ornt[col] = [row, np.sign(max_val)] + R[row, :] = 0 + R[:, col] = 0 + return ornt + + +def aff2axcodes(aff: AffineMatrix4x4, labels: tuple[str, str, str] = ("LR", "PA", "IS"), tol: float = 1e-5) \ + -> OrntArrayType: + """Convert an affine matrix to axis codes. + + See Also + -------- + + """ + from nibabel.orientations import ornt2axcodes + ornt = io_orientation(aff, tol) + return ornt2axcodes(ornt, labels) \ No newline at end of file diff --git a/HypVINN/data_loader/data_utils.py b/HypVINN/data_loader/data_utils.py index 8bda9e0b4..9b71a0500 100644 --- a/HypVINN/data_loader/data_utils.py +++ b/HypVINN/data_loader/data_utils.py @@ -14,11 +14,11 @@ # limitations under the License. # IMPORTS -import nibabel as nib import numpy as np from numpy import typing as npt -from FastSurferCNN.data_loader.conform import getscale, scalecrop +from FastSurferCNN.data_loader.conform import conform, getscale, scalecrop +from FastSurferCNN.utils.affines import aff2axcodes from HypVINN.config.hypvinn_global_var import ( FS_CLASS_NAMES, HYPVINN_CLASS_NAMES, @@ -31,39 +31,6 @@ ## -def calculate_flip_orientation(iornt: np.ndarray, base_ornt: np.ndarray) -> np.ndarray: - """ - Compute the flip orientation transform. - - ornt[N, 1] is flip of axis N, where 1 means no flip and -1 means flip. - - Parameters - ---------- - iornt : np.ndarray - Initial orientation. - base_ornt : np.ndarray - Base orientation. - - Returns - ------- - new_iornt : np.ndarray - New orientation. - """ - new_iornt = iornt.copy() - - # Find the axis to compared and then compared orientation, where 1 means no flip - # and -1 means flip. - for axno, direction in np.asarray(base_ornt): - idx = np.where(iornt[:, 0] == axno) - idirection = iornt[int(idx[0][0]), 1] - if direction == idirection: - new_iornt[int(idx[0][0]), 1] = 1.0 - else: - new_iornt[int(idx[0][0]), 1] = -1.0 - - return new_iornt - - def reorient_img(img, ref_img): """ Reorient a Nibabel image based on the orientation of a reference nibabel image. @@ -80,19 +47,11 @@ def reorient_img(img, ref_img): img : nibabel.Nifti1Image Reoriented image. """ - ref_ornt = nib.io_orientation(ref_img.affine) - iornt = nib.io_orientation(img.affine) - - if not np.array_equal(iornt, ref_ornt): - # first flip orientation - fornt = calculate_flip_orientation(iornt, ref_ornt) - img = img.as_reoriented(fornt) - # the transpose axis - tornt = np.ones_like(ref_ornt) - tornt[:, 0] = ref_ornt[:, 0] - img = img.as_reoriented(tornt) - - return img + # if the affines are the same, no reorientation is required and we can skip this + if np.array_equal(ref_img.affine, img.affine): + return img + target_orientation = "soft " + aff2axcodes(ref_img.affine, ("LR", "PA", "IS")) + return conform(img, orientation=target_orientation, vox_size=None, img_size=None, dtype=None, rescale=None) def transform_axial2coronal(vol: np.ndarray, axial2coronal: bool = True) -> np.ndarray: diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index c4800638c..86a6bc42d 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -21,30 +21,12 @@ from skimage.measure import label import FastSurferCNN.utils.logging as logging +from FastSurferCNN.utils.affines import aff2axcodes from HypVINN.data_loader.data_utils import hypo_map_subseg_2_fsseg LOGGER = logging.get_logger(__name__) -def img2axcodes(img: nib.Nifti1Image) -> tuple: - """ - Convert the affine matrix of an image to axis codes. - - This function takes an image as input and returns the axis codes corresponding to the affine matrix of the image. - - Parameters - ---------- - img : nibabel image object - The input image. - - Returns - ------- - tuple - The axis codes corresponding to the affine matrix of the image. - """ - return nib.aff2axcodes(img.affine) - - def save_segmentation( prediction: np.ndarray, orig_path: Path, @@ -94,23 +76,19 @@ def save_segmentation( # Mapped HypVINN labelst to FreeSurfer Hypvinn Labels pred_arr = hypo_map_subseg_2_fsseg(pred_arr) orig_img = nib.load(orig_path) - LOGGER.info(f"Orig data orientation : {img2axcodes(orig_img)}") + LOGGER.info(f"Orig data orientation : {aff2axcodes(orig_img.affine)}") if save_mask: mask_img = nib.Nifti1Image(labels_cc, affine=ras_affine, header=ras_header) - LOGGER.info(f"HypVINN Mask orientation: {img2axcodes(mask_img)}") + LOGGER.info(f"HypVINN Mask orientation: {aff2axcodes(mask_img.affine)}") mask_img = reorient_img(mask_img, orig_img) - LOGGER.info( - f"HypVINN Mask after re-orientation: {img2axcodes(mask_img)}" - ) + LOGGER.info(f"HypVINN Mask after re-orientation: {aff2axcodes(mask_img.affine)}") nib.save(mask_img, subject_dir / "mri" / mask_file) pred_img = nib.Nifti1Image(pred_arr, affine=ras_affine, header=ras_header) - LOGGER.info(f"HypVINN Prediction orientation: {img2axcodes(pred_img)}") + LOGGER.info(f"HypVINN Prediction orientation: {aff2axcodes(pred_img.affine)}") pred_img = reorient_img(pred_img, orig_img) - LOGGER.info( - f"HypVINN Prediction after re-orientation: {img2axcodes(pred_img)}" - ) + LOGGER.info(f"HypVINN Prediction after re-orientation: {aff2axcodes(pred_img.affine)}") pred_img.set_data_dtype(np.int16) # Maximum value 984 nib.save(pred_img, subject_dir / "mri" / seg_file) return time() - starttime @@ -153,17 +131,15 @@ def save_logits( """ from HypVINN.data_loader.data_utils import reorient_img orig_img = nib.load(orig_path) - LOGGER.info(f"Orig data orientation: {img2axcodes(orig_img)}") + LOGGER.info(f"Orig data orientation: {aff2axcodes(orig_img.affine)}") nifti_img = nib.Nifti1Image( logits.astype(np.float32), affine=ras_affine, header=ras_header, ) - LOGGER.info(f"HypVINN logits orientation: {img2axcodes(nifti_img)}") + LOGGER.info(f"HypVINN logits orientation: {aff2axcodes(nifti_img.affine)}") nifti_img = reorient_img(nifti_img, orig_img) - LOGGER.info( - f"HypVINN logits after re-orientation: {img2axcodes(nifti_img)}" - ) + LOGGER.info(f"HypVINN logits after re-orientation: {aff2axcodes(nifti_img.affine)}") nifti_img.set_data_dtype(np.float32) save_as = save_dir / f"HypVINN_logits_{mode}.nii.gz" nib.save(nifti_img, save_as) diff --git a/test/image/conftest.py b/test/image/conftest.py new file mode 100644 index 000000000..d21bffc46 --- /dev/null +++ b/test/image/conftest.py @@ -0,0 +1,48 @@ +import nibabel as nib +import numpy as np +import pytest + +from FastSurferCNN.utils import AffineMatrix4x4 +from FastSurferCNN.utils.affines import aff2axcodes +from FastSurferCNN.utils.arg_types import OrientationType, StrictOrientationType + + +@pytest.fixture(scope="session", params=["soft LIA", "soft ARS"]) +def soft_orientation(request) -> OrientationType: + return request.param + + +@pytest.fixture(scope="session", params=["soft LIA", "soft ARS", "PIL"]) +def orientation(request) -> OrientationType: + return request.param + + +@pytest.fixture(scope="session", params=["LIA", "ARS", "PIL"]) +def strict_orientation(request) -> StrictOrientationType: + return request.param + + +@pytest.fixture(scope="session", params=[0.8, 1.0]) +def vox_size(request) -> float: + return request.param + + +@pytest.fixture(scope="session", params=[8, 15]) # [128, 256]) +def img_size(request) -> float: + return request.param + + +@pytest.fixture(scope="session") +def random_affine(img_size: int, vox_size: float) -> AffineMatrix4x4: + from scipy.spatial.transform import Rotation + affine = np.eye(4, dtype=np.float64) + vec = np.random.randn(3) + rotvec = vec / np.linalg.norm(vec, axis=0) * np.random.rand(1) * np.pi + affine[:3, :3] = Rotation.from_rotvec(rotvec, False).as_matrix() * vox_size + try: + L_axis = aff2axcodes(affine, ("LR", "PA", "IS")).index("L") + affine[:3, L_axis] *= -1 + except ValueError: + pass + affine[:3, 3] = (np.random.rand(3) - 0.5) * img_size + return affine diff --git a/test/image/test_conform_reorient.py b/test/image/test_conform_reorient.py new file mode 100644 index 000000000..c05026033 --- /dev/null +++ b/test/image/test_conform_reorient.py @@ -0,0 +1,232 @@ +from logging import getLogger +from typing import TypedDict + +import nibabel as nib +import numpy as np +import pytest +from pytest import approx + +from FastSurferCNN.data_loader.conform import OrientationType, conform, prepare_mgh_header +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, nibabelHeader, nibabelImage +from FastSurferCNN.utils.affines import aff2axcodes +from FastSurferCNN.utils.arg_types import StrictOrientationType + +logger = getLogger(__name__) + +class MultiCoordImages(TypedDict): + X: nibabelImage + Y: nibabelImage + Z: nibabelImage + +conform_reorient = {"rescale": None, "dtype": np.float32} + + +def circle_data(img_size: int, radius: float, center: float) -> Image3d[np.float32]: + """Generates a 3D image with a centered sphere of radius img_size/2.""" + data = np.mgrid[0:img_size, 0:img_size, 0:img_size].astype(np.float32) - center + return (np.sum(data * data, axis=0) < radius * radius).astype(np.float32) + + +@pytest.fixture(scope="session") +def radius(img_size: int) -> float: + return img_size / 2.0 - 2.0 + +@pytest.fixture(scope="session") +def center(img_size: int) -> float: + return (img_size - 1) / 2.0 + + +@pytest.fixture(scope="session") +def circle_image(random_affine: AffineMatrix4x4, img_size: int, radius: float, center: float) -> nib.Nifti1Image: + return nib.Nifti1Image(circle_data(img_size, radius, center), random_affine) + + +@pytest.fixture(scope="session", params=[1.0, 0.23]) +def resample_factor(request) -> float: + return request.param + + +@pytest.fixture(scope="session") +def random_image(random_affine: AffineMatrix4x4, img_size: int) -> nib.Nifti1Image: + return nib.Nifti1Image(np.random.randn(img_size, img_size, img_size), random_affine) + + +@pytest.fixture(scope="session") +def empty_image(random_affine: AffineMatrix4x4, img_size: int) -> nib.Nifti1Image: + return nib.Nifti1Image(np.empty((img_size,) * 3), random_affine) + + +def affine2orientation(affine: AffineMatrix4x4) -> OrientationType: + """Generates the orientation type string from an affine matrix.""" + orientation: StrictOrientationType = "".join(aff2axcodes(affine, ("LR", "PA", "IS"))) + # make sure the affine is normalized for vox_sizes not 1 + norm_affine = affine[:3, :3] / np.linalg.norm(affine[:3, :3], keepdims=True, axis=0) + if np.allclose(np.sum([np.isclose(np.abs(norm_affine), i) for i in (0, 1.)], axis=0), 1): + return orientation + else: + return "soft " + orientation + + +class HeaderTests: + def test_affine_orientation(self, affine: AffineMatrix4x4, orientation: OrientationType): + """Tests whether a conformed image actually has the correct orientation.""" + actual = affine2orientation(affine) + expected = orientation + assert actual == expected, "The expected orientation did not match the actual orientation." + + def test_affine_vox_size(self, affine: AffineMatrix4x4, vox_size: float): + """Tests whether a conformed image actually has the correct voxel size.""" + actual = np.linalg.norm(affine[:3, :3], axis=0) + expected = vox_size + assert actual == approx(expected), "The actual voxel sizes in the affine did not match the expected." + + def test_vox_size(self, header: nibabelHeader, vox_size: float): + """Tests whether a conformed image actually has the correct voxel size.""" + actual = header.get_zooms() + expected = np.full_like(actual, vox_size) + assert actual == approx(expected), "The actual voxel sizes in the affine did not match the expected." + + +class TestPrepareHeader(HeaderTests): + + @pytest.fixture(scope="class") + def header(self, empty_image: nib.Nifti1Image, orientation: OrientationType, img_size: int, vox_size: float) \ + -> nib.freesurfer.mghformat.MGHHeader: + return prepare_mgh_header(empty_image, [vox_size] * 3, [img_size] * 3, orientation) + + @pytest.fixture(scope="class") + def affine(self, header: nib.freesurfer.mghformat.MGHHeader) -> AffineMatrix4x4: + return header.get_affine() + + +class TestConformAffine(HeaderTests): + + @pytest.fixture(scope="class") + def image(self, empty_image: nib.Nifti1Image, orientation: OrientationType, vox_size: float) -> nib.Nifti1Image: + return conform(empty_image, orientation=orientation, vox_size=vox_size, **conform_reorient) + + @pytest.fixture(scope="class") + def affine(self, image: nib.Nifti1Image) -> AffineMatrix4x4: + return image.affine + + @pytest.fixture(scope="class") + def header(self, image: nib.Nifti1Image) -> nib.Nifti1Header: + return image.header + + +class TestThereAndBack: + + @pytest.fixture(scope="class") + def image( + self, + circle_image: nib.Nifti1Image, + soft_orientation: OrientationType, + img_size: int, + vox_size: float, + resample_factor: float, + ) -> nib.MGHImage: + """ + Conform `circle_image` to `soft_orientation` and back to the original orientation. + """ + there = conform( + circle_image, + orientation=soft_orientation, vox_size=vox_size * resample_factor, img_size=64, order=1, + **conform_reorient, + ) + in_orientation: OrientationType = "soft " + "".join(aff2axcodes(circle_image.affine, ("LR", "PA", "IS"))) + return conform(there, orientation=in_orientation, img_size=img_size, vox_size=vox_size, **conform_reorient) + + def test_affine(self, image: nib.MGHImage, circle_image: nibabelImage) -> None: + """ + Tests whether the affines of the original and the re-oriented images are the same. + """ + expected = circle_image.affine + actual = image.affine + # currently, the translation parts of the affines differ + assert actual == approx(expected, abs=1e-5), "The affines of original and re-reoriented images differ!" + + def test_image( + self, + image: nib.MGHImage, + circle_image: nibabelImage, + radius: float, + center: float, + soft_orientation: AffineMatrix4x4, + ) -> None: + """ + Tests whether the content of the original and the re-oriented images are the same in the "center circle". + """ + + # this has to be filtered by the region of the image that is the same + inner_circle = circle_data(circle_image.shape[0], radius - 1, center) + outer_circle = circle_data(circle_image.shape[0], radius + 1, center) + # the expected image is 0 outside the outer circle, 1 inside the inner circle and ignored (NaN) in between, + # where interpolation can cause differences + expected = np.where(outer_circle - inner_circle, np.nan, circle_image.dataobj) + actual = np.asarray(image.dataobj) + logger.info(f"Difference is {np.max(np.abs(actual - circle_image.get_fdata()))} - {np.nanmax(np.abs(actual - expected))}") + + assert actual == approx(expected, abs=0.3, nan_ok=True), "The data differs from the re-oriented image!" + +class TestReorientWorldCoords: + + @staticmethod + def worldcoords_data(affine: AffineMatrix4x4, img_size: int) -> np.ndarray: + xi = np.moveaxis(np.mgrid[0:img_size, 0:img_size, 0:img_size], 0, -1) + return nib.affines.apply_affine(affine, xi.reshape((-1, 3)).astype(float)).reshape(xi.shape).astype(np.float32) + + @pytest.fixture(scope="class") + def worldcoord_images(self, random_affine: AffineMatrix4x4, img_size: int) -> MultiCoordImages: + data = self.worldcoords_data(random_affine, img_size) + return MultiCoordImages(**{c: nib.Nifti1Image(data[..., i], random_affine) for i, c in enumerate("XYZ")}) + + @pytest.fixture(scope="class") + def conf_img_size(self) -> int: + return 16 + + @pytest.fixture(scope="class") + def conf_images( + self, + conf_img_size: int, + worldcoord_images: MultiCoordImages, + orientation: OrientationType, + ) -> MultiCoordImages: + return MultiCoordImages( + **{k: conform(img, **conform_reorient, img_size=conf_img_size) for k, img in worldcoord_images.items()}, + ) + + @pytest.mark.parametrize(argnames=["dim_name"], argvalues=[["X"], ["Y"], ["Z"]]) + def test_reorient_worldcoords_affine( + self, + conf_images: MultiCoordImages, + orientation: OrientationType, + dim_name: str, + ): + """ + This test checks, whether the affines of the world coordinate images are consistent with the original affine. + """ + actual = "".join(aff2axcodes(conf_images[dim_name], ("LR", "PA", "IS"))) + expected = orientation + assert actual == expected + + + def test_reorient_worldcoords_image( + self, + conf_images: MultiCoordImages, + soft_orientation: OrientationType, + vox_size: float, + img_size: int, + ): + """ + This test checks, whether the world coordinates are consistent. + """ + from pytest import approx, xfail + + xyz = np.stack([img.get_fdata() for img in conf_images], axis=-1) + logger.info("Checking affines of world images:") + + worldcoords = self.worldcoords_data(conf_images[0].affine, conf_images[0].shape[0]) + # center_mask = np.pad(circle_data(conf_images[0] - 2), 1, constant_values=0) + expected = np.where(center_mask[..., None].astype(np.bool_), worldcoords, np.NaN) + + assert xyz == approx(expected, abs=0.1), "The world images " diff --git a/test/image/test_orientation_transform.py b/test/image/test_orientation_transform.py new file mode 100644 index 000000000..c512d034e --- /dev/null +++ b/test/image/test_orientation_transform.py @@ -0,0 +1,151 @@ +from logging import getLogger + +import nibabel as nib +import numpy as np +import pytest +from numpy import typing as npt +from pytest import approx + +from FastSurferCNN.data_loader.conform import ( + vox2vox_for_target_orientation, + does_vox2vox_rot_require_interpolation, + ornt2affine, +) +from FastSurferCNN.utils import AffineMatrix4x4 +from FastSurferCNN.utils.affines import aff2axcodes +from FastSurferCNN.utils.arg_types import OrientationType, StrictOrientationType + +logger = getLogger(__name__) +SQRT1_2 = np.sqrt(0.5) +SQRT3_4 = np.sqrt(0.75) + + +@pytest.mark.parametrize(argnames=["shape"], argvalues=[[(1,) * 2]]) +def test_ornt2affine_valueerror(shape: tuple[int]): + """Test whether ornt2affine raises the correct ValueError""" + with pytest.raises(ValueError, match="length of shape"): + ornt2affine(np.transpose([np.arange(3), np.ones((3,))]), shape=shape) + + +@pytest.mark.parametrize(argnames="shape", argvalues=[None, (1, 1, 1)]) +def test_ornt2affine_shape(shape): + """Test whether ornt2affine returns the correct output shape.""" + actual = ornt2affine(np.transpose([np.arange(3), np.ones((3,))]), shape=shape).shape + expected = (3 + int(shape is not None),) * 2 + assert actual == expected, "The shape of ornt2affine was incorrect!" + + +def test_ornt2affine_axcode(strict_orientation: StrictOrientationType): + """Test whether ornt2affine returns the an affine of the correct axcode.""" + vox2vox = ornt2affine(nib.orientations.axcodes2ornt(strict_orientation, ("LR", "PA", "IS")), (0,) * 3) + actual = "".join(aff2axcodes(vox2vox, ("LR", "PA", "IS"))) + expected = strict_orientation + assert actual == expected, "ornt2affine did not return a vox2vox of the correct orientation." + + +@pytest.mark.parametrize(argnames=["axcode", "translation"], argvalues=[["ALS", [1, 0, 0]], ["PIR", [0, 1, 1]]]) +def test_ornt2affine_translation(axcode, translation, img_size: int): + """Test whether the translation of ornt2affine is correct.""" + ornt = nib.orientations.axcodes2ornt(axcode) + actual = ornt2affine(ornt, shape=(img_size,) * 3)[:3, 3] + expected = np.asarray(translation) * (img_size - 1) + assert actual == approx(expected), "Translation component of the vox2vox from ornt2affine did not match!" + + +@pytest.mark.parametrize( + argnames=["axcode", "translation"], + argvalues=[["ALS", [SQRT3_4, 0.5, 0]], ["PIR", [-0.5, SQRT3_4, 1]], ["PIL", [SQRT3_4 - 0.5, 0.5 + SQRT3_4, 1]]], +) +def test_ornt2affine_translation2(axcode: StrictOrientationType, translation: npt.ArrayLike, img_size: int): + """Test whether the translation of ornt2affine is correct.""" + from scipy.spatial.transform import Rotation + + ornt = nib.orientations.axcodes2ornt(axcode) + img_affine = np.pad(Rotation.from_euler("XYZ", [0, 0, 30], degrees=True).as_matrix(), ((0, 1), (0, 1))) + img_affine[:, 3] = [0, 5, 0, 1] + + vox2vox = ornt2affine(ornt, shape=(img_size,) * 3) + actual = (img_affine @ vox2vox)[:3, 3] + expected = np.asarray(translation) * (img_size - 1) + img_affine[:3, 3] + assert actual == approx(expected), "Translation component of the vox2vox from ornt2affine did not match!" + + +def test_ornt2affine_data(strict_orientation: OrientationType, img_size: int): + """Test whether ornt2affine + apply_vox2vox equals scipy.ndimage.affine_transform.""" + from scipy.ndimage import affine_transform + + shape = (img_size,) * 3 + # not actually using the strict re-orientation + ornt = nib.orientations.axcodes2ornt(strict_orientation, ("LR", "PA", "IS")) + vox2vox = ornt2affine(ornt, shape=shape) + data = np.random.randn(*shape) + expected = nib.orientations.apply_orientation(data, ornt) + # affine_transform applies the inverse of the given transformation, so we need to invert vox2vox here + actual = affine_transform(data, np.linalg.inv(vox2vox)) + assert actual == approx(expected), "affine_transform and apply_affine did not yield the same result!" + + +def test_affine_for_target_orientation(random_affine: AffineMatrix4x4, img_size: int, orientation: OrientationType): + """Test whether affine_for_target_orientation works as expected.""" + vox2vox = vox2vox_for_target_orientation(random_affine, orientation, (img_size,) * 3) + combined = random_affine @ vox2vox + actual = "".join(aff2axcodes(combined, ("lr", "pa", "is"))) + expected = orientation.lower().removeprefix("soft").lstrip("-_ ") + assert actual == expected, f"affine_for_target_orientation did not yield a transformation to {orientation}!" + + +@pytest.mark.parametrize( + argnames=["affine", "axcode"], + argvalues=[ + [np.diag([2, 3, 4, 1]), "RAS"], + [np.diag([-2, 3, -4, 1])[:, [1, 0, 2, 3]], "ALI"], + [np.diag([2, -3, 4, 1])[:, [0, 2, 1, 3]], "RSP"], + [np.diag([-2, -3, -4, 1])[:, [2, 0, 1, 3]], "ILP"], + ], +) +def test_aff2axcodes(affine: AffineMatrix4x4, axcode: StrictOrientationType): + """Test whether aff2axcodes works as expected.""" + actual = "".join(aff2axcodes(affine, ("lr", "pa", "is"))) + expected = axcode.lower() + assert actual == expected, "aff2axcodes did not return the expected axcode!" + + +def test_affine_for_target_orientation2( + random_affine: AffineMatrix4x4, + img_size: int, + strict_orientation: StrictOrientationType, +): + """Test whether affine_for_target_orientation works as expected.""" + vox2vox = vox2vox_for_target_orientation(random_affine, "soft " + strict_orientation, (img_size,) * 3) + is_reorder_flip = not does_vox2vox_rot_require_interpolation(vox2vox) + assert is_reorder_flip, "affine_for_target_orientation did not yield a \"soft\" vox2vox transformation!" + +# +# class TestAffineForTargetOrientation: +# """Tests the function conform.orientation_to_ornts.""" +# +# TwoOrnts = tuple[npt.NDArray[int], npt.NDArray[int]] +# +# @pytest.fixture(scope="class") +# def output(self, random_affine: npt.NDArray[float], strict_orientation: StrictOrientationType) -> TwoOrnts: +# return orientation_to_ornts(random_affine, strict_orientation) +# +# def test_forward_and_back(self, output: TwoOrnts): +# """Test whether the two outputs of orientation_to_ornts are inverse if each other.""" +# actual = nib.orientations.ornt_transform(*output) +# expected = np.stack([np.arange(3), np.ones((3,))], axis=-1) +# assert actual == approx(expected), "The two outputs of orientation_to_ornts are not the inverse of each other" +# +# def test_axcode( +# self, +# output: TwoOrnts, +# random_affine: npt.NDArray[float], +# strict_orientation: StrictOrientationType, +# img_size: int, +# ): +# """Test whether the axcodes of an affine transformed by reorient_affine are correct.""" +# from nibabel.orientations import ornt2axcodes, io_orientation +# reoriented_affine = reorient_affine(random_affine, output[0], (img_size,) * 3) +# actual = "".join(ornt2axcodes(io_orientation(reoriented_affine), ("LR", "PA", "IS"))) +# expected = strict_orientation +# assert actual == expected, "Axcodes of affine after orientations_to_ornts + reorient_affine was not correct."