Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions eks/command_line_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ def handle_parse_args(script_type):
parser.add_argument(
'--s-frames',
help='frames to be considered for smoothing '
'parameter optimization, first 2k frames by default. Moot if --s is specified. '
'Format: "[(start_int, end_int), (start_int, end_int), ... ]" or int. '
'Inputting a single int uses all frames from 1 to the int. '
'parameter optimization. Moot if --s is specified. '
'Format: "[(start_int, end_int), (start_int, end_int), ... ]". '
'(None, end_int) starts from first frame; (start_int, None) proceeds to last frame.',
default=[(None, 10000)],
default=None,
type=parse_s_frames,
)
parser.add_argument(
Expand Down
75 changes: 51 additions & 24 deletions eks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from jax import numpy as jnp
from sleap_io.io.slp import read_labels
from typeguard import typechecked

Expand Down Expand Up @@ -221,34 +220,62 @@ def plot_results(
print(f'see example EKS output at {save_file}')


@typechecked
def crop_frames(y: np.ndarray | jnp.ndarray, s_frames: list | tuple) -> np.ndarray | jnp.ndarray:
""" Crops frames as specified by s_frames to be used for auto-tuning s."""
# Create an empty list to store arrays
result = []
def crop_frames(y: np.ndarray,
s_frames: list[tuple[int | None, int | None]] | None) -> np.ndarray:
"""
Crop frames from `y` according to `s_frames`.

for frame in s_frames:
# Unpack the frame, setting defaults for empty start or end
start, end = frame
# Default start to 0 if not specified (and adjust for zero indexing)
start = start - 1 if start is not None else 0
# Default end to the length of ys if not specified
end = end if end is not None else len(y)
Rules:
- Each element is (start, end), where start/end are 0-based, [start, end) half-open.
Use None for open ends (e.g., (None, 100) → frames [0:100), (250, None) → [250:end)).
- s_frames is None or [(None, None)] → return y unchanged.
"""
n = len(y)

# Case 1: No cropping at all
if s_frames is None or (len(s_frames) == 1 and s_frames[0] == (None, None)):
return y
if len(s_frames) == 0:
return y

# Cap the indices within valid range
start = max(0, start)
end = min(len(y), end)
# Type enforcement
if not isinstance(s_frames, list):
raise TypeError("s_frames must be a list of (start, end) tuples or None.")

# Validate the keys
if start >= end:
raise ValueError(f"Index range ({start + 1}, {end}) "
f"is out of bounds for the list of length {len(y)}.")
spans = []
for i, frame in enumerate(s_frames):
if not (isinstance(frame, tuple) and len(frame) == 2):
raise ValueError(f"s_frames[{i}] must be a (start, end) tuple, got {frame!r}")

# Use numpy slicing to preserve the data structure
result.append(y[start:end])
start, end = frame

# Concatenate all slices into a single numpy array
return np.concatenate(result)
if start is not None and not isinstance(start, int):
raise ValueError(f"s_frames[{i}].start must be int or None, got {start!r}")
if end is not None and not isinstance(end, int):
raise ValueError(f"s_frames[{i}].end must be int or None, got {end!r}")

start_idx = 0 if start is None else start
end_idx = n if end is None else end

if start_idx < 0 or end_idx > n:
raise ValueError(f"Range ({start_idx}, {end_idx}) out of bounds for length {n}.")
if start_idx >= end_idx:
raise ValueError(f"Invalid range ({start_idx}, {end_idx}).")

spans.append((start_idx, end_idx))

# Ensure ascending, non-overlapping order
spans.sort(key=lambda s: s[0])
for i in range(1, len(spans)):
if spans[i][0] < spans[i - 1][1]:
raise ValueError(
f"Overlapping or out-of-order intervals: {spans[i - 1]} and {spans[i]}")

# Perform crop
if len(spans) == 1:
s, e = spans[0]
return y[s:e]
return np.concatenate([y[s:e] for s, e in spans], axis=0)


@typechecked()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_ensemble_kalman_smoother_multicam_no_smooth_param():

camera_names = ['cam1', 'cam2']
quantile_keep_pca = 90
s_frames = [(0, 100)]
s_frames = None

# Run the smoother without providing smooth_param
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_ensemble_kalman_smoother_multicam_n_latent():

camera_names = ['cam1', 'cam2']
quantile_keep_pca = 90
s_frames = [(0, 10)]
s_frames = None

for n_latent in [2, 3, 5]: # Test different PCA dimensions
camera_dfs, _ = ensemble_kalman_smoother_multicam(
Expand Down
62 changes: 62 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np
import pytest

from eks.utils import crop_frames


def test_crop_frames_no_crop_none():
"""If s_frames is None, return y unchanged."""
y = np.arange(20)
out = crop_frames(y, None)
assert np.shares_memory(out, y) or np.array_equal(out, y)
assert out.shape == y.shape


def test_crop_frames_no_crop_none_none():
"""If s_frames == [(None, None)], return y unchanged."""
y = np.arange(20)
out = crop_frames(y, [(None, None)])
assert np.shares_memory(out, y) or np.array_equal(out, y)
assert out.shape == y.shape


def test_crop_frames_single_span():
"""Basic single-span crop with 1-based (inclusive) bounds."""
y = np.arange(10) # [0..9]
# (2,5) -> indices 2,3,4
out = crop_frames(y, [(2, 5)])
np.testing.assert_array_equal(out, np.array([2, 3, 4]))


def test_crop_frames_open_ended_spans():
"""Open-ended spans using None for start or end."""
y = np.arange(10)
# (None,3) -> [0,1,2]
# (7,None) -> [7,8,9]
out = crop_frames(y, [(None, 3), (7, None)])
np.testing.assert_array_equal(out, np.array([0, 1, 2, 7, 8, 9]))


def test_crop_frames_invalid_tuple_shape():
"""Each element must be a 2-tuple (start, end)."""
y = np.arange(10)
with pytest.raises(ValueError):
crop_frames(y, [(1, 3, 5)]) # 3-tuple is invalid


def test_crop_frames_out_of_bounds():
"""Out-of-bounds ranges raise ValueError."""
y = np.arange(10)
with pytest.raises(ValueError):
crop_frames(y, [(1, 20)])

# start beyond end (invalid after conversion)
with pytest.raises(ValueError):
crop_frames(y, [(6, 5)])


def test_crop_frames_overlap_raises():
"""Overlapping intervals are rejected."""
y = np.arange(20)
with pytest.raises(ValueError):
crop_frames(y, [(2, 6), (5, 10)])