From 61dddb296ff6d07642ba3088a13ef05427389211 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Tue, 15 Jul 2025 11:43:47 -0400 Subject: [PATCH 01/11] dynamax-linear for singlecam with fixed s --- eks/command_line_args.py | 6 +++++ eks/core.py | 26 ++++++++++++++++--- eks/kalman_backends.py | 50 ++++++++++++++++++++++++++++++++++++ eks/singlecam_smoother.py | 5 +++- scripts/singlecam_example.py | 4 ++- 5 files changed, 85 insertions(+), 6 deletions(-) create mode 100644 eks/kalman_backends.py diff --git a/eks/command_line_args.py b/eks/command_line_args.py index 86bb059..99a8be0 100644 --- a/eks/command_line_args.py +++ b/eks/command_line_args.py @@ -71,6 +71,12 @@ def handle_parse_args(script_type): default='', type=str, ) + parser.add_argument( + '--backend', + help='Options: jax, dynamax-linear. Determines the backend to be used for smoothing.', + default='jax', + type=str + ) if script_type == 'singlecam': add_bodyparts(parser) add_s(parser) diff --git a/eks/core.py b/eks/core.py index d00a49a..53cba2d 100644 --- a/eks/core.py +++ b/eks/core.py @@ -12,6 +12,7 @@ from eks.marker_array import MarkerArray from eks.utils import crop_frames +from eks.kalman_backends import dynamax_linear_smooth_routine # ------------------------------------------------------------------------------------- # Kalman Functions: Functions related to performing filtering and smoothing @@ -385,6 +386,7 @@ def final_forwards_backwards_pass( Cs: jnp.ndarray, As: jnp.ndarray, ensemble_vars: np.ndarray, + backend: str = 'jax' ) -> Tuple[np.ndarray, np.ndarray]: """ Runs the full Kalman forward-backward smoother across all keypoints using @@ -416,9 +418,24 @@ def final_forwards_backwards_pass( # Run forward and backward pass for each keypoint for k in range(n_keypoints): - mf, Vf, nll = forward_pass( - ys[k], m0s[k], S0s[k], As[k], Qs[k], Cs[k], ensemble_vars[:, k, :]) - ms, Vs = backward_pass(mf, Vf, As[k], Qs[k]) + y = ys[k] # (T, obs_dim) + m0 = m0s[k] # (D,) + S0 = S0s[k] # (D, D) + A = As[k] # (D, D) + Q = Qs[k] # (D, D) + C = Cs[k] # (obs_dim, D) + R_diag = np.mean(ensemble_vars[:, k, :], axis=0) + R = np.diag(R_diag) # (obs_dim, obs_dim) + + if backend == 'jax': + mf, Vf, _ = forward_pass(y, m0, S0, A, Q, C, ensemble_vars[:, k, :]) + ms, Vs = backward_pass(mf, Vf, A, Q) + + elif backend == 'dynamax-linear': + ms, Vs = dynamax_linear_smooth_routine(y, m0, S0, A, Q, C, R) + + else: + raise ValueError(f"Unsupported backend: {backend}") ms_array.append(np.array(ms)) Vs_array.append(np.array(Vs)) @@ -481,6 +498,7 @@ def optimize_smooth_param( blocks: Optional[List[List[int]]] = None, maxiter: int = 1000, verbose: bool = False, + backend: str = 'jax', ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Optimize smoothing parameters for each keypoint (or block of keypoints) using @@ -598,7 +616,7 @@ def step(s, opt_state): s_finals = np.array(s_finals) # Final smooth with optimized s ms, Vs = final_forwards_backwards_pass( - cov_mats, s_finals, ys, m0s, S0s, Cs, As, ensemble_vars, + cov_mats, s_finals, ys, m0s, S0s, Cs, As, ensemble_vars, backend=backend, ) return s_finals, ms, Vs diff --git a/eks/kalman_backends.py b/eks/kalman_backends.py new file mode 100644 index 0000000..3b398b2 --- /dev/null +++ b/eks/kalman_backends.py @@ -0,0 +1,50 @@ +from dynamax.linear_gaussian_ssm.models import ( + LinearGaussianSSM, + ParamsLGSSM, + ParamsLGSSMInitial, + ParamsLGSSMDynamics, + ParamsLGSSMEmissions +) +import jax +import jax.numpy as jnp +import numpy as np +from typing import Union, Tuple +from typeguard import typechecked + +ArrayLike = Union[np.ndarray, jax.Array] + +@typechecked +def dynamax_linear_smooth_routine( + y: ArrayLike, + m0: ArrayLike, + S0: ArrayLike, + A: ArrayLike, + Q: ArrayLike, + C: ArrayLike, + R: ArrayLike +) -> Tuple[jnp.ndarray, jnp.ndarray]: + # Convert everything to JAX arrays + y, m0, S0, A, Q, C, R = map(jnp.asarray, (y, m0, S0, A, Q, C, R)) + state_dim, obs_dim = A.shape[0], C.shape[0] + + # Build model and correct param structure + model = LinearGaussianSSM(state_dim, obs_dim) + + params = ParamsLGSSM( + initial=ParamsLGSSMInitial(mean=m0, cov=S0), + dynamics=ParamsLGSSMDynamics( + weights=A, + cov=Q, + bias=jnp.zeros(A.shape[0]), # shape (state_dim,) + input_weights=jnp.zeros((A.shape[0], 0)) # shape (state_dim, 0) for no control input + ), + emissions=ParamsLGSSMEmissions( + weights=C, + cov=R, + bias=jnp.zeros(C.shape[0]), # shape (obs_dim,) + input_weights=jnp.zeros((C.shape[0], 0)) # shape (obs_dim, 0) for no control input + ) + ) + + posterior = model.smoother(params, y) + return posterior.smoothed_means, posterior.smoothed_covariances diff --git a/eks/singlecam_smoother.py b/eks/singlecam_smoother.py index 3c9f849..cf96aa9 100644 --- a/eks/singlecam_smoother.py +++ b/eks/singlecam_smoother.py @@ -21,6 +21,7 @@ def fit_eks_singlecam( avg_mode: str = 'median', var_mode: str = 'confidence_weighted_var', verbose: bool = False, + backend: str = 'jax', ) -> tuple: """Fit the Ensemble Kalman Smoother for single-camera data. @@ -65,6 +66,7 @@ def fit_eks_singlecam( avg_mode=avg_mode, var_mode=var_mode, verbose=verbose, + backend=backend, ) # Save the output DataFrame to CSV @@ -85,6 +87,7 @@ def ensemble_kalman_smoother_singlecam( avg_mode: str = 'median', var_mode: str = 'confidence_weighted_var', verbose: bool = False, + backend: str = 'jax', ) -> tuple: """Perform Ensemble Kalman Smoothing for single-camera data. @@ -139,7 +142,7 @@ def ensemble_kalman_smoother_singlecam( # Main smoothing function s_finals, ms, Vs = optimize_smooth_param( cov_mats, ys, m0s, S0s, Cs, As, emA_vars.get_array(squeeze=True), - s_frames, smooth_param, blocks, verbose=verbose, + s_frames, smooth_param, blocks, verbose=verbose, backend=backend ) y_m_smooths = np.zeros((n_keypoints, n_frames, 2)) diff --git a/scripts/singlecam_example.py b/scripts/singlecam_example.py index 54389cf..f082d7c 100644 --- a/scripts/singlecam_example.py +++ b/scripts/singlecam_example.py @@ -24,6 +24,7 @@ s_frames = args.s_frames # Frames to be used for automatic optimization if s is not provided blocks = args.blocks verbose = True if args.verbose == 'True' else False +backend = args.backend # Fit EKS using the provided input data output_df, s_finals, input_dfs, bodypart_list = fit_eks_singlecam( @@ -33,7 +34,8 @@ smooth_param=s, s_frames=s_frames, blocks=blocks, - verbose=verbose + verbose=verbose, + backend=backend ) # Plot results for a specific keypoint (default to last keypoint) From 1b5b880deb4ce393882b5b7d74d7ace5eb35b587 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Tue, 15 Jul 2025 11:52:21 -0400 Subject: [PATCH 02/11] dynamax-linear for multicam with fixed s --- eks/multicam_smoother.py | 13 ++++++++----- scripts/multicam_example.py | 4 +++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/eks/multicam_smoother.py b/eks/multicam_smoother.py index 9e3eaeb..001f296 100644 --- a/eks/multicam_smoother.py +++ b/eks/multicam_smoother.py @@ -129,7 +129,8 @@ def fit_eks_multicam( var_mode: str = 'confidence_weighted_var', inflate_vars: bool = False, verbose: bool = False, - n_latent: int = 3 + n_latent: int = 3, + backend: str = 'jax', ) -> tuple: """ Fit the Ensemble Kalman Smoother for un-mirrored multi-camera data. @@ -177,7 +178,8 @@ def fit_eks_multicam( var_mode=var_mode, verbose=verbose, inflate_vars=inflate_vars, - n_latent=n_latent + n_latent=n_latent, + backend=backend, ) # Save output DataFrames to CSVs (one per camera view) os.makedirs(save_dir, exist_ok=True) @@ -201,7 +203,8 @@ def ensemble_kalman_smoother_multicam( inflate_vars_kwargs: dict = {}, verbose: bool = False, pca_object: PCA | None = None, - n_latent: int = 3 + n_latent: int = 3, + backend: str = 'jax', ) -> tuple: """ Use multi-view constraints to fit a 3D latent subspace for each body part. @@ -296,7 +299,8 @@ def ensemble_kalman_smoother_multicam( ensemble_vars=np.swapaxes(ensemble_vars, 0, 1), s_frames=s_frames, smooth_param=smooth_param, - verbose=verbose + verbose=verbose, + backend=backend, ) # Reproject from latent space back to observed space camera_arrs = [[] for _ in camera_names] @@ -373,7 +377,6 @@ def initialize_kalman_filter_pca( ]) As = np.tile(np.eye(n_latent), (n_keypoints, 1, 1)) Cs = np.stack([pca.components_.T for pca in ensemble_pca]) - Rs = np.tile(np.eye(n_latent), (n_keypoints, 1, 1)) cov_mats = [] for k in range(n_keypoints): diff --git a/scripts/multicam_example.py b/scripts/multicam_example.py index f25dbc0..aec05c2 100644 --- a/scripts/multicam_example.py +++ b/scripts/multicam_example.py @@ -27,6 +27,7 @@ verbose = True if args.verbose == 'True' else False inflate_vars = True if args.inflate_vars == 'True' else False n_latent = args.n_latent +backend = args.backend # Fit EKS using the provided input data camera_dfs, s_finals, input_dfs, bodypart_list = fit_eks_multicam( @@ -39,7 +40,8 @@ quantile_keep_pca=quantile_keep_pca, verbose=verbose, inflate_vars=inflate_vars, - n_latent=args.n_latent + n_latent=args.n_latent, + backend=backend, ) # Plot results for a specific keypoint (default to last keypoint of last camera view) From 3dd814ebe722d4744f476cc07bfe8e1d08134ac1 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Wed, 16 Jul 2025 14:47:15 -0400 Subject: [PATCH 03/11] R_t impl --- eks/core.py | 5 ++--- eks/kalman_backends.py | 36 +++++++++++++++++++++++++++--------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/eks/core.py b/eks/core.py index 53cba2d..eb90df6 100644 --- a/eks/core.py +++ b/eks/core.py @@ -424,15 +424,14 @@ def final_forwards_backwards_pass( A = As[k] # (D, D) Q = Qs[k] # (D, D) C = Cs[k] # (obs_dim, D) - R_diag = np.mean(ensemble_vars[:, k, :], axis=0) - R = np.diag(R_diag) # (obs_dim, obs_dim) + per_timestep_vars = ensemble_vars[:, k, :] # (T, obs_dim) if backend == 'jax': mf, Vf, _ = forward_pass(y, m0, S0, A, Q, C, ensemble_vars[:, k, :]) ms, Vs = backward_pass(mf, Vf, A, Q) elif backend == 'dynamax-linear': - ms, Vs = dynamax_linear_smooth_routine(y, m0, S0, A, Q, C, R) + ms, Vs = dynamax_linear_smooth_routine(y, m0, S0, A, Q, C, per_timestep_vars) else: raise ValueError(f"Unsupported backend: {backend}") diff --git a/eks/kalman_backends.py b/eks/kalman_backends.py index 3b398b2..76dcbaa 100644 --- a/eks/kalman_backends.py +++ b/eks/kalman_backends.py @@ -21,13 +21,31 @@ def dynamax_linear_smooth_routine( A: ArrayLike, Q: ArrayLike, C: ArrayLike, - R: ArrayLike + ensemble_vars: ArrayLike # shape (T, obs_dim) ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Run Dynamax smoother with time-varying diagonal observation noise from ensemble variances. + + Args: + y: (T, obs_dim) observations + m0: (state_dim,) initial mean + S0: (state_dim, state_dim) initial covariance + A: (state_dim, state_dim) transition matrix + Q: (state_dim, state_dim) process noise + C: (obs_dim, state_dim) emission matrix + ensemble_vars: (T, obs_dim) per-timestep observation noise variance + + Returns: + smoothed_means: (T, state_dim) + smoothed_covs: (T, state_dim, state_dim) + """ + # Convert everything to JAX arrays - y, m0, S0, A, Q, C, R = map(jnp.asarray, (y, m0, S0, A, Q, C, R)) - state_dim, obs_dim = A.shape[0], C.shape[0] + y, m0, S0, A, Q, C, ensemble_vars = map(jnp.asarray, (y, m0, S0, A, Q, C, ensemble_vars)) + T, obs_dim = y.shape + state_dim = A.shape[0] - # Build model and correct param structure + # Dynamax accepts time-varying diagonal R_t as (T, obs_dim) model = LinearGaussianSSM(state_dim, obs_dim) params = ParamsLGSSM( @@ -35,14 +53,14 @@ def dynamax_linear_smooth_routine( dynamics=ParamsLGSSMDynamics( weights=A, cov=Q, - bias=jnp.zeros(A.shape[0]), # shape (state_dim,) - input_weights=jnp.zeros((A.shape[0], 0)) # shape (state_dim, 0) for no control input + bias=jnp.zeros(state_dim), + input_weights=jnp.zeros((state_dim, 0)) ), emissions=ParamsLGSSMEmissions( weights=C, - cov=R, - bias=jnp.zeros(C.shape[0]), # shape (obs_dim,) - input_weights=jnp.zeros((C.shape[0], 0)) # shape (obs_dim, 0) for no control input + cov=ensemble_vars, # <=== time-varying diagonal noise + bias=jnp.zeros(obs_dim), + input_weights=jnp.zeros((obs_dim, 0)) ) ) From 3d5531c029017ed0d47860b0ce1ae7c9429ce3fc Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Tue, 16 Sep 2025 11:27:57 -0400 Subject: [PATCH 04/11] nonlinear notebooks --- bb_crop.ipynb | 173 ++++++++ bbox_uncrop.ipynb | 231 ++++++++++ ekf.ipynb | 964 +++++++++++++++++++++++++++++++++++++++++ eks/core.py | 11 +- eks/kalman_backends.py | 86 ++-- 5 files changed, 1417 insertions(+), 48 deletions(-) create mode 100644 bb_crop.ipynb create mode 100644 bbox_uncrop.ipynb create mode 100644 ekf.ipynb diff --git a/bb_crop.ipynb b/bb_crop.ipynb new file mode 100644 index 0000000..7e93793 --- /dev/null +++ b/bb_crop.ipynb @@ -0,0 +1,173 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "aef7a1cd", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from typing import List, Union\n", + "import pandas as pd\n", + "import re\n", + "\n", + "def _detect_frame_index_pose(df: pd.DataFrame) -> pd.DataFrame:\n", + " df = df.copy()\n", + " if 'scorer' in df.columns and pd.api.types.is_integer_dtype(df['scorer']):\n", + " return df.set_index('scorer', drop=True).rename_axis('frame')\n", + " for col in ('frame', 'index', 'Unnamed: 0'):\n", + " if col in df.columns and pd.api.types.is_integer_dtype(df[col]):\n", + " return df.set_index(col, drop=True).rename_axis('frame')\n", + " return df.reset_index(drop=True).rename_axis('frame')\n", + "\n", + "def _detect_frame_index_bbox(df: pd.DataFrame) -> pd.DataFrame:\n", + " df = df.copy()\n", + " if 'frame' in df.columns:\n", + " return df.set_index('frame', drop=True)\n", + " if 'Unnamed: 0' in df.columns:\n", + " return df.rename(columns={'Unnamed: 0': 'frame'}).set_index('frame', drop=True)\n", + " if not isinstance(df.index, pd.RangeIndex):\n", + " return df.reset_index(drop=False).rename(columns={'index': 'frame'}).set_index('frame', drop=True)\n", + " return df.rename_axis('frame')\n", + "\n", + "def _xy_like_multiindex_cols(columns: pd.MultiIndex):\n", + " x_cols, y_cols = [], []\n", + " for col in columns:\n", + " last = col[-1]\n", + " if isinstance(last, str):\n", + " if last.startswith('x') and 'var' not in last: x_cols.append(col)\n", + " if last.startswith('y') and 'var' not in last: y_cols.append(col)\n", + " return x_cols, y_cols\n", + "\n", + "def _xy_like_flat_cols(columns: pd.Index):\n", + " x_cols, y_cols = [], []\n", + " for c in columns.astype(str):\n", + " if 'var' in c.lower(): \n", + " continue\n", + " if c == 'x' or c.endswith('_x') or c.startswith('x_'): x_cols.append(c)\n", + " if c == 'y' or c.endswith('_y') or c.startswith('y_'): y_cols.append(c)\n", + " return x_cols, y_cols\n", + "\n", + "def translate_pose_by_bbox(pose_df: pd.DataFrame, bbox_df: pd.DataFrame, mode: str = \"subtract\") -> pd.DataFrame:\n", + " \"\"\"Map full-frame coords -> bbox-cropped coords (mode='subtract').\n", + " Use mode='add' to go back to full-frame.\"\"\"\n", + " pose_df = _detect_frame_index_pose(pose_df)\n", + " bbox_df = _detect_frame_index_bbox(bbox_df)\n", + " common = pose_df.index.intersection(bbox_df.index)\n", + " pose_df, bbox_df = pose_df.loc[common].copy(), bbox_df.loc[common].copy()\n", + " if not {'x','y'}.issubset(bbox_df.columns):\n", + " raise ValueError(f\"bbox_df must have 'x' and 'y'; got {bbox_df.columns}\")\n", + " sign = -1 if mode == \"subtract\" else 1\n", + "\n", + " if isinstance(pose_df.columns, pd.MultiIndex):\n", + " x_cols, y_cols = _xy_like_multiindex_cols(pose_df.columns)\n", + " else:\n", + " x_cols, y_cols = _xy_like_flat_cols(pose_df.columns)\n", + "\n", + " if x_cols: pose_df.loc[:, x_cols] = pose_df.loc[:, x_cols].add(sign * bbox_df['x'], axis=0)\n", + " if y_cols: pose_df.loc[:, y_cols] = pose_df.loc[:, y_cols].add(sign * bbox_df['y'], axis=0)\n", + " return pose_df\n", + "\n", + "def batch_translate_pose_csvs(\n", + " pose_csvs: List[Union[str, Path]],\n", + " bbox_csvs: List[Union[str, Path]],\n", + " output_dir: Union[str, Path],\n", + " mode: str = \"subtract\",\n", + " suffix: str = \"\",\n", + "):\n", + " if len(pose_csvs) != len(bbox_csvs):\n", + " raise ValueError(\"pose_csvs and bbox_csvs must be same length (one per view).\")\n", + " output_dir = Path(output_dir); output_dir.mkdir(parents=True, exist_ok=True)\n", + " outs = []\n", + " for pose_csv, bbox_csv in zip(pose_csvs, bbox_csvs):\n", + " # Try MultiIndex header first (Lightning Pose/EKS), else flat\n", + " try: pose_df = pd.read_csv(pose_csv, header=[0,1,2])\n", + " except Exception: pose_df = pd.read_csv(pose_csv)\n", + " bbox_df = pd.read_csv(bbox_csv)\n", + " translated = translate_pose_by_bbox(pose_df, bbox_df, mode=mode)\n", + " out_path = output_dir / f\"{Path(pose_csv).stem}.csv\"\n", + " translated.to_csv(out_path)\n", + " outs.append(out_path)\n", + " return outs" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a1d8c557", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[WindowsPath('outputs/cropped_csvs/PRL43_200617_131904_lBack.short_cropped.csv')]\n" + ] + } + ], + "source": [ + "pose_csvs = [\"./outputs/chickadee-preds/video_preds/PRL43_200617_131904_lBack.short.csv\"] # one per view\n", + "bbox_csvs = [\"./data/bounding_boxes/PRL43_200617_131904_lBack.short_bbox.csv\"] # matching order\n", + "out_files = batch_translate_pose_csvs(pose_csvs, bbox_csvs, output_dir=\"./outputs/cropped_csvs\")\n", + "print(out_files)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d470ddc0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Frames: 1800\n", + "FPS: 60.0\n", + "Duration (s): 30.000\n" + ] + } + ], + "source": [ + "import cv2\n", + "\n", + "path = \"./videos/chickadee/PRL43_200617_131904_lBack.short.mp4\" # update to your file path\n", + "cap = cv2.VideoCapture(path)\n", + "if not cap.isOpened():\n", + " raise RuntimeError(\"Could not open video\")\n", + "\n", + "frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", + "fps = cap.get(cv2.CAP_PROP_FPS)\n", + "duration_sec = frame_count / fps if fps > 0 else None\n", + "\n", + "cap.release()\n", + "\n", + "print(f\"Frames: {frame_count}\")\n", + "print(f\"FPS: {fps}\")\n", + "print(f\"Duration (s): {duration_sec:.3f}\" if duration_sec else \"Duration unknown\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bbox_uncrop.ipynb b/bbox_uncrop.ipynb new file mode 100644 index 0000000..1cdfc88 --- /dev/null +++ b/bbox_uncrop.ipynb @@ -0,0 +1,231 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5b2692e2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] CWD: E:\\eks\n", + "[INFO] preds_root: E:\\eks\\data\\chickadee\n", + "[INFO] bbox_root : E:\\eks\\data\\bounding_boxes\n", + "[INFO] output_dir: E:\\eks\\data\\chickadee_uncropped\n" + ] + }, + { + "ename": "IntCastingNaNError", + "evalue": "Cannot convert non-finite values (NA or inf) to integer", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mIntCastingNaNError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[6], line 150\u001b[0m\n\u001b[0;32m 147\u001b[0m bbox_root \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./data/bounding_boxes\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m# one bbox CSV per camera (name contains camera)\u001b[39;00m\n\u001b[0;32m 148\u001b[0m output_dir \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./data/chickadee_uncropped\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m--> 150\u001b[0m \u001b[43mprocess_all\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcamera_names\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreds_root\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbbox_root\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[1;32mIn[6], line 125\u001b[0m, in \u001b[0;36mprocess_all\u001b[1;34m(camera_names, preds_root, bbox_root, output_dir)\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 121\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNeed exactly one bbox CSV for camera \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcam\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, found \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(bbox_matches)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 122\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m[\u001b[38;5;28mstr\u001b[39m(p)\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mp\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39mbbox_matches]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 123\u001b[0m )\n\u001b[0;32m 124\u001b[0m bbox_path \u001b[38;5;241m=\u001b[39m bbox_matches[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m--> 125\u001b[0m df_bbox \u001b[38;5;241m=\u001b[39m \u001b[43m_read_bbox_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbbox_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 127\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[INFO] Camera \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcam\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(preds_paths)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m preds, bbox=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbbox_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 129\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m pred_path \u001b[38;5;129;01min\u001b[39;00m preds_paths:\n", + "Cell \u001b[1;32mIn[6], line 33\u001b[0m, in \u001b[0;36m_read_bbox_csv\u001b[1;34m(path)\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[0;32m 31\u001b[0m df \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(p, header\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mframe\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mh\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m---> 33\u001b[0m df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mframe\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mframe\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m c \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mh\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[0;32m 35\u001b[0m df[c] \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mto_numeric(df[c], errors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mraise\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\generic.py:6640\u001b[0m, in \u001b[0;36mNDFrame.astype\u001b[1;34m(self, dtype, copy, errors)\u001b[0m\n\u001b[0;32m 6634\u001b[0m results \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m 6635\u001b[0m ser\u001b[38;5;241m.\u001b[39mastype(dtype, copy\u001b[38;5;241m=\u001b[39mcopy, errors\u001b[38;5;241m=\u001b[39merrors) \u001b[38;5;28;01mfor\u001b[39;00m _, ser \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitems()\n\u001b[0;32m 6636\u001b[0m ]\n\u001b[0;32m 6638\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 6639\u001b[0m \u001b[38;5;66;03m# else, only a single dtype is given\u001b[39;00m\n\u001b[1;32m-> 6640\u001b[0m new_data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_mgr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 6641\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_constructor_from_mgr(new_data, axes\u001b[38;5;241m=\u001b[39mnew_data\u001b[38;5;241m.\u001b[39maxes)\n\u001b[0;32m 6642\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m res\u001b[38;5;241m.\u001b[39m__finalize__(\u001b[38;5;28mself\u001b[39m, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mastype\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\internals\\managers.py:430\u001b[0m, in \u001b[0;36mBaseBlockManager.astype\u001b[1;34m(self, dtype, copy, errors)\u001b[0m\n\u001b[0;32m 427\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m using_copy_on_write():\n\u001b[0;32m 428\u001b[0m copy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m--> 430\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 431\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mastype\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43musing_cow\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43musing_copy_on_write\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 436\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\internals\\managers.py:363\u001b[0m, in \u001b[0;36mBaseBlockManager.apply\u001b[1;34m(self, f, align_keys, **kwargs)\u001b[0m\n\u001b[0;32m 361\u001b[0m applied \u001b[38;5;241m=\u001b[39m b\u001b[38;5;241m.\u001b[39mapply(f, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 362\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 363\u001b[0m applied \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(b, f)(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 364\u001b[0m result_blocks \u001b[38;5;241m=\u001b[39m extend_blocks(applied, result_blocks)\n\u001b[0;32m 366\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39mfrom_blocks(result_blocks, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maxes)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\internals\\blocks.py:758\u001b[0m, in \u001b[0;36mBlock.astype\u001b[1;34m(self, dtype, copy, errors, using_cow, squeeze)\u001b[0m\n\u001b[0;32m 755\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan not squeeze with more than one column.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 756\u001b[0m values \u001b[38;5;241m=\u001b[39m values[\u001b[38;5;241m0\u001b[39m, :] \u001b[38;5;66;03m# type: ignore[call-overload]\u001b[39;00m\n\u001b[1;32m--> 758\u001b[0m new_values \u001b[38;5;241m=\u001b[39m \u001b[43mastype_array_safe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 760\u001b[0m new_values \u001b[38;5;241m=\u001b[39m maybe_coerce_values(new_values)\n\u001b[0;32m 762\u001b[0m refs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\dtypes\\astype.py:237\u001b[0m, in \u001b[0;36mastype_array_safe\u001b[1;34m(values, dtype, copy, errors)\u001b[0m\n\u001b[0;32m 234\u001b[0m dtype \u001b[38;5;241m=\u001b[39m dtype\u001b[38;5;241m.\u001b[39mnumpy_dtype\n\u001b[0;32m 236\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 237\u001b[0m new_values \u001b[38;5;241m=\u001b[39m \u001b[43mastype_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 238\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mTypeError\u001b[39;00m):\n\u001b[0;32m 239\u001b[0m \u001b[38;5;66;03m# e.g. _astype_nansafe can fail on object-dtype of strings\u001b[39;00m\n\u001b[0;32m 240\u001b[0m \u001b[38;5;66;03m# trying to convert to float\u001b[39;00m\n\u001b[0;32m 241\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m errors \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\dtypes\\astype.py:182\u001b[0m, in \u001b[0;36mastype_array\u001b[1;34m(values, dtype, copy)\u001b[0m\n\u001b[0;32m 179\u001b[0m values \u001b[38;5;241m=\u001b[39m values\u001b[38;5;241m.\u001b[39mastype(dtype, copy\u001b[38;5;241m=\u001b[39mcopy)\n\u001b[0;32m 181\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 182\u001b[0m values \u001b[38;5;241m=\u001b[39m \u001b[43m_astype_nansafe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 184\u001b[0m \u001b[38;5;66;03m# in pandas we don't store numpy str dtypes, so convert to object\u001b[39;00m\n\u001b[0;32m 185\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(dtype, np\u001b[38;5;241m.\u001b[39mdtype) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(values\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mtype, \u001b[38;5;28mstr\u001b[39m):\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\dtypes\\astype.py:101\u001b[0m, in \u001b[0;36m_astype_nansafe\u001b[1;34m(arr, dtype, copy, skipna)\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mensure_string_array(\n\u001b[0;32m 97\u001b[0m arr, skipna\u001b[38;5;241m=\u001b[39mskipna, convert_na_value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 98\u001b[0m )\u001b[38;5;241m.\u001b[39mreshape(shape)\n\u001b[0;32m 100\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m np\u001b[38;5;241m.\u001b[39missubdtype(arr\u001b[38;5;241m.\u001b[39mdtype, np\u001b[38;5;241m.\u001b[39mfloating) \u001b[38;5;129;01mand\u001b[39;00m dtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miu\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m--> 101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_astype_float_to_int_nansafe\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 103\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m arr\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mobject\u001b[39m:\n\u001b[0;32m 104\u001b[0m \u001b[38;5;66;03m# if we have a datetime/timedelta array of objects\u001b[39;00m\n\u001b[0;32m 105\u001b[0m \u001b[38;5;66;03m# then coerce to datetime64[ns] and use DatetimeArray.astype\u001b[39;00m\n\u001b[0;32m 107\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mis_np_dtype(dtype, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mM\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\dtypes\\astype.py:145\u001b[0m, in \u001b[0;36m_astype_float_to_int_nansafe\u001b[1;34m(values, dtype, copy)\u001b[0m\n\u001b[0;32m 141\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 142\u001b[0m \u001b[38;5;124;03mastype with a check preventing converting NaN to an meaningless integer value.\u001b[39;00m\n\u001b[0;32m 143\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np\u001b[38;5;241m.\u001b[39misfinite(values)\u001b[38;5;241m.\u001b[39mall():\n\u001b[1;32m--> 145\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m IntCastingNaNError(\n\u001b[0;32m 146\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot convert non-finite values (NA or inf) to integer\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 147\u001b[0m )\n\u001b[0;32m 148\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mu\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m 149\u001b[0m \u001b[38;5;66;03m# GH#45151\u001b[39;00m\n\u001b[0;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (values \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mall():\n", + "\u001b[1;31mIntCastingNaNError\u001b[0m: Cannot convert non-finite values (NA or inf) to integer" + ] + } + ], + "source": [ + "#!/usr/bin/env python3\n", + "\"\"\"\n", + "Un-crop DLC-style predictions using per-frame bounding boxes.\n", + "Recursively scans directories for prediction CSVs and bbox CSVs.\n", + "\"\"\"\n", + "\n", + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "from typing import List\n", + "import pandas as pd\n", + "\n", + "\n", + "# --------------------------\n", + "# Helpers\n", + "# --------------------------\n", + "\n", + "def _read_dlc_csv(path: str) -> pd.DataFrame:\n", + " \"\"\"Read DLC-style CSV with 3-row multi-index header.\"\"\"\n", + " return pd.read_csv(path, header=[0, 1, 2])\n", + "\n", + "def _read_bbox_csv(path: str) -> pd.DataFrame:\n", + " \"\"\"\n", + " Read a headerless bbox CSV with 5 columns: frame, x, y, h, w.\n", + " Assumes first cell (A1) is blank, so no header row.\n", + " \"\"\"\n", + " import pandas as pd\n", + "\n", + " # Force header=None so the first row is treated as data\n", + " df = pd.read_csv(path, header=None, names=[\"frame\", \"x\", \"y\", \"h\", \"w\"])\n", + "\n", + " # Coerce numeric values\n", + " for c in [\"frame\", \"x\", \"y\", \"h\", \"w\"]:\n", + " df[c] = pd.to_numeric(df[c], errors=\"coerce\")\n", + "\n", + " # If frames look 1-based, shift to 0-based\n", + " if df[\"frame\"].min() == 1:\n", + " df[\"frame\"] = df[\"frame\"] - 1\n", + "\n", + " # Final cast to int for frame\n", + " df[\"frame\"] = df[\"frame\"].round().astype(int)\n", + "\n", + " return df.sort_values(\"frame\").reset_index(drop=True)\n", + "\n", + "def _transform_predictions(df_pred: pd.DataFrame, df_bbox: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Apply uncropping transform.\"\"\"\n", + " n_frames = len(df_pred)\n", + " merged = pd.DataFrame({\"frame\": range(n_frames)}).merge(df_bbox, on=\"frame\", how=\"left\")\n", + " if merged[[\"x\", \"y\", \"h\", \"w\"]].isna().any().any():\n", + " missing = merged[merged[[\"x\",\"y\",\"h\",\"w\"]].isna().any(axis=1)][\"frame\"].tolist()\n", + " raise ValueError(f\"Missing bbox entries for frames (first 10 shown): {missing[:10]}\")\n", + "\n", + " x_off, y_off, h, w = [merged[c].to_numpy() for c in [\"x\", \"y\", \"h\", \"w\"]]\n", + " out = df_pred.copy()\n", + "\n", + " lvl0 = out.columns.get_level_values(0).unique()\n", + " lvl1 = out.columns.get_level_values(1).unique()\n", + "\n", + " for scorer in lvl0:\n", + " for bp in lvl1:\n", + " if (scorer, bp, \"x\") in out.columns:\n", + " x_vals = pd.to_numeric(out[(scorer, bp, \"x\")], errors=\"coerce\").to_numpy()\n", + " out[(scorer, bp, \"x\")] = (x_vals / 320.0) * w + x_off\n", + " if (scorer, bp, \"y\") in out.columns:\n", + " y_vals = pd.to_numeric(out[(scorer, bp, \"y\")], errors=\"coerce\").to_numpy()\n", + " out[(scorer, bp, \"y\")] = (y_vals / 320.0) * h + y_off\n", + " return out\n", + "\n", + "def _derive_output_path(pred_path: Path, output_dir: Path | None) -> Path:\n", + " \"\"\"Output path with _uncropped.csv suffix.\"\"\"\n", + " root = pred_path.stem\n", + " out_name = f\"{root}_uncropped.csv\"\n", + " if output_dir is not None:\n", + " output_dir.mkdir(parents=True, exist_ok=True)\n", + " return output_dir / out_name\n", + " return pred_path.parent / out_name\n", + "\n", + "\n", + "# --------------------------\n", + "# Discovery\n", + "# --------------------------\n", + "\n", + "def _rglob_csvs(root: Path) -> list[Path]:\n", + " return [p for p in root.rglob(\"*.csv\") if p.is_file()]\n", + "\n", + "def _filter_by_cam(paths: list[Path], cam: str) -> list[Path]:\n", + " cam_l = cam.lower()\n", + " return [p for p in paths if cam_l in p.name.lower()]\n", + "\n", + "# --------------------------\n", + "# Main processing\n", + "# --------------------------\n", + "\n", + "def process_all(camera_names: List[str], preds_root: str, bbox_root: str, output_dir: str | None = None) -> None:\n", + " # Resolve to absolute paths relative to the kernel's CWD\n", + " preds_root_p = Path(preds_root).expanduser().resolve()\n", + " bbox_root_p = Path(bbox_root).expanduser().resolve()\n", + " output_dir_p = Path(output_dir).expanduser().resolve() if output_dir else None\n", + "\n", + " print(f\"[INFO] CWD: {Path.cwd().resolve()}\")\n", + " print(f\"[INFO] preds_root: {preds_root_p}\")\n", + " print(f\"[INFO] bbox_root : {bbox_root_p}\")\n", + " if output_dir_p: print(f\"[INFO] output_dir: {output_dir_p}\")\n", + "\n", + " if not preds_root_p.exists():\n", + " raise FileNotFoundError(f\"preds_root does not exist: {preds_root_p}\")\n", + " if not bbox_root_p.exists():\n", + " raise FileNotFoundError(f\"bbox_root does not exist: {bbox_root_p}\")\n", + "\n", + " # Discover all CSVs once (recursive)\n", + " all_pred_csvs = _rglob_csvs(preds_root_p)\n", + " all_bbox_csvs = _rglob_csvs(bbox_root_p)\n", + "\n", + " if not all_pred_csvs:\n", + " print(f\"[WARN] No prediction CSVs found under {preds_root_p}\", file=sys.stderr)\n", + " if not all_bbox_csvs:\n", + " print(f\"[WARN] No bbox CSVs found under {bbox_root_p}\", file=sys.stderr)\n", + "\n", + " for cam in camera_names:\n", + " preds_paths = _filter_by_cam(all_pred_csvs, cam)\n", + " if not preds_paths:\n", + " print(f\"[WARN] No prediction CSVs matched camera '{cam}'\", file=sys.stderr)\n", + " continue\n", + "\n", + " bbox_matches = _filter_by_cam(all_bbox_csvs, cam)\n", + " if len(bbox_matches) != 1:\n", + " raise ValueError(\n", + " f\"Need exactly one bbox CSV for camera '{cam}', found {len(bbox_matches)}: \"\n", + " f\"{[str(p) for p in bbox_matches]}\"\n", + " )\n", + " bbox_path = bbox_matches[0]\n", + " df_bbox = _read_bbox_csv(bbox_path)\n", + "\n", + " print(f\"[INFO] Camera '{cam}': {len(preds_paths)} preds, bbox={bbox_path}\")\n", + "\n", + " for pred_path in preds_paths:\n", + " try:\n", + " df_pred = _read_dlc_csv(pred_path)\n", + " except Exception as e:\n", + " print(f\"[WARN] Skipping (bad DLC header?): {pred_path} ({e})\", file=sys.stderr)\n", + " continue\n", + " df_out = _transform_predictions(df_pred, df_bbox)\n", + " out_path = _derive_output_path(pred_path, output_dir_p)\n", + " df_out.to_csv(out_path, index=False)\n", + " print(f\" → Saved: {out_path}\")\n", + "\n", + "# --------------------------\n", + "# Hard-coded config\n", + "# --------------------------\n", + "\n", + "if __name__ == \"__main__\":\n", + " camera_names = [\"lBack\", \"lFront\", \"lTop\", \"rBack\", \"rFront\", \"rTop\"]\n", + " preds_root = \"./data/chickadee\" # can be relative; resolved & printed\n", + " bbox_root = \"./data/bounding_boxes\" # one bbox CSV per camera (name contains camera)\n", + " output_dir = \"./data/chickadee_uncropped\"\n", + "\n", + " process_all(camera_names, preds_root, bbox_root, output_dir)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e90a3c4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ekf.ipynb b/ekf.ipynb new file mode 100644 index 0000000..05eaffb --- /dev/null +++ b/ekf.ipynb @@ -0,0 +1,964 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "d3bfa09d", + "metadata": {}, + "outputs": [], + "source": [ + "# Pose Smoothing with Dynamax EKF\n", + "# We load ensemble 2D pose predictions from 6 cameras (A–F), compute ensemble variance for observation noise, \n", + "# triangulate a geometric 3D latent state using calibrated camera parameters, and apply the Extended Kalman Smoother (EKF) using Dynamax.\n", + "\n", + "import os\n", + "import numpy as np\n", + "import pandas as pd\n", + "from pathlib import Path\n", + "from glob import glob\n", + "\n", + "from aniposelib.boards import CharucoBoard\n", + "from aniposelib.cameras import CameraGroup" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d199b1f6", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import jit\n", + "\n", + "def _rodrigues(rvec):\n", + " \"\"\"OpenCV-style Rodrigues: rvec (3,) -> R (3,3).\"\"\"\n", + " theta = jnp.linalg.norm(rvec)\n", + " def small_angle(_):\n", + " # First-order approx: R ≈ I + [r]_x (good when theta ~ 0)\n", + " rx, ry, rz = rvec\n", + " K = jnp.array([[0.0, -rz, ry],\n", + " [rz, 0.0, -rx],\n", + " [-ry, rx, 0.0]])\n", + " return jnp.eye(3) + K\n", + " def general(_):\n", + " rx, ry, rz = rvec / theta\n", + " K = jnp.array([[0.0, -rz, ry],\n", + " [rz, 0.0, -rx],\n", + " [-ry, rx, 0.0]])\n", + " s = jnp.sin(theta)\n", + " c = jnp.cos(theta)\n", + " return jnp.eye(3) + s*K + (1.0 - c) * (K @ K)\n", + " return jax.lax.cond(theta < 1e-12, small_angle, general, operand=None)\n", + "\n", + "def _parse_dist(dist_coeffs):\n", + " \"\"\"\n", + " OpenCV pinhole distortion ordering:\n", + " [k1, k2, p1, p2, k3, k4, k5, k6, s1, s2, s3, s4, tx, ty] (tx,ty tilt optional)\n", + " We support up to s1..s4; tilt is ignored here.\n", + " \"\"\"\n", + " dc = jnp.pad(jnp.asarray(dist_coeffs, dtype=jnp.float64), (0, max(0, 14 - len(dist_coeffs)))) # length ≥ 14\n", + " k1, k2, p1, p2, k3, k4, k5, k6, s1, s2, s3, s4, tx, ty = [dc[i] for i in range(14)]\n", + " return dict(k1=k1, k2=k2, p1=p1, p2=p2, k3=k3, k4=k4, k5=k5, k6=k6, s1=s1, s2=s2, s3=s3, s4=s4)\n", + "\n", + "def make_jax_projection_fn(rvec, tvec, K, dist_coeffs):\n", + " \"\"\"\n", + " JAX-compatible replacement for cv2.projectPoints (standard pinhole model).\n", + "\n", + " Args\n", + " ----\n", + " rvec : (3,) Rodrigues rotation vector (world -> camera)\n", + " tvec : (3,) translation (world -> camera), same units as your world coords\n", + " K : (3,3) camera intrinsic matrix\n", + " [[fx, s, cx],\n", + " [ 0, fy, cy],\n", + " [ 0, 0, 1 ]]\n", + " dist_coeffs : iterable of distortion coefficients in OpenCV order\n", + " [k1, k2, p1, p2[, k3[, k4, k5, k6[, s1, s2, s3, s4[, tx, ty]]]]]\n", + "\n", + " Returns\n", + " -------\n", + " project(object_points) -> image_points\n", + " object_points: (..., 3)\n", + " image_points: (..., 2)\n", + " \"\"\"\n", + " # cache params as arrays\n", + " rvec = jnp.asarray(rvec, dtype=jnp.float64)\n", + " tvec = jnp.asarray(tvec, dtype=jnp.float64)\n", + " K = jnp.asarray(K, dtype=jnp.float64)\n", + " fx, fy, cx, cy, skew = K[0,0], K[1,1], K[0,2], K[1,2], K[0,1]\n", + " d = _parse_dist(dist_coeffs)\n", + " R = _rodrigues(rvec)\n", + "\n", + " @jit\n", + " def project(object_points):\n", + " # object_points: (..., 3)\n", + " Xw = jnp.asarray(object_points, dtype=jnp.float64)\n", + " # world -> camera\n", + " Xc = Xw @ R.T + tvec # (..., 3)\n", + " X, Y, Z = Xc[..., 0], Xc[..., 1], Xc[..., 2]\n", + "\n", + " # normalized coords\n", + " x = X / Z\n", + " y = Y / Z\n", + "\n", + " r2 = x*x + y*y\n", + " r4 = r2*r2\n", + " r6 = r4*r2\n", + " r8 = r4*r4\n", + " r10 = r8*r2\n", + " r12 = r6*r6\n", + "\n", + " radial = (\n", + " 1.0\n", + " + d[\"k1\"]*r2 + d[\"k2\"]*r4 + d[\"k3\"]*r6\n", + " + d[\"k4\"]*r8 + d[\"k5\"]*r10 + d[\"k6\"]*r12\n", + " )\n", + "\n", + " x_tan = 2.0*d[\"p1\"]*x*y + d[\"p2\"]*(r2 + 2.0*x*x)\n", + " y_tan = d[\"p1\"]*(r2 + 2.0*y*y) + 2.0*d[\"p2\"]*x*y\n", + "\n", + " # thin-prism\n", + " x_tp = d[\"s1\"]*r2 + d[\"s2\"]*r4\n", + " y_tp = d[\"s3\"]*r2 + d[\"s4\"]*r4\n", + "\n", + " xd = x * radial + x_tan + x_tp\n", + " yd = y * radial + y_tan + y_tp\n", + "\n", + " # intrinsics (allow nonzero skew)\n", + " u = fx * xd + skew * yd + cx\n", + " v = fy * yd + cy\n", + "\n", + " return jnp.stack([u, v], axis=-1) # (..., 2)\n", + "\n", + " return project" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35029956", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import cv2\n", + "import jax.numpy as jnp\n", + "from jax import jit, vmap\n", + "import numpy as np\n", + "import pandas as pd\n", + "from aniposelib.cameras import CameraGroup\n", + "from sklearn.decomposition import PCA\n", + "from typeguard import typechecked\n", + "from typing import Tuple, Callable\n", + "from eks.core import ensemble\n", + "from eks.marker_array import (\n", + " MarkerArray,\n", + " input_dfs_to_markerArray,\n", + " mA_to_stacked_array,\n", + " stacked_array_to_mA,\n", + ")\n", + "import jax\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "from eks.stats import compute_mahalanobis, compute_pca\n", + "from eks.utils import center_predictions, format_data, make_dlc_pandas_index\n", + "from eks.multicam_smoother import mA_compute_maha, initialize_kalman_filter_pca\n", + "\n", + "def fit_eks_multicam(\n", + " input_source: str | list,\n", + " save_dir: str,\n", + " bodypart_list: list | None = None,\n", + " smooth_param: float | list | None = None,\n", + " s_frames: list | None = None,\n", + " camera_names: list | None = None,\n", + " quantile_keep_pca: float = 95.0,\n", + " avg_mode: str = 'median',\n", + " var_mode: str = 'confidence_weighted_var',\n", + " inflate_vars: bool = False,\n", + " verbose: bool = False,\n", + " n_latent: int = 3,\n", + " backend: str = 'jax',\n", + " camgroup=None\n", + ") -> tuple:\n", + "\n", + " # Load and format input files\n", + " # NOTE: input_dfs_list is a list of camera-specific lists of Dataframes\n", + " input_dfs_list, keypoint_names = format_data(input_source, camera_names=camera_names)\n", + " if bodypart_list is None:\n", + " bodypart_list = keypoint_names\n", + "\n", + " marker_array = input_dfs_to_markerArray(input_dfs_list, bodypart_list, camera_names)\n", + "\n", + " # Run the ensemble Kalman smoother for multi-camera data\n", + " camera_dfs, smooth_params_final, h_cams, ys_3d = ensemble_kalman_smoother_multicam(\n", + " marker_array=marker_array,\n", + " keypoint_names=bodypart_list,\n", + " smooth_param=smooth_param,\n", + " quantile_keep_pca=quantile_keep_pca,\n", + " camera_names=camera_names,\n", + " s_frames=s_frames,\n", + " avg_mode=avg_mode,\n", + " var_mode=var_mode,\n", + " verbose=verbose,\n", + " inflate_vars=inflate_vars,\n", + " n_latent=n_latent,\n", + " backend=backend,\n", + " camgroup=camgroup\n", + " )\n", + " # Save output DataFrames to CSVs (one per camera view)\n", + " os.makedirs(save_dir, exist_ok=True)\n", + " for c, camera in enumerate(camera_names):\n", + " save_filename = f'multicam_{camera}_results.csv'\n", + " camera_dfs[c].to_csv(os.path.join(save_dir, save_filename))\n", + " return camera_dfs, smooth_params_final, input_dfs_list, bodypart_list, marker_array, h_cams, ys_3d\n", + "\n", + "def initialize_kalman_filter_geometric(ys: np.ndarray) -> Tuple[jnp.ndarray, ...]:\n", + " \"\"\"\n", + " Initialize Kalman filter parameters for geometric (3D) keypoints.\n", + "\n", + " Args:\n", + " ys: Array of shape (K, T, 3) — triangulated keypoints.\n", + "\n", + " Returns:\n", + " Tuple of Kalman filter parameters:\n", + " - m0s: (K, 3) initial means\n", + " - S0s: (K, 3, 3) initial covariances\n", + " - As: (K, 3, 3) transition matrices\n", + " - Qs: (K, 3, 3) process noise covariances\n", + " - Cs: (K, 3, 3) observation matrices\n", + " \"\"\"\n", + " K, T, D = ys.shape\n", + "\n", + " # Initial state means (can also use ys[:, 0, :] if preferred)\n", + " m0s = np.zeros((K, D))\n", + " # Use variance across time to estimate initial uncertainty\n", + " S0s = np.array([\n", + " np.diag([\n", + " np.nanvar(ys[k, :, d]) + 1e-4 # avoid degenerate matrices\n", + " for d in range(D)\n", + " ])\n", + " for k in range(K)\n", + " ]) # (K, 3, 3)\n", + "\n", + " # Identity matrices\n", + " As = np.tile(np.eye(D), (K, 1, 1))\n", + " Cs = np.tile(np.eye(D), (K, 1, 1))\n", + " Qs = np.tile(np.eye(D), (K, 1, 1)) * 1e-3 # small default process noise\n", + "\n", + " return (\n", + " jnp.array(m0s),\n", + " jnp.array(S0s),\n", + " jnp.array(As),\n", + " jnp.array(Qs),\n", + " jnp.array(Cs),\n", + " )\n", + "\n", + "\n", + "def ensemble_kalman_smoother_multicam(\n", + " marker_array: MarkerArray,\n", + " keypoint_names: list,\n", + " smooth_param: float | list | None = None,\n", + " quantile_keep_pca: float = 95.0,\n", + " camera_names: list | None = None,\n", + " s_frames: list | None = None,\n", + " avg_mode: str = 'median',\n", + " var_mode: str = 'confidence_weighted_var',\n", + " inflate_vars: bool = False,\n", + " inflate_vars_kwargs: dict = {},\n", + " verbose: bool = False,\n", + " pca_object: PCA | None = None,\n", + " n_latent: int = 3,\n", + " backend: str = 'jax',\n", + " camgroup=None,\n", + ") -> tuple:\n", + "\n", + " n_models, n_cameras, n_frames, n_keypoints, _ = marker_array.shape\n", + "\n", + " # === Ensemble Mean/Var per camera/keypoint ===\n", + " ensemble_marker_array = ensemble(marker_array, avg_mode=avg_mode, var_mode=var_mode)\n", + " emA_unsmoothed_preds = ensemble_marker_array.slice_fields(\"x\", \"y\")\n", + " emA_vars = ensemble_marker_array.slice_fields(\"var_x\", \"var_y\")\n", + " emA_likes = ensemble_marker_array.slice_fields(\"likelihood\")\n", + "\n", + " # === Triangulate all 3D positions ===\n", + " triangulated_3d_models = np.zeros((n_models, n_keypoints, n_frames, 3))\n", + " raw_array = marker_array.get_array()\n", + " for m in range(n_models):\n", + " for k in range(n_keypoints):\n", + " for t in range(n_frames):\n", + " xy_views = [raw_array[m, c, t, k, :2] for c in range(n_cameras)]\n", + " triangulated_3d_models[m, k, t] = camgroup.triangulate(np.array(xy_views))\n", + "\n", + " ys_3d = triangulated_3d_models.mean(axis=0) # (K, T, 3)\n", + " ensemble_vars_3d = triangulated_3d_models.var(axis=0) # (K, T, 3)\n", + "\n", + " # === Define a single multi-view h_fn (ℝ³ → ℝ^{2V}) ===\n", + " h_cams = []\n", + " for cam in camgroup.cameras:\n", + " print(cam.get_size())\n", + " rot = np.array(cam.get_rotation())\n", + " # Convert to Rodrigues vector if needed\n", + " rvec = cv2.Rodrigues(rot)[0].ravel() if rot.shape == (3, 3) else rot.ravel()\n", + " tvec = np.array(cam.get_translation()).ravel()\n", + " K = np.array(cam.get_camera_matrix())\n", + " dist = np.array(cam.get_distortions()).ravel() # distortion coeffs: k1,k2,p1,p2,k3,...\n", + "\n", + " h_cams.append(\n", + " make_jax_projection_fn(\n", + " jnp.array(rvec),\n", + " jnp.array(tvec),\n", + " jnp.array(K),\n", + " jnp.array(dist)\n", + " )\n", + " )\n", + "\n", + " def make_combined_h_fn(h_list):\n", + " def h_fn(x):\n", + " return jnp.concatenate([h(x) for h in h_list], axis=0)\n", + " return h_fn\n", + "\n", + " h_fn_combined = make_combined_h_fn(h_cams)\n", + "\n", + " # === Initialize Kalman filter ===\n", + "\n", + " m0s, S0s, As, cov_mats, Cs = initialize_kalman_filter_geometric(ys_3d)\n", + " m0s = np.array([ys_3d[k, :10].mean(axis=0) for k in range(n_keypoints)])\n", + " s_finals = np.full(len(keypoint_names), smooth_param) if np.isscalar(smooth_param) else np.asarray(smooth_param)\n", + "\n", + " # === Apply EKF in latent 3D space using projected 2D observations ===\n", + " ms_all, Vs_all = [], []\n", + " for k in range(n_keypoints):\n", + " y_proj = np.concatenate([vmap(h)(ys_3d[k]) for h in h_cams], axis=1) # (T, 2V)\n", + " r_proj = np.concatenate([ensemble_vars_3d[k][:, :2] for _ in range(n_cameras)], axis=1) # (T, 2V)\n", + " \n", + " ms, Vs = dynamax_ekf_smooth_routine(\n", + " y=ys_3d[k],\n", + " m0=m0s[k],\n", + " S0=S0s[k],\n", + " A=As[k],\n", + " Q=s_finals[k] * cov_mats[k],\n", + " C=np.eye(3),\n", + " ensemble_vars=ensemble_vars_3d[k],\n", + " f_fn=None,\n", + " h_fn=None, \n", + " )\n", + "\n", + " # ms, Vs = dynamax_ekf_smooth_routine(\n", + " # y=y_proj,\n", + " # m0=m0s[k],\n", + " # S0=S0s[k],\n", + " # A=As[k],\n", + " # Q=s_finals[k] * cov_mats[k],\n", + " # C=None,\n", + " # ensemble_vars=r_proj,\n", + " # f_fn=None,\n", + " # h_fn=h_fn_combined, \n", + " # )\n", + "\n", + "\n", + " ms_all.append(np.array(ms))\n", + " Vs_all.append(np.array(Vs))\n", + "\n", + " ms_all = np.stack(ms_all, axis=0) # (K, T, 3)\n", + " Vs_all = np.stack(Vs_all, axis=0) # (K, T, 3, 3)\n", + "\n", + "\n", + " # === Reproject smoothed 3D estimates back to each camera ===\n", + " camera_arrs = [[] for _ in camera_names]\n", + " for k, keypoint in enumerate(keypoint_names):\n", + " ms_k = ms_all[k]\n", + " Vs_k = Vs_all[k]\n", + " inflated_vars_k = ensemble_vars_3d[k]\n", + " \n", + " # rebuild a no-distortion projector per cam using the same rvec,tvec,K\n", + " print(\"camgroup order:\", [getattr(cam, \"name\", f\"cam{i}\") for i,cam in enumerate(camgroup.cameras)])\n", + " print(\"marker_array order:\", marker_array.get_camera_names() if hasattr(marker_array, \"get_camera_names\") else \"unknown\")\n", + "\n", + " # Compare one frame k=0,t=0\n", + " k=0; t=0\n", + " for c in range(len(camgroup.cameras)):\n", + " obs = emA_unsmoothed_preds.slice(\"keypoints\", k).slice(\"cameras\", c).get_array(squeeze=True)[t]\n", + " prj = np.array(h_cams[c](ms_all[k][t]))\n", + " print(f\"c{c}: obs={obs}, proj={prj}, diff={obs-prj}\")\n", + " \n", + " for c, camera in enumerate(camgroup.cameras):\n", + " #xy_proj = camera.project(ms_k).reshape(-1, 2)\n", + " xy_proj = np.array(vmap(h_cams[c])(ms_k)) # (T, 2)\n", + " xy_obs = emA_unsmoothed_preds.slice(\"keypoints\", k).slice(\"cameras\", c).get_array(squeeze=True) # (T,2)\n", + " resid = xy_obs - xy_proj # (T,2)\n", + " print(f\"cam {c} mean residual (px):\", resid.mean(axis=0), \" std:\", resid.std(axis=0))\n", + " try:\n", + " cov2d_proj = camera.project_covariance(ms_k, Vs_k)\n", + " var_x = cov2d_proj[:, 0, 0] + inflated_vars_k[:, 0]\n", + " var_y = cov2d_proj[:, 1, 1] + inflated_vars_k[:, 1]\n", + " except AttributeError:\n", + " var_x = np.full(ms_k.shape[0], np.nan)\n", + " var_y = np.full(ms_k.shape[0], np.nan)\n", + "\n", + " data_arr = camera_arrs[c]\n", + " data_arr.extend([\n", + " xy_proj[:, 0],\n", + " xy_proj[:, 1],\n", + " emA_likes.slice(\"keypoints\", k).slice(\"cameras\", c).get_array(squeeze=True),\n", + " emA_unsmoothed_preds.slice(\"keypoints\", k).slice(\"cameras\", c).slice_fields(\"x\").get_array(squeeze=True),\n", + " emA_unsmoothed_preds.slice(\"keypoints\", k).slice(\"cameras\", c).slice_fields(\"y\").get_array(squeeze=True),\n", + " emA_vars.slice(\"keypoints\", k).slice(\"cameras\", c).slice_fields(\"var_x\").get_array(squeeze=True),\n", + " emA_vars.slice(\"keypoints\", k).slice(\"cameras\", c).slice_fields(\"var_y\").get_array(squeeze=True),\n", + " var_x,\n", + " var_y,\n", + " ])\n", + "\n", + " # === Format output ===\n", + " labels = ['x', 'y', 'likelihood', 'x_ens_median', 'y_ens_median',\n", + " 'x_ens_var', 'y_ens_var', 'x_posterior_var', 'y_posterior_var']\n", + " pdindex = make_dlc_pandas_index(keypoint_names, labels=labels)\n", + " camera_dfs = [pd.DataFrame(np.asarray(arr).T, columns=pdindex) for arr in camera_arrs]\n", + "\n", + " return camera_dfs, s_finals, h_cams, ys_3d\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "80c930d4", + "metadata": {}, + "outputs": [], + "source": [ + "from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_smoother, extended_kalman_filter\n", + "from dynamax.nonlinear_gaussian_ssm.models import (\n", + " ParamsNLGSSM,\n", + ")\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "from typing import Union, Tuple, Callable\n", + "from typeguard import typechecked\n", + "\n", + "ArrayLike = Union[np.ndarray, jax.Array]\n", + "\n", + "def dynamax_ekf_smooth_routine(\n", + " y: ArrayLike,\n", + " m0: ArrayLike,\n", + " S0: ArrayLike,\n", + " A: ArrayLike,\n", + " Q: ArrayLike,\n", + " C: ArrayLike | None,\n", + " ensemble_vars: ArrayLike, # shape (T, obs_dim)\n", + " f_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None,\n", + " h_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None\n", + ") -> Tuple[jnp.ndarray, jnp.ndarray]:\n", + " \"\"\"\n", + " Extended Kalman smoother using the Dynamax nonlinear interface,\n", + " allowing for time-varying observation noise.\n", + "\n", + " By default, uses linear dynamics and emissions: f(x) = Ax, h(x) = Cx.\n", + "\n", + " Args:\n", + " y: (T, obs_dim) observation sequence.\n", + " m0: (state_dim,) initial mean.\n", + " S0: (state_dim, state_dim) initial covariance.\n", + " A: (state_dim, state_dim) dynamics matrix.\n", + " Q: (state_dim, state_dim) process noise covariance.\n", + " C: (obs_dim, state_dim) emission matrix (optional).\n", + " ensemble_vars: (T, obs_dim) per-timestep observation noise variance.\n", + " f_fn: optional dynamics function f(x).\n", + " h_fn: optional emission function h(x).\n", + "\n", + " Returns:\n", + " smoothed_means: (T, state_dim)\n", + " smoothed_covariances: (T, state_dim, state_dim)\n", + " \"\"\"\n", + " y, m0, S0, A, Q, ensemble_vars = map(jnp.asarray, (y, m0, S0, A, Q, ensemble_vars))\n", + " C = jnp.asarray(C) if C is not None else None\n", + "\n", + " if f_fn is None:\n", + " f_fn = lambda x: A @ x\n", + " if h_fn is None:\n", + " if C is None:\n", + " raise ValueError(\"Must provide either emission matrix C or a nonlinear emission function h_fn.\")\n", + " h_fn = lambda x: C @ x\n", + " # Dynamically determine obs_dim from h_fn output\n", + " obs_dim = y.shape[1]\n", + " R_t = jnp.stack([jnp.diag(var_t[:obs_dim]) for var_t in ensemble_vars], axis=0) # shape (T, obs_dim, obs_dim)\n", + " params = ParamsNLGSSM(\n", + " initial_mean=m0,\n", + " initial_covariance=S0,\n", + " dynamics_function=f_fn,\n", + " dynamics_covariance=Q,\n", + " emission_function=h_fn,\n", + " emission_covariance=R_t,\n", + " )\n", + " #with jax.disable_jit():\n", + " posterior = extended_kalman_smoother(params, y)\n", + " return posterior.filtered_means, posterior.filtered_covariances" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "49d1e9fc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2816, 1408]\n", + "[2816, 1408]\n", + "[2816, 1696]\n", + "[2816, 1408]\n", + "[2816, 1408]\n", + "[2816, 1696]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-2.00586461 -0.04818341] std: [4.65336109 3.65989187]\n", + "cam 1 mean residual (px): [-0.400378 -0.55426035] std: [4.47750558 4.57930606]\n", + "cam 2 mean residual (px): [-4.25086082 4.0237834 ] std: [8.05280531 7.82694513]\n", + "cam 3 mean residual (px): [-2.11519798 -1.08246869] std: [3.01945266 2.22987753]\n", + "cam 4 mean residual (px): [-2.22338437 -1.25190865] std: [4.80872571 3.11726965]\n", + "cam 5 mean residual (px): [-2.02939332 0.71905944] std: [3.34086571 2.34249222]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-16.86640789 27.81901492] std: [21.06616917 10.16015718]\n", + "cam 1 mean residual (px): [ 1.27670095 31.39372914] std: [34.65760412 12.1620663 ]\n", + "cam 2 mean residual (px): [-16.98771473 14.33849501] std: [23.61536901 17.70184138]\n", + "cam 3 mean residual (px): [-4.87177248 25.47695466] std: [24.57880916 7.24002574]\n", + "cam 4 mean residual (px): [13.91017088 28.7900739 ] std: [27.64924222 9.96280995]\n", + "cam 5 mean residual (px): [ 6.45919793 28.26663876] std: [26.92549736 14.30860297]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-43.94862492 38.07620084] std: [39.50286349 37.61951063]\n", + "cam 1 mean residual (px): [-22.92389838 34.12073141] std: [60.91164161 53.11203159]\n", + "cam 2 mean residual (px): [-52.44768635 10.2830962 ] std: [37.41010619 48.59247171]\n", + "cam 3 mean residual (px): [ 6.26194338 35.10307059] std: [43.47138297 32.47832518]\n", + "cam 4 mean residual (px): [50.13023656 32.6284639 ] std: [40.05778705 44.31061814]\n", + "cam 5 mean residual (px): [35.77109648 40.20139873] std: [39.75000321 38.20310405]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-30.25113722 -40.99577707] std: [28.65042933 44.07906918]\n", + "cam 1 mean residual (px): [-34.98391788 -64.6586051 ] std: [41.16062318 54.55639073]\n", + "cam 2 mean residual (px): [-40.71146852 -33.90442907] std: [28.85536403 48.46962482]\n", + "cam 3 mean residual (px): [ 12.82158102 -37.16530327] std: [31.12152066 36.56094066]\n", + "cam 4 mean residual (px): [ 35.36500756 -54.25401178] std: [24.95655536 42.88833274]\n", + "cam 5 mean residual (px): [ 30.04054017 -24.99231918] std: [25.92538582 32.81484657]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-77.03501222 -14.06791509] std: [49.55689821 54.99093352]\n", + "cam 1 mean residual (px): [-121.43571681 -59.02314783] std: [72.04167148 74.07204966]\n", + "cam 2 mean residual (px): [-127.85352757 -25.78346588] std: [55.91121001 66.00386657]\n", + "cam 3 mean residual (px): [ 56.73325089 -15.44272932] std: [49.70291459 45.57550667]\n", + "cam 4 mean residual (px): [118.81426752 -50.95219002] std: [53.48983305 57.99132334]\n", + "cam 5 mean residual (px): [107.70158039 -10.72051293] std: [51.63925117 43.73291347]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-107.73449426 -25.13505912] std: [75.17239777 57.95634324]\n", + "cam 1 mean residual (px): [-265.44229017 -111.57601448] std: [122.9936259 83.63128374]\n", + "cam 2 mean residual (px): [-226.29392882 -25.78944817] std: [96.08242022 79.07648141]\n", + "cam 3 mean residual (px): [127.76895727 -33.62190642] std: [79.13346773 49.20533222]\n", + "cam 4 mean residual (px): [ 214.87055494 -109.10578232] std: [89.70933629 69.53311219]\n", + "cam 5 mean residual (px): [206.38994512 -45.90479289] std: [90.30455758 59.17178075]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-20.05716054 15.86566925] std: [15.31337375 8.3487262 ]\n", + "cam 1 mean residual (px): [-1.81428365 15.14691077] std: [27.12656552 11.97268196]\n", + "cam 2 mean residual (px): [-19.78322016 4.58641567] std: [18.58410842 14.75168836]\n", + "cam 3 mean residual (px): [-3.86281001 15.00418717] std: [19.08585247 6.15065234]\n", + "cam 4 mean residual (px): [17.14123839 15.34504954] std: [22.33711894 10.23454471]\n", + "cam 5 mean residual (px): [ 8.98992669 20.32010298] std: [21.78587376 11.17949793]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-31.64188096 3.83078018] std: [15.8083754 29.82208695]\n", + "cam 1 mean residual (px): [-10.15232963 -4.38954052] std: [28.9432177 39.41426118]\n", + "cam 2 mean residual (px): [-31.68059102 -9.37822028] std: [18.92196291 30.33923798]\n", + "cam 3 mean residual (px): [-1.09225093 4.94040225] std: [20.64178028 25.81219085]\n", + "cam 4 mean residual (px): [30.50141194 0.16262853] std: [18.36859319 33.19362351]\n", + "cam 5 mean residual (px): [18.75977526 14.56760995] std: [20.02123306 24.32289152]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-45.63755893 7.96823886] std: [23.85887077 47.18275982]\n", + "cam 1 mean residual (px): [-16.02835436 -3.75159711] std: [39.33860605 61.70858653]\n", + "cam 2 mean residual (px): [-45.67716142 -13.88146324] std: [25.67134238 44.52143834]\n", + "cam 3 mean residual (px): [0.24121305 9.63683264] std: [28.12789429 41.08081596]\n", + "cam 4 mean residual (px): [46.60601802 2.736363 ] std: [24.2237432 51.62540431]\n", + "cam 5 mean residual (px): [29.43244706 22.25706045] std: [25.84400445 37.5825051 ]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-53.01464164 -54.71981158] std: [35.82581259 50.09749436]\n", + "cam 1 mean residual (px): [-66.18716647 -95.0638527 ] std: [51.59764782 65.77213503]\n", + "cam 2 mean residual (px): [-73.51290185 -50.69122393] std: [39.62568939 55.19435951]\n", + "cam 3 mean residual (px): [ 27.11963099 -49.21706219] std: [36.57808591 41.87582014]\n", + "cam 4 mean residual (px): [ 68.78763228 -78.53982291] std: [37.86301571 52.00905446]\n", + "cam 5 mean residual (px): [ 58.83761913 -33.31484122] std: [36.69507741 37.25871558]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-27.48803853 -68.9780643 ] std: [28.86503036 51.00612907]\n", + "cam 1 mean residual (px): [-23.11555466 -96.6514439 ] std: [41.02268415 66.51927875]\n", + "cam 2 mean residual (px): [-28.40975361 -54.02259942] std: [27.75373633 53.37846493]\n", + "cam 3 mean residual (px): [ 4.31867099 -60.81992412] std: [30.67331701 42.8880862 ]\n", + "cam 4 mean residual (px): [ 26.01591817 -79.63224592] std: [27.38312343 52.87647067]\n", + "cam 5 mean residual (px): [ 18.82322649 -40.97111631] std: [26.47418793 36.46591139]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-10.83391517 17.11601076] std: [18.66345732 9.56010962]\n", + "cam 1 mean residual (px): [-5.64380896 18.17178012] std: [25.47195893 12.95737098]\n", + "cam 2 mean residual (px): [-15.81379777 12.05263804] std: [16.14768232 18.77020813]\n", + "cam 3 mean residual (px): [ 0.36446203 14.60267179] std: [19.09077463 7.12737293]\n", + "cam 4 mean residual (px): [ 9.74926083 15.46465836] std: [20.28911007 10.00831862]\n", + "cam 5 mean residual (px): [ 7.08709365 15.35813022] std: [18.31135053 13.90714397]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-14.70415945 5.94012917] std: [26.89201462 31.44245168]\n", + "cam 1 mean residual (px): [-18.58007466 1.22947655] std: [33.45409109 40.48061898]\n", + "cam 2 mean residual (px): [-25.29845403 4.64110098] std: [20.5573175 37.54246641]\n", + "cam 3 mean residual (px): [7.57065902 3.82773206] std: [26.09452606 26.51365765]\n", + "cam 4 mean residual (px): [17.51522829 -0.06323189] std: [22.11948331 32.61460809]\n", + "cam 5 mean residual (px): [16.29940983 4.89736708] std: [19.6988043 28.20827179]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-23.41938215 14.44121108] std: [40.35550002 46.97320393]\n", + "cam 1 mean residual (px): [-33.81880448 7.43041355] std: [51.00221251 61.34696419]\n", + "cam 2 mean residual (px): [-41.89819398 8.90692818] std: [29.37183653 55.51384624]\n", + "cam 3 mean residual (px): [16.33399736 10.86549482] std: [39.66516442 39.85803748]\n", + "cam 4 mean residual (px): [32.39365723 4.64546822] std: [32.76761261 49.46399521]\n", + "cam 5 mean residual (px): [31.08093032 9.83292101] std: [29.15675992 41.8012197 ]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-45.45435744 -54.61607257] std: [39.64273023 48.26931621]\n", + "cam 1 mean residual (px): [-74.85878053 -94.18487352] std: [56.23314573 63.06956186]\n", + "cam 2 mean residual (px): [-73.46970474 -43.86349541] std: [40.39699254 56.30734317]\n", + "cam 3 mean residual (px): [ 34.25225743 -51.24268605] std: [40.82544648 40.30805024]\n", + "cam 4 mean residual (px): [ 64.82452392 -81.53024431] std: [39.03215943 49.94257993]\n", + "cam 5 mean residual (px): [ 60.9854295 -40.16336504] std: [37.80814396 39.1601131 ]\n", + "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", + "marker_array order: unknown\n", + "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", + "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", + "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", + "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", + "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", + "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", + "cam 0 mean residual (px): [-17.62599391 -66.77606412] std: [34.05057432 49.62368119]\n", + "cam 1 mean residual (px): [-31.71298444 -92.47139918] std: [45.62965855 63.94357298]\n", + "cam 2 mean residual (px): [-27.17331358 -44.40016944] std: [27.94681585 53.8826323 ]\n", + "cam 3 mean residual (px): [ 12.01686193 -61.44974139] std: [35.06009354 42.20340585]\n", + "cam 4 mean residual (px): [ 19.79745444 -80.43697436] std: [28.57462023 51.59245566]\n", + "cam 5 mean residual (px): [ 19.93228684 -47.6077619 ] std: [26.79455762 39.85515655]\n", + "see example EKS output at ./outputs/multicam_rightFoot.pdf\n" + ] + } + ], + "source": [ + "from eks.utils import plot_results\n", + "\n", + "input_source = \"./data/chickadee_uncropped\"\n", + "camera_names = [\"lBack\", \"lFront\", \"lTop\", \"rBack\", \"rFront\", \"rTop\"]\n", + "keypoints = [\"topBeak\", \"topHead\", \"backHead\", \"centerChes\", \"baseTail\", \"tipTail\", \"leftEye\", \"leftNeck\", \"leftWing\", \"leftAnkle\", \"leftFoot\", \"rightEye\", \"rightNeck\", \"rightWing\", \"rightAnkle\", \"rightFoot\"]\n", + "camgroup = CameraGroup.load(\"./data/chickadee/calibration.toml\")\n", + "# input_source = \"./data/fly\"\n", + "# camera_names = [\"Cam-A\", \"Cam-B\", \"Cam-C\", \"Cam-D\", \"Cam-E\", \"Cam-F\"]\n", + "# keypoints = [\"L1A\", \"L1B\"]\n", + "# camgroup = CameraGroup.load(\"./data/fly/calibration.toml\")\n", + "\n", + "save_dir = \"./outputs/\"\n", + "\n", + "# Load calibration file\n", + "\n", + "\n", + "camera_dfs, s_finals, input_dfs, bodypart_list, marker_array, h_cams, ys_3d = fit_eks_multicam(\n", + " input_source=input_source,\n", + " save_dir=save_dir,\n", + " bodypart_list=keypoints,\n", + " smooth_param=10,\n", + " camera_names=camera_names,\n", + " quantile_keep_pca=95,\n", + " verbose=True,\n", + " inflate_vars=False,\n", + " n_latent=3,\n", + " backend=\"dynamax-ekf\",\n", + " camgroup=camgroup\n", + ")\n", + "\n", + "keypoint_i = -1\n", + "camera_c = -1\n", + "plot_results(\n", + " output_df=camera_dfs[camera_c],\n", + " input_dfs_list=input_dfs[camera_c],\n", + " key=f'{bodypart_list[keypoint_i]}',\n", + " idxs=(0, 500),\n", + " s_final=s_finals[keypoint_i],\n", + " nll_values=None,\n", + " save_dir=save_dir,\n", + " smoother_type='multicam',\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "63fd6ec4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Triangulated 3D point: [-0.05837184 -0.28620351 0.42230134]\n", + "Camera 0: reprojection error = 5.255 pixels\n", + "Camera 1: reprojection error = 5.623 pixels\n", + "Camera 2: reprojection error = 2.997 pixels\n", + "Camera 3: reprojection error = 2.839 pixels\n", + "Camera 4: reprojection error = 0.815 pixels\n", + "Camera 5: reprojection error = 6.390 pixels\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import jax.numpy as jnp\n", + "\n", + "# === Settings ===\n", + "frame_idx = 120\n", + "keypoint_idx = 1\n", + "model_idx = 0 # or average across models\n", + "\n", + "# === Step 1: Extract 2D predictions from all cameras ===\n", + "raw_array = marker_array.get_array() # (n_models, n_cameras, n_frames, n_keypoints, 2+)\n", + "xy_views = [raw_array[model_idx, c, frame_idx, keypoint_idx, :2] for c in range(len(camgroup.cameras))]\n", + "xy_views_np = np.stack(xy_views) # (n_cameras, 2)\n", + "\n", + "# === Step 2: Triangulate to get 3D point ===\n", + "x_3d = camgroup.triangulate(xy_views_np) # shape (3,)\n", + "\n", + "print(f\"Triangulated 3D point: {x_3d}\")\n", + "\n", + "# === Step 3: Reproject into each view ===\n", + "projected_views = [h(jnp.array(x_3d)) for h in h_cams]\n", + "projected_views_np = np.stack([np.array(p) for p in projected_views]) # (n_cameras, 2)\n", + "\n", + "# === Step 4: Compute reprojection error per view ===\n", + "for i, (orig, proj) in enumerate(zip(xy_views_np, projected_views_np)):\n", + " err = np.linalg.norm(orig - proj)\n", + " print(f\"Camera {i}: reprojection error = {err:.3f} pixels\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6cc22bf1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Triangulated point: [-0.09515544 -0.22536958 0.36407084]\n", + "PCA-reconstructed point:[-0.09511244 -0.22479253 0.36471556]\n", + "\n", + "Camera 0: tri e_custom=4.01px, e_cg=4.01px, projΔ=0.00px | pca e_custom=3.46px, e_cg=3.46px, projΔ=0.00px\n", + "Camera 1: tri e_custom=12.93px, e_cg=12.93px, projΔ=0.00px | pca e_custom=13.93px, e_cg=13.93px, projΔ=0.00px\n", + "Camera 2: tri e_custom=2.23px, e_cg=2.23px, projΔ=0.00px | pca e_custom=1.27px, e_cg=1.27px, projΔ=0.00px\n", + "Camera 3: tri e_custom=2.84px, e_cg=2.84px, projΔ=0.00px | pca e_custom=2.60px, e_cg=2.60px, projΔ=0.00px\n", + "Camera 4: tri e_custom=5.52px, e_cg=5.52px, projΔ=0.00px | pca e_custom=4.95px, e_cg=4.95px, projΔ=0.00px\n", + "Camera 5: tri e_custom=1.84px, e_cg=1.84px, projΔ=0.00px | pca e_custom=2.30px, e_cg=2.30px, projΔ=0.00px\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from sklearn.decomposition import PCA\n", + "import jax.numpy as jnp\n", + "\n", + "# === Settings ===\n", + "frame_idx = 45\n", + "keypoint_idx = 0\n", + "model_idx = 0 # pick one or average if needed\n", + "\n", + "# --- helper: make camgroup.project output (n_cams, 2) ---\n", + "def cg_project_point(camgroup, x3d):\n", + " \"\"\"Project a single 3D point with aniposelib CameraGroup, return (n_cams, 2).\"\"\"\n", + " x = np.asarray(x3d, dtype=float).reshape(1, 3)\n", + " out = camgroup.project(x) # library-dependent shape\n", + " # Try common shapes: (n_cams, 1, 2), (n_cams, 2), dict of cam->(1,2)\n", + " if isinstance(out, dict):\n", + " proj = np.stack([np.asarray(out[cam.name])[0] for cam in camgroup.cameras], axis=0)\n", + " else:\n", + " arr = np.asarray(out)\n", + " if arr.ndim == 3 and arr.shape[1] == 1 and arr.shape[2] == 2:\n", + " proj = arr[:, 0, :]\n", + " elif arr.ndim == 2 and arr.shape == (len(camgroup.cameras), 2):\n", + " proj = arr\n", + " elif arr.ndim == 2 and arr.shape == (2, len(camgroup.cameras)):\n", + " proj = arr.T\n", + " else:\n", + " raise ValueError(f\"Unexpected camgroup.project shape: {arr.shape}\")\n", + " return proj # (n_cams, 2)\n", + "\n", + "# === 1) 2D observations for this keypoint+frame from all cameras ===\n", + "raw_array = marker_array.get_array() # (n_models, n_cams, n_frames, n_keypoints, 2+)\n", + "n_cams = len(camgroup.cameras)\n", + "xy_views_np = np.stack(\n", + " [raw_array[model_idx, c, frame_idx, keypoint_idx, :2] for c in range(n_cams)],\n", + " axis=0\n", + ") # (n_cams, 2)\n", + "\n", + "# === 2) Triangulated 3D point ===\n", + "x_triang = camgroup.triangulate(xy_views_np) # (3,)\n", + "\n", + "# === 3) PCA-reconstructed 3D point ===\n", + "# Assume ys_3d: (K, T, 3)\n", + "ys_3d_reshaped = ys_3d.reshape(-1, 3)\n", + "pca = PCA(n_components=3)\n", + "Z = pca.fit_transform(ys_3d_reshaped)\n", + "ys_3d_pca = pca.inverse_transform(Z).reshape(ys_3d.shape)\n", + "x_pca = ys_3d_pca[keypoint_idx, frame_idx] # (3,)\n", + "\n", + "# === 4) Project both 3D points with:\n", + "# (a) your custom JAX projectors h_cams\n", + "# (b) camgroup.project (OpenCV-based)\n", + "reproj_triang_custom = np.stack([np.array(h(jnp.array(x_triang))) for h in h_cams], axis=0) # (n_cams, 2)\n", + "reproj_pca_custom = np.stack([np.array(h(jnp.array(x_pca))) for h in h_cams], axis=0)\n", + "\n", + "reproj_triang_cg = cg_project_point(camgroup, x_triang) # (n_cams, 2)\n", + "reproj_pca_cg = cg_project_point(camgroup, x_pca)\n", + "\n", + "# === 5) Print comparison ===\n", + "print(f\"Triangulated point: {x_triang}\")\n", + "print(f\"PCA-reconstructed point:{x_pca}\\n\")\n", + "\n", + "for i in range(n_cams):\n", + " obs = xy_views_np[i]\n", + "\n", + " # errors vs observations\n", + " err_tri_custom = np.linalg.norm(obs - reproj_triang_custom[i])\n", + " err_tri_cg = np.linalg.norm(obs - reproj_triang_cg[i])\n", + "\n", + " err_pca_custom = np.linalg.norm(obs - reproj_pca_custom[i])\n", + " err_pca_cg = np.linalg.norm(obs - reproj_pca_cg[i])\n", + "\n", + " # difference between projectors (should be ~0 if both are consistent)\n", + " diff_tri = np.linalg.norm(reproj_triang_custom[i] - reproj_triang_cg[i])\n", + " diff_pca = np.linalg.norm(reproj_pca_custom[i] - reproj_pca_cg[i])\n", + "\n", + " print(\n", + " f\"Camera {i}: \"\n", + " f\"tri e_custom={err_tri_custom:.2f}px, e_cg={err_tri_cg:.2f}px, projΔ={diff_tri:.2f}px | \"\n", + " f\"pca e_custom={err_pca_custom:.2f}px, e_cg={err_pca_cg:.2f}px, projΔ={diff_pca:.2f}px\"\n", + " )\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/eks/core.py b/eks/core.py index eb90df6..da8f69f 100644 --- a/eks/core.py +++ b/eks/core.py @@ -12,7 +12,7 @@ from eks.marker_array import MarkerArray from eks.utils import crop_frames -from eks.kalman_backends import dynamax_linear_smooth_routine +from eks.kalman_backends import dynamax_ekf_smooth_routine # ------------------------------------------------------------------------------------- # Kalman Functions: Functions related to performing filtering and smoothing @@ -424,14 +424,13 @@ def final_forwards_backwards_pass( A = As[k] # (D, D) Q = Qs[k] # (D, D) C = Cs[k] # (obs_dim, D) - per_timestep_vars = ensemble_vars[:, k, :] # (T, obs_dim) + keypoint_vars = ensemble_vars[:, k, :] # (T, obs_dim) if backend == 'jax': - mf, Vf, _ = forward_pass(y, m0, S0, A, Q, C, ensemble_vars[:, k, :]) + mf, Vf, _ = forward_pass(y, m0, S0, A, Q, C, keypoint_vars) ms, Vs = backward_pass(mf, Vf, A, Q) - - elif backend == 'dynamax-linear': - ms, Vs = dynamax_linear_smooth_routine(y, m0, S0, A, Q, C, per_timestep_vars) + elif backend == 'dynamax-ekf': + ms, Vs = dynamax_ekf_smooth_routine(y, m0, S0, A, Q, C, keypoint_vars) else: raise ValueError(f"Unsupported backend: {backend}") diff --git a/eks/kalman_backends.py b/eks/kalman_backends.py index 76dcbaa..ccd61b4 100644 --- a/eks/kalman_backends.py +++ b/eks/kalman_backends.py @@ -1,68 +1,70 @@ -from dynamax.linear_gaussian_ssm.models import ( - LinearGaussianSSM, - ParamsLGSSM, - ParamsLGSSMInitial, - ParamsLGSSMDynamics, - ParamsLGSSMEmissions +from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_smoother +from dynamax.nonlinear_gaussian_ssm.models import ( + ParamsNLGSSM, ) + import jax import jax.numpy as jnp import numpy as np -from typing import Union, Tuple +from typing import Union, Tuple, Callable from typeguard import typechecked ArrayLike = Union[np.ndarray, jax.Array] @typechecked -def dynamax_linear_smooth_routine( +def dynamax_ekf_smooth_routine( y: ArrayLike, m0: ArrayLike, S0: ArrayLike, A: ArrayLike, Q: ArrayLike, - C: ArrayLike, - ensemble_vars: ArrayLike # shape (T, obs_dim) + C: ArrayLike | None, + ensemble_vars: ArrayLike, # shape (T, obs_dim) + f_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None, + h_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ - Run Dynamax smoother with time-varying diagonal observation noise from ensemble variances. + Extended Kalman smoother using the Dynamax nonlinear interface, + allowing for time-varying observation noise. + + By default, uses linear dynamics and emissions: f(x) = Ax, h(x) = Cx. Args: - y: (T, obs_dim) observations - m0: (state_dim,) initial mean - S0: (state_dim, state_dim) initial covariance - A: (state_dim, state_dim) transition matrix - Q: (state_dim, state_dim) process noise - C: (obs_dim, state_dim) emission matrix - ensemble_vars: (T, obs_dim) per-timestep observation noise variance + y: (T, obs_dim) observation sequence. + m0: (state_dim,) initial mean. + S0: (state_dim, state_dim) initial covariance. + A: (state_dim, state_dim) dynamics matrix. + Q: (state_dim, state_dim) process noise covariance. + C: (obs_dim, state_dim) emission matrix (optional). + ensemble_vars: (T, obs_dim) per-timestep observation noise variance. + f_fn: optional dynamics function f(x). + h_fn: optional emission function h(x). Returns: smoothed_means: (T, state_dim) - smoothed_covs: (T, state_dim, state_dim) + smoothed_covariances: (T, state_dim, state_dim) """ + y, m0, S0, A, Q, ensemble_vars = map(jnp.asarray, (y, m0, S0, A, Q, ensemble_vars)) + C = jnp.asarray(C) if C is not None else None - # Convert everything to JAX arrays - y, m0, S0, A, Q, C, ensemble_vars = map(jnp.asarray, (y, m0, S0, A, Q, C, ensemble_vars)) - T, obs_dim = y.shape - state_dim = A.shape[0] - - # Dynamax accepts time-varying diagonal R_t as (T, obs_dim) - model = LinearGaussianSSM(state_dim, obs_dim) + if f_fn is None: + f_fn = lambda x: A @ x + if h_fn is None: + if C is None: + raise ValueError("Must provide either emission matrix C or a nonlinear emission function h_fn.") + h_fn = lambda x: C @ x - params = ParamsLGSSM( - initial=ParamsLGSSMInitial(mean=m0, cov=S0), - dynamics=ParamsLGSSMDynamics( - weights=A, - cov=Q, - bias=jnp.zeros(state_dim), - input_weights=jnp.zeros((state_dim, 0)) - ), - emissions=ParamsLGSSMEmissions( - weights=C, - cov=ensemble_vars, # <=== time-varying diagonal noise - bias=jnp.zeros(obs_dim), - input_weights=jnp.zeros((obs_dim, 0)) - ) + # Dynamically determine obs_dim from h_fn output + obs_dim = y.shape[1] + R_t = jnp.stack([jnp.diag(var_t[:obs_dim]) for var_t in ensemble_vars], axis=0) # shape (T, obs_dim, obs_dim) + params = ParamsNLGSSM( + initial_mean=m0, + initial_covariance=S0, + dynamics_function=f_fn, + dynamics_covariance=Q, + emission_function=h_fn, + emission_covariance=R_t, ) - posterior = model.smoother(params, y) - return posterior.smoothed_means, posterior.smoothed_covariances + posterior = extended_kalman_smoother(params, y) + return posterior.smoothed_means, posterior.smoothed_covariances \ No newline at end of file From 3a57255e4ba32982caa0f719987426dee9b383fb Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Fri, 3 Oct 2025 13:33:13 -0400 Subject: [PATCH 05/11] dynamax backend refactor! --- eks/command_line_args.py | 6 - eks/core.py | 750 +++++++++---------------------- eks/ibl_paw_multicam_smoother.py | 10 +- eks/ibl_pupil_smoother.py | 287 +++++++----- eks/kalman_backends.py | 70 --- eks/multicam_smoother.py | 13 +- eks/singlecam_smoother.py | 5 +- eks/utils.py | 36 ++ scripts/multicam_example.py | 2 - scripts/singlecam_example.py | 2 - tests/test_multicam_smoother.py | 21 +- tests/test_singlecam_smoother.py | 12 +- 12 files changed, 440 insertions(+), 774 deletions(-) delete mode 100644 eks/kalman_backends.py diff --git a/eks/command_line_args.py b/eks/command_line_args.py index 99a8be0..86bb059 100644 --- a/eks/command_line_args.py +++ b/eks/command_line_args.py @@ -71,12 +71,6 @@ def handle_parse_args(script_type): default='', type=str, ) - parser.add_argument( - '--backend', - help='Options: jax, dynamax-linear. Determines the backend to be used for smoothing.', - default='jax', - type=str - ) if script_type == 'singlecam': add_bodyparts(parser) add_s(parser) diff --git a/eks/core.py b/eks/core.py index da8f69f..c99be07 100644 --- a/eks/core.py +++ b/eks/core.py @@ -1,22 +1,14 @@ -from functools import partial - import jax -import jax.scipy as jsc import numpy as np import optax -from jax import jit -from jax import numpy as jnp -from jax import vmap +from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, extended_kalman_filter, \ + extended_kalman_smoother +from jax import numpy as jnp, vmap, jit, value_and_grad, lax from typeguard import typechecked -from typing import List, Literal, Optional, Tuple, Union +from typing import Literal, Union, Optional, List, Tuple from eks.marker_array import MarkerArray -from eks.utils import crop_frames -from eks.kalman_backends import dynamax_ekf_smooth_routine - -# ------------------------------------------------------------------------------------- -# Kalman Functions: Functions related to performing filtering and smoothing -# ------------------------------------------------------------------------------------- +from eks.utils import build_R_from_vars, crop_frames, crop_R @typechecked @@ -87,366 +79,6 @@ def compute_stats(data_x, data_y, data_lh): return ensemble_marker_array -@typechecked -def kalman_filter_step( - carry, - inputs -) -> Tuple[tuple, Tuple[jnp.ndarray, jnp.ndarray, jax.Array]]: - """ - Performs a single Kalman filter update step using time-varying observation noise - from ensemble variance. - - Used in a scan loop, updating the state mean and covariance - based on the current observation and its associated ensemble variance. - - Args: - carry: Tuple containing the previous state and model parameters: - - m_prev (jnp.ndarray): Previous state estimate (mean vector). - - V_prev (jnp.ndarray): Previous state covariance matrix. - - A (jnp.ndarray): State transition matrix. - - Q (jnp.ndarray): Process noise covariance matrix. - - C (jnp.ndarray): Observation matrix. - - nll_net (float): Accumulated negative log-likelihood. - inputs: Tuple containing the current observation and its estimated ensemble variance: - - curr_y (jnp.ndarray): Current observation vector. - - curr_ensemble_var (jnp.ndarray): Estimated observation noise variance - (used to build time-varying R matrix). - - Returns: - A tuple of two elements: - - carry (tuple): Updated (m_t, V_t, A, Q, C, nll_net) to pass to the next step. - - output (tuple): Tuple of: - - m_t (jnp.ndarray): Updated state mean. - - V_t (jnp.ndarray): Updated state covariance. - - nll_current (float, stored as jax.Array): NLL of the current observation. - """ - m_prev, V_prev, A, Q, C, nll_net = carry - curr_y, curr_ensemble_var = inputs - - # Update R with time-varying ensemble variance - R = jnp.diag(curr_ensemble_var) - - # Predict - m_pred = jnp.dot(A, m_prev) - V_pred = jnp.dot(A, jnp.dot(V_prev, A.T)) + Q - - # Update - innovation = curr_y - jnp.dot(C, m_pred) - innovation_cov = jnp.dot(C, jnp.dot(V_pred, C.T)) + R - K = jnp.dot(V_pred, jnp.dot(C.T, jnp.linalg.inv(innovation_cov))) - m_t = m_pred + jnp.dot(K, innovation) - V_t = jnp.dot((jnp.eye(V_pred.shape[0]) - jnp.dot(K, C)), V_pred) - - nll_current = single_timestep_nll(innovation, innovation_cov) - nll_net = nll_net + nll_current - - return (m_t, V_t, A, Q, C, nll_net), (m_t, V_t, nll_current) - - -@typechecked -def kalman_filter_step_nlls( - carry: tuple, - inputs: tuple -) -> Tuple[tuple, Tuple[jnp.ndarray, jnp.ndarray, float]]: - """ - Performs a single Kalman filter update step and records per-timestep negative - log-likelihoods (NLLs) into a preallocated array. - - Used inside a `lax.scan` loop. In addition to updating the state estimate and total NLL, - it writes the NLL of each timestep into a persistent array for later analysis/plotting. - - Args: - carry: Tuple containing: - - m_prev (jnp.ndarray): Previous state estimate (mean vector). - - V_prev (jnp.ndarray): Previous state covariance matrix. - - A (jnp.ndarray): State transition matrix. - - Q (jnp.ndarray): Process noise covariance matrix. - - C (jnp.ndarray): Observation matrix. - - nll_net (float): Cumulative negative log-likelihood. - - nll_array (jnp.ndarray): Preallocated array for per-step NLL values. - - t (int): Current timestep index into the NLL array. - - inputs: Tuple containing: - - curr_y (jnp.ndarray): Current observation vector. - - curr_ensemble_var (jnp.ndarray): Estimated observation noise variance, - used to construct the time-varying R matrix. - - Returns: - A tuple of: - - carry (tuple): Updated state and NLL tracking info for the next timestep. - - output (tuple): - - m_t (jnp.ndarray): Updated state mean. - - V_t (jnp.ndarray): Updated state covariance. - - nll_current (float): Negative log-likelihood of the current timestep. - """ - # Unpack carry and inputs - m_prev, V_prev, A, Q, C, nll_net, nll_array, t = carry - curr_y, curr_ensemble_var = inputs - - # Update R with the current ensemble variance - R = jnp.diag(curr_ensemble_var) - - # Predict - m_pred = jnp.dot(A, m_prev) - V_pred = jnp.dot(A, jnp.dot(V_prev, A.T)) + Q - - # Update - innovation = curr_y - jnp.dot(C, m_pred) - innovation_cov = jnp.dot(C, jnp.dot(V_pred, C.T)) + R - K = jnp.dot(V_pred, jnp.dot(C.T, jnp.linalg.inv(innovation_cov))) - m_t = m_pred + jnp.dot(K, innovation) - V_t = V_pred - jnp.dot(K, jnp.dot(C, V_pred)) - - # Compute the negative log-likelihood for the current time step - nll_current = single_timestep_nll(innovation, innovation_cov) - - # Accumulate the negative log-likelihood - nll_net = nll_net + nll_current - - # Save the current NLL to the preallocated array - nll_array = nll_array.at[t].set(nll_current) - - # Increment the time step - t = t + 1 - - # Return the updated state and outputs - return (m_t, V_t, A, Q, C, nll_net, nll_array, t), (m_t, V_t, nll_current) - - -@partial(jit, backend='cpu') -def forward_pass( - y: jnp.ndarray, - m0: jnp.ndarray, - cov0: jnp.ndarray, - A: jnp.ndarray, - Q: jnp.ndarray, - C: jnp.ndarray, - ensemble_vars: jnp.ndarray -) -> Tuple[jnp.ndarray, jnp.ndarray, float]: - """ - Executes the Kalman filter forward pass for a single keypoint over time, - incorporating time-varying observation noise variances. - - Computes filtered state means, covariances, and the cumulative - negative log-likelihood across all timesteps. Used within `vmap` to - handle multiple keypoints in parallel. - - Args: - y: Array of shape (T, obs_dim). Sequence of observations over time. - m0: Array of shape (state_dim,). Initial state estimate. - cov0: Array of shape (state_dim, state_dim). Initial state covariance. - A: Array of shape (state_dim, state_dim). State transition matrix. - Q: Array of shape (state_dim, state_dim). Process noise covariance matrix. - C: Array of shape (obs_dim, state_dim). Observation matrix. - ensemble_vars: Array of shape (T, obs_dim). Per-frame observation noise variances. - - Returns: - mfs: Array of shape (T, state_dim). Filtered mean estimates at each timestep. - Vfs: Array of shape (T, state_dim, state_dim). Filtered covariance estimates at each timestep. - nll_net: Scalar float. Total negative log-likelihood across all timesteps. - """ - # Initialize carry - carry = (m0, cov0, A, Q, C, 0) - # Run the scan, passing y and ensemble_vars as inputs to kalman_filter_step - carry, outputs = jax.lax.scan(kalman_filter_step, carry, (y, ensemble_vars)) - mfs, Vfs, _ = outputs - nll_net = carry[-1] - return mfs, Vfs, nll_net - - -@typechecked -def kalman_smoother_step( - carry: tuple, - X: list, -) -> Tuple[tuple, Tuple[jnp.ndarray, jnp.ndarray]]: - """ - Performs a single backward pass of the Kalman smoother. - - Updates the smoothed state estimate and covariance based on the - current filtered estimate and the next time step's smoothed estimate. Used - within a `jax.lax.scan` in reverse over the time axis. - - Args: - carry: Tuple containing: - - m_ahead_smooth (jnp.ndarray): Smoothed state mean at the next timestep. - - v_ahead_smooth (jnp.ndarray): Smoothed state covariance at the next timestep. - - A (jnp.ndarray): State transition matrix. - - Q (jnp.ndarray): Process noise covariance matrix. - - X: Tuple containing: - - m_curr_filter (jnp.ndarray): Filtered mean estimate at the current timestep. - - v_curr_filter (jnp.ndarray): Filtered covariance at the current timestep. - - Returns: - A tuple of: - - carry (tuple): Updated smoothed state (mean, cov) and model params for the next step. - - output (tuple): - - smoothed_state (jnp.ndarray): Smoothed mean estimate at the current timestep. - - smoothed_cov (jnp.ndarray): Smoothed covariance at the current timestep. - """ - m_ahead_smooth, v_ahead_smooth, A, Q = carry - m_curr_filter, v_curr_filter = X[0], X[1] - - # Compute the smoother gain - ahead_cov = jnp.dot(A, jnp.dot(v_curr_filter, A.T)) + Q - - smoothing_gain = jsc.linalg.solve(ahead_cov, jnp.dot(A, v_curr_filter.T)).T - smoothed_state = m_curr_filter + jnp.dot(smoothing_gain, m_ahead_smooth - m_curr_filter) - smoothed_cov = v_curr_filter + jnp.dot(jnp.dot(smoothing_gain, v_ahead_smooth - ahead_cov), - smoothing_gain.T) - - return (smoothed_state, smoothed_cov, A, Q), (smoothed_state, smoothed_cov) - - -# @typechecked -- raises InstrumentationWarning as @jit rewrites into compiled form (JAX XLA) -@partial(jit, backend='cpu') -def backward_pass( - mfs: jnp.ndarray, - Vfs: jnp.ndarray, - A: jnp.ndarray, - Q: jnp.ndarray -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Executes the Kalman smoother backward pass using filtered means and covariances. - - Refines forward-filtered estimates by incorporating future observations. - Used after a Kalman filter forward pass to recover more accurate state estimates. - - Args: - mfs: Array of shape (T, state_dim). Filtered state means from the forward pass. - Vfs: Array of shape (T, state_dim, state_dim). Filtered covariances from the forward pass. - A: Array of shape (state_dim, state_dim). State transition matrix. - Q: Array of shape (state_dim, state_dim). Process noise covariance matrix. - - Returns: - smoothed_states: Array of shape (T, state_dim). Smoothed state mean estimates. - smoothed_state_covariances: Array of shape (T, state_dim, state_dim). - Smoothed state covariance estimates. - """ - carry = (mfs[-1], Vfs[-1], A, Q) - - # Reverse scan over the time steps - carry, outputs = jax.lax.scan( - kalman_smoother_step, - carry, - [mfs[:-1], Vfs[:-1]], - reverse=True - ) - - smoothed_states, smoothed_state_covariances = outputs - smoothed_states = jnp.append(smoothed_states, jnp.expand_dims(mfs[-1], 0), 0) - smoothed_state_covariances = jnp.append(smoothed_state_covariances, - jnp.expand_dims(Vfs[-1], 0), 0) - return smoothed_states, smoothed_state_covariances - - -@typechecked -def single_timestep_nll( - innovation: jnp.ndarray, - innovation_cov: jnp.ndarray -) -> jax.Array: - """ - Computes the negative log-likelihood (NLL) of a single multivariate Gaussian observation. - - Measures how well the predicted state explains the current observation. - A small regularization term (epsilon) is added to the covariance to ensure numerical stability. - - Args: - innovation: Array of shape (D,). The difference between observed and predicted observation. - innovation_cov: Array of shape (D, D). Covariance of the innovation. - - Returns: - nll_increment: Scalar float stored as a jax.Array. - Negative log-likelihood of observing the current innovation. - """ - epsilon = 1e-6 - n_coords = innovation.shape[0] - - # Regularize the innovation covariance matrix by adding epsilon to the diagonal - reg_innovation_cov = innovation_cov + epsilon * jnp.eye(n_coords) - - # Compute the log determinant of the regularized covariance matrix - log_det_S = jnp.log(jnp.abs(jnp.linalg.det(reg_innovation_cov)) + epsilon) - solved_term = jnp.linalg.solve(reg_innovation_cov, innovation) - quadratic_term = jnp.dot(innovation, solved_term) - - # Compute the NLL increment for the current time step - c = jnp.log(2 * jnp.pi) * n_coords # The Gaussian normalization constant part - nll_increment = 0.5 * jnp.abs(log_det_S + quadratic_term + c) - return nll_increment - - -@typechecked -def final_forwards_backwards_pass( - process_cov: jnp.ndarray, - s: np.ndarray, - ys: np.ndarray, - m0s: jnp.ndarray, - S0s: jnp.ndarray, - Cs: jnp.ndarray, - As: jnp.ndarray, - ensemble_vars: np.ndarray, - backend: str = 'jax' -) -> Tuple[np.ndarray, np.ndarray]: - """ - Runs the full Kalman forward-backward smoother across all keypoints using - optimized smoothing parameters. - - Computes smoothed state means and covariances for each keypoint over time. - The process noise covariance is scaled per-keypoint by a learned smoothing parameter `s`. - - Args: - process_cov: Array of shape (K, D, D). Base process noise covariance per keypoint. - s: Array of shape (K,). Smoothing scalars applied to process_cov per keypoint. - ys: Array of shape (K, T, obs_dim). Observations per keypoint over time. - m0s: Array of shape (K, D). Initial state mean per keypoint. - S0s: Array of shape (K, D, D). Initial state covariance per keypoint. - Cs: Array of shape (K, obs_dim, D). Observation matrix per keypoint. - As: Array of shape (K, D, D). State transition matrix per keypoint. - ensemble_vars: Array of shape (T, K, obs_dim). Time-varying obs variances per keypoint. - - Returns: - smoothed_means: Array of shape (K, T, D). Smoothed state means for each keypoint over time. - smoothed_covariances: Array of shape (K, T, D, D). Smoothed state covariances over time. - """ - - # Initialize - n_keypoints = ys.shape[0] - ms_array = [] - Vs_array = [] - Qs = s[:, None, None] * process_cov - - # Run forward and backward pass for each keypoint - for k in range(n_keypoints): - y = ys[k] # (T, obs_dim) - m0 = m0s[k] # (D,) - S0 = S0s[k] # (D, D) - A = As[k] # (D, D) - Q = Qs[k] # (D, D) - C = Cs[k] # (obs_dim, D) - keypoint_vars = ensemble_vars[:, k, :] # (T, obs_dim) - - if backend == 'jax': - mf, Vf, _ = forward_pass(y, m0, S0, A, Q, C, keypoint_vars) - ms, Vs = backward_pass(mf, Vf, A, Q) - elif backend == 'dynamax-ekf': - ms, Vs = dynamax_ekf_smooth_routine(y, m0, S0, A, Q, C, keypoint_vars) - - else: - raise ValueError(f"Unsupported backend: {backend}") - - ms_array.append(np.array(ms)) - Vs_array.append(np.array(Vs)) - - smoothed_means = np.stack(ms_array, axis=0) - smoothed_covariances = np.stack(Vs_array, axis=0) - - return smoothed_means, smoothed_covariances - -# ------------------------------------------------------------------------------------- -# Optimization: Functions related to optimizing the smoothing hyperparameter -# ------------------------------------------------------------------------------------- - @typechecked def compute_initial_guesses( @@ -476,199 +108,233 @@ def compute_initial_guesses( # Compute temporal differences temporal_diffs = ensemble_vars[1:] - ensemble_vars[:-1] - # Compute standard deviation across all temporal differences std_dev_guess = round(np.nanstd(temporal_diffs), 5) return float(std_dev_guess) +def params_nlgssm_for_keypoint(m0, S0, Q, s, R, f_fn, h_fn) -> ParamsNLGSSM: + """ + Construct the ParamsNLGSSM for a single (keypoint) sequence. + """ + return ParamsNLGSSM( + initial_mean=jnp.asarray(m0), + initial_covariance=jnp.asarray(S0), + dynamics_function=f_fn, + dynamics_covariance=jnp.asarray(s) * jnp.asarray(Q), + emission_function=h_fn, + emission_covariance=jnp.asarray(R), + ) + + @typechecked def optimize_smooth_param( - cov_mats: jnp.ndarray, - ys: np.ndarray, - m0s: jnp.ndarray, - S0s: jnp.ndarray, - Cs: jnp.ndarray, - As: jnp.ndarray, - ensemble_vars: np.ndarray, + Qs: jnp.ndarray, # (K, D, D) + ys: np.ndarray, # (K, T, obs) + m0s: jnp.ndarray, # (K, D) + S0s: jnp.ndarray, # (K, D, D) + Cs: jnp.ndarray, # (K, obs, D) + As: jnp.ndarray, # (K, D, D) + ensemble_vars: np.ndarray, # (T, K, obs) s_frames: Optional[List] = None, smooth_param: Optional[Union[float, List[float]]] = None, blocks: Optional[List[List[int]]] = None, - maxiter: int = 1000, verbose: bool = False, - backend: str = 'jax', + # JIT-closed constants: + lr: float = 0.25, + s_bounds_log: tuple = (-8.0, 8.0), + tol: float = 1e-3, + safety_cap: int = 5000, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ - Optimize smoothing parameters for each keypoint (or block of keypoints) using - negative log-likelihood minimization, and apply final Kalman forward-backward smoothing. + Optimize the process-noise scale `s` (shared within each block of keypoints) by minimizing + summed negative log-likelihood (NLL) under a *linear* state-space model using the + Dynamax EKF filter (fast), then produce final trajectories via the EKF smoother. - If `smooth_param` is provided, it is used directly. Otherwise, the function computes - initial guesses and uses gradient descent to optimize per-block values of `s`. + Model (per keypoint k): + x_{t+1} = A_k x_t + w_t, y_t = C_k x_t + v_t + w_t ~ N(0, s * Q_k), v_t ~ N(0, R_{k,t}) + + where R_{k,t} is **time-varying**, built from ensemble variances: + R_{k,t} = diag( clip( ensemble_vars[t, k, :], 1e-12, ∞ ) ). Args: - cov_mats: Array of shape (K, D, D). Base process noise covariances per keypoint. - ys: Array of shape (K, T, obs_dim). Observations per keypoint over time. - m0s: Array of shape (K, D). Initial state means per keypoint. - S0s: Array of shape (K, D, D). Initial state covariances per keypoint. - Cs: Array of shape (K, obs_dim, D). Observation matrices per keypoint. - As: Array of shape (K, D, D). State transition matrices per keypoint. - ensemble_vars: Array of shape (T, K, obs_dim). Time-varying ensemble variances. - s_frames: Optional list of frame indices for computing initial guess statistics. - smooth_param: Optional fixed value(s) of smoothing param `s`. - Can be a float or list of floats (one per keypoint/block). - blocks: Optional list of lists of keypoint indices to share a smoothing param. - Defaults to treating each keypoint independently. - maxiter: Max number of optimization steps per block. - verbose: If True, print progress logs. + Qs: (K, D, D) base process noise covariances Q_k per keypoint (scaled by `s`). + ys: (K, T, obs) observations per keypoint across time. + m0s: (K, D) initial state means per keypoint. + S0s: (K, D, D) initial state covariances per keypoint. + Cs: (K, obs, D) observation matrices C_k per keypoint. + As: (K, D, D) transition matrices A_k per keypoint. + ensemble_vars: (T, K, obs) per-frame ensemble variances for each keypoint’s obs dims; + used to construct time-varying R_{k,t}. + s_frames: Optional list of frame indices used for NLL optimization (cropping the loss). + Final smoothing always runs on the full sequence. + smooth_param: If provided, bypass optimization. + • float/int: same `s` for all keypoints; + • list[float] of length K: per-keypoint `s`. + blocks: Optional list of lists of keypoint indices; each block shares a single `s`. + Default: each keypoint forms its own block. + verbose: If True, prints per-block optimization summaries. + lr: Adam learning rate for optimizing log(s). + s_bounds_log: (low, high) clamp for log(s) during optimization. + tol: Relative tolerance on loss change for early stopping. + safety_cap: Hard limit on iterations inside the jitted while-loop. Returns: - s_finals: Array of shape (K,). Final smoothing parameter per keypoint. - ms: Array of shape (K, T, D). Smoothed state means. - Vs: Array of shape (K, T, D, D). Smoothed state covariances. + s_finals: (K,) final `s` per keypoint (blockwise value broadcast to members). + ms: (K, T, D) smoothed state means. + Vs: (K, T, D, D) smoothed state covariances. + + Notes: + • NLL is computed with EKF *filter*; outputs use EKF *smoother*. + • Loss for a block is the sum of member keypoints’ NLLs (via vmap). + • All jitted helpers close over optimizer/tol/bounds to avoid passing Python objects. """ - - n_keypoints = ys.shape[0] - s_finals = [] - if blocks is None: - blocks = [] - if len(blocks) == 0: - for n in range(n_keypoints): - blocks.append([n]) + # -------------------- setup & time-varying R_t -------------------- + K, T, obs_dim = ys.shape + if not blocks: + blocks = [[k] for k in range(K)] if verbose: - print(f'Correlated keypoint blocks: {blocks}') - - @partial(jit) - def nll_loss_sequential_scan(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, ensemble_vars): - s = jnp.exp(s) # To ensure positivity - return smooth_min( - s, cov_mats, cropped_ys, m0s, S0s, Cs, As, ensemble_vars) - - loss_function = nll_loss_sequential_scan - - # Optimize smooth_param + print(f"Correlated keypoint blocks: {blocks}") + + # Build time-varying R + Rs = build_R_from_vars(np.swapaxes(ensemble_vars, 0, 1)) + Rs_j = jnp.asarray(Rs) + + # Device arrays once + ys_j = jnp.asarray(ys) + m0s_j = jnp.asarray(m0s) + S0s_j = jnp.asarray(S0s) + As_j = jnp.asarray(As) + Qs_j = jnp.asarray(Qs) + Cs_j = jnp.asarray(Cs) + + # Initial guesses per keypoint + s_guess_per_k = np.empty(K, dtype=float) + for k in range(K): + g = float(compute_initial_guesses(ensemble_vars[:, k, :]) or 2.0) + s_guess_per_k[k] = g if (np.isfinite(g) and g > 0.0) else 2.0 + + # -------------------- choose or optimize s -------------------- + s_finals = np.empty(K, dtype=float) if smooth_param is not None: - if isinstance(smooth_param, float): - s_finals = [smooth_param] - elif isinstance(smooth_param, int): - s_finals = [float(smooth_param)] + if isinstance(smooth_param, (int, float)): + s_finals[:] = float(smooth_param) else: - s_finals = smooth_param + s_finals[:] = np.asarray(smooth_param, dtype=float) else: - guesses = [] - cropped_ys = [] - for k in range(n_keypoints): - current_guess = compute_initial_guesses(ensemble_vars[:, k, :]) - guesses.append(current_guess) - if s_frames is None or len(s_frames) == 0: - cropped_ys.append(ys[k]) - else: - cropped_ys.append(crop_frames(ys[k], s_frames)) - - cropped_ys = np.array(cropped_ys) # Concatenation of this list along dimension 0 - - # Optimize negative log likelihood + optimizer = optax.adam(float(lr)) + s_bounds_log_j = jnp.array(s_bounds_log, dtype=jnp.float32) + tol_j = float(tol) + + def _params_linear(m0, S0, A, Q_base, s, R_any, C): + f_fn = (lambda x, A=A: A @ x) # linear dynamics + h_fn = (lambda x, C=C: C @ x) # linear emission + return params_nlgssm_for_keypoint(m0, S0, Q_base, s, R_any, f_fn, h_fn) + + # NLL for a single keypoint with time-varying R_t + def _nll_one_keypoint(log_s, y_k, m0_k, S0_k, A_k, Q_k, C_k, R_k_tv): + s = jnp.exp(jnp.clip(log_s, s_bounds_log_j[0], s_bounds_log_j[1])) + params = _params_linear(m0_k, S0_k, A_k, Q_k, s, R_k_tv, C_k) + post = extended_kalman_filter(params, jnp.asarray(y_k)) + return -post.marginal_loglik + + # Sum NLL across all keypoints in the block + def _nll_block(log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + nlls = vmap(_nll_one_keypoint, in_axes=(None, 0, 0, 0, 0, 0, 0, 0))( + log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + return jnp.sum(nlls) + + @jit + def _opt_step(log_s, opt_state, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + loss, grad = value_and_grad(_nll_block)( + log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + updates, opt_state = optimizer.update(grad, opt_state) + log_s = optax.apply_updates(log_s, updates) + return log_s, opt_state, loss + + @jit + def _run_tol_loop(log_s0, opt_state0, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + def cond(carry): + _, _, prev_loss, iters, done = carry + return jnp.logical_and(~done, iters < safety_cap) + + def body(carry): + log_s, opt_state, prev_loss, iters, _ = carry + log_s, opt_state, loss = _opt_step( + log_s, opt_state, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + rel_tol = tol_j * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12))) + done = jnp.where( + jnp.isfinite(prev_loss), + jnp.linalg.norm(loss - prev_loss) < (rel_tol + 1e-6), + False + ) + return (log_s, opt_state, loss, iters + 1, done) + + return lax.while_loop( + cond, body, (log_s0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) + ) + + # Optimize per block (shared s within each block) for block in blocks: - s_init = guesses[block[0]] - if s_init <= 0: - s_init = 2 - s_init = jnp.log(s_init) - optimizer = optax.adam(learning_rate=0.25) - opt_state = optimizer.init(s_init) - - selector = np.array(block).astype(int) - cov_mats_sub = cov_mats[selector] - m0s_crop = m0s[selector] - S0s_crop = S0s[selector] - Cs_crop = Cs[selector] - As_crop = As[selector] - y_subset = cropped_ys[selector] - ensemble_vars_crop = np.swapaxes(ensemble_vars[:, selector, :], 0, 1) - - def step(s, opt_state): - loss, grads = jax.value_and_grad(loss_function)( - s, cov_mats_sub, y_subset, m0s_crop, S0s_crop, Cs_crop, As_crop, - ensemble_vars_crop) - updates, opt_state = optimizer.update(grads, opt_state) - s = optax.apply_updates(s, updates) - return s, opt_state, loss - - prev_loss = jnp.inf - for iteration in range(maxiter): - s_init, opt_state, loss = step(s_init, opt_state) - - if verbose and iteration % 10 == 0 or iteration == maxiter - 1: - print(f'Iteration {iteration}, Current loss: {loss}, Current s: {s_init}') - - tol = 0.001 * jnp.abs(jnp.log(prev_loss)) - if jnp.linalg.norm(loss - prev_loss) < tol + 1e-6: - break - prev_loss = loss - - s_final = jnp.exp(s_init) # Convert back from log-space - - for b in block: - if verbose: - print(f's={s_final} for keypoint {b}') - s_finals.append(s_final) - - s_finals = np.array(s_finals) - # Final smooth with optimized s - ms, Vs = final_forwards_backwards_pass( - cov_mats, s_finals, ys, m0s, S0s, Cs, As, ensemble_vars, backend=backend, - ) - - return s_finals, ms, Vs + sel = jnp.asarray(block, dtype=int) + # Crop frames for the loss (both y and R_t) if s_frames is provided + if s_frames and len(s_frames) > 0: + # Crop both y and R_t using the same frame spec -- each (T', obs) + y_block_list = [crop_frames(ys[int(k)], s_frames) for k in block] + R_block_list = [crop_R(Rs[int(k)], s_frames) for k in block] -@typechecked -def inner_smooth_min_routine( - y: jnp.ndarray, - m0: jnp.ndarray, - S0: jnp.ndarray, - A: jnp.ndarray, - Q: jnp.ndarray, - C: jnp.ndarray, - ensemble_var: jnp.ndarray -) -> jax.Array: - # Run filtering with the current smooth_param - _, _, nll = forward_pass(y, m0, S0, A, Q, C, ensemble_var) - return nll - - -inner_smooth_min_routine_vmap = vmap(inner_smooth_min_routine, in_axes=(0, 0, 0, 0, 0, 0, 0)) - - -@typechecked -def smooth_min( - smooth_param: jax.Array, - cov_mats: jnp.ndarray, - ys: jnp.ndarray, - m0s: jnp.ndarray, - S0s: jnp.ndarray, - Cs: jnp.ndarray, - As: jnp.ndarray, - ensemble_vars: jnp.ndarray -) -> jax.Array: - """ - Computes the total negative log-likelihood (NLL) for a given smoothing parameter - by running a full forward-pass Kalman filter over all keypoints. - - This is the objective function minimized during smoothing parameter optimization. - - Args: - smooth_param: Scalar float value of the smoothing parameter `s`. - cov_mats: Array of shape (K, D, D). Process noise covariance templates. - ys: Array of shape (K, T, obs_dim). Observations per keypoint. - m0s: Array of shape (K, D). Initial state means. - S0s: Array of shape (K, D, D). Initial state covariances. - Cs: Array of shape (K, obs_dim, D). Observation matrices. - As: Array of shape (K, D, D). State transition matrices. - ensemble_vars: Array of shape (T, K, obs_dim). Time-varying ensemble variances. + # Stack and jnp + y_block = jnp.asarray(np.stack(y_block_list, axis=0)) # (B, T', obs) + R_block = jnp.asarray(np.stack(R_block_list, axis=0)) # (B, T', obs, obs) + else: + y_block = ys_j[sel] # (B, T, obs) + R_block = Rs_j[sel] # (B, T, obs, obs) + + m0_block = m0s_j[sel] + S0_block = S0s_j[sel] + A_block = As_j[sel] + Q_block = Qs_j[sel] + C_block = Cs_j[sel] + + s0 = float(np.mean([s_guess_per_k[k] for k in block])) + log_s0 = jnp.array(np.log(max(s0, 1e-6)), dtype=jnp.float32) + opt_state0 = optimizer.init(log_s0) + + log_s_f, opt_state_f, last_loss, iters_f, _done = _run_tol_loop( + log_s0, opt_state0, y_block, m0_block, S0_block, A_block, Q_block, C_block, + R_block + ) + s_star = float(jnp.exp(jnp.clip(log_s_f, s_bounds_log_j[0], s_bounds_log_j[1]))) + for k in block: + s_finals[k] = s_star + if verbose: + print(f"[Block {block}] s={s_star:.6g}, iters={int(iters_f)}, " + f"NLL={float(last_loss):.6f}") + + # -------------------- final smoother pass (full R_t) -------------------- + def _params_linear_for_k(k: int, s_val: float): + A_k, C_k = As_j[k], Cs_j[k] + f_fn = (lambda x, A=A_k: A @ x) + h_fn = (lambda x, C=C_k: C @ x) + return params_nlgssm_for_keypoint( + m0s_j[k], S0s_j[k], Qs_j[k], s_val, Rs[k], f_fn, h_fn) + + means_list, covs_list = [], [] + for k in range(K): + params_k = _params_linear_for_k(k, s_finals[k]) + sm = extended_kalman_smoother(params_k, ys_j[k]) + if hasattr(sm, "smoothed_means"): + m_k, V_k = sm.smoothed_means, sm.smoothed_covariances + else: + m_k, V_k = sm.filtered_means, sm.filtered_covariances + means_list.append(np.array(m_k)) + covs_list.append(np.array(V_k)) - Returns: - nlls: Scalar JAX array. Total negative log-likelihood across all keypoints. - """ - # Adjust Q based on smooth_param and cov_matrix - Qs = smooth_param * cov_mats - nlls = jnp.sum(inner_smooth_min_routine_vmap(ys, m0s, S0s, As, Qs, Cs, ensemble_vars)) - return nlls + ms = np.stack(means_list, axis=0) # (K, T, D) + Vs = np.stack(covs_list, axis=0) # (K, T, D, D) + return s_finals, ms, Vs diff --git a/eks/ibl_paw_multicam_smoother.py b/eks/ibl_paw_multicam_smoother.py index 81ed91c..80d6c40 100644 --- a/eks/ibl_paw_multicam_smoother.py +++ b/eks/ibl_paw_multicam_smoother.py @@ -9,12 +9,9 @@ from eks.marker_array import ( MarkerArray, input_dfs_to_markerArray, - mA_to_stacked_array, - stacked_array_to_mA, ) from eks.multicam_smoother import ensemble_kalman_smoother_multicam -from eks.stats import compute_pca -from eks.utils import convert_lp_dlc, make_dlc_pandas_index +from eks.utils import convert_lp_dlc def remove_camera_means(ensemble_stacks, camera_means): @@ -39,6 +36,7 @@ def pca(S, n_comps): pca_ = PCA(n_components=n_comps) return pca_.fit(S), pca_.explained_variance_ratio_ + @typechecked def fit_eks_multicam_ibl_paw( input_source: str | list, @@ -67,7 +65,7 @@ def fit_eks_multicam_ibl_paw( 'var' | 'confidence_weighted_var' verbose: True to print out details img_width: The width of the image being smoothed (128 default, IBL-specific). - inflate_vars: True to use Mahalanobis distance thresholding to inflate ensemble variance + inflate_vars: True to use Mahalanobis distance threshold to inflate ensemble variance n_latent: number of dimensions to keep from PCA Returns: @@ -119,7 +117,7 @@ def fit_eks_multicam_ibl_paw( raise ValueError('Need timestamps for both cameras') if len(input_dfs_right) != len(input_dfs_left) or len(input_dfs_left) == 0: raise ValueError( - 'There must be the same number of left and right camera models and >=1 model for each.') + 'Need same number of left and right camera models and >=1 model for each.') # Interpolate right cam markers to left cam timestamps markers_list_stacked_interp = [] diff --git a/eks/ibl_pupil_smoother.py b/eks/ibl_pupil_smoother.py index 2209020..a76b773 100644 --- a/eks/ibl_pupil_smoother.py +++ b/eks/ibl_pupil_smoother.py @@ -1,18 +1,24 @@ import os import warnings -from functools import partial + +from dynamax.nonlinear_gaussian_ssm.inference_ekf import ( + extended_kalman_filter, + extended_kalman_smoother, +) import jax import numpy as np import optax import pandas as pd -from jax import jit +from jax import jit, lax, value_and_grad from jax import numpy as jnp +from numbers import Real from typeguard import typechecked +from typing import List, Optional, Sequence, Tuple -from eks.core import backward_pass, ensemble, forward_pass +from eks.core import ensemble, params_nlgssm_for_keypoint from eks.marker_array import MarkerArray, input_dfs_to_markerArray -from eks.utils import crop_frames, format_data, make_dlc_pandas_index +from eks.utils import build_R_from_vars, crop_frames, format_data, make_dlc_pandas_index @typechecked @@ -226,9 +232,6 @@ def ensemble_kalman_smoother_ibl_pupil( [-.5, 1, 0], [0, 0, 1] ]) - # placeholder diagonal matrix for ensemble variance - R = np.eye(8) - centered_ensemble_preds = ensemble_preds.copy() # subtract COM means from the ensemble predictions for i in range(ensemble_preds.shape[1]): @@ -242,7 +245,7 @@ def ensemble_kalman_smoother_ibl_pupil( # Perform filtering with SINGLE PAIR of diameter_s, com_s # ------------------------------------------------------- s_finals, ms, Vs, nll = pupil_optimize_smooth( - y_obs, m0, S0, C, R, ensemble_vars, + y_obs, m0, S0, C, ensemble_vars, np.var(pupil_diameters), np.var(x_t_obs), np.var(y_t_obs), s_frames, smooth_params, verbose=verbose) if verbose: @@ -305,133 +308,175 @@ def ensemble_kalman_smoother_ibl_pupil( return markers_df, s_finals +@typechecked def pupil_optimize_smooth( - ys: np.ndarray, - m0: np.ndarray, - S0: np.ndarray, - C: np.ndarray, - R: np.ndarray, - ensemble_vars: np.ndarray, - diameters_var: np.ndarray, - x_var: np.ndarray, - y_var: np.ndarray, - s_frames: list | None = [(1, 2000)], - smooth_params: list | None = [None, None], - maxiter: int = 1000, - verbose: bool = False, + ys: np.ndarray, # (T, 8) centered obs + m0: np.ndarray, # (3,) + S0: np.ndarray, # (3,3) + C: np.ndarray, # (8,3) + ensemble_vars: np.ndarray, # (T, 8) + diameters_var: Real, + x_var: float, + y_var: float, + s_frames: Optional[List[Tuple[Optional[int], Optional[int]]]] = [(1, 2000)], + smooth_params: list | None = None, # [diam_s, com_s] in (0,1) + maxiter: int = 1000, # retained (unused with tol-loop) + verbose: bool = False, + # optimizer/loop knobs + lr: float = 5e-3, + tol: float = 1e-6, + safety_cap: int = 5000, ) -> tuple: - """Optimize-and-smooth function for the pupil example script. - - Parameters: - ys: Observations. Shape (keypoints, frames, coordinates). - m0: Initial mean state. - S0: Initial state covariance. - C: Measurement function. - R: Measurement noise covariance. - ensemble_vars: Ensemble variances. - diameters_var: Diameter variance - x_var: x variance for COM - y_var: y variance for COM - s_frames: List of frames. - smooth_params: Smoothing parameter tuple (diameter_s, com_s) - verbose: Prints extra information for smoothing parameter iterations + """ + Dynamax backend: optimize [s_diameter, s_com] with EKF NLL, then EKF smoother. Returns: - tuple: Final smoothing parameters, smoothed means, smoothed covariances, - negative log-likelihoods, negative log-likelihood values. + s_finals (list[float]), ms (T,3), Vs (T,3,3), nll (float) """ - @partial(jit) - def nll_loss_sequential_scan( - s_log, ys, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var): - s = jnp.exp(s_log) # Ensure positivity - return pupil_smooth( - s, ys, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var) + # logistic reparam to keep s in (eps,1-eps) + def _to_stable_s(u, eps=1e-3): + return jax.nn.sigmoid(u) * (1.0 - 2 * eps) + eps - loss_function = nll_loss_sequential_scan - # Optimize smooth_param + # crop ys and ev for the loss if s_frames provided + if s_frames and len(s_frames) > 0: + y_cropped = crop_frames(ys, s_frames) # (T', 8) + ev_cropped = crop_frames(ensemble_vars, s_frames) # (T', 8) + else: + y_cropped, ev_cropped = ys, ensemble_vars + + # build time-varying R_t for loss and for final smoothing (full sequence) + R_loss = build_R_from_vars(ev_cropped) # (T' ,8,8) + R_full = build_R_from_vars(ensemble_vars) # (T ,8,8) + + # jnp once + y_c = jnp.asarray(y_cropped) + R_loss = jnp.asarray(R_loss) + m0_j, S0_j, C_j = jnp.asarray(m0), jnp.asarray(S0), jnp.asarray(C) + y_full = jnp.asarray(ys) + R_full = jnp.asarray(R_full) + + # local params builder using your NLGSSM wrapper; pass Q_exact with s=1.0 + def _params_linear(m0, S0, A, Q_exact, R_any, C): + f_fn = (lambda x, A=A: A @ x) + h_fn = (lambda x, C=C: C @ x) + return params_nlgssm_for_keypoint(m0, S0, Q_exact, 1.0, R_any, f_fn, h_fn) + + # EKF NLL for a given unconstrained u = [u_d, u_c] + def _nll_from_u(u: jnp.ndarray) -> jnp.ndarray: + s_d, s_c = _to_stable_s(u) + A = jnp.diag(jnp.array([s_d, s_c, s_c])) + Q = jnp.diag(jnp.array([ + diameters_var * (1.0 - s_d**2), + x_var * (1.0 - s_c**2), + y_var * (1.0 - s_c**2), + ])) + params = _params_linear(m0_j, S0_j, A, Q, R_loss, C_j) + post = extended_kalman_filter(params, y_c) + return -post.marginal_loglik + + optimizer = optax.adam(lr) if smooth_params is None or smooth_params[0] is None or smooth_params[1] is None: - # Crop to only contain s_frames for time axis - y_cropped = crop_frames(ys, s_frames) - ensemble_vars_cropped = crop_frames(ensemble_vars, s_frames) - - # Optimize negative log likelihood - s_init = jnp.log(jnp.array([0.99, 0.98])) # reasonable guess for s_finals - optimizer = optax.adam(learning_rate=0.005) - opt_state = optimizer.init(s_init) - - def step(s, opt_state): - loss, grads = jax.value_and_grad(loss_function)( - s, y_cropped, m0, S0, C, R, ensemble_vars_cropped, diameters_var, x_var, y_var + # init near your old guess (invert logistic) + s0 = jnp.array([0.99, 0.98]) + u0 = jnp.log(s0 / (1.0 - s0)) + opt_state0 = optimizer.init(u0) + + @jit + def _opt_step(u, opt_state): + loss, grad = value_and_grad(_nll_from_u)(u) + updates, opt_state = optimizer.update(grad, opt_state) + u = optax.apply_updates(u, updates) + return u, opt_state, loss + + @jit + def _run_tol_loop(u0, opt_state0): + def cond(carry): + _, _, prev_loss, iters, done = carry + return jnp.logical_and(~done, iters < safety_cap) + + def body(carry): + u, opt_state, prev_loss, iters, _ = carry + u, opt_state, loss = _opt_step(u, opt_state) + rel_tol = tol * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12))) + done = jnp.where(jnp.isfinite(prev_loss), + jnp.linalg.norm(loss - prev_loss) < (rel_tol + 1e-6), + False) + return (u, opt_state, loss, iters + 1, done) + return lax.while_loop( + cond, body, (u0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) ) - updates, opt_state = optimizer.update(grads, opt_state) - s = optax.apply_updates(s, updates) - return s, opt_state, loss - prev_loss = jnp.inf - for iteration in range(maxiter): - s_init, opt_state, loss = step(s_init, opt_state) - - if verbose and iteration % 10 == 0 or iteration == maxiter - 1: - print(f'Iteration {iteration}, Current loss: {loss}, Current s: {jnp.exp(s_init)}') - - tol = 1e-6 * jnp.abs(jnp.log(prev_loss)) - if jnp.linalg.norm(loss - prev_loss) < tol + 1e-6: - break - prev_loss = loss - - s_finals = jnp.exp(s_init) - s_finals = [round(s_finals[0], 5), round(s_finals[1], 5)] - print(f'Optimized to diameter_s={s_finals[0]}, com_s={s_finals[1]}') + u_f, opt_state_f, last_loss, iters_f, _ = _run_tol_loop(u0, opt_state0) + s_opt = _to_stable_s(u_f) + if verbose: + print(f"[pupil/dynamax] iters={int(iters_f)} s_diam={float(s_opt[0]):.6f} " + f"s_com={float(s_opt[1]):.6f} NLL={float(last_loss):.6f}") else: - s_finals = smooth_params + s_user = jnp.clip(jnp.asarray(smooth_params, dtype=jnp.float32), 1e-3, 1 - 1e-3) + s_opt = s_user - # Final smooth with optimized s + # final smoother on full sequence with full R_t + s_d, s_c = float(s_opt[0]), float(s_opt[1]) ms, Vs, nll = pupil_smooth( - s_finals, ys, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var, return_full=True) - - return s_finals, ms, Vs, nll + smooth_params=[s_d, s_c], + ys=y_full, m0=m0_j, S0=S0_j, C=C_j, R=R_full, + diameters_var=diameters_var, x_var=x_var, y_var=y_var, + return_full=True + ) + return [s_d, s_c], np.asarray(ms), np.asarray(Vs), float(nll) -def pupil_smooth(smooth_params, ys, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var, - return_full=False): +@typechecked +def pupil_smooth( + smooth_params: Sequence[float], # [s_diam, s_com] in (0,1) + ys: np.ndarray | jnp.ndarray, # (T, 8) + m0: np.ndarray | jnp.ndarray, # (3,) + S0: np.ndarray | jnp.ndarray, # (3,3) + C: np.ndarray | jnp.ndarray, # (8,3) + R: np.ndarray | jnp.ndarray, # (T, 8, 8) time-varying obs covariance + diameters_var: Real, + x_var: float, + y_var: float, + return_full: bool = False, +): """ - Smooths once using the given smooth_param. Returns only the nll loss by default - (if return_full is False). - - Parameters: - smooth_params (float): Smoothing parameter. - block (list): List of blocks. - cov_mats (np.ndarray): Covariance matrices. - ys (np.ndarray): Observations. - m0s (np.ndarray): Initial mean state. - S0s (np.ndarray): Initial state covariance. - Cs (np.ndarray): Measurement function. - As (np.ndarray): State-transition matrix. - Rs (np.ndarray): Measurement noise covariance. - - Returns: - float: Negative log-likelihood. + One EKF forward (and optional smoother) using Dynamax NLGSSM with: + A = diag([s_d, s_c, s_c]) and Q = diag([σ_d^2(1-s_d^2), σ_x^2(1-s_c^2), σ_y^2(1-s_c^2)]). + R_t = diag(ensemble_vars[t]) (or provided via _R_override). """ - # Construct As - diameter_s, com_s = smooth_params[0], smooth_params[1] - A = jnp.array([ - [diameter_s, 0, 0], - [0, com_s, 0], - [0, 0, com_s] - ]) - - # Construct cov_matrix Q - Q = jnp.array([ - [diameters_var * (1 - (A[0, 0] ** 2)), 0, 0], - [0, x_var * (1 - A[1, 1] ** 2), 0], - [0, 0, y_var * (1 - (A[2, 2] ** 2))] - ]) - - mf, Vf, nll = forward_pass(ys, m0, S0, A, Q, C, ensemble_vars) - - if return_full: - ms, Vs = backward_pass(mf, Vf, A, Q) - return ms, Vs, nll - - return nll \ No newline at end of file + ys = jnp.asarray(ys) + m0 = jnp.asarray(m0) + S0 = jnp.asarray(S0) + C = jnp.asarray(C) + + s_d = jnp.clip(jnp.asarray(smooth_params[0]), 1e-3, 1 - 1e-3) + s_c = jnp.clip(jnp.asarray(smooth_params[1]), 1e-3, 1 - 1e-3) + + A = jnp.diag(jnp.array([s_d, s_c, s_c])) + Q = jnp.diag(jnp.array([ + diameters_var * (1.0 - s_d**2), + x_var * (1.0 - s_c**2), + y_var * (1.0 - s_c**2), + ])) + + # linear f/h closures + f_fn = (lambda x, A=A: A @ x) + h_fn = (lambda x, C=C: C @ x) + + # build NLGSSM params; pass Q as exact and s=1.0 to avoid extra scaling + params = params_nlgssm_for_keypoint(m0, S0, Q, 1.0, R, f_fn, h_fn) + + filt = extended_kalman_filter(params, ys) + nll = -filt.marginal_loglik + if not return_full: + return nll + + sm = extended_kalman_smoother(params, ys) + if hasattr(sm, "smoothed_means"): + ms = sm.smoothed_means + Vs = sm.smoothed_covariances + else: + ms = sm.filtered_means + Vs = sm.filtered_covariances + return ms, Vs, -filt.marginal_loglik diff --git a/eks/kalman_backends.py b/eks/kalman_backends.py deleted file mode 100644 index ccd61b4..0000000 --- a/eks/kalman_backends.py +++ /dev/null @@ -1,70 +0,0 @@ -from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_smoother -from dynamax.nonlinear_gaussian_ssm.models import ( - ParamsNLGSSM, -) - -import jax -import jax.numpy as jnp -import numpy as np -from typing import Union, Tuple, Callable -from typeguard import typechecked - -ArrayLike = Union[np.ndarray, jax.Array] - -@typechecked -def dynamax_ekf_smooth_routine( - y: ArrayLike, - m0: ArrayLike, - S0: ArrayLike, - A: ArrayLike, - Q: ArrayLike, - C: ArrayLike | None, - ensemble_vars: ArrayLike, # shape (T, obs_dim) - f_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None, - h_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Extended Kalman smoother using the Dynamax nonlinear interface, - allowing for time-varying observation noise. - - By default, uses linear dynamics and emissions: f(x) = Ax, h(x) = Cx. - - Args: - y: (T, obs_dim) observation sequence. - m0: (state_dim,) initial mean. - S0: (state_dim, state_dim) initial covariance. - A: (state_dim, state_dim) dynamics matrix. - Q: (state_dim, state_dim) process noise covariance. - C: (obs_dim, state_dim) emission matrix (optional). - ensemble_vars: (T, obs_dim) per-timestep observation noise variance. - f_fn: optional dynamics function f(x). - h_fn: optional emission function h(x). - - Returns: - smoothed_means: (T, state_dim) - smoothed_covariances: (T, state_dim, state_dim) - """ - y, m0, S0, A, Q, ensemble_vars = map(jnp.asarray, (y, m0, S0, A, Q, ensemble_vars)) - C = jnp.asarray(C) if C is not None else None - - if f_fn is None: - f_fn = lambda x: A @ x - if h_fn is None: - if C is None: - raise ValueError("Must provide either emission matrix C or a nonlinear emission function h_fn.") - h_fn = lambda x: C @ x - - # Dynamically determine obs_dim from h_fn output - obs_dim = y.shape[1] - R_t = jnp.stack([jnp.diag(var_t[:obs_dim]) for var_t in ensemble_vars], axis=0) # shape (T, obs_dim, obs_dim) - params = ParamsNLGSSM( - initial_mean=m0, - initial_covariance=S0, - dynamics_function=f_fn, - dynamics_covariance=Q, - emission_function=h_fn, - emission_covariance=R_t, - ) - - posterior = extended_kalman_smoother(params, y) - return posterior.smoothed_means, posterior.smoothed_covariances \ No newline at end of file diff --git a/eks/multicam_smoother.py b/eks/multicam_smoother.py index 001f296..3396897 100644 --- a/eks/multicam_smoother.py +++ b/eks/multicam_smoother.py @@ -130,7 +130,6 @@ def fit_eks_multicam( inflate_vars: bool = False, verbose: bool = False, n_latent: int = 3, - backend: str = 'jax', ) -> tuple: """ Fit the Ensemble Kalman Smoother for un-mirrored multi-camera data. @@ -179,7 +178,6 @@ def fit_eks_multicam( verbose=verbose, inflate_vars=inflate_vars, n_latent=n_latent, - backend=backend, ) # Save output DataFrames to CSVs (one per camera view) os.makedirs(save_dir, exist_ok=True) @@ -257,8 +255,8 @@ def ensemble_kalman_smoother_multicam( n_components=n_latent, pca_object=pca_object, ) - if inflate_vars: + print('inflating') if inflate_vars_kwargs.get("mean", None) is not None: # set mean to zero since we are passing in centered predictions inflate_vars_kwargs["mean"] = np.zeros_like(inflate_vars_kwargs["mean"]) @@ -272,7 +270,7 @@ def ensemble_kalman_smoother_multicam( # Kalman Filter Section ------------------------------------------ # Initialize Kalman filter parameters - m0s, S0s, As, cov_mats, Cs = initialize_kalman_filter_pca( + m0s, S0s, As, Qs, Cs = initialize_kalman_filter_pca( good_pcs_list=good_pcs_list, ensemble_pca=ensemble_pca, n_latent=n_latent, @@ -290,7 +288,7 @@ def ensemble_kalman_smoother_multicam( # Optimize smoothing s_finals, ms, Vs = optimize_smooth_param( - cov_mats=cov_mats, + Qs=Qs, ys=ys, m0s=m0s, S0s=S0s, @@ -300,7 +298,6 @@ def ensemble_kalman_smoother_multicam( s_frames=s_frames, smooth_param=smooth_param, verbose=verbose, - backend=backend, ) # Reproject from latent space back to observed space camera_arrs = [[] for _ in camera_names] @@ -398,7 +395,7 @@ def initialize_kalman_filter_pca( def mA_compute_maha(centered_emA_preds, emA_vars, emA_likes, n_latent, - inflate_vars_kwargs={}, threshold=5, scalar=2): + inflate_vars_kwargs={}, threshold=5, scalar=10): """ Reshape marker arrays for Mahalanobis computation, compute Mahalanobis distances, and optionally inflate variances for all keypoints. @@ -430,7 +427,7 @@ def mA_compute_maha(centered_emA_preds, emA_vars, emA_likes, n_latent, inflate_vars_kwargs['v_quantile_threshold'] = 50.0 inflated = True tmp_vars = vars - + print(f'inflating keypoint: {k}') while inflated: # Compute Mahalanobis distances if inflate_vars_kwargs.get("likelihoods", None) is None: diff --git a/eks/singlecam_smoother.py b/eks/singlecam_smoother.py index cf96aa9..441e06d 100644 --- a/eks/singlecam_smoother.py +++ b/eks/singlecam_smoother.py @@ -21,7 +21,6 @@ def fit_eks_singlecam( avg_mode: str = 'median', var_mode: str = 'confidence_weighted_var', verbose: bool = False, - backend: str = 'jax', ) -> tuple: """Fit the Ensemble Kalman Smoother for single-camera data. @@ -66,7 +65,6 @@ def fit_eks_singlecam( avg_mode=avg_mode, var_mode=var_mode, verbose=verbose, - backend=backend, ) # Save the output DataFrame to CSV @@ -87,7 +85,6 @@ def ensemble_kalman_smoother_singlecam( avg_mode: str = 'median', var_mode: str = 'confidence_weighted_var', verbose: bool = False, - backend: str = 'jax', ) -> tuple: """Perform Ensemble Kalman Smoothing for single-camera data. @@ -142,7 +139,7 @@ def ensemble_kalman_smoother_singlecam( # Main smoothing function s_finals, ms, Vs = optimize_smooth_param( cov_mats, ys, m0s, S0s, Cs, As, emA_vars.get_array(squeeze=True), - s_frames, smooth_param, blocks, verbose=verbose, backend=backend + s_frames, smooth_param, blocks, verbose=verbose ) y_m_smooths = np.zeros((n_keypoints, n_frames, 2)) diff --git a/eks/utils.py b/eks/utils.py index 3783fee..343e547 100644 --- a/eks/utils.py +++ b/eks/utils.py @@ -251,6 +251,7 @@ def crop_frames(y: np.ndarray | jnp.ndarray, s_frames: list | tuple) -> np.ndarr return np.concatenate(result) +@typechecked() def center_predictions( ensemble_marker_array: MarkerArray, quantile_keep_pca: float @@ -324,3 +325,38 @@ def center_predictions( emA_means = MarkerArray.stack(emA_means_list, "keypoints") return valid_frames_mask, emA_centered_preds, emA_good_centered_preds, emA_means + + +@typechecked +def build_R_from_vars(ev: np.ndarray) -> np.ndarray: + """ + Build time-varying diagonal observation covariances from per-dimension variances. + ev shape: (..., T, O) -> returns (..., T, O, O) with diag(ev[t]). + """ + ev_np = np.clip(np.asarray(ev), 1e-12, None) + O_dim = ev_np.shape[-1] + # Broadcast-diagonal without Python loops: + # (..., T, O, 1) * (O, O) -> (..., T, O, O), scaling rows of the identity. + return ev_np[..., :, None] * np.eye(O_dim, dtype=ev_np.dtype) + + +@typechecked +def crop_R(R: np.ndarray, s_frames: list | None) -> np.ndarray: + """ + Crop time-varying R along its time axis using the same spec as crop_frames. + R_tv shape: (..., T, O, O) -> returns (..., T', O, O). + Assumes R_tv is diagonal (built via build_R_tv_from_vars) but works generically. + """ + if not s_frames: + return np.asarray(R) + R_np = np.asarray(R) + leading = R_np.shape[:-3] # any leading batch dims + T, O, O2 = R_np.shape[-3:] + assert O == O2, "R_tv must be square in its last two dims" + # Flatten leading dims to crop time contiguous + R_flat = R_np.reshape((-1, T, O, O)) + cropped_list = [] + for block in R_flat: + cropped_list.append(crop_frames(block, s_frames)) # uses the same semantics + R_cropped = np.stack(cropped_list, axis=0) + return R_cropped.reshape((*leading, -1, O, O)) diff --git a/scripts/multicam_example.py b/scripts/multicam_example.py index aec05c2..1656a8f 100644 --- a/scripts/multicam_example.py +++ b/scripts/multicam_example.py @@ -27,7 +27,6 @@ verbose = True if args.verbose == 'True' else False inflate_vars = True if args.inflate_vars == 'True' else False n_latent = args.n_latent -backend = args.backend # Fit EKS using the provided input data camera_dfs, s_finals, input_dfs, bodypart_list = fit_eks_multicam( @@ -41,7 +40,6 @@ verbose=verbose, inflate_vars=inflate_vars, n_latent=args.n_latent, - backend=backend, ) # Plot results for a specific keypoint (default to last keypoint of last camera view) diff --git a/scripts/singlecam_example.py b/scripts/singlecam_example.py index f082d7c..c6f76d4 100644 --- a/scripts/singlecam_example.py +++ b/scripts/singlecam_example.py @@ -24,7 +24,6 @@ s_frames = args.s_frames # Frames to be used for automatic optimization if s is not provided blocks = args.blocks verbose = True if args.verbose == 'True' else False -backend = args.backend # Fit EKS using the provided input data output_df, s_finals, input_dfs, bodypart_list = fit_eks_singlecam( @@ -35,7 +34,6 @@ s_frames=s_frames, blocks=blocks, verbose=verbose, - backend=backend ) # Plot results for a specific keypoint (default to last keypoint) diff --git a/tests/test_multicam_smoother.py b/tests/test_multicam_smoother.py index 5a0ad1f..99f5721 100644 --- a/tests/test_multicam_smoother.py +++ b/tests/test_multicam_smoother.py @@ -45,9 +45,10 @@ def test_ensemble_kalman_smoother_multicam(): f"Expected {len(camera_names)} entries in camera_dfs, got {len(camera_dfs)}" assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" - assert smooth_params_final == smooth_param, \ - f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ - f"got {smooth_params_final}" + for k in range(len(keypoint_names)): + assert smooth_params_final[c] == smooth_param, \ + f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ + f"got {smooth_params_final}" # --------------------------------------------------- # Run with variance inflation @@ -69,9 +70,10 @@ def test_ensemble_kalman_smoother_multicam(): f"Expected {len(camera_names)} entries in camera_dfs, got {len(camera_dfs)}" assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" - assert smooth_params_final == smooth_param, \ - f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ - f"got {smooth_params_final}" + for k in range(len(keypoint_names)): + assert smooth_params_final[c] == smooth_param, \ + f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ + f"got {smooth_params_final}" # --------------------------------------------------- # Run with variance inflation + more maha kwargs @@ -95,9 +97,10 @@ def test_ensemble_kalman_smoother_multicam(): f"Expected {len(camera_names)} entries in camera_dfs, got {len(camera_dfs)}" assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" - assert smooth_params_final == smooth_param, \ - f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ - f"got {smooth_params_final}" + for k in range(len(keypoint_names)): + assert smooth_params_final[c] == smooth_param, \ + f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ + f"got {smooth_params_final}" # --------------------------------------------------- # Run with variance inflation + more maha kwargs diff --git a/tests/test_singlecam_smoother.py b/tests/test_singlecam_smoother.py index d8751fd..678fb40 100644 --- a/tests/test_singlecam_smoother.py +++ b/tests/test_singlecam_smoother.py @@ -41,7 +41,8 @@ def _check_outputs(df, params): blocks=blocks, ) _check_outputs(df_smoothed, s_finals) - assert s_finals == [smooth_param] + for k in range(len(keypoint_names)): + assert s_finals[k] == smooth_param # run with fixed smooth param (int) smooth_param = 5 @@ -53,7 +54,8 @@ def _check_outputs(df, params): blocks=blocks, ) _check_outputs(df_smoothed, s_finals) - assert s_finals == [smooth_param] + for k in range(len(keypoint_names)): + assert s_finals[k] == smooth_param # run with fixed smooth param (single-entry list) smooth_param = [0.1] @@ -65,7 +67,8 @@ def _check_outputs(df, params): blocks=blocks, ) _check_outputs(df_smoothed, s_finals) - assert s_finals == smooth_param + for k in range(len(keypoint_names)): + assert s_finals[k] == smooth_param # run with fixed smooth param (list) smooth_param = [0.1, 0.4] @@ -77,7 +80,8 @@ def _check_outputs(df, params): blocks=blocks, ) _check_outputs(df_smoothed, s_finals) - assert np.all(s_finals == smooth_param) + for k in range(len(keypoint_names)): + assert s_finals[k] == smooth_param[k] # run with None smooth param smooth_param = None From ac674145f7285e0da5bfd2eb8cba995ffbae7705 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Fri, 3 Oct 2025 13:42:02 -0400 Subject: [PATCH 06/11] removed notebooks from git --- bb_crop.ipynb | 173 --------- bbox_uncrop.ipynb | 231 ----------- ekf.ipynb | 964 ---------------------------------------------- 3 files changed, 1368 deletions(-) delete mode 100644 bb_crop.ipynb delete mode 100644 bbox_uncrop.ipynb delete mode 100644 ekf.ipynb diff --git a/bb_crop.ipynb b/bb_crop.ipynb deleted file mode 100644 index 7e93793..0000000 --- a/bb_crop.ipynb +++ /dev/null @@ -1,173 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "aef7a1cd", - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "from typing import List, Union\n", - "import pandas as pd\n", - "import re\n", - "\n", - "def _detect_frame_index_pose(df: pd.DataFrame) -> pd.DataFrame:\n", - " df = df.copy()\n", - " if 'scorer' in df.columns and pd.api.types.is_integer_dtype(df['scorer']):\n", - " return df.set_index('scorer', drop=True).rename_axis('frame')\n", - " for col in ('frame', 'index', 'Unnamed: 0'):\n", - " if col in df.columns and pd.api.types.is_integer_dtype(df[col]):\n", - " return df.set_index(col, drop=True).rename_axis('frame')\n", - " return df.reset_index(drop=True).rename_axis('frame')\n", - "\n", - "def _detect_frame_index_bbox(df: pd.DataFrame) -> pd.DataFrame:\n", - " df = df.copy()\n", - " if 'frame' in df.columns:\n", - " return df.set_index('frame', drop=True)\n", - " if 'Unnamed: 0' in df.columns:\n", - " return df.rename(columns={'Unnamed: 0': 'frame'}).set_index('frame', drop=True)\n", - " if not isinstance(df.index, pd.RangeIndex):\n", - " return df.reset_index(drop=False).rename(columns={'index': 'frame'}).set_index('frame', drop=True)\n", - " return df.rename_axis('frame')\n", - "\n", - "def _xy_like_multiindex_cols(columns: pd.MultiIndex):\n", - " x_cols, y_cols = [], []\n", - " for col in columns:\n", - " last = col[-1]\n", - " if isinstance(last, str):\n", - " if last.startswith('x') and 'var' not in last: x_cols.append(col)\n", - " if last.startswith('y') and 'var' not in last: y_cols.append(col)\n", - " return x_cols, y_cols\n", - "\n", - "def _xy_like_flat_cols(columns: pd.Index):\n", - " x_cols, y_cols = [], []\n", - " for c in columns.astype(str):\n", - " if 'var' in c.lower(): \n", - " continue\n", - " if c == 'x' or c.endswith('_x') or c.startswith('x_'): x_cols.append(c)\n", - " if c == 'y' or c.endswith('_y') or c.startswith('y_'): y_cols.append(c)\n", - " return x_cols, y_cols\n", - "\n", - "def translate_pose_by_bbox(pose_df: pd.DataFrame, bbox_df: pd.DataFrame, mode: str = \"subtract\") -> pd.DataFrame:\n", - " \"\"\"Map full-frame coords -> bbox-cropped coords (mode='subtract').\n", - " Use mode='add' to go back to full-frame.\"\"\"\n", - " pose_df = _detect_frame_index_pose(pose_df)\n", - " bbox_df = _detect_frame_index_bbox(bbox_df)\n", - " common = pose_df.index.intersection(bbox_df.index)\n", - " pose_df, bbox_df = pose_df.loc[common].copy(), bbox_df.loc[common].copy()\n", - " if not {'x','y'}.issubset(bbox_df.columns):\n", - " raise ValueError(f\"bbox_df must have 'x' and 'y'; got {bbox_df.columns}\")\n", - " sign = -1 if mode == \"subtract\" else 1\n", - "\n", - " if isinstance(pose_df.columns, pd.MultiIndex):\n", - " x_cols, y_cols = _xy_like_multiindex_cols(pose_df.columns)\n", - " else:\n", - " x_cols, y_cols = _xy_like_flat_cols(pose_df.columns)\n", - "\n", - " if x_cols: pose_df.loc[:, x_cols] = pose_df.loc[:, x_cols].add(sign * bbox_df['x'], axis=0)\n", - " if y_cols: pose_df.loc[:, y_cols] = pose_df.loc[:, y_cols].add(sign * bbox_df['y'], axis=0)\n", - " return pose_df\n", - "\n", - "def batch_translate_pose_csvs(\n", - " pose_csvs: List[Union[str, Path]],\n", - " bbox_csvs: List[Union[str, Path]],\n", - " output_dir: Union[str, Path],\n", - " mode: str = \"subtract\",\n", - " suffix: str = \"\",\n", - "):\n", - " if len(pose_csvs) != len(bbox_csvs):\n", - " raise ValueError(\"pose_csvs and bbox_csvs must be same length (one per view).\")\n", - " output_dir = Path(output_dir); output_dir.mkdir(parents=True, exist_ok=True)\n", - " outs = []\n", - " for pose_csv, bbox_csv in zip(pose_csvs, bbox_csvs):\n", - " # Try MultiIndex header first (Lightning Pose/EKS), else flat\n", - " try: pose_df = pd.read_csv(pose_csv, header=[0,1,2])\n", - " except Exception: pose_df = pd.read_csv(pose_csv)\n", - " bbox_df = pd.read_csv(bbox_csv)\n", - " translated = translate_pose_by_bbox(pose_df, bbox_df, mode=mode)\n", - " out_path = output_dir / f\"{Path(pose_csv).stem}.csv\"\n", - " translated.to_csv(out_path)\n", - " outs.append(out_path)\n", - " return outs" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "a1d8c557", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[WindowsPath('outputs/cropped_csvs/PRL43_200617_131904_lBack.short_cropped.csv')]\n" - ] - } - ], - "source": [ - "pose_csvs = [\"./outputs/chickadee-preds/video_preds/PRL43_200617_131904_lBack.short.csv\"] # one per view\n", - "bbox_csvs = [\"./data/bounding_boxes/PRL43_200617_131904_lBack.short_bbox.csv\"] # matching order\n", - "out_files = batch_translate_pose_csvs(pose_csvs, bbox_csvs, output_dir=\"./outputs/cropped_csvs\")\n", - "print(out_files)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d470ddc0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Frames: 1800\n", - "FPS: 60.0\n", - "Duration (s): 30.000\n" - ] - } - ], - "source": [ - "import cv2\n", - "\n", - "path = \"./videos/chickadee/PRL43_200617_131904_lBack.short.mp4\" # update to your file path\n", - "cap = cv2.VideoCapture(path)\n", - "if not cap.isOpened():\n", - " raise RuntimeError(\"Could not open video\")\n", - "\n", - "frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", - "fps = cap.get(cv2.CAP_PROP_FPS)\n", - "duration_sec = frame_count / fps if fps > 0 else None\n", - "\n", - "cap.release()\n", - "\n", - "print(f\"Frames: {frame_count}\")\n", - "print(f\"FPS: {fps}\")\n", - "print(f\"Duration (s): {duration_sec:.3f}\" if duration_sec else \"Duration unknown\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/bbox_uncrop.ipynb b/bbox_uncrop.ipynb deleted file mode 100644 index 1cdfc88..0000000 --- a/bbox_uncrop.ipynb +++ /dev/null @@ -1,231 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "5b2692e2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[INFO] CWD: E:\\eks\n", - "[INFO] preds_root: E:\\eks\\data\\chickadee\n", - "[INFO] bbox_root : E:\\eks\\data\\bounding_boxes\n", - "[INFO] output_dir: E:\\eks\\data\\chickadee_uncropped\n" - ] - }, - { - "ename": "IntCastingNaNError", - "evalue": "Cannot convert non-finite values (NA or inf) to integer", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mIntCastingNaNError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[6], line 150\u001b[0m\n\u001b[0;32m 147\u001b[0m bbox_root \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./data/bounding_boxes\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m# one bbox CSV per camera (name contains camera)\u001b[39;00m\n\u001b[0;32m 148\u001b[0m output_dir \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./data/chickadee_uncropped\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m--> 150\u001b[0m \u001b[43mprocess_all\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcamera_names\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreds_root\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbbox_root\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[1;32mIn[6], line 125\u001b[0m, in \u001b[0;36mprocess_all\u001b[1;34m(camera_names, preds_root, bbox_root, output_dir)\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 121\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNeed exactly one bbox CSV for camera \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcam\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, found \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(bbox_matches)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 122\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m[\u001b[38;5;28mstr\u001b[39m(p)\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mp\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39mbbox_matches]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 123\u001b[0m )\n\u001b[0;32m 124\u001b[0m bbox_path \u001b[38;5;241m=\u001b[39m bbox_matches[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m--> 125\u001b[0m df_bbox \u001b[38;5;241m=\u001b[39m \u001b[43m_read_bbox_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbbox_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 127\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[INFO] Camera \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcam\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(preds_paths)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m preds, bbox=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbbox_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 129\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m pred_path \u001b[38;5;129;01min\u001b[39;00m preds_paths:\n", - "Cell \u001b[1;32mIn[6], line 33\u001b[0m, in \u001b[0;36m_read_bbox_csv\u001b[1;34m(path)\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[0;32m 31\u001b[0m df \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(p, header\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mframe\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mh\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m---> 33\u001b[0m df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mframe\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mframe\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m c \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mh\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[0;32m 35\u001b[0m df[c] \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mto_numeric(df[c], errors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mraise\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\generic.py:6640\u001b[0m, in \u001b[0;36mNDFrame.astype\u001b[1;34m(self, dtype, copy, errors)\u001b[0m\n\u001b[0;32m 6634\u001b[0m results \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m 6635\u001b[0m ser\u001b[38;5;241m.\u001b[39mastype(dtype, copy\u001b[38;5;241m=\u001b[39mcopy, errors\u001b[38;5;241m=\u001b[39merrors) \u001b[38;5;28;01mfor\u001b[39;00m _, ser \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitems()\n\u001b[0;32m 6636\u001b[0m ]\n\u001b[0;32m 6638\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 6639\u001b[0m \u001b[38;5;66;03m# else, only a single dtype is given\u001b[39;00m\n\u001b[1;32m-> 6640\u001b[0m new_data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_mgr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 6641\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_constructor_from_mgr(new_data, axes\u001b[38;5;241m=\u001b[39mnew_data\u001b[38;5;241m.\u001b[39maxes)\n\u001b[0;32m 6642\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m res\u001b[38;5;241m.\u001b[39m__finalize__(\u001b[38;5;28mself\u001b[39m, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mastype\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\internals\\managers.py:430\u001b[0m, in \u001b[0;36mBaseBlockManager.astype\u001b[1;34m(self, dtype, copy, errors)\u001b[0m\n\u001b[0;32m 427\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m using_copy_on_write():\n\u001b[0;32m 428\u001b[0m copy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m--> 430\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 431\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mastype\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43musing_cow\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43musing_copy_on_write\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 436\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\internals\\managers.py:363\u001b[0m, in \u001b[0;36mBaseBlockManager.apply\u001b[1;34m(self, f, align_keys, **kwargs)\u001b[0m\n\u001b[0;32m 361\u001b[0m applied \u001b[38;5;241m=\u001b[39m b\u001b[38;5;241m.\u001b[39mapply(f, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 362\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 363\u001b[0m applied \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(b, f)(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 364\u001b[0m result_blocks \u001b[38;5;241m=\u001b[39m extend_blocks(applied, result_blocks)\n\u001b[0;32m 366\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39mfrom_blocks(result_blocks, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maxes)\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\internals\\blocks.py:758\u001b[0m, in \u001b[0;36mBlock.astype\u001b[1;34m(self, dtype, copy, errors, using_cow, squeeze)\u001b[0m\n\u001b[0;32m 755\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan not squeeze with more than one column.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 756\u001b[0m values \u001b[38;5;241m=\u001b[39m values[\u001b[38;5;241m0\u001b[39m, :] \u001b[38;5;66;03m# type: ignore[call-overload]\u001b[39;00m\n\u001b[1;32m--> 758\u001b[0m new_values \u001b[38;5;241m=\u001b[39m \u001b[43mastype_array_safe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 760\u001b[0m new_values \u001b[38;5;241m=\u001b[39m maybe_coerce_values(new_values)\n\u001b[0;32m 762\u001b[0m refs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\dtypes\\astype.py:237\u001b[0m, in \u001b[0;36mastype_array_safe\u001b[1;34m(values, dtype, copy, errors)\u001b[0m\n\u001b[0;32m 234\u001b[0m dtype \u001b[38;5;241m=\u001b[39m dtype\u001b[38;5;241m.\u001b[39mnumpy_dtype\n\u001b[0;32m 236\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 237\u001b[0m new_values \u001b[38;5;241m=\u001b[39m \u001b[43mastype_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 238\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mTypeError\u001b[39;00m):\n\u001b[0;32m 239\u001b[0m \u001b[38;5;66;03m# e.g. _astype_nansafe can fail on object-dtype of strings\u001b[39;00m\n\u001b[0;32m 240\u001b[0m \u001b[38;5;66;03m# trying to convert to float\u001b[39;00m\n\u001b[0;32m 241\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m errors \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\dtypes\\astype.py:182\u001b[0m, in \u001b[0;36mastype_array\u001b[1;34m(values, dtype, copy)\u001b[0m\n\u001b[0;32m 179\u001b[0m values \u001b[38;5;241m=\u001b[39m values\u001b[38;5;241m.\u001b[39mastype(dtype, copy\u001b[38;5;241m=\u001b[39mcopy)\n\u001b[0;32m 181\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 182\u001b[0m values \u001b[38;5;241m=\u001b[39m \u001b[43m_astype_nansafe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 184\u001b[0m \u001b[38;5;66;03m# in pandas we don't store numpy str dtypes, so convert to object\u001b[39;00m\n\u001b[0;32m 185\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(dtype, np\u001b[38;5;241m.\u001b[39mdtype) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(values\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mtype, \u001b[38;5;28mstr\u001b[39m):\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\dtypes\\astype.py:101\u001b[0m, in \u001b[0;36m_astype_nansafe\u001b[1;34m(arr, dtype, copy, skipna)\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mensure_string_array(\n\u001b[0;32m 97\u001b[0m arr, skipna\u001b[38;5;241m=\u001b[39mskipna, convert_na_value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 98\u001b[0m )\u001b[38;5;241m.\u001b[39mreshape(shape)\n\u001b[0;32m 100\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m np\u001b[38;5;241m.\u001b[39missubdtype(arr\u001b[38;5;241m.\u001b[39mdtype, np\u001b[38;5;241m.\u001b[39mfloating) \u001b[38;5;129;01mand\u001b[39;00m dtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miu\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m--> 101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_astype_float_to_int_nansafe\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 103\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m arr\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mobject\u001b[39m:\n\u001b[0;32m 104\u001b[0m \u001b[38;5;66;03m# if we have a datetime/timedelta array of objects\u001b[39;00m\n\u001b[0;32m 105\u001b[0m \u001b[38;5;66;03m# then coerce to datetime64[ns] and use DatetimeArray.astype\u001b[39;00m\n\u001b[0;32m 107\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mis_np_dtype(dtype, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mM\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\pandas\\core\\dtypes\\astype.py:145\u001b[0m, in \u001b[0;36m_astype_float_to_int_nansafe\u001b[1;34m(values, dtype, copy)\u001b[0m\n\u001b[0;32m 141\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 142\u001b[0m \u001b[38;5;124;03mastype with a check preventing converting NaN to an meaningless integer value.\u001b[39;00m\n\u001b[0;32m 143\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np\u001b[38;5;241m.\u001b[39misfinite(values)\u001b[38;5;241m.\u001b[39mall():\n\u001b[1;32m--> 145\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m IntCastingNaNError(\n\u001b[0;32m 146\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot convert non-finite values (NA or inf) to integer\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 147\u001b[0m )\n\u001b[0;32m 148\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mu\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m 149\u001b[0m \u001b[38;5;66;03m# GH#45151\u001b[39;00m\n\u001b[0;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (values \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mall():\n", - "\u001b[1;31mIntCastingNaNError\u001b[0m: Cannot convert non-finite values (NA or inf) to integer" - ] - } - ], - "source": [ - "#!/usr/bin/env python3\n", - "\"\"\"\n", - "Un-crop DLC-style predictions using per-frame bounding boxes.\n", - "Recursively scans directories for prediction CSVs and bbox CSVs.\n", - "\"\"\"\n", - "\n", - "import os\n", - "import sys\n", - "from pathlib import Path\n", - "from typing import List\n", - "import pandas as pd\n", - "\n", - "\n", - "# --------------------------\n", - "# Helpers\n", - "# --------------------------\n", - "\n", - "def _read_dlc_csv(path: str) -> pd.DataFrame:\n", - " \"\"\"Read DLC-style CSV with 3-row multi-index header.\"\"\"\n", - " return pd.read_csv(path, header=[0, 1, 2])\n", - "\n", - "def _read_bbox_csv(path: str) -> pd.DataFrame:\n", - " \"\"\"\n", - " Read a headerless bbox CSV with 5 columns: frame, x, y, h, w.\n", - " Assumes first cell (A1) is blank, so no header row.\n", - " \"\"\"\n", - " import pandas as pd\n", - "\n", - " # Force header=None so the first row is treated as data\n", - " df = pd.read_csv(path, header=None, names=[\"frame\", \"x\", \"y\", \"h\", \"w\"])\n", - "\n", - " # Coerce numeric values\n", - " for c in [\"frame\", \"x\", \"y\", \"h\", \"w\"]:\n", - " df[c] = pd.to_numeric(df[c], errors=\"coerce\")\n", - "\n", - " # If frames look 1-based, shift to 0-based\n", - " if df[\"frame\"].min() == 1:\n", - " df[\"frame\"] = df[\"frame\"] - 1\n", - "\n", - " # Final cast to int for frame\n", - " df[\"frame\"] = df[\"frame\"].round().astype(int)\n", - "\n", - " return df.sort_values(\"frame\").reset_index(drop=True)\n", - "\n", - "def _transform_predictions(df_pred: pd.DataFrame, df_bbox: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Apply uncropping transform.\"\"\"\n", - " n_frames = len(df_pred)\n", - " merged = pd.DataFrame({\"frame\": range(n_frames)}).merge(df_bbox, on=\"frame\", how=\"left\")\n", - " if merged[[\"x\", \"y\", \"h\", \"w\"]].isna().any().any():\n", - " missing = merged[merged[[\"x\",\"y\",\"h\",\"w\"]].isna().any(axis=1)][\"frame\"].tolist()\n", - " raise ValueError(f\"Missing bbox entries for frames (first 10 shown): {missing[:10]}\")\n", - "\n", - " x_off, y_off, h, w = [merged[c].to_numpy() for c in [\"x\", \"y\", \"h\", \"w\"]]\n", - " out = df_pred.copy()\n", - "\n", - " lvl0 = out.columns.get_level_values(0).unique()\n", - " lvl1 = out.columns.get_level_values(1).unique()\n", - "\n", - " for scorer in lvl0:\n", - " for bp in lvl1:\n", - " if (scorer, bp, \"x\") in out.columns:\n", - " x_vals = pd.to_numeric(out[(scorer, bp, \"x\")], errors=\"coerce\").to_numpy()\n", - " out[(scorer, bp, \"x\")] = (x_vals / 320.0) * w + x_off\n", - " if (scorer, bp, \"y\") in out.columns:\n", - " y_vals = pd.to_numeric(out[(scorer, bp, \"y\")], errors=\"coerce\").to_numpy()\n", - " out[(scorer, bp, \"y\")] = (y_vals / 320.0) * h + y_off\n", - " return out\n", - "\n", - "def _derive_output_path(pred_path: Path, output_dir: Path | None) -> Path:\n", - " \"\"\"Output path with _uncropped.csv suffix.\"\"\"\n", - " root = pred_path.stem\n", - " out_name = f\"{root}_uncropped.csv\"\n", - " if output_dir is not None:\n", - " output_dir.mkdir(parents=True, exist_ok=True)\n", - " return output_dir / out_name\n", - " return pred_path.parent / out_name\n", - "\n", - "\n", - "# --------------------------\n", - "# Discovery\n", - "# --------------------------\n", - "\n", - "def _rglob_csvs(root: Path) -> list[Path]:\n", - " return [p for p in root.rglob(\"*.csv\") if p.is_file()]\n", - "\n", - "def _filter_by_cam(paths: list[Path], cam: str) -> list[Path]:\n", - " cam_l = cam.lower()\n", - " return [p for p in paths if cam_l in p.name.lower()]\n", - "\n", - "# --------------------------\n", - "# Main processing\n", - "# --------------------------\n", - "\n", - "def process_all(camera_names: List[str], preds_root: str, bbox_root: str, output_dir: str | None = None) -> None:\n", - " # Resolve to absolute paths relative to the kernel's CWD\n", - " preds_root_p = Path(preds_root).expanduser().resolve()\n", - " bbox_root_p = Path(bbox_root).expanduser().resolve()\n", - " output_dir_p = Path(output_dir).expanduser().resolve() if output_dir else None\n", - "\n", - " print(f\"[INFO] CWD: {Path.cwd().resolve()}\")\n", - " print(f\"[INFO] preds_root: {preds_root_p}\")\n", - " print(f\"[INFO] bbox_root : {bbox_root_p}\")\n", - " if output_dir_p: print(f\"[INFO] output_dir: {output_dir_p}\")\n", - "\n", - " if not preds_root_p.exists():\n", - " raise FileNotFoundError(f\"preds_root does not exist: {preds_root_p}\")\n", - " if not bbox_root_p.exists():\n", - " raise FileNotFoundError(f\"bbox_root does not exist: {bbox_root_p}\")\n", - "\n", - " # Discover all CSVs once (recursive)\n", - " all_pred_csvs = _rglob_csvs(preds_root_p)\n", - " all_bbox_csvs = _rglob_csvs(bbox_root_p)\n", - "\n", - " if not all_pred_csvs:\n", - " print(f\"[WARN] No prediction CSVs found under {preds_root_p}\", file=sys.stderr)\n", - " if not all_bbox_csvs:\n", - " print(f\"[WARN] No bbox CSVs found under {bbox_root_p}\", file=sys.stderr)\n", - "\n", - " for cam in camera_names:\n", - " preds_paths = _filter_by_cam(all_pred_csvs, cam)\n", - " if not preds_paths:\n", - " print(f\"[WARN] No prediction CSVs matched camera '{cam}'\", file=sys.stderr)\n", - " continue\n", - "\n", - " bbox_matches = _filter_by_cam(all_bbox_csvs, cam)\n", - " if len(bbox_matches) != 1:\n", - " raise ValueError(\n", - " f\"Need exactly one bbox CSV for camera '{cam}', found {len(bbox_matches)}: \"\n", - " f\"{[str(p) for p in bbox_matches]}\"\n", - " )\n", - " bbox_path = bbox_matches[0]\n", - " df_bbox = _read_bbox_csv(bbox_path)\n", - "\n", - " print(f\"[INFO] Camera '{cam}': {len(preds_paths)} preds, bbox={bbox_path}\")\n", - "\n", - " for pred_path in preds_paths:\n", - " try:\n", - " df_pred = _read_dlc_csv(pred_path)\n", - " except Exception as e:\n", - " print(f\"[WARN] Skipping (bad DLC header?): {pred_path} ({e})\", file=sys.stderr)\n", - " continue\n", - " df_out = _transform_predictions(df_pred, df_bbox)\n", - " out_path = _derive_output_path(pred_path, output_dir_p)\n", - " df_out.to_csv(out_path, index=False)\n", - " print(f\" → Saved: {out_path}\")\n", - "\n", - "# --------------------------\n", - "# Hard-coded config\n", - "# --------------------------\n", - "\n", - "if __name__ == \"__main__\":\n", - " camera_names = [\"lBack\", \"lFront\", \"lTop\", \"rBack\", \"rFront\", \"rTop\"]\n", - " preds_root = \"./data/chickadee\" # can be relative; resolved & printed\n", - " bbox_root = \"./data/bounding_boxes\" # one bbox CSV per camera (name contains camera)\n", - " output_dir = \"./data/chickadee_uncropped\"\n", - "\n", - " process_all(camera_names, preds_root, bbox_root, output_dir)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5e90a3c4", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/ekf.ipynb b/ekf.ipynb deleted file mode 100644 index 05eaffb..0000000 --- a/ekf.ipynb +++ /dev/null @@ -1,964 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "d3bfa09d", - "metadata": {}, - "outputs": [], - "source": [ - "# Pose Smoothing with Dynamax EKF\n", - "# We load ensemble 2D pose predictions from 6 cameras (A–F), compute ensemble variance for observation noise, \n", - "# triangulate a geometric 3D latent state using calibrated camera parameters, and apply the Extended Kalman Smoother (EKF) using Dynamax.\n", - "\n", - "import os\n", - "import numpy as np\n", - "import pandas as pd\n", - "from pathlib import Path\n", - "from glob import glob\n", - "\n", - "from aniposelib.boards import CharucoBoard\n", - "from aniposelib.cameras import CameraGroup" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "d199b1f6", - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from jax import jit\n", - "\n", - "def _rodrigues(rvec):\n", - " \"\"\"OpenCV-style Rodrigues: rvec (3,) -> R (3,3).\"\"\"\n", - " theta = jnp.linalg.norm(rvec)\n", - " def small_angle(_):\n", - " # First-order approx: R ≈ I + [r]_x (good when theta ~ 0)\n", - " rx, ry, rz = rvec\n", - " K = jnp.array([[0.0, -rz, ry],\n", - " [rz, 0.0, -rx],\n", - " [-ry, rx, 0.0]])\n", - " return jnp.eye(3) + K\n", - " def general(_):\n", - " rx, ry, rz = rvec / theta\n", - " K = jnp.array([[0.0, -rz, ry],\n", - " [rz, 0.0, -rx],\n", - " [-ry, rx, 0.0]])\n", - " s = jnp.sin(theta)\n", - " c = jnp.cos(theta)\n", - " return jnp.eye(3) + s*K + (1.0 - c) * (K @ K)\n", - " return jax.lax.cond(theta < 1e-12, small_angle, general, operand=None)\n", - "\n", - "def _parse_dist(dist_coeffs):\n", - " \"\"\"\n", - " OpenCV pinhole distortion ordering:\n", - " [k1, k2, p1, p2, k3, k4, k5, k6, s1, s2, s3, s4, tx, ty] (tx,ty tilt optional)\n", - " We support up to s1..s4; tilt is ignored here.\n", - " \"\"\"\n", - " dc = jnp.pad(jnp.asarray(dist_coeffs, dtype=jnp.float64), (0, max(0, 14 - len(dist_coeffs)))) # length ≥ 14\n", - " k1, k2, p1, p2, k3, k4, k5, k6, s1, s2, s3, s4, tx, ty = [dc[i] for i in range(14)]\n", - " return dict(k1=k1, k2=k2, p1=p1, p2=p2, k3=k3, k4=k4, k5=k5, k6=k6, s1=s1, s2=s2, s3=s3, s4=s4)\n", - "\n", - "def make_jax_projection_fn(rvec, tvec, K, dist_coeffs):\n", - " \"\"\"\n", - " JAX-compatible replacement for cv2.projectPoints (standard pinhole model).\n", - "\n", - " Args\n", - " ----\n", - " rvec : (3,) Rodrigues rotation vector (world -> camera)\n", - " tvec : (3,) translation (world -> camera), same units as your world coords\n", - " K : (3,3) camera intrinsic matrix\n", - " [[fx, s, cx],\n", - " [ 0, fy, cy],\n", - " [ 0, 0, 1 ]]\n", - " dist_coeffs : iterable of distortion coefficients in OpenCV order\n", - " [k1, k2, p1, p2[, k3[, k4, k5, k6[, s1, s2, s3, s4[, tx, ty]]]]]\n", - "\n", - " Returns\n", - " -------\n", - " project(object_points) -> image_points\n", - " object_points: (..., 3)\n", - " image_points: (..., 2)\n", - " \"\"\"\n", - " # cache params as arrays\n", - " rvec = jnp.asarray(rvec, dtype=jnp.float64)\n", - " tvec = jnp.asarray(tvec, dtype=jnp.float64)\n", - " K = jnp.asarray(K, dtype=jnp.float64)\n", - " fx, fy, cx, cy, skew = K[0,0], K[1,1], K[0,2], K[1,2], K[0,1]\n", - " d = _parse_dist(dist_coeffs)\n", - " R = _rodrigues(rvec)\n", - "\n", - " @jit\n", - " def project(object_points):\n", - " # object_points: (..., 3)\n", - " Xw = jnp.asarray(object_points, dtype=jnp.float64)\n", - " # world -> camera\n", - " Xc = Xw @ R.T + tvec # (..., 3)\n", - " X, Y, Z = Xc[..., 0], Xc[..., 1], Xc[..., 2]\n", - "\n", - " # normalized coords\n", - " x = X / Z\n", - " y = Y / Z\n", - "\n", - " r2 = x*x + y*y\n", - " r4 = r2*r2\n", - " r6 = r4*r2\n", - " r8 = r4*r4\n", - " r10 = r8*r2\n", - " r12 = r6*r6\n", - "\n", - " radial = (\n", - " 1.0\n", - " + d[\"k1\"]*r2 + d[\"k2\"]*r4 + d[\"k3\"]*r6\n", - " + d[\"k4\"]*r8 + d[\"k5\"]*r10 + d[\"k6\"]*r12\n", - " )\n", - "\n", - " x_tan = 2.0*d[\"p1\"]*x*y + d[\"p2\"]*(r2 + 2.0*x*x)\n", - " y_tan = d[\"p1\"]*(r2 + 2.0*y*y) + 2.0*d[\"p2\"]*x*y\n", - "\n", - " # thin-prism\n", - " x_tp = d[\"s1\"]*r2 + d[\"s2\"]*r4\n", - " y_tp = d[\"s3\"]*r2 + d[\"s4\"]*r4\n", - "\n", - " xd = x * radial + x_tan + x_tp\n", - " yd = y * radial + y_tan + y_tp\n", - "\n", - " # intrinsics (allow nonzero skew)\n", - " u = fx * xd + skew * yd + cx\n", - " v = fy * yd + cy\n", - "\n", - " return jnp.stack([u, v], axis=-1) # (..., 2)\n", - "\n", - " return project" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35029956", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import cv2\n", - "import jax.numpy as jnp\n", - "from jax import jit, vmap\n", - "import numpy as np\n", - "import pandas as pd\n", - "from aniposelib.cameras import CameraGroup\n", - "from sklearn.decomposition import PCA\n", - "from typeguard import typechecked\n", - "from typing import Tuple, Callable\n", - "from eks.core import ensemble\n", - "from eks.marker_array import (\n", - " MarkerArray,\n", - " input_dfs_to_markerArray,\n", - " mA_to_stacked_array,\n", - " stacked_array_to_mA,\n", - ")\n", - "import jax\n", - "jax.config.update(\"jax_enable_x64\", True)\n", - "from eks.stats import compute_mahalanobis, compute_pca\n", - "from eks.utils import center_predictions, format_data, make_dlc_pandas_index\n", - "from eks.multicam_smoother import mA_compute_maha, initialize_kalman_filter_pca\n", - "\n", - "def fit_eks_multicam(\n", - " input_source: str | list,\n", - " save_dir: str,\n", - " bodypart_list: list | None = None,\n", - " smooth_param: float | list | None = None,\n", - " s_frames: list | None = None,\n", - " camera_names: list | None = None,\n", - " quantile_keep_pca: float = 95.0,\n", - " avg_mode: str = 'median',\n", - " var_mode: str = 'confidence_weighted_var',\n", - " inflate_vars: bool = False,\n", - " verbose: bool = False,\n", - " n_latent: int = 3,\n", - " backend: str = 'jax',\n", - " camgroup=None\n", - ") -> tuple:\n", - "\n", - " # Load and format input files\n", - " # NOTE: input_dfs_list is a list of camera-specific lists of Dataframes\n", - " input_dfs_list, keypoint_names = format_data(input_source, camera_names=camera_names)\n", - " if bodypart_list is None:\n", - " bodypart_list = keypoint_names\n", - "\n", - " marker_array = input_dfs_to_markerArray(input_dfs_list, bodypart_list, camera_names)\n", - "\n", - " # Run the ensemble Kalman smoother for multi-camera data\n", - " camera_dfs, smooth_params_final, h_cams, ys_3d = ensemble_kalman_smoother_multicam(\n", - " marker_array=marker_array,\n", - " keypoint_names=bodypart_list,\n", - " smooth_param=smooth_param,\n", - " quantile_keep_pca=quantile_keep_pca,\n", - " camera_names=camera_names,\n", - " s_frames=s_frames,\n", - " avg_mode=avg_mode,\n", - " var_mode=var_mode,\n", - " verbose=verbose,\n", - " inflate_vars=inflate_vars,\n", - " n_latent=n_latent,\n", - " backend=backend,\n", - " camgroup=camgroup\n", - " )\n", - " # Save output DataFrames to CSVs (one per camera view)\n", - " os.makedirs(save_dir, exist_ok=True)\n", - " for c, camera in enumerate(camera_names):\n", - " save_filename = f'multicam_{camera}_results.csv'\n", - " camera_dfs[c].to_csv(os.path.join(save_dir, save_filename))\n", - " return camera_dfs, smooth_params_final, input_dfs_list, bodypart_list, marker_array, h_cams, ys_3d\n", - "\n", - "def initialize_kalman_filter_geometric(ys: np.ndarray) -> Tuple[jnp.ndarray, ...]:\n", - " \"\"\"\n", - " Initialize Kalman filter parameters for geometric (3D) keypoints.\n", - "\n", - " Args:\n", - " ys: Array of shape (K, T, 3) — triangulated keypoints.\n", - "\n", - " Returns:\n", - " Tuple of Kalman filter parameters:\n", - " - m0s: (K, 3) initial means\n", - " - S0s: (K, 3, 3) initial covariances\n", - " - As: (K, 3, 3) transition matrices\n", - " - Qs: (K, 3, 3) process noise covariances\n", - " - Cs: (K, 3, 3) observation matrices\n", - " \"\"\"\n", - " K, T, D = ys.shape\n", - "\n", - " # Initial state means (can also use ys[:, 0, :] if preferred)\n", - " m0s = np.zeros((K, D))\n", - " # Use variance across time to estimate initial uncertainty\n", - " S0s = np.array([\n", - " np.diag([\n", - " np.nanvar(ys[k, :, d]) + 1e-4 # avoid degenerate matrices\n", - " for d in range(D)\n", - " ])\n", - " for k in range(K)\n", - " ]) # (K, 3, 3)\n", - "\n", - " # Identity matrices\n", - " As = np.tile(np.eye(D), (K, 1, 1))\n", - " Cs = np.tile(np.eye(D), (K, 1, 1))\n", - " Qs = np.tile(np.eye(D), (K, 1, 1)) * 1e-3 # small default process noise\n", - "\n", - " return (\n", - " jnp.array(m0s),\n", - " jnp.array(S0s),\n", - " jnp.array(As),\n", - " jnp.array(Qs),\n", - " jnp.array(Cs),\n", - " )\n", - "\n", - "\n", - "def ensemble_kalman_smoother_multicam(\n", - " marker_array: MarkerArray,\n", - " keypoint_names: list,\n", - " smooth_param: float | list | None = None,\n", - " quantile_keep_pca: float = 95.0,\n", - " camera_names: list | None = None,\n", - " s_frames: list | None = None,\n", - " avg_mode: str = 'median',\n", - " var_mode: str = 'confidence_weighted_var',\n", - " inflate_vars: bool = False,\n", - " inflate_vars_kwargs: dict = {},\n", - " verbose: bool = False,\n", - " pca_object: PCA | None = None,\n", - " n_latent: int = 3,\n", - " backend: str = 'jax',\n", - " camgroup=None,\n", - ") -> tuple:\n", - "\n", - " n_models, n_cameras, n_frames, n_keypoints, _ = marker_array.shape\n", - "\n", - " # === Ensemble Mean/Var per camera/keypoint ===\n", - " ensemble_marker_array = ensemble(marker_array, avg_mode=avg_mode, var_mode=var_mode)\n", - " emA_unsmoothed_preds = ensemble_marker_array.slice_fields(\"x\", \"y\")\n", - " emA_vars = ensemble_marker_array.slice_fields(\"var_x\", \"var_y\")\n", - " emA_likes = ensemble_marker_array.slice_fields(\"likelihood\")\n", - "\n", - " # === Triangulate all 3D positions ===\n", - " triangulated_3d_models = np.zeros((n_models, n_keypoints, n_frames, 3))\n", - " raw_array = marker_array.get_array()\n", - " for m in range(n_models):\n", - " for k in range(n_keypoints):\n", - " for t in range(n_frames):\n", - " xy_views = [raw_array[m, c, t, k, :2] for c in range(n_cameras)]\n", - " triangulated_3d_models[m, k, t] = camgroup.triangulate(np.array(xy_views))\n", - "\n", - " ys_3d = triangulated_3d_models.mean(axis=0) # (K, T, 3)\n", - " ensemble_vars_3d = triangulated_3d_models.var(axis=0) # (K, T, 3)\n", - "\n", - " # === Define a single multi-view h_fn (ℝ³ → ℝ^{2V}) ===\n", - " h_cams = []\n", - " for cam in camgroup.cameras:\n", - " print(cam.get_size())\n", - " rot = np.array(cam.get_rotation())\n", - " # Convert to Rodrigues vector if needed\n", - " rvec = cv2.Rodrigues(rot)[0].ravel() if rot.shape == (3, 3) else rot.ravel()\n", - " tvec = np.array(cam.get_translation()).ravel()\n", - " K = np.array(cam.get_camera_matrix())\n", - " dist = np.array(cam.get_distortions()).ravel() # distortion coeffs: k1,k2,p1,p2,k3,...\n", - "\n", - " h_cams.append(\n", - " make_jax_projection_fn(\n", - " jnp.array(rvec),\n", - " jnp.array(tvec),\n", - " jnp.array(K),\n", - " jnp.array(dist)\n", - " )\n", - " )\n", - "\n", - " def make_combined_h_fn(h_list):\n", - " def h_fn(x):\n", - " return jnp.concatenate([h(x) for h in h_list], axis=0)\n", - " return h_fn\n", - "\n", - " h_fn_combined = make_combined_h_fn(h_cams)\n", - "\n", - " # === Initialize Kalman filter ===\n", - "\n", - " m0s, S0s, As, cov_mats, Cs = initialize_kalman_filter_geometric(ys_3d)\n", - " m0s = np.array([ys_3d[k, :10].mean(axis=0) for k in range(n_keypoints)])\n", - " s_finals = np.full(len(keypoint_names), smooth_param) if np.isscalar(smooth_param) else np.asarray(smooth_param)\n", - "\n", - " # === Apply EKF in latent 3D space using projected 2D observations ===\n", - " ms_all, Vs_all = [], []\n", - " for k in range(n_keypoints):\n", - " y_proj = np.concatenate([vmap(h)(ys_3d[k]) for h in h_cams], axis=1) # (T, 2V)\n", - " r_proj = np.concatenate([ensemble_vars_3d[k][:, :2] for _ in range(n_cameras)], axis=1) # (T, 2V)\n", - " \n", - " ms, Vs = dynamax_ekf_smooth_routine(\n", - " y=ys_3d[k],\n", - " m0=m0s[k],\n", - " S0=S0s[k],\n", - " A=As[k],\n", - " Q=s_finals[k] * cov_mats[k],\n", - " C=np.eye(3),\n", - " ensemble_vars=ensemble_vars_3d[k],\n", - " f_fn=None,\n", - " h_fn=None, \n", - " )\n", - "\n", - " # ms, Vs = dynamax_ekf_smooth_routine(\n", - " # y=y_proj,\n", - " # m0=m0s[k],\n", - " # S0=S0s[k],\n", - " # A=As[k],\n", - " # Q=s_finals[k] * cov_mats[k],\n", - " # C=None,\n", - " # ensemble_vars=r_proj,\n", - " # f_fn=None,\n", - " # h_fn=h_fn_combined, \n", - " # )\n", - "\n", - "\n", - " ms_all.append(np.array(ms))\n", - " Vs_all.append(np.array(Vs))\n", - "\n", - " ms_all = np.stack(ms_all, axis=0) # (K, T, 3)\n", - " Vs_all = np.stack(Vs_all, axis=0) # (K, T, 3, 3)\n", - "\n", - "\n", - " # === Reproject smoothed 3D estimates back to each camera ===\n", - " camera_arrs = [[] for _ in camera_names]\n", - " for k, keypoint in enumerate(keypoint_names):\n", - " ms_k = ms_all[k]\n", - " Vs_k = Vs_all[k]\n", - " inflated_vars_k = ensemble_vars_3d[k]\n", - " \n", - " # rebuild a no-distortion projector per cam using the same rvec,tvec,K\n", - " print(\"camgroup order:\", [getattr(cam, \"name\", f\"cam{i}\") for i,cam in enumerate(camgroup.cameras)])\n", - " print(\"marker_array order:\", marker_array.get_camera_names() if hasattr(marker_array, \"get_camera_names\") else \"unknown\")\n", - "\n", - " # Compare one frame k=0,t=0\n", - " k=0; t=0\n", - " for c in range(len(camgroup.cameras)):\n", - " obs = emA_unsmoothed_preds.slice(\"keypoints\", k).slice(\"cameras\", c).get_array(squeeze=True)[t]\n", - " prj = np.array(h_cams[c](ms_all[k][t]))\n", - " print(f\"c{c}: obs={obs}, proj={prj}, diff={obs-prj}\")\n", - " \n", - " for c, camera in enumerate(camgroup.cameras):\n", - " #xy_proj = camera.project(ms_k).reshape(-1, 2)\n", - " xy_proj = np.array(vmap(h_cams[c])(ms_k)) # (T, 2)\n", - " xy_obs = emA_unsmoothed_preds.slice(\"keypoints\", k).slice(\"cameras\", c).get_array(squeeze=True) # (T,2)\n", - " resid = xy_obs - xy_proj # (T,2)\n", - " print(f\"cam {c} mean residual (px):\", resid.mean(axis=0), \" std:\", resid.std(axis=0))\n", - " try:\n", - " cov2d_proj = camera.project_covariance(ms_k, Vs_k)\n", - " var_x = cov2d_proj[:, 0, 0] + inflated_vars_k[:, 0]\n", - " var_y = cov2d_proj[:, 1, 1] + inflated_vars_k[:, 1]\n", - " except AttributeError:\n", - " var_x = np.full(ms_k.shape[0], np.nan)\n", - " var_y = np.full(ms_k.shape[0], np.nan)\n", - "\n", - " data_arr = camera_arrs[c]\n", - " data_arr.extend([\n", - " xy_proj[:, 0],\n", - " xy_proj[:, 1],\n", - " emA_likes.slice(\"keypoints\", k).slice(\"cameras\", c).get_array(squeeze=True),\n", - " emA_unsmoothed_preds.slice(\"keypoints\", k).slice(\"cameras\", c).slice_fields(\"x\").get_array(squeeze=True),\n", - " emA_unsmoothed_preds.slice(\"keypoints\", k).slice(\"cameras\", c).slice_fields(\"y\").get_array(squeeze=True),\n", - " emA_vars.slice(\"keypoints\", k).slice(\"cameras\", c).slice_fields(\"var_x\").get_array(squeeze=True),\n", - " emA_vars.slice(\"keypoints\", k).slice(\"cameras\", c).slice_fields(\"var_y\").get_array(squeeze=True),\n", - " var_x,\n", - " var_y,\n", - " ])\n", - "\n", - " # === Format output ===\n", - " labels = ['x', 'y', 'likelihood', 'x_ens_median', 'y_ens_median',\n", - " 'x_ens_var', 'y_ens_var', 'x_posterior_var', 'y_posterior_var']\n", - " pdindex = make_dlc_pandas_index(keypoint_names, labels=labels)\n", - " camera_dfs = [pd.DataFrame(np.asarray(arr).T, columns=pdindex) for arr in camera_arrs]\n", - "\n", - " return camera_dfs, s_finals, h_cams, ys_3d\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "80c930d4", - "metadata": {}, - "outputs": [], - "source": [ - "from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_smoother, extended_kalman_filter\n", - "from dynamax.nonlinear_gaussian_ssm.models import (\n", - " ParamsNLGSSM,\n", - ")\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "from typing import Union, Tuple, Callable\n", - "from typeguard import typechecked\n", - "\n", - "ArrayLike = Union[np.ndarray, jax.Array]\n", - "\n", - "def dynamax_ekf_smooth_routine(\n", - " y: ArrayLike,\n", - " m0: ArrayLike,\n", - " S0: ArrayLike,\n", - " A: ArrayLike,\n", - " Q: ArrayLike,\n", - " C: ArrayLike | None,\n", - " ensemble_vars: ArrayLike, # shape (T, obs_dim)\n", - " f_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None,\n", - " h_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None\n", - ") -> Tuple[jnp.ndarray, jnp.ndarray]:\n", - " \"\"\"\n", - " Extended Kalman smoother using the Dynamax nonlinear interface,\n", - " allowing for time-varying observation noise.\n", - "\n", - " By default, uses linear dynamics and emissions: f(x) = Ax, h(x) = Cx.\n", - "\n", - " Args:\n", - " y: (T, obs_dim) observation sequence.\n", - " m0: (state_dim,) initial mean.\n", - " S0: (state_dim, state_dim) initial covariance.\n", - " A: (state_dim, state_dim) dynamics matrix.\n", - " Q: (state_dim, state_dim) process noise covariance.\n", - " C: (obs_dim, state_dim) emission matrix (optional).\n", - " ensemble_vars: (T, obs_dim) per-timestep observation noise variance.\n", - " f_fn: optional dynamics function f(x).\n", - " h_fn: optional emission function h(x).\n", - "\n", - " Returns:\n", - " smoothed_means: (T, state_dim)\n", - " smoothed_covariances: (T, state_dim, state_dim)\n", - " \"\"\"\n", - " y, m0, S0, A, Q, ensemble_vars = map(jnp.asarray, (y, m0, S0, A, Q, ensemble_vars))\n", - " C = jnp.asarray(C) if C is not None else None\n", - "\n", - " if f_fn is None:\n", - " f_fn = lambda x: A @ x\n", - " if h_fn is None:\n", - " if C is None:\n", - " raise ValueError(\"Must provide either emission matrix C or a nonlinear emission function h_fn.\")\n", - " h_fn = lambda x: C @ x\n", - " # Dynamically determine obs_dim from h_fn output\n", - " obs_dim = y.shape[1]\n", - " R_t = jnp.stack([jnp.diag(var_t[:obs_dim]) for var_t in ensemble_vars], axis=0) # shape (T, obs_dim, obs_dim)\n", - " params = ParamsNLGSSM(\n", - " initial_mean=m0,\n", - " initial_covariance=S0,\n", - " dynamics_function=f_fn,\n", - " dynamics_covariance=Q,\n", - " emission_function=h_fn,\n", - " emission_covariance=R_t,\n", - " )\n", - " #with jax.disable_jit():\n", - " posterior = extended_kalman_smoother(params, y)\n", - " return posterior.filtered_means, posterior.filtered_covariances" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "49d1e9fc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[2816, 1408]\n", - "[2816, 1408]\n", - "[2816, 1696]\n", - "[2816, 1408]\n", - "[2816, 1408]\n", - "[2816, 1696]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-2.00586461 -0.04818341] std: [4.65336109 3.65989187]\n", - "cam 1 mean residual (px): [-0.400378 -0.55426035] std: [4.47750558 4.57930606]\n", - "cam 2 mean residual (px): [-4.25086082 4.0237834 ] std: [8.05280531 7.82694513]\n", - "cam 3 mean residual (px): [-2.11519798 -1.08246869] std: [3.01945266 2.22987753]\n", - "cam 4 mean residual (px): [-2.22338437 -1.25190865] std: [4.80872571 3.11726965]\n", - "cam 5 mean residual (px): [-2.02939332 0.71905944] std: [3.34086571 2.34249222]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-16.86640789 27.81901492] std: [21.06616917 10.16015718]\n", - "cam 1 mean residual (px): [ 1.27670095 31.39372914] std: [34.65760412 12.1620663 ]\n", - "cam 2 mean residual (px): [-16.98771473 14.33849501] std: [23.61536901 17.70184138]\n", - "cam 3 mean residual (px): [-4.87177248 25.47695466] std: [24.57880916 7.24002574]\n", - "cam 4 mean residual (px): [13.91017088 28.7900739 ] std: [27.64924222 9.96280995]\n", - "cam 5 mean residual (px): [ 6.45919793 28.26663876] std: [26.92549736 14.30860297]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-43.94862492 38.07620084] std: [39.50286349 37.61951063]\n", - "cam 1 mean residual (px): [-22.92389838 34.12073141] std: [60.91164161 53.11203159]\n", - "cam 2 mean residual (px): [-52.44768635 10.2830962 ] std: [37.41010619 48.59247171]\n", - "cam 3 mean residual (px): [ 6.26194338 35.10307059] std: [43.47138297 32.47832518]\n", - "cam 4 mean residual (px): [50.13023656 32.6284639 ] std: [40.05778705 44.31061814]\n", - "cam 5 mean residual (px): [35.77109648 40.20139873] std: [39.75000321 38.20310405]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-30.25113722 -40.99577707] std: [28.65042933 44.07906918]\n", - "cam 1 mean residual (px): [-34.98391788 -64.6586051 ] std: [41.16062318 54.55639073]\n", - "cam 2 mean residual (px): [-40.71146852 -33.90442907] std: [28.85536403 48.46962482]\n", - "cam 3 mean residual (px): [ 12.82158102 -37.16530327] std: [31.12152066 36.56094066]\n", - "cam 4 mean residual (px): [ 35.36500756 -54.25401178] std: [24.95655536 42.88833274]\n", - "cam 5 mean residual (px): [ 30.04054017 -24.99231918] std: [25.92538582 32.81484657]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-77.03501222 -14.06791509] std: [49.55689821 54.99093352]\n", - "cam 1 mean residual (px): [-121.43571681 -59.02314783] std: [72.04167148 74.07204966]\n", - "cam 2 mean residual (px): [-127.85352757 -25.78346588] std: [55.91121001 66.00386657]\n", - "cam 3 mean residual (px): [ 56.73325089 -15.44272932] std: [49.70291459 45.57550667]\n", - "cam 4 mean residual (px): [118.81426752 -50.95219002] std: [53.48983305 57.99132334]\n", - "cam 5 mean residual (px): [107.70158039 -10.72051293] std: [51.63925117 43.73291347]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-107.73449426 -25.13505912] std: [75.17239777 57.95634324]\n", - "cam 1 mean residual (px): [-265.44229017 -111.57601448] std: [122.9936259 83.63128374]\n", - "cam 2 mean residual (px): [-226.29392882 -25.78944817] std: [96.08242022 79.07648141]\n", - "cam 3 mean residual (px): [127.76895727 -33.62190642] std: [79.13346773 49.20533222]\n", - "cam 4 mean residual (px): [ 214.87055494 -109.10578232] std: [89.70933629 69.53311219]\n", - "cam 5 mean residual (px): [206.38994512 -45.90479289] std: [90.30455758 59.17178075]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-20.05716054 15.86566925] std: [15.31337375 8.3487262 ]\n", - "cam 1 mean residual (px): [-1.81428365 15.14691077] std: [27.12656552 11.97268196]\n", - "cam 2 mean residual (px): [-19.78322016 4.58641567] std: [18.58410842 14.75168836]\n", - "cam 3 mean residual (px): [-3.86281001 15.00418717] std: [19.08585247 6.15065234]\n", - "cam 4 mean residual (px): [17.14123839 15.34504954] std: [22.33711894 10.23454471]\n", - "cam 5 mean residual (px): [ 8.98992669 20.32010298] std: [21.78587376 11.17949793]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-31.64188096 3.83078018] std: [15.8083754 29.82208695]\n", - "cam 1 mean residual (px): [-10.15232963 -4.38954052] std: [28.9432177 39.41426118]\n", - "cam 2 mean residual (px): [-31.68059102 -9.37822028] std: [18.92196291 30.33923798]\n", - "cam 3 mean residual (px): [-1.09225093 4.94040225] std: [20.64178028 25.81219085]\n", - "cam 4 mean residual (px): [30.50141194 0.16262853] std: [18.36859319 33.19362351]\n", - "cam 5 mean residual (px): [18.75977526 14.56760995] std: [20.02123306 24.32289152]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-45.63755893 7.96823886] std: [23.85887077 47.18275982]\n", - "cam 1 mean residual (px): [-16.02835436 -3.75159711] std: [39.33860605 61.70858653]\n", - "cam 2 mean residual (px): [-45.67716142 -13.88146324] std: [25.67134238 44.52143834]\n", - "cam 3 mean residual (px): [0.24121305 9.63683264] std: [28.12789429 41.08081596]\n", - "cam 4 mean residual (px): [46.60601802 2.736363 ] std: [24.2237432 51.62540431]\n", - "cam 5 mean residual (px): [29.43244706 22.25706045] std: [25.84400445 37.5825051 ]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-53.01464164 -54.71981158] std: [35.82581259 50.09749436]\n", - "cam 1 mean residual (px): [-66.18716647 -95.0638527 ] std: [51.59764782 65.77213503]\n", - "cam 2 mean residual (px): [-73.51290185 -50.69122393] std: [39.62568939 55.19435951]\n", - "cam 3 mean residual (px): [ 27.11963099 -49.21706219] std: [36.57808591 41.87582014]\n", - "cam 4 mean residual (px): [ 68.78763228 -78.53982291] std: [37.86301571 52.00905446]\n", - "cam 5 mean residual (px): [ 58.83761913 -33.31484122] std: [36.69507741 37.25871558]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-27.48803853 -68.9780643 ] std: [28.86503036 51.00612907]\n", - "cam 1 mean residual (px): [-23.11555466 -96.6514439 ] std: [41.02268415 66.51927875]\n", - "cam 2 mean residual (px): [-28.40975361 -54.02259942] std: [27.75373633 53.37846493]\n", - "cam 3 mean residual (px): [ 4.31867099 -60.81992412] std: [30.67331701 42.8880862 ]\n", - "cam 4 mean residual (px): [ 26.01591817 -79.63224592] std: [27.38312343 52.87647067]\n", - "cam 5 mean residual (px): [ 18.82322649 -40.97111631] std: [26.47418793 36.46591139]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-10.83391517 17.11601076] std: [18.66345732 9.56010962]\n", - "cam 1 mean residual (px): [-5.64380896 18.17178012] std: [25.47195893 12.95737098]\n", - "cam 2 mean residual (px): [-15.81379777 12.05263804] std: [16.14768232 18.77020813]\n", - "cam 3 mean residual (px): [ 0.36446203 14.60267179] std: [19.09077463 7.12737293]\n", - "cam 4 mean residual (px): [ 9.74926083 15.46465836] std: [20.28911007 10.00831862]\n", - "cam 5 mean residual (px): [ 7.08709365 15.35813022] std: [18.31135053 13.90714397]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-14.70415945 5.94012917] std: [26.89201462 31.44245168]\n", - "cam 1 mean residual (px): [-18.58007466 1.22947655] std: [33.45409109 40.48061898]\n", - "cam 2 mean residual (px): [-25.29845403 4.64110098] std: [20.5573175 37.54246641]\n", - "cam 3 mean residual (px): [7.57065902 3.82773206] std: [26.09452606 26.51365765]\n", - "cam 4 mean residual (px): [17.51522829 -0.06323189] std: [22.11948331 32.61460809]\n", - "cam 5 mean residual (px): [16.29940983 4.89736708] std: [19.6988043 28.20827179]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-23.41938215 14.44121108] std: [40.35550002 46.97320393]\n", - "cam 1 mean residual (px): [-33.81880448 7.43041355] std: [51.00221251 61.34696419]\n", - "cam 2 mean residual (px): [-41.89819398 8.90692818] std: [29.37183653 55.51384624]\n", - "cam 3 mean residual (px): [16.33399736 10.86549482] std: [39.66516442 39.85803748]\n", - "cam 4 mean residual (px): [32.39365723 4.64546822] std: [32.76761261 49.46399521]\n", - "cam 5 mean residual (px): [31.08093032 9.83292101] std: [29.15675992 41.8012197 ]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-45.45435744 -54.61607257] std: [39.64273023 48.26931621]\n", - "cam 1 mean residual (px): [-74.85878053 -94.18487352] std: [56.23314573 63.06956186]\n", - "cam 2 mean residual (px): [-73.46970474 -43.86349541] std: [40.39699254 56.30734317]\n", - "cam 3 mean residual (px): [ 34.25225743 -51.24268605] std: [40.82544648 40.30805024]\n", - "cam 4 mean residual (px): [ 64.82452392 -81.53024431] std: [39.03215943 49.94257993]\n", - "cam 5 mean residual (px): [ 60.9854295 -40.16336504] std: [37.80814396 39.1601131 ]\n", - "camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']\n", - "marker_array order: unknown\n", - "c0: obs=[1750.93206108 391.50040874], proj=[1751.38181764 385.1042529 ], diff=[-0.44975657 6.39615584]\n", - "c1: obs=[1594.26820338 463.19929254], proj=[1609.76057455 480.0672997 ], diff=[-15.49237117 -16.86800716]\n", - "c2: obs=[1691.0761517 778.44187536], proj=[1692.78650143 778.0042768 ], diff=[-1.71034972 0.43759856]\n", - "c3: obs=[1280.40348787 326.09742532], proj=[1283.34233378 326.58320903], diff=[-2.93884591 -0.48578371]\n", - "c4: obs=[1006.14625554 404.40209646], proj=[1015.72107357 401.72426204], diff=[-9.57481803 2.67783442]\n", - "c5: obs=[1113.13084931 556.79887314], proj=[1121.75050458 544.48723124], diff=[-8.61965526 12.3116419 ]\n", - "cam 0 mean residual (px): [-17.62599391 -66.77606412] std: [34.05057432 49.62368119]\n", - "cam 1 mean residual (px): [-31.71298444 -92.47139918] std: [45.62965855 63.94357298]\n", - "cam 2 mean residual (px): [-27.17331358 -44.40016944] std: [27.94681585 53.8826323 ]\n", - "cam 3 mean residual (px): [ 12.01686193 -61.44974139] std: [35.06009354 42.20340585]\n", - "cam 4 mean residual (px): [ 19.79745444 -80.43697436] std: [28.57462023 51.59245566]\n", - "cam 5 mean residual (px): [ 19.93228684 -47.6077619 ] std: [26.79455762 39.85515655]\n", - "see example EKS output at ./outputs/multicam_rightFoot.pdf\n" - ] - } - ], - "source": [ - "from eks.utils import plot_results\n", - "\n", - "input_source = \"./data/chickadee_uncropped\"\n", - "camera_names = [\"lBack\", \"lFront\", \"lTop\", \"rBack\", \"rFront\", \"rTop\"]\n", - "keypoints = [\"topBeak\", \"topHead\", \"backHead\", \"centerChes\", \"baseTail\", \"tipTail\", \"leftEye\", \"leftNeck\", \"leftWing\", \"leftAnkle\", \"leftFoot\", \"rightEye\", \"rightNeck\", \"rightWing\", \"rightAnkle\", \"rightFoot\"]\n", - "camgroup = CameraGroup.load(\"./data/chickadee/calibration.toml\")\n", - "# input_source = \"./data/fly\"\n", - "# camera_names = [\"Cam-A\", \"Cam-B\", \"Cam-C\", \"Cam-D\", \"Cam-E\", \"Cam-F\"]\n", - "# keypoints = [\"L1A\", \"L1B\"]\n", - "# camgroup = CameraGroup.load(\"./data/fly/calibration.toml\")\n", - "\n", - "save_dir = \"./outputs/\"\n", - "\n", - "# Load calibration file\n", - "\n", - "\n", - "camera_dfs, s_finals, input_dfs, bodypart_list, marker_array, h_cams, ys_3d = fit_eks_multicam(\n", - " input_source=input_source,\n", - " save_dir=save_dir,\n", - " bodypart_list=keypoints,\n", - " smooth_param=10,\n", - " camera_names=camera_names,\n", - " quantile_keep_pca=95,\n", - " verbose=True,\n", - " inflate_vars=False,\n", - " n_latent=3,\n", - " backend=\"dynamax-ekf\",\n", - " camgroup=camgroup\n", - ")\n", - "\n", - "keypoint_i = -1\n", - "camera_c = -1\n", - "plot_results(\n", - " output_df=camera_dfs[camera_c],\n", - " input_dfs_list=input_dfs[camera_c],\n", - " key=f'{bodypart_list[keypoint_i]}',\n", - " idxs=(0, 500),\n", - " s_final=s_finals[keypoint_i],\n", - " nll_values=None,\n", - " save_dir=save_dir,\n", - " smoother_type='multicam',\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "63fd6ec4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Triangulated 3D point: [-0.05837184 -0.28620351 0.42230134]\n", - "Camera 0: reprojection error = 5.255 pixels\n", - "Camera 1: reprojection error = 5.623 pixels\n", - "Camera 2: reprojection error = 2.997 pixels\n", - "Camera 3: reprojection error = 2.839 pixels\n", - "Camera 4: reprojection error = 0.815 pixels\n", - "Camera 5: reprojection error = 6.390 pixels\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import jax.numpy as jnp\n", - "\n", - "# === Settings ===\n", - "frame_idx = 120\n", - "keypoint_idx = 1\n", - "model_idx = 0 # or average across models\n", - "\n", - "# === Step 1: Extract 2D predictions from all cameras ===\n", - "raw_array = marker_array.get_array() # (n_models, n_cameras, n_frames, n_keypoints, 2+)\n", - "xy_views = [raw_array[model_idx, c, frame_idx, keypoint_idx, :2] for c in range(len(camgroup.cameras))]\n", - "xy_views_np = np.stack(xy_views) # (n_cameras, 2)\n", - "\n", - "# === Step 2: Triangulate to get 3D point ===\n", - "x_3d = camgroup.triangulate(xy_views_np) # shape (3,)\n", - "\n", - "print(f\"Triangulated 3D point: {x_3d}\")\n", - "\n", - "# === Step 3: Reproject into each view ===\n", - "projected_views = [h(jnp.array(x_3d)) for h in h_cams]\n", - "projected_views_np = np.stack([np.array(p) for p in projected_views]) # (n_cameras, 2)\n", - "\n", - "# === Step 4: Compute reprojection error per view ===\n", - "for i, (orig, proj) in enumerate(zip(xy_views_np, projected_views_np)):\n", - " err = np.linalg.norm(orig - proj)\n", - " print(f\"Camera {i}: reprojection error = {err:.3f} pixels\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "6cc22bf1", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Triangulated point: [-0.09515544 -0.22536958 0.36407084]\n", - "PCA-reconstructed point:[-0.09511244 -0.22479253 0.36471556]\n", - "\n", - "Camera 0: tri e_custom=4.01px, e_cg=4.01px, projΔ=0.00px | pca e_custom=3.46px, e_cg=3.46px, projΔ=0.00px\n", - "Camera 1: tri e_custom=12.93px, e_cg=12.93px, projΔ=0.00px | pca e_custom=13.93px, e_cg=13.93px, projΔ=0.00px\n", - "Camera 2: tri e_custom=2.23px, e_cg=2.23px, projΔ=0.00px | pca e_custom=1.27px, e_cg=1.27px, projΔ=0.00px\n", - "Camera 3: tri e_custom=2.84px, e_cg=2.84px, projΔ=0.00px | pca e_custom=2.60px, e_cg=2.60px, projΔ=0.00px\n", - "Camera 4: tri e_custom=5.52px, e_cg=5.52px, projΔ=0.00px | pca e_custom=4.95px, e_cg=4.95px, projΔ=0.00px\n", - "Camera 5: tri e_custom=1.84px, e_cg=1.84px, projΔ=0.00px | pca e_custom=2.30px, e_cg=2.30px, projΔ=0.00px\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "from sklearn.decomposition import PCA\n", - "import jax.numpy as jnp\n", - "\n", - "# === Settings ===\n", - "frame_idx = 45\n", - "keypoint_idx = 0\n", - "model_idx = 0 # pick one or average if needed\n", - "\n", - "# --- helper: make camgroup.project output (n_cams, 2) ---\n", - "def cg_project_point(camgroup, x3d):\n", - " \"\"\"Project a single 3D point with aniposelib CameraGroup, return (n_cams, 2).\"\"\"\n", - " x = np.asarray(x3d, dtype=float).reshape(1, 3)\n", - " out = camgroup.project(x) # library-dependent shape\n", - " # Try common shapes: (n_cams, 1, 2), (n_cams, 2), dict of cam->(1,2)\n", - " if isinstance(out, dict):\n", - " proj = np.stack([np.asarray(out[cam.name])[0] for cam in camgroup.cameras], axis=0)\n", - " else:\n", - " arr = np.asarray(out)\n", - " if arr.ndim == 3 and arr.shape[1] == 1 and arr.shape[2] == 2:\n", - " proj = arr[:, 0, :]\n", - " elif arr.ndim == 2 and arr.shape == (len(camgroup.cameras), 2):\n", - " proj = arr\n", - " elif arr.ndim == 2 and arr.shape == (2, len(camgroup.cameras)):\n", - " proj = arr.T\n", - " else:\n", - " raise ValueError(f\"Unexpected camgroup.project shape: {arr.shape}\")\n", - " return proj # (n_cams, 2)\n", - "\n", - "# === 1) 2D observations for this keypoint+frame from all cameras ===\n", - "raw_array = marker_array.get_array() # (n_models, n_cams, n_frames, n_keypoints, 2+)\n", - "n_cams = len(camgroup.cameras)\n", - "xy_views_np = np.stack(\n", - " [raw_array[model_idx, c, frame_idx, keypoint_idx, :2] for c in range(n_cams)],\n", - " axis=0\n", - ") # (n_cams, 2)\n", - "\n", - "# === 2) Triangulated 3D point ===\n", - "x_triang = camgroup.triangulate(xy_views_np) # (3,)\n", - "\n", - "# === 3) PCA-reconstructed 3D point ===\n", - "# Assume ys_3d: (K, T, 3)\n", - "ys_3d_reshaped = ys_3d.reshape(-1, 3)\n", - "pca = PCA(n_components=3)\n", - "Z = pca.fit_transform(ys_3d_reshaped)\n", - "ys_3d_pca = pca.inverse_transform(Z).reshape(ys_3d.shape)\n", - "x_pca = ys_3d_pca[keypoint_idx, frame_idx] # (3,)\n", - "\n", - "# === 4) Project both 3D points with:\n", - "# (a) your custom JAX projectors h_cams\n", - "# (b) camgroup.project (OpenCV-based)\n", - "reproj_triang_custom = np.stack([np.array(h(jnp.array(x_triang))) for h in h_cams], axis=0) # (n_cams, 2)\n", - "reproj_pca_custom = np.stack([np.array(h(jnp.array(x_pca))) for h in h_cams], axis=0)\n", - "\n", - "reproj_triang_cg = cg_project_point(camgroup, x_triang) # (n_cams, 2)\n", - "reproj_pca_cg = cg_project_point(camgroup, x_pca)\n", - "\n", - "# === 5) Print comparison ===\n", - "print(f\"Triangulated point: {x_triang}\")\n", - "print(f\"PCA-reconstructed point:{x_pca}\\n\")\n", - "\n", - "for i in range(n_cams):\n", - " obs = xy_views_np[i]\n", - "\n", - " # errors vs observations\n", - " err_tri_custom = np.linalg.norm(obs - reproj_triang_custom[i])\n", - " err_tri_cg = np.linalg.norm(obs - reproj_triang_cg[i])\n", - "\n", - " err_pca_custom = np.linalg.norm(obs - reproj_pca_custom[i])\n", - " err_pca_cg = np.linalg.norm(obs - reproj_pca_cg[i])\n", - "\n", - " # difference between projectors (should be ~0 if both are consistent)\n", - " diff_tri = np.linalg.norm(reproj_triang_custom[i] - reproj_triang_cg[i])\n", - " diff_pca = np.linalg.norm(reproj_pca_custom[i] - reproj_pca_cg[i])\n", - "\n", - " print(\n", - " f\"Camera {i}: \"\n", - " f\"tri e_custom={err_tri_custom:.2f}px, e_cg={err_tri_cg:.2f}px, projΔ={diff_tri:.2f}px | \"\n", - " f\"pca e_custom={err_pca_custom:.2f}px, e_cg={err_pca_cg:.2f}px, projΔ={diff_pca:.2f}px\"\n", - " )\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 4519048c40276e501b44a0682ac240a8783a807f Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Fri, 3 Oct 2025 14:13:52 -0400 Subject: [PATCH 07/11] added dynamax to install requirements --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 3e2b1ca..94ec5ed 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ def get_version(rel_path): 'sleap_io', 'jax', 'jaxlib', + 'dynamax' ] # additional requirements From 8b8149908de62fcd4d8a0e230fd5248fcad674ab Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Fri, 3 Oct 2025 14:20:45 -0400 Subject: [PATCH 08/11] pytest bugfix --- setup.py | 2 +- tests/test_multicam_smoother.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 94ec5ed..8f74105 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ def get_version(rel_path): long_description_content_type='text/markdown', author='Cole Hurwitz', author_email='', - url='http://www.github.com/colehurwitz/eks', + url='http://www.github.com/paninski-lab/eks', packages=['eks'], install_requires=install_requires, extras_require=extras_require, diff --git a/tests/test_multicam_smoother.py b/tests/test_multicam_smoother.py index 99f5721..c36f0c2 100644 --- a/tests/test_multicam_smoother.py +++ b/tests/test_multicam_smoother.py @@ -46,7 +46,7 @@ def test_ensemble_kalman_smoother_multicam(): assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" for k in range(len(keypoint_names)): - assert smooth_params_final[c] == smooth_param, \ + assert smooth_params_final[k] == smooth_param, \ f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ f"got {smooth_params_final}" @@ -71,7 +71,7 @@ def test_ensemble_kalman_smoother_multicam(): assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" for k in range(len(keypoint_names)): - assert smooth_params_final[c] == smooth_param, \ + assert smooth_params_final[k] == smooth_param, \ f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ f"got {smooth_params_final}" @@ -98,7 +98,7 @@ def test_ensemble_kalman_smoother_multicam(): assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" for k in range(len(keypoint_names)): - assert smooth_params_final[c] == smooth_param, \ + assert smooth_params_final[k] == smooth_param, \ f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ f"got {smooth_params_final}" From 89de5185463d8014ed7c717b1a6b0a33923cfe67 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Sun, 5 Oct 2025 22:41:47 -0400 Subject: [PATCH 09/11] refactoring optimize smooth funcs --- eks/core.py | 258 ++++++++++++++++++++++++-------------- eks/multicam_smoother.py | 5 +- eks/singlecam_smoother.py | 4 +- 3 files changed, 169 insertions(+), 98 deletions(-) diff --git a/eks/core.py b/eks/core.py index c99be07..305d243 100644 --- a/eks/core.py +++ b/eks/core.py @@ -128,7 +128,7 @@ def params_nlgssm_for_keypoint(m0, S0, Q, s, R, f_fn, h_fn) -> ParamsNLGSSM: @typechecked -def optimize_smooth_param( +def run_kalman_smoother( Qs: jnp.ndarray, # (K, D, D) ys: np.ndarray, # (K, T, obs) m0s: jnp.ndarray, # (K, D) @@ -223,98 +223,8 @@ def optimize_smooth_param( else: s_finals[:] = np.asarray(smooth_param, dtype=float) else: - optimizer = optax.adam(float(lr)) - s_bounds_log_j = jnp.array(s_bounds_log, dtype=jnp.float32) - tol_j = float(tol) - - def _params_linear(m0, S0, A, Q_base, s, R_any, C): - f_fn = (lambda x, A=A: A @ x) # linear dynamics - h_fn = (lambda x, C=C: C @ x) # linear emission - return params_nlgssm_for_keypoint(m0, S0, Q_base, s, R_any, f_fn, h_fn) - - # NLL for a single keypoint with time-varying R_t - def _nll_one_keypoint(log_s, y_k, m0_k, S0_k, A_k, Q_k, C_k, R_k_tv): - s = jnp.exp(jnp.clip(log_s, s_bounds_log_j[0], s_bounds_log_j[1])) - params = _params_linear(m0_k, S0_k, A_k, Q_k, s, R_k_tv, C_k) - post = extended_kalman_filter(params, jnp.asarray(y_k)) - return -post.marginal_loglik - - # Sum NLL across all keypoints in the block - def _nll_block(log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): - nlls = vmap(_nll_one_keypoint, in_axes=(None, 0, 0, 0, 0, 0, 0, 0))( - log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv - ) - return jnp.sum(nlls) - - @jit - def _opt_step(log_s, opt_state, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): - loss, grad = value_and_grad(_nll_block)( - log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv - ) - updates, opt_state = optimizer.update(grad, opt_state) - log_s = optax.apply_updates(log_s, updates) - return log_s, opt_state, loss - - @jit - def _run_tol_loop(log_s0, opt_state0, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): - def cond(carry): - _, _, prev_loss, iters, done = carry - return jnp.logical_and(~done, iters < safety_cap) - - def body(carry): - log_s, opt_state, prev_loss, iters, _ = carry - log_s, opt_state, loss = _opt_step( - log_s, opt_state, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv - ) - rel_tol = tol_j * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12))) - done = jnp.where( - jnp.isfinite(prev_loss), - jnp.linalg.norm(loss - prev_loss) < (rel_tol + 1e-6), - False - ) - return (log_s, opt_state, loss, iters + 1, done) - - return lax.while_loop( - cond, body, (log_s0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) - ) - - # Optimize per block (shared s within each block) - for block in blocks: - sel = jnp.asarray(block, dtype=int) - - # Crop frames for the loss (both y and R_t) if s_frames is provided - if s_frames and len(s_frames) > 0: - # Crop both y and R_t using the same frame spec -- each (T', obs) - y_block_list = [crop_frames(ys[int(k)], s_frames) for k in block] - R_block_list = [crop_R(Rs[int(k)], s_frames) for k in block] - - # Stack and jnp - y_block = jnp.asarray(np.stack(y_block_list, axis=0)) # (B, T', obs) - R_block = jnp.asarray(np.stack(R_block_list, axis=0)) # (B, T', obs, obs) - else: - y_block = ys_j[sel] # (B, T, obs) - R_block = Rs_j[sel] # (B, T, obs, obs) - - m0_block = m0s_j[sel] - S0_block = S0s_j[sel] - A_block = As_j[sel] - Q_block = Qs_j[sel] - C_block = Cs_j[sel] - - s0 = float(np.mean([s_guess_per_k[k] for k in block])) - log_s0 = jnp.array(np.log(max(s0, 1e-6)), dtype=jnp.float32) - opt_state0 = optimizer.init(log_s0) - - log_s_f, opt_state_f, last_loss, iters_f, _done = _run_tol_loop( - log_s0, opt_state0, y_block, m0_block, S0_block, A_block, Q_block, C_block, - R_block - ) - s_star = float(jnp.exp(jnp.clip(log_s_f, s_bounds_log_j[0], s_bounds_log_j[1]))) - for k in block: - s_finals[k] = s_star - if verbose: - print(f"[Block {block}] s={s_star:.6g}, iters={int(iters_f)}, " - f"NLL={float(last_loss):.6f}") + optimize_smooth_param(As_j, Cs_j, Qs_j, Rs, Rs_j, S0s_j, blocks, lr, m0s_j, s_bounds_log, + s_finals, s_frames, s_guess_per_k, tol, verbose, ys, ys_j) # -------------------- final smoother pass (full R_t) -------------------- def _params_linear_for_k(k: int, s_val: float): @@ -338,3 +248,165 @@ def _params_linear_for_k(k: int, s_val: float): ms = np.stack(means_list, axis=0) # (K, T, D) Vs = np.stack(covs_list, axis=0) # (K, T, D, D) return s_finals, ms, Vs + + +from jax import jit, lax, value_and_grad +import jax +import jax.numpy as jnp +import numpy as np +import optax + +def optimize_smooth_param( + As, # jnp.ndarray, (K, D, D) + Cs, # jnp.ndarray, (K, obs, D) + Qs, # jnp.ndarray, (K, D, D) + Rs, # jnp.ndarray, (K, T, obs, obs) + S0s, # jnp.ndarray, (K, D, D) + blocks, # Optional[List[List[int]]] + lr: float, + m0s, # jnp.ndarray, (K, D) + s_bounds_log: tuple, + s_finals: np.ndarray, # (K,), filled in-place + s_frames, # Optional[List] + s_guess_per_k: np.ndarray, # (K,) + tol: float, + verbose: bool, + ys: np.ndarray, # (K, T, obs) (NumPy for cropping) + ys_j, # jnp.ndarray, (K, T, obs) + safety_cap: int, +) -> None: + """ + Blockwise optimization of a single scalar process-noise scale `s` (shared within + each block of keypoints) by minimizing the sum of EKF negative log-likelihoods, + using time-varying observation covariances R_{k,t}. Updates `s_finals` in place. + + Parameters + ---------- + As, Cs, Qs : jnp.ndarray + Per-keypoint model matrices with shapes (K,D,D), (K,obs,D), (K,D,D). + Rs : jnp.ndarray + Time-varying observation covariances, shape (K, T, obs, obs). JAX array for + fast slicing; converted to NumPy internally only when cropping. + S0s : jnp.ndarray + Initial state covariances, shape (K, D, D). + blocks : list[list[int]] or None + Groups of keypoint indices that share a single `s`. If None/empty, each keypoint + is its own block. + lr : float + Adam learning rate for optimizing log(s). + m0s : jnp.ndarray + Initial state means, shape (K, D). + s_bounds_log : (float, float) + Clamp bounds for log(s) (numerical stability). + s_finals : np.ndarray + Output array of shape (K,). Filled with the final `s` per keypoint + (blockwise optimum broadcast to members). + s_frames : list or None + Frame ranges for cropping (list of (start, end) tuples; 1-based start, inclusive end). + Applied to both y and R for the loss only. + s_guess_per_k : np.ndarray + Heuristic initial guesses of `s` per keypoint; block init is the mean over members. + tol : float + Relative tolerance on the loss change for early stopping. + verbose : bool + If True, prints per-block optimization progress. + ys : np.ndarray + Observations (NumPy) for CPU-side cropping, shape (K, T, obs). + ys_j : jnp.ndarray + Observations (JAX) for fast slicing when not cropping, shape (K, T, obs). + safety_cap : int + Maximum number of iterations inside the jitted while-loop. + + Returns + ------- + None + Results are written into `s_finals` in place. + """ + optimizer = optax.adam(float(lr)) + s_bounds_log_j = jnp.array(s_bounds_log, dtype=jnp.float32) + tol_j = float(tol) + + def _params_linear(m0, S0, A, Q_base, s, R_any, C): + f_fn = (lambda x, A=A: A @ x) # linear dynamics + h_fn = (lambda x, C=C: C @ x) # linear emission + return params_nlgssm_for_keypoint(m0, S0, Q_base, s, R_any, f_fn, h_fn) + + def _nll_one_keypoint(log_s, y_k, m0_k, S0_k, A_k, Q_k, C_k, R_k_tv): + s = jnp.exp(jnp.clip(log_s, s_bounds_log_j[0], s_bounds_log_j[1])) + params = _params_linear(m0_k, S0_k, A_k, Q_k, s, R_k_tv, C_k) + post = extended_kalman_filter(params, jnp.asarray(y_k)) + return -post.marginal_loglik + + def _nll_block(log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + nlls = jax.vmap(_nll_one_keypoint, in_axes=(None, 0, 0, 0, 0, 0, 0, 0))( + log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + return jnp.sum(nlls) + + @jit + def _opt_step(log_s, opt_state, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + loss, grad = value_and_grad(_nll_block)( + log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + updates, opt_state = optimizer.update(grad, opt_state) + log_s = optax.apply_updates(log_s, updates) + return log_s, opt_state, loss + + @jit + def _run_tol_loop(log_s0, opt_state0, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + def cond(carry): + _, _, prev_loss, iters, done = carry + return jnp.logical_and(~done, iters < safety_cap) + + def body(carry): + log_s, opt_state, prev_loss, iters, _ = carry + log_s, opt_state, loss = _opt_step( + log_s, opt_state, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + rel_tol = tol_j * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12))) + done = jnp.where( + jnp.isfinite(prev_loss), + jnp.linalg.norm(loss - prev_loss) < (rel_tol + 1e-6), + False + ) + return (log_s, opt_state, loss, iters + 1, done) + + return lax.while_loop( + cond, body, (log_s0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) + ) + + # for cropping only: NumPy view of Rs + Rs_np = np.asarray(Rs) + + # Optimize per block (shared s within each block) + for block in (blocks or []): + sel = jnp.asarray(block, dtype=int) + + if s_frames and len(s_frames) > 0: + y_block_list = [crop_frames(ys[int(k)], s_frames) for k in block] # (T', obs) + R_block_list = [crop_R_tv(Rs_np[int(k)], s_frames) for k in block] # (T', obs, obs) + y_block = jnp.asarray(np.stack(y_block_list, axis=0)) # (B, T', obs) + R_block = jnp.asarray(np.stack(R_block_list, axis=0)) # (B, T', obs, obs) + else: + y_block = ys_j[sel] # (B, T, obs) + R_block = Rs[sel] # (B, T, obs, obs) + + m0_block = m0s[sel] + S0_block = S0s[sel] + A_block = As[sel] + Q_block = Qs[sel] + C_block = Cs[sel] + + s0 = float(np.mean([s_guess_per_k[k] for k in block])) + log_s0 = jnp.array(np.log(max(s0, 1e-6)), dtype=jnp.float32) + opt_state0 = optimizer.init(log_s0) + + log_s_f, opt_state_f, last_loss, iters_f, _done = _run_tol_loop( + log_s0, opt_state0, y_block, m0_block, S0_block, A_block, Q_block, C_block, R_block + ) + s_star = float(jnp.exp(jnp.clip(log_s_f, s_bounds_log_j[0], s_bounds_log_j[1]))) + for k in block: + s_finals[k] = s_star + if verbose: + print(f"[Block {block}] s={s_star:.6g}, " + f"iters={int(iters_f)}, NLL={float(last_loss):.6f}") diff --git a/eks/multicam_smoother.py b/eks/multicam_smoother.py index 3396897..1f30062 100644 --- a/eks/multicam_smoother.py +++ b/eks/multicam_smoother.py @@ -6,7 +6,7 @@ from sklearn.decomposition import PCA from typeguard import typechecked -from eks.core import ensemble, optimize_smooth_param +from eks.core import ensemble, run_kalman_smoother from eks.marker_array import ( MarkerArray, input_dfs_to_markerArray, @@ -202,7 +202,6 @@ def ensemble_kalman_smoother_multicam( verbose: bool = False, pca_object: PCA | None = None, n_latent: int = 3, - backend: str = 'jax', ) -> tuple: """ Use multi-view constraints to fit a 3D latent subspace for each body part. @@ -287,7 +286,7 @@ def ensemble_kalman_smoother_multicam( ]) # Optimize smoothing - s_finals, ms, Vs = optimize_smooth_param( + s_finals, ms, Vs = run_kalman_smoother( Qs=Qs, ys=ys, m0s=m0s, diff --git a/eks/singlecam_smoother.py b/eks/singlecam_smoother.py index 441e06d..870a04c 100644 --- a/eks/singlecam_smoother.py +++ b/eks/singlecam_smoother.py @@ -5,7 +5,7 @@ import pandas as pd from typeguard import typechecked -from eks.core import ensemble, optimize_smooth_param +from eks.core import ensemble, run_kalman_smoother from eks.marker_array import MarkerArray, input_dfs_to_markerArray from eks.utils import center_predictions, format_data, make_dlc_pandas_index @@ -137,7 +137,7 @@ def ensemble_kalman_smoother_singlecam( m0s, S0s, As, cov_mats, Cs = initialize_kalman_filter(emA_centered_preds) # Main smoothing function - s_finals, ms, Vs = optimize_smooth_param( + s_finals, ms, Vs = run_kalman_smoother( cov_mats, ys, m0s, S0s, Cs, As, emA_vars.get_array(squeeze=True), s_frames, smooth_param, blocks, verbose=verbose ) From 79de8fdfa016210ecd846703e6bf34bd2676e9d9 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Mon, 6 Oct 2025 10:50:54 -0400 Subject: [PATCH 10/11] api runner/optimizer function refactor for core and pupil --- eks/core.py | 217 ++++++++++++++-------------- eks/ibl_pupil_smoother.py | 293 ++++++++++++++++++++++++++------------ eks/multicam_smoother.py | 6 +- 3 files changed, 312 insertions(+), 204 deletions(-) diff --git a/eks/core.py b/eks/core.py index 305d243..bb8dc6a 100644 --- a/eks/core.py +++ b/eks/core.py @@ -127,14 +127,15 @@ def params_nlgssm_for_keypoint(m0, S0, Q, s, R, f_fn, h_fn) -> ParamsNLGSSM: ) +# ----------------- Public API ----------------- @typechecked def run_kalman_smoother( - Qs: jnp.ndarray, # (K, D, D) - ys: np.ndarray, # (K, T, obs) + ys: jnp.ndarray, # (K, T, obs) m0s: jnp.ndarray, # (K, D) S0s: jnp.ndarray, # (K, D, D) - Cs: jnp.ndarray, # (K, obs, D) As: jnp.ndarray, # (K, D, D) + Cs: jnp.ndarray, # (K, obs, D) + Qs: jnp.ndarray, # (K, D, D) ensemble_vars: np.ndarray, # (T, K, obs) s_frames: Optional[List] = None, smooth_param: Optional[Union[float, List[float]]] = None, @@ -148,74 +149,55 @@ def run_kalman_smoother( ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Optimize the process-noise scale `s` (shared within each block of keypoints) by minimizing - summed negative log-likelihood (NLL) under a *linear* state-space model using the - Dynamax EKF filter (fast), then produce final trajectories via the EKF smoother. + the summed EKF filter negative log-likelihood (NLL) in a *linear* state-space model, + then run the EKF smoother for final trajectories. - Model (per keypoint k): + Model per keypoint k: x_{t+1} = A_k x_t + w_t, y_t = C_k x_t + v_t - w_t ~ N(0, s * Q_k), v_t ~ N(0, R_{k,t}) - - where R_{k,t} is **time-varying**, built from ensemble variances: - R_{k,t} = diag( clip( ensemble_vars[t, k, :], 1e-12, ∞ ) ). + w_t ~ N(0, s * Q_k), v_t ~ N(0, R_{k,t}), with time-varying R_{k,t}. Args: - Qs: (K, D, D) base process noise covariances Q_k per keypoint (scaled by `s`). - ys: (K, T, obs) observations per keypoint across time. - m0s: (K, D) initial state means per keypoint. - S0s: (K, D, D) initial state covariances per keypoint. - Cs: (K, obs, D) observation matrices C_k per keypoint. - As: (K, D, D) transition matrices A_k per keypoint. - ensemble_vars: (T, K, obs) per-frame ensemble variances for each keypoint’s obs dims; - used to construct time-varying R_{k,t}. - s_frames: Optional list of frame indices used for NLL optimization (cropping the loss). - Final smoothing always runs on the full sequence. - smooth_param: If provided, bypass optimization. - • float/int: same `s` for all keypoints; - • list[float] of length K: per-keypoint `s`. - blocks: Optional list of lists of keypoint indices; each block shares a single `s`. - Default: each keypoint forms its own block. - verbose: If True, prints per-block optimization summaries. - lr: Adam learning rate for optimizing log(s). - s_bounds_log: (low, high) clamp for log(s) during optimization. + ys: (K, T, obs) observations per keypoint over time. + m0s: (K, D) initial state means. + S0s: (K, D, D) initial state covariances. + As: (K, D, D) transition matrices. + Cs: (K, obs, D) observation matrices. + Qs: (K, D, D) base process covariances (scaled by `s`). + ensemble_vars: (T, K, obs) per-frame ensemble variances; used to build R_{k,t} + via diag(clip(ensemble_vars[t, k, :], 1e-12, ∞)). + s_frames: Optional list of (start, end) tuples (1-based, inclusive end) to crop + the time axis *for the loss only*. Final smoothing uses the full sequence. + smooth_param: If provided, bypass optimization. Either a scalar (shared across K) + or a list of length K (per-keypoint). + blocks: Optional list of lists of keypoint indices; each block shares one `s`. + Default: each keypoint is its own block. + verbose: Print per-block optimization summaries if True. + lr: Adam learning rate (on log(s)). + s_bounds_log: Clamp bounds for log(s) during optimization. tol: Relative tolerance on loss change for early stopping. - safety_cap: Hard limit on iterations inside the jitted while-loop. + safety_cap: Hard iteration cap inside the jitted while-loop. Returns: - s_finals: (K,) final `s` per keypoint (blockwise value broadcast to members). + s_finals: (K,) final `s` per keypoint (block optimum broadcast to members). ms: (K, T, D) smoothed state means. Vs: (K, T, D, D) smoothed state covariances. - - Notes: - • NLL is computed with EKF *filter*; outputs use EKF *smoother*. - • Loss for a block is the sum of member keypoints’ NLLs (via vmap). - • All jitted helpers close over optimizer/tol/bounds to avoid passing Python objects. """ - # -------------------- setup & time-varying R_t -------------------- K, T, obs_dim = ys.shape if not blocks: blocks = [[k] for k in range(K)] if verbose: print(f"Correlated keypoint blocks: {blocks}") - # Build time-varying R - Rs = build_R_from_vars(np.swapaxes(ensemble_vars, 0, 1)) - Rs_j = jnp.asarray(Rs) + # Build time-varying R (K, T, obs, obs) + Rs = jnp.asarray(build_R_from_vars(np.swapaxes(ensemble_vars, 0, 1))) - # Device arrays once - ys_j = jnp.asarray(ys) - m0s_j = jnp.asarray(m0s) - S0s_j = jnp.asarray(S0s) - As_j = jnp.asarray(As) - Qs_j = jnp.asarray(Qs) - Cs_j = jnp.asarray(Cs) - - # Initial guesses per keypoint + # Initial guesses per keypoint (host-side) s_guess_per_k = np.empty(K, dtype=float) for k in range(K): g = float(compute_initial_guesses(ensemble_vars[:, k, :]) or 2.0) s_guess_per_k[k] = g if (np.isfinite(g) and g > 0.0) else 2.0 - # -------------------- choose or optimize s -------------------- + # Choose or optimize s s_finals = np.empty(K, dtype=float) if smooth_param is not None: if isinstance(smooth_param, (int, float)): @@ -223,21 +205,37 @@ def run_kalman_smoother( else: s_finals[:] = np.asarray(smooth_param, dtype=float) else: - optimize_smooth_param(As_j, Cs_j, Qs_j, Rs, Rs_j, S0s_j, blocks, lr, m0s_j, s_bounds_log, - s_finals, s_frames, s_guess_per_k, tol, verbose, ys, ys_j) + optimize_smooth_param( + ys=ys, + m0s=m0s, + S0s=S0s, + As=As, + Cs=Cs, + Qs=Qs, + Rs=Rs, + blocks=blocks, + lr=lr, + s_bounds_log=s_bounds_log, + s_finals=s_finals, + s_frames=s_frames, + s_guess_per_k=s_guess_per_k, + tol=tol, + verbose=verbose, + safety_cap=safety_cap, + ) - # -------------------- final smoother pass (full R_t) -------------------- + # Final smoother pass (full R_t) def _params_linear_for_k(k: int, s_val: float): - A_k, C_k = As_j[k], Cs_j[k] + A_k, C_k = As[k], Cs[k] f_fn = (lambda x, A=A_k: A @ x) h_fn = (lambda x, C=C_k: C @ x) return params_nlgssm_for_keypoint( - m0s_j[k], S0s_j[k], Qs_j[k], s_val, Rs[k], f_fn, h_fn) + m0s[k], S0s[k], Qs[k], s_val, Rs[k], f_fn, h_fn) means_list, covs_list = [], [] for k in range(K): params_k = _params_linear_for_k(k, s_finals[k]) - sm = extended_kalman_smoother(params_k, ys_j[k]) + sm = extended_kalman_smoother(params_k, ys[k]) if hasattr(sm, "smoothed_means"): m_k, V_k = sm.smoothed_means, sm.smoothed_covariances else: @@ -245,90 +243,84 @@ def _params_linear_for_k(k: int, s_val: float): means_list.append(np.array(m_k)) covs_list.append(np.array(V_k)) - ms = np.stack(means_list, axis=0) # (K, T, D) - Vs = np.stack(covs_list, axis=0) # (K, T, D, D) + ms = np.stack(means_list, axis=0) + Vs = np.stack(covs_list, axis=0) return s_finals, ms, Vs -from jax import jit, lax, value_and_grad -import jax -import jax.numpy as jnp -import numpy as np -import optax - +# ----------------- Optimizer (blockwise s) ----------------- def optimize_smooth_param( - As, # jnp.ndarray, (K, D, D) - Cs, # jnp.ndarray, (K, obs, D) - Qs, # jnp.ndarray, (K, D, D) - Rs, # jnp.ndarray, (K, T, obs, obs) - S0s, # jnp.ndarray, (K, D, D) - blocks, # Optional[List[List[int]]] + ys: jnp.ndarray, # (K, T, obs) + m0s: jnp.ndarray, # (K, D) + S0s: jnp.ndarray, # (K, D, D) + As: jnp.ndarray, # (K, D, D) + Cs: jnp.ndarray, # (K, obs, D) + Qs: jnp.ndarray, # (K, D, D) + Rs: jnp.ndarray, # (K, T, obs, obs) time-varying R_t + blocks: Optional[list], lr: float, - m0s, # jnp.ndarray, (K, D) s_bounds_log: tuple, s_finals: np.ndarray, # (K,), filled in-place - s_frames, # Optional[List] + s_frames: Optional[list], s_guess_per_k: np.ndarray, # (K,) tol: float, verbose: bool, - ys: np.ndarray, # (K, T, obs) (NumPy for cropping) - ys_j, # jnp.ndarray, (K, T, obs) safety_cap: int, ) -> None: """ - Blockwise optimization of a single scalar process-noise scale `s` (shared within - each block of keypoints) by minimizing the sum of EKF negative log-likelihoods, - using time-varying observation covariances R_{k,t}. Updates `s_finals` in place. + Optimize a single scalar process-noise scale `s` per block of keypoints by minimizing + the sum of EKF filter negative log-likelihoods, using time-varying observation noise + R_{k,t}. Writes results into `s_finals` in place. Parameters ---------- - As, Cs, Qs : jnp.ndarray - Per-keypoint model matrices with shapes (K,D,D), (K,obs,D), (K,D,D). - Rs : jnp.ndarray - Time-varying observation covariances, shape (K, T, obs, obs). JAX array for - fast slicing; converted to NumPy internally only when cropping. - S0s : jnp.ndarray - Initial state covariances, shape (K, D, D). + ys : jnp.ndarray, shape (K, T, obs) + Observations per keypoint (JAX). For cropped loss, host-side slices are created. + m0s : jnp.ndarray, shape (K, D) + Initial state means per keypoint. + S0s : jnp.ndarray, shape (K, D, D) + Initial state covariances per keypoint. + As : jnp.ndarray, shape (K, D, D) + State transition matrices. + Cs : jnp.ndarray, shape (K, obs, D) + Observation matrices. + Qs : jnp.ndarray, shape (K, D, D) + Base process covariances (scaled by `s` inside the model). + Rs : jnp.ndarray, shape (K, T, obs, obs) + Time-varying observation covariances for each keypoint. blocks : list[list[int]] or None - Groups of keypoint indices that share a single `s`. If None/empty, each keypoint - is its own block. + Groups of keypoint indices that share a single `s`. + If None/empty, each keypoint is its own block. lr : float - Adam learning rate for optimizing log(s). - m0s : jnp.ndarray - Initial state means, shape (K, D). + Adam learning rate (on log(s)). s_bounds_log : (float, float) - Clamp bounds for log(s) (numerical stability). - s_finals : np.ndarray - Output array of shape (K,). Filled with the final `s` per keypoint - (blockwise optimum broadcast to members). + Clamp bounds for log(s) to stabilize optimization. + s_finals : np.ndarray, shape (K,) + Output array filled with final per-keypoint `s` (block optimum broadcast). s_frames : list or None - Frame ranges for cropping (list of (start, end) tuples; 1-based start, inclusive end). - Applied to both y and R for the loss only. - s_guess_per_k : np.ndarray - Heuristic initial guesses of `s` per keypoint; block init is the mean over members. + Frame ranges for cropping (list of (start, end); 1-based start, inclusive end). + Applied to both y and R_t for the loss only. + s_guess_per_k : np.ndarray, shape (K,) + Heuristic initial guesses of `s` per keypoint. Block init uses the mean over members. tol : float - Relative tolerance on the loss change for early stopping. + Relative tolerance on loss change for early stopping. verbose : bool If True, prints per-block optimization progress. - ys : np.ndarray - Observations (NumPy) for CPU-side cropping, shape (K, T, obs). - ys_j : jnp.ndarray - Observations (JAX) for fast slicing when not cropping, shape (K, T, obs). safety_cap : int Maximum number of iterations inside the jitted while-loop. Returns ------- None - Results are written into `s_finals` in place. + Results are written into `s_finals`. """ optimizer = optax.adam(float(lr)) s_bounds_log_j = jnp.array(s_bounds_log, dtype=jnp.float32) tol_j = float(tol) def _params_linear(m0, S0, A, Q_base, s, R_any, C): - f_fn = (lambda x, A=A: A @ x) # linear dynamics - h_fn = (lambda x, C=C: C @ x) # linear emission + f_fn = (lambda x, A=A: A @ x) + h_fn = (lambda x, C=C: C @ x) return params_nlgssm_for_keypoint(m0, S0, Q_base, s, R_any, f_fn, h_fn) def _nll_one_keypoint(log_s, y_k, m0_k, S0_k, A_k, Q_k, C_k, R_k_tv): @@ -375,21 +367,22 @@ def body(carry): cond, body, (log_s0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) ) - # for cropping only: NumPy view of Rs + # For cropping only: host view Rs_np = np.asarray(Rs) + ys_np = np.asarray(ys) - # Optimize per block (shared s within each block) + # Optimize per block (shared s) for block in (blocks or []): sel = jnp.asarray(block, dtype=int) if s_frames and len(s_frames) > 0: - y_block_list = [crop_frames(ys[int(k)], s_frames) for k in block] # (T', obs) - R_block_list = [crop_R_tv(Rs_np[int(k)], s_frames) for k in block] # (T', obs, obs) - y_block = jnp.asarray(np.stack(y_block_list, axis=0)) # (B, T', obs) - R_block = jnp.asarray(np.stack(R_block_list, axis=0)) # (B, T', obs, obs) + y_block_list = [crop_frames(ys_np[int(k)], s_frames) for k in block] # (T', obs) + R_block_list = [crop_R(Rs_np[int(k)], s_frames) for k in block] # (T', obs, obs) + y_block = jnp.asarray(np.stack(y_block_list, axis=0)) # (B, T', obs) + R_block = jnp.asarray(np.stack(R_block_list, axis=0)) # (B, T', obs, obs) else: - y_block = ys_j[sel] # (B, T, obs) - R_block = Rs[sel] # (B, T, obs, obs) + y_block = ys[sel] # (B, T, obs) + R_block = Rs[sel] # (B, T, obs, obs) m0_block = m0s[sel] S0_block = S0s[sel] diff --git a/eks/ibl_pupil_smoother.py b/eks/ibl_pupil_smoother.py index a76b773..0f0de25 100644 --- a/eks/ibl_pupil_smoother.py +++ b/eks/ibl_pupil_smoother.py @@ -18,7 +18,7 @@ from eks.core import ensemble, params_nlgssm_for_keypoint from eks.marker_array import MarkerArray, input_dfs_to_markerArray -from eks.utils import build_R_from_vars, crop_frames, format_data, make_dlc_pandas_index +from eks.utils import build_R_from_vars, crop_frames, crop_R, format_data, make_dlc_pandas_index @typechecked @@ -244,10 +244,19 @@ def ensemble_kalman_smoother_ibl_pupil( # ------------------------------------------------------- # Perform filtering with SINGLE PAIR of diameter_s, com_s # ------------------------------------------------------- - s_finals, ms, Vs, nll = pupil_optimize_smooth( - y_obs, m0, S0, C, ensemble_vars, - np.var(pupil_diameters), np.var(x_t_obs), np.var(y_t_obs), s_frames, smooth_params, - verbose=verbose) + s_finals, ms, Vs = run_pupil_kalman_smoother( + ys=jnp.asarray(y_obs), + m0=jnp.asarray(m0), + S0=jnp.asarray(S0), + C=jnp.asarray(C), + ensemble_vars=ensemble_vars, + diameters_var=np.var(pupil_diameters), + x_var=np.var(x_t_obs), + y_var=np.var(y_t_obs), + s_frames=s_frames, + smooth_params=smooth_params, + verbose=verbose + ) if verbose: print(f"diameter_s={s_finals[0]}, com_s={s_finals[1]}") # Smoothed posterior over ys @@ -308,123 +317,229 @@ def ensemble_kalman_smoother_ibl_pupil( return markers_df, s_finals +# ----------------- Public API ----------------- @typechecked -def pupil_optimize_smooth( - ys: np.ndarray, # (T, 8) centered obs - m0: np.ndarray, # (3,) - S0: np.ndarray, # (3,3) - C: np.ndarray, # (8,3) - ensemble_vars: np.ndarray, # (T, 8) +def run_pupil_kalman_smoother( + ys: jnp.ndarray, # (T, 8) centered obs + m0: jnp.ndarray, # (3,) + S0: jnp.ndarray, # (3,3) + C: jnp.ndarray, # (8,3) + ensemble_vars: np.ndarray, # (T, 8) diameters_var: Real, - x_var: float, - y_var: float, - s_frames: Optional[List[Tuple[Optional[int], Optional[int]]]] = [(1, 2000)], - smooth_params: list | None = None, # [diam_s, com_s] in (0,1) - maxiter: int = 1000, # retained (unused with tol-loop) + x_var: Real, + y_var: Real, + s_frames: Optional[List[Tuple[Optional[int], Optional[int]]]] = None, + smooth_params: Optional[list] = None, # [s_diam, s_com] in (0,1) verbose: bool = False, # optimizer/loop knobs lr: float = 5e-3, tol: float = 1e-6, safety_cap: int = 5000, -) -> tuple: +) -> Tuple[List[float], np.ndarray, np.ndarray]: """ - Dynamax backend: optimize [s_diameter, s_com] with EKF NLL, then EKF smoother. + Optimize pupil AR(1) smoothing params `[s_diam, s_com]` via EKF filter NLL with + time-varying R_t built from ensemble variances, then run EKF smoother for final + trajectories. + + Args: + ys: (T, 8) centered observations (order: top,bottom,right,left x/y). + m0: (3,) initial state mean [diameter, com_x, com_y]. + S0: (3,3) initial state covariance. + C: (8,3) observation matrix mapping state -> 8 observed coords. + ensemble_vars: (T, 8) per-dimension ensemble variances; used to build R_t. + diameters_var: variance scale for diameter latent. + x_var, y_var: variance scales for com_x, com_y latents. + s_frames: optional list of (start, end) 1-based, inclusive frame ranges for + NLL optimization only (final smoothing runs over the full T). + smooth_params: if provided, use `[s_diam, s_com]` directly (values in (0,1)). + verbose: print optimization summary. + lr: Adam learning rate on the unconstrained parameters. + tol: relative tolerance for early stopping. + safety_cap: hard limit on optimizer steps inside the jitted loop. Returns: - s_finals (list[float]), ms (T,3), Vs (T,3,3), nll (float) + (s_finals, ms, Vs): + s_finals: [s_diam, s_com] + ms: (T, 3) smoothed state means + Vs: (T, 3, 3) smoothed state covariances """ + # build time-varying R_t (T, 8, 8) and JAX-ify inputs + R = jnp.asarray(build_R_from_vars(ensemble_vars)) + + # --- optimize [s_diam, s_com] on cropped loss (if requested) --- + s_d, s_c = pupil_optimize_smooth( + ys=ys, + m0=m0, + S0=S0, + C=C, + R=R, + diameters_var=diameters_var, + x_var=x_var, + y_var=y_var, + s_frames=s_frames, + smooth_params=smooth_params, + lr=lr, + tol=tol, + safety_cap=safety_cap, + verbose=verbose, + ) - # logistic reparam to keep s in (eps,1-eps) + # --- final smoother on full sequence with A(s), Q(s) and supplied R_t --- + s_d_j, s_c_j = jnp.asarray(s_d), jnp.asarray(s_c) + A = jnp.diag(jnp.array([s_d_j, s_c_j, s_c_j])) + Q = jnp.diag(jnp.array([ + jnp.asarray(diameters_var) * (1.0 - s_d_j**2), + jnp.asarray(x_var) * (1.0 - s_c_j**2), + jnp.asarray(y_var) * (1.0 - s_c_j**2), + ])) + + f_fn = (lambda x: A @ x) + h_fn = (lambda x: C @ x) + # Pass Q as exact and s=1.0 (we already encoded s into A, Q) + params = params_nlgssm_for_keypoint(m0, S0, Q, 1.0, R, f_fn, h_fn) + + sm = extended_kalman_smoother(params, ys) + ms = np.array(getattr(sm, "smoothed_means", sm.filtered_means)) + Vs = np.array(getattr(sm, "smoothed_covariances", sm.filtered_covariances)) + return [float(s_d), float(s_c)], ms, Vs + + +# ----------------- Optimizer (two-parameter AR(1)) ----------------- +@typechecked +def pupil_optimize_smooth( + ys: jnp.ndarray, # (T, 8) centered obs + m0: jnp.ndarray, # (3,) + S0: jnp.ndarray, # (3,3) + C: jnp.ndarray, # (8,3) + R: jnp.ndarray, # (T, 8, 8) time-varying obs covariance + diameters_var: Real, + x_var: Real, + y_var: Real, + s_frames: Optional[List[Tuple[Optional[int], Optional[int]]]] = None, + smooth_params: Optional[list] = None, # [s_diam, s_com] in (0,1) + lr: float = 5e-3, + tol: float = 1e-6, + safety_cap: int = 5000, + verbose: bool = False, +) -> Tuple[float, float]: + """ + Optimize `[s_diam, s_com]` for the pupil AR(1) model by minimizing EKF filter + negative log-likelihood on (optionally) cropped data. Uses a logistic reparam + to keep the parameters in (ε, 1−ε). Returns the optimized pair. + + Parameters + ---------- + ys : jnp.ndarray, shape (T, 8) + Centered observations. + m0 : jnp.ndarray, shape (3,) + Initial state mean. + S0 : jnp.ndarray, shape (3, 3) + Initial state covariance. + C : jnp.ndarray, shape (8, 3) + Observation matrix. + R : jnp.ndarray, shape (T, 8, 8) + Time-varying observation covariance. + diameters_var : Real + Variance scale for diameter latent. + x_var, y_var : Real + Variance scales for com_x and com_y latents. + s_frames : list[(start, end)] or None + 1-based start, inclusive end cropping ranges for the loss only. + smooth_params : Optional[Sequence[Real]] + If provided and both values are not None, bypass optimization and use them directly. + lr : float + Adam learning rate on the unconstrained variables. + tol : float + Relative tolerance for early stopping. + safety_cap : int + Hard iteration cap in the jitted loop. + verbose : bool + Print optimization summary if True. + + Returns + ------- + (s_diam, s_com) : Tuple[float, float] + Optimized AR(1) parameters in (0, 1). + """ + # Map unconstrained u -> s in (eps, 1-eps) def _to_stable_s(u, eps=1e-3): return jax.nn.sigmoid(u) * (1.0 - 2 * eps) + eps - # crop ys and ev for the loss if s_frames provided + # Cropping for loss (host-side), then back to JAX + ys_np = np.asarray(ys) + R_np = np.asarray(R) if s_frames and len(s_frames) > 0: - y_cropped = crop_frames(ys, s_frames) # (T', 8) - ev_cropped = crop_frames(ensemble_vars, s_frames) # (T', 8) + y_loss = jnp.asarray(crop_frames(ys_np, s_frames)) # (T', 8) + R_loss = jnp.asarray(crop_R(R_np, s_frames)) # (T', 8, 8) else: - y_cropped, ev_cropped = ys, ensemble_vars - - # build time-varying R_t for loss and for final smoothing (full sequence) - R_loss = build_R_from_vars(ev_cropped) # (T' ,8,8) - R_full = build_R_from_vars(ensemble_vars) # (T ,8,8) + y_loss = ys + R_loss = R - # jnp once - y_c = jnp.asarray(y_cropped) - R_loss = jnp.asarray(R_loss) - m0_j, S0_j, C_j = jnp.asarray(m0), jnp.asarray(S0), jnp.asarray(C) - y_full = jnp.asarray(ys) - R_full = jnp.asarray(R_full) - - # local params builder using your NLGSSM wrapper; pass Q_exact with s=1.0 + # Params builder with Q exact and s=1.0 (A, Q depend on s directly) def _params_linear(m0, S0, A, Q_exact, R_any, C): f_fn = (lambda x, A=A: A @ x) h_fn = (lambda x, C=C: C @ x) return params_nlgssm_for_keypoint(m0, S0, Q_exact, 1.0, R_any, f_fn, h_fn) - # EKF NLL for a given unconstrained u = [u_d, u_c] + # NLL(u) with u = [u_diam, u_com] def _nll_from_u(u: jnp.ndarray) -> jnp.ndarray: s_d, s_c = _to_stable_s(u) A = jnp.diag(jnp.array([s_d, s_c, s_c])) Q = jnp.diag(jnp.array([ - diameters_var * (1.0 - s_d**2), - x_var * (1.0 - s_c**2), - y_var * (1.0 - s_c**2), + jnp.asarray(diameters_var) * (1.0 - s_d**2), + jnp.asarray(x_var) * (1.0 - s_c**2), + jnp.asarray(y_var) * (1.0 - s_c**2), ])) - params = _params_linear(m0_j, S0_j, A, Q, R_loss, C_j) - post = extended_kalman_filter(params, y_c) + params = _params_linear(m0, S0, A, Q, R_loss, C) + post = extended_kalman_filter(params, y_loss) return -post.marginal_loglik + # If user provided both params, just use them + if smooth_params is not None and all(v is not None for v in smooth_params): + s = jnp.clip(jnp.asarray(smooth_params, dtype=jnp.float32), 1e-3, 1 - 1e-3) + return float(s[0]), float(s[1]) + + # Otherwise optimize in unconstrained space optimizer = optax.adam(lr) - if smooth_params is None or smooth_params[0] is None or smooth_params[1] is None: - # init near your old guess (invert logistic) - s0 = jnp.array([0.99, 0.98]) - u0 = jnp.log(s0 / (1.0 - s0)) - opt_state0 = optimizer.init(u0) - - @jit - def _opt_step(u, opt_state): - loss, grad = value_and_grad(_nll_from_u)(u) - updates, opt_state = optimizer.update(grad, opt_state) - u = optax.apply_updates(u, updates) - return u, opt_state, loss - - @jit - def _run_tol_loop(u0, opt_state0): - def cond(carry): - _, _, prev_loss, iters, done = carry - return jnp.logical_and(~done, iters < safety_cap) - - def body(carry): - u, opt_state, prev_loss, iters, _ = carry - u, opt_state, loss = _opt_step(u, opt_state) - rel_tol = tol * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12))) - done = jnp.where(jnp.isfinite(prev_loss), - jnp.linalg.norm(loss - prev_loss) < (rel_tol + 1e-6), - False) - return (u, opt_state, loss, iters + 1, done) - return lax.while_loop( - cond, body, (u0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) + s0 = jnp.array([0.99, 0.98], dtype=jnp.float32) + u0 = jnp.log(s0 / (1.0 - s0)) + opt_state0 = optimizer.init(u0) + + @jit + def _opt_step(u, opt_state): + loss, grad = value_and_grad(_nll_from_u)(u) + updates, opt_state = optimizer.update(grad, opt_state) + u = optax.apply_updates(u, updates) + return u, opt_state, loss + + @jit + def _run_tol_loop(u0, opt_state0): + def cond(carry): + _, _, prev_loss, iters, done = carry + return jnp.logical_and(~done, iters < safety_cap) + + def body(carry): + u, opt_state, prev_loss, iters, _ = carry + u, opt_state, loss = _opt_step(u, opt_state) + rel_tol = tol * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12))) + done = jnp.where( + jnp.isfinite(prev_loss), + jnp.linalg.norm(loss - prev_loss) < (rel_tol + 1e-6), + False ) + return (u, opt_state, loss, iters + 1, done) - u_f, opt_state_f, last_loss, iters_f, _ = _run_tol_loop(u0, opt_state0) - s_opt = _to_stable_s(u_f) - if verbose: - print(f"[pupil/dynamax] iters={int(iters_f)} s_diam={float(s_opt[0]):.6f} " - f"s_com={float(s_opt[1]):.6f} NLL={float(last_loss):.6f}") - else: - s_user = jnp.clip(jnp.asarray(smooth_params, dtype=jnp.float32), 1e-3, 1 - 1e-3) - s_opt = s_user - - # final smoother on full sequence with full R_t - s_d, s_c = float(s_opt[0]), float(s_opt[1]) - ms, Vs, nll = pupil_smooth( - smooth_params=[s_d, s_c], - ys=y_full, m0=m0_j, S0=S0_j, C=C_j, R=R_full, - diameters_var=diameters_var, x_var=x_var, y_var=y_var, - return_full=True - ) - return [s_d, s_c], np.asarray(ms), np.asarray(Vs), float(nll) + return lax.while_loop( + cond, body, (u0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) + ) + + u_f, _opt_state_f, last_loss, iters_f, _ = _run_tol_loop(u0, opt_state0) + s_opt = _to_stable_s(u_f) + if verbose: + print(f"[pupil/dynamax] iters={int(iters_f)} " + f"s_diam={float(s_opt[0]):.6f} s_com={float(s_opt[1]):.6f} " + f"NLL={float(last_loss):.6f}") + return float(s_opt[0]), float(s_opt[1]) @typechecked diff --git a/eks/multicam_smoother.py b/eks/multicam_smoother.py index 1f30062..b90756f 100644 --- a/eks/multicam_smoother.py +++ b/eks/multicam_smoother.py @@ -287,12 +287,12 @@ def ensemble_kalman_smoother_multicam( # Optimize smoothing s_finals, ms, Vs = run_kalman_smoother( - Qs=Qs, - ys=ys, + ys=jnp.asarray(ys), m0s=m0s, S0s=S0s, - Cs=Cs, As=As, + Cs=Cs, + Qs=Qs, ensemble_vars=np.swapaxes(ensemble_vars, 0, 1), s_frames=s_frames, smooth_param=smooth_param, From 648064a3e0128ca21dd2df3d168864590fe6d840 Mon Sep 17 00:00:00 2001 From: Keemin Lee Date: Mon, 6 Oct 2025 16:15:13 -0400 Subject: [PATCH 11/11] singlecam compat fix --- eks/singlecam_smoother.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/eks/singlecam_smoother.py b/eks/singlecam_smoother.py index 870a04c..5e6c28e 100644 --- a/eks/singlecam_smoother.py +++ b/eks/singlecam_smoother.py @@ -94,7 +94,7 @@ def ensemble_kalman_smoother_singlecam( keypoint_names: List of body parts to run smoothing on smooth_param: value in (0, Inf); smaller values lead to more smoothing s_frames: List of frames for automatic computation of smoothing parameter - blocks: keypoints to be blocked for correlated noise. Generates on smoothing param per + blocks: keypoints to be blocked for correlated noise. Generates one smoothing param per block, as opposed to per keypoint. Specified by the form "x1, x2, x3; y1, y2" referring to keypoint indices (start at 0) avg_mode: mode for averaging across ensemble @@ -134,12 +134,21 @@ def ensemble_kalman_smoother_singlecam( # Prepare params for singlecam_optimize_smooth() ys = emA_centered_preds.get_array(squeeze=True).transpose(1, 0, 2) - m0s, S0s, As, cov_mats, Cs = initialize_kalman_filter(emA_centered_preds) + m0s, S0s, As, Qs, Cs = initialize_kalman_filter(emA_centered_preds) # Main smoothing function s_finals, ms, Vs = run_kalman_smoother( - cov_mats, ys, m0s, S0s, Cs, As, emA_vars.get_array(squeeze=True), - s_frames, smooth_param, blocks, verbose=verbose + ys=jnp.asarray(ys), + m0s=m0s, + S0s=S0s, + As=As, + Cs=Cs, + Qs=Qs, + ensemble_vars=emA_vars.get_array(squeeze=True), + s_frames=s_frames, + smooth_param=smooth_param, + blocks=blocks, + verbose=verbose ) y_m_smooths = np.zeros((n_keypoints, n_frames, 2))