diff --git a/eks/core.py b/eks/core.py index d00a49a..bb8dc6a 100644 --- a/eks/core.py +++ b/eks/core.py @@ -1,21 +1,14 @@ -from functools import partial - import jax -import jax.scipy as jsc import numpy as np import optax -from jax import jit -from jax import numpy as jnp -from jax import vmap +from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, extended_kalman_filter, \ + extended_kalman_smoother +from jax import numpy as jnp, vmap, jit, value_and_grad, lax from typeguard import typechecked -from typing import List, Literal, Optional, Tuple, Union +from typing import Literal, Union, Optional, List, Tuple from eks.marker_array import MarkerArray -from eks.utils import crop_frames - -# ------------------------------------------------------------------------------------- -# Kalman Functions: Functions related to performing filtering and smoothing -# ------------------------------------------------------------------------------------- +from eks.utils import build_R_from_vars, crop_frames, crop_R @typechecked @@ -86,352 +79,6 @@ def compute_stats(data_x, data_y, data_lh): return ensemble_marker_array -@typechecked -def kalman_filter_step( - carry, - inputs -) -> Tuple[tuple, Tuple[jnp.ndarray, jnp.ndarray, jax.Array]]: - """ - Performs a single Kalman filter update step using time-varying observation noise - from ensemble variance. - - Used in a scan loop, updating the state mean and covariance - based on the current observation and its associated ensemble variance. - - Args: - carry: Tuple containing the previous state and model parameters: - - m_prev (jnp.ndarray): Previous state estimate (mean vector). - - V_prev (jnp.ndarray): Previous state covariance matrix. - - A (jnp.ndarray): State transition matrix. - - Q (jnp.ndarray): Process noise covariance matrix. - - C (jnp.ndarray): Observation matrix. - - nll_net (float): Accumulated negative log-likelihood. - inputs: Tuple containing the current observation and its estimated ensemble variance: - - curr_y (jnp.ndarray): Current observation vector. - - curr_ensemble_var (jnp.ndarray): Estimated observation noise variance - (used to build time-varying R matrix). - - Returns: - A tuple of two elements: - - carry (tuple): Updated (m_t, V_t, A, Q, C, nll_net) to pass to the next step. - - output (tuple): Tuple of: - - m_t (jnp.ndarray): Updated state mean. - - V_t (jnp.ndarray): Updated state covariance. - - nll_current (float, stored as jax.Array): NLL of the current observation. - """ - m_prev, V_prev, A, Q, C, nll_net = carry - curr_y, curr_ensemble_var = inputs - - # Update R with time-varying ensemble variance - R = jnp.diag(curr_ensemble_var) - - # Predict - m_pred = jnp.dot(A, m_prev) - V_pred = jnp.dot(A, jnp.dot(V_prev, A.T)) + Q - - # Update - innovation = curr_y - jnp.dot(C, m_pred) - innovation_cov = jnp.dot(C, jnp.dot(V_pred, C.T)) + R - K = jnp.dot(V_pred, jnp.dot(C.T, jnp.linalg.inv(innovation_cov))) - m_t = m_pred + jnp.dot(K, innovation) - V_t = jnp.dot((jnp.eye(V_pred.shape[0]) - jnp.dot(K, C)), V_pred) - - nll_current = single_timestep_nll(innovation, innovation_cov) - nll_net = nll_net + nll_current - - return (m_t, V_t, A, Q, C, nll_net), (m_t, V_t, nll_current) - - -@typechecked -def kalman_filter_step_nlls( - carry: tuple, - inputs: tuple -) -> Tuple[tuple, Tuple[jnp.ndarray, jnp.ndarray, float]]: - """ - Performs a single Kalman filter update step and records per-timestep negative - log-likelihoods (NLLs) into a preallocated array. - - Used inside a `lax.scan` loop. In addition to updating the state estimate and total NLL, - it writes the NLL of each timestep into a persistent array for later analysis/plotting. - - Args: - carry: Tuple containing: - - m_prev (jnp.ndarray): Previous state estimate (mean vector). - - V_prev (jnp.ndarray): Previous state covariance matrix. - - A (jnp.ndarray): State transition matrix. - - Q (jnp.ndarray): Process noise covariance matrix. - - C (jnp.ndarray): Observation matrix. - - nll_net (float): Cumulative negative log-likelihood. - - nll_array (jnp.ndarray): Preallocated array for per-step NLL values. - - t (int): Current timestep index into the NLL array. - - inputs: Tuple containing: - - curr_y (jnp.ndarray): Current observation vector. - - curr_ensemble_var (jnp.ndarray): Estimated observation noise variance, - used to construct the time-varying R matrix. - - Returns: - A tuple of: - - carry (tuple): Updated state and NLL tracking info for the next timestep. - - output (tuple): - - m_t (jnp.ndarray): Updated state mean. - - V_t (jnp.ndarray): Updated state covariance. - - nll_current (float): Negative log-likelihood of the current timestep. - """ - # Unpack carry and inputs - m_prev, V_prev, A, Q, C, nll_net, nll_array, t = carry - curr_y, curr_ensemble_var = inputs - - # Update R with the current ensemble variance - R = jnp.diag(curr_ensemble_var) - - # Predict - m_pred = jnp.dot(A, m_prev) - V_pred = jnp.dot(A, jnp.dot(V_prev, A.T)) + Q - - # Update - innovation = curr_y - jnp.dot(C, m_pred) - innovation_cov = jnp.dot(C, jnp.dot(V_pred, C.T)) + R - K = jnp.dot(V_pred, jnp.dot(C.T, jnp.linalg.inv(innovation_cov))) - m_t = m_pred + jnp.dot(K, innovation) - V_t = V_pred - jnp.dot(K, jnp.dot(C, V_pred)) - - # Compute the negative log-likelihood for the current time step - nll_current = single_timestep_nll(innovation, innovation_cov) - - # Accumulate the negative log-likelihood - nll_net = nll_net + nll_current - - # Save the current NLL to the preallocated array - nll_array = nll_array.at[t].set(nll_current) - - # Increment the time step - t = t + 1 - - # Return the updated state and outputs - return (m_t, V_t, A, Q, C, nll_net, nll_array, t), (m_t, V_t, nll_current) - - -@partial(jit, backend='cpu') -def forward_pass( - y: jnp.ndarray, - m0: jnp.ndarray, - cov0: jnp.ndarray, - A: jnp.ndarray, - Q: jnp.ndarray, - C: jnp.ndarray, - ensemble_vars: jnp.ndarray -) -> Tuple[jnp.ndarray, jnp.ndarray, float]: - """ - Executes the Kalman filter forward pass for a single keypoint over time, - incorporating time-varying observation noise variances. - - Computes filtered state means, covariances, and the cumulative - negative log-likelihood across all timesteps. Used within `vmap` to - handle multiple keypoints in parallel. - - Args: - y: Array of shape (T, obs_dim). Sequence of observations over time. - m0: Array of shape (state_dim,). Initial state estimate. - cov0: Array of shape (state_dim, state_dim). Initial state covariance. - A: Array of shape (state_dim, state_dim). State transition matrix. - Q: Array of shape (state_dim, state_dim). Process noise covariance matrix. - C: Array of shape (obs_dim, state_dim). Observation matrix. - ensemble_vars: Array of shape (T, obs_dim). Per-frame observation noise variances. - - Returns: - mfs: Array of shape (T, state_dim). Filtered mean estimates at each timestep. - Vfs: Array of shape (T, state_dim, state_dim). Filtered covariance estimates at each timestep. - nll_net: Scalar float. Total negative log-likelihood across all timesteps. - """ - # Initialize carry - carry = (m0, cov0, A, Q, C, 0) - # Run the scan, passing y and ensemble_vars as inputs to kalman_filter_step - carry, outputs = jax.lax.scan(kalman_filter_step, carry, (y, ensemble_vars)) - mfs, Vfs, _ = outputs - nll_net = carry[-1] - return mfs, Vfs, nll_net - - -@typechecked -def kalman_smoother_step( - carry: tuple, - X: list, -) -> Tuple[tuple, Tuple[jnp.ndarray, jnp.ndarray]]: - """ - Performs a single backward pass of the Kalman smoother. - - Updates the smoothed state estimate and covariance based on the - current filtered estimate and the next time step's smoothed estimate. Used - within a `jax.lax.scan` in reverse over the time axis. - - Args: - carry: Tuple containing: - - m_ahead_smooth (jnp.ndarray): Smoothed state mean at the next timestep. - - v_ahead_smooth (jnp.ndarray): Smoothed state covariance at the next timestep. - - A (jnp.ndarray): State transition matrix. - - Q (jnp.ndarray): Process noise covariance matrix. - - X: Tuple containing: - - m_curr_filter (jnp.ndarray): Filtered mean estimate at the current timestep. - - v_curr_filter (jnp.ndarray): Filtered covariance at the current timestep. - - Returns: - A tuple of: - - carry (tuple): Updated smoothed state (mean, cov) and model params for the next step. - - output (tuple): - - smoothed_state (jnp.ndarray): Smoothed mean estimate at the current timestep. - - smoothed_cov (jnp.ndarray): Smoothed covariance at the current timestep. - """ - m_ahead_smooth, v_ahead_smooth, A, Q = carry - m_curr_filter, v_curr_filter = X[0], X[1] - - # Compute the smoother gain - ahead_cov = jnp.dot(A, jnp.dot(v_curr_filter, A.T)) + Q - - smoothing_gain = jsc.linalg.solve(ahead_cov, jnp.dot(A, v_curr_filter.T)).T - smoothed_state = m_curr_filter + jnp.dot(smoothing_gain, m_ahead_smooth - m_curr_filter) - smoothed_cov = v_curr_filter + jnp.dot(jnp.dot(smoothing_gain, v_ahead_smooth - ahead_cov), - smoothing_gain.T) - - return (smoothed_state, smoothed_cov, A, Q), (smoothed_state, smoothed_cov) - - -# @typechecked -- raises InstrumentationWarning as @jit rewrites into compiled form (JAX XLA) -@partial(jit, backend='cpu') -def backward_pass( - mfs: jnp.ndarray, - Vfs: jnp.ndarray, - A: jnp.ndarray, - Q: jnp.ndarray -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Executes the Kalman smoother backward pass using filtered means and covariances. - - Refines forward-filtered estimates by incorporating future observations. - Used after a Kalman filter forward pass to recover more accurate state estimates. - - Args: - mfs: Array of shape (T, state_dim). Filtered state means from the forward pass. - Vfs: Array of shape (T, state_dim, state_dim). Filtered covariances from the forward pass. - A: Array of shape (state_dim, state_dim). State transition matrix. - Q: Array of shape (state_dim, state_dim). Process noise covariance matrix. - - Returns: - smoothed_states: Array of shape (T, state_dim). Smoothed state mean estimates. - smoothed_state_covariances: Array of shape (T, state_dim, state_dim). - Smoothed state covariance estimates. - """ - carry = (mfs[-1], Vfs[-1], A, Q) - - # Reverse scan over the time steps - carry, outputs = jax.lax.scan( - kalman_smoother_step, - carry, - [mfs[:-1], Vfs[:-1]], - reverse=True - ) - - smoothed_states, smoothed_state_covariances = outputs - smoothed_states = jnp.append(smoothed_states, jnp.expand_dims(mfs[-1], 0), 0) - smoothed_state_covariances = jnp.append(smoothed_state_covariances, - jnp.expand_dims(Vfs[-1], 0), 0) - return smoothed_states, smoothed_state_covariances - - -@typechecked -def single_timestep_nll( - innovation: jnp.ndarray, - innovation_cov: jnp.ndarray -) -> jax.Array: - """ - Computes the negative log-likelihood (NLL) of a single multivariate Gaussian observation. - - Measures how well the predicted state explains the current observation. - A small regularization term (epsilon) is added to the covariance to ensure numerical stability. - - Args: - innovation: Array of shape (D,). The difference between observed and predicted observation. - innovation_cov: Array of shape (D, D). Covariance of the innovation. - - Returns: - nll_increment: Scalar float stored as a jax.Array. - Negative log-likelihood of observing the current innovation. - """ - epsilon = 1e-6 - n_coords = innovation.shape[0] - - # Regularize the innovation covariance matrix by adding epsilon to the diagonal - reg_innovation_cov = innovation_cov + epsilon * jnp.eye(n_coords) - - # Compute the log determinant of the regularized covariance matrix - log_det_S = jnp.log(jnp.abs(jnp.linalg.det(reg_innovation_cov)) + epsilon) - solved_term = jnp.linalg.solve(reg_innovation_cov, innovation) - quadratic_term = jnp.dot(innovation, solved_term) - - # Compute the NLL increment for the current time step - c = jnp.log(2 * jnp.pi) * n_coords # The Gaussian normalization constant part - nll_increment = 0.5 * jnp.abs(log_det_S + quadratic_term + c) - return nll_increment - - -@typechecked -def final_forwards_backwards_pass( - process_cov: jnp.ndarray, - s: np.ndarray, - ys: np.ndarray, - m0s: jnp.ndarray, - S0s: jnp.ndarray, - Cs: jnp.ndarray, - As: jnp.ndarray, - ensemble_vars: np.ndarray, -) -> Tuple[np.ndarray, np.ndarray]: - """ - Runs the full Kalman forward-backward smoother across all keypoints using - optimized smoothing parameters. - - Computes smoothed state means and covariances for each keypoint over time. - The process noise covariance is scaled per-keypoint by a learned smoothing parameter `s`. - - Args: - process_cov: Array of shape (K, D, D). Base process noise covariance per keypoint. - s: Array of shape (K,). Smoothing scalars applied to process_cov per keypoint. - ys: Array of shape (K, T, obs_dim). Observations per keypoint over time. - m0s: Array of shape (K, D). Initial state mean per keypoint. - S0s: Array of shape (K, D, D). Initial state covariance per keypoint. - Cs: Array of shape (K, obs_dim, D). Observation matrix per keypoint. - As: Array of shape (K, D, D). State transition matrix per keypoint. - ensemble_vars: Array of shape (T, K, obs_dim). Time-varying obs variances per keypoint. - - Returns: - smoothed_means: Array of shape (K, T, D). Smoothed state means for each keypoint over time. - smoothed_covariances: Array of shape (K, T, D, D). Smoothed state covariances over time. - """ - - # Initialize - n_keypoints = ys.shape[0] - ms_array = [] - Vs_array = [] - Qs = s[:, None, None] * process_cov - - # Run forward and backward pass for each keypoint - for k in range(n_keypoints): - mf, Vf, nll = forward_pass( - ys[k], m0s[k], S0s[k], As[k], Qs[k], Cs[k], ensemble_vars[:, k, :]) - ms, Vs = backward_pass(mf, Vf, As[k], Qs[k]) - - ms_array.append(np.array(ms)) - Vs_array.append(np.array(Vs)) - - smoothed_means = np.stack(ms_array, axis=0) - smoothed_covariances = np.stack(Vs_array, axis=0) - - return smoothed_means, smoothed_covariances - -# ------------------------------------------------------------------------------------- -# Optimization: Functions related to optimizing the smoothing hyperparameter -# ------------------------------------------------------------------------------------- - @typechecked def compute_initial_guesses( @@ -461,198 +108,298 @@ def compute_initial_guesses( # Compute temporal differences temporal_diffs = ensemble_vars[1:] - ensemble_vars[:-1] - # Compute standard deviation across all temporal differences std_dev_guess = round(np.nanstd(temporal_diffs), 5) return float(std_dev_guess) +def params_nlgssm_for_keypoint(m0, S0, Q, s, R, f_fn, h_fn) -> ParamsNLGSSM: + """ + Construct the ParamsNLGSSM for a single (keypoint) sequence. + """ + return ParamsNLGSSM( + initial_mean=jnp.asarray(m0), + initial_covariance=jnp.asarray(S0), + dynamics_function=f_fn, + dynamics_covariance=jnp.asarray(s) * jnp.asarray(Q), + emission_function=h_fn, + emission_covariance=jnp.asarray(R), + ) + + +# ----------------- Public API ----------------- @typechecked -def optimize_smooth_param( - cov_mats: jnp.ndarray, - ys: np.ndarray, - m0s: jnp.ndarray, - S0s: jnp.ndarray, - Cs: jnp.ndarray, - As: jnp.ndarray, - ensemble_vars: np.ndarray, +def run_kalman_smoother( + ys: jnp.ndarray, # (K, T, obs) + m0s: jnp.ndarray, # (K, D) + S0s: jnp.ndarray, # (K, D, D) + As: jnp.ndarray, # (K, D, D) + Cs: jnp.ndarray, # (K, obs, D) + Qs: jnp.ndarray, # (K, D, D) + ensemble_vars: np.ndarray, # (T, K, obs) s_frames: Optional[List] = None, smooth_param: Optional[Union[float, List[float]]] = None, blocks: Optional[List[List[int]]] = None, - maxiter: int = 1000, verbose: bool = False, + # JIT-closed constants: + lr: float = 0.25, + s_bounds_log: tuple = (-8.0, 8.0), + tol: float = 1e-3, + safety_cap: int = 5000, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ - Optimize smoothing parameters for each keypoint (or block of keypoints) using - negative log-likelihood minimization, and apply final Kalman forward-backward smoothing. + Optimize the process-noise scale `s` (shared within each block of keypoints) by minimizing + the summed EKF filter negative log-likelihood (NLL) in a *linear* state-space model, + then run the EKF smoother for final trajectories. - If `smooth_param` is provided, it is used directly. Otherwise, the function computes - initial guesses and uses gradient descent to optimize per-block values of `s`. + Model per keypoint k: + x_{t+1} = A_k x_t + w_t, y_t = C_k x_t + v_t + w_t ~ N(0, s * Q_k), v_t ~ N(0, R_{k,t}), with time-varying R_{k,t}. Args: - cov_mats: Array of shape (K, D, D). Base process noise covariances per keypoint. - ys: Array of shape (K, T, obs_dim). Observations per keypoint over time. - m0s: Array of shape (K, D). Initial state means per keypoint. - S0s: Array of shape (K, D, D). Initial state covariances per keypoint. - Cs: Array of shape (K, obs_dim, D). Observation matrices per keypoint. - As: Array of shape (K, D, D). State transition matrices per keypoint. - ensemble_vars: Array of shape (T, K, obs_dim). Time-varying ensemble variances. - s_frames: Optional list of frame indices for computing initial guess statistics. - smooth_param: Optional fixed value(s) of smoothing param `s`. - Can be a float or list of floats (one per keypoint/block). - blocks: Optional list of lists of keypoint indices to share a smoothing param. - Defaults to treating each keypoint independently. - maxiter: Max number of optimization steps per block. - verbose: If True, print progress logs. + ys: (K, T, obs) observations per keypoint over time. + m0s: (K, D) initial state means. + S0s: (K, D, D) initial state covariances. + As: (K, D, D) transition matrices. + Cs: (K, obs, D) observation matrices. + Qs: (K, D, D) base process covariances (scaled by `s`). + ensemble_vars: (T, K, obs) per-frame ensemble variances; used to build R_{k,t} + via diag(clip(ensemble_vars[t, k, :], 1e-12, ∞)). + s_frames: Optional list of (start, end) tuples (1-based, inclusive end) to crop + the time axis *for the loss only*. Final smoothing uses the full sequence. + smooth_param: If provided, bypass optimization. Either a scalar (shared across K) + or a list of length K (per-keypoint). + blocks: Optional list of lists of keypoint indices; each block shares one `s`. + Default: each keypoint is its own block. + verbose: Print per-block optimization summaries if True. + lr: Adam learning rate (on log(s)). + s_bounds_log: Clamp bounds for log(s) during optimization. + tol: Relative tolerance on loss change for early stopping. + safety_cap: Hard iteration cap inside the jitted while-loop. Returns: - s_finals: Array of shape (K,). Final smoothing parameter per keypoint. - ms: Array of shape (K, T, D). Smoothed state means. - Vs: Array of shape (K, T, D, D). Smoothed state covariances. + s_finals: (K,) final `s` per keypoint (block optimum broadcast to members). + ms: (K, T, D) smoothed state means. + Vs: (K, T, D, D) smoothed state covariances. """ - - n_keypoints = ys.shape[0] - s_finals = [] - if blocks is None: - blocks = [] - if len(blocks) == 0: - for n in range(n_keypoints): - blocks.append([n]) + K, T, obs_dim = ys.shape + if not blocks: + blocks = [[k] for k in range(K)] if verbose: - print(f'Correlated keypoint blocks: {blocks}') + print(f"Correlated keypoint blocks: {blocks}") - @partial(jit) - def nll_loss_sequential_scan(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, ensemble_vars): - s = jnp.exp(s) # To ensure positivity - return smooth_min( - s, cov_mats, cropped_ys, m0s, S0s, Cs, As, ensemble_vars) + # Build time-varying R (K, T, obs, obs) + Rs = jnp.asarray(build_R_from_vars(np.swapaxes(ensemble_vars, 0, 1))) - loss_function = nll_loss_sequential_scan + # Initial guesses per keypoint (host-side) + s_guess_per_k = np.empty(K, dtype=float) + for k in range(K): + g = float(compute_initial_guesses(ensemble_vars[:, k, :]) or 2.0) + s_guess_per_k[k] = g if (np.isfinite(g) and g > 0.0) else 2.0 - # Optimize smooth_param + # Choose or optimize s + s_finals = np.empty(K, dtype=float) if smooth_param is not None: - if isinstance(smooth_param, float): - s_finals = [smooth_param] - elif isinstance(smooth_param, int): - s_finals = [float(smooth_param)] + if isinstance(smooth_param, (int, float)): + s_finals[:] = float(smooth_param) else: - s_finals = smooth_param + s_finals[:] = np.asarray(smooth_param, dtype=float) else: - guesses = [] - cropped_ys = [] - for k in range(n_keypoints): - current_guess = compute_initial_guesses(ensemble_vars[:, k, :]) - guesses.append(current_guess) - if s_frames is None or len(s_frames) == 0: - cropped_ys.append(ys[k]) - else: - cropped_ys.append(crop_frames(ys[k], s_frames)) - - cropped_ys = np.array(cropped_ys) # Concatenation of this list along dimension 0 - - # Optimize negative log likelihood - for block in blocks: - s_init = guesses[block[0]] - if s_init <= 0: - s_init = 2 - s_init = jnp.log(s_init) - optimizer = optax.adam(learning_rate=0.25) - opt_state = optimizer.init(s_init) - - selector = np.array(block).astype(int) - cov_mats_sub = cov_mats[selector] - m0s_crop = m0s[selector] - S0s_crop = S0s[selector] - Cs_crop = Cs[selector] - As_crop = As[selector] - y_subset = cropped_ys[selector] - ensemble_vars_crop = np.swapaxes(ensemble_vars[:, selector, :], 0, 1) - - def step(s, opt_state): - loss, grads = jax.value_and_grad(loss_function)( - s, cov_mats_sub, y_subset, m0s_crop, S0s_crop, Cs_crop, As_crop, - ensemble_vars_crop) - updates, opt_state = optimizer.update(grads, opt_state) - s = optax.apply_updates(s, updates) - return s, opt_state, loss - - prev_loss = jnp.inf - for iteration in range(maxiter): - s_init, opt_state, loss = step(s_init, opt_state) - - if verbose and iteration % 10 == 0 or iteration == maxiter - 1: - print(f'Iteration {iteration}, Current loss: {loss}, Current s: {s_init}') - - tol = 0.001 * jnp.abs(jnp.log(prev_loss)) - if jnp.linalg.norm(loss - prev_loss) < tol + 1e-6: - break - prev_loss = loss - - s_final = jnp.exp(s_init) # Convert back from log-space - - for b in block: - if verbose: - print(f's={s_final} for keypoint {b}') - s_finals.append(s_final) - - s_finals = np.array(s_finals) - # Final smooth with optimized s - ms, Vs = final_forwards_backwards_pass( - cov_mats, s_finals, ys, m0s, S0s, Cs, As, ensemble_vars, - ) + optimize_smooth_param( + ys=ys, + m0s=m0s, + S0s=S0s, + As=As, + Cs=Cs, + Qs=Qs, + Rs=Rs, + blocks=blocks, + lr=lr, + s_bounds_log=s_bounds_log, + s_finals=s_finals, + s_frames=s_frames, + s_guess_per_k=s_guess_per_k, + tol=tol, + verbose=verbose, + safety_cap=safety_cap, + ) + + # Final smoother pass (full R_t) + def _params_linear_for_k(k: int, s_val: float): + A_k, C_k = As[k], Cs[k] + f_fn = (lambda x, A=A_k: A @ x) + h_fn = (lambda x, C=C_k: C @ x) + return params_nlgssm_for_keypoint( + m0s[k], S0s[k], Qs[k], s_val, Rs[k], f_fn, h_fn) + + means_list, covs_list = [], [] + for k in range(K): + params_k = _params_linear_for_k(k, s_finals[k]) + sm = extended_kalman_smoother(params_k, ys[k]) + if hasattr(sm, "smoothed_means"): + m_k, V_k = sm.smoothed_means, sm.smoothed_covariances + else: + m_k, V_k = sm.filtered_means, sm.filtered_covariances + means_list.append(np.array(m_k)) + covs_list.append(np.array(V_k)) + ms = np.stack(means_list, axis=0) + Vs = np.stack(covs_list, axis=0) return s_finals, ms, Vs -@typechecked -def inner_smooth_min_routine( - y: jnp.ndarray, - m0: jnp.ndarray, - S0: jnp.ndarray, - A: jnp.ndarray, - Q: jnp.ndarray, - C: jnp.ndarray, - ensemble_var: jnp.ndarray -) -> jax.Array: - # Run filtering with the current smooth_param - _, _, nll = forward_pass(y, m0, S0, A, Q, C, ensemble_var) - return nll - - -inner_smooth_min_routine_vmap = vmap(inner_smooth_min_routine, in_axes=(0, 0, 0, 0, 0, 0, 0)) - - -@typechecked -def smooth_min( - smooth_param: jax.Array, - cov_mats: jnp.ndarray, - ys: jnp.ndarray, - m0s: jnp.ndarray, - S0s: jnp.ndarray, - Cs: jnp.ndarray, - As: jnp.ndarray, - ensemble_vars: jnp.ndarray -) -> jax.Array: +# ----------------- Optimizer (blockwise s) ----------------- +def optimize_smooth_param( + ys: jnp.ndarray, # (K, T, obs) + m0s: jnp.ndarray, # (K, D) + S0s: jnp.ndarray, # (K, D, D) + As: jnp.ndarray, # (K, D, D) + Cs: jnp.ndarray, # (K, obs, D) + Qs: jnp.ndarray, # (K, D, D) + Rs: jnp.ndarray, # (K, T, obs, obs) time-varying R_t + blocks: Optional[list], + lr: float, + s_bounds_log: tuple, + s_finals: np.ndarray, # (K,), filled in-place + s_frames: Optional[list], + s_guess_per_k: np.ndarray, # (K,) + tol: float, + verbose: bool, + safety_cap: int, +) -> None: """ - Computes the total negative log-likelihood (NLL) for a given smoothing parameter - by running a full forward-pass Kalman filter over all keypoints. - - This is the objective function minimized during smoothing parameter optimization. - - Args: - smooth_param: Scalar float value of the smoothing parameter `s`. - cov_mats: Array of shape (K, D, D). Process noise covariance templates. - ys: Array of shape (K, T, obs_dim). Observations per keypoint. - m0s: Array of shape (K, D). Initial state means. - S0s: Array of shape (K, D, D). Initial state covariances. - Cs: Array of shape (K, obs_dim, D). Observation matrices. - As: Array of shape (K, D, D). State transition matrices. - ensemble_vars: Array of shape (T, K, obs_dim). Time-varying ensemble variances. - - Returns: - nlls: Scalar JAX array. Total negative log-likelihood across all keypoints. + Optimize a single scalar process-noise scale `s` per block of keypoints by minimizing + the sum of EKF filter negative log-likelihoods, using time-varying observation noise + R_{k,t}. Writes results into `s_finals` in place. + + Parameters + ---------- + ys : jnp.ndarray, shape (K, T, obs) + Observations per keypoint (JAX). For cropped loss, host-side slices are created. + m0s : jnp.ndarray, shape (K, D) + Initial state means per keypoint. + S0s : jnp.ndarray, shape (K, D, D) + Initial state covariances per keypoint. + As : jnp.ndarray, shape (K, D, D) + State transition matrices. + Cs : jnp.ndarray, shape (K, obs, D) + Observation matrices. + Qs : jnp.ndarray, shape (K, D, D) + Base process covariances (scaled by `s` inside the model). + Rs : jnp.ndarray, shape (K, T, obs, obs) + Time-varying observation covariances for each keypoint. + blocks : list[list[int]] or None + Groups of keypoint indices that share a single `s`. + If None/empty, each keypoint is its own block. + lr : float + Adam learning rate (on log(s)). + s_bounds_log : (float, float) + Clamp bounds for log(s) to stabilize optimization. + s_finals : np.ndarray, shape (K,) + Output array filled with final per-keypoint `s` (block optimum broadcast). + s_frames : list or None + Frame ranges for cropping (list of (start, end); 1-based start, inclusive end). + Applied to both y and R_t for the loss only. + s_guess_per_k : np.ndarray, shape (K,) + Heuristic initial guesses of `s` per keypoint. Block init uses the mean over members. + tol : float + Relative tolerance on loss change for early stopping. + verbose : bool + If True, prints per-block optimization progress. + safety_cap : int + Maximum number of iterations inside the jitted while-loop. + + Returns + ------- + None + Results are written into `s_finals`. """ - # Adjust Q based on smooth_param and cov_matrix - Qs = smooth_param * cov_mats - nlls = jnp.sum(inner_smooth_min_routine_vmap(ys, m0s, S0s, As, Qs, Cs, ensemble_vars)) - return nlls + optimizer = optax.adam(float(lr)) + s_bounds_log_j = jnp.array(s_bounds_log, dtype=jnp.float32) + tol_j = float(tol) + + def _params_linear(m0, S0, A, Q_base, s, R_any, C): + f_fn = (lambda x, A=A: A @ x) + h_fn = (lambda x, C=C: C @ x) + return params_nlgssm_for_keypoint(m0, S0, Q_base, s, R_any, f_fn, h_fn) + + def _nll_one_keypoint(log_s, y_k, m0_k, S0_k, A_k, Q_k, C_k, R_k_tv): + s = jnp.exp(jnp.clip(log_s, s_bounds_log_j[0], s_bounds_log_j[1])) + params = _params_linear(m0_k, S0_k, A_k, Q_k, s, R_k_tv, C_k) + post = extended_kalman_filter(params, jnp.asarray(y_k)) + return -post.marginal_loglik + + def _nll_block(log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + nlls = jax.vmap(_nll_one_keypoint, in_axes=(None, 0, 0, 0, 0, 0, 0, 0))( + log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + return jnp.sum(nlls) + + @jit + def _opt_step(log_s, opt_state, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + loss, grad = value_and_grad(_nll_block)( + log_s, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + updates, opt_state = optimizer.update(grad, opt_state) + log_s = optax.apply_updates(log_s, updates) + return log_s, opt_state, loss + + @jit + def _run_tol_loop(log_s0, opt_state0, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv): + def cond(carry): + _, _, prev_loss, iters, done = carry + return jnp.logical_and(~done, iters < safety_cap) + + def body(carry): + log_s, opt_state, prev_loss, iters, _ = carry + log_s, opt_state, loss = _opt_step( + log_s, opt_state, ys_b, m0s_b, S0s_b, As_b, Qs_b, Cs_b, Rs_b_tv + ) + rel_tol = tol_j * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12))) + done = jnp.where( + jnp.isfinite(prev_loss), + jnp.linalg.norm(loss - prev_loss) < (rel_tol + 1e-6), + False + ) + return (log_s, opt_state, loss, iters + 1, done) + + return lax.while_loop( + cond, body, (log_s0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) + ) + + # For cropping only: host view + Rs_np = np.asarray(Rs) + ys_np = np.asarray(ys) + + # Optimize per block (shared s) + for block in (blocks or []): + sel = jnp.asarray(block, dtype=int) + + if s_frames and len(s_frames) > 0: + y_block_list = [crop_frames(ys_np[int(k)], s_frames) for k in block] # (T', obs) + R_block_list = [crop_R(Rs_np[int(k)], s_frames) for k in block] # (T', obs, obs) + y_block = jnp.asarray(np.stack(y_block_list, axis=0)) # (B, T', obs) + R_block = jnp.asarray(np.stack(R_block_list, axis=0)) # (B, T', obs, obs) + else: + y_block = ys[sel] # (B, T, obs) + R_block = Rs[sel] # (B, T, obs, obs) + + m0_block = m0s[sel] + S0_block = S0s[sel] + A_block = As[sel] + Q_block = Qs[sel] + C_block = Cs[sel] + + s0 = float(np.mean([s_guess_per_k[k] for k in block])) + log_s0 = jnp.array(np.log(max(s0, 1e-6)), dtype=jnp.float32) + opt_state0 = optimizer.init(log_s0) + + log_s_f, opt_state_f, last_loss, iters_f, _done = _run_tol_loop( + log_s0, opt_state0, y_block, m0_block, S0_block, A_block, Q_block, C_block, R_block + ) + s_star = float(jnp.exp(jnp.clip(log_s_f, s_bounds_log_j[0], s_bounds_log_j[1]))) + for k in block: + s_finals[k] = s_star + if verbose: + print(f"[Block {block}] s={s_star:.6g}, " + f"iters={int(iters_f)}, NLL={float(last_loss):.6f}") diff --git a/eks/ibl_paw_multicam_smoother.py b/eks/ibl_paw_multicam_smoother.py index 81ed91c..80d6c40 100644 --- a/eks/ibl_paw_multicam_smoother.py +++ b/eks/ibl_paw_multicam_smoother.py @@ -9,12 +9,9 @@ from eks.marker_array import ( MarkerArray, input_dfs_to_markerArray, - mA_to_stacked_array, - stacked_array_to_mA, ) from eks.multicam_smoother import ensemble_kalman_smoother_multicam -from eks.stats import compute_pca -from eks.utils import convert_lp_dlc, make_dlc_pandas_index +from eks.utils import convert_lp_dlc def remove_camera_means(ensemble_stacks, camera_means): @@ -39,6 +36,7 @@ def pca(S, n_comps): pca_ = PCA(n_components=n_comps) return pca_.fit(S), pca_.explained_variance_ratio_ + @typechecked def fit_eks_multicam_ibl_paw( input_source: str | list, @@ -67,7 +65,7 @@ def fit_eks_multicam_ibl_paw( 'var' | 'confidence_weighted_var' verbose: True to print out details img_width: The width of the image being smoothed (128 default, IBL-specific). - inflate_vars: True to use Mahalanobis distance thresholding to inflate ensemble variance + inflate_vars: True to use Mahalanobis distance threshold to inflate ensemble variance n_latent: number of dimensions to keep from PCA Returns: @@ -119,7 +117,7 @@ def fit_eks_multicam_ibl_paw( raise ValueError('Need timestamps for both cameras') if len(input_dfs_right) != len(input_dfs_left) or len(input_dfs_left) == 0: raise ValueError( - 'There must be the same number of left and right camera models and >=1 model for each.') + 'Need same number of left and right camera models and >=1 model for each.') # Interpolate right cam markers to left cam timestamps markers_list_stacked_interp = [] diff --git a/eks/ibl_pupil_smoother.py b/eks/ibl_pupil_smoother.py index 2209020..0f0de25 100644 --- a/eks/ibl_pupil_smoother.py +++ b/eks/ibl_pupil_smoother.py @@ -1,18 +1,24 @@ import os import warnings -from functools import partial + +from dynamax.nonlinear_gaussian_ssm.inference_ekf import ( + extended_kalman_filter, + extended_kalman_smoother, +) import jax import numpy as np import optax import pandas as pd -from jax import jit +from jax import jit, lax, value_and_grad from jax import numpy as jnp +from numbers import Real from typeguard import typechecked +from typing import List, Optional, Sequence, Tuple -from eks.core import backward_pass, ensemble, forward_pass +from eks.core import ensemble, params_nlgssm_for_keypoint from eks.marker_array import MarkerArray, input_dfs_to_markerArray -from eks.utils import crop_frames, format_data, make_dlc_pandas_index +from eks.utils import build_R_from_vars, crop_frames, crop_R, format_data, make_dlc_pandas_index @typechecked @@ -226,9 +232,6 @@ def ensemble_kalman_smoother_ibl_pupil( [-.5, 1, 0], [0, 0, 1] ]) - # placeholder diagonal matrix for ensemble variance - R = np.eye(8) - centered_ensemble_preds = ensemble_preds.copy() # subtract COM means from the ensemble predictions for i in range(ensemble_preds.shape[1]): @@ -241,10 +244,19 @@ def ensemble_kalman_smoother_ibl_pupil( # ------------------------------------------------------- # Perform filtering with SINGLE PAIR of diameter_s, com_s # ------------------------------------------------------- - s_finals, ms, Vs, nll = pupil_optimize_smooth( - y_obs, m0, S0, C, R, ensemble_vars, - np.var(pupil_diameters), np.var(x_t_obs), np.var(y_t_obs), s_frames, smooth_params, - verbose=verbose) + s_finals, ms, Vs = run_pupil_kalman_smoother( + ys=jnp.asarray(y_obs), + m0=jnp.asarray(m0), + S0=jnp.asarray(S0), + C=jnp.asarray(C), + ensemble_vars=ensemble_vars, + diameters_var=np.var(pupil_diameters), + x_var=np.var(x_t_obs), + y_var=np.var(y_t_obs), + s_frames=s_frames, + smooth_params=smooth_params, + verbose=verbose + ) if verbose: print(f"diameter_s={s_finals[0]}, com_s={s_finals[1]}") # Smoothed posterior over ys @@ -305,133 +317,281 @@ def ensemble_kalman_smoother_ibl_pupil( return markers_df, s_finals -def pupil_optimize_smooth( - ys: np.ndarray, - m0: np.ndarray, - S0: np.ndarray, - C: np.ndarray, - R: np.ndarray, - ensemble_vars: np.ndarray, - diameters_var: np.ndarray, - x_var: np.ndarray, - y_var: np.ndarray, - s_frames: list | None = [(1, 2000)], - smooth_params: list | None = [None, None], - maxiter: int = 1000, - verbose: bool = False, -) -> tuple: - """Optimize-and-smooth function for the pupil example script. - - Parameters: - ys: Observations. Shape (keypoints, frames, coordinates). - m0: Initial mean state. - S0: Initial state covariance. - C: Measurement function. - R: Measurement noise covariance. - ensemble_vars: Ensemble variances. - diameters_var: Diameter variance - x_var: x variance for COM - y_var: y variance for COM - s_frames: List of frames. - smooth_params: Smoothing parameter tuple (diameter_s, com_s) - verbose: Prints extra information for smoothing parameter iterations +# ----------------- Public API ----------------- +@typechecked +def run_pupil_kalman_smoother( + ys: jnp.ndarray, # (T, 8) centered obs + m0: jnp.ndarray, # (3,) + S0: jnp.ndarray, # (3,3) + C: jnp.ndarray, # (8,3) + ensemble_vars: np.ndarray, # (T, 8) + diameters_var: Real, + x_var: Real, + y_var: Real, + s_frames: Optional[List[Tuple[Optional[int], Optional[int]]]] = None, + smooth_params: Optional[list] = None, # [s_diam, s_com] in (0,1) + verbose: bool = False, + # optimizer/loop knobs + lr: float = 5e-3, + tol: float = 1e-6, + safety_cap: int = 5000, +) -> Tuple[List[float], np.ndarray, np.ndarray]: + """ + Optimize pupil AR(1) smoothing params `[s_diam, s_com]` via EKF filter NLL with + time-varying R_t built from ensemble variances, then run EKF smoother for final + trajectories. + + Args: + ys: (T, 8) centered observations (order: top,bottom,right,left x/y). + m0: (3,) initial state mean [diameter, com_x, com_y]. + S0: (3,3) initial state covariance. + C: (8,3) observation matrix mapping state -> 8 observed coords. + ensemble_vars: (T, 8) per-dimension ensemble variances; used to build R_t. + diameters_var: variance scale for diameter latent. + x_var, y_var: variance scales for com_x, com_y latents. + s_frames: optional list of (start, end) 1-based, inclusive frame ranges for + NLL optimization only (final smoothing runs over the full T). + smooth_params: if provided, use `[s_diam, s_com]` directly (values in (0,1)). + verbose: print optimization summary. + lr: Adam learning rate on the unconstrained parameters. + tol: relative tolerance for early stopping. + safety_cap: hard limit on optimizer steps inside the jitted loop. Returns: - tuple: Final smoothing parameters, smoothed means, smoothed covariances, - negative log-likelihoods, negative log-likelihood values. + (s_finals, ms, Vs): + s_finals: [s_diam, s_com] + ms: (T, 3) smoothed state means + Vs: (T, 3, 3) smoothed state covariances """ + # build time-varying R_t (T, 8, 8) and JAX-ify inputs + R = jnp.asarray(build_R_from_vars(ensemble_vars)) + + # --- optimize [s_diam, s_com] on cropped loss (if requested) --- + s_d, s_c = pupil_optimize_smooth( + ys=ys, + m0=m0, + S0=S0, + C=C, + R=R, + diameters_var=diameters_var, + x_var=x_var, + y_var=y_var, + s_frames=s_frames, + smooth_params=smooth_params, + lr=lr, + tol=tol, + safety_cap=safety_cap, + verbose=verbose, + ) - @partial(jit) - def nll_loss_sequential_scan( - s_log, ys, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var): - s = jnp.exp(s_log) # Ensure positivity - return pupil_smooth( - s, ys, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var) - - loss_function = nll_loss_sequential_scan - # Optimize smooth_param - if smooth_params is None or smooth_params[0] is None or smooth_params[1] is None: - # Crop to only contain s_frames for time axis - y_cropped = crop_frames(ys, s_frames) - ensemble_vars_cropped = crop_frames(ensemble_vars, s_frames) - - # Optimize negative log likelihood - s_init = jnp.log(jnp.array([0.99, 0.98])) # reasonable guess for s_finals - optimizer = optax.adam(learning_rate=0.005) - opt_state = optimizer.init(s_init) - - def step(s, opt_state): - loss, grads = jax.value_and_grad(loss_function)( - s, y_cropped, m0, S0, C, R, ensemble_vars_cropped, diameters_var, x_var, y_var - ) - updates, opt_state = optimizer.update(grads, opt_state) - s = optax.apply_updates(s, updates) - return s, opt_state, loss + # --- final smoother on full sequence with A(s), Q(s) and supplied R_t --- + s_d_j, s_c_j = jnp.asarray(s_d), jnp.asarray(s_c) + A = jnp.diag(jnp.array([s_d_j, s_c_j, s_c_j])) + Q = jnp.diag(jnp.array([ + jnp.asarray(diameters_var) * (1.0 - s_d_j**2), + jnp.asarray(x_var) * (1.0 - s_c_j**2), + jnp.asarray(y_var) * (1.0 - s_c_j**2), + ])) - prev_loss = jnp.inf - for iteration in range(maxiter): - s_init, opt_state, loss = step(s_init, opt_state) + f_fn = (lambda x: A @ x) + h_fn = (lambda x: C @ x) + # Pass Q as exact and s=1.0 (we already encoded s into A, Q) + params = params_nlgssm_for_keypoint(m0, S0, Q, 1.0, R, f_fn, h_fn) - if verbose and iteration % 10 == 0 or iteration == maxiter - 1: - print(f'Iteration {iteration}, Current loss: {loss}, Current s: {jnp.exp(s_init)}') + sm = extended_kalman_smoother(params, ys) + ms = np.array(getattr(sm, "smoothed_means", sm.filtered_means)) + Vs = np.array(getattr(sm, "smoothed_covariances", sm.filtered_covariances)) + return [float(s_d), float(s_c)], ms, Vs - tol = 1e-6 * jnp.abs(jnp.log(prev_loss)) - if jnp.linalg.norm(loss - prev_loss) < tol + 1e-6: - break - prev_loss = loss - s_finals = jnp.exp(s_init) - s_finals = [round(s_finals[0], 5), round(s_finals[1], 5)] - print(f'Optimized to diameter_s={s_finals[0]}, com_s={s_finals[1]}') +# ----------------- Optimizer (two-parameter AR(1)) ----------------- +@typechecked +def pupil_optimize_smooth( + ys: jnp.ndarray, # (T, 8) centered obs + m0: jnp.ndarray, # (3,) + S0: jnp.ndarray, # (3,3) + C: jnp.ndarray, # (8,3) + R: jnp.ndarray, # (T, 8, 8) time-varying obs covariance + diameters_var: Real, + x_var: Real, + y_var: Real, + s_frames: Optional[List[Tuple[Optional[int], Optional[int]]]] = None, + smooth_params: Optional[list] = None, # [s_diam, s_com] in (0,1) + lr: float = 5e-3, + tol: float = 1e-6, + safety_cap: int = 5000, + verbose: bool = False, +) -> Tuple[float, float]: + """ + Optimize `[s_diam, s_com]` for the pupil AR(1) model by minimizing EKF filter + negative log-likelihood on (optionally) cropped data. Uses a logistic reparam + to keep the parameters in (ε, 1−ε). Returns the optimized pair. + + Parameters + ---------- + ys : jnp.ndarray, shape (T, 8) + Centered observations. + m0 : jnp.ndarray, shape (3,) + Initial state mean. + S0 : jnp.ndarray, shape (3, 3) + Initial state covariance. + C : jnp.ndarray, shape (8, 3) + Observation matrix. + R : jnp.ndarray, shape (T, 8, 8) + Time-varying observation covariance. + diameters_var : Real + Variance scale for diameter latent. + x_var, y_var : Real + Variance scales for com_x and com_y latents. + s_frames : list[(start, end)] or None + 1-based start, inclusive end cropping ranges for the loss only. + smooth_params : Optional[Sequence[Real]] + If provided and both values are not None, bypass optimization and use them directly. + lr : float + Adam learning rate on the unconstrained variables. + tol : float + Relative tolerance for early stopping. + safety_cap : int + Hard iteration cap in the jitted loop. + verbose : bool + Print optimization summary if True. + + Returns + ------- + (s_diam, s_com) : Tuple[float, float] + Optimized AR(1) parameters in (0, 1). + """ + # Map unconstrained u -> s in (eps, 1-eps) + def _to_stable_s(u, eps=1e-3): + return jax.nn.sigmoid(u) * (1.0 - 2 * eps) + eps + + # Cropping for loss (host-side), then back to JAX + ys_np = np.asarray(ys) + R_np = np.asarray(R) + if s_frames and len(s_frames) > 0: + y_loss = jnp.asarray(crop_frames(ys_np, s_frames)) # (T', 8) + R_loss = jnp.asarray(crop_R(R_np, s_frames)) # (T', 8, 8) else: - s_finals = smooth_params + y_loss = ys + R_loss = R + + # Params builder with Q exact and s=1.0 (A, Q depend on s directly) + def _params_linear(m0, S0, A, Q_exact, R_any, C): + f_fn = (lambda x, A=A: A @ x) + h_fn = (lambda x, C=C: C @ x) + return params_nlgssm_for_keypoint(m0, S0, Q_exact, 1.0, R_any, f_fn, h_fn) + + # NLL(u) with u = [u_diam, u_com] + def _nll_from_u(u: jnp.ndarray) -> jnp.ndarray: + s_d, s_c = _to_stable_s(u) + A = jnp.diag(jnp.array([s_d, s_c, s_c])) + Q = jnp.diag(jnp.array([ + jnp.asarray(diameters_var) * (1.0 - s_d**2), + jnp.asarray(x_var) * (1.0 - s_c**2), + jnp.asarray(y_var) * (1.0 - s_c**2), + ])) + params = _params_linear(m0, S0, A, Q, R_loss, C) + post = extended_kalman_filter(params, y_loss) + return -post.marginal_loglik + + # If user provided both params, just use them + if smooth_params is not None and all(v is not None for v in smooth_params): + s = jnp.clip(jnp.asarray(smooth_params, dtype=jnp.float32), 1e-3, 1 - 1e-3) + return float(s[0]), float(s[1]) + + # Otherwise optimize in unconstrained space + optimizer = optax.adam(lr) + s0 = jnp.array([0.99, 0.98], dtype=jnp.float32) + u0 = jnp.log(s0 / (1.0 - s0)) + opt_state0 = optimizer.init(u0) + + @jit + def _opt_step(u, opt_state): + loss, grad = value_and_grad(_nll_from_u)(u) + updates, opt_state = optimizer.update(grad, opt_state) + u = optax.apply_updates(u, updates) + return u, opt_state, loss + + @jit + def _run_tol_loop(u0, opt_state0): + def cond(carry): + _, _, prev_loss, iters, done = carry + return jnp.logical_and(~done, iters < safety_cap) + + def body(carry): + u, opt_state, prev_loss, iters, _ = carry + u, opt_state, loss = _opt_step(u, opt_state) + rel_tol = tol * jnp.abs(jnp.log(jnp.maximum(prev_loss, 1e-12))) + done = jnp.where( + jnp.isfinite(prev_loss), + jnp.linalg.norm(loss - prev_loss) < (rel_tol + 1e-6), + False + ) + return (u, opt_state, loss, iters + 1, done) - # Final smooth with optimized s - ms, Vs, nll = pupil_smooth( - s_finals, ys, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var, return_full=True) + return lax.while_loop( + cond, body, (u0, opt_state0, jnp.inf, jnp.array(0), jnp.array(False)) + ) - return s_finals, ms, Vs, nll + u_f, _opt_state_f, last_loss, iters_f, _ = _run_tol_loop(u0, opt_state0) + s_opt = _to_stable_s(u_f) + if verbose: + print(f"[pupil/dynamax] iters={int(iters_f)} " + f"s_diam={float(s_opt[0]):.6f} s_com={float(s_opt[1]):.6f} " + f"NLL={float(last_loss):.6f}") + return float(s_opt[0]), float(s_opt[1]) -def pupil_smooth(smooth_params, ys, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var, - return_full=False): +@typechecked +def pupil_smooth( + smooth_params: Sequence[float], # [s_diam, s_com] in (0,1) + ys: np.ndarray | jnp.ndarray, # (T, 8) + m0: np.ndarray | jnp.ndarray, # (3,) + S0: np.ndarray | jnp.ndarray, # (3,3) + C: np.ndarray | jnp.ndarray, # (8,3) + R: np.ndarray | jnp.ndarray, # (T, 8, 8) time-varying obs covariance + diameters_var: Real, + x_var: float, + y_var: float, + return_full: bool = False, +): """ - Smooths once using the given smooth_param. Returns only the nll loss by default - (if return_full is False). - - Parameters: - smooth_params (float): Smoothing parameter. - block (list): List of blocks. - cov_mats (np.ndarray): Covariance matrices. - ys (np.ndarray): Observations. - m0s (np.ndarray): Initial mean state. - S0s (np.ndarray): Initial state covariance. - Cs (np.ndarray): Measurement function. - As (np.ndarray): State-transition matrix. - Rs (np.ndarray): Measurement noise covariance. - - Returns: - float: Negative log-likelihood. + One EKF forward (and optional smoother) using Dynamax NLGSSM with: + A = diag([s_d, s_c, s_c]) and Q = diag([σ_d^2(1-s_d^2), σ_x^2(1-s_c^2), σ_y^2(1-s_c^2)]). + R_t = diag(ensemble_vars[t]) (or provided via _R_override). """ - # Construct As - diameter_s, com_s = smooth_params[0], smooth_params[1] - A = jnp.array([ - [diameter_s, 0, 0], - [0, com_s, 0], - [0, 0, com_s] - ]) - - # Construct cov_matrix Q - Q = jnp.array([ - [diameters_var * (1 - (A[0, 0] ** 2)), 0, 0], - [0, x_var * (1 - A[1, 1] ** 2), 0], - [0, 0, y_var * (1 - (A[2, 2] ** 2))] - ]) - - mf, Vf, nll = forward_pass(ys, m0, S0, A, Q, C, ensemble_vars) - - if return_full: - ms, Vs = backward_pass(mf, Vf, A, Q) - return ms, Vs, nll - - return nll \ No newline at end of file + ys = jnp.asarray(ys) + m0 = jnp.asarray(m0) + S0 = jnp.asarray(S0) + C = jnp.asarray(C) + + s_d = jnp.clip(jnp.asarray(smooth_params[0]), 1e-3, 1 - 1e-3) + s_c = jnp.clip(jnp.asarray(smooth_params[1]), 1e-3, 1 - 1e-3) + + A = jnp.diag(jnp.array([s_d, s_c, s_c])) + Q = jnp.diag(jnp.array([ + diameters_var * (1.0 - s_d**2), + x_var * (1.0 - s_c**2), + y_var * (1.0 - s_c**2), + ])) + + # linear f/h closures + f_fn = (lambda x, A=A: A @ x) + h_fn = (lambda x, C=C: C @ x) + + # build NLGSSM params; pass Q as exact and s=1.0 to avoid extra scaling + params = params_nlgssm_for_keypoint(m0, S0, Q, 1.0, R, f_fn, h_fn) + + filt = extended_kalman_filter(params, ys) + nll = -filt.marginal_loglik + if not return_full: + return nll + + sm = extended_kalman_smoother(params, ys) + if hasattr(sm, "smoothed_means"): + ms = sm.smoothed_means + Vs = sm.smoothed_covariances + else: + ms = sm.filtered_means + Vs = sm.filtered_covariances + return ms, Vs, -filt.marginal_loglik diff --git a/eks/multicam_smoother.py b/eks/multicam_smoother.py index 9e3eaeb..b90756f 100644 --- a/eks/multicam_smoother.py +++ b/eks/multicam_smoother.py @@ -6,7 +6,7 @@ from sklearn.decomposition import PCA from typeguard import typechecked -from eks.core import ensemble, optimize_smooth_param +from eks.core import ensemble, run_kalman_smoother from eks.marker_array import ( MarkerArray, input_dfs_to_markerArray, @@ -129,7 +129,7 @@ def fit_eks_multicam( var_mode: str = 'confidence_weighted_var', inflate_vars: bool = False, verbose: bool = False, - n_latent: int = 3 + n_latent: int = 3, ) -> tuple: """ Fit the Ensemble Kalman Smoother for un-mirrored multi-camera data. @@ -177,7 +177,7 @@ def fit_eks_multicam( var_mode=var_mode, verbose=verbose, inflate_vars=inflate_vars, - n_latent=n_latent + n_latent=n_latent, ) # Save output DataFrames to CSVs (one per camera view) os.makedirs(save_dir, exist_ok=True) @@ -201,7 +201,7 @@ def ensemble_kalman_smoother_multicam( inflate_vars_kwargs: dict = {}, verbose: bool = False, pca_object: PCA | None = None, - n_latent: int = 3 + n_latent: int = 3, ) -> tuple: """ Use multi-view constraints to fit a 3D latent subspace for each body part. @@ -254,8 +254,8 @@ def ensemble_kalman_smoother_multicam( n_components=n_latent, pca_object=pca_object, ) - if inflate_vars: + print('inflating') if inflate_vars_kwargs.get("mean", None) is not None: # set mean to zero since we are passing in centered predictions inflate_vars_kwargs["mean"] = np.zeros_like(inflate_vars_kwargs["mean"]) @@ -269,7 +269,7 @@ def ensemble_kalman_smoother_multicam( # Kalman Filter Section ------------------------------------------ # Initialize Kalman filter parameters - m0s, S0s, As, cov_mats, Cs = initialize_kalman_filter_pca( + m0s, S0s, As, Qs, Cs = initialize_kalman_filter_pca( good_pcs_list=good_pcs_list, ensemble_pca=ensemble_pca, n_latent=n_latent, @@ -286,17 +286,17 @@ def ensemble_kalman_smoother_multicam( ]) # Optimize smoothing - s_finals, ms, Vs = optimize_smooth_param( - cov_mats=cov_mats, - ys=ys, + s_finals, ms, Vs = run_kalman_smoother( + ys=jnp.asarray(ys), m0s=m0s, S0s=S0s, - Cs=Cs, As=As, + Cs=Cs, + Qs=Qs, ensemble_vars=np.swapaxes(ensemble_vars, 0, 1), s_frames=s_frames, smooth_param=smooth_param, - verbose=verbose + verbose=verbose, ) # Reproject from latent space back to observed space camera_arrs = [[] for _ in camera_names] @@ -373,7 +373,6 @@ def initialize_kalman_filter_pca( ]) As = np.tile(np.eye(n_latent), (n_keypoints, 1, 1)) Cs = np.stack([pca.components_.T for pca in ensemble_pca]) - Rs = np.tile(np.eye(n_latent), (n_keypoints, 1, 1)) cov_mats = [] for k in range(n_keypoints): @@ -395,7 +394,7 @@ def initialize_kalman_filter_pca( def mA_compute_maha(centered_emA_preds, emA_vars, emA_likes, n_latent, - inflate_vars_kwargs={}, threshold=5, scalar=2): + inflate_vars_kwargs={}, threshold=5, scalar=10): """ Reshape marker arrays for Mahalanobis computation, compute Mahalanobis distances, and optionally inflate variances for all keypoints. @@ -427,7 +426,7 @@ def mA_compute_maha(centered_emA_preds, emA_vars, emA_likes, n_latent, inflate_vars_kwargs['v_quantile_threshold'] = 50.0 inflated = True tmp_vars = vars - + print(f'inflating keypoint: {k}') while inflated: # Compute Mahalanobis distances if inflate_vars_kwargs.get("likelihoods", None) is None: diff --git a/eks/singlecam_smoother.py b/eks/singlecam_smoother.py index 3c9f849..5e6c28e 100644 --- a/eks/singlecam_smoother.py +++ b/eks/singlecam_smoother.py @@ -5,7 +5,7 @@ import pandas as pd from typeguard import typechecked -from eks.core import ensemble, optimize_smooth_param +from eks.core import ensemble, run_kalman_smoother from eks.marker_array import MarkerArray, input_dfs_to_markerArray from eks.utils import center_predictions, format_data, make_dlc_pandas_index @@ -94,7 +94,7 @@ def ensemble_kalman_smoother_singlecam( keypoint_names: List of body parts to run smoothing on smooth_param: value in (0, Inf); smaller values lead to more smoothing s_frames: List of frames for automatic computation of smoothing parameter - blocks: keypoints to be blocked for correlated noise. Generates on smoothing param per + blocks: keypoints to be blocked for correlated noise. Generates one smoothing param per block, as opposed to per keypoint. Specified by the form "x1, x2, x3; y1, y2" referring to keypoint indices (start at 0) avg_mode: mode for averaging across ensemble @@ -134,12 +134,21 @@ def ensemble_kalman_smoother_singlecam( # Prepare params for singlecam_optimize_smooth() ys = emA_centered_preds.get_array(squeeze=True).transpose(1, 0, 2) - m0s, S0s, As, cov_mats, Cs = initialize_kalman_filter(emA_centered_preds) + m0s, S0s, As, Qs, Cs = initialize_kalman_filter(emA_centered_preds) # Main smoothing function - s_finals, ms, Vs = optimize_smooth_param( - cov_mats, ys, m0s, S0s, Cs, As, emA_vars.get_array(squeeze=True), - s_frames, smooth_param, blocks, verbose=verbose, + s_finals, ms, Vs = run_kalman_smoother( + ys=jnp.asarray(ys), + m0s=m0s, + S0s=S0s, + As=As, + Cs=Cs, + Qs=Qs, + ensemble_vars=emA_vars.get_array(squeeze=True), + s_frames=s_frames, + smooth_param=smooth_param, + blocks=blocks, + verbose=verbose ) y_m_smooths = np.zeros((n_keypoints, n_frames, 2)) diff --git a/eks/utils.py b/eks/utils.py index 3783fee..343e547 100644 --- a/eks/utils.py +++ b/eks/utils.py @@ -251,6 +251,7 @@ def crop_frames(y: np.ndarray | jnp.ndarray, s_frames: list | tuple) -> np.ndarr return np.concatenate(result) +@typechecked() def center_predictions( ensemble_marker_array: MarkerArray, quantile_keep_pca: float @@ -324,3 +325,38 @@ def center_predictions( emA_means = MarkerArray.stack(emA_means_list, "keypoints") return valid_frames_mask, emA_centered_preds, emA_good_centered_preds, emA_means + + +@typechecked +def build_R_from_vars(ev: np.ndarray) -> np.ndarray: + """ + Build time-varying diagonal observation covariances from per-dimension variances. + ev shape: (..., T, O) -> returns (..., T, O, O) with diag(ev[t]). + """ + ev_np = np.clip(np.asarray(ev), 1e-12, None) + O_dim = ev_np.shape[-1] + # Broadcast-diagonal without Python loops: + # (..., T, O, 1) * (O, O) -> (..., T, O, O), scaling rows of the identity. + return ev_np[..., :, None] * np.eye(O_dim, dtype=ev_np.dtype) + + +@typechecked +def crop_R(R: np.ndarray, s_frames: list | None) -> np.ndarray: + """ + Crop time-varying R along its time axis using the same spec as crop_frames. + R_tv shape: (..., T, O, O) -> returns (..., T', O, O). + Assumes R_tv is diagonal (built via build_R_tv_from_vars) but works generically. + """ + if not s_frames: + return np.asarray(R) + R_np = np.asarray(R) + leading = R_np.shape[:-3] # any leading batch dims + T, O, O2 = R_np.shape[-3:] + assert O == O2, "R_tv must be square in its last two dims" + # Flatten leading dims to crop time contiguous + R_flat = R_np.reshape((-1, T, O, O)) + cropped_list = [] + for block in R_flat: + cropped_list.append(crop_frames(block, s_frames)) # uses the same semantics + R_cropped = np.stack(cropped_list, axis=0) + return R_cropped.reshape((*leading, -1, O, O)) diff --git a/scripts/multicam_example.py b/scripts/multicam_example.py index f25dbc0..1656a8f 100644 --- a/scripts/multicam_example.py +++ b/scripts/multicam_example.py @@ -39,7 +39,7 @@ quantile_keep_pca=quantile_keep_pca, verbose=verbose, inflate_vars=inflate_vars, - n_latent=args.n_latent + n_latent=args.n_latent, ) # Plot results for a specific keypoint (default to last keypoint of last camera view) diff --git a/scripts/singlecam_example.py b/scripts/singlecam_example.py index 54389cf..c6f76d4 100644 --- a/scripts/singlecam_example.py +++ b/scripts/singlecam_example.py @@ -33,7 +33,7 @@ smooth_param=s, s_frames=s_frames, blocks=blocks, - verbose=verbose + verbose=verbose, ) # Plot results for a specific keypoint (default to last keypoint) diff --git a/setup.py b/setup.py index 3e2b1ca..8f74105 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ def get_version(rel_path): 'sleap_io', 'jax', 'jaxlib', + 'dynamax' ] # additional requirements @@ -67,7 +68,7 @@ def get_version(rel_path): long_description_content_type='text/markdown', author='Cole Hurwitz', author_email='', - url='http://www.github.com/colehurwitz/eks', + url='http://www.github.com/paninski-lab/eks', packages=['eks'], install_requires=install_requires, extras_require=extras_require, diff --git a/tests/test_multicam_smoother.py b/tests/test_multicam_smoother.py index 5a0ad1f..c36f0c2 100644 --- a/tests/test_multicam_smoother.py +++ b/tests/test_multicam_smoother.py @@ -45,9 +45,10 @@ def test_ensemble_kalman_smoother_multicam(): f"Expected {len(camera_names)} entries in camera_dfs, got {len(camera_dfs)}" assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" - assert smooth_params_final == smooth_param, \ - f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ - f"got {smooth_params_final}" + for k in range(len(keypoint_names)): + assert smooth_params_final[k] == smooth_param, \ + f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ + f"got {smooth_params_final}" # --------------------------------------------------- # Run with variance inflation @@ -69,9 +70,10 @@ def test_ensemble_kalman_smoother_multicam(): f"Expected {len(camera_names)} entries in camera_dfs, got {len(camera_dfs)}" assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" - assert smooth_params_final == smooth_param, \ - f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ - f"got {smooth_params_final}" + for k in range(len(keypoint_names)): + assert smooth_params_final[k] == smooth_param, \ + f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ + f"got {smooth_params_final}" # --------------------------------------------------- # Run with variance inflation + more maha kwargs @@ -95,9 +97,10 @@ def test_ensemble_kalman_smoother_multicam(): f"Expected {len(camera_names)} entries in camera_dfs, got {len(camera_dfs)}" assert isinstance(smooth_params_final, np.ndarray), \ f"Expected smooth_param_final to be an array, got {type(smooth_params_final)}" - assert smooth_params_final == smooth_param, \ - f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ - f"got {smooth_params_final}" + for k in range(len(keypoint_names)): + assert smooth_params_final[k] == smooth_param, \ + f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ + f"got {smooth_params_final}" # --------------------------------------------------- # Run with variance inflation + more maha kwargs diff --git a/tests/test_singlecam_smoother.py b/tests/test_singlecam_smoother.py index d8751fd..678fb40 100644 --- a/tests/test_singlecam_smoother.py +++ b/tests/test_singlecam_smoother.py @@ -41,7 +41,8 @@ def _check_outputs(df, params): blocks=blocks, ) _check_outputs(df_smoothed, s_finals) - assert s_finals == [smooth_param] + for k in range(len(keypoint_names)): + assert s_finals[k] == smooth_param # run with fixed smooth param (int) smooth_param = 5 @@ -53,7 +54,8 @@ def _check_outputs(df, params): blocks=blocks, ) _check_outputs(df_smoothed, s_finals) - assert s_finals == [smooth_param] + for k in range(len(keypoint_names)): + assert s_finals[k] == smooth_param # run with fixed smooth param (single-entry list) smooth_param = [0.1] @@ -65,7 +67,8 @@ def _check_outputs(df, params): blocks=blocks, ) _check_outputs(df_smoothed, s_finals) - assert s_finals == smooth_param + for k in range(len(keypoint_names)): + assert s_finals[k] == smooth_param # run with fixed smooth param (list) smooth_param = [0.1, 0.4] @@ -77,7 +80,8 @@ def _check_outputs(df, params): blocks=blocks, ) _check_outputs(df_smoothed, s_finals) - assert np.all(s_finals == smooth_param) + for k in range(len(keypoint_names)): + assert s_finals[k] == smooth_param[k] # run with None smooth param smooth_param = None