Skip to content

Commit aacb096

Browse files
committed
feat: add switching & regime models
1 parent 0adf91b commit aacb096

18 files changed

Lines changed: 1545 additions & 16 deletions

File tree

docs/api/index.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ Complete reference for all public classes and functions in dynaris.
1818
+------------------+-------------------------------------------------------------+
1919
| :doc:`filters` | Kalman, EKF, UKF, and Particle filters |
2020
+------------------+-------------------------------------------------------------+
21+
| :doc:`switching` | Markov-switching models, Hamilton filter, Kim smoother |
22+
+------------------+-------------------------------------------------------------+
2123
| :doc:`smoothers` | Rauch-Tung-Striebel backward smoother |
2224
+------------------+-------------------------------------------------------------+
23-
| :doc:`estimation` | MLE, EM algorithm, diagnostics, transforms |
25+
| :doc:`estimation` | MLE, EM algorithm, diagnostics, model selection |
2426
+------------------+-------------------------------------------------------------+
2527
| :doc:`forecast` | Multi-step forecasting and batch processing |
2628
+------------------+-------------------------------------------------------------+
@@ -39,6 +41,7 @@ Complete reference for all public classes and functions in dynaris.
3941
models
4042
core
4143
filters
44+
switching
4245
smoothers
4346
estimation
4447
forecast

docs/api/switching.rst

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
Switching & Regime Models
2+
=========================
3+
4+
Markov-switching state-space models with K discrete regimes, each with
5+
its own linear-Gaussian dynamics. The Hamilton filter tracks filtered
6+
regime probabilities, and the Kim smoother provides smoothed estimates.
7+
8+
Model
9+
-----
10+
11+
.. autoclass:: dynaris.core.switching.MarkovSwitchingSSM
12+
:members:
13+
:show-inheritance:
14+
15+
Hamilton Filter
16+
---------------
17+
18+
Forward filtering for Markov-switching models. Runs K parallel Kalman
19+
filters with Kim's moment-matching collapse approximation.
20+
21+
.. autoclass:: dynaris.filters.hamilton.HamiltonFilter
22+
:members:
23+
:show-inheritance:
24+
25+
.. autofunction:: dynaris.filters.hamilton.hamilton_filter
26+
27+
Kim Smoother
28+
------------
29+
30+
Backward smoothing for Markov-switching models. Produces smoothed
31+
state estimates and regime probabilities.
32+
33+
.. autoclass:: dynaris.smoothers.kim.KimSmoother
34+
:members:
35+
:show-inheritance:
36+
37+
.. autofunction:: dynaris.smoothers.kim.kim_smooth
38+
39+
Model Selection
40+
---------------
41+
42+
.. autofunction:: dynaris.estimation.model_selection.switching_aic
43+
44+
.. autofunction:: dynaris.estimation.model_selection.switching_bic

src/dynaris/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
FilterProtocol,
66
FilterResult,
77
GaussianState,
8+
MarkovSwitchingSSM,
89
NonlinearSSM,
910
SmootherProtocol,
1011
SmootherResult,
1112
StateSpaceModel,
13+
SwitchingFilterResult,
14+
SwitchingSmootherResult,
1215
)
1316
from dynaris.dlm import (
1417
DLM,
@@ -21,10 +24,12 @@
2124
)
2225
from dynaris.filters import (
2326
ExtendedKalmanFilter,
27+
HamiltonFilter,
2428
KalmanFilter,
2529
ParticleFilter,
2630
UnscentedKalmanFilter,
2731
ekf_filter,
32+
hamilton_filter,
2833
kalman_filter,
2934
particle_filter,
3035
ukf_filter,
@@ -35,7 +40,7 @@
3540
StochasticVolatility,
3641
transform_returns,
3742
)
38-
from dynaris.smoothers import RTSSmoother, rts_smooth
43+
from dynaris.smoothers import KimSmoother, RTSSmoother, kim_smooth, rts_smooth
3944

4045
__version__ = "0.1.0"
4146

@@ -49,10 +54,13 @@
4954
"FilterProtocol",
5055
"FilterResult",
5156
"GaussianState",
57+
"HamiltonFilter",
5258
"KalmanFilter",
59+
"KimSmoother",
5360
"LocalLevel",
5461
"LocalLinearTrend",
5562
"LorenzAttractor",
63+
"MarkovSwitchingSSM",
5664
"NonlinearSSM",
5765
"ParticleFilter",
5866
"RTSSmoother",
@@ -62,10 +70,14 @@
6270
"SmootherResult",
6371
"StateSpaceModel",
6472
"StochasticVolatility",
73+
"SwitchingFilterResult",
74+
"SwitchingSmootherResult",
6575
"UnscentedKalmanFilter",
6676
"__version__",
6777
"ekf_filter",
78+
"hamilton_filter",
6879
"kalman_filter",
80+
"kim_smooth",
6981
"particle_filter",
7082
"rts_smooth",
7183
"transform_returns",

src/dynaris/core/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,27 @@
22

