diff --git a/eks/command_line_args.py b/eks/command_line_args.py index 86bb059..be3e70b 100644 --- a/eks/command_line_args.py +++ b/eks/command_line_args.py @@ -50,11 +50,10 @@ def handle_parse_args(script_type): parser.add_argument( '--s-frames', help='frames to be considered for smoothing ' - 'parameter optimization, first 2k frames by default. Moot if --s is specified. ' - 'Format: "[(start_int, end_int), (start_int, end_int), ... ]" or int. ' - 'Inputting a single int uses all frames from 1 to the int. ' + 'parameter optimization. Moot if --s is specified. ' + 'Format: "[(start_int, end_int), (start_int, end_int), ... ]". ' '(None, end_int) starts from first frame; (start_int, None) proceeds to last frame.', - default=[(None, 10000)], + default=None, type=parse_s_frames, ) parser.add_argument( diff --git a/eks/utils.py b/eks/utils.py index 343e547..618fe83 100644 --- a/eks/utils.py +++ b/eks/utils.py @@ -3,7 +3,6 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from jax import numpy as jnp from sleap_io.io.slp import read_labels from typeguard import typechecked @@ -221,34 +220,62 @@ def plot_results( print(f'see example EKS output at {save_file}') -@typechecked -def crop_frames(y: np.ndarray | jnp.ndarray, s_frames: list | tuple) -> np.ndarray | jnp.ndarray: - """ Crops frames as specified by s_frames to be used for auto-tuning s.""" - # Create an empty list to store arrays - result = [] +def crop_frames(y: np.ndarray, + s_frames: list[tuple[int | None, int | None]] | None) -> np.ndarray: + """ + Crop frames from `y` according to `s_frames`. - for frame in s_frames: - # Unpack the frame, setting defaults for empty start or end - start, end = frame - # Default start to 0 if not specified (and adjust for zero indexing) - start = start - 1 if start is not None else 0 - # Default end to the length of ys if not specified - end = end if end is not None else len(y) + Rules: + - Each element is (start, end), where start/end are 0-based, [start, end) half-open. + Use None for open ends (e.g., (None, 100) → frames [0:100), (250, None) → [250:end)). + - s_frames is None or [(None, None)] → return y unchanged. + """ + n = len(y) + + # Case 1: No cropping at all + if s_frames is None or (len(s_frames) == 1 and s_frames[0] == (None, None)): + return y + if len(s_frames) == 0: + return y - # Cap the indices within valid range - start = max(0, start) - end = min(len(y), end) + # Type enforcement + if not isinstance(s_frames, list): + raise TypeError("s_frames must be a list of (start, end) tuples or None.") - # Validate the keys - if start >= end: - raise ValueError(f"Index range ({start + 1}, {end}) " - f"is out of bounds for the list of length {len(y)}.") + spans = [] + for i, frame in enumerate(s_frames): + if not (isinstance(frame, tuple) and len(frame) == 2): + raise ValueError(f"s_frames[{i}] must be a (start, end) tuple, got {frame!r}") - # Use numpy slicing to preserve the data structure - result.append(y[start:end]) + start, end = frame - # Concatenate all slices into a single numpy array - return np.concatenate(result) + if start is not None and not isinstance(start, int): + raise ValueError(f"s_frames[{i}].start must be int or None, got {start!r}") + if end is not None and not isinstance(end, int): + raise ValueError(f"s_frames[{i}].end must be int or None, got {end!r}") + + start_idx = 0 if start is None else start + end_idx = n if end is None else end + + if start_idx < 0 or end_idx > n: + raise ValueError(f"Range ({start_idx}, {end_idx}) out of bounds for length {n}.") + if start_idx >= end_idx: + raise ValueError(f"Invalid range ({start_idx}, {end_idx}).") + + spans.append((start_idx, end_idx)) + + # Ensure ascending, non-overlapping order + spans.sort(key=lambda s: s[0]) + for i in range(1, len(spans)): + if spans[i][0] < spans[i - 1][1]: + raise ValueError( + f"Overlapping or out-of-order intervals: {spans[i - 1]} and {spans[i]}") + + # Perform crop + if len(spans) == 1: + s, e = spans[0] + return y[s:e] + return np.concatenate([y[s:e] for s, e in spans], axis=0) @typechecked() diff --git a/tests/test_multicam_smoother.py b/tests/test_multicam_smoother.py index c36f0c2..e70d000 100644 --- a/tests/test_multicam_smoother.py +++ b/tests/test_multicam_smoother.py @@ -156,7 +156,7 @@ def test_ensemble_kalman_smoother_multicam_no_smooth_param(): camera_names = ['cam1', 'cam2'] quantile_keep_pca = 90 - s_frames = [(0, 100)] + s_frames = None # Run the smoother without providing smooth_param camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( @@ -191,7 +191,7 @@ def test_ensemble_kalman_smoother_multicam_n_latent(): camera_names = ['cam1', 'cam2'] quantile_keep_pca = 90 - s_frames = [(0, 10)] + s_frames = None for n_latent in [2, 3, 5]: # Test different PCA dimensions camera_dfs, _ = ensemble_kalman_smoother_multicam( diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..45fb525 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,62 @@ +import numpy as np +import pytest + +from eks.utils import crop_frames + + +def test_crop_frames_no_crop_none(): + """If s_frames is None, return y unchanged.""" + y = np.arange(20) + out = crop_frames(y, None) + assert np.shares_memory(out, y) or np.array_equal(out, y) + assert out.shape == y.shape + + +def test_crop_frames_no_crop_none_none(): + """If s_frames == [(None, None)], return y unchanged.""" + y = np.arange(20) + out = crop_frames(y, [(None, None)]) + assert np.shares_memory(out, y) or np.array_equal(out, y) + assert out.shape == y.shape + + +def test_crop_frames_single_span(): + """Basic single-span crop with 1-based (inclusive) bounds.""" + y = np.arange(10) # [0..9] + # (2,5) -> indices 2,3,4 + out = crop_frames(y, [(2, 5)]) + np.testing.assert_array_equal(out, np.array([2, 3, 4])) + + +def test_crop_frames_open_ended_spans(): + """Open-ended spans using None for start or end.""" + y = np.arange(10) + # (None,3) -> [0,1,2] + # (7,None) -> [7,8,9] + out = crop_frames(y, [(None, 3), (7, None)]) + np.testing.assert_array_equal(out, np.array([0, 1, 2, 7, 8, 9])) + + +def test_crop_frames_invalid_tuple_shape(): + """Each element must be a 2-tuple (start, end).""" + y = np.arange(10) + with pytest.raises(ValueError): + crop_frames(y, [(1, 3, 5)]) # 3-tuple is invalid + + +def test_crop_frames_out_of_bounds(): + """Out-of-bounds ranges raise ValueError.""" + y = np.arange(10) + with pytest.raises(ValueError): + crop_frames(y, [(1, 20)]) + + # start beyond end (invalid after conversion) + with pytest.raises(ValueError): + crop_frames(y, [(6, 5)]) + + +def test_crop_frames_overlap_raises(): + """Overlapping intervals are rejected.""" + y = np.arange(20) + with pytest.raises(ValueError): + crop_frames(y, [(2, 6), (5, 10)])