diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 99bbc905..59aafda3 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -9,6 +9,7 @@ - The factory method `cedalion.dot.get_standard_headmodel` to construct the `TwoSurfaceHeadModel` of the standard Colin27 and ICBM-152 heads was added, by [Eike Middell](https://github.com/emiddell). - Added `cedalion.xrutils.dot_dataarray_csr` for matrix products between `xr.DataArray` and `scipy.sparse` arrays, by [Eike Middell](https://github.com/emiddell). +- Added `cedalion.geometry.landmarks.normalize_landmarks_labels` to map alternative landmark names (e.g., "nasion", "left ear", "nz") to their canonical 10-10 system labels (e.g. Nz, LPA). The function handles now case-insensitive matching and supports common naming conventions. Usage: `geo3d = normalize_landmarks_labels(geo3d)` before calling registration or plotting functions, by [Mohammad Orabe](https://github.com/orabe). ([#84](https://github.com/ibs-lab/cedalion/issues/84)) ### Changed - The package `cedalion.sigproc.motion_correct` was renamed to `cedalion.sigproc.motion`. - The ICA-EBM and ICA_ERBM implementations were moved into `cedalion.sigdecomp.unimodal`. @@ -35,6 +36,7 @@ access example datasets are now available under `cedalion.data`. coordinate system, by [Nils Harmening](https://github.com/harmening). ([#110](https://github.com/ibs-lab/cedalion/pull/110)) - Changed the names of several motion correction algorithms from `motion_correct.motion_correct_X` to `motion_correct.X`. Argument names were made PEP8 compliant. The example `22_motion_artefacts_and_correction` was improved. By [Eike Middell](https://github.com/emiddell). +- The function `cedalion.vis.anatomy.plot_montage3D` now accepts a `landmarks` parameter to specify which landmarks should be highlighted. Pass `None` (default) to show all available canonical registration landmarks (e.g. Nz, Iz, LPA, RPA, Cz), a list of landmark names to show specific ones, or an empty list to show none, by [Mohammad Orabe](https://github.com/orabe). ([#84](https://github.com/ibs-lab/cedalion/issues/84)) ### Deprecated diff --git a/src/cedalion/geometry/landmarks.py b/src/cedalion/geometry/landmarks.py index ef9ea83c..053c0d5d 100644 --- a/src/cedalion/geometry/landmarks.py +++ b/src/cedalion/geometry/landmarks.py @@ -20,6 +20,70 @@ from cedalion.typing import LabeledPoints +def normalize_landmarks_labels(geo3d: LabeledPoints) -> LabeledPoints: + """Normalize landmark labels to canonical names. + + Maps commonly used alternative landmark names to canonical names: + - NASION, Nasion, nasion, nas, Nas, NAS -> Nz + - INION, Inion, inion, ini, Ini, INI -> Iz + - LPA_L, lpa, left ear, Left Ear, LEFT EAR, LE, left, Left, L -> LPA + - RPA_R, rpa, right ear, Right Ear, RIGHT EAR, RE, right, Right, R -> RPA + - CZ, cz, vertex, Vertex, VERTEX -> Cz + + When multiple labels normalize to the same canonical name: + - If the canonical name already exists, alternative forms are dropped + - If multiple alternatives exist without the canonical form, only the + first one is renamed and others are dropped + + Args: + geo3d: LabeledPoints with potentially non-canonical landmark names. + + Returns: + LabeledPoints with normalized landmark labels (duplicates removed). + """ + if len(geo3d.label) == 0: + return geo3d + + label_mapping = { + "Nz": {"NASION", "Nasion", "nasion", "nas", "Nas", "NAS"}, + "Iz": {"INION", "Inion", "inion", "ini", "Ini", "INI"}, + "LPA": {"LPA_L", "lpa", "left ear", "Left Ear", "LEFT EAR", "LE", "left", "Left", "L"}, + "RPA": {"RPA_R", "rpa", "right ear", "Right Ear", "RIGHT EAR", "RE", "right", "Right", "R"}, + "Cz": {"CZ", "cz", "vertex", "Vertex", "VERTEX"}, + } + + existing_labels = set(geo3d.label.values) + + labels_to_rename = {} # {alternative_label: canonical_label} + labels_to_drop = [] + + for canonical, alternatives in label_mapping.items(): + # Find which alternative forms are present in the data + present_alternatives = [lbl for lbl in existing_labels if lbl in alternatives] + + if len(present_alternatives) > 0: + if canonical in existing_labels: + # drop all alternative forms if canonical already exists + labels_to_drop.extend(present_alternatives) + else: + # If no canonical form: rename first alternative, drop rest + labels_to_rename[present_alternatives[0]] = canonical + if len(present_alternatives) > 1: + labels_to_drop.extend(present_alternatives[1:]) + + # Apply transformations: first drop duplicates, then rename + if labels_to_drop: + keep_labels = [label for label in geo3d.label.values + if label not in labels_to_drop] + geo3d = geo3d.sel(label=keep_labels) + + # Rename alternative labels to canonical names + if labels_to_rename: + geo3d = geo3d.points.rename(labels_to_rename) + + return geo3d + + def _sort_line_points(start_point: np.ndarray, points: np.ndarray): sorted_indices = [] sorted_distances = [] diff --git a/src/cedalion/vis/anatomy/montage.py b/src/cedalion/vis/anatomy/montage.py index 9122d7d8..b64371b0 100644 --- a/src/cedalion/vis/anatomy/montage.py +++ b/src/cedalion/vis/anatomy/montage.py @@ -1,13 +1,22 @@ import cedalion.typing as cdt +from cedalion.dataclasses.geometry import PointType import matplotlib.pyplot as p -def plot_montage3D(amp: cdt.NDTimeSeries, geo3d: cdt.LabeledPoints): +def plot_montage3D( + amp: cdt.NDTimeSeries, + geo3d: cdt.LabeledPoints, + landmarks: list[str] | None = None +): """Plots a 3D visualization of a montage. Args: amp: Time series data array. geo3d: Landmark coordinates. + landmarks: Landmarks to highlight in the plot. Can be: + - None (default): Shows canonical registration landmarks (Nz, Iz, LPA, RPA, Cz) + - list of str: Shows specific landmarks (only if they exist in geo3d) + - []: Empty list shows no landmarks """ geo3d = geo3d.pint.dequantify() @@ -15,19 +24,45 @@ def plot_montage3D(amp: cdt.NDTimeSeries, geo3d: cdt.LabeledPoints): ax = f.add_subplot(projection="3d") colors = ["r", "b", "gray"] sizes = [20, 20, 2] - for i, (type, x) in enumerate(geo3d.groupby("type")): - ax.scatter(x[:, 0], x[:, 1], x[:, 2], c=colors[i], s=sizes[i]) + for i, (point_type, x) in enumerate(geo3d.groupby("type")): + if len(x) > 0: + ax.scatter(x[:, 0], x[:, 1], x[:, 2], c=colors[i], s=sizes[i]) + # Draw lines connecting sources to detectors for each channel for i in range(amp.sizes["channel"]): src = geo3d.loc[amp.source[i], :] det = geo3d.loc[amp.detector[i], :] ax.plot([src[0], det[0]], [src[1], det[1]], [src[2], det[2]], c="k") - # if available mark Nasion in yellow - if "Nz" in geo3d.label: + # Determine which landmarks to highlight + if landmarks is None: + # Default: show canonical registration landmarks + canonical_landmarks = ["Nz", "Iz", "LPA", "RPA", "Cz"] + landmarks_to_plot = [ + label for label in canonical_landmarks + if label in geo3d.label.values + ] + else: + # Show specified landmarks (filter non-existent ones) + landmarks_to_plot = [ + label for label in landmarks + if label in geo3d.label.values + ] + + landmark_colors = ["y", "m", "c", "orange", "lime", "pink", "brown", "purple"] + for idx, label in enumerate(landmarks_to_plot): + color = landmark_colors[idx % len(landmark_colors)] ax.scatter( - geo3d.loc["Nz", 0], geo3d.loc["Nz", 1], geo3d.loc["Nz", 2], c="y", s=25 + geo3d.loc[label, 0], + geo3d.loc[label, 1], + geo3d.loc[label, 2], + c=color, + s=50, + label=label ) + + if landmarks_to_plot: + ax.legend(bbox_to_anchor=(0, 0.5), loc='center right') ax.view_init(elev=30, azim=145) p.tight_layout() diff --git a/tests/test_landmarks.py b/tests/test_landmarks.py new file mode 100644 index 00000000..bdf826bd --- /dev/null +++ b/tests/test_landmarks.py @@ -0,0 +1,256 @@ +"""Tests for landmark normalization and visualization functions.""" + +import numpy as np +import pytest +import xarray as xr + +import cedalion.dataclasses as cdc +from cedalion.geometry.landmarks import normalize_landmarks_labels +from cedalion.vis.anatomy.montage import plot_montage3D + + +def test_normalize_all_alternatives(): + """Test normalization of all alternative landmark names.""" + labels = ["NASION", "INION", "lpa", "rpa", "CZ", "left ear", "vertex", "S1"] + coords = np.random.rand(len(labels), 3) * 100 + types = [cdc.PointType.LANDMARK] * 7 + [cdc.PointType.SOURCE] + + geo3d = cdc.build_labeled_points( + coords, crs="unknown", units="mm", labels=labels, types=types + ) + + result = normalize_landmarks_labels(geo3d) + + expected = {"Nz", "Iz", "LPA", "RPA", "Cz", "S1"} + assert set(result.label.values) == expected + + +def test_normalize_duplicate(): + """Test critical duplicate handling: multiple alternatives + canonical present.""" + labels = ["NASION", "Nasion", "Nz", "INION", "Iz", "S1"] + coords = np.random.rand(len(labels), 3) * 100 + types = [cdc.PointType.LANDMARK] * 5 + [cdc.PointType.SOURCE] + + geo3d = cdc.build_labeled_points( + coords, crs="unknown", units="mm", labels=labels, types=types + ) + + result = normalize_landmarks_labels(geo3d) + + # Should keep canonicals, drop alternatives, no duplicates + assert list(result.label.values).count("Nz") == 1 + assert list(result.label.values).count("Iz") == 1 + assert len(result.label) == 3 # Nz, Iz, S1 + + +def test_normalize_preserve_unknown_and_canonical(): + """Test unknown labels and already-canonical labels are preserved.""" + labels = ["Nz", "UnknownLandmark", "S1"] + coords = np.random.rand(len(labels), 3) * 100 + types = [cdc.PointType.LANDMARK] * 2 + [cdc.PointType.SOURCE] + + geo3d = cdc.build_labeled_points( + coords, crs="unknown", units="mm", labels=labels, types=types + ) + + result = normalize_landmarks_labels(geo3d) + + assert set(result.label.values) == {"Nz", "UnknownLandmark", "S1"} + + +def test_normalize_empty_data(): + """Test handling of empty input.""" + geo3d = cdc.build_labeled_points( + np.empty((0, 3)), crs="unknown", units="mm", labels=[], types=[] + ) + + result = normalize_landmarks_labels(geo3d) + + assert len(result.label) == 0 + + +@pytest.fixture +def sample_geo3d(): + """Create sample geo3d with all canonical landmarks and optodes.""" + labels = ["Nz", "Iz", "LPA", "RPA", "Cz", "S1", "D1"] + coords = np.array([ + [0, 0, 100], # Nz + [0, 0, -50], # Iz + [-50, 0, 25], # LPA + [50, 0, 25], # RPA + [0, 50, 50], # Cz + [-30, 40, 80], # S1 + [-20, 40, 80], # D1 + ]) + types = [ + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.SOURCE, + cdc.PointType.DETECTOR, + ] + + return cdc.build_labeled_points( + coords, crs="unknown", units="mm", labels=labels, types=types + ) + + +@pytest.fixture +def sample_amp(): + """Create sample amplitude data.""" + return xr.DataArray( + np.random.rand(10, 1), + dims=["time", "channel"], + coords={ + "time": np.arange(10), + "source": ("channel", ["S1"]), + "detector": ("channel", ["D1"]), + }, + ) + + +def test_plot_landmarks_modes(sample_amp, sample_geo3d): + """Test landmark modes: None (default), list, and empty list.""" + # None should show all canonical landmarks present (default) + plot_montage3D(sample_amp, sample_geo3d, landmarks=None) + # List should show specified landmarks + plot_montage3D(sample_amp, sample_geo3d, landmarks=["Nz", "Iz"]) + # Empty list should show no landmarks + plot_montage3D(sample_amp, sample_geo3d, landmarks=[]) + + +def test_plot_default_shows_all_canonical_landmarks(sample_amp, sample_geo3d): + """Test that default (None) shows all 5 canonical landmarks if present.""" + # sample_geo3d has all 5 canonical landmarks + plot_montage3D(sample_amp, sample_geo3d) + # Should not raise and would show Nz, Iz, LPA, RPA, Cz + + +def test_plot_default_with_partial_canonical_landmarks(sample_amp): + """Test default shows only canonical landmarks that are present.""" + # Create geo3d with only some canonical landmarks + labels = ["Nz", "LPA", "S1", "D1", "UnknownLandmark"] + coords = np.array([ + [0, 0, 100], # Nz + [-50, 0, 25], # LPA + [-30, 40, 80], # S1 + [-20, 40, 80], # D1 + [10, 10, 10], # UnknownLandmark + ]) + types = [ + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.SOURCE, + cdc.PointType.DETECTOR, + cdc.PointType.LANDMARK, + ] + + geo3d = cdc.build_labeled_points( + coords, crs="unknown", units="mm", labels=labels, types=types + ) + + # Should show only Nz and LPA (canonical landmarks present) + # UnknownLandmark should not be shown + plot_montage3D(sample_amp, geo3d) + + +def test_plot_empty_list_shows_no_landmarks(sample_amp, sample_geo3d): + """Test that empty list explicitly shows no landmarks.""" + # Should not raise and show no landmarks + plot_montage3D(sample_amp, sample_geo3d, landmarks=[]) + + +def test_plot_custom_list_shows_only_specified(sample_amp, sample_geo3d): + """Test custom list shows only specified landmarks.""" + # Should show only Nz and Cz, not the others + plot_montage3D(sample_amp, sample_geo3d, landmarks=["Nz", "Cz"]) + + +def test_plot_nonexistent_landmarks_filtered(sample_amp, sample_geo3d): + """Test that non-existent landmarks are filtered out silently.""" + # Should show only Nz and RPA (NonExistent filtered out) + plot_montage3D( + sample_amp, sample_geo3d, landmarks=["Nz", "NonExistent", "RPA"] + ) + + +def test_normalize_then_plot(): + """Test full workflow: normalize alternative names then plot.""" + labels = ["NASION", "INION", "lpa", "rpa", "CZ", "S1", "D1"] + coords = np.array([ + [0, 0, 100], # NASION -> Nz + [0, 0, -50], # INION -> Iz + [-50, 0, 25], # lpa -> LPA + [50, 0, 25], # rpa -> RPA + [0, 50, 50], # CZ -> Cz + [-30, 40, 80], # S1 + [-20, 40, 80], # D1 + ]) + types = [ + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.LANDMARK, + cdc.PointType.SOURCE, + cdc.PointType.DETECTOR, + ] + + geo3d = cdc.build_labeled_points( + coords, crs="unknown", units="mm", labels=labels, types=types + ) + + amp = xr.DataArray( + np.random.rand(10, 1), + dims=["time", "channel"], + coords={ + "time": np.arange(10), + "source": ("channel", ["S1"]), + "detector": ("channel", ["D1"]), + }, + ) + + # Normalize labels + normalized_geo3d = normalize_landmarks_labels(geo3d) + + # Verify normalization + assert "Nz" in normalized_geo3d.label.values + assert "Iz" in normalized_geo3d.label.values + assert "LPA" in normalized_geo3d.label.values + assert "RPA" in normalized_geo3d.label.values + assert "Cz" in normalized_geo3d.label.values + assert "NASION" not in normalized_geo3d.label.values + + # Plot with default (shows all canonical landmarks) + plot_montage3D(amp, normalized_geo3d) + + # Plot with custom list + plot_montage3D(amp, normalized_geo3d, landmarks=["Nz", "Cz"]) + + # Plot with empty list + plot_montage3D(amp, normalized_geo3d, landmarks=[]) + + +def test_plot_no_canonical_landmarks_present(sample_amp): + """Test default behavior when no canonical landmarks are present.""" + # Create geo3d with only non-canonical landmarks + labels = ["S1", "D1", "CustomLandmark"] + coords = np.array([ + [-30, 40, 80], # S1 + [-20, 40, 80], # D1 + [10, 10, 10], # CustomLandmark + ]) + types = [ + cdc.PointType.SOURCE, + cdc.PointType.DETECTOR, + cdc.PointType.LANDMARK, + ] + + geo3d = cdc.build_labeled_points( + coords, crs="unknown", units="mm", labels=labels, types=types + ) + + # Should not raise, just show no landmarks + plot_montage3D(sample_amp, geo3d)