33
from dynaris.core.nonlinear import NonlinearSSM
44
from dynaris.core.protocols import FilterProtocol, SmootherProtocol
5-
from dynaris.core.results import FilterResult, SmootherResult
5+
from dynaris.core.results import (
6+
FilterResult,
7+
SmootherResult,
8+
SwitchingFilterResult,
9+
SwitchingSmootherResult,
10+
)
611
from dynaris.core.ssm import SSM
712
from dynaris.core.state_space import StateSpaceModel
13+
from dynaris.core.switching import MarkovSwitchingSSM
814
from dynaris.core.types import GaussianState
915

1016
__all__ = [
1117
"SSM",
1218
"FilterProtocol",
1319
"FilterResult",
1420
"GaussianState",
21+
"MarkovSwitchingSSM",
1522
"NonlinearSSM",
1623
"SmootherProtocol",
1724
"SmootherResult",
1825
"StateSpaceModel",
26+
"SwitchingFilterResult",
27+
"SwitchingSmootherResult",
1928
]

src/dynaris/core/results.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,69 @@ class SmootherResult(NamedTuple):
5858
predicted_covariances: Array # (T, n, n)
5959
log_likelihood: Array # ()
6060
observations: Array # (T, m)
61+
62+
63+
class SwitchingFilterResult(NamedTuple):
64+
"""Output of a Hamilton forward filtering pass for switching models.
65+
66+
Contains mixture-collapsed state estimates and per-regime quantities.
67+
68+
Attributes:
69+
filtered_states: Mixture-collapsed filtered means, shape (T, n).
70+
filtered_covariances: Mixture-collapsed filtered covs, shape (T, n, n).
71+
predicted_states: Mixture-collapsed predicted means, shape (T, n).
72+
predicted_covariances: Mixture-collapsed predicted covs, shape (T, n, n).
73+
log_likelihood: Total log-likelihood scalar, shape ().
74+
observations: Input observations, shape (T, m).
75+
regime_filtered_probs: Filtered regime probabilities, shape (T, K).
76+
regime_predicted_probs: Predicted regime probabilities, shape (T, K).
77+
regime_filtered_states: Per-regime filtered means, shape (T, K, n).
78+
regime_filtered_covs: Per-regime filtered covs, shape (T, K, n, n).
79+
regime_predicted_states: Per-regime predicted means, shape (T, K, n).
80+
regime_predicted_covs: Per-regime predicted covs, shape (T, K, n, n).
81+
"""
82+
83+
filtered_states: Array # (T, n)
84+
filtered_covariances: Array # (T, n, n)
85+
predicted_states: Array # (T, n)
86+
predicted_covariances: Array # (T, n, n)
87+
log_likelihood: Array # ()
88+
observations: Array # (T, m)
89+
regime_filtered_probs: Array # (T, K)
90+
regime_predicted_probs: Array # (T, K)
91+
regime_filtered_states: Array # (T, K, n)
92+
regime_filtered_covs: Array # (T, K, n, n)
93+
regime_predicted_states: Array # (T, K, n)
94+
regime_predicted_covs: Array # (T, K, n, n)
95+
96+
97+
class SwitchingSmootherResult(NamedTuple):
98+
"""Output of a Kim backward smoothing pass for switching models.
99+
100+
Attributes:
101+
smoothed_states: Mixture-collapsed smoothed means, shape (T, n).
102+
smoothed_covariances: Mixture-collapsed smoothed covs, shape (T, n, n).
103+
filtered_states: Mixture-collapsed filtered means, shape (T, n).
104+
filtered_covariances: Mixture-collapsed filtered covs, shape (T, n, n).
105+
predicted_states: Mixture-collapsed predicted means, shape (T, n).
106+
predicted_covariances: Mixture-collapsed predicted covs, shape (T, n, n).
107+
log_likelihood: Total log-likelihood scalar, shape ().
108+
observations: Input observations, shape (T, m).
109+
regime_smoothed_probs: Smoothed regime probabilities, shape (T, K).
110+
regime_filtered_probs: Filtered regime probabilities, shape (T, K).
111+
regime_smoothed_states: Per-regime smoothed means, shape (T, K, n).
112+
regime_smoothed_covs: Per-regime smoothed covs, shape (T, K, n, n).
113+
"""
114+
115+
smoothed_states: Array # (T, n)
116+
smoothed_covariances: Array # (T, n, n)
117+
filtered_states: Array # (T, n)
118+
filtered_covariances: Array # (T, n, n)
119+
predicted_states: Array # (T, n)
120+
predicted_covariances: Array # (T, n, n)
121+
log_likelihood: Array # ()
122+
observations: Array # (T, m)
123+
regime_smoothed_probs: Array # (T, K)
124+
regime_filtered_probs: Array # (T, K)
125+
regime_smoothed_states: Array # (T, K, n)
126+
regime_smoothed_covs: Array # (T, K, n, n)

src/dynaris/core/ssm.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,20 @@
2828
from jax import Array
2929

