diff --git a/README.md b/README.md index 20b6dcc3..338fd6a7 100644 --- a/README.md +++ b/README.md @@ -1138,13 +1138,15 @@ trop_est = TROP( lambda_nn_grid=[0.0, 0.1, 1.0], # Nuclear norm grid n_bootstrap=200 ) +# Note: TROP infers treatment periods from the treatment indicator column. +# The 'treated' column must be an absorbing state (D=1 for all periods +# during and after treatment starts for each unit). results = trop_est.fit( panel_data, outcome='gdp_growth', treatment='treated', unit='state', - time='year', - post_periods=[2015, 2016, 2017, 2018] + time='year' ) # View results @@ -1232,9 +1234,11 @@ sdid_results = sdid.fit(data, outcome='y', treatment='treated', unit='unit', time='time', post_periods=[5,6,7]) # TROP (accounts for factors) +# Note: TROP infers treatment periods from the treatment indicator column +# (D=1 for treated observations, D=0 for control) trop_est = TROP() # Uses default grids with LOOCV selection trop_results = trop_est.fit(data, outcome='y', treatment='treated', - unit='unit', time='time', post_periods=[5,6,7]) + unit='unit', time='time') print(f"SDID estimate: {sdid_results.att:.3f}") print(f"TROP estimate: {trop_results.att:.3f}") @@ -1279,13 +1283,13 @@ TROP( ```python # One-liner estimation with default tuning grids +# Note: TROP infers treatment periods from the treatment indicator results = trop( data, outcome='y', treatment='treated', unit='unit', time='time', - post_periods=[5, 6, 7], n_bootstrap=200 ) ``` @@ -1877,10 +1881,11 @@ TROP( |-----------|------|-------------| | `data` | DataFrame | Panel data | | `outcome` | str | Outcome variable column name | -| `treatment` | str | Treatment indicator column (0/1) | +| `treatment` | str | Treatment indicator column (0/1 absorbing state) | | `unit` | str | Unit identifier column | | `time` | str | Time period column | -| `post_periods` | list | List of post-treatment period values | + +Note: TROP infers treatment periods from the treatment indicator column. The treatment column should be an absorbing state indicator where D=1 for all periods during and after treatment starts. ### TROPResults @@ -1906,8 +1911,8 @@ TROP( | `factor_matrix` | Low-rank factor matrix L (n_periods x n_units) | | `effective_rank` | Effective rank of factor matrix | | `loocv_score` | LOOCV score for selected parameters | -| `pre_periods` | List of pre-treatment periods | -| `post_periods` | List of post-treatment periods | +| `n_pre_periods` | Number of pre-treatment periods | +| `n_post_periods` | Number of post-treatment periods | | `variance_method` | Variance estimation method | | `bootstrap_distribution` | Bootstrap distribution (if bootstrap) | diff --git a/ROADMAP.md b/ROADMAP.md index 87a47e0f..b76ee5a0 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -44,6 +44,17 @@ Two-stage approach gaining traction in applied work. First residualizes outcomes **Reference**: Gardner (2022). *Working Paper*. +### Stacked Difference-in-Differences + +An intuitive approach that explicitly constructs sub-experiments for each treatment cohort, avoiding forbidden comparisons. + +- Creates separate datasets per cohort with valid controls only +- Stacks sub-experiments and applies corrective sample weights +- Returns variance-weighted ATT with proper compositional balance +- Conceptually simpler alternative to aggregation-based methods + +**Reference**: [Wing, Freedman & Hollingsworth (2024)](https://www.nber.org/papers/w32054). *NBER Working Paper 32054*. Stata: `STACKDID`. + ### Staggered Triple Difference (DDD) Extend the existing `TripleDifference` estimator to handle staggered adoption settings. The current implementation handles 2-period DDD; this extends to multi-period designs. diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 0b45a375..dc71a692 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -43,6 +43,11 @@ from diff_diff.utils import compute_confidence_interval, compute_p_value +# Sentinel value for "disabled" mode in LOOCV parameter search +# Following paper's footnote 2: λ=∞ disables the corresponding component +_LAMBDA_INF: float = float('inf') + + class _PrecomputedStructures(TypedDict): """Type definition for pre-computed structures used across LOOCV iterations. @@ -109,11 +114,15 @@ class TROPResults: treatment_effects : dict Individual treatment effects for each treated (unit, time) pair. lambda_time : float - Selected time weight decay parameter. + Selected time weight decay parameter from grid. Note: infinity values + are converted internally (∞ → 0.0 for uniform weights) for computation. lambda_unit : float - Selected unit weight decay parameter. + Selected unit weight decay parameter from grid. Note: infinity values + are converted internally (∞ → 0.0 for uniform weights) for computation. lambda_nn : float - Selected nuclear norm regularization parameter. + Selected nuclear norm regularization parameter from grid. Note: infinity + values are converted internally (∞ → 1e10, factor model disabled) for + computation. factor_matrix : np.ndarray Estimated low-rank factor matrix L (n_periods x n_units). effective_rank : float @@ -124,10 +133,10 @@ class TROPResults: Method used for variance estimation. alpha : float Significance level for confidence interval. - pre_periods : list - List of pre-treatment period identifiers. - post_periods : list - List of post-treatment period identifiers. + n_pre_periods : int + Number of pre-treatment periods. + n_post_periods : int + Number of post-treatment periods (periods with D=1 observations). n_bootstrap : int, optional Number of bootstrap replications (if bootstrap variance). bootstrap_distribution : np.ndarray, optional @@ -154,8 +163,8 @@ class TROPResults: loocv_score: float variance_method: str alpha: float = 0.05 - pre_periods: List[Any] = field(default_factory=list) - post_periods: List[Any] = field(default_factory=list) + n_pre_periods: int = 0 + n_post_periods: int = 0 n_bootstrap: Optional[int] = field(default=None) bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) @@ -197,8 +206,8 @@ def summary(self, alpha: Optional[float] = None) -> str: f"{'Treated units:':<25} {self.n_treated:>10}", f"{'Control units:':<25} {self.n_control:>10}", f"{'Treated observations:':<25} {self.n_treated_obs:>10}", - f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}", - f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}", + f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}", + f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}", "", "-" * 75, "Tuning Parameters (selected via LOOCV)".center(75), @@ -261,8 +270,8 @@ def to_dict(self) -> Dict[str, Any]: "n_treated": self.n_treated, "n_control": self.n_control, "n_treated_obs": self.n_treated_obs, - "n_pre_periods": len(self.pre_periods), - "n_post_periods": len(self.post_periods), + "n_pre_periods": self.n_pre_periods, + "n_post_periods": self.n_post_periods, "lambda_time": self.lambda_time, "lambda_unit": self.lambda_unit, "lambda_nn": self.lambda_nn, @@ -397,7 +406,6 @@ class TROP: ... treatment='treated', ... unit='unit', ... time='period', - ... post_periods=[5, 6, 7, 8] ... ) >>> results.print_summary() @@ -658,6 +666,168 @@ def _compute_unit_distance_for_obs( else: return np.inf + def _univariate_loocv_search( + self, + Y: np.ndarray, + D: np.ndarray, + control_mask: np.ndarray, + control_unit_idx: np.ndarray, + n_units: int, + n_periods: int, + param_name: str, + grid: List[float], + fixed_params: Dict[str, float], + ) -> Tuple[float, float]: + """ + Search over one parameter with others fixed. + + Following paper's footnote 2, this performs a univariate grid search + for one tuning parameter while holding others fixed. The fixed_params + can include _LAMBDA_INF values to disable specific components: + - lambda_nn = inf: Skip nuclear norm regularization (L=0) + - lambda_time = inf: Uniform time weights (treated as 0) + - lambda_unit = inf: Uniform unit weights (treated as 0) + + Parameters + ---------- + Y : np.ndarray + Outcome matrix (n_periods x n_units). + D : np.ndarray + Treatment indicator matrix (n_periods x n_units). + control_mask : np.ndarray + Boolean mask for control observations. + control_unit_idx : np.ndarray + Indices of control units. + n_units : int + Number of units. + n_periods : int + Number of periods. + param_name : str + Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'. + grid : List[float] + Grid of values to search over. + fixed_params : Dict[str, float] + Fixed values for other parameters. May include _LAMBDA_INF. + + Returns + ------- + Tuple[float, float] + (best_value, best_score) for the searched parameter. + """ + best_score = np.inf + best_value = grid[0] if grid else 0.0 + + for value in grid: + params = {**fixed_params, param_name: value} + + # Convert inf values to 0 for computation (inf means "disabled" = uniform weights) + lambda_time = params.get('lambda_time', 0.0) + lambda_unit = params.get('lambda_unit', 0.0) + lambda_nn = params.get('lambda_nn', 0.0) + + # Handle infinity as "disabled" mode + # Per paper Equations 2-3: + # - λ_time/λ_unit=∞ → exp(-∞×dist)→0 for dist>0, uniform weights → use 0.0 + # - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10 + # Note: λ_nn=0 means NO regularization (full-rank L), opposite of "disabled" + if np.isinf(lambda_time): + lambda_time = 0.0 # Uniform time weights + if np.isinf(lambda_unit): + lambda_unit = 0.0 # Uniform unit weights + if np.isinf(lambda_nn): + lambda_nn = 1e10 # Very large → L≈0 (factor model disabled) + + try: + score = self._loocv_score_obs_specific( + Y, D, control_mask, control_unit_idx, + lambda_time, lambda_unit, lambda_nn, + n_units, n_periods + ) + if score < best_score: + best_score = score + best_value = value + except (np.linalg.LinAlgError, ValueError): + continue + + return best_value, best_score + + def _cycling_parameter_search( + self, + Y: np.ndarray, + D: np.ndarray, + control_mask: np.ndarray, + control_unit_idx: np.ndarray, + n_units: int, + n_periods: int, + initial_lambda: Tuple[float, float, float], + max_cycles: int = 10, + ) -> Tuple[float, float, float]: + """ + Cycle through parameters until convergence (coordinate descent). + + Following paper's footnote 2 (Stage 2), this iteratively optimizes + each tuning parameter while holding the others fixed, until convergence. + + Parameters + ---------- + Y : np.ndarray + Outcome matrix (n_periods x n_units). + D : np.ndarray + Treatment indicator matrix (n_periods x n_units). + control_mask : np.ndarray + Boolean mask for control observations. + control_unit_idx : np.ndarray + Indices of control units. + n_units : int + Number of units. + n_periods : int + Number of periods. + initial_lambda : Tuple[float, float, float] + Initial values (lambda_time, lambda_unit, lambda_nn). + max_cycles : int, default=10 + Maximum number of coordinate descent cycles. + + Returns + ------- + Tuple[float, float, float] + Optimized (lambda_time, lambda_unit, lambda_nn). + """ + lambda_time, lambda_unit, lambda_nn = initial_lambda + prev_score = np.inf + + for cycle in range(max_cycles): + # Optimize λ_unit (fix λ_time, λ_nn) + lambda_unit, _ = self._univariate_loocv_search( + Y, D, control_mask, control_unit_idx, n_units, n_periods, + 'lambda_unit', self.lambda_unit_grid, + {'lambda_time': lambda_time, 'lambda_nn': lambda_nn} + ) + + # Optimize λ_time (fix λ_unit, λ_nn) + lambda_time, _ = self._univariate_loocv_search( + Y, D, control_mask, control_unit_idx, n_units, n_periods, + 'lambda_time', self.lambda_time_grid, + {'lambda_unit': lambda_unit, 'lambda_nn': lambda_nn} + ) + + # Optimize λ_nn (fix λ_unit, λ_time) + lambda_nn, score = self._univariate_loocv_search( + Y, D, control_mask, control_unit_idx, n_units, n_periods, + 'lambda_nn', self.lambda_nn_grid, + {'lambda_unit': lambda_unit, 'lambda_time': lambda_time} + ) + + # Check convergence + if abs(score - prev_score) < 1e-6: + logger.debug( + "Cycling search converged after %d cycles with score %.6f", + cycle + 1, score + ) + break + prev_score = score + + return lambda_time, lambda_unit, lambda_nn + def fit( self, data: pd.DataFrame, @@ -665,7 +835,6 @@ def fit( treatment: str, unit: str, time: str, - post_periods: Optional[List[Any]] = None, ) -> TROPResults: """ Fit the TROP model. @@ -679,20 +848,31 @@ def fit( Name of the outcome variable column. treatment : str Name of the treatment indicator column (0/1). - Should be 1 for treated unit-time observations. + + IMPORTANT: This should be an ABSORBING STATE indicator, not a + treatment timing indicator. For each unit, D=1 for ALL periods + during and after treatment: + + - D[t, i] = 0 for all t < g_i (pre-treatment periods) + - D[t, i] = 1 for all t >= g_i (treatment and post-treatment) + + where g_i is the treatment start time for unit i. + + For staggered adoption, different units can have different g_i. + The ATT averages over ALL D=1 cells per Equation 1 of the paper. unit : str Name of the unit identifier column. time : str Name of the time period column. - post_periods : list, optional - List of time period values that are post-treatment. - If None, infers from treatment indicator. Returns ------- TROPResults Object containing the ATT estimate, standard error, - factor estimates, and tuning parameters. + factor estimates, and tuning parameters. The lambda_* + attributes show the selected grid values. Infinity values + (∞) are converted internally: λ_time/λ_unit=∞ → 0.0 (uniform + weights), λ_nn=∞ → 1e10 (factor model disabled). """ # Validate inputs required_cols = [outcome, treatment, unit, time] @@ -720,13 +900,39 @@ def fit( .reindex(index=all_periods, columns=all_units) .values ) - D = ( + + # For D matrix, track missing values BEFORE fillna to support unbalanced panels + # Issue 3 fix: Missing observations should not trigger spurious violations + D_raw = ( data.pivot(index=time, columns=unit, values=treatment) .reindex(index=all_periods, columns=all_units) - .fillna(0) - .astype(int) - .values ) + missing_mask = pd.isna(D_raw).values # True where originally missing + D = D_raw.fillna(0).astype(int).values + + # Validate D is monotonic non-decreasing per unit (absorbing state) + # D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods + # Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity + # This catches 1→0 violations that span missing period gaps + # Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though + # adjacent period transitions don't show it (the gap hides the transition) + violating_units = [] + for unit_idx in range(n_units): + # Get observed D values for this unit (where not missing) + observed_mask = ~missing_mask[:, unit_idx] + observed_d = D[observed_mask, unit_idx] + + # Check if observed sequence is monotonically non-decreasing + if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0): + violating_units.append(all_units[unit_idx]) + + if violating_units: + raise ValueError( + f"Treatment indicator is not an absorbing state for units: {violating_units}. " + f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). " + f"If this is event-study style data, convert to absorbing state: " + f"D[t, i] = 1 for all t >= first treatment period." + ) # Identify treated observations treated_mask = D == 1 @@ -743,28 +949,23 @@ def fit( if len(control_unit_idx) == 0: raise ValueError("No control units found") - # Determine pre/post periods - if post_periods is None: - # Infer from first treatment time - first_treat_period = None - for t in range(n_periods): - if np.any(D[t, :] == 1): - first_treat_period = t - break - if first_treat_period is None: - raise ValueError("Could not infer post-treatment periods") - pre_period_idx = list(range(first_treat_period)) - post_period_idx = list(range(first_treat_period, n_periods)) - else: - post_period_idx = [period_to_idx[p] for p in post_periods if p in period_to_idx] - pre_period_idx = [i for i in range(n_periods) if i not in post_period_idx] + # Determine pre/post periods from treatment indicator D + # D matrix is the sole input for treatment timing per the paper + first_treat_period = None + for t in range(n_periods): + if np.any(D[t, :] == 1): + first_treat_period = t + break + if first_treat_period is None: + raise ValueError("Could not infer post-treatment periods from D matrix") - if len(pre_period_idx) < 2: - raise ValueError("Need at least 2 pre-treatment periods") + n_pre_periods = first_treat_period + # Count periods where D=1 is actually observed (matches docstring) + # Per docstring: "Number of post-treatment periods (periods with D=1 observations)" + n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1))) - pre_periods_list = [idx_to_period[i] for i in pre_period_idx] - post_periods_list = [idx_to_period[i] for i in post_period_idx] - n_treated_periods = len(post_period_idx) + if n_pre_periods < 2: + raise ValueError("Need at least 2 pre-treatment periods") # Step 1: Grid search with LOOCV for tuning parameters best_lambda = None @@ -789,14 +990,45 @@ def fit( lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64) lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64) - best_lt, best_lu, best_ln, best_score = _rust_loocv_grid_search( + result = _rust_loocv_grid_search( Y, D.astype(np.float64), control_mask_u8, time_dist_matrix, lambda_time_arr, lambda_unit_arr, lambda_nn_arr, self.max_loocv_samples, self.max_iter, self.tol, self.seed if self.seed is not None else 0 ) - best_lambda = (best_lt, best_lu, best_ln) + # Unpack result - 7 values including optional first_failed_obs + best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result + # Only accept finite scores - infinite means all fits failed + if np.isfinite(best_score): + best_lambda = (best_lt, best_lu, best_ln) + # else: best_lambda stays None, triggering defaults fallback + # Emit warnings consistent with Python implementation + if n_valid == 0: + # Include failed observation coordinates if available (Issue 2 fix) + obs_info = "" + if first_failed_obs is not None: + t_idx, i_idx = first_failed_obs + obs_info = f" First failure at observation ({t_idx}, {i_idx})." + warnings.warn( + f"LOOCV: All {n_attempted} fits failed for " + f"λ=({best_lt}, {best_lu}, {best_ln}). " + f"Returning infinite score.{obs_info}", + UserWarning + ) + elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted: + n_failed = n_attempted - n_valid + # Include failed observation coordinates if available + obs_info = "" + if first_failed_obs is not None: + t_idx, i_idx = first_failed_obs + obs_info = f" First failure at observation ({t_idx}, {i_idx})." + warnings.warn( + f"LOOCV: {n_failed}/{n_attempted} fits failed for " + f"λ=({best_lt}, {best_lu}, {best_ln}). " + f"This may indicate numerical instability.{obs_info}", + UserWarning + ) except Exception as e: # Fall back to Python implementation on error logger.debug( @@ -806,21 +1038,54 @@ def fit( best_score = np.inf # Fall back to Python implementation if Rust unavailable or failed + # Uses two-stage approach per paper's footnote 2: + # Stage 1: Univariate searches for initial values + # Stage 2: Cycling (coordinate descent) until convergence if best_lambda is None: - for lambda_time in self.lambda_time_grid: - for lambda_unit in self.lambda_unit_grid: - for lambda_nn in self.lambda_nn_grid: - try: - score = self._loocv_score_obs_specific( - Y, D, control_mask, control_unit_idx, - lambda_time, lambda_unit, lambda_nn, - n_units, n_periods - ) - if score < best_score: - best_score = score - best_lambda = (lambda_time, lambda_unit, lambda_nn) - except (np.linalg.LinAlgError, ValueError): - continue + # Stage 1: Univariate searches with extreme fixed values + # Following paper's footnote 2 for initial bounds + + # λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment) + lambda_time_init, _ = self._univariate_loocv_search( + Y, D, control_mask, control_unit_idx, n_units, n_periods, + 'lambda_time', self.lambda_time_grid, + {'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF} + ) + + # λ_nn search: fix λ_time=∞ (uniform time weights), λ_unit=0 + lambda_nn_init, _ = self._univariate_loocv_search( + Y, D, control_mask, control_unit_idx, n_units, n_periods, + 'lambda_nn', self.lambda_nn_grid, + {'lambda_time': _LAMBDA_INF, 'lambda_unit': 0.0} + ) + + # λ_unit search: fix λ_nn=∞, λ_time=0 + lambda_unit_init, _ = self._univariate_loocv_search( + Y, D, control_mask, control_unit_idx, n_units, n_periods, + 'lambda_unit', self.lambda_unit_grid, + {'lambda_nn': _LAMBDA_INF, 'lambda_time': 0.0} + ) + + # Stage 2: Cycling refinement (coordinate descent) + lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search( + Y, D, control_mask, control_unit_idx, n_units, n_periods, + (lambda_time_init, lambda_unit_init, lambda_nn_init) + ) + + # Compute final score for the optimized parameters + try: + best_score = self._loocv_score_obs_specific( + Y, D, control_mask, control_unit_idx, + lambda_time, lambda_unit, lambda_nn, + n_units, n_periods + ) + # Only accept finite scores - infinite means all fits failed + if np.isfinite(best_score): + best_lambda = (lambda_time, lambda_unit, lambda_nn) + # else: best_lambda stays None, triggering defaults fallback + except (np.linalg.LinAlgError, ValueError): + # If even the optimized parameters fail, best_lambda stays None + pass if best_lambda is None: warnings.warn( @@ -833,6 +1098,26 @@ def fit( self._optimal_lambda = best_lambda lambda_time, lambda_unit, lambda_nn = best_lambda + # Convert infinity values for final estimation (matching LOOCV conversion) + # This ensures final estimation uses the same effective parameters that LOOCV evaluated. + # See REGISTRY.md "λ=∞ implementation" for rationale. + # + # IMPORTANT: Store original grid values for results, use converted for computation. + # This lets users see what was selected from their grid, while ensuring consistent + # behavior between point estimation and variance estimation. + original_lambda_time, original_lambda_unit, original_lambda_nn = best_lambda + + if np.isinf(lambda_time): + lambda_time = 0.0 # Uniform time weights + if np.isinf(lambda_unit): + lambda_unit = 0.0 # Uniform unit weights + if np.isinf(lambda_nn): + lambda_nn = 1e10 # Very large → L≈0 (factor model disabled) + + # Create effective_lambda with converted values for ALL downstream computation + # This ensures variance estimation uses the same parameters as point estimation + effective_lambda = (lambda_time, lambda_unit, lambda_nn) + # Step 2: Final estimation - per-observation model fitting following Algorithm 2 # For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it} treatment_effects = {} @@ -886,14 +1171,16 @@ def fit( effective_rank = 0.0 # Step 4: Variance estimation + # Use effective_lambda (converted values) to ensure SE is computed with same + # parameters as point estimation. This fixes the variance inconsistency issue. if self.variance_method == "bootstrap": se, bootstrap_dist = self._bootstrap_variance( - data, outcome, treatment, unit, time, post_periods_list, - best_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx + data, outcome, treatment, unit, time, + effective_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx ) else: se, bootstrap_dist = self._jackknife_variance( - Y, D, control_mask, control_unit_idx, best_lambda, + Y, D, control_mask, control_unit_idx, effective_lambda, n_units, n_periods ) @@ -901,11 +1188,12 @@ def fit( if se > 0: t_stat = att / se p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1))) + conf_int = compute_confidence_interval(att, se, self.alpha) else: - t_stat = 0.0 - p_value = 1.0 - - conf_int = compute_confidence_interval(att, se, self.alpha) + # When SE is undefined/zero, ALL inference fields should be NaN + t_stat = np.nan + p_value = np.nan + conf_int = (np.nan, np.nan) # Create results dictionaries unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)} @@ -925,16 +1213,18 @@ def fit( unit_effects=unit_effects_dict, time_effects=time_effects_dict, treatment_effects=treatment_effects, - lambda_time=lambda_time, - lambda_unit=lambda_unit, - lambda_nn=lambda_nn, + # Store ORIGINAL grid values (possibly inf) so users see what was selected. + # Internally, infinity values are converted for computation (see effective_lambda). + lambda_time=original_lambda_time, + lambda_unit=original_lambda_unit, + lambda_nn=original_lambda_nn, factor_matrix=L_hat, effective_rank=effective_rank, loocv_score=best_score, variance_method=self.variance_method, alpha=self.alpha, - pre_periods=pre_periods_list, - post_periods=post_periods_list, + n_pre_periods=n_pre_periods, + n_post_periods=n_post_periods, n_bootstrap=self.n_bootstrap if self.variance_method == "bootstrap" else None, bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None, ) @@ -1409,6 +1699,17 @@ def _loocv_score_obs_specific( indices = rng.choice(len(control_obs), size=max_loocv, replace=False) control_obs = [control_obs[idx] for idx in indices] + # Empty control set check: if no control observations, return infinity + # A score of 0.0 would incorrectly "win" over legitimate parameters + if len(control_obs) == 0: + warnings.warn( + f"LOOCV: No valid control observations for " + f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). " + "Returning infinite score.", + UserWarning + ) + return np.inf + tau_squared_sum = 0.0 n_valid = 0 @@ -1433,12 +1734,19 @@ def _loocv_score_obs_specific( n_valid += 1 except (np.linalg.LinAlgError, ValueError): - continue - - if n_valid == 0: - return np.inf + # Per Equation 5: Q(λ) must sum over ALL D==0 cells + # Any failure means this λ cannot produce valid estimates for all cells + warnings.warn( + f"LOOCV: Fit failed for observation ({t}, {i}) with " + f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). " + "Returning infinite score per Equation 5.", + UserWarning + ) + return np.inf - return tau_squared_sum / n_valid + # Return SUM of squared pseudo-treatment effects per Equation 5 (page 8): + # Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² + return tau_squared_sum def _bootstrap_variance( self, @@ -1447,7 +1755,6 @@ def _bootstrap_variance( treatment: str, unit: str, time: str, - post_periods: List[Any], optimal_lambda: Tuple[float, float, float], Y: Optional[np.ndarray] = None, D: Optional[np.ndarray] = None, @@ -1473,8 +1780,6 @@ def _bootstrap_variance( Name of the unit identifier column in data. time : str Name of the time period column in data. - post_periods : list - List of post-treatment time periods. optimal_lambda : tuple of float Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn) from cross-validation. Used for model estimation in each bootstrap. @@ -1579,7 +1884,7 @@ def _bootstrap_variance( # Fit with fixed lambda (skip LOOCV for speed) att = self._fit_with_fixed_lambda( boot_data, outcome, treatment, unit, time, - post_periods, optimal_lambda + optimal_lambda ) bootstrap_estimates_list.append(att) except (ValueError, np.linalg.LinAlgError, KeyError): @@ -1703,7 +2008,6 @@ def _fit_with_fixed_lambda( treatment: str, unit: str, time: str, - post_periods: List[Any], fixed_lambda: Tuple[float, float, float], ) -> float: """ @@ -1803,7 +2107,6 @@ def trop( treatment: str, unit: str, time: str, - post_periods: Optional[List[Any]] = None, **kwargs, ) -> TROPResults: """ @@ -1816,13 +2119,16 @@ def trop( outcome : str Outcome variable column name. treatment : str - Treatment indicator column name. + Treatment indicator column name (0/1). + + IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment + timing indicator. For each unit, D=1 for ALL periods during and after + treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is + the treatment start time for unit i). unit : str Unit identifier column name. time : str Time period column name. - post_periods : list, optional - Post-treatment periods. **kwargs Additional arguments passed to TROP constructor. @@ -1834,8 +2140,8 @@ def trop( Examples -------- >>> from diff_diff import trop - >>> results = trop(data, 'y', 'treated', 'unit', 'time', post_periods=[5,6,7]) + >>> results = trop(data, 'y', 'treated', 'unit', 'time') >>> print(f"ATT: {results.att:.3f}") """ estimator = TROP(**kwargs) - return estimator.fit(data, outcome, treatment, unit, time, post_periods) + return estimator.fit(data, outcome, treatment, unit, time) diff --git a/docs/api/trop.rst b/docs/api/trop.rst index 3712a4e0..732dbcfa 100644 --- a/docs/api/trop.rst +++ b/docs/api/trop.rst @@ -130,13 +130,15 @@ Basic usage:: seed=42 ) + # Note: TROP infers treatment periods from the treatment indicator column. + # The treatment column should be an absorbing state (D=1 for all periods + # during and after treatment starts). results = trop.fit( data, outcome='y', treatment='treated', unit='unit_id', - time='period', - post_periods=[10, 11, 12, 13, 14] + time='period' ) results.print_summary() @@ -150,7 +152,6 @@ Quick estimation with convenience function:: treatment='treated', unit='unit_id', time='period', - post_periods=[10, 11, 12, 13, 14], n_bootstrap=200 ) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 015ba986..6f779c7b 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -449,9 +449,33 @@ Doubly robust estimator: **Key implementation requirements:** *Assumption checks / warnings:* -- Requires sufficient pre-treatment periods for factor estimation +- Requires sufficient pre-treatment periods for factor estimation (at least 2) - Warns if estimated rank seems too high/low relative to panel dimensions - Unit weights can become degenerate if λ_unit too large +- Returns Q(λ) = ∞ if ANY LOOCV fit fails (Equation 5 compliance) + +*Treatment indicator (D matrix) semantics:* + +D must be an **ABSORBING STATE** indicator, not a treatment timing indicator: +- D[t, i] = 0 for all t < g_i (pre-treatment periods for unit i) +- D[t, i] = 1 for all t >= g_i (during and after treatment for unit i) + +where g_i is the treatment start time for unit i. + +For staggered adoption, different units have different treatment start times g_i. +The D matrix naturally handles this - distances use periods where BOTH units +have D=0, matching the paper's (1 - W_iu)(1 - W_ju) formula in Equation 3. + +**Wrong D specification**: If user provides event-style D (only first treatment period +has D=1), ATT will be incorrect - document this clearly. + +*ATT definition (Equation 1, Section 6.1):* +``` +τ̂ = (1 / Σ_i Σ_t W_{it}) Σ_{i=1}^N Σ_{t=1}^T W_{it} τ̂_{it}(λ̂) +``` +- ATT averages over ALL cells where D_it=1 (treatment indicator) +- No separate "post_periods" concept - D matrix is the sole input for treatment timing +- Supports general assignment patterns including staggered adoption *Estimator equation (as implemented):* @@ -478,6 +502,30 @@ where d(j, treated) is RMSE distance to treated units in pre-period. Time weights: analogous construction for periods. +*LOOCV tuning parameter selection (Equation 5, Footnote 2):* +``` +Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² +``` +- Score is **SUM** of squared pseudo-treatment effects on control observations +- **Two-stage procedure** (per paper's footnote 2): + - Stage 1: Univariate grid searches with extreme fixed values + - λ_time search: fix λ_unit=0, λ_nn=∞ (disabled) + - λ_nn search: fix λ_time=∞ (uniform time weights), λ_unit=0 + - λ_unit search: fix λ_nn=∞, λ_time=0 + - Stage 2: Cycling (coordinate descent) until convergence +- **"Disabled" parameter semantics** (per paper Equations 2-3): + - `λ_time=∞` or `λ_unit=∞`: Converts to `0.0` internally → exp(-0×dist)=1 → uniform weights + - `λ_nn=∞`: Converts to `1e10` internally → very large penalty → L≈0 (factor model off, recovers DID/TWFE) + - **Note**: `λ_nn=0` means NO regularization (full-rank L), which is the OPPOSITE of "disabled" +- **Subsampling**: max_loocv_samples (default 100) for computational tractability + - This subsamples control observations, NOT parameter combinations + - Increases precision at cost of computation; increase for more precise tuning +- **LOOCV failure handling** (Equation 5 compliance): + - If ANY LOOCV fit fails for a parameter combination, Q(λ) = ∞ + - A warning is emitted on the first failure with the observation (t, i) and λ values + - Subsequent failures for the same λ are not individually warned (early return) + - This ensures λ selection only considers fully estimable combinations + *Standard errors:* - Default: Block bootstrap preserving panel structure - Alternative: Jackknife (leave-one-unit-out) @@ -486,16 +534,42 @@ Time weights: analogous construction for periods. - Rank selection: automatic via cross-validation, information criterion, or elbow - Zero singular values: handled by soft-thresholding - Extreme distances: weights regularized to prevent degeneracy +- LOOCV fit failures: returns Q(λ) = ∞ on first failure (per Equation 5 requirement that Q sums over ALL D==0 cells); if all parameter combinations fail, falls back to defaults (1.0, 1.0, 0.1) +- **λ=∞ implementation**: Infinity values are converted in both LOOCV search and final estimation: + - λ_time=∞ or λ_unit=∞ → 0.0 (uniform weights via exp(-0×d)=1) + - λ_nn=∞ → 1e10 (large penalty → L≈0, factor model disabled) + - Conversion applied to grid values during LOOCV (including Rust backend) + - Conversion applied to selected values for point estimation + - Conversion applied to selected values for variance estimation (ensures SE matches ATT) + - **Results storage**: `TROPResults` stores *original* grid values (e.g., inf), while computations use converted values. This lets users see what was selected from their grid. +- **Empty control observations**: If LOOCV control observations become empty (edge case during subsampling), returns Q(λ) = ∞ with warning. A score of 0.0 would incorrectly "win" over legitimate parameters. +- **Infinite LOOCV score handling**: If best LOOCV score is infinite, `best_lambda` is set to None, triggering defaults fallback +- Validation: requires at least 2 periods before first treatment +- **D matrix validation**: Treatment indicator must be an absorbing state (monotonic non-decreasing per unit) + - Detection: `np.diff(D, axis=0) < 0` for any column indicates violation + - Handling: Raises `ValueError` with list of violating unit IDs and remediation guidance + - Error message includes: "convert to absorbing state: D[t, i] = 1 for all t >= first treatment period" + - **Rationale**: Event-style D (0→1→0) silently biases ATT; runtime validation prevents misuse + - **Unbalanced panels**: Missing unit-period observations are allowed. Monotonicity validation checks each unit's *observed* D sequence for monotonicity, which correctly catches 1→0 violations that span missing period gaps (e.g., D[2]=1, missing [3,4], D[5]=0 is detected as a violation even though the gap hides the transition in adjacent-period checks). + - **n_post_periods metadata**: Counts periods where D=1 is actually observed (at least one unit has D=1), not calendar periods from first treatment. In unbalanced panels where treated units are missing in some post-treatment periods, only periods with observed D=1 values are counted. +- Wrong D specification: if user provides event-style D (only first treatment period), + the absorbing-state validation will raise ValueError with helpful guidance +- **LOOCV failure metadata**: When LOOCV fits fail in the Rust backend, the first failed observation coordinates (t, i) are returned to Python for informative warning messages **Reference implementation(s):** - Authors' replication code (forthcoming) **Requirements checklist:** -- [ ] Factor matrix estimated via soft-threshold SVD -- [ ] Unit weights: `exp(-λ_unit × distance)` with normalization -- [ ] LOOCV implemented for tuning parameter selection -- [ ] Multiple rank selection methods: cv, ic, elbow -- [ ] Returns factor loadings and scores for interpretation +- [x] Factor matrix estimated via soft-threshold SVD +- [x] Unit weights: `exp(-λ_unit × distance)` with normalization +- [x] LOOCV implemented for tuning parameter selection +- [x] LOOCV uses SUM of squared errors per Equation 5 +- [x] Multiple rank selection methods: cv, ic, elbow +- [x] Returns factor loadings and scores for interpretation +- [x] ATT averages over all D==1 cells (general assignment patterns) +- [x] No post_periods parameter (D matrix determines treatment timing) +- [x] D matrix semantics documented (absorbing state, not event indicator) +- [x] Unbalanced panels supported (missing observations don't trigger false violations) --- diff --git a/docs/tutorials/10_trop.ipynb b/docs/tutorials/10_trop.ipynb index 056f1658..a2aa020a 100644 --- a/docs/tutorials/10_trop.ipynb +++ b/docs/tutorials/10_trop.ipynb @@ -161,29 +161,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# Fit TROP with automatic tuning via LOOCV\n", - "trop_est = TROP(\n", - " lambda_time_grid=[0.0, 1.0], # Reduced time decay grid\n", - " lambda_unit_grid=[0.0, 1.0], # Reduced unit distance grid \n", - " lambda_nn_grid=[0.0, 0.1], # Reduced nuclear norm grid\n", - " n_bootstrap=50, # Reduced bootstrap replications for SE\n", - " seed=42\n", - ")\n", - "\n", - "post_periods = list(range(n_pre, n_pre + n_post))\n", - "\n", - "results = trop_est.fit(\n", - " df,\n", - " outcome='outcome',\n", - " treatment='treated',\n", - " unit='unit',\n", - " time='period',\n", - " post_periods=post_periods\n", - ")\n", - "\n", - "print(results.summary())" - ] + "source": "# Fit TROP with automatic tuning via LOOCV\ntrop_est = TROP(\n lambda_time_grid=[0.0, 1.0], # Reduced time decay grid\n lambda_unit_grid=[0.0, 1.0], # Reduced unit distance grid \n lambda_nn_grid=[0.0, 0.1], # Reduced nuclear norm grid\n n_bootstrap=50, # Reduced bootstrap replications for SE\n seed=42\n)\n\n# Note: TROP infers treatment periods from the treatment indicator column.\n# The 'treated' column should be an absorbing state (D=1 for all periods\n# during and after treatment starts).\n\n# For SDID comparison later, we keep post_periods for SDID\npost_periods = list(range(n_pre, n_pre + n_post))\n\nresults = trop_est.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\n\nprint(results.summary())" }, { "cell_type": "code", @@ -239,34 +217,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# Effect of different nuclear norm regularization levels\n", - "print(\"Effect of nuclear norm regularization (λ_nn):\")\n", - "print(\"=\"*65)\n", - "print(f\"{'λ_nn':>10} {'ATT':>12} {'Bias':>12} {'Eff. Rank':>15}\")\n", - "print(\"-\"*65)\n", - "\n", - "for lambda_nn in [0.0, 0.1, 1.0]: # Reduced grid\n", - " trop_fixed = TROP(\n", - " lambda_time_grid=[1.0], # Fixed\n", - " lambda_unit_grid=[1.0], # Fixed\n", - " lambda_nn_grid=[lambda_nn], # Vary this\n", - " n_bootstrap=20, # Reduced for faster execution\n", - " seed=42\n", - " )\n", - " \n", - " res = trop_fixed.fit(\n", - " df,\n", - " outcome='outcome',\n", - " treatment='treated',\n", - " unit='unit',\n", - " time='period',\n", - " post_periods=post_periods\n", - " )\n", - " \n", - " bias = res.att - true_att\n", - " print(f\"{lambda_nn:>10.1f} {res.att:>12.4f} {bias:>12.4f} {res.effective_rank:>15.2f}\")" - ] + "source": "# Effect of different nuclear norm regularization levels\nprint(\"Effect of nuclear norm regularization (λ_nn):\")\nprint(\"=\"*65)\nprint(f\"{'λ_nn':>10} {'ATT':>12} {'Bias':>12} {'Eff. Rank':>15}\")\nprint(\"-\"*65)\n\nfor lambda_nn in [0.0, 0.1, 1.0]: # Reduced grid\n trop_fixed = TROP(\n lambda_time_grid=[1.0], # Fixed\n lambda_unit_grid=[1.0], # Fixed\n lambda_nn_grid=[lambda_nn], # Vary this\n n_bootstrap=20, # Reduced for faster execution\n seed=42\n )\n \n res = trop_fixed.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n )\n \n bias = res.att - true_att\n print(f\"{lambda_nn:>10.1f} {res.att:>12.4f} {bias:>12.4f} {res.effective_rank:>15.2f}\")" }, { "cell_type": "markdown", @@ -402,55 +353,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# SDID (no factor adjustment)\n", - "# Note: SDID uses 'treat' (unit-level ever-treated indicator)\n", - "sdid = SyntheticDiD(\n", - " n_bootstrap=50, # Reduced for faster execution\n", - " seed=42\n", - ")\n", - "\n", - "sdid_results = sdid.fit(\n", - " df,\n", - " outcome='outcome',\n", - " treatment='treat', # Unit-level ever-treated indicator\n", - " unit='unit',\n", - " time='period',\n", - " post_periods=post_periods\n", - ")\n", - "\n", - "# TROP (with factor adjustment)\n", - "# Note: TROP uses 'treated' (observation-level treatment indicator)\n", - "trop_est2 = TROP(\n", - " lambda_nn_grid=[0.0, 0.1], # Reduced grid for faster execution\n", - " n_bootstrap=50, # Reduced for faster execution\n", - " seed=42\n", - ")\n", - "\n", - "trop_results = trop_est2.fit(\n", - " df,\n", - " outcome='outcome',\n", - " treatment='treated', # Observation-level indicator\n", - " unit='unit',\n", - " time='period',\n", - " post_periods=post_periods\n", - ")\n", - "\n", - "print(\"Comparison: SDID vs TROP\")\n", - "print(\"=\"*60)\n", - "print(f\"True ATT: {true_att:.4f}\")\n", - "print()\n", - "print(f\"Synthetic DiD (no factor adjustment):\")\n", - "print(f\" ATT: {sdid_results.att:.4f}\")\n", - "print(f\" SE: {sdid_results.se:.4f}\")\n", - "print(f\" Bias: {sdid_results.att - true_att:.4f}\")\n", - "print()\n", - "print(f\"TROP (with factor adjustment):\")\n", - "print(f\" ATT: {trop_results.att:.4f}\")\n", - "print(f\" SE: {trop_results.se:.4f}\")\n", - "print(f\" Bias: {trop_results.att - true_att:.4f}\")\n", - "print(f\" Effective rank: {trop_results.effective_rank:.2f}\")" - ] + "source": "# SDID (no factor adjustment)\n# Note: SDID uses 'treat' (unit-level ever-treated indicator)\nsdid = SyntheticDiD(\n n_bootstrap=50, # Reduced for faster execution\n seed=42\n)\n\n# SDID still uses post_periods parameter\nsdid_results = sdid.fit(\n df,\n outcome='outcome',\n treatment='treat', # Unit-level ever-treated indicator\n unit='unit',\n time='period',\n post_periods=post_periods\n)\n\n# TROP (with factor adjustment)\n# Note: TROP uses 'treated' (observation-level treatment indicator)\n# and infers treatment periods automatically\ntrop_est2 = TROP(\n lambda_nn_grid=[0.0, 0.1], # Reduced grid for faster execution\n n_bootstrap=50, # Reduced for faster execution\n seed=42\n)\n\ntrop_results = trop_est2.fit(\n df,\n outcome='outcome',\n treatment='treated', # Observation-level indicator\n unit='unit',\n time='period'\n)\n\nprint(\"Comparison: SDID vs TROP\")\nprint(\"=\"*60)\nprint(f\"True ATT: {true_att:.4f}\")\nprint()\nprint(f\"Synthetic DiD (no factor adjustment):\")\nprint(f\" ATT: {sdid_results.att:.4f}\")\nprint(f\" SE: {sdid_results.se:.4f}\")\nprint(f\" Bias: {sdid_results.att - true_att:.4f}\")\nprint()\nprint(f\"TROP (with factor adjustment):\")\nprint(f\" ATT: {trop_results.att:.4f}\")\nprint(f\" SE: {trop_results.se:.4f}\")\nprint(f\" Bias: {trop_results.att - true_att:.4f}\")\nprint(f\" Effective rank: {trop_results.effective_rank:.2f}\")" }, { "cell_type": "markdown", @@ -466,82 +369,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# Monte Carlo comparison (reduced for faster tutorial execution)\n", - "n_sims = 5 # Reduced from 20 for faster validation\n", - "trop_estimates = []\n", - "sdid_estimates = []\n", - "\n", - "print(f\"Running {n_sims} simulations...\")\n", - "\n", - "for sim in range(n_sims):\n", - " # Generate new data using the library function\n", - " # (includes both 'treated' and 'treat' columns)\n", - " sim_data = generate_factor_data(\n", - " n_units=50,\n", - " n_pre=10,\n", - " n_post=5,\n", - " n_treated=10,\n", - " n_factors=2,\n", - " treatment_effect=2.0,\n", - " factor_strength=1.5,\n", - " noise_sd=0.5,\n", - " seed=100 + sim\n", - " )\n", - " \n", - " # TROP (uses observation-level 'treated')\n", - " try:\n", - " trop_m = TROP(\n", - " lambda_time_grid=[1.0],\n", - " lambda_unit_grid=[1.0],\n", - " lambda_nn_grid=[0.1],\n", - " n_bootstrap=10, \n", - " seed=42 + sim\n", - " )\n", - " trop_res = trop_m.fit(\n", - " sim_data,\n", - " outcome='outcome',\n", - " treatment='treated',\n", - " unit='unit',\n", - " time='period',\n", - " post_periods=list(range(10, 15))\n", - " )\n", - " trop_estimates.append(trop_res.att)\n", - " except Exception as e:\n", - " print(f\"TROP failed on sim {sim}: {e}\")\n", - " \n", - " # SDID (uses unit-level 'treat')\n", - " try:\n", - " sdid_m = SyntheticDiD(n_bootstrap=10, seed=42 + sim)\n", - " sdid_res = sdid_m.fit(\n", - " sim_data,\n", - " outcome='outcome',\n", - " treatment='treat', # Unit-level ever-treated indicator\n", - " unit='unit',\n", - " time='period',\n", - " post_periods=list(range(10, 15))\n", - " )\n", - " sdid_estimates.append(sdid_res.att)\n", - " except Exception as e:\n", - " print(f\"SDID failed on sim {sim}: {e}\")\n", - "\n", - "print(f\"\\nMonte Carlo Results (True ATT = {true_att})\")\n", - "print(\"=\"*60)\n", - "print(f\"{'Estimator':<15} {'Mean':>12} {'Bias':>12} {'RMSE':>12}\")\n", - "print(\"-\"*60)\n", - "\n", - "if trop_estimates:\n", - " trop_mean = np.mean(trop_estimates)\n", - " trop_bias = trop_mean - true_att\n", - " trop_rmse = np.sqrt(np.mean([(e - true_att)**2 for e in trop_estimates]))\n", - " print(f\"{'TROP':<15} {trop_mean:>12.4f} {trop_bias:>12.4f} {trop_rmse:>12.4f}\")\n", - "\n", - "if sdid_estimates:\n", - " sdid_mean = np.mean(sdid_estimates)\n", - " sdid_bias = sdid_mean - true_att\n", - " sdid_rmse = np.sqrt(np.mean([(e - true_att)**2 for e in sdid_estimates]))\n", - " print(f\"{'SDID':<15} {sdid_mean:>12.4f} {sdid_bias:>12.4f} {sdid_rmse:>12.4f}\")" - ] + "source": "# Monte Carlo comparison (reduced for faster tutorial execution)\nn_sims = 5 # Reduced from 20 for faster validation\ntrop_estimates = []\nsdid_estimates = []\n\nprint(f\"Running {n_sims} simulations...\")\n\nfor sim in range(n_sims):\n # Generate new data using the library function\n # (includes both 'treated' and 'treat' columns)\n sim_data = generate_factor_data(\n n_units=50,\n n_pre=10,\n n_post=5,\n n_treated=10,\n n_factors=2,\n treatment_effect=2.0,\n factor_strength=1.5,\n noise_sd=0.5,\n seed=100 + sim\n )\n \n # TROP (uses observation-level 'treated')\n # Note: TROP infers treatment periods from the treatment indicator\n try:\n trop_m = TROP(\n lambda_time_grid=[1.0],\n lambda_unit_grid=[1.0],\n lambda_nn_grid=[0.1],\n n_bootstrap=10, \n seed=42 + sim\n )\n trop_res = trop_m.fit(\n sim_data,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n )\n trop_estimates.append(trop_res.att)\n except Exception as e:\n print(f\"TROP failed on sim {sim}: {e}\")\n \n # SDID (uses unit-level 'treat')\n # Note: SDID still uses post_periods parameter\n try:\n sdid_m = SyntheticDiD(n_bootstrap=10, seed=42 + sim)\n sdid_res = sdid_m.fit(\n sim_data,\n outcome='outcome',\n treatment='treat', # Unit-level ever-treated indicator\n unit='unit',\n time='period',\n post_periods=list(range(10, 15))\n )\n sdid_estimates.append(sdid_res.att)\n except Exception as e:\n print(f\"SDID failed on sim {sim}: {e}\")\n\nprint(f\"\\nMonte Carlo Results (True ATT = {true_att})\")\nprint(\"=\"*60)\nprint(f\"{'Estimator':<15} {'Mean':>12} {'Bias':>12} {'RMSE':>12}\")\nprint(\"-\"*60)\n\nif trop_estimates:\n trop_mean = np.mean(trop_estimates)\n trop_bias = trop_mean - true_att\n trop_rmse = np.sqrt(np.mean([(e - true_att)**2 for e in trop_estimates]))\n print(f\"{'TROP':<15} {trop_mean:>12.4f} {trop_bias:>12.4f} {trop_rmse:>12.4f}\")\n\nif sdid_estimates:\n sdid_mean = np.mean(sdid_estimates)\n sdid_bias = sdid_mean - true_att\n sdid_rmse = np.sqrt(np.mean([(e - true_att)**2 for e in sdid_estimates]))\n print(f\"{'SDID':<15} {sdid_mean:>12.4f} {sdid_bias:>12.4f} {sdid_rmse:>12.4f}\")" }, { "cell_type": "code", @@ -579,27 +407,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# One-liner estimation with default tuning grid\n", - "quick_results = trop(\n", - " df,\n", - " outcome='outcome',\n", - " treatment='treated',\n", - " unit='unit',\n", - " time='period',\n", - " post_periods=post_periods,\n", - " n_bootstrap=20, # Reduced for faster execution\n", - " seed=42\n", - ")\n", - "\n", - "print(f\"Quick estimation:\")\n", - "print(f\" ATT: {quick_results.att:.4f}\")\n", - "print(f\" SE: {quick_results.se:.4f}\")\n", - "print(f\" λ_time: {quick_results.lambda_time:.2f}\")\n", - "print(f\" λ_unit: {quick_results.lambda_unit:.2f}\")\n", - "print(f\" λ_nn: {quick_results.lambda_nn:.2f}\")\n", - "print(f\" Effective rank: {quick_results.effective_rank:.2f}\")" - ] + "source": "# One-liner estimation with default tuning grid\n# Note: TROP infers treatment periods from the treatment indicator\nquick_results = trop(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period',\n n_bootstrap=20, # Reduced for faster execution\n seed=42\n)\n\nprint(f\"Quick estimation:\")\nprint(f\" ATT: {quick_results.att:.4f}\")\nprint(f\" SE: {quick_results.se:.4f}\")\nprint(f\" λ_time: {quick_results.lambda_time:.2f}\")\nprint(f\" λ_unit: {quick_results.lambda_unit:.2f}\")\nprint(f\" λ_nn: {quick_results.lambda_nn:.2f}\")\nprint(f\" Effective rank: {quick_results.effective_rank:.2f}\")" }, { "cell_type": "markdown", @@ -617,35 +425,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# Compare variance estimation methods\n", - "print(\"Variance estimation comparison:\")\n", - "print(\"=\"*50)\n", - "\n", - "for method in ['bootstrap', 'jackknife']:\n", - " trop_var = TROP(\n", - " lambda_time_grid=[1.0],\n", - " lambda_unit_grid=[1.0], \n", - " lambda_nn_grid=[0.1],\n", - " variance_method=method,\n", - " n_bootstrap=30, # Reduced for faster execution\n", - " seed=42\n", - " )\n", - " \n", - " res = trop_var.fit(\n", - " df,\n", - " outcome='outcome',\n", - " treatment='treated',\n", - " unit='unit',\n", - " time='period',\n", - " post_periods=post_periods\n", - " )\n", - " \n", - " print(f\"\\n{method.capitalize()}:\")\n", - " print(f\" ATT: {res.att:.4f}\")\n", - " print(f\" SE: {res.se:.4f}\")\n", - " print(f\" 95% CI: [{res.conf_int[0]:.4f}, {res.conf_int[1]:.4f}]\")" - ] + "source": "# Compare variance estimation methods\nprint(\"Variance estimation comparison:\")\nprint(\"=\"*50)\n\nfor method in ['bootstrap', 'jackknife']:\n trop_var = TROP(\n lambda_time_grid=[1.0],\n lambda_unit_grid=[1.0], \n lambda_nn_grid=[0.1],\n variance_method=method,\n n_bootstrap=30, # Reduced for faster execution\n seed=42\n )\n \n res = trop_var.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n )\n \n print(f\"\\n{method.capitalize()}:\")\n print(f\" ATT: {res.att:.4f}\")\n print(f\" SE: {res.se:.4f}\")\n print(f\" 95% CI: [{res.conf_int[0]:.4f}, {res.conf_int[1]:.4f}]\")" }, { "cell_type": "markdown", @@ -728,4 +508,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 10ccf226..382badd5 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -146,10 +146,160 @@ fn compute_pair_distance( } } -/// Perform LOOCV grid search over tuning parameters in parallel. +/// Perform univariate LOOCV search over a single parameter. /// -/// Evaluates all combinations of (lambda_time, lambda_unit, lambda_nn) in parallel -/// and returns the combination with the lowest LOOCV score. +/// Following paper's footnote 2, this performs a grid search for one parameter +/// while holding others fixed. Used in the two-stage LOOCV approach. +/// +/// # Arguments +/// * `y` - Outcome matrix (n_periods x n_units) +/// * `d` - Treatment indicator matrix +/// * `control_mask` - Boolean mask for control observations +/// * `time_dist` - Time distance matrix +/// * `control_obs` - List of control observations for LOOCV +/// * `grid` - Grid of values to search +/// * `fixed_time` - Fixed lambda_time (inf for disabled) +/// * `fixed_unit` - Fixed lambda_unit (inf for disabled) +/// * `fixed_nn` - Fixed lambda_nn (inf for disabled) +/// * `param_type` - Which parameter to search: 0=time, 1=unit, 2=nn +/// * `max_iter` - Maximum iterations +/// * `tol` - Convergence tolerance +/// +/// # Returns +/// (best_value, best_score) +fn univariate_loocv_search( + y: &ArrayView2, + d: &ArrayView2, + control_mask: &ArrayView2, + time_dist: &ArrayView2, + control_obs: &[(usize, usize)], + grid: &[f64], + fixed_time: f64, + fixed_unit: f64, + fixed_nn: f64, + param_type: usize, // 0=time, 1=unit, 2=nn + max_iter: usize, + tol: f64, +) -> (f64, f64) { + let mut best_score = f64::INFINITY; + let mut best_value = grid.first().copied().unwrap_or(0.0); + + // Parallelize over grid values + let results: Vec<(f64, f64)> = grid + .par_iter() + .map(|&value| { + // Set parameters, converting inf for "disabled" mode + // Per paper Equations 2-3: + // - λ_time/λ_unit=∞ → uniform weights → use 0.0 + // - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10 + // Note: λ_nn=0 means NO regularization (full-rank L), opposite of "disabled" + // + // IMPORTANT: Convert the grid value BEFORE using it, matching Python behavior. + // This ensures Rust and Python evaluate the same objective for infinity grids. + let (lambda_time, lambda_unit, lambda_nn) = match param_type { + 0 => { + // Searching λ_time: convert grid value if infinite + let value_converted = if value.is_infinite() { 0.0 } else { value }; + (value_converted, + if fixed_unit.is_infinite() { 0.0 } else { fixed_unit }, + if fixed_nn.is_infinite() { 1e10 } else { fixed_nn }) + }, + 1 => { + // Searching λ_unit: convert grid value if infinite + let value_converted = if value.is_infinite() { 0.0 } else { value }; + (if fixed_time.is_infinite() { 0.0 } else { fixed_time }, + value_converted, + if fixed_nn.is_infinite() { 1e10 } else { fixed_nn }) + }, + _ => { + // Searching λ_nn: convert grid value if infinite + let value_converted = if value.is_infinite() { 1e10 } else { value }; + (if fixed_time.is_infinite() { 0.0 } else { fixed_time }, + if fixed_unit.is_infinite() { 0.0 } else { fixed_unit }, + value_converted) + }, + }; + + let (score, _, _) = loocv_score_for_params( + y, d, control_mask, time_dist, control_obs, + lambda_time, lambda_unit, lambda_nn, + max_iter, tol, + ); + (value, score) + }) + .collect(); + + for (value, score) in results { + if score < best_score { + best_score = score; + best_value = value; + } + } + + (best_value, best_score) +} + +/// Cycle through parameters until convergence (coordinate descent). +/// +/// Following paper's footnote 2 (Stage 2), iteratively optimize each parameter. +fn cycling_parameter_search( + y: &ArrayView2, + d: &ArrayView2, + control_mask: &ArrayView2, + time_dist: &ArrayView2, + control_obs: &[(usize, usize)], + lambda_time_grid: &[f64], + lambda_unit_grid: &[f64], + lambda_nn_grid: &[f64], + initial_time: f64, + initial_unit: f64, + initial_nn: f64, + max_iter: usize, + tol: f64, + max_cycles: usize, +) -> (f64, f64, f64) { + let mut lambda_time = initial_time; + let mut lambda_unit = initial_unit; + let mut lambda_nn = initial_nn; + let mut prev_score = f64::INFINITY; + + for _cycle in 0..max_cycles { + // Optimize λ_unit (fix λ_time, λ_nn) + let (new_unit, _) = univariate_loocv_search( + y, d, control_mask, time_dist, control_obs, + lambda_unit_grid, lambda_time, 0.0, lambda_nn, 1, max_iter, tol, + ); + lambda_unit = new_unit; + + // Optimize λ_time (fix λ_unit, λ_nn) + let (new_time, _) = univariate_loocv_search( + y, d, control_mask, time_dist, control_obs, + lambda_time_grid, 0.0, lambda_unit, lambda_nn, 0, max_iter, tol, + ); + lambda_time = new_time; + + // Optimize λ_nn (fix λ_unit, λ_time) + let (new_nn, score) = univariate_loocv_search( + y, d, control_mask, time_dist, control_obs, + lambda_nn_grid, lambda_time, lambda_unit, 0.0, 2, max_iter, tol, + ); + lambda_nn = new_nn; + + // Check convergence + if (score - prev_score).abs() < 1e-6 { + break; + } + prev_score = score; + } + + (lambda_time, lambda_unit, lambda_nn) +} + +/// Perform LOOCV grid search over tuning parameters using two-stage approach. +/// +/// Following paper's footnote 2: +/// - Stage 1: Univariate searches for initial values with extreme fixed parameters +/// - Stage 2: Cycling (coordinate descent) until convergence /// /// Following TROP Equation 5 (page 8): /// Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² @@ -158,8 +308,6 @@ fn compute_pair_distance( /// * `y` - Outcome matrix (n_periods x n_units) /// * `d` - Treatment indicator matrix (n_periods x n_units) /// * `control_mask` - Boolean mask (n_periods x n_units) for control observations -/// * `control_unit_idx` - Array of control unit indices -/// * `unit_dist_matrix` - Pre-computed unit distance matrix (n_units x n_units) /// * `time_dist_matrix` - Pre-computed time distance matrix (n_periods x n_periods) /// * `lambda_time_grid` - Grid of time decay parameters /// * `lambda_unit_grid` - Grid of unit distance parameters @@ -170,7 +318,10 @@ fn compute_pair_distance( /// * `seed` - Random seed for subsampling /// /// # Returns -/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score) +/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score, n_valid, n_attempted, first_failed_obs) +/// where n_valid and n_attempted are the counts for the best parameter combination, +/// allowing Python to emit warnings when >10% of fits fail. +/// first_failed_obs is Some((t, i)) if a fit failed during final score computation, None otherwise. #[pyfunction] #[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))] #[allow(clippy::too_many_arguments)] @@ -187,7 +338,7 @@ pub fn loocv_grid_search<'py>( max_iter: usize, tol: f64, seed: u64, -) -> PyResult<(f64, f64, f64, f64)> { +) -> PyResult<(f64, f64, f64, f64, usize, usize, Option<(usize, usize)>)> { let y_arr = y.as_array(); let d_arr = d.as_array(); let control_mask_arr = control_mask.as_array(); @@ -204,43 +355,53 @@ pub fn loocv_grid_search<'py>( seed, ); - // Generate all parameter combinations - let mut param_combos: Vec<(f64, f64, f64)> = Vec::new(); - for < in &lambda_time_vec { - for &lu in &lambda_unit_vec { - for &ln in &lambda_nn_vec { - param_combos.push((lt, lu, ln)); - } - } - } + let n_attempted = control_obs.len(); - // Evaluate all combinations in parallel - let results: Vec<(f64, f64, f64, f64)> = param_combos - .par_iter() - .map(|&(lambda_time, lambda_unit, lambda_nn)| { - let score = loocv_score_for_params( - &y_arr, - &d_arr, - &control_mask_arr, - &time_dist_arr, - &control_obs, - lambda_time, - lambda_unit, - lambda_nn, - max_iter, - tol, - ); - (lambda_time, lambda_unit, lambda_nn, score) - }) - .collect(); + // Stage 1: Univariate searches for initial values (paper footnote 2) + // λ_time search: fix λ_unit=0, λ_nn=∞ (disabled) + let (lambda_time_init, _) = univariate_loocv_search( + &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs, + &lambda_time_vec, 0.0, 0.0, f64::INFINITY, 0, max_iter, tol, + ); - // Find best (minimum score) - let best = results - .into_iter() - .min_by(|a, b| a.3.partial_cmp(&b.3).unwrap_or(std::cmp::Ordering::Equal)) - .unwrap_or((1.0, 1.0, 0.1, f64::INFINITY)); + // λ_nn search: fix λ_time=∞ (disabled), λ_unit=0 + let (lambda_nn_init, _) = univariate_loocv_search( + &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs, + &lambda_nn_vec, f64::INFINITY, 0.0, 0.0, 2, max_iter, tol, + ); + + // λ_unit search: fix λ_nn=∞, λ_time=0 + let (lambda_unit_init, _) = univariate_loocv_search( + &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs, + &lambda_unit_vec, 0.0, 0.0, f64::INFINITY, 1, max_iter, tol, + ); - Ok(best) + // Stage 2: Cycling refinement + let (best_time, best_unit, best_nn) = cycling_parameter_search( + &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs, + &lambda_time_vec, &lambda_unit_vec, &lambda_nn_vec, + lambda_time_init, lambda_unit_init, lambda_nn_init, + max_iter, tol, 10, + ); + + // Convert infinity values BEFORE computing final score (Issue 1 fix) + // Per paper Equations 2-3: + // - λ_time/λ_unit=∞ → uniform weights → use 0.0 + // - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10 + // This ensures final score computation matches what LOOCV evaluated. + let best_time_eff = if best_time.is_infinite() { 0.0 } else { best_time }; + let best_unit_eff = if best_unit.is_infinite() { 0.0 } else { best_unit }; + let best_nn_eff = if best_nn.is_infinite() { 1e10 } else { best_nn }; + + // Compute final score with converted values + let (best_score, n_valid, first_failed) = loocv_score_for_params( + &y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs, + best_time_eff, best_unit_eff, best_nn_eff, + max_iter, tol, + ); + + // Return ORIGINAL grid values (for user visibility) but score computed with converted + Ok((best_time, best_unit, best_nn, best_score, n_valid, n_attempted, first_failed)) } /// Get sampled control observations for LOOCV. @@ -277,6 +438,10 @@ fn get_control_observations( } /// Compute LOOCV score for a specific parameter combination. +/// +/// # Returns +/// (score, n_valid, first_failed_obs) - the LOOCV score, number of successful fits, +/// and the first failed observation (t, i) if any fit failed, None otherwise. #[allow(clippy::too_many_arguments)] fn loocv_score_for_params( y: &ArrayView2, @@ -289,7 +454,7 @@ fn loocv_score_for_params( lambda_nn: f64, max_iter: usize, tol: f64, -) -> f64 { +) -> (f64, usize, Option<(usize, usize)>) { let n_periods = y.nrows(); let n_units = y.ncols(); @@ -328,14 +493,21 @@ fn loocv_score_for_params( tau_sq_sum += tau * tau; n_valid += 1; } - None => continue, // Skip if estimation failed + None => { + // Per Equation 5: Q(λ) must sum over ALL D==0 cells + // Any failure means this λ cannot produce valid estimates for all cells + // Return the failed observation (t, i) for warning metadata + return (f64::INFINITY, n_valid, Some((t, i))); + } } } if n_valid == 0 { - f64::INFINITY + (f64::INFINITY, 0, None) } else { - tau_sq_sum / n_valid as f64 + // Return SUM of squared pseudo-treatment effects per Equation 5 (page 8): + // Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² + (tau_sq_sum, n_valid, None) } } diff --git a/tests/test_rust_backend.py b/tests/test_rust_backend.py index 5362cb7f..332f5321 100644 --- a/tests/test_rust_backend.py +++ b/tests/test_rust_backend.py @@ -970,7 +970,7 @@ def test_loocv_grid_search_returns_valid_params(self): lambda_unit = np.array([0.0, 1.0], dtype=np.float64) lambda_nn = np.array([0.0, 0.1], dtype=np.float64) - best_lt, best_lu, best_ln, score = loocv_grid_search( + best_lt, best_lu, best_ln, score, n_valid, n_attempted, first_failed = loocv_grid_search( Y, D, control_mask, time_dist, lambda_time, lambda_unit, lambda_nn, 50, 100, 1e-6, 42 @@ -981,6 +981,12 @@ def test_loocv_grid_search_returns_valid_params(self): assert best_lu in lambda_unit assert best_ln in lambda_nn assert np.isfinite(score) or score == np.inf + # Check failure counts are valid + assert n_valid >= 0 + assert n_attempted >= 0 + assert n_valid <= n_attempted + # Check first_failed is None or a valid (unit, time) tuple + assert first_failed is None or (isinstance(first_failed, tuple) and len(first_failed) == 2) def test_bootstrap_variance_shape(self): """Test bootstrap returns correct shapes.""" diff --git a/tests/test_trop.py b/tests/test_trop.py index 2bec030a..2babc901 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -1,5 +1,7 @@ """Tests for Triply Robust Panel (TROP) estimator.""" +import warnings + import numpy as np import pandas as pd import pytest @@ -108,7 +110,6 @@ def test_basic_fit(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) assert isinstance(results, TROPResults) @@ -126,14 +127,12 @@ def test_fit_with_factors(self, factor_dgp_data): n_bootstrap=20, seed=42 ) - post_periods = list(range(8, 12)) results = trop_est.fit( factor_dgp_data, outcome="outcome", treatment="treated", unit="unit", time="period", - post_periods=post_periods, ) assert isinstance(results, TROPResults) @@ -151,14 +150,12 @@ def test_treatment_effect_recovery(self, factor_dgp_data): n_bootstrap=30, seed=42 ) - post_periods = list(range(8, 12)) results = trop_est.fit( factor_dgp_data, outcome="outcome", treatment="treated", unit="unit", time="period", - post_periods=post_periods, ) # ATT should be positive (correct direction) @@ -181,7 +178,6 @@ def test_tuning_parameter_selection(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) # Check that lambda values are from the grid @@ -205,7 +201,6 @@ def test_bootstrap_variance(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) assert results.se > 0 @@ -228,7 +223,6 @@ def test_jackknife_variance(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) assert results.se >= 0 @@ -250,7 +244,6 @@ def test_confidence_interval(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) lower, upper = results.conf_int @@ -356,7 +349,6 @@ def test_summary(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) summary = results.summary() @@ -381,7 +373,6 @@ def test_to_dict(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) d = results.to_dict() @@ -407,7 +398,6 @@ def test_to_dataframe(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) df = results.to_dataframe() @@ -430,7 +420,6 @@ def test_get_treatment_effects_df(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) effects_df = results.get_treatment_effects_df() @@ -455,7 +444,6 @@ def test_get_unit_effects_df(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) effects_df = results.get_unit_effects_df() @@ -478,7 +466,6 @@ def test_get_time_effects_df(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) effects_df = results.get_time_effects_df() @@ -502,7 +489,6 @@ def test_is_significant(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) assert isinstance(results.is_significant, bool) @@ -522,12 +508,51 @@ def test_significance_stars(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) stars = results.significance_stars assert stars in ["", ".", "*", "**", "***"] + def test_nan_propagation_when_se_zero(self): + """Test that inference fields are NaN when SE is zero/undefined. + + This verifies the P0 fix: when SE <= 0, all inference fields + (t_stat, p_value, conf_int) should be NaN, not finite values. + """ + from diff_diff.trop import TROPResults + + # Create a TROPResults directly with SE=0 + results = TROPResults( + att=1.0, + se=0.0, # Zero SE - inference should be undefined + t_stat=np.nan, + p_value=np.nan, + conf_int=(np.nan, np.nan), + n_obs=100, + n_treated=5, + n_control=10, + n_treated_obs=20, + unit_effects={0: 0.1, 1: 0.2}, + time_effects={0: 0.0, 1: 0.1}, + treatment_effects={(0, 5): 1.0}, + lambda_time=1.0, + lambda_unit=1.0, + lambda_nn=0.1, + factor_matrix=np.zeros((10, 15)), + effective_rank=2.0, + loocv_score=0.5, + variance_method="bootstrap", + ) + + # Verify that all inference fields are NaN when SE=0 + assert np.isnan(results.t_stat), "t_stat should be NaN when SE=0" + assert np.isnan(results.p_value), "p_value should be NaN when SE=0" + assert np.isnan(results.conf_int[0]), "conf_int[0] should be NaN when SE=0" + assert np.isnan(results.conf_int[1]), "conf_int[1] should be NaN when SE=0" + + # Verify the ATT itself is still valid + assert results.att == 1.0, "ATT should still be valid" + class TestTROPvsSDID: """Tests comparing TROP to SDID under different DGPs.""" @@ -545,7 +570,6 @@ def test_trop_handles_factor_dgp(self): noise_std=0.5, seed=42, ) - post_periods = list(range(8, 12)) # TROP should complete without error trop_est = TROP( @@ -561,7 +585,6 @@ def test_trop_handles_factor_dgp(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods, ) assert results.att != 0 @@ -579,7 +602,6 @@ def test_convenience_function(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -598,7 +620,6 @@ def test_convenience_with_kwargs(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], lambda_time_grid=[0.0, 0.5, 1.0], lambda_unit_grid=[0.0, 0.5], lambda_nn_grid=[0.0, 0.1], @@ -654,7 +675,6 @@ def test_limiting_case_uniform_weights(self): }) df = pd.DataFrame(data) - post_periods = list(range(n_pre, n_pre + n_post)) # TROP with uniform weights trop_est = TROP( @@ -670,7 +690,6 @@ def test_limiting_case_uniform_weights(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods, ) # Should recover treatment effect within reasonable tolerance @@ -722,7 +741,6 @@ def test_unit_weights_reduce_bias(self): }) df = pd.DataFrame(data) - post_periods = list(range(n_pre, n_pre + n_post)) # TROP with unit weighting enabled trop_est = TROP( @@ -738,7 +756,6 @@ def test_unit_weights_reduce_bias(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods, ) # Should recover treatment effect reasonably well @@ -781,7 +798,6 @@ def test_time_weights_reduce_bias(self): }) df = pd.DataFrame(data) - post_periods = list(range(n_pre, n_pre + n_post)) # TROP with time weighting enabled trop_est = TROP( @@ -797,7 +813,6 @@ def test_time_weights_reduce_bias(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods, ) # Should recover treatment effect direction @@ -824,7 +839,6 @@ def test_factor_model_reduces_bias(self): noise_std=0.5, seed=789, ) - post_periods = list(range(10, 15)) # TROP with nuclear norm regularization trop_est = TROP( @@ -840,7 +854,6 @@ def test_factor_model_reduces_bias(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods, ) true_att = 2.0 @@ -906,7 +919,6 @@ def test_paper_dgp_recovery(self): }) df = pd.DataFrame(data) - post_periods = list(range(n_pre, n_pre + n_post)) # TROP estimation trop_est = TROP( @@ -922,7 +934,6 @@ def test_paper_dgp_recovery(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods, ) # Under null hypothesis, ATT should be close to zero @@ -964,7 +975,6 @@ def test_precomputed_structures_consistency(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) precomputed = trop_est._precomputed @@ -1058,7 +1068,6 @@ def test_vectorized_weights_computation(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) precomputed = trop_est._precomputed @@ -1170,7 +1179,6 @@ def test_reproducibility_with_seed(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -1184,7 +1192,6 @@ def test_reproducibility_with_seed(self, simple_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -1200,6 +1207,231 @@ def test_reproducibility_with_seed(self, simple_panel_data): assert results1.lambda_nn == results2.lambda_nn +class TestDMatrixValidation: + """Tests for D matrix absorbing-state validation.""" + + def test_d_matrix_absorbing_state_validation_valid(self): + """Test that valid absorbing-state D passes validation.""" + # Staggered adoption: once treated, always treated + rng = np.random.default_rng(42) + n_units = 15 + n_periods = 8 + + data = [] + for i in range(n_units): + # Different treatment timing for different units + if i < 5: + treat_period = 3 # Early adopters + elif i < 10: + treat_period = 5 # Late adopters + else: + treat_period = None # Never treated + + for t in range(n_periods): + is_treated = treat_period is not None and t >= treat_period + y = 10.0 + rng.normal(0, 0.5) + if is_treated: + y += 2.0 + data.append({ + "unit": i, + "period": t, + "outcome": y, + "treated": 1 if is_treated else 0, + }) + + df = pd.DataFrame(data) + + # Should work without error + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42 + ) + results = trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + assert results is not None + assert isinstance(results, TROPResults) + + def test_d_matrix_absorbing_state_validation_invalid(self): + """Test that non-absorbing D raises ValueError.""" + # Event-style D: only first treatment period has D=1 + data = [] + n_units = 10 + n_periods = 6 + + for i in range(n_units): + is_treated_unit = i < 3 + for t in range(n_periods): + # Event-style: D=1 only at t=3, then back to 0 + if is_treated_unit and t == 3: + treated = 1 + else: + treated = 0 + data.append({ + "unit": i, + "period": t, + "outcome": float(i + t), + "treated": treated, + }) + + df = pd.DataFrame(data) + + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5 + ) + + with pytest.raises(ValueError, match="not an absorbing state"): + trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + def test_d_matrix_validation_error_message_helpful(self): + """Test that error message includes unit IDs and remediation guidance.""" + # Event-style D for unit 5 only + data = [] + for i in range(10): + for t in range(5): + # Unit 5: D goes 0→1→0 (invalid) + if i == 5: + treated = 1 if t == 2 else 0 + else: + # Other units: proper absorbing state + treated = 1 if (i < 3 and t >= 3) else 0 + data.append({ + "unit": i, + "period": t, + "outcome": float(i + t), + "treated": treated, + }) + + df = pd.DataFrame(data) + + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5 + ) + + with pytest.raises(ValueError) as exc_info: + trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + error_msg = str(exc_info.value) + # Check that error message is helpful + assert "5" in error_msg, "Should mention unit ID 5" + assert "absorbing state" in error_msg + assert "monotonic" in error_msg.lower() or "non-decreasing" in error_msg.lower() + assert "D[t, i] = 1 for all t >= first treatment" in error_msg + + +class TestCyclingSearch: + """Tests for LOOCV cycling (coordinate descent) search.""" + + def test_cycling_search_converges(self, simple_panel_data): + """Test that cycling search converges to reasonable values.""" + trop_est = TROP( + lambda_time_grid=[0.0, 0.5, 1.0], + lambda_unit_grid=[0.0, 0.5, 1.0], + lambda_nn_grid=[0.0, 0.1, 1.0], + n_bootstrap=5, + seed=42 + ) + + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Check that lambda values are from the grid + assert results.lambda_time in trop_est.lambda_time_grid + assert results.lambda_unit in trop_est.lambda_unit_grid + assert results.lambda_nn in trop_est.lambda_nn_grid + + # Check that results are reasonable + assert np.isfinite(results.att) + assert results.se >= 0 + + def test_cycling_search_reproducible(self, simple_panel_data): + """Test that cycling search produces reproducible results.""" + results1 = trop( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + lambda_time_grid=[0.0, 0.5, 1.0], + lambda_unit_grid=[0.0, 0.5, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=10, + seed=42, + ) + + results2 = trop( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + lambda_time_grid=[0.0, 0.5, 1.0], + lambda_unit_grid=[0.0, 0.5, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=10, + seed=42, + ) + + # Results should be identical with same seed + assert results1.att == results2.att + assert results1.lambda_time == results2.lambda_time + assert results1.lambda_unit == results2.lambda_unit + assert results1.lambda_nn == results2.lambda_nn + + def test_cycling_search_single_value_grids(self, simple_panel_data): + """Test cycling search with single-value grids (degenerate case).""" + trop_est = TROP( + lambda_time_grid=[0.5], # Single value + lambda_unit_grid=[0.5], # Single value + lambda_nn_grid=[0.1], # Single value + n_bootstrap=5, + seed=42 + ) + + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Should use the only available values + assert results.lambda_time == 0.5 + assert results.lambda_unit == 0.5 + assert results.lambda_nn == 0.1 + + class TestPaperConformanceFixes: """Tests verifying fixes for paper conformance issues. @@ -1328,7 +1560,6 @@ def test_issue_b_distance_excludes_target_period(self): treatment="treated", unit="unit", time="period", - post_periods=[3, 4, 5], ) # Model should fit without error @@ -1389,7 +1620,6 @@ def test_issue_c_weighted_nuclear_norm(self): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7], ) # Factor matrix should have been estimated with non-zero effective rank @@ -1447,7 +1677,6 @@ def test_issue_d_stratified_bootstrap(self): treatment="treated", unit="unit", time="period", - post_periods=[3, 4, 5], ) # Bootstrap should complete successfully @@ -1490,3 +1719,939 @@ def test_weighted_nuclear_norm_solver_convergence(self): # Regularized singular values should be smaller than original assert np.sum(s) < np.sum(s_orig), \ "Nuclear norm regularization should reduce total singular value mass" + + +class TestAPIChangesV2_1_8: + """Tests verifying API changes in v2.1.8. + + These tests verify: + 1. post_periods parameter has been removed + 2. TROPResults uses n_pre_periods/n_post_periods instead of lists + 3. CV scoring uses sum (not average) per Equation 5 + 4. LOOCV warning is emitted when fits fail + """ + + def test_fit_no_post_periods_parameter(self, simple_panel_data): + """Test that fit() no longer accepts post_periods parameter.""" + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42 + ) + + # This should work - no post_periods parameter + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + assert results is not None + assert isinstance(results, TROPResults) + + # Verify the API change - post_periods should raise TypeError + with pytest.raises(TypeError, match="unexpected keyword argument"): + trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + post_periods=[5, 6, 7], # This should fail + ) + + def test_convenience_function_no_post_periods(self, simple_panel_data): + """Test that trop() convenience function no longer accepts post_periods.""" + # This should work + results = trop( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42, + ) + assert results is not None + + # This should fail + with pytest.raises(TypeError, match="unexpected keyword argument"): + trop( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + post_periods=[5, 6, 7], # Should fail + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42, + ) + + def test_results_has_period_counts_not_lists(self, simple_panel_data): + """Test that TROPResults has n_pre_periods/n_post_periods, not lists.""" + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42 + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Should have count attributes, not list attributes + assert hasattr(results, "n_pre_periods") + assert hasattr(results, "n_post_periods") + assert isinstance(results.n_pre_periods, int) + assert isinstance(results.n_post_periods, int) + + # Should NOT have list attributes + assert not hasattr(results, "pre_periods") + assert not hasattr(results, "post_periods") + + # Values should be correct (5 pre, 3 post in simple_panel_data) + assert results.n_pre_periods == 5 + assert results.n_post_periods == 3 + + def test_validation_still_checks_pre_periods(self): + """Test that validation still requires at least 2 pre-treatment periods.""" + # Create data with only 1 pre-treatment period + data = pd.DataFrame({ + "unit": [0, 0, 1, 1], + "period": [0, 1, 0, 1], + "outcome": [1.0, 2.0, 1.5, 2.5], + "treated": [0, 1, 0, 0], # Treatment at period 1 + }) + + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5 + ) + + with pytest.raises(ValueError, match="at least 2 pre-treatment periods"): + trop_est.fit( + data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + def test_loocv_warning_on_many_failures(self): + """Test that LOOCV emits warning when many fits fail.""" + import warnings + + # Create numerically challenging data that may cause LOOCV failures + rng = np.random.default_rng(42) + n_units = 10 + n_periods = 5 + + data = [] + for i in range(n_units): + is_treated = i < 2 + for t in range(n_periods): + post = t >= 3 + # Add some extreme values that might cause numerical issues + y = rng.normal(0, 1) if not (is_treated and post) else 1e10 + treatment_indicator = 1 if (is_treated and post) else 0 + data.append({ + "unit": i, + "period": t, + "outcome": y, + "treated": treatment_indicator, + }) + + df = pd.DataFrame(data) + + trop_est = TROP( + lambda_time_grid=[100.0], # Extreme lambda may cause issues + lambda_unit_grid=[100.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42 + ) + + # Capture warnings and verify the warning code path + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + fit_succeeded = False + try: + trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + fit_succeeded = True + except (ValueError, np.linalg.LinAlgError): + # Expected if data is too extreme - this is valid behavior + pass + + # Check for LOOCV-related warnings + loocv_warnings = [ + x for x in w + if issubclass(x.category, UserWarning) + and "LOOCV" in str(x.message) + ] + + # If fit succeeded, check that we can capture warnings properly + # (warnings may or may not be raised depending on data) + if fit_succeeded: + # At minimum, verify warnings capture infrastructure is working + # by checking that w is a list we can inspect + assert isinstance(w, list), "Warning capture should work" + + # If any LOOCV warnings were raised, verify they have expected content + for warning in loocv_warnings: + msg = str(warning.message) + # Warnings should mention LOOCV and provide context + assert "LOOCV" in msg, f"Warning should mention LOOCV: {msg}" + + def test_loocv_warning_deterministic_with_mock(self, simple_panel_data): + """Test that LOOCV returns infinity and warns on first fit failure. + + Per Equation 5, Q(λ) must sum over ALL D==0 cells. Any failure means + this λ cannot produce valid estimates, so we return infinity immediately. + """ + import warnings + from unittest.mock import patch + + trop_est = TROP( + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + n_bootstrap=5, + seed=42 + ) + + # Mock _estimate_model to fail on the first LOOCV call + # This simulates a parameter combination that can't estimate all control cells + call_count = [0] + original_estimate = trop_est._estimate_model + + def mock_estimate_with_failure(*args, **kwargs): + """Mock that fails on first call (immediate rejection per Equation 5).""" + call_count[0] += 1 + # Fail on first call to trigger immediate infinity return + if call_count[0] == 1: + raise np.linalg.LinAlgError("Simulated failure") + return original_estimate(*args, **kwargs) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Disable Rust backend for this test by patching the module-level variables + import sys + trop_module = sys.modules['diff_diff.trop'] + with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ + patch.object(trop_module, '_rust_loocv_grid_search', None), \ + patch.object(trop_est, '_estimate_model', mock_estimate_with_failure): + try: + trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + except (ValueError, np.linalg.LinAlgError): + # If all fits fail, that's acceptable + pass + + # Check that LOOCV warning was raised on first failure + loocv_warnings = [ + x for x in w + if issubclass(x.category, UserWarning) + and "LOOCV" in str(x.message) + ] + + # With any failure, we should get a warning about returning infinity + assert len(loocv_warnings) > 0, ( + "Expected LOOCV warning on first failure, but none was raised. " + f"call_count={call_count[0]}, warnings={[str(x.message) for x in w]}" + ) + + # Verify warning content mentions Equation 5 and returning infinity + msg = str(loocv_warnings[0].message) + assert "LOOCV" in msg + assert "fail" in msg.lower(), f"Warning should mention failure: {msg}" + assert "Equation 5" in msg, f"Warning should reference Equation 5: {msg}" + + +class TestLOOCVFallback: + """Tests for LOOCV fallback to defaults when all fits fail.""" + + def test_infinite_score_triggers_fallback(self, simple_panel_data): + """ + Test that infinite LOOCV scores trigger fallback to defaults. + + When all LOOCV fits return infinity (e.g., due to numerical issues), + the estimator should: + 1. Emit a warning about using defaults + 2. Use default parameters (1.0, 1.0, 0.1) + 3. Still complete estimation + """ + import sys + from unittest.mock import patch + + trop_est = TROP( + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=5, + seed=42 + ) + + # Mock LOOCV to always return infinity + def always_infinity(*args, **kwargs): + return np.inf + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Disable Rust backend and mock LOOCV score to always return infinity + trop_module = sys.modules['diff_diff.trop'] + with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ + patch.object(trop_module, '_rust_loocv_grid_search', None), \ + patch.object(trop_est, '_loocv_score_obs_specific', always_infinity): + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Verify warning emitted about fallback to defaults + fallback_warnings = [ + x for x in w + if issubclass(x.category, UserWarning) + and "defaults" in str(x.message).lower() + ] + assert len(fallback_warnings) > 0, ( + f"Expected fallback warning, got: {[str(x.message) for x in w]}" + ) + + # Verify defaults used (per REGISTRY.md: 1.0, 1.0, 0.1) + assert results.lambda_time == 1.0, \ + f"Expected default lambda_time=1.0, got {results.lambda_time}" + assert results.lambda_unit == 1.0, \ + f"Expected default lambda_unit=1.0, got {results.lambda_unit}" + assert results.lambda_nn == 0.1, \ + f"Expected default lambda_nn=0.1, got {results.lambda_nn}" + + # Verify estimation still completed + assert np.isfinite(results.att), "ATT should be finite even with default params" + + def test_rust_infinite_score_triggers_fallback(self, simple_panel_data): + """ + Test that infinite LOOCV score from Rust backend triggers fallback. + + The Rust backend may return infinite score when all fits fail. + Python should detect this and fall back to defaults. + When Rust returns infinity, best_lambda stays None, then Python fallback + is attempted. If Python also returns infinity, defaults are used. + """ + import sys + from unittest.mock import patch, MagicMock + + trop_est = TROP( + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=5, + seed=42 + ) + + # Mock Rust function to return infinite score + # Return format: (lambda_time, lambda_unit, lambda_nn, score, n_valid, n_attempted, first_failed_obs) + mock_rust_loocv = MagicMock(return_value=(0.5, 0.5, 0.05, np.inf, 0, 100, None)) + + # Also mock Python LOOCV to return infinity (so Python fallback also fails) + def always_infinity(*args, **kwargs): + return np.inf + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + trop_module = sys.modules['diff_diff.trop'] + with patch.object(trop_module, 'HAS_RUST_BACKEND', True), \ + patch.object(trop_module, '_rust_loocv_grid_search', mock_rust_loocv), \ + patch.object(trop_est, '_loocv_score_obs_specific', always_infinity): + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Verify warning emitted about fallback to defaults + fallback_warnings = [ + x for x in w + if issubclass(x.category, UserWarning) + and "defaults" in str(x.message).lower() + ] + assert len(fallback_warnings) > 0, ( + f"Expected fallback warning with Rust backend, got: {[str(x.message) for x in w]}" + ) + + # Verify defaults used (NOT the Rust-returned values) + assert results.lambda_time == 1.0, \ + f"Expected default lambda_time=1.0, got {results.lambda_time}" + assert results.lambda_unit == 1.0, \ + f"Expected default lambda_unit=1.0, got {results.lambda_unit}" + assert results.lambda_nn == 0.1, \ + f"Expected default lambda_nn=0.1, got {results.lambda_nn}" + + def test_infinity_grid_values_handled_consistently(self, simple_panel_data): + """ + Test that infinity in grids is handled consistently in LOOCV and final estimation. + + When infinity is in the parameter grid: + - LOOCV converts it for computation (λ_time=∞→0, λ_unit=∞→0, λ_nn=∞→1e10) + - LOOCV returns the original grid value (inf) if it was best + - Final estimation must also convert infinity to match LOOCV behavior + + This test ensures the conversion in final estimation matches LOOCV. + """ + # Create estimator with infinity in grids + # Use grids where infinity is likely to be selected: + # - lambda_time_grid: [inf] forces selection of inf + # - lambda_nn_grid: [inf] forces selection of inf + trop_est = TROP( + lambda_time_grid=[np.inf], # Only inf available → must be selected + lambda_unit_grid=[0.0], # Normal value + lambda_nn_grid=[np.inf], # Only inf available → must be selected + n_bootstrap=5, + seed=42 + ) + + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # ATT should be finite (no NaN/inf from unconverted infinity parameters) + assert np.isfinite(results.att), ( + f"ATT should be finite when infinity params are converted, got {results.att}" + ) + + # SE should be finite or at least non-negative + assert np.isfinite(results.se) or results.se >= 0, ( + f"SE should be finite, got {results.se}" + ) + + # The stored lambda values should be the original grid values (inf) + # because we store what was selected, but conversion happens internally + # (This documents current behavior; the key is that ATT is finite) + assert np.isinf(results.lambda_time) or results.lambda_time == 0.0, ( + f"lambda_time should be inf (stored) or 0.0 (if converted for storage)" + ) + assert np.isinf(results.lambda_nn) or results.lambda_nn == 1e10, ( + f"lambda_nn should be inf (stored) or 1e10 (if converted for storage)" + ) + + def test_variance_estimation_uses_converted_params(self, simple_panel_data): + """ + Test that variance estimation uses the same converted parameters as point estimation. + + When infinity is in the grid and gets selected, both ATT and SE should be + computed with the same effective parameters (e.g., λ_time=∞ converted to 0.0). + This test verifies the fix for variance estimation inconsistency (PR #110 Round 7). + """ + from unittest.mock import patch + + # Use grids with only infinity values to force selection + trop_est = TROP( + lambda_time_grid=[np.inf], # Will be converted to 0.0 internally + lambda_unit_grid=[0.0], + lambda_nn_grid=[np.inf], # Will be converted to 1e10 internally + n_bootstrap=5, + variance_method="bootstrap", + seed=42 + ) + + # Track what parameters are passed to _fit_with_fixed_lambda + # (called by bootstrap variance estimation) + original_fit_with_fixed = TROP._fit_with_fixed_lambda + captured_lambda = [] + + def tracking_fit(self, data, outcome, treatment, unit, time, fixed_lambda): + captured_lambda.append(fixed_lambda) + return original_fit_with_fixed(self, data, outcome, treatment, unit, time, fixed_lambda) + + with patch.object(TROP, '_fit_with_fixed_lambda', tracking_fit): + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Results should store original grid values + assert np.isinf(results.lambda_time), "Results should store original infinity value" + assert np.isinf(results.lambda_nn), "Results should store original infinity value" + + # ATT should be finite (computed with converted params) + assert np.isfinite(results.att), "ATT should be finite" + + # Variance estimation should have received converted parameters + # Check that bootstrap iterations used converted (non-infinite) values + for captured in captured_lambda: + lambda_time, lambda_unit, lambda_nn = captured + assert not np.isinf(lambda_time), ( + f"Bootstrap should receive converted λ_time=0.0, not {lambda_time}" + ) + assert not np.isinf(lambda_nn), ( + f"Bootstrap should receive converted λ_nn=1e10, not {lambda_nn}" + ) + + def test_empty_control_obs_returns_infinity(self, simple_panel_data): + """ + Test that LOOCV returns infinity when control observations are empty. + + A score of 0.0 for empty control would incorrectly "win" over legitimate + parameters. This test verifies the fix for empty control handling (PR #110 Round 7). + """ + import warnings + + trop_est = TROP( + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[1.0], + max_loocv_samples=100, + seed=42 + ) + + # Setup matrices from data + data = simple_panel_data + all_units = sorted(data['unit'].unique()) + all_periods = sorted(data['period'].unique()) + n_units = len(all_units) + n_periods = len(all_periods) + + Y = ( + data.pivot(index='period', columns='unit', values='outcome') + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + data.pivot(index='period', columns='unit', values='treated') + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(int) + .values + ) + + control_mask = D == 0 + control_unit_idx = np.where(~np.any(D == 1, axis=0))[0] + + # Force empty control_obs by setting precomputed with empty list + trop_est._precomputed = { + "control_obs": [], # Empty! + "time_dist_matrix": np.abs(np.subtract.outer( + np.arange(n_periods), np.arange(n_periods) + )), + } + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + score = trop_est._loocv_score_obs_specific( + Y, D, control_mask, control_unit_idx, + 1.0, 1.0, 1.0, n_units, n_periods + ) + + # Should return infinity, not 0.0 + assert np.isinf(score), f"Empty control_obs should return inf, got {score}" + + # Should emit warning + warning_msgs = [str(warning.message) for warning in w] + assert any("No valid control observations" in msg for msg in warning_msgs), ( + f"Should warn about empty control obs. Warnings: {warning_msgs}" + ) + + def test_original_grid_values_stored_in_results(self, simple_panel_data): + """ + Test that TROPResults stores the original grid values, not converted ones. + + Per the design decision in PR #110 Round 7, results should store the + original grid values (possibly inf) so users can see what was selected, + while internally using converted values for computation. + """ + trop_est = TROP( + lambda_time_grid=[np.inf], # Original value: inf + lambda_unit_grid=[0.5], + lambda_nn_grid=[np.inf], # Original value: inf + n_bootstrap=5, + seed=42 + ) + + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Results should store original grid values (inf) + assert np.isinf(results.lambda_time), ( + f"results.lambda_time should be inf (original), got {results.lambda_time}" + ) + assert results.lambda_unit == 0.5, ( + f"results.lambda_unit should be 0.5, got {results.lambda_unit}" + ) + assert np.isinf(results.lambda_nn), ( + f"results.lambda_nn should be inf (original), got {results.lambda_nn}" + ) + + # But ATT should still be finite (computed with converted values) + assert np.isfinite(results.att), "ATT should be finite" + + +class TestPR110FeedbackRound8: + """Tests for PR #110 feedback round 8 fixes. + + Issue 1: Final LOOCV score uses converted infinity values (not raw inf) + Issue 2: Rust LOOCV warnings include failed observation metadata + Issue 3: D matrix validation handles unbalanced panels correctly + """ + + def test_unbalanced_panel_d_matrix_validation(self): + """Test that unbalanced panels don't trigger spurious D matrix violations. + + Issue 3 fix: Missing unit-period observations should not be flagged + as violations. Only validate monotonicity between observed periods. + """ + # Create an unbalanced panel: unit 1 is missing period 5 + # Unit 1: treated from period 3 onwards, but missing period 5 + # This should NOT raise an error, because the 1→0 transition at period 5 + # is due to missing data, not a real violation. + data = [] + + # Unit 0: control, complete panel + for t in range(6): + data.append({ + "unit": 0, + "period": t, + "outcome": 10.0 + t, + "treated": 0, + }) + + # Unit 1: treated from t=3, missing t=5 (unbalanced) + for t in range(6): + if t == 5: + continue # Skip period 5 - creates unbalanced panel + treated = 1 if t >= 3 else 0 + data.append({ + "unit": 1, + "period": t, + "outcome": 10.0 + t + (2.0 if treated else 0), + "treated": treated, + }) + + # Unit 2: control, complete panel + for t in range(6): + data.append({ + "unit": 2, + "period": t, + "outcome": 10.0 + t, + "treated": 0, + }) + + df = pd.DataFrame(data) + + # This should NOT raise an error + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42 + ) + + # Should not raise ValueError - missing data is not a violation + try: + results = trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + # Basic sanity checks + assert results is not None + assert np.isfinite(results.att) + except ValueError as e: + if "absorbing state" in str(e): + pytest.fail( + f"Unbalanced panel incorrectly flagged as absorbing state violation: {e}" + ) + raise + + def test_unbalanced_panel_real_violation_still_caught(self): + """Test that real violations are still caught in unbalanced panels. + + Even with missing data, actual D→1→0 violations on observed periods + should still be detected and raise ValueError. + """ + data = [] + + # Unit 0: control, complete + for t in range(5): + data.append({ + "unit": 0, + "period": t, + "outcome": 10.0 + t, + "treated": 0, + }) + + # Unit 1: REAL violation - D goes 0→1→0 on observed periods (t=2: D=1, t=3: D=0) + # This is a real violation, not a missing data artifact + for t in range(5): + if t == 2: + treated = 1 + else: + treated = 0 + data.append({ + "unit": 1, + "period": t, + "outcome": 10.0 + t, + "treated": treated, + }) + + # Unit 2: control + for t in range(5): + data.append({ + "unit": 2, + "period": t, + "outcome": 10.0 + t, + "treated": 0, + }) + + df = pd.DataFrame(data) + + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5 + ) + + # This SHOULD raise an error - real violation + with pytest.raises(ValueError, match="absorbing state"): + trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + def test_unbalanced_panel_multiple_missing_periods(self): + """Test unbalanced panel with multiple missing periods per unit.""" + data = [] + + # Unit 0: control, complete + for t in range(8): + data.append({ + "unit": 0, + "period": t, + "outcome": 10.0 + t, + "treated": 0, + }) + + # Unit 1: treated from t=4, missing t=2 and t=6 + for t in range(8): + if t in [2, 6]: + continue # Skip these periods + treated = 1 if t >= 4 else 0 + data.append({ + "unit": 1, + "period": t, + "outcome": 10.0 + t + (2.0 if treated else 0), + "treated": treated, + }) + + # Unit 2: control, missing t=0 + for t in range(8): + if t == 0: + continue + data.append({ + "unit": 2, + "period": t, + "outcome": 10.0 + t, + "treated": 0, + }) + + df = pd.DataFrame(data) + + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42 + ) + + # Should not raise error + results = trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + assert results is not None + assert np.isfinite(results.att) + + def test_infinity_grid_values_with_final_score_computation(self, simple_panel_data): + """Test that infinity grid values are properly converted for final score. + + Issue 1 fix: When LOOCV selects infinity values from the grid, the + final score computation should use converted values (0.0 or 1e10), + not raw infinity. + """ + trop_est = TROP( + lambda_time_grid=[np.inf, 0.5], # inf should convert to 0.0 + lambda_unit_grid=[np.inf, 0.5], # inf should convert to 0.0 + lambda_nn_grid=[np.inf, 0.1], # inf should convert to 1e10 + n_bootstrap=5, + seed=42 + ) + + # This should complete without error, even if inf values are selected + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # ATT should be finite regardless of which grid values were selected + assert np.isfinite(results.att), "ATT should be finite with inf grid values" + assert results.se >= 0, "SE should be non-negative" + + # If inf values were selected, LOOCV score should still be computed correctly + # (using converted values, not raw inf) + if np.isinf(results.loocv_score): + # Infinite LOOCV score is acceptable (means fits failed) + # but ATT should still be finite (falls back to defaults) + pass + else: + assert np.isfinite(results.loocv_score), ( + "LOOCV score should be finite when computed with converted inf values" + ) + + def test_violation_across_missing_gap_caught(self): + """Test that 1→0 violations spanning missing periods are caught. + + Issue: If periods [3, 4] are missing and D[2]=1, D[5]=0, this is a + real violation that must be detected even though the adjacent + period transitions don't show it (the gap hides the transition). + + PR #110 round 10 fix: Check each unit's observed D sequence for + monotonicity, not just adjacent periods in the full time grid. + """ + data = [] + + # Unit 0: control, complete + for t in range(6): + data.append({"unit": 0, "period": t, "outcome": 10.0 + t, "treated": 0}) + + # Unit 1: VIOLATION across gap + # Observed at [0, 1, 2, 5], missing [3, 4] + # D[2]=1, D[5]=0 is a real violation spanning the gap + for t in [0, 1, 2, 5]: + treated = 1 if t == 2 else 0 # Only treated at period 2 + data.append({"unit": 1, "period": t, "outcome": 10.0 + t, "treated": treated}) + + # Unit 2: control, complete + for t in range(6): + data.append({"unit": 2, "period": t, "outcome": 10.0 + t, "treated": 0}) + + df = pd.DataFrame(data) + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + ) + + with pytest.raises(ValueError, match="absorbing state"): + trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + def test_n_post_periods_counts_observed_treatment(self): + """Test n_post_periods counts periods with actual D=1 observations. + + Per docstring: "Number of post-treatment periods (periods with D=1 observations)" + + This tests that n_post_periods reflects periods where treatment is + actually observed, not just calendar periods from first treatment. + """ + data = [] + + # Create panel where period 5 exists but has no D=1 observations + # (all treated units are missing at period 5) + for unit in range(3): + for period in range(6): + # Units 1, 2 are treated from period 3, but missing at period 5 + if unit in [1, 2] and period == 5: + continue # Skip - creates unbalanced panel + treated = 1 if (unit in [1, 2] and period >= 3) else 0 + data.append({ + "unit": unit, + "period": period, + "outcome": 10.0 + period, + "treated": treated, + }) + + df = pd.DataFrame(data) + trop_est = TROP( + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.0], + n_bootstrap=5, + seed=42, + ) + results = trop_est.fit( + df, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Periods with D=1 observations: 3, 4 (not 5 - missing for treated units) + assert results.n_post_periods == 2, ( + f"Expected 2 post-periods with D=1, got {results.n_post_periods}" + )