From c91728fa988488eb269927dff31a506c00ebe0b3 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Mon, 10 Nov 2025 18:25:46 -0500 Subject: [PATCH 1/4] fixed crop_frames for robustness --- eks/command_line_args.py | 7 ++-- eks/utils.py | 74 +++++++++++++++++++++++++++------------- tests/test_utils.py | 64 ++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 28 deletions(-) create mode 100644 tests/test_utils.py diff --git a/eks/command_line_args.py b/eks/command_line_args.py index 86bb059..be3e70b 100644 --- a/eks/command_line_args.py +++ b/eks/command_line_args.py @@ -50,11 +50,10 @@ def handle_parse_args(script_type): parser.add_argument( '--s-frames', help='frames to be considered for smoothing ' - 'parameter optimization, first 2k frames by default. Moot if --s is specified. ' - 'Format: "[(start_int, end_int), (start_int, end_int), ... ]" or int. ' - 'Inputting a single int uses all frames from 1 to the int. ' + 'parameter optimization. Moot if --s is specified. ' + 'Format: "[(start_int, end_int), (start_int, end_int), ... ]". ' '(None, end_int) starts from first frame; (start_int, None) proceeds to last frame.', - default=[(None, 10000)], + default=None, type=parse_s_frames, ) parser.add_argument( diff --git a/eks/utils.py b/eks/utils.py index 343e547..0297f91 100644 --- a/eks/utils.py +++ b/eks/utils.py @@ -3,7 +3,6 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from jax import numpy as jnp from sleap_io.io.slp import read_labels from typeguard import typechecked @@ -221,34 +220,61 @@ def plot_results( print(f'see example EKS output at {save_file}') -@typechecked -def crop_frames(y: np.ndarray | jnp.ndarray, s_frames: list | tuple) -> np.ndarray | jnp.ndarray: - """ Crops frames as specified by s_frames to be used for auto-tuning s.""" - # Create an empty list to store arrays - result = [] +def crop_frames(y: np.ndarray, + s_frames: list[tuple[int | None, int | None]] | None) -> np.ndarray: + """ + Crop frames from `y` according to `s_frames`. - for frame in s_frames: - # Unpack the frame, setting defaults for empty start or end - start, end = frame - # Default start to 0 if not specified (and adjust for zero indexing) - start = start - 1 if start is not None else 0 - # Default end to the length of ys if not specified - end = end if end is not None else len(y) + Rules (1-based, inclusive user spans): + - Each element is (start, end), where start/end are 1-based, inclusive. + Use None for open ends (e.g., (None, 100) → frames [0:100), (250, None) → [249:end)). + - s_frames is None or [(None, None)] → return y unchanged. + """ + n = len(y) + + # Case 1: No cropping at all + if s_frames is None or (len(s_frames) == 1 and s_frames[0] == (None, None)): + return y - # Cap the indices within valid range - start = max(0, start) - end = min(len(y), end) + # Type enforcement + if not isinstance(s_frames, list): + raise TypeError("s_frames must be a list of (start, end) tuples or None.") - # Validate the keys - if start >= end: - raise ValueError(f"Index range ({start + 1}, {end}) " - f"is out of bounds for the list of length {len(y)}.") + spans = [] + for i, frame in enumerate(s_frames): + if not (isinstance(frame, tuple) and len(frame) == 2): + raise ValueError(f"s_frames[{i}] must be a (start, end) tuple, got {frame!r}") - # Use numpy slicing to preserve the data structure - result.append(y[start:end]) + start, end = frame - # Concatenate all slices into a single numpy array - return np.concatenate(result) + if start is not None and not isinstance(start, int): + raise ValueError(f"s_frames[{i}].start must be int or None, got {start!r}") + if end is not None and not isinstance(end, int): + raise ValueError(f"s_frames[{i}].end must be int or None, got {end!r}") + + # Convert 1-based inclusive to 0-based half-open + start_idx = 0 if start is None else start - 1 + end_idx = n if end is None else end + + if start_idx < 0 or end_idx > n: + raise ValueError(f"Range ({start_idx + 1}, {end_idx}) out of bounds for length {n}.") + if start_idx >= end_idx: + raise ValueError(f"Invalid range ({start_idx + 1}, {end_idx}).") + + spans.append((start_idx, end_idx)) + + # Ensure ascending, non-overlapping order + spans.sort(key=lambda s: s[0]) + for i in range(1, len(spans)): + if spans[i][0] < spans[i - 1][1]: + raise ValueError( + f"Overlapping or out-of-order intervals: {spans[i - 1]} and {spans[i]}") + + # Perform crop + if len(spans) == 1: + s, e = spans[0] + return y[s:e] + return np.concatenate([y[s:e] for s, e in spans], axis=0) @typechecked() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..2039406 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest + +from eks.utils import crop_frames + + +def test_crop_frames_no_crop_none(): + """If s_frames is None, return y unchanged.""" + y = np.arange(20) + out = crop_frames(y, None) + assert np.shares_memory(out, y) or np.array_equal(out, y) + assert out.shape == y.shape + + +def test_crop_frames_no_crop_none_none(): + """If s_frames == [(None, None)], return y unchanged.""" + y = np.arange(20) + out = crop_frames(y, [(None, None)]) + assert np.shares_memory(out, y) or np.array_equal(out, y) + assert out.shape == y.shape + + +def test_crop_frames_single_span(): + """Basic single-span crop with 1-based (inclusive) bounds.""" + y = np.arange(10) # [0..9] + # (start=2, end=5) → 1-based inclusive => [1..5) in 0-based => indices 1..4 => [1,2,3,4] + out = crop_frames(y, [(2, 5)]) + np.testing.assert_array_equal(out, np.array([1, 2, 3, 4])) + + +def test_crop_frames_open_ended_spans(): + """Open-ended spans using None for start or end.""" + y = np.arange(10) # [0..9] + # (None, 3) -> [0:3) => [0,1,2] + # (7, None) -> [6:end) => [6,7,8,9] + out = crop_frames(y, [(None, 3), (7, None)]) + np.testing.assert_array_equal(out, np.array([0, 1, 2, 6, 7, 8, 9])) + + +def test_crop_frames_invalid_tuple_shape(): + """Each element must be a 2-tuple (start, end).""" + y = np.arange(10) + with pytest.raises(ValueError): + crop_frames(y, [(1, 3, 5)]) # 3-tuple is invalid + + +def test_crop_frames_out_of_bounds(): + """Out-of-bounds ranges raise ValueError.""" + y = np.arange(10) + # end too large (1-based end=20 -> 0-based end_idx=20 > n) + with pytest.raises(ValueError): + crop_frames(y, [(1, 20)]) + + # start beyond end (invalid after conversion) + with pytest.raises(ValueError): + crop_frames(y, [(6, 5)]) + + +def test_crop_frames_overlap_raises(): + """Overlapping intervals are rejected.""" + y = np.arange(20) + # Overlap: (2, 6) -> [1:6), (5, 10) -> [4:10) overlap on indices 4..5 + with pytest.raises(ValueError): + crop_frames(y, [(2, 6), (5, 10)]) From 04390268d8e8de8edf3ab411cc3e7c49e29456f7 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Mon, 10 Nov 2025 18:33:27 -0500 Subject: [PATCH 2/4] fixed pytest compat with s-frames changes --- tests/test_multicam_smoother.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_multicam_smoother.py b/tests/test_multicam_smoother.py index c36f0c2..e70d000 100644 --- a/tests/test_multicam_smoother.py +++ b/tests/test_multicam_smoother.py @@ -156,7 +156,7 @@ def test_ensemble_kalman_smoother_multicam_no_smooth_param(): camera_names = ['cam1', 'cam2'] quantile_keep_pca = 90 - s_frames = [(0, 100)] + s_frames = None # Run the smoother without providing smooth_param camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( @@ -191,7 +191,7 @@ def test_ensemble_kalman_smoother_multicam_n_latent(): camera_names = ['cam1', 'cam2'] quantile_keep_pca = 90 - s_frames = [(0, 10)] + s_frames = None for n_latent in [2, 3, 5]: # Test different PCA dimensions camera_dfs, _ = ensemble_kalman_smoother_multicam( From d357b7789df4315f05fbb6d553bb66683a710051 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Tue, 11 Nov 2025 19:21:38 -0500 Subject: [PATCH 3/4] 0-based half open change for frame cropping --- eks/utils.py | 15 ++++++++------- tests/test_utils.py | 12 +++++------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/eks/utils.py b/eks/utils.py index 0297f91..618fe83 100644 --- a/eks/utils.py +++ b/eks/utils.py @@ -225,9 +225,9 @@ def crop_frames(y: np.ndarray, """ Crop frames from `y` according to `s_frames`. - Rules (1-based, inclusive user spans): - - Each element is (start, end), where start/end are 1-based, inclusive. - Use None for open ends (e.g., (None, 100) → frames [0:100), (250, None) → [249:end)). + Rules: + - Each element is (start, end), where start/end are 0-based, [start, end) half-open. + Use None for open ends (e.g., (None, 100) → frames [0:100), (250, None) → [250:end)). - s_frames is None or [(None, None)] → return y unchanged. """ n = len(y) @@ -235,6 +235,8 @@ def crop_frames(y: np.ndarray, # Case 1: No cropping at all if s_frames is None or (len(s_frames) == 1 and s_frames[0] == (None, None)): return y + if len(s_frames) == 0: + return y # Type enforcement if not isinstance(s_frames, list): @@ -252,14 +254,13 @@ def crop_frames(y: np.ndarray, if end is not None and not isinstance(end, int): raise ValueError(f"s_frames[{i}].end must be int or None, got {end!r}") - # Convert 1-based inclusive to 0-based half-open - start_idx = 0 if start is None else start - 1 + start_idx = 0 if start is None else start end_idx = n if end is None else end if start_idx < 0 or end_idx > n: - raise ValueError(f"Range ({start_idx + 1}, {end_idx}) out of bounds for length {n}.") + raise ValueError(f"Range ({start_idx}, {end_idx}) out of bounds for length {n}.") if start_idx >= end_idx: - raise ValueError(f"Invalid range ({start_idx + 1}, {end_idx}).") + raise ValueError(f"Invalid range ({start_idx}, {end_idx}).") spans.append((start_idx, end_idx)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2039406..d52bb28 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,16 +23,16 @@ def test_crop_frames_no_crop_none_none(): def test_crop_frames_single_span(): """Basic single-span crop with 1-based (inclusive) bounds.""" y = np.arange(10) # [0..9] - # (start=2, end=5) → 1-based inclusive => [1..5) in 0-based => indices 1..4 => [1,2,3,4] + # (2,5) -> indices 2,3,4 out = crop_frames(y, [(2, 5)]) - np.testing.assert_array_equal(out, np.array([1, 2, 3, 4])) + np.testing.assert_array_equal(out, np.array([2, 3, 4])) def test_crop_frames_open_ended_spans(): """Open-ended spans using None for start or end.""" - y = np.arange(10) # [0..9] - # (None, 3) -> [0:3) => [0,1,2] - # (7, None) -> [6:end) => [6,7,8,9] + y = np.arange(10) + # (None,3) -> [0,1,2] + # (7,None) -> [7,8,9] out = crop_frames(y, [(None, 3), (7, None)]) np.testing.assert_array_equal(out, np.array([0, 1, 2, 6, 7, 8, 9])) @@ -47,7 +47,6 @@ def test_crop_frames_invalid_tuple_shape(): def test_crop_frames_out_of_bounds(): """Out-of-bounds ranges raise ValueError.""" y = np.arange(10) - # end too large (1-based end=20 -> 0-based end_idx=20 > n) with pytest.raises(ValueError): crop_frames(y, [(1, 20)]) @@ -59,6 +58,5 @@ def test_crop_frames_out_of_bounds(): def test_crop_frames_overlap_raises(): """Overlapping intervals are rejected.""" y = np.arange(20) - # Overlap: (2, 6) -> [1:6), (5, 10) -> [4:10) overlap on indices 4..5 with pytest.raises(ValueError): crop_frames(y, [(2, 6), (5, 10)]) From 5ce3071c356782b70926bb37703e84829176dec9 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Tue, 11 Nov 2025 19:31:27 -0500 Subject: [PATCH 4/4] minor test fix for s_frames --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index d52bb28..45fb525 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -34,7 +34,7 @@ def test_crop_frames_open_ended_spans(): # (None,3) -> [0,1,2] # (7,None) -> [7,8,9] out = crop_frames(y, [(None, 3), (7, None)]) - np.testing.assert_array_equal(out, np.array([0, 1, 2, 6, 7, 8, 9])) + np.testing.assert_array_equal(out, np.array([0, 1, 2, 7, 8, 9])) def test_crop_frames_invalid_tuple_shape():