3030
from dynaris.core.nonlinear import NonlinearSSM
31-
from dynaris.core.results import FilterResult
31+
from dynaris.core.results import FilterResult, SwitchingFilterResult
3232
from dynaris.core.state_space import StateSpaceModel
33+
from dynaris.core.switching import MarkovSwitchingSSM
3334
from dynaris.core.types import GaussianState
3435
from dynaris.filters.ekf import ekf_filter
36+
from dynaris.filters.hamilton import hamilton_filter
3537
from dynaris.filters.kalman import kalman_filter
3638
from dynaris.filters.particle import particle_filter
3739
from dynaris.filters.ukf import ukf_filter
3840

39-
_VALID_FILTERS = {"auto", "kalman", "ekf", "ukf", "particle"}
41+
_VALID_FILTERS = {"auto", "kalman", "ekf", "ukf", "particle", "hamilton"}
4042
_LINEAR_FILTERS = {"kalman"}
4143
_NONLINEAR_FILTERS = {"ekf", "ukf", "particle"}
44+
_SWITCHING_FILTERS = {"hamilton"}
4245

4346

4447
def _to_jax_2d(y: Any) -> tuple[Array, pd.DatetimeIndex | None]:
@@ -92,14 +95,17 @@ class SSM:
9295

9396
def __init__(
9497
self,
95-
model: StateSpaceModel | NonlinearSSM,
98+
model: StateSpaceModel | NonlinearSSM | MarkovSwitchingSSM,
9699
filter: str = "auto",
97100
*,
98101
key: Array | None = None,
99102
**filter_kwargs: Any,
100103
) -> None:
101-
if not isinstance(model, (StateSpaceModel, NonlinearSSM)):
102-
msg = f"model must be a StateSpaceModel or NonlinearSSM, got {type(model).__name__}"
104+
if not isinstance(model, (StateSpaceModel, NonlinearSSM, MarkovSwitchingSSM)):
105+
msg = (
106+
f"model must be a StateSpaceModel, NonlinearSSM, or "
107+
f"MarkovSwitchingSSM, got {type(model).__name__}"
108+
)
103109
raise TypeError(msg)
104110

105111
if filter not in _VALID_FILTERS:
@@ -108,29 +114,41 @@ def __init__(
108114

109115
# Resolve auto-selection
110116
is_linear = isinstance(model, StateSpaceModel)
111-
filter_name = ("kalman" if is_linear else "ukf") if filter == "auto" else filter
117+
is_switching = isinstance(model, MarkovSwitchingSSM)
118+
if filter == "auto":
119+
if is_switching:
120+
filter_name = "hamilton"
121+
elif is_linear:
122+
filter_name = "kalman"
123+
else:
124+
filter_name = "ukf"
125+
else:
126+
filter_name = filter
112127

113128
# Validate filter/model compatibility
114129
if filter_name in _LINEAR_FILTERS and not is_linear:
115-
msg = f"Filter {filter_name!r} requires a StateSpaceModel, got NonlinearSSM."
130+
msg = f"Filter {filter_name!r} requires a StateSpaceModel."
116131
raise ValueError(msg)
117-
if filter_name in _NONLINEAR_FILTERS and is_linear:
118-
msg = f"Filter {filter_name!r} requires a NonlinearSSM, got StateSpaceModel."
132+
if filter_name in _NONLINEAR_FILTERS and (is_linear or is_switching):
133+
msg = f"Filter {filter_name!r} requires a NonlinearSSM."
134+
raise ValueError(msg)
135+
if filter_name in _SWITCHING_FILTERS and not is_switching:
136+
msg = f"Filter {filter_name!r} requires a MarkovSwitchingSSM."
119137
raise ValueError(msg)
120138

121139
self._model = model
122140
self._filter_name = filter_name
123141
self._filter_kwargs = filter_kwargs
124142
self._key = key
125-
self._filter_result: FilterResult | None = None
143+
self._filter_result: FilterResult | SwitchingFilterResult | None = None
126144
self._observations: Array | None = None
127145
self._index: pd.DatetimeIndex | None = None
128146
self._is_fitted = False
129147

130148
# --- Properties ---
131149

132150
@property
133-
def model(self) -> StateSpaceModel | NonlinearSSM:
151+
def model(self) -> StateSpaceModel | NonlinearSSM | MarkovSwitchingSSM:
134152
"""The underlying state-space model."""
135153
return self._model
136154

@@ -140,7 +158,7 @@ def filter_name(self) -> str:
140158
return self._filter_name
141159

142160
@property
143-
def filter_result(self) -> FilterResult:
161+
def filter_result(self) -> FilterResult | SwitchingFilterResult:
144162
"""Filter result from the last ``fit()`` call."""
145163
if self._filter_result is None:
146164
msg = "Model not fitted. Call .fit() first."
@@ -190,6 +208,10 @@ def fit(
190208
initial_state=initial_state,
191209
**self._filter_kwargs,
192210
)
211+
elif self._filter_name == "hamilton":
212+
self._filter_result = hamilton_filter(
213+
self._model, obs, initial_state=initial_state
214+
)
193215

194216
self._is_fitted = True
195217
return self

0 commit comments

Comments
 (0)