Skip to content

Commit 0b42b40

Browse files
igerberclaude
andcommitted
Fix P0/P1/P2/P3 issues from PR #218 review
P0-1: Fix double-weighting in vcov — solve_ols() now computes vcov on original-scale data with weights applied once (not on sqrt(w)-transformed data with weights applied again). Fix HC1 meat to use X'diag(w*u²)X instead of (X*w*u)'(X*w*u) which gave w² in the meat. P0-2: Replace one-shot weighted within-transformation with iterative alternating projections (max_iter=100, tol=1e-8) for correct weighted FWL residualization in TWFE. P0-3: Add NaN-vcov guard when no stratum contributes variance (all singletons skipped), preventing se=0 → t=±inf instead of NaN. P1: FPC validation now checks against PSU count (not obs count), enforces constancy within stratum, and rejects fpc-only designs. P2: Add survey_metadata to MultiPeriodDiDResults. Replace placeholder R reference tests with exact manual oracle tests. P3: Fix cluster comparison to use partition equivalence. Add docstrings for survey_design/weights/weight_type parameters. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4a54eed commit 0b42b40

File tree

9 files changed

+415
-235
lines changed

9 files changed

+415
-235
lines changed

diff_diff/estimators.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ def fit(
181181
List of categorical column names for high-dimensional fixed effects.
182182
Uses within-transformation (demeaning) instead of dummy variables.
183183
More efficient for large numbers of categories (e.g., firm, individual).
184+
survey_design : SurveyDesign, optional
185+
Survey design specification for design-based inference. When provided,
186+
uses Taylor Series Linearization for variance estimation and
187+
applies sampling weights to the regression.
184188
185189
Returns
186190
-------
@@ -787,6 +791,10 @@ def fit( # type: ignore[override]
787791
is detected (suggests CallawaySantAnna instead). Does NOT affect
788792
standard error computation -- use the ``cluster`` parameter for
789793
cluster-robust SEs.
794+
survey_design : SurveyDesign, optional
795+
Survey design specification for design-based inference. When provided,
796+
uses Taylor Series Linearization for variance estimation and
797+
applies sampling weights to the regression.
790798
791799
Returns
792800
-------
@@ -951,8 +959,8 @@ def fit( # type: ignore[override]
951959
# Resolve survey design if provided
952960
from diff_diff.survey import _resolve_effective_cluster, _resolve_survey_for_fit
953961

954-
resolved_survey, survey_weights, survey_weight_type, _ = _resolve_survey_for_fit(
955-
survey_design, data, self.inference
962+
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
963+
_resolve_survey_for_fit(survey_design, data, self.inference)
956964
)
957965

958966
# Handle absorbed fixed effects (within-transformation)
@@ -1161,6 +1169,7 @@ def fit( # type: ignore[override]
11611169
r_squared=r_squared,
11621170
reference_period=reference_period,
11631171
interaction_indices=interaction_indices,
1172+
survey_metadata=survey_metadata,
11641173
)
11651174

11661175
self._coefficients = coefficients

diff_diff/linalg.py

Lines changed: 78 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,14 @@ def solve_ols(
390390
rank-deficient matrices. Use only when you know the design matrix is
391391
full rank. If the matrix is actually rank-deficient, results may be
392392
incorrect (minimum-norm solution instead of R-style NA handling).
393+
weights : ndarray of shape (n,), optional
394+
Observation weights for Weighted Least Squares. When provided,
395+
minimizes sum(w_i * (y_i - X_i @ beta)^2). Weights should be
396+
pre-normalized (e.g., mean=1 for pweights).
397+
weight_type : str, default "pweight"
398+
Type of weights: "pweight" (inverse selection probability),
399+
"fweight" (frequency), or "aweight" (inverse variance).
400+
Affects variance estimation but not coefficient computation.
393401
394402
Returns
395403
-------
@@ -497,6 +505,11 @@ def solve_ols(
497505
X = X * sqrt_w[:, np.newaxis]
498506
y = y * sqrt_w
499507

508+
# When weights are present, compute vcov separately on original-scale data
509+
# to avoid double-weighting. The backend only computes point estimates.
510+
_weighted_vcov_external = weights is not None
511+
_backend_return_vcov = return_vcov and not _weighted_vcov_external
512+
500513
# Fast path: skip rank check and use Rust directly when requested
501514
# This saves O(nk²) QR overhead but won't detect rank-deficient matrices
502515
result = None # Will hold the tuple from backend functions
@@ -507,23 +520,20 @@ def solve_ols(
507520
X,
508521
y,
509522
cluster_ids=cluster_ids,
510-
return_vcov=return_vcov,
523+
return_vcov=_backend_return_vcov,
511524
return_fitted=return_fitted,
512525
)
513526
# result is None on numerical instability → fall through
514527
if result is None:
515-
# Fall through to Python without rank check (user guarantees full rank)
516528
result = _solve_ols_numpy(
517529
X,
518530
y,
519531
cluster_ids=cluster_ids,
520-
return_vcov=return_vcov,
532+
return_vcov=_backend_return_vcov,
521533
return_fitted=return_fitted,
522534
rank_deficient_action=rank_deficient_action,
523535
column_names=column_names,
524536
_skip_rank_check=True,
525-
weights=weights,
526-
weight_type=weight_type,
527537
)
528538
else:
529539
# Check for rank deficiency using fast pivoted QR decomposition.
@@ -546,14 +556,13 @@ def solve_ols(
546556
X,
547557
y,
548558
cluster_ids=cluster_ids,
549-
return_vcov=return_vcov,
559+
return_vcov=_backend_return_vcov,
550560
return_fitted=return_fitted,
551561
)
552562

553563
if result is not None:
554-
# Check for NaN vcov: Rust SVD may detect rank-deficiency that QR missed
555564
vcov_check = result[-1]
556-
if return_vcov and vcov_check is not None and np.any(np.isnan(vcov_check)):
565+
if _backend_return_vcov and vcov_check is not None and np.any(np.isnan(vcov_check)):
557566
warnings.warn(
558567
"Rust backend detected ill-conditioned matrix (NaN in variance-covariance). "
559568
"Re-running with Python backend for proper rank detection.",
@@ -563,35 +572,41 @@ def solve_ols(
563572
result = None # Force Python fallback below
564573

565574
if result is None:
566-
# Python backend for: weighted, rank-deficient, Rust instability, no Rust
567575
result = _solve_ols_numpy(
568576
X,
569577
y,
570578
cluster_ids=cluster_ids,
571-
return_vcov=return_vcov,
579+
return_vcov=_backend_return_vcov,
572580
return_fitted=return_fitted,
573581
rank_deficient_action=rank_deficient_action,
574582
column_names=column_names,
575-
_precomputed_rank_info=(
576-
(rank, dropped_cols, pivot)
577-
if not (weights is not None and _original_X is not None)
578-
else None
579-
),
580-
weights=weights,
581-
weight_type=weight_type,
583+
_precomputed_rank_info=(rank, dropped_cols, pivot),
582584
)
583585

584-
# Back-transform residuals to original scale when WLS was applied.
585-
# WLS solves on transformed (X_w, y_w) but residuals should be y - X @ beta.
586+
# Back-transform residuals and compute weighted vcov on original-scale data.
587+
# The WLS transform (sqrt(w) scaling) is for point estimates only. Vcov must
588+
# be computed on original X and residuals with weights applied exactly once.
586589
if _original_X is not None and _original_y is not None:
587590
if return_fitted:
588591
coefficients, _resid_w, _fitted_w, vcov_out = result
589-
fitted_orig = np.dot(_original_X, coefficients)
590-
residuals_orig = _original_y - fitted_orig
591-
result = (coefficients, residuals_orig, fitted_orig, vcov_out)
592592
else:
593593
coefficients, _resid_w, vcov_out = result
594-
residuals_orig = _original_y - np.dot(_original_X, coefficients)
594+
595+
fitted_orig = np.dot(_original_X, coefficients)
596+
residuals_orig = _original_y - fitted_orig
597+
598+
if return_vcov:
599+
vcov_out = _compute_robust_vcov_numpy(
600+
_original_X,
601+
residuals_orig,
602+
cluster_ids,
603+
weights=weights,
604+
weight_type=weight_type,
605+
)
606+
607+
if return_fitted:
608+
result = (coefficients, residuals_orig, fitted_orig, vcov_out)
609+
else:
595610
result = (coefficients, residuals_orig, vcov_out)
596611

597612
return result
@@ -608,8 +623,6 @@ def _solve_ols_numpy(
608623
column_names: Optional[List[str]] = None,
609624
_precomputed_rank_info: Optional[Tuple[int, np.ndarray, np.ndarray]] = None,
610625
_skip_rank_check: bool = False,
611-
weights: Optional[np.ndarray] = None,
612-
weight_type: str = "pweight",
613626
) -> Union[
614627
Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
615628
Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
@@ -716,8 +729,6 @@ def _solve_ols_numpy(
716729
X_reduced,
717730
residuals,
718731
cluster_ids,
719-
weights=weights,
720-
weight_type=weight_type,
721732
)
722733
vcov = _expand_vcov_with_nan(vcov_reduced, k, kept_cols)
723734
else:
@@ -732,13 +743,7 @@ def _solve_ols_numpy(
732743
# Compute variance-covariance matrix if requested
733744
vcov = None
734745
if return_vcov:
735-
vcov = _compute_robust_vcov_numpy(
736-
X,
737-
residuals,
738-
cluster_ids,
739-
weights=weights,
740-
weight_type=weight_type,
741-
)
746+
vcov = _compute_robust_vcov_numpy(X, residuals, cluster_ids)
742747

743748
if return_fitted:
744749
return coefficients, residuals, fitted, vcov
@@ -892,8 +897,8 @@ def _compute_robust_vcov_numpy(
892897
if weights is not None and weight_type == "fweight":
893898
n_eff = int(np.sum(weights))
894899

895-
# Compute weighted scores: pweight/fweight multiply by w; aweight and
896-
# unweighted use raw residuals (aweight errors are ~homoskedastic after WLS)
900+
# Compute weighted scores for cluster-robust meat (outer product of sums).
901+
# pweight/fweight multiply by w; aweight and unweighted use raw residuals.
897902
_use_weighted_scores = weights is not None and weight_type not in ("aweight",)
898903
if _use_weighted_scores:
899904
scores = X * (weights * residuals)[:, np.newaxis]
@@ -902,8 +907,12 @@ def _compute_robust_vcov_numpy(
902907

903908
if cluster_ids is None:
904909
# HC1 (heteroskedasticity-robust) standard errors
910+
# For HC1, meat = X' diag(w * u²) X (NOT scores'scores which gives w²*u²)
905911
adjustment = n_eff / (n_eff - k)
906-
meat = scores.T @ scores
912+
if _use_weighted_scores:
913+
meat = np.dot(X.T, X * (weights * residuals**2)[:, np.newaxis])
914+
else:
915+
meat = np.dot(X.T, X * (residuals**2)[:, np.newaxis])
907916
else:
908917
# Cluster-robust standard errors (vectorized via groupby)
909918
cluster_ids = np.asarray(cluster_ids)
@@ -1450,22 +1459,42 @@ def fit(
14501459
# Rank-deficient: compute vcov for identified coefficients only
14511460
kept_cols = np.where(~nan_mask)[0]
14521461
X_reduced = X[:, kept_cols]
1453-
mse = np.sum(residuals**2) / (n - k_effective)
1454-
try:
1455-
vcov_reduced = np.linalg.solve(
1456-
X_reduced.T @ X_reduced, mse * np.eye(k_effective)
1457-
)
1458-
except np.linalg.LinAlgError:
1459-
vcov_reduced = np.linalg.pinv(X_reduced.T @ X_reduced) * mse
1462+
if self.weights is not None:
1463+
# Weighted classical vcov: use weighted RSS and X'WX
1464+
w = self.weights
1465+
mse = np.sum(w * residuals**2) / (n - k_effective)
1466+
XtWX_reduced = X_reduced.T @ (X_reduced * w[:, np.newaxis])
1467+
try:
1468+
vcov_reduced = np.linalg.solve(XtWX_reduced, mse * np.eye(k_effective))
1469+
except np.linalg.LinAlgError:
1470+
vcov_reduced = np.linalg.pinv(XtWX_reduced) * mse
1471+
else:
1472+
mse = np.sum(residuals**2) / (n - k_effective)
1473+
try:
1474+
vcov_reduced = np.linalg.solve(
1475+
X_reduced.T @ X_reduced, mse * np.eye(k_effective)
1476+
)
1477+
except np.linalg.LinAlgError:
1478+
vcov_reduced = np.linalg.pinv(X_reduced.T @ X_reduced) * mse
14601479
# Expand to full size with NaN for dropped columns
14611480
vcov = _expand_vcov_with_nan(vcov_reduced, k, kept_cols)
14621481
else:
14631482
# Full rank: standard computation
1464-
mse = np.sum(residuals**2) / (n - k)
1465-
try:
1466-
vcov = np.linalg.solve(X.T @ X, mse * np.eye(k))
1467-
except np.linalg.LinAlgError:
1468-
vcov = np.linalg.pinv(X.T @ X) * mse
1483+
if self.weights is not None:
1484+
# Weighted classical vcov: use weighted RSS and X'WX
1485+
w = self.weights
1486+
mse = np.sum(w * residuals**2) / (n - k)
1487+
XtWX = X.T @ (X * w[:, np.newaxis])
1488+
try:
1489+
vcov = np.linalg.solve(XtWX, mse * np.eye(k))
1490+
except np.linalg.LinAlgError:
1491+
vcov = np.linalg.pinv(XtWX) * mse
1492+
else:
1493+
mse = np.sum(residuals**2) / (n - k)
1494+
try:
1495+
vcov = np.linalg.solve(X.T @ X, mse * np.eye(k))
1496+
except np.linalg.LinAlgError:
1497+
vcov = np.linalg.pinv(X.T @ X) * mse
14691498

14701499
# Compute survey vcov if applicable
14711500
if _use_survey_vcov:

diff_diff/results.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ class MultiPeriodDiDResults:
347347
r_squared: Optional[float] = field(default=None)
348348
reference_period: Optional[Any] = field(default=None)
349349
interaction_indices: Optional[Dict[Any, int]] = field(default=None, repr=False)
350+
# Survey design metadata (SurveyMetadata instance from diff_diff.survey)
351+
survey_metadata: Optional[Any] = field(default=None)
350352

351353
def __repr__(self) -> str:
352354
"""Concise string representation."""
@@ -400,6 +402,28 @@ def summary(self, alpha: Optional[float] = None) -> str:
400402
if self.r_squared is not None:
401403
lines.append(f"{'R-squared:':<25} {self.r_squared:>10.4f}")
402404

405+
# Add survey design info
406+
if self.survey_metadata is not None:
407+
sm = self.survey_metadata
408+
lines.extend(
409+
[
410+
"",
411+
"-" * 80,
412+
"Survey Design".center(80),
413+
"-" * 80,
414+
f"{'Weight type:':<25} {sm.weight_type:>10}",
415+
]
416+
)
417+
if sm.n_strata is not None:
418+
lines.append(f"{'Strata:':<25} {sm.n_strata:>10}")
419+
if sm.n_psu is not None:
420+
lines.append(f"{'PSU/Cluster:':<25} {sm.n_psu:>10}")
421+
lines.append(f"{'Effective sample size:':<25} {sm.effective_n:>10.1f}")
422+
lines.append(f"{'Design effect (DEFF):':<25} {sm.design_effect:>10.2f}")
423+
if sm.df_survey is not None:
424+
lines.append(f"{'Survey d.f.:':<25} {sm.df_survey:>10}")
425+
lines.append("-" * 80)
426+
403427
# Pre-period effects (parallel trends test)
404428
pre_effects = {p: pe for p, pe in self.period_effects.items() if p in self.pre_periods}
405429
if pre_effects:
@@ -548,6 +572,17 @@ def to_dict(self) -> Dict[str, Any]:
548572
result[f"se_period_{period}"] = pe.se
549573
result[f"pval_period_{period}"] = pe.p_value
550574

575+
# Add survey metadata if present
576+
if self.survey_metadata is not None:
577+
sm = self.survey_metadata
578+
result["weight_type"] = sm.weight_type
579+
result["effective_n"] = sm.effective_n
580+
result["design_effect"] = sm.design_effect
581+
result["sum_weights"] = sm.sum_weights
582+
result["n_strata"] = sm.n_strata
583+
result["n_psu"] = sm.n_psu
584+
result["df_survey"] = sm.df_survey
585+
551586
return result
552587

553588
def to_dataframe(self) -> pd.DataFrame:

0 commit comments

Comments
 (0)