Skip to content

Commit a1ff847

Browse files
authored
Merge pull request #3 from quant-sci/implemnet-ekf
feat: implement extended kalman filter
2 parents 4b021a9 + 413d128 commit a1ff847

6 files changed

Lines changed: 721 additions & 1 deletion

File tree

src/dynaris/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
FilterProtocol,
55
FilterResult,
66
GaussianState,
7+
NonlinearSSM,
78
SmootherProtocol,
89
SmootherResult,
910
StateSpaceModel,
@@ -17,7 +18,7 @@
1718
Regression,
1819
Seasonal,
1920
)
20-
from dynaris.filters import KalmanFilter, kalman_filter
21+
from dynaris.filters import ExtendedKalmanFilter, KalmanFilter, ekf_filter, kalman_filter
2122
from dynaris.smoothers import RTSSmoother, rts_smooth
2223

2324
__version__ = "0.1.0"
@@ -26,19 +27,22 @@
2627
"DLM",
2728
"Autoregressive",
2829
"Cycle",
30+
"ExtendedKalmanFilter",
2931
"FilterProtocol",
3032
"FilterResult",
3133
"GaussianState",
3234
"KalmanFilter",
3335
"LocalLevel",
3436
"LocalLinearTrend",
37+
"NonlinearSSM",
3538
"RTSSmoother",
3639
"Regression",
3740
"Seasonal",
3841
"SmootherProtocol",
3942
"SmootherResult",
4043
"StateSpaceModel",
4144
"__version__",
45+
"ekf_filter",
4246
"kalman_filter",
4347
"rts_smooth",
4448
]

src/dynaris/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Core math engine: state-space representation, filter protocols, result types."""
22

3+
from dynaris.core.nonlinear import NonlinearSSM
34
from dynaris.core.protocols import FilterProtocol, SmootherProtocol
45
from dynaris.core.results import FilterResult, SmootherResult
56
from dynaris.core.state_space import StateSpaceModel
@@ -9,6 +10,7 @@
910
"FilterProtocol",
1011
"FilterResult",
1112
"GaussianState",
13+
"NonlinearSSM",
1214
"SmootherProtocol",
1315
"SmootherResult",
1416
"StateSpaceModel",

src/dynaris/core/nonlinear.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""Nonlinear state-space model representation for EKF/UKF."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Callable
7+
8+
import jax
9+
import jax.numpy as jnp
10+
from jax import Array
11+
12+
from dynaris.core.types import GaussianState
13+
14+
# Type aliases for transition and observation functions.
15+
# transition_fn: (state_vec,) -> predicted_state_vec
16+
# observation_fn: (state_vec,) -> predicted_observation_vec
17+
TransitionFn = Callable[[Array], Array]
18+
ObservationFn = Callable[[Array], Array]
19+
20+
21+
@dataclass(frozen=True)
22+
class NonlinearSSM:
23+
"""Nonlinear state-space model for use with the Extended Kalman Filter.
24+
25+
State equation: theta_t = f(theta_{t-1}) + omega_t, omega_t ~ N(0, Q)
26+
Observation eq: Y_t = h(theta_t) + nu_t, nu_t ~ N(0, R)
27+
28+
The Jacobians of f and h are computed automatically via ``jax.jacfwd``,
29+
so no manual derivation is required.
30+
31+
Attributes:
32+
transition_fn: f, maps state (n,) -> state (n,).
33+
observation_fn: h, maps state (n,) -> observation (m,).
34+
transition_cov: Q, evolution noise covariance, shape (n, n).
35+
observation_cov: R, observation noise covariance, shape (m, m).
36+
state_dim: Dimension of the state vector.
37+
obs_dim: Dimension of the observation vector.
38+
"""
39+
40+
transition_fn: TransitionFn
41+
observation_fn: ObservationFn
42+
transition_cov: Array # Q: (n, n)
43+
observation_cov: Array # R: (m, m)
44+
state_dim: int
45+
obs_dim: int
46+
47+
# --- Short aliases ---
48+
49+
@property
50+
def Q(self) -> Array: # noqa: N802
51+
"""Evolution / transition noise covariance."""
52+
return self.transition_cov
53+
54+
@property
55+
def R(self) -> Array: # noqa: N802
56+
"""Observation noise covariance."""
57+
return self.observation_cov
58+
59+
@property
60+
def f(self) -> TransitionFn:
61+
"""Transition function alias."""
62+
return self.transition_fn
63+
64+
@property
65+
def h(self) -> ObservationFn:
66+
"""Observation function alias."""
67+
return self.observation_fn
68+
69+
# --- Factory methods ---
70+
71+
def initial_state(
72+
self,
73+
mean: Array | None = None,
74+
cov: Array | None = None,
75+
) -> GaussianState:
76+
"""Create a default initial GaussianState for this model.
77+
78+
Args:
79+
mean: Initial state mean. Defaults to zeros.
80+
cov: Initial state covariance. Defaults to 1e6 * I (diffuse prior).
81+
82+
Returns:
83+
GaussianState with the specified or default initial conditions.
84+
"""
85+
n = self.state_dim
86+
if mean is None:
87+
mean = jnp.zeros(n)
88+
if cov is None:
89+
cov = jnp.eye(n) * 1e6
90+
return GaussianState(mean=mean, cov=cov)
91+
92+
def __repr__(self) -> str:
93+
return f"NonlinearSSM(state_dim={self.state_dim}, obs_dim={self.obs_dim})"
94+
95+
# --- JAX pytree registration ---
96+
97+
def tree_flatten(self) -> tuple[list[Array], dict[str, object]]:
98+
"""Flatten into JAX pytree leaves and auxiliary data."""
99+
leaves = [self.transition_cov, self.observation_cov]
100+
aux = {
101+
"transition_fn": self.transition_fn,
102+
"observation_fn": self.observation_fn,
103+
"state_dim": self.state_dim,
104+
"obs_dim": self.obs_dim,
105+
}
106+
return leaves, aux
107+
108+
@classmethod
109+
def tree_unflatten(
110+
cls, aux_data: dict[str, object], children: list[Array]
111+
) -> NonlinearSSM:
112+
"""Reconstruct from JAX pytree leaves."""
113+
return cls(
114+
transition_fn=aux_data["transition_fn"], # type: ignore[arg-type]
115+
observation_fn=aux_data["observation_fn"], # type: ignore[arg-type]
116+
transition_cov=children[0],
117+
observation_cov=children[1],
118+
state_dim=aux_data["state_dim"], # type: ignore[arg-type]
119+
obs_dim=aux_data["obs_dim"], # type: ignore[arg-type]
120+
)
121+
122+
123+
jax.tree_util.register_pytree_node_class(NonlinearSSM)

src/dynaris/filters/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Filtering algorithms: Kalman filter and variants."""
22

3+
from dynaris.filters.ekf import ExtendedKalmanFilter, ekf_filter
34
from dynaris.filters.kalman import KalmanFilter, kalman_filter
45

56
__all__ = [
7+
"ExtendedKalmanFilter",
68
"KalmanFilter",
9+
"ekf_filter",
710
"kalman_filter",
811
]

0 commit comments

Comments
 (0)