From 413d12867160e6bc2d6f1aba778bce544274ccf9 Mon Sep 17 00:00:00 2001 From: kleyt0n Date: Tue, 24 Mar 2026 23:57:19 +0000 Subject: [PATCH] feat: implement extended kalman filter --- src/dynaris/__init__.py | 6 +- src/dynaris/core/__init__.py | 2 + src/dynaris/core/nonlinear.py | 123 +++++++++++ src/dynaris/filters/__init__.py | 3 + src/dynaris/filters/ekf.py | 210 ++++++++++++++++++ tests/test_filters/test_ekf.py | 378 ++++++++++++++++++++++++++++++++ 6 files changed, 721 insertions(+), 1 deletion(-) create mode 100644 src/dynaris/core/nonlinear.py create mode 100644 src/dynaris/filters/ekf.py create mode 100644 tests/test_filters/test_ekf.py diff --git a/src/dynaris/__init__.py b/src/dynaris/__init__.py index 993d803..cc67eba 100644 --- a/src/dynaris/__init__.py +++ b/src/dynaris/__init__.py @@ -4,6 +4,7 @@ FilterProtocol, FilterResult, GaussianState, + NonlinearSSM, SmootherProtocol, SmootherResult, StateSpaceModel, @@ -17,7 +18,7 @@ Regression, Seasonal, ) -from dynaris.filters import KalmanFilter, kalman_filter +from dynaris.filters import ExtendedKalmanFilter, KalmanFilter, ekf_filter, kalman_filter from dynaris.smoothers import RTSSmoother, rts_smooth __version__ = "0.1.0" @@ -26,12 +27,14 @@ "DLM", "Autoregressive", "Cycle", + "ExtendedKalmanFilter", "FilterProtocol", "FilterResult", "GaussianState", "KalmanFilter", "LocalLevel", "LocalLinearTrend", + "NonlinearSSM", "RTSSmoother", "Regression", "Seasonal", @@ -39,6 +42,7 @@ "SmootherResult", "StateSpaceModel", "__version__", + "ekf_filter", "kalman_filter", "rts_smooth", ] diff --git a/src/dynaris/core/__init__.py b/src/dynaris/core/__init__.py index 45dbf76..ee21584 100644 --- a/src/dynaris/core/__init__.py +++ b/src/dynaris/core/__init__.py @@ -1,5 +1,6 @@ """Core math engine: state-space representation, filter protocols, result types.""" +from dynaris.core.nonlinear import NonlinearSSM from dynaris.core.protocols import FilterProtocol, SmootherProtocol from dynaris.core.results import FilterResult, SmootherResult from dynaris.core.state_space import StateSpaceModel @@ -9,6 +10,7 @@ "FilterProtocol", "FilterResult", "GaussianState", + "NonlinearSSM", "SmootherProtocol", "SmootherResult", "StateSpaceModel", diff --git a/src/dynaris/core/nonlinear.py b/src/dynaris/core/nonlinear.py new file mode 100644 index 0000000..6e1e9a1 --- /dev/null +++ b/src/dynaris/core/nonlinear.py @@ -0,0 +1,123 @@ +"""Nonlinear state-space model representation for EKF/UKF.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +import jax +import jax.numpy as jnp +from jax import Array + +from dynaris.core.types import GaussianState + +# Type aliases for transition and observation functions. +# transition_fn: (state_vec,) -> predicted_state_vec +# observation_fn: (state_vec,) -> predicted_observation_vec +TransitionFn = Callable[[Array], Array] +ObservationFn = Callable[[Array], Array] + + +@dataclass(frozen=True) +class NonlinearSSM: + """Nonlinear state-space model for use with the Extended Kalman Filter. + + State equation: theta_t = f(theta_{t-1}) + omega_t, omega_t ~ N(0, Q) + Observation eq: Y_t = h(theta_t) + nu_t, nu_t ~ N(0, R) + + The Jacobians of f and h are computed automatically via ``jax.jacfwd``, + so no manual derivation is required. + + Attributes: + transition_fn: f, maps state (n,) -> state (n,). + observation_fn: h, maps state (n,) -> observation (m,). + transition_cov: Q, evolution noise covariance, shape (n, n). + observation_cov: R, observation noise covariance, shape (m, m). + state_dim: Dimension of the state vector. + obs_dim: Dimension of the observation vector. + """ + + transition_fn: TransitionFn + observation_fn: ObservationFn + transition_cov: Array # Q: (n, n) + observation_cov: Array # R: (m, m) + state_dim: int + obs_dim: int + + # --- Short aliases --- + + @property + def Q(self) -> Array: # noqa: N802 + """Evolution / transition noise covariance.""" + return self.transition_cov + + @property + def R(self) -> Array: # noqa: N802 + """Observation noise covariance.""" + return self.observation_cov + + @property + def f(self) -> TransitionFn: + """Transition function alias.""" + return self.transition_fn + + @property + def h(self) -> ObservationFn: + """Observation function alias.""" + return self.observation_fn + + # --- Factory methods --- + + def initial_state( + self, + mean: Array | None = None, + cov: Array | None = None, + ) -> GaussianState: + """Create a default initial GaussianState for this model. + + Args: + mean: Initial state mean. Defaults to zeros. + cov: Initial state covariance. Defaults to 1e6 * I (diffuse prior). + + Returns: + GaussianState with the specified or default initial conditions. + """ + n = self.state_dim + if mean is None: + mean = jnp.zeros(n) + if cov is None: + cov = jnp.eye(n) * 1e6 + return GaussianState(mean=mean, cov=cov) + + def __repr__(self) -> str: + return f"NonlinearSSM(state_dim={self.state_dim}, obs_dim={self.obs_dim})" + + # --- JAX pytree registration --- + + def tree_flatten(self) -> tuple[list[Array], dict[str, object]]: + """Flatten into JAX pytree leaves and auxiliary data.""" + leaves = [self.transition_cov, self.observation_cov] + aux = { + "transition_fn": self.transition_fn, + "observation_fn": self.observation_fn, + "state_dim": self.state_dim, + "obs_dim": self.obs_dim, + } + return leaves, aux + + @classmethod + def tree_unflatten( + cls, aux_data: dict[str, object], children: list[Array] + ) -> NonlinearSSM: + """Reconstruct from JAX pytree leaves.""" + return cls( + transition_fn=aux_data["transition_fn"], # type: ignore[arg-type] + observation_fn=aux_data["observation_fn"], # type: ignore[arg-type] + transition_cov=children[0], + observation_cov=children[1], + state_dim=aux_data["state_dim"], # type: ignore[arg-type] + obs_dim=aux_data["obs_dim"], # type: ignore[arg-type] + ) + + +jax.tree_util.register_pytree_node_class(NonlinearSSM) diff --git a/src/dynaris/filters/__init__.py b/src/dynaris/filters/__init__.py index 69adcd5..e0be3e9 100644 --- a/src/dynaris/filters/__init__.py +++ b/src/dynaris/filters/__init__.py @@ -1,8 +1,11 @@ """Filtering algorithms: Kalman filter and variants.""" +from dynaris.filters.ekf import ExtendedKalmanFilter, ekf_filter from dynaris.filters.kalman import KalmanFilter, kalman_filter __all__ = [ + "ExtendedKalmanFilter", "KalmanFilter", + "ekf_filter", "kalman_filter", ] diff --git a/src/dynaris/filters/ekf.py b/src/dynaris/filters/ekf.py new file mode 100644 index 0000000..9a6f0f5 --- /dev/null +++ b/src/dynaris/filters/ekf.py @@ -0,0 +1,210 @@ +"""Extended Kalman Filter for nonlinear state-space models. + +Linearizes the transition and observation functions at each time step using +automatic Jacobians via ``jax.jacfwd``, then applies the standard Kalman +predict/update equations to the linearized system. +""" + +from __future__ import annotations + +from typing import NamedTuple + +import jax +import jax.numpy as jnp +from jax import Array + +from dynaris.core.nonlinear import NonlinearSSM +from dynaris.core.results import FilterResult +from dynaris.core.types import GaussianState + +# --------------------------------------------------------------------------- +# Internal scan carry +# --------------------------------------------------------------------------- + + +class _ScanCarry(NamedTuple): + filtered: GaussianState + log_likelihood: Array # scalar + + +class _ScanOutput(NamedTuple): + predicted_mean: Array + predicted_cov: Array + filtered_mean: Array + filtered_cov: Array + + +# --------------------------------------------------------------------------- +# Pure-function predict and update steps +# --------------------------------------------------------------------------- + + +def predict(state: GaussianState, model: NonlinearSSM) -> GaussianState: + """EKF predict step (time update). + + Propagates the state through the nonlinear transition function and + linearizes using the Jacobian F = df/dx evaluated at the current mean. + + a_t = f(m_{t-1}) + F_t = Jacobian of f at m_{t-1} + R_t = F_t @ C_{t-1} @ F_t' + Q + """ + mean = model.f(state.mean) + F_jac = jax.jacfwd(model.f)(state.mean) # (n, n) + cov = F_jac @ state.cov @ F_jac.T + model.Q + return GaussianState(mean=mean, cov=cov) + + +def update( + predicted: GaussianState, + observation: Array, + model: NonlinearSSM, +) -> tuple[GaussianState, Array]: + """EKF update step (measurement update). + + Linearizes the observation function at the predicted state mean and + applies the standard Kalman update. + + Returns the filtered state and the log-likelihood contribution. + Handles missing observations (NaN) by skipping the update. + """ + y = observation + y_pred = model.h(predicted.mean) # (m,) + H_jac = jax.jacfwd(model.h)(predicted.mean) # (m, n) + + e = y - y_pred # innovation (m,) + S = H_jac @ predicted.cov @ H_jac.T + model.R # innovation covariance (m, m) + + # Kalman gain: K = P @ H' @ S^{-1} + K = jnp.linalg.solve(S.T, (predicted.cov @ H_jac.T).T).T # (n, m) + + filtered_mean = predicted.mean + K @ e + identity = jnp.eye(predicted.mean.shape[-1]) + filtered_cov = (identity - K @ H_jac) @ predicted.cov + + # Log-likelihood: log N(e; 0, S) + m = observation.shape[-1] + log_det = jnp.linalg.slogdet(S)[1] + mahal = e @ jnp.linalg.solve(S, e) + ll = -0.5 * (m * jnp.log(2.0 * jnp.pi) + log_det + mahal) + + # Handle missing observations: if any element is NaN, skip update + obs_valid = ~jnp.any(jnp.isnan(y)) + filtered_mean = jnp.where(obs_valid, filtered_mean, predicted.mean) + filtered_cov = jnp.where(obs_valid, filtered_cov, predicted.cov) + ll = jnp.where(obs_valid, ll, 0.0) + + filtered = GaussianState(mean=filtered_mean, cov=filtered_cov) + return filtered, ll + + +# --------------------------------------------------------------------------- +# Full forward pass via lax.scan +# --------------------------------------------------------------------------- + + +class ExtendedKalmanFilter: + """Extended Kalman Filter implementing the same interface as KalmanFilter. + + Linearizes nonlinear transition and observation functions at each step + using ``jax.jacfwd`` for automatic Jacobian computation. + """ + + def predict(self, state: GaussianState, model: NonlinearSSM) -> GaussianState: + """EKF predict step (time update).""" + return predict(state, model) + + def update( + self, + predicted: GaussianState, + observation: Array, + model: NonlinearSSM, + ) -> GaussianState: + """EKF update step (measurement update).""" + filtered, _ll = update(predicted, observation, model) + return filtered + + def scan( + self, + model: NonlinearSSM, + observations: Array, + initial_state: GaussianState | None = None, + ) -> FilterResult: + """Run full forward EKF via jax.lax.scan.""" + return ekf_filter(model, observations, initial_state) + + +@jax.jit +def ekf_filter( + model: NonlinearSSM, + observations: Array, + initial_state: GaussianState | None = None, +) -> FilterResult: + """JIT-compiled Extended Kalman Filter forward pass. + + Linearizes the nonlinear transition and observation functions at each + time step using automatic Jacobians (``jax.jacfwd``), then applies the + standard Kalman filter predict/update recursion. + + Args: + model: Nonlinear state-space model with callable f and h. + observations: Observation sequence, shape (T, obs_dim). + initial_state: Initial state belief. Defaults to diffuse prior. + + Returns: + FilterResult with filtered/predicted states and log-likelihood. + + Example:: + + import jax.numpy as jnp + from dynaris.core.nonlinear import NonlinearSSM + from dynaris.filters.ekf import ekf_filter + + def f(x): + return x # random walk + + def h(x): + return x # direct observation + + model = NonlinearSSM( + transition_fn=f, observation_fn=h, + transition_cov=jnp.eye(1), observation_cov=jnp.eye(1), + state_dim=1, obs_dim=1, + ) + result = ekf_filter(model, observations) + """ + if initial_state is None: + initial_state = model.initial_state() + + init_carry = _ScanCarry( + filtered=initial_state, + log_likelihood=jnp.array(0.0), + ) + + def _scan_step( + carry: _ScanCarry, obs: Array + ) -> tuple[_ScanCarry, _ScanOutput]: + predicted = predict(carry.filtered, model) + filtered, ll = update(predicted, obs, model) + new_carry = _ScanCarry( + filtered=filtered, + log_likelihood=carry.log_likelihood + ll, + ) + output = _ScanOutput( + predicted_mean=predicted.mean, + predicted_cov=predicted.cov, + filtered_mean=filtered.mean, + filtered_cov=filtered.cov, + ) + return new_carry, output + + final_carry, outputs = jax.lax.scan(_scan_step, init_carry, observations) + + return FilterResult( + filtered_states=outputs.filtered_mean, + filtered_covariances=outputs.filtered_cov, + predicted_states=outputs.predicted_mean, + predicted_covariances=outputs.predicted_cov, + log_likelihood=final_carry.log_likelihood, + observations=observations, + ) diff --git a/tests/test_filters/test_ekf.py b/tests/test_filters/test_ekf.py new file mode 100644 index 0000000..b1991a2 --- /dev/null +++ b/tests/test_filters/test_ekf.py @@ -0,0 +1,378 @@ +"""Tests for the Extended Kalman Filter.""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from dynaris.core.nonlinear import NonlinearSSM +from dynaris.core.results import FilterResult +from dynaris.core.state_space import StateSpaceModel +from dynaris.core.types import GaussianState +from dynaris.datasets import load_nile_jax +from dynaris.filters.ekf import ExtendedKalmanFilter, ekf_filter, predict, update +from dynaris.filters.kalman import kalman_filter + +NILE = load_nile_jax() + + +# --------------------------------------------------------------------------- +# Helper: linear model expressed as NonlinearSSM (for comparison with Kalman) +# --------------------------------------------------------------------------- + + +def _linear_nonlinear_model( + sigma_level: float = 1.0, sigma_obs: float = 1.0 +) -> NonlinearSSM: + """Local-level model as a NonlinearSSM (identity transition/observation).""" + Q = jnp.array([[sigma_level**2]]) + R = jnp.array([[sigma_obs**2]]) + + def f(x: Array) -> Array: + return x + + def h(x: Array) -> Array: + return x + + return NonlinearSSM( + transition_fn=f, + observation_fn=h, + transition_cov=Q, + observation_cov=R, + state_dim=1, + obs_dim=1, + ) + + +def _linear_ssm(sigma_level: float = 1.0, sigma_obs: float = 1.0) -> StateSpaceModel: + """Equivalent linear model for Kalman filter comparison.""" + return StateSpaceModel( + system_matrix=jnp.array([[1.0]]), + observation_matrix=jnp.array([[1.0]]), + evolution_cov=jnp.array([[sigma_level**2]]), + obs_cov=jnp.array([[sigma_obs**2]]), + ) + + +# --------------------------------------------------------------------------- +# Predict step tests +# --------------------------------------------------------------------------- + + +def test_predict_identity_transition() -> None: + model = _linear_nonlinear_model() + state = GaussianState(mean=jnp.array([5.0]), cov=jnp.array([[2.0]])) + pred = predict(state, model) + # Identity transition: mean unchanged, cov = P + Q + np.testing.assert_allclose(pred.mean, [5.0], atol=1e-6) + np.testing.assert_allclose(pred.cov, [[3.0]], atol=1e-6) + + +def test_predict_nonlinear_transition() -> None: + """Test with a nonlinear transition: f(x) = x + 0.1 * sin(x).""" + Q = jnp.array([[0.5]]) + + def f(x: Array) -> Array: + return x + 0.1 * jnp.sin(x) + + model = NonlinearSSM( + transition_fn=f, + observation_fn=lambda x: x, + transition_cov=Q, + observation_cov=jnp.array([[1.0]]), + state_dim=1, + obs_dim=1, + ) + state = GaussianState(mean=jnp.array([1.0]), cov=jnp.array([[0.5]])) + pred = predict(state, model) + # Mean should be f(1.0) = 1.0 + 0.1 * sin(1.0) + expected_mean = 1.0 + 0.1 * float(jnp.sin(1.0)) + np.testing.assert_allclose(pred.mean, [expected_mean], atol=1e-5) + assert jnp.all(jnp.isfinite(pred.cov)) + + +# --------------------------------------------------------------------------- +# Update step tests +# --------------------------------------------------------------------------- + + +def test_update_reduces_uncertainty() -> None: + model = _linear_nonlinear_model(sigma_level=1.0, sigma_obs=1.0) + predicted = GaussianState(mean=jnp.array([0.0]), cov=jnp.array([[10.0]])) + obs = jnp.array([5.0]) + filtered, ll = update(predicted, obs, model) + assert float(filtered.cov[0, 0]) < 10.0 + assert float(filtered.mean[0]) > 0.0 + assert jnp.isfinite(ll) + + +def test_update_nan_skips() -> None: + model = _linear_nonlinear_model() + predicted = GaussianState(mean=jnp.array([3.0]), cov=jnp.array([[2.0]])) + obs = jnp.array([jnp.nan]) + filtered, ll = update(predicted, obs, model) + np.testing.assert_allclose(filtered.mean, predicted.mean) + np.testing.assert_allclose(filtered.cov, predicted.cov) + assert float(ll) == 0.0 + + +# --------------------------------------------------------------------------- +# EKF matches Kalman on linear models +# --------------------------------------------------------------------------- + + +def test_ekf_matches_kalman_on_linear_model() -> None: + """When the model is linear, EKF should produce identical results to Kalman.""" + sigma_level, sigma_obs = 40.0, 120.0 + nl_model = _linear_nonlinear_model(sigma_level, sigma_obs) + lin_model = _linear_ssm(sigma_level, sigma_obs) + + observations = NILE.reshape(-1, 1) + + # Use same initial state for both + init = GaussianState( + mean=jnp.zeros(1), + cov=jnp.eye(1) * 1e6, + ) + + ekf_result = ekf_filter(nl_model, observations, initial_state=init) + kf_result = kalman_filter(lin_model, observations, initial_state=init) + + np.testing.assert_allclose( + ekf_result.filtered_states, kf_result.filtered_states, atol=1e-4 + ) + np.testing.assert_allclose( + ekf_result.filtered_covariances, kf_result.filtered_covariances, atol=1e-3 + ) + np.testing.assert_allclose( + ekf_result.log_likelihood, kf_result.log_likelihood, atol=1e-2 + ) + + +# --------------------------------------------------------------------------- +# Full filter scan tests +# --------------------------------------------------------------------------- + + +def test_ekf_filter_shapes() -> None: + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE.reshape(-1, 1) + result = ekf_filter(model, observations) + + assert isinstance(result, FilterResult) + assert result.filtered_states.shape == (100, 1) + assert result.filtered_covariances.shape == (100, 1, 1) + assert result.predicted_states.shape == (100, 1) + assert result.predicted_covariances.shape == (100, 1, 1) + assert result.log_likelihood.shape == () + assert result.observations.shape == (100, 1) + + +def test_ekf_filter_finite() -> None: + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE.reshape(-1, 1) + result = ekf_filter(model, observations) + + assert jnp.all(jnp.isfinite(result.filtered_states)) + assert jnp.all(jnp.isfinite(result.filtered_covariances)) + assert jnp.isfinite(result.log_likelihood) + + +def test_ekf_filter_negative_log_likelihood() -> None: + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE.reshape(-1, 1) + result = ekf_filter(model, observations) + assert float(result.log_likelihood) < 0.0 + + +def test_ekf_filter_with_missing_obs() -> None: + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE.reshape(-1, 1) + observations = observations.at[10, 0].set(jnp.nan) + observations = observations.at[20, 0].set(jnp.nan) + + result = ekf_filter(model, observations) + assert jnp.all(jnp.isfinite(result.filtered_states)) + assert jnp.isfinite(result.log_likelihood) + + # At NaN points, predicted == filtered + np.testing.assert_allclose( + result.filtered_states[10], result.predicted_states[10], atol=1e-5 + ) + + +# --------------------------------------------------------------------------- +# Nonlinear model tests +# --------------------------------------------------------------------------- + + +def test_ekf_nonlinear_tracking() -> None: + """EKF should track a nonlinear state through noisy observations.""" + key = jax.random.PRNGKey(42) + k1, k2 = jax.random.split(key) + n_steps = 200 + + # Mild nonlinear transition: x_t = 0.95 * x_{t-1} + 0.1 * sin(x_{t-1}) + def f(x: Array) -> Array: + return 0.95 * x + 0.1 * jnp.sin(x) + + def h(x: Array) -> Array: + return x + + # Simulate true states and observations + sigma_q, sigma_r = 0.5, 1.0 + state_noise = jax.random.normal(k1, (n_steps,)) * sigma_q + obs_noise = jax.random.normal(k2, (n_steps,)) * sigma_r + + # Build true state sequence + state = jnp.array([5.0]) + states_list = [] + for t in range(n_steps): + state = f(state) + state_noise[t : t + 1] + states_list.append(state) + true_states = jnp.concatenate(states_list) + observations = (true_states + obs_noise).reshape(-1, 1) + + model = NonlinearSSM( + transition_fn=f, + observation_fn=h, + transition_cov=jnp.array([[sigma_q**2]]), + observation_cov=jnp.array([[sigma_r**2]]), + state_dim=1, + obs_dim=1, + ) + + init = GaussianState(mean=jnp.array([5.0]), cov=jnp.array([[1.0]])) + result = ekf_filter(model, observations, initial_state=init) + + # Filtered states should track the true states well + filtered = result.filtered_states[:, 0] + correlation = jnp.corrcoef(jnp.stack([filtered, true_states]))[0, 1] + assert float(correlation) > 0.7, f"Correlation {correlation} too low" + assert jnp.all(jnp.isfinite(result.filtered_states)) + + +def test_ekf_2d_nonlinear() -> None: + """Test EKF with a 2D nonlinear model (polar-to-cartesian observation).""" + + def f(x: Array) -> Array: + # Near-constant velocity + return x * 0.99 + + def h(x: Array) -> Array: + # Observe range and bearing (nonlinear observation) + r = jnp.sqrt(x[0] ** 2 + x[1] ** 2) + theta = jnp.arctan2(x[1], x[0]) + return jnp.array([r, theta]) + + model = NonlinearSSM( + transition_fn=f, + observation_fn=h, + transition_cov=jnp.eye(2) * 0.1, + observation_cov=jnp.eye(2) * 0.01, + state_dim=2, + obs_dim=2, + ) + + key = jax.random.PRNGKey(7) + # Simulate observations from a known trajectory + true_state = jnp.array([3.0, 4.0]) + obs_list = [] + for t in range(50): + true_state = f(true_state) + jax.random.normal(key, (2,)) * 0.01 + key, _ = jax.random.split(key) + obs = h(true_state) + jax.random.normal(key, (2,)) * 0.1 + key, _ = jax.random.split(key) + obs_list.append(obs) + observations = jnp.stack(obs_list) + + init = GaussianState(mean=jnp.array([3.0, 4.0]), cov=jnp.eye(2) * 1.0) + result = ekf_filter(model, observations, initial_state=init) + + assert result.filtered_states.shape == (50, 2) + assert jnp.all(jnp.isfinite(result.filtered_states)) + assert jnp.isfinite(result.log_likelihood) + + +# --------------------------------------------------------------------------- +# JIT compatibility +# --------------------------------------------------------------------------- + + +def test_ekf_filter_jit() -> None: + """Verify ekf_filter is JIT-compiled without errors.""" + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE[:20].reshape(-1, 1) + r1 = ekf_filter(model, observations) + r2 = ekf_filter(model, observations) + np.testing.assert_allclose(r1.log_likelihood, r2.log_likelihood, atol=1e-5) + + +def test_grad_through_ekf() -> None: + """Verify autodiff works through the EKF log-likelihood.""" + observations = NILE[:20].reshape(-1, 1) + + def neg_ll(log_sigma_level: Array, log_sigma_obs: Array) -> Array: + Q = jnp.exp(log_sigma_level) * jnp.eye(1) + R = jnp.exp(log_sigma_obs) * jnp.eye(1) + model = NonlinearSSM( + transition_fn=lambda x: x, + observation_fn=lambda x: x, + transition_cov=Q, + observation_cov=R, + state_dim=1, + obs_dim=1, + ) + result = ekf_filter(model, observations) + return -result.log_likelihood + + grad_fn = jax.grad(neg_ll, argnums=(0, 1)) + g1, g2 = grad_fn(jnp.log(jnp.array(1600.0)), jnp.log(jnp.array(15000.0))) + assert jnp.isfinite(g1) + assert jnp.isfinite(g2) + + +# --------------------------------------------------------------------------- +# Class interface +# --------------------------------------------------------------------------- + + +def test_ekf_class_scan() -> None: + ekf = ExtendedKalmanFilter() + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE[:10].reshape(-1, 1) + result = ekf.scan(model, observations) + assert isinstance(result, FilterResult) + assert result.filtered_states.shape == (10, 1) + + +# --------------------------------------------------------------------------- +# NonlinearSSM tests +# --------------------------------------------------------------------------- + + +def test_nonlinear_ssm_repr() -> None: + model = _linear_nonlinear_model() + assert "NonlinearSSM" in repr(model) + assert "state_dim=1" in repr(model) + + +def test_nonlinear_ssm_initial_state() -> None: + model = _linear_nonlinear_model() + init = model.initial_state() + assert init.mean.shape == (1,) + assert init.cov.shape == (1, 1) + np.testing.assert_allclose(init.mean, [0.0]) + assert float(init.cov[0, 0]) == 1e6 + + +def test_nonlinear_ssm_pytree() -> None: + """NonlinearSSM should be a valid JAX pytree.""" + model = _linear_nonlinear_model(sigma_level=2.0, sigma_obs=3.0) + leaves, treedef = jax.tree_util.tree_flatten(model) + reconstructed = jax.tree_util.tree_unflatten(treedef, leaves) + np.testing.assert_allclose(reconstructed.Q, model.Q) + np.testing.assert_allclose(reconstructed.R, model.R) + assert reconstructed.state_dim == model.state_dim