Skip to content

Commit 69fde20

Browse files
igerberclaude
andcommitted
Add warning and test for compute_robust_vcov numerical instability fallback
Address PR #115 review round 3: - P2: Add UserWarning when Rust backend falls back to Python on numerical instability - P3: Add test_numerical_instability_fallback_warns to verify warning and fallback behavior Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 1a00fb9 commit 69fde20

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

diff_diff/linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,13 @@ def compute_robust_vcov(
773773
"and covariates for linear dependencies."
774774
) from e
775775
if "numerically unstable" in error_msg.lower():
776-
# Fall back to NumPy on numerical instability
776+
# Fall back to NumPy on numerical instability (with warning)
777+
warnings.warn(
778+
f"Rust backend detected numerical instability: {e}. "
779+
"Falling back to Python backend for variance computation.",
780+
UserWarning,
781+
stacklevel=2,
782+
)
777783
return _compute_robust_vcov_numpy(X, residuals, cluster_ids)
778784
raise
779785

tests/test_linalg.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,41 @@ def test_cluster_robust_symmetric(self, ols_data):
520520

521521
np.testing.assert_array_almost_equal(vcov, vcov.T)
522522

523+
def test_numerical_instability_fallback_warns(self, ols_data):
524+
"""Test that numerical instability in Rust backend triggers warning and fallback."""
525+
from unittest.mock import patch
526+
import warnings
527+
528+
from diff_diff import HAS_RUST_BACKEND
529+
530+
if not HAS_RUST_BACKEND:
531+
pytest.skip("Rust backend not available")
532+
533+
X, residuals = ols_data
534+
535+
# Mock _rust_compute_robust_vcov to raise numerical instability error
536+
def mock_rust_vcov(*args, **kwargs):
537+
raise ValueError("Matrix inversion numerically unstable")
538+
539+
with patch("diff_diff.linalg._rust_compute_robust_vcov", mock_rust_vcov):
540+
with warnings.catch_warnings(record=True) as caught_warnings:
541+
warnings.simplefilter("always")
542+
vcov = compute_robust_vcov(X, residuals)
543+
544+
# Verify warning was emitted
545+
instability_warnings = [
546+
w for w in caught_warnings
547+
if "numerical instability" in str(w.message).lower()
548+
]
549+
assert len(instability_warnings) == 1, (
550+
f"Expected 1 numerical instability warning, got {len(instability_warnings)}"
551+
)
552+
553+
# Verify fallback produced valid vcov matrix
554+
assert vcov.shape == (X.shape[1], X.shape[1])
555+
assert np.allclose(vcov, vcov.T) # Symmetric
556+
assert np.all(np.linalg.eigvalsh(vcov) >= -1e-10) # PSD
557+
523558

524559
class TestComputeRSquared:
525560
"""Tests for compute_r_squared function."""

0 commit comments

Comments
 (0)