diff --git a/.gitignore b/.gitignore index d4d86f36..3811dc20 100644 --- a/.gitignore +++ b/.gitignore @@ -80,3 +80,6 @@ scripts/ # Launch directories (local only) launch/ launch-video/ + +# Reference implementations (local only) +trop_avg_ref/ diff --git a/CLAUDE.md b/CLAUDE.md index 4001f86f..cfe1930a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -138,6 +138,9 @@ test bootstrap::tests::test_webb_mean_approx_zero ... ok - `TROPResults` - Results with ATT, factors, loadings, unit/time weights - `trop()` - Convenience function for quick estimation - Three robustness components: factor adjustment, unit weights, time weights + - Two estimation methods via `method` parameter: + - `"twostep"` (default): Per-observation model fitting (Algorithm 2 of paper) + - `"joint"`: Weighted least squares with homogeneous treatment effect (faster) - Automatic rank selection via cross-validation, information criterion, or elbow detection - Bootstrap and placebo-based variance estimation diff --git a/README.md b/README.md index 338fd6a7..d7fea1bd 100644 --- a/README.md +++ b/README.md @@ -1267,6 +1267,7 @@ trop = TROP( ```python TROP( + method='twostep', # Estimation method: 'twostep' (default) or 'joint' lambda_time_grid=None, # Time decay grid (default: [0, 0.1, 0.5, 1, 2, 5]) lambda_unit_grid=None, # Unit distance grid (default: [0, 0.1, 0.5, 1, 2, 5]) lambda_nn_grid=None, # Nuclear norm grid (default: [0, 0.01, 0.1, 1, 10]) @@ -1279,6 +1280,10 @@ TROP( ) ``` +**Estimation methods:** +- `'twostep'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive. +- `'joint'`: Joint weighted least squares optimization. Estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Faster but assumes homogeneous treatment effects. + **Convenience function:** ```python diff --git a/diff_diff/_backend.py b/diff_diff/_backend.py index 6d22ead1..60bf853b 100644 --- a/diff_diff/_backend.py +++ b/diff_diff/_backend.py @@ -23,10 +23,13 @@ project_simplex as _rust_project_simplex, solve_ols as _rust_solve_ols, compute_robust_vcov as _rust_compute_robust_vcov, - # TROP estimator acceleration + # TROP estimator acceleration (twostep method) compute_unit_distance_matrix as _rust_unit_distance_matrix, loocv_grid_search as _rust_loocv_grid_search, bootstrap_trop_variance as _rust_bootstrap_trop_variance, + # TROP estimator acceleration (joint method) + loocv_grid_search_joint as _rust_loocv_grid_search_joint, + bootstrap_trop_variance_joint as _rust_bootstrap_trop_variance_joint, ) _rust_available = True except ImportError: @@ -36,10 +39,13 @@ _rust_project_simplex = None _rust_solve_ols = None _rust_compute_robust_vcov = None - # TROP estimator acceleration + # TROP estimator acceleration (twostep method) _rust_unit_distance_matrix = None _rust_loocv_grid_search = None _rust_bootstrap_trop_variance = None + # TROP estimator acceleration (joint method) + _rust_loocv_grid_search_joint = None + _rust_bootstrap_trop_variance_joint = None # Determine final backend based on environment variable and availability if _backend_env == 'python': @@ -50,10 +56,13 @@ _rust_project_simplex = None _rust_solve_ols = None _rust_compute_robust_vcov = None - # TROP estimator acceleration + # TROP estimator acceleration (twostep method) _rust_unit_distance_matrix = None _rust_loocv_grid_search = None _rust_bootstrap_trop_variance = None + # TROP estimator acceleration (joint method) + _rust_loocv_grid_search_joint = None + _rust_bootstrap_trop_variance_joint = None elif _backend_env == 'rust': # Force Rust mode - fail if not available if not _rust_available: @@ -73,8 +82,11 @@ '_rust_project_simplex', '_rust_solve_ols', '_rust_compute_robust_vcov', - # TROP estimator acceleration + # TROP estimator acceleration (twostep method) '_rust_unit_distance_matrix', '_rust_loocv_grid_search', '_rust_bootstrap_trop_variance', + # TROP estimator acceleration (joint method) + '_rust_loocv_grid_search_joint', + '_rust_bootstrap_trop_variance_joint', ] diff --git a/diff_diff/trop.py b/diff_diff/trop.py index dc71a692..a6e2fdca 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -38,6 +38,8 @@ _rust_unit_distance_matrix, _rust_loocv_grid_search, _rust_bootstrap_trop_variance, + _rust_loocv_grid_search_joint, + _rust_bootstrap_trop_variance_joint, ) from diff_diff.results import _get_significance_stars from diff_diff.utils import compute_confidence_interval, compute_p_value @@ -365,6 +367,20 @@ class TROP: Parameters ---------- + method : str, default='twostep' + Estimation method to use: + + - 'twostep': Per-observation model fitting following Algorithm 2 of + Athey et al. (2025). Computes observation-specific weights and fits + a model for each treated observation, averaging the individual + treatment effects. More flexible but computationally intensive. + + - 'joint': Joint weighted least squares optimization. Estimates a + single scalar treatment effect τ along with fixed effects and + optional low-rank factor adjustment. Faster but assumes homogeneous + treatment effects. Uses alternating minimization when nuclear norm + penalty is finite. + lambda_time_grid : list, optional Grid of time weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5]. lambda_unit_grid : list, optional @@ -434,6 +450,7 @@ class TROP: def __init__( self, + method: str = "twostep", lambda_time_grid: Optional[List[float]] = None, lambda_unit_grid: Optional[List[float]] = None, lambda_nn_grid: Optional[List[float]] = None, @@ -445,6 +462,14 @@ def __init__( max_loocv_samples: int = 100, seed: Optional[int] = None, ): + # Validate method parameter + valid_methods = ("twostep", "joint") + if method not in valid_methods: + raise ValueError( + f"method must be one of {valid_methods}, got '{method}'" + ) + self.method = method + # Default grids from paper self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0] self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0] @@ -828,7 +853,368 @@ def _cycling_parameter_search( return lambda_time, lambda_unit, lambda_nn - def fit( + # ========================================================================= + # Joint estimation method + # ========================================================================= + + def _compute_joint_weights( + self, + Y: np.ndarray, + D: np.ndarray, + lambda_time: float, + lambda_unit: float, + treated_periods: int, + n_units: int, + n_periods: int, + ) -> np.ndarray: + """ + Compute distance-based weights for joint estimation. + + Following the reference implementation, weights are computed based on: + - Time distance: distance to center of treated block + - Unit distance: RMSE to average treated trajectory over pre-periods + + Parameters + ---------- + Y : np.ndarray + Outcome matrix (n_periods x n_units). + D : np.ndarray + Treatment indicator matrix (n_periods x n_units). + lambda_time : float + Time weight decay parameter. + lambda_unit : float + Unit weight decay parameter. + treated_periods : int + Number of post-treatment periods. + n_units : int + Number of units. + n_periods : int + Number of periods. + + Returns + ------- + np.ndarray + Weight matrix (n_periods x n_units). + """ + # Identify treated units (ever treated) + treated_mask = np.any(D == 1, axis=0) + treated_unit_idx = np.where(treated_mask)[0] + + if len(treated_unit_idx) == 0: + raise ValueError("No treated units found") + + # Time weights: distance to center of treated block + # Following reference: center = T - treated_periods/2 + center = n_periods - treated_periods / 2.0 + dist_time = np.abs(np.arange(n_periods, dtype=float) - center) + delta_time = np.exp(-lambda_time * dist_time) + + # Unit weights: RMSE to average treated trajectory over pre-periods + # Compute average treated trajectory (use nanmean to handle NaN) + average_treated = np.nanmean(Y[:, treated_unit_idx], axis=1) + + # Pre-period mask: 1 in pre, 0 in post + pre_mask = np.ones(n_periods, dtype=float) + pre_mask[-treated_periods:] = 0.0 + + # Compute RMS distance for each unit + # dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre) + # Use NaN-safe operations: treat NaN differences as 0 (excluded) + diff = average_treated[:, np.newaxis] - Y + diff_sq = np.where(np.isfinite(diff), diff ** 2, 0.0) * pre_mask[:, np.newaxis] + + # Count valid observations per unit in pre-period + # Must check diff is finite (both Y and average_treated finite) + # to match the periods contributing to diff_sq + valid_count = np.sum( + np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0 + ) + sum_sq = np.sum(diff_sq, axis=0) + n_pre = np.sum(pre_mask) + + if n_pre == 0: + raise ValueError("No pre-treatment periods") + + # Track units with no valid pre-period data + no_valid_pre = valid_count == 0 + + # Use valid count per unit (avoid division by zero for calculation) + valid_count_safe = np.maximum(valid_count, 1) + dist_unit = np.sqrt(sum_sq / valid_count_safe) + + # Units with no valid pre-period data get zero weight + # (dist is undefined, so we set it to inf -> delta_unit = exp(-inf) = 0) + delta_unit = np.exp(-lambda_unit * dist_unit) + delta_unit[no_valid_pre] = 0.0 + + # Outer product: (n_periods x n_units) + delta = np.outer(delta_time, delta_unit) + + return delta + + def _loocv_score_joint( + self, + Y: np.ndarray, + D: np.ndarray, + control_obs: List[Tuple[int, int]], + lambda_time: float, + lambda_unit: float, + lambda_nn: float, + treated_periods: int, + n_units: int, + n_periods: int, + ) -> float: + """ + Compute LOOCV score for joint method with specific parameter combination. + + Following paper's Equation 5: + Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² + + For joint method, we exclude each control observation, fit the joint model + on remaining data, and compute the pseudo-treatment effect at the excluded obs. + + Parameters + ---------- + Y : np.ndarray + Outcome matrix (n_periods x n_units). + D : np.ndarray + Treatment indicator matrix (n_periods x n_units). + control_obs : List[Tuple[int, int]] + List of (t, i) control observations for LOOCV. + lambda_time : float + Time weight decay parameter. + lambda_unit : float + Unit weight decay parameter. + lambda_nn : float + Nuclear norm regularization parameter. + treated_periods : int + Number of post-treatment periods. + n_units : int + Number of units. + n_periods : int + Number of periods. + + Returns + ------- + float + LOOCV score (sum of squared pseudo-treatment effects). + """ + # Compute global weights (same for all LOOCV iterations) + delta = self._compute_joint_weights( + Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods + ) + + tau_sq_sum = 0.0 + n_valid = 0 + + for t_ex, i_ex in control_obs: + # Create modified delta with excluded observation zeroed out + delta_ex = delta.copy() + delta_ex[t_ex, i_ex] = 0.0 + + try: + # Fit joint model excluding this observation + if lambda_nn >= 1e10: + mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta_ex) + L = np.zeros((n_periods, n_units)) + else: + mu, alpha, beta, L, tau = self._solve_joint_with_lowrank( + Y, D, delta_ex, lambda_nn, self.max_iter, self.tol + ) + + # Pseudo treatment effect: τ = Y - μ - α - β - L + if np.isfinite(Y[t_ex, i_ex]): + tau_loocv = Y[t_ex, i_ex] - mu - alpha[i_ex] - beta[t_ex] - L[t_ex, i_ex] + tau_sq_sum += tau_loocv ** 2 + n_valid += 1 + + except (np.linalg.LinAlgError, ValueError): + # Any failure means this λ combination is invalid per Equation 5 + return np.inf + + if n_valid == 0: + return np.inf + + return tau_sq_sum + + def _solve_joint_no_lowrank( + self, + Y: np.ndarray, + D: np.ndarray, + delta: np.ndarray, + ) -> Tuple[float, np.ndarray, np.ndarray, float]: + """ + Solve joint TWFE + treatment via weighted least squares (no low-rank). + + Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - τ*W_{it})² + + Parameters + ---------- + Y : np.ndarray + Outcome matrix (n_periods x n_units). + D : np.ndarray + Treatment indicator matrix (n_periods x n_units). + delta : np.ndarray + Weight matrix (n_periods x n_units). + + Returns + ------- + Tuple[float, np.ndarray, np.ndarray, float] + (mu, alpha, beta, tau) estimated parameters. + """ + n_periods, n_units = Y.shape + + # Flatten matrices for regression + y = Y.flatten() # length n_periods * n_units + w = D.flatten() + weights = delta.flatten() + + # Handle NaN values: zero weight for NaN outcomes/weights, impute with 0 + # This ensures NaN observations don't contribute to estimation + valid_y = np.isfinite(y) + valid_w = np.isfinite(weights) + valid_mask = valid_y & valid_w + weights = np.where(valid_mask, weights, 0.0) + y = np.where(valid_mask, y, 0.0) + + sqrt_weights = np.sqrt(np.maximum(weights, 0)) + + # Check for all-zero weights (matches Rust's sum_w < 1e-10 check) + sum_w = np.sum(weights) + if sum_w < 1e-10: + raise ValueError("All weights are zero - cannot estimate") + + # Build design matrix: [intercept, unit_dummies, time_dummies, treatment] + # Total columns: 1 + n_units + n_periods + 1 + # But we need to drop one unit and one time dummy for identification + # Drop first unit (unit 0) and first time (time 0) + n_obs = n_periods * n_units + n_params = 1 + (n_units - 1) + (n_periods - 1) + 1 + + X = np.zeros((n_obs, n_params)) + X[:, 0] = 1.0 # intercept + + # Unit dummies (skip unit 0) + for i in range(1, n_units): + for t in range(n_periods): + X[t * n_units + i, i] = 1.0 + + # Time dummies (skip time 0) + for t in range(1, n_periods): + for i in range(n_units): + X[t * n_units + i, (n_units - 1) + t] = 1.0 + + # Treatment indicator + X[:, -1] = w + + # Apply weights + X_weighted = X * sqrt_weights[:, np.newaxis] + y_weighted = y * sqrt_weights + + # Solve weighted least squares + try: + coeffs, _, _, _ = np.linalg.lstsq(X_weighted, y_weighted, rcond=None) + except np.linalg.LinAlgError: + # Fallback: use pseudo-inverse + coeffs = np.linalg.pinv(X_weighted) @ y_weighted + + # Extract parameters + mu = coeffs[0] + alpha = np.zeros(n_units) + alpha[1:] = coeffs[1:n_units] + beta = np.zeros(n_periods) + beta[1:] = coeffs[n_units:(n_units + n_periods - 1)] + tau = coeffs[-1] + + return float(mu), alpha, beta, float(tau) + + def _solve_joint_with_lowrank( + self, + Y: np.ndarray, + D: np.ndarray, + delta: np.ndarray, + lambda_nn: float, + max_iter: int = 100, + tol: float = 1e-6, + ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]: + """ + Solve joint TWFE + treatment + low-rank via alternating minimization. + + Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it} - τ*W_{it})² + λ_nn||L||_* + + Parameters + ---------- + Y : np.ndarray + Outcome matrix (n_periods x n_units). + D : np.ndarray + Treatment indicator matrix (n_periods x n_units). + delta : np.ndarray + Weight matrix (n_periods x n_units). + lambda_nn : float + Nuclear norm regularization parameter. + max_iter : int, default=100 + Maximum iterations for alternating minimization. + tol : float, default=1e-6 + Convergence tolerance. + + Returns + ------- + Tuple[float, np.ndarray, np.ndarray, np.ndarray, float] + (mu, alpha, beta, L, tau) estimated parameters. + """ + n_periods, n_units = Y.shape + + # Handle NaN values: impute with 0 for computations + # The solver will also zero weights for NaN observations + Y_safe = np.where(np.isfinite(Y), Y, 0.0) + + # Mask delta to exclude NaN outcomes from estimation + # This ensures NaN observations don't contribute to the gradient step + nan_mask = ~np.isfinite(Y) + delta_masked = delta.copy() + delta_masked[nan_mask] = 0.0 + + # Initialize L = 0 + L = np.zeros((n_periods, n_units)) + + for iteration in range(max_iter): + L_old = L.copy() + + # Step 1: Fix L, solve for (mu, alpha, beta, tau) + # Adjusted outcome: Y - L (using NaN-safe Y) + # Pass masked delta to exclude NaN observations from WLS + Y_adj = Y_safe - L + mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y_adj, D, delta_masked) + + # Step 2: Fix (mu, alpha, beta, tau), update L + # Residual: R = Y - mu - alpha - beta - tau*D (using NaN-safe Y) + R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - tau * D + + # Weighted proximal step for L (soft-threshold SVD) + # Normalize weights (using masked delta to exclude NaN observations) + delta_max = np.max(delta_masked) + if delta_max > 0: + delta_norm = delta_masked / delta_max + else: + delta_norm = delta_masked + + # Weighted average between current L and target R + # L_next = L + delta_norm * (R - L), then soft-threshold + # NaN observations have delta_norm=0, so they don't influence L update + gradient_step = L + delta_norm * (R - L) + + # Soft-threshold singular values + # Use eta * lambda_nn for proper proximal step size (matches Rust) + eta = 1.0 / delta_max if delta_max > 0 else 1.0 + L = self._soft_threshold_svd(gradient_step, eta * lambda_nn) + + # Check convergence + if np.max(np.abs(L - L_old)) < tol: + break + + return mu, alpha, beta, L, tau + + def _fit_joint( self, data: pd.DataFrame, outcome: str, @@ -837,49 +1223,664 @@ def fit( time: str, ) -> TROPResults: """ - Fit the TROP model. + Fit TROP using joint weighted least squares method. + + This method estimates a single scalar treatment effect τ along with + fixed effects and optional low-rank factor adjustment. Parameters ---------- data : pd.DataFrame - Panel data with observations for multiple units over multiple - time periods. + Panel data. outcome : str - Name of the outcome variable column. + Outcome variable column name. treatment : str - Name of the treatment indicator column (0/1). - - IMPORTANT: This should be an ABSORBING STATE indicator, not a - treatment timing indicator. For each unit, D=1 for ALL periods - during and after treatment: - - - D[t, i] = 0 for all t < g_i (pre-treatment periods) - - D[t, i] = 1 for all t >= g_i (treatment and post-treatment) - - where g_i is the treatment start time for unit i. - - For staggered adoption, different units can have different g_i. - The ATT averages over ALL D=1 cells per Equation 1 of the paper. + Treatment indicator column name. unit : str - Name of the unit identifier column. + Unit identifier column name. time : str - Name of the time period column. + Time period column name. Returns ------- TROPResults - Object containing the ATT estimate, standard error, - factor estimates, and tuning parameters. The lambda_* - attributes show the selected grid values. Infinity values - (∞) are converted internally: λ_time/λ_unit=∞ → 0.0 (uniform - weights), λ_nn=∞ → 1e10 (factor model disabled). + Estimation results. + + Notes + ----- + Bootstrap and jackknife variance estimation assume simultaneous treatment + adoption (fixed `treated_periods` across resamples). The treatment timing + is inferred from the data once and held constant for all bootstrap/jackknife + iterations. For staggered adoption designs where treatment timing varies + across units, use `method="twostep"` which computes observation-specific + weights that naturally handle heterogeneous timing. """ - # Validate inputs - required_cols = [outcome, treatment, unit, time] - missing = [c for c in required_cols if c not in data.columns] - if missing: - raise ValueError(f"Missing columns: {missing}") + # Data setup (same as twostep method) + all_units = sorted(data[unit].unique()) + all_periods = sorted(data[time].unique()) + + n_units = len(all_units) + n_periods = len(all_periods) + + idx_to_unit = {i: u for i, u in enumerate(all_units)} + idx_to_period = {i: p for i, p in enumerate(all_periods)} + + # Create matrices + Y = ( + data.pivot(index=time, columns=unit, values=outcome) + .reindex(index=all_periods, columns=all_units) + .values + ) + + D_raw = ( + data.pivot(index=time, columns=unit, values=treatment) + .reindex(index=all_periods, columns=all_units) + ) + missing_mask = pd.isna(D_raw).values + D = D_raw.fillna(0).astype(int).values + + # Validate absorbing state + violating_units = [] + for unit_idx in range(n_units): + observed_mask = ~missing_mask[:, unit_idx] + observed_d = D[observed_mask, unit_idx] + if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0): + violating_units.append(all_units[unit_idx]) + + if violating_units: + raise ValueError( + f"Treatment indicator is not an absorbing state for units: {violating_units}. " + f"D[t, unit] must be monotonic non-decreasing." + ) + + # Identify treated observations + treated_mask = D == 1 + n_treated_obs = np.sum(treated_mask) + + if n_treated_obs == 0: + raise ValueError("No treated observations found") + + # Identify treated and control units + unit_ever_treated = np.any(D == 1, axis=0) + treated_unit_idx = np.where(unit_ever_treated)[0] + control_unit_idx = np.where(~unit_ever_treated)[0] + + if len(control_unit_idx) == 0: + raise ValueError("No control units found") + + # Determine pre/post periods + first_treat_period = None + for t in range(n_periods): + if np.any(D[t, :] == 1): + first_treat_period = t + break + + if first_treat_period is None: + raise ValueError("Could not infer post-treatment periods from D matrix") + + n_pre_periods = first_treat_period + treated_periods = n_periods - first_treat_period + n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1))) + + if n_pre_periods < 2: + raise ValueError("Need at least 2 pre-treatment periods") + + # Check for staggered adoption (joint method requires simultaneous treatment) + # Use only observed periods (skip missing) to avoid false positives on unbalanced panels + first_treat_by_unit = [] + for i in treated_unit_idx: + observed_mask = ~missing_mask[:, i] + # Get D values for observed periods only + observed_d = D[observed_mask, i] + observed_periods = np.where(observed_mask)[0] + # Find first treatment among observed periods + treated_idx = np.where(observed_d == 1)[0] + if len(treated_idx) > 0: + first_treat_by_unit.append(observed_periods[treated_idx[0]]) + + unique_starts = sorted(set(first_treat_by_unit)) + if len(unique_starts) > 1: + raise ValueError( + f"method='joint' requires simultaneous treatment adoption, but your data " + f"shows staggered adoption (units first treated at periods {unique_starts}). " + f"Use method='twostep' which properly handles staggered adoption designs." + ) + + # LOOCV grid search for tuning parameters + # Use Rust backend when available for parallel LOOCV (5-10x speedup) + best_lambda = None + best_score = np.inf + control_mask = D == 0 + + if HAS_RUST_BACKEND and _rust_loocv_grid_search_joint is not None: + try: + # Prepare inputs for Rust function + control_mask_u8 = control_mask.astype(np.uint8) + + lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64) + lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64) + lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64) + + result = _rust_loocv_grid_search_joint( + Y, D.astype(np.float64), control_mask_u8, + lambda_time_arr, lambda_unit_arr, lambda_nn_arr, + self.max_loocv_samples, self.max_iter, self.tol, + self.seed if self.seed is not None else 0 + ) + # Unpack result - 7 values including optional first_failed_obs + best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result + # Only accept finite scores - infinite means all fits failed + if np.isfinite(best_score): + best_lambda = (best_lt, best_lu, best_ln) + # Emit warnings consistent with Python implementation + if n_valid == 0: + obs_info = "" + if first_failed_obs is not None: + t_idx, i_idx = first_failed_obs + obs_info = f" First failure at observation ({t_idx}, {i_idx})." + warnings.warn( + f"LOOCV: All {n_attempted} fits failed for " + f"λ=({best_lt}, {best_lu}, {best_ln}). " + f"Returning infinite score.{obs_info}", + UserWarning + ) + elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted: + n_failed = n_attempted - n_valid + obs_info = "" + if first_failed_obs is not None: + t_idx, i_idx = first_failed_obs + obs_info = f" First failure at observation ({t_idx}, {i_idx})." + warnings.warn( + f"LOOCV: {n_failed}/{n_attempted} fits failed for " + f"λ=({best_lt}, {best_lu}, {best_ln}). " + f"This may indicate numerical instability.{obs_info}", + UserWarning + ) + except Exception as e: + # Fall back to Python implementation on error + logger.debug( + "Rust LOOCV grid search (joint) failed, falling back to Python: %s", e + ) + best_lambda = None + best_score = np.inf + + # Fall back to Python implementation if Rust unavailable or failed + if best_lambda is None: + # Get control observations for LOOCV + control_obs = [ + (t, i) for t in range(n_periods) for i in range(n_units) + if control_mask[t, i] and not np.isnan(Y[t, i]) + ] + + # Subsample if needed (sample indices to avoid ValueError on list of tuples) + rng = np.random.default_rng(self.seed) + max_loocv = min(self.max_loocv_samples, len(control_obs)) + if len(control_obs) > max_loocv: + indices = rng.choice(len(control_obs), size=max_loocv, replace=False) + control_obs = [control_obs[idx] for idx in indices] + + # Grid search with true LOOCV + for lambda_time_val in self.lambda_time_grid: + for lambda_unit_val in self.lambda_unit_grid: + for lambda_nn_val in self.lambda_nn_grid: + # Convert infinity values + lt = 0.0 if np.isinf(lambda_time_val) else lambda_time_val + lu = 0.0 if np.isinf(lambda_unit_val) else lambda_unit_val + ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val + + try: + score = self._loocv_score_joint( + Y, D, control_obs, lt, lu, ln, + treated_periods, n_units, n_periods + ) + + if score < best_score: + best_score = score + best_lambda = (lambda_time_val, lambda_unit_val, lambda_nn_val) + + except (np.linalg.LinAlgError, ValueError): + continue + + if best_lambda is None: + warnings.warn( + "All tuning parameter combinations failed. Using defaults.", + UserWarning + ) + best_lambda = (1.0, 1.0, 0.1) + best_score = np.nan + + # Final estimation with best parameters + lambda_time, lambda_unit, lambda_nn = best_lambda + original_lambda_time, original_lambda_unit, original_lambda_nn = best_lambda + + # Convert infinity values for computation + if np.isinf(lambda_time): + lambda_time = 0.0 + if np.isinf(lambda_unit): + lambda_unit = 0.0 + if np.isinf(lambda_nn): + lambda_nn = 1e10 + + # Compute final weights and fit + delta = self._compute_joint_weights( + Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods + ) + + if lambda_nn >= 1e10: + mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta) + L = np.zeros((n_periods, n_units)) + else: + mu, alpha, beta, L, tau = self._solve_joint_with_lowrank( + Y, D, delta, lambda_nn, self.max_iter, self.tol + ) + + # ATT is the scalar treatment effect + att = tau + + # Compute individual treatment effects for reporting (same τ for all) + treatment_effects = {} + for t in range(n_periods): + for i in range(n_units): + if D[t, i] == 1: + unit_id = idx_to_unit[i] + time_id = idx_to_period[t] + treatment_effects[(unit_id, time_id)] = tau + + # Compute effective rank of L + _, s, _ = np.linalg.svd(L, full_matrices=False) + if s[0] > 0: + effective_rank = np.sum(s) / s[0] + else: + effective_rank = 0.0 + + # Bootstrap variance estimation + effective_lambda = (lambda_time, lambda_unit, lambda_nn) + + if self.variance_method == "bootstrap": + se, bootstrap_dist = self._bootstrap_variance_joint( + data, outcome, treatment, unit, time, + effective_lambda, treated_periods + ) + else: + # Jackknife for joint method + se, bootstrap_dist = self._jackknife_variance_joint( + Y, D, effective_lambda, treated_periods, + n_units, n_periods + ) + + # Compute test statistics + if se > 0: + t_stat = att / se + p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1))) + conf_int = compute_confidence_interval(att, se, self.alpha) + else: + t_stat = np.nan + p_value = np.nan + conf_int = (np.nan, np.nan) + + # Create results dictionaries + unit_effects_dict = {idx_to_unit[i]: alpha[i] for i in range(n_units)} + time_effects_dict = {idx_to_period[t]: beta[t] for t in range(n_periods)} + + self.results_ = TROPResults( + att=float(att), + se=float(se), + t_stat=float(t_stat) if np.isfinite(t_stat) else t_stat, + p_value=float(p_value) if np.isfinite(p_value) else p_value, + conf_int=conf_int, + n_obs=len(data), + n_treated=len(treated_unit_idx), + n_control=len(control_unit_idx), + n_treated_obs=int(n_treated_obs), + unit_effects=unit_effects_dict, + time_effects=time_effects_dict, + treatment_effects=treatment_effects, + lambda_time=original_lambda_time, + lambda_unit=original_lambda_unit, + lambda_nn=original_lambda_nn, + factor_matrix=L, + effective_rank=effective_rank, + loocv_score=best_score, + variance_method=self.variance_method, + alpha=self.alpha, + n_pre_periods=n_pre_periods, + n_post_periods=n_post_periods, + n_bootstrap=self.n_bootstrap if self.variance_method == "bootstrap" else None, + bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None, + ) + + self.is_fitted_ = True + return self.results_ + + def _bootstrap_variance_joint( + self, + data: pd.DataFrame, + outcome: str, + treatment: str, + unit: str, + time: str, + optimal_lambda: Tuple[float, float, float], + treated_periods: int, + ) -> Tuple[float, np.ndarray]: + """ + Compute bootstrap standard error for joint method. + + Uses Rust backend when available for parallel bootstrap (5-15x speedup). + + Parameters + ---------- + data : pd.DataFrame + Original data. + outcome : str + Outcome column name. + treatment : str + Treatment column name. + unit : str + Unit column name. + time : str + Time column name. + optimal_lambda : tuple + Optimal tuning parameters. + treated_periods : int + Number of post-treatment periods. + + Returns + ------- + Tuple[float, np.ndarray] + (se, bootstrap_estimates). + """ + lambda_time, lambda_unit, lambda_nn = optimal_lambda + + # Try Rust backend for parallel bootstrap (5-15x speedup) + if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_joint is not None: + try: + # Create matrices for Rust function + all_units = sorted(data[unit].unique()) + all_periods = sorted(data[time].unique()) + + Y = ( + data.pivot(index=time, columns=unit, values=outcome) + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + data.pivot(index=time, columns=unit, values=treatment) + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(np.float64) + .values + ) + + bootstrap_estimates, se = _rust_bootstrap_trop_variance_joint( + Y, D, + lambda_time, lambda_unit, lambda_nn, + self.n_bootstrap, self.max_iter, self.tol, + self.seed if self.seed is not None else 0 + ) + + if len(bootstrap_estimates) < 10: + warnings.warn( + f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", + UserWarning + ) + if len(bootstrap_estimates) == 0: + return 0.0, np.array([]) + + return float(se), np.array(bootstrap_estimates) + + except Exception as e: + logger.debug( + "Rust bootstrap (joint) failed, falling back to Python: %s", e + ) + + # Python fallback implementation + rng = np.random.default_rng(self.seed) + + # Stratified bootstrap sampling + unit_ever_treated = data.groupby(unit)[treatment].max() + treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index.tolist()) + control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index.tolist()) + + n_treated_units = len(treated_units) + n_control_units = len(control_units) + + bootstrap_estimates_list: List[float] = [] + + for _ in range(self.n_bootstrap): + # Stratified sampling + if n_control_units > 0: + sampled_control = rng.choice( + control_units, size=n_control_units, replace=True + ) + else: + sampled_control = np.array([], dtype=object) + + if n_treated_units > 0: + sampled_treated = rng.choice( + treated_units, size=n_treated_units, replace=True + ) + else: + sampled_treated = np.array([], dtype=object) + + sampled_units = np.concatenate([sampled_control, sampled_treated]) + + # Create bootstrap sample + boot_data = pd.concat([ + data[data[unit] == u].assign(**{unit: f"{u}_{idx}"}) + for idx, u in enumerate(sampled_units) + ], ignore_index=True) + + try: + tau = self._fit_joint_with_fixed_lambda( + boot_data, outcome, treatment, unit, time, + optimal_lambda, treated_periods + ) + bootstrap_estimates_list.append(tau) + except (ValueError, np.linalg.LinAlgError, KeyError): + continue + + bootstrap_estimates = np.array(bootstrap_estimates_list) + + if len(bootstrap_estimates) < 10: + warnings.warn( + f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", + UserWarning + ) + if len(bootstrap_estimates) == 0: + return 0.0, np.array([]) + + se = np.std(bootstrap_estimates, ddof=1) + return float(se), bootstrap_estimates + + def _fit_joint_with_fixed_lambda( + self, + data: pd.DataFrame, + outcome: str, + treatment: str, + unit: str, + time: str, + fixed_lambda: Tuple[float, float, float], + treated_periods: int, + ) -> float: + """ + Fit joint model with fixed tuning parameters. + + Returns only the treatment effect τ. + """ + lambda_time, lambda_unit, lambda_nn = fixed_lambda + + all_units = sorted(data[unit].unique()) + all_periods = sorted(data[time].unique()) + + n_units = len(all_units) + n_periods = len(all_periods) + + Y = ( + data.pivot(index=time, columns=unit, values=outcome) + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + data.pivot(index=time, columns=unit, values=treatment) + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(int) + .values + ) + + # Compute weights + delta = self._compute_joint_weights( + Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods + ) + + # Fit model + if lambda_nn >= 1e10: + _, _, _, tau = self._solve_joint_no_lowrank(Y, D, delta) + else: + _, _, _, _, tau = self._solve_joint_with_lowrank( + Y, D, delta, lambda_nn, self.max_iter, self.tol + ) + + return tau + + def _jackknife_variance_joint( + self, + Y: np.ndarray, + D: np.ndarray, + optimal_lambda: Tuple[float, float, float], + treated_periods: int, + n_units: int, + n_periods: int, + ) -> Tuple[float, np.ndarray]: + """ + Compute jackknife standard error for joint method. + + Parameters + ---------- + Y : np.ndarray + Outcome matrix. + D : np.ndarray + Treatment matrix. + optimal_lambda : tuple + Optimal tuning parameters. + treated_periods : int + Number of post-treatment periods. + n_units : int + Number of units. + n_periods : int + Number of periods. + + Returns + ------- + Tuple[float, np.ndarray] + (se, jackknife_estimates). + """ + lambda_time, lambda_unit, lambda_nn = optimal_lambda + jackknife_estimates = [] + + # Get treated unit indices + treated_unit_idx = np.where(np.any(D == 1, axis=0))[0] + + for leave_out in treated_unit_idx: + # True leave-one-out: zero the delta weight for the left-out unit + # This excludes the unit from estimation without imputation + Y_jack = Y.copy() + D_jack = D.copy() + D_jack[:, leave_out] = 0 # Mark as not treated for weight computation + + try: + # Compute weights (left-out unit is still in calculation) + delta = self._compute_joint_weights( + Y_jack, D_jack, lambda_time, lambda_unit, + treated_periods, n_units, n_periods + ) + + # Zero the delta weight for the left-out unit + # This ensures the unit doesn't contribute to estimation + delta[:, leave_out] = 0.0 + + # Fit model (left-out unit has zero weight, truly excluded) + if lambda_nn >= 1e10: + _, _, _, tau = self._solve_joint_no_lowrank(Y_jack, D_jack, delta) + else: + _, _, _, _, tau = self._solve_joint_with_lowrank( + Y_jack, D_jack, delta, lambda_nn, self.max_iter, self.tol + ) + + jackknife_estimates.append(tau) + + except (np.linalg.LinAlgError, ValueError): + continue + + jackknife_estimates = np.array(jackknife_estimates) + + if len(jackknife_estimates) < 2: + return 0.0, jackknife_estimates + + # Jackknife SE formula + n = len(jackknife_estimates) + mean_est = np.mean(jackknife_estimates) + se = np.sqrt((n - 1) / n * np.sum((jackknife_estimates - mean_est) ** 2)) + + return float(se), jackknife_estimates + + def fit( + self, + data: pd.DataFrame, + outcome: str, + treatment: str, + unit: str, + time: str, + ) -> TROPResults: + """ + Fit the TROP model. + + Parameters + ---------- + data : pd.DataFrame + Panel data with observations for multiple units over multiple + time periods. + outcome : str + Name of the outcome variable column. + treatment : str + Name of the treatment indicator column (0/1). + + IMPORTANT: This should be an ABSORBING STATE indicator, not a + treatment timing indicator. For each unit, D=1 for ALL periods + during and after treatment: + + - D[t, i] = 0 for all t < g_i (pre-treatment periods) + - D[t, i] = 1 for all t >= g_i (treatment and post-treatment) + + where g_i is the treatment start time for unit i. + + For staggered adoption, different units can have different g_i. + The ATT averages over ALL D=1 cells per Equation 1 of the paper. + unit : str + Name of the unit identifier column. + time : str + Name of the time period column. + + Returns + ------- + TROPResults + Object containing the ATT estimate, standard error, + factor estimates, and tuning parameters. The lambda_* + attributes show the selected grid values. Infinity values + (∞) are converted internally: λ_time/λ_unit=∞ → 0.0 (uniform + weights), λ_nn=∞ → 1e10 (factor model disabled). + """ + # Validate inputs + required_cols = [outcome, treatment, unit, time] + missing = [c for c in required_cols if c not in data.columns] + if missing: + raise ValueError(f"Missing columns: {missing}") + + # Dispatch based on estimation method + if self.method == "joint": + return self._fit_joint(data, outcome, treatment, unit, time) + # Below is the twostep method (default) # Get unique units and periods all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) @@ -2079,6 +3080,7 @@ def _fit_with_fixed_lambda( def get_params(self) -> Dict[str, Any]: """Get estimator parameters.""" return { + "method": self.method, "lambda_time_grid": self.lambda_time_grid, "lambda_unit_grid": self.lambda_unit_grid, "lambda_nn_grid": self.lambda_nn_grid, diff --git a/docs/api/trop.rst b/docs/api/trop.rst index 732dbcfa..e359b41b 100644 --- a/docs/api/trop.rst +++ b/docs/api/trop.rst @@ -95,10 +95,14 @@ TROP uses leave-one-out cross-validation (LOOCV) to select three tuning paramete - Nuclear norm penalty - Higher values encourage lower-rank factor structure -Algorithm ---------- +Estimation Methods +------------------ -TROP follows Algorithm 2 from the paper: +TROP supports two estimation methods via the ``method`` parameter: + +**Two-Step Method** (``method='twostep'``, default) + +The default method follows Algorithm 2 from the paper: 1. **Grid search with LOOCV**: For each (λ_time, λ_unit, λ_nn) combination, compute cross-validation score by treating control observations as pseudo-treated @@ -111,10 +115,54 @@ TROP follows Algorithm 2 from the paper: 3. **Average**: ATT = mean(τ̂_{it}) over all treated observations -This structure provides the **triple robustness** property (Theorem 5.1): +This provides the **triple robustness** property (Theorem 5.1): the estimator is consistent if any one of the three components (unit weights, time weights, factor model) is correctly specified. +**Joint Method** (``method='joint'``) + +An alternative approach that estimates a single scalar treatment effect: + +1. **Compute weights**: Distance-based unit and time weights computed once + (distance to center of treated block, RMSE to average treated trajectory) + +2. **Joint optimization**: Solve weighted least squares problem + + .. math:: + + \min_{\mu, \alpha, \beta, L, \tau} \sum_{i,t} \delta_{it} (Y_{it} - \mu - \alpha_i - \beta_t - L_{it} - W_{it} \tau)^2 + \lambda_{nn} \|L\|_* + + where τ is a **single scalar** (homogeneous treatment effect). + +3. **With low-rank** (finite λ_nn): Uses alternating minimization between + weighted LS for (μ, α, β, τ) and soft-threshold SVD for L. + +The joint method is **faster** (single optimization vs N_treated optimizations) +but assumes **homogeneous treatment effects** across all treated observations. + +.. list-table:: + :header-rows: 1 + :widths: 20 40 40 + + * - Feature + - Two-Step (default) + - Joint + * - Treatment effect + - Per-observation τ_{it} + - Single scalar τ + * - Flexibility + - Heterogeneous effects + - Homogeneous assumption + * - Speed + - Slower (N_treated fits) + - Faster (single fit) + * - Weights + - Observation-specific + - Global (center of treated block) + +Use ``method='twostep'`` when treatment effects may vary across observations. +Use ``method='joint'`` for faster estimation when effects are expected to be homogeneous. + Example Usage ------------- @@ -155,6 +203,28 @@ Quick estimation with convenience function:: n_bootstrap=200 ) +Using the joint method for faster estimation:: + + from diff_diff import TROP + + # Joint method: single scalar treatment effect via weighted LS + trop_joint = TROP( + method='joint', # Use joint weighted least squares + lambda_time_grid=[0.0, 0.5, 1.0, 2.0], + lambda_unit_grid=[0.0, 0.5, 1.0, 2.0], + lambda_nn_grid=[0.0, 0.1, 1.0], + n_bootstrap=200, + seed=42 + ) + results_joint = trop_joint.fit(data, outcome='y', treatment='treated', + unit='unit_id', time='period') + + # Compare methods + trop_twostep = TROP(method='twostep', ...) # Default + results_twostep = trop_twostep.fit(data, ...) + print(f"Two-step ATT: {results_twostep.att:.3f}") + print(f"Joint ATT: {results_joint.att:.3f}") + Examining factor structure:: # Get the estimated factor matrix diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 6f779c7b..36febbc9 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -571,6 +571,87 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - [x] D matrix semantics documented (absorbing state, not event indicator) - [x] Unbalanced panels supported (missing observations don't trigger false violations) +### TROP Joint Optimization Method + +**Method**: `method="joint"` in TROP estimator + +**Approach**: Joint weighted least squares with optional nuclear norm penalty. +Estimates fixed effects, factor matrix, and scalar treatment effect simultaneously. + +**Objective function** (Equation J1): +``` +min_{μ, α, β, L, τ} Σ_{i,t} δ_{it} × (Y_{it} - μ - α_i - β_t - L_{it} - W_{it}×τ)² + λ_nn×||L||_* +``` + +where: +- δ_{it} = δ_time(t) × δ_unit(i) are observation weights (product of time and unit weights) +- μ is the intercept +- α_i are unit fixed effects +- β_t are time fixed effects +- L_{it} is the low-rank factor component +- τ is a **single scalar** (homogeneous treatment effect assumption) +- W_{it} is the treatment indicator + +**Weight computation** (differs from twostep): +- Time weights: δ_time(t) = exp(-λ_time × |t - center|) where center = T - treated_periods/2 +- Unit weights: δ_unit(i) = exp(-λ_unit × RMSE(i, treated_avg)) + where RMSE is computed over pre-treatment periods comparing to average treated trajectory + +**Implementation approach** (without CVXPY): + +1. **Without low-rank (λ_nn = ∞)**: Standard weighted least squares + - Build design matrix with unit/time dummies + treatment indicator + - Solve via iterative coordinate descent for (μ, α, β, τ) + +2. **With low-rank (finite λ_nn)**: Alternating minimization + - Alternate between: + - Fix L, solve weighted LS for (μ, α, β, τ) + - Fix (μ, α, β, τ), soft-threshold SVD for L (proximal step) + - Continue until convergence + +**LOOCV parameter selection** (unified with twostep, Equation 5): +Following paper's Equation 5 and footnote 2: +``` +Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² +``` +where τ̂_js^loocv is the pseudo-treatment effect at control observation (j,s) +with that observation excluded from fitting. + +For joint method, LOOCV works as follows: +1. For each control observation (t, i): + - Zero out weight δ_{ti} = 0 (exclude from weighted objective) + - Fit joint model on remaining data → obtain (μ̂, α̂, β̂, L̂) + - Compute pseudo-treatment: τ̂_{ti} = Y_{ti} - μ̂ - α̂_i - β̂_t - L̂_{ti} +2. Score = Σ τ̂_{ti}² (sum of squared pseudo-treatment effects) +3. Select λ combination that minimizes Q(λ) + +**Rust acceleration**: The LOOCV grid search is parallelized in Rust for 5-10x speedup. +- `loocv_grid_search_joint()` - Parallel LOOCV across all λ combinations +- `bootstrap_trop_variance_joint()` - Parallel bootstrap variance estimation + +**Key differences from twostep method**: +- Treatment effect τ is a single scalar (homogeneous assumption) vs. per-observation τ_{it} +- Global weights (distance to treated block center) vs. per-observation weights +- Single model fit per λ combination vs. N_treated fits +- Faster computation for large panels + +**Assumptions**: +- **Simultaneous adoption (enforced)**: The joint method requires all treated units + to receive treatment at the same time. A `ValueError` is raised if staggered + adoption is detected (units first treated at different periods). Treatment timing is + inferred once and held constant for bootstrap/jackknife variance estimation. + For staggered adoption designs, use `method="twostep"`. + +**Reference**: Adapted from reference implementation. See also Athey et al. (2025). + +**Requirements checklist:** +- [x] Same LOOCV framework as twostep (Equation 5) +- [x] Global weight computation using treated block center +- [x] Weighted least squares with treatment indicator +- [x] Alternating minimization for nuclear norm penalty +- [x] Returns scalar τ (homogeneous treatment effect) +- [x] Rust acceleration for LOOCV and bootstrap + --- # Diagnostics & Sensitivity diff --git a/docs/tutorials/10_trop.ipynb b/docs/tutorials/10_trop.ipynb index e52c8a15..ab25d9d6 100644 --- a/docs/tutorials/10_trop.ipynb +++ b/docs/tutorials/10_trop.ipynb @@ -3,29 +3,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "# Triply Robust Panel (TROP) Estimator\n", - "\n", - "This notebook demonstrates the **Triply Robust Panel (TROP)** estimator (Athey, Imbens, Qu & Viviano, 2025), which combines three robustness components:\n", - "\n", - "1. **Nuclear Norm Regularized Factor Model**: Estimates interactive fixed effects via matrix completion with nuclear norm penalty\n", - "2. **Exponential Distance-Based Unit Weights**: ω_j = exp(-λ_unit × dist(j,i)) where dist(j,i) is the root mean squared difference in outcomes between units j and i, computed only on periods where both units are untreated and excluding the target period t (Equation 3 in the paper)\n", - "3. **Exponential Time Decay Weights**: θ_s = exp(-λ_time × |s-t|) weighting by proximity to treatment\n", - "\n", - "**Weight Normalization**: Following the paper, the observation-specific weights ω and θ are treated as probability weights that effectively sum to one within each treated observation's counterfactual estimation.\n", - "\n", - "TROP is particularly useful when:\n", - "- There may be unobserved time-varying confounders with factor structure\n", - "- Standard DiD or SDID may be biased due to latent factors\n", - "- You want robust inference under factor confounding\n", - "\n", - "We'll cover:\n", - "1. When to use TROP\n", - "2. Basic estimation with LOOCV tuning\n", - "3. Understanding tuning parameters\n", - "4. Examining factor structure\n", - "5. Comparing TROP vs SDID" - ] + "source": "# Triply Robust Panel (TROP) Estimator\n\nThis notebook demonstrates the **Triply Robust Panel (TROP)** estimator (Athey, Imbens, Qu & Viviano, 2025), which combines three robustness components:\n\n1. **Nuclear Norm Regularized Factor Model**: Estimates interactive fixed effects via matrix completion with nuclear norm penalty\n2. **Exponential Distance-Based Unit Weights**: ω_j = exp(-λ_unit × dist(j,i)) where dist(j,i) is the root mean squared difference in outcomes between units j and i, computed only on periods where both units are untreated and excluding the target period t (Equation 3 in the paper)\n3. **Exponential Time Decay Weights**: θ_s = exp(-λ_time × |s-t|) weighting by proximity to treatment\n\n**Weights**: The observation-specific weights ω and θ are importance weights that control the relative contribution of each observation to counterfactual estimation. Higher weights indicate more relevant observations for the target counterfactual.\n\nTROP is particularly useful when:\n- There may be unobserved time-varying confounders with factor structure\n- Standard DiD or SDID may be biased due to latent factors\n- You want robust inference under factor confounding\n\nWe'll cover:\n1. When to use TROP\n2. Basic estimation with LOOCV tuning\n3. Understanding tuning parameters\n4. Examining factor structure\n5. Comparing TROP vs SDID" }, { "cell_type": "code", @@ -651,6 +629,18 @@ " print(f\" 95% CI: [{res.conf_int[0]:.4f}, {res.conf_int[1]:.4f}]\")" ] }, + { + "cell_type": "code", + "source": "# Compare estimation methods\nprint(\"Estimation method comparison:\")\nprint(\"=\"*60)\n\nimport time\n\n# Two-step method (default)\nstart = time.time()\ntrop_twostep = TROP(\n method='twostep',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_twostep = trop_twostep.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\ntwostep_time = time.time() - start\n\n# Joint method\nstart = time.time()\ntrop_joint = TROP(\n method='joint',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_joint = trop_joint.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\njoint_time = time.time() - start\n\nprint(f\"\\n{'Method':<15} {'ATT':>10} {'SE':>10} {'Time (s)':>12}\")\nprint(\"-\"*60)\nprint(f\"{'Two-step':<15} {results_twostep.att:>10.4f} {results_twostep.se:>10.4f} {twostep_time:>12.2f}\")\nprint(f\"{'Joint':<15} {results_joint.att:>10.4f} {results_joint.se:>10.4f} {joint_time:>12.2f}\")\nprint(f\"\\nTrue ATT: {true_att}\")\nprint(f\"Two-step bias: {results_twostep.att - true_att:.4f}\")\nprint(f\"Joint bias: {results_joint.att - true_att:.4f}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": "## 10. Estimation Methods: Two-Step vs Joint\n\nTROP supports two estimation methods via the `method` parameter:\n\n**Two-Step Method** (`method='twostep'`, default):\n- Follows Algorithm 2 from the paper\n- Computes observation-specific weights for each treated observation\n- Fits a model per treated observation, then averages the individual effects\n- More flexible, allows for heterogeneous treatment effects\n- Computationally intensive (N_treated optimizations)\n\n**Joint Method** (`method='joint'`):\n- Weighted least squares with a single scalar treatment effect τ\n- Weights computed once (distance to center of treated block)\n- With low-rank: uses alternating minimization between weighted LS and soft-threshold SVD\n- Faster but assumes homogeneous treatment effects", + "metadata": {} + }, { "cell_type": "markdown", "metadata": {}, @@ -677,16 +667,11 @@ ] }, { - "cell_type": "code", + "cell_type": "markdown", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# Convert to DataFrame\n", - "results_df = results.to_dataframe()\n", - "print(\"\\nResults as DataFrame:\")\n", - "print(results_df.T)" - ] + "source": "## Summary\n\nKey takeaways for TROP:\n\n1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n5. **Time weights**: Exponential decay weighting of pre-treatment periods\n6. **Weights**: Importance weights controlling relative contribution of observations (higher = more relevant)\n7. **Estimation methods**:\n - `method='twostep'` (default): Per-observation estimation, allows heterogeneous effects\n - `method='joint'`: Single scalar treatment effect, faster but assumes homogeneity\n\n**When to use TROP vs SDID**:\n- Use **SDID** when parallel trends is plausible and factors are not a concern\n- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n- Running both provides a useful robustness check\n\n**When to use twostep vs joint method**:\n- Use **twostep** (default) for maximum flexibility and heterogeneous treatment effects\n- Use **joint** for faster estimation when effects are expected to be homogeneous\n\n**Reference**:\n- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" }, { "cell_type": "code", @@ -703,26 +688,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Key takeaways for TROP:\n", - "\n", - "1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n", - "2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n", - "3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n", - "4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n", - "5. **Time weights**: Exponential decay weighting of pre-treatment periods\n", - "6. **Weight normalization**: Weights are treated as probability weights that sum to one\n", - "\n", - "**When to use TROP vs SDID**:\n", - "- Use **SDID** when parallel trends is plausible and factors are not a concern\n", - "- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n", - "- Running both provides a useful robustness check\n", - "\n", - "**Reference**:\n", - "- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" - ] + "source": "## Summary\n\nKey takeaways for TROP:\n\n1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n5. **Time weights**: Exponential decay weighting of pre-treatment periods\n6. **Weights**: Importance weights controlling relative contribution of observations (higher = more relevant)\n\n**When to use TROP vs SDID**:\n- Use **SDID** when parallel trends is plausible and factors are not a concern\n- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n- Running both provides a useful robustness check\n\n**Reference**:\n- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" } ], "metadata": { @@ -732,4 +698,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/rust/src/lib.rs b/rust/src/lib.rs index eb168d4f..c11d3a99 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -27,11 +27,15 @@ fn _rust_backend(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(linalg::solve_ols, m)?)?; m.add_function(wrap_pyfunction!(linalg::compute_robust_vcov, m)?)?; - // TROP estimator acceleration + // TROP estimator acceleration (twostep method) m.add_function(wrap_pyfunction!(trop::compute_unit_distance_matrix, m)?)?; m.add_function(wrap_pyfunction!(trop::loocv_grid_search, m)?)?; m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance, m)?)?; + // TROP estimator acceleration (joint method) + m.add_function(wrap_pyfunction!(trop::loocv_grid_search_joint, m)?)?; + m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance_joint, m)?)?; + // Version info m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 382badd5..2f269946 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -1024,6 +1024,671 @@ pub fn bootstrap_trop_variance<'py>( Ok((estimates_arr.into_pyarray(py), se)) } +// ============================================================================ +// Joint method implementation +// ============================================================================ + +/// Compute global weights for joint method estimation. +/// +/// Unlike twostep (which computes per-observation weights), joint uses global +/// weights based on: +/// - Time weights: distance to center of treated block +/// - Unit weights: RMSE to average treated trajectory over pre-periods +/// +/// # Arguments +/// * `y` - Outcome matrix (n_periods x n_units) +/// * `d` - Treatment indicator matrix (n_periods x n_units) +/// * `lambda_time` - Time weight decay parameter +/// * `lambda_unit` - Unit weight decay parameter +/// * `treated_periods` - Number of post-treatment periods +/// +/// # Returns +/// Weight matrix (n_periods x n_units) +fn compute_joint_weights( + y: &ArrayView2, + d: &ArrayView2, + lambda_time: f64, + lambda_unit: f64, + treated_periods: usize, +) -> Array2 { + let n_periods = y.nrows(); + let n_units = y.ncols(); + + // Identify treated units (ever treated) + let mut treated_unit_idx: Vec = Vec::new(); + for i in 0..n_units { + if (0..n_periods).any(|t| d[[t, i]] == 1.0) { + treated_unit_idx.push(i); + } + } + + // Time weights: distance to center of treated block + // center = T - treated_periods / 2 + let center = n_periods as f64 - treated_periods as f64 / 2.0; + let mut delta_time = Array1::::zeros(n_periods); + for t in 0..n_periods { + let dist = (t as f64 - center).abs(); + delta_time[t] = (-lambda_time * dist).exp(); + } + + // Unit weights: RMSE to average treated trajectory over pre-periods + let n_pre = n_periods.saturating_sub(treated_periods); + + // Compute average treated trajectory + // Initialize to NaN so periods with all-NaN treated data stay NaN (excluded from RMSE) + let mut average_treated = Array1::::from_elem(n_periods, f64::NAN); + if !treated_unit_idx.is_empty() { + for t in 0..n_periods { + let mut sum = 0.0; + let mut count = 0; + for &i in &treated_unit_idx { + if y[[t, i]].is_finite() { + sum += y[[t, i]]; + count += 1; + } + } + if count > 0 { + average_treated[t] = sum / count as f64; + } + // If count == 0, average_treated[t] stays NaN (correctly excluded) + } + } + + // Compute RMS distance for each unit over pre-periods + let mut delta_unit = Array1::::zeros(n_units); + for i in 0..n_units { + if n_pre > 0 { + let mut sum_sq = 0.0; + let mut n_valid = 0; + for t in 0..n_pre { + if y[[t, i]].is_finite() && average_treated[t].is_finite() { + let diff = average_treated[t] - y[[t, i]]; + sum_sq += diff * diff; + n_valid += 1; + } + } + let dist = if n_valid > 0 { + (sum_sq / n_valid as f64).sqrt() + } else { + // No valid pre-period observations for this unit. + // Set dist = infinity so delta_unit = exp(-infinity) = 0. + // This ensures units with no valid pre-period data get zero weight, + // matching the Python behavior. + f64::INFINITY + }; + delta_unit[i] = (-lambda_unit * dist).exp(); + } else { + delta_unit[i] = 1.0; + } + } + + // Outer product: (n_periods x n_units) + let mut delta = Array2::::zeros((n_periods, n_units)); + for t in 0..n_periods { + for i in 0..n_units { + delta[[t, i]] = delta_time[t] * delta_unit[i]; + } + } + + delta +} + +/// Solve joint TWFE + treatment via weighted least squares (no low-rank). +/// +/// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - τ*W_{it})² +/// +/// # Returns +/// (mu, alpha, beta, tau) estimated parameters +fn solve_joint_no_lowrank( + y: &ArrayView2, + d: &ArrayView2, + delta: &ArrayView2, +) -> Option<(f64, Array1, Array1, f64)> { + let n_periods = y.nrows(); + let n_units = y.ncols(); + + // We solve using normal equations with the design matrix structure + // Rather than build full X matrix, use block structure for efficiency + // + // The model: Y_it = μ + α_i + β_t + τ*D_it + ε_it + // With identification: α_0 = β_0 = 0 + + // Compute weighted sums needed for normal equations + let mut sum_w = 0.0; + let mut sum_wy = 0.0; + + // Per-unit and per-period weighted sums + let mut sum_w_by_unit = Array1::::zeros(n_units); + let mut sum_wy_by_unit = Array1::::zeros(n_units); + let mut sum_w_by_period = Array1::::zeros(n_periods); + let mut sum_wy_by_period = Array1::::zeros(n_periods); + + for t in 0..n_periods { + for i in 0..n_units { + // NaN outcomes get zero weight (not imputed to 0.0 with active weight) + let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; + let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; + + sum_w += w; + sum_wy += w * y_ti; + + sum_w_by_unit[i] += w; + sum_wy_by_unit[i] += w * y_ti; + sum_w_by_period[t] += w; + sum_wy_by_period[t] += w * y_ti; + } + } + + if sum_w < 1e-10 { + return None; + } + + // Use iterative approach: alternate between (alpha, beta, tau) and mu + // until convergence (simpler than full normal equations) + let mut mu = sum_wy / sum_w; + let mut alpha = Array1::::zeros(n_units); + let mut beta = Array1::::zeros(n_periods); + let mut tau = 0.0; + + for _ in 0..50 { + let mu_old = mu; + let tau_old = tau; + + // Update alpha (fixing beta, tau, mu) + for i in 1..n_units { // α_0 = 0 for identification + if sum_w_by_unit[i] > 1e-10 { + let mut num = 0.0; + for t in 0..n_periods { + // NaN outcomes get zero weight + let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; + let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; + num += w * (y_ti - mu - beta[t] - tau * d[[t, i]]); + } + alpha[i] = num / sum_w_by_unit[i]; + } + } + + // Update beta (fixing alpha, tau, mu) + for t in 1..n_periods { // β_0 = 0 for identification + if sum_w_by_period[t] > 1e-10 { + let mut num = 0.0; + for i in 0..n_units { + // NaN outcomes get zero weight + let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; + let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; + num += w * (y_ti - mu - alpha[i] - tau * d[[t, i]]); + } + beta[t] = num / sum_w_by_period[t]; + } + } + + // Update tau (fixing alpha, beta, mu) + let mut num_tau = 0.0; + let mut denom_tau = 0.0; + for t in 0..n_periods { + for i in 0..n_units { + // NaN outcomes get zero weight + let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; + let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; + let d_ti = d[[t, i]]; + if d_ti > 0.5 { // Only treated observations contribute + num_tau += w * d_ti * (y_ti - mu - alpha[i] - beta[t]); + denom_tau += w * d_ti * d_ti; + } + } + } + if denom_tau > 1e-10 { + tau = num_tau / denom_tau; + } + + // Update mu (fixing alpha, beta, tau) + let mut num_mu = 0.0; + for t in 0..n_periods { + for i in 0..n_units { + // NaN outcomes get zero weight + let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; + let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; + num_mu += w * (y_ti - alpha[i] - beta[t] - tau * d[[t, i]]); + } + } + mu = num_mu / sum_w; + + // Check convergence + if (mu - mu_old).abs() < 1e-8 && (tau - tau_old).abs() < 1e-8 { + break; + } + } + + Some((mu, alpha, beta, tau)) +} + +/// Solve joint TWFE + treatment + low-rank via alternating minimization. +/// +/// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it} - τ*W_{it})² + λ_nn||L||_* +/// +/// # Returns +/// (mu, alpha, beta, L, tau) estimated parameters +fn solve_joint_with_lowrank( + y: &ArrayView2, + d: &ArrayView2, + delta: &ArrayView2, + lambda_nn: f64, + max_iter: usize, + tol: f64, +) -> Option<(f64, Array1, Array1, Array2, f64)> { + let n_periods = y.nrows(); + let n_units = y.ncols(); + + // Initialize L = 0 + let mut l = Array2::::zeros((n_periods, n_units)); + + for _ in 0..max_iter { + let l_old = l.clone(); + + // Step 1: Fix L, solve for (mu, alpha, beta, tau) + // Adjusted outcome: Y - L (preserve NaN so solve_joint_no_lowrank masks weights) + let y_adj = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { + y[[t, i]] - l[[t, i]] // NaN - finite = NaN (preserves NaN info) + }); + + let (mu, alpha, beta, tau) = solve_joint_no_lowrank(&y_adj.view(), d, delta)?; + + // Step 2: Fix (mu, alpha, beta, tau), update L + // Residual: R = Y - mu - alpha - beta - tau*D (preserve NaN) + let mut r = Array2::::zeros((n_periods, n_units)); + for t in 0..n_periods { + for i in 0..n_units { + // NaN - finite = NaN (will be masked in gradient step) + r[[t, i]] = y[[t, i]] - mu - alpha[i] - beta[t] - tau * d[[t, i]]; + } + } + + // Weighted proximal step for L (soft-threshold SVD) + let delta_max = delta.iter().cloned().fold(0.0_f64, f64::max); + let eta = if delta_max > 0.0 { 1.0 / delta_max } else { 1.0 }; + + // gradient_step = L + eta * delta * (R - L) + // NaN outcomes get zero weight so they don't affect gradient + let mut gradient_step = Array2::::zeros((n_periods, n_units)); + for t in 0..n_periods { + for i in 0..n_units { + // Mask delta for NaN outcomes + let delta_ti = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; + let delta_norm = if delta_max > 0.0 { + delta_ti / delta_max + } else { + delta_ti + }; + // r[[t,i]] may be NaN, but delta_norm=0 for NaN obs, so contribution=0 + let r_contrib = if r[[t, i]].is_finite() { r[[t, i]] } else { 0.0 }; + gradient_step[[t, i]] = l[[t, i]] + delta_norm * (r_contrib - l[[t, i]]); + } + } + + // Soft-threshold singular values + l = soft_threshold_svd(&gradient_step, eta * lambda_nn)?; + + // Check convergence + let l_diff = max_abs_diff_2d(&l, &l_old); + if l_diff < tol { + break; + } + } + + // Final solve with converged L (preserve NaN so solve_joint_no_lowrank masks weights) + let y_adj = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { + y[[t, i]] - l[[t, i]] // NaN - finite = NaN (preserves NaN info) + }); + let (mu, alpha, beta, tau) = solve_joint_no_lowrank(&y_adj.view(), d, delta)?; + + Some((mu, alpha, beta, l, tau)) +} + +/// Compute LOOCV score for joint method with specific parameter combination. +/// +/// Following paper's Equation 5: +/// Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² +/// +/// For joint method, we exclude each control observation, fit the joint model +/// on remaining data, and compute the pseudo-treatment effect at the excluded obs. +/// +/// # Returns +/// (score, n_valid, first_failed_obs) +#[allow(clippy::too_many_arguments)] +fn loocv_score_joint( + y: &ArrayView2, + d: &ArrayView2, + control_obs: &[(usize, usize)], + lambda_time: f64, + lambda_unit: f64, + lambda_nn: f64, + treated_periods: usize, + max_iter: usize, + tol: f64, +) -> (f64, usize, Option<(usize, usize)>) { + let n_periods = y.nrows(); + let n_units = y.ncols(); + + let mut tau_sq_sum = 0.0; + let mut n_valid = 0usize; + + // Compute global weights (same for all LOOCV iterations) + let delta = compute_joint_weights(y, d, lambda_time, lambda_unit, treated_periods); + + for &(t_ex, i_ex) in control_obs { + // Create modified delta with excluded observation zeroed out + let mut delta_ex = delta.clone(); + delta_ex[[t_ex, i_ex]] = 0.0; + + // Fit joint model excluding this observation + let result = if lambda_nn >= 1e10 { + solve_joint_no_lowrank(y, d, &delta_ex.view()) + .map(|(mu, alpha, beta, tau)| { + let l = Array2::::zeros((n_periods, n_units)); + (mu, alpha, beta, l, tau) + }) + } else { + solve_joint_with_lowrank(y, d, &delta_ex.view(), lambda_nn, max_iter, tol) + }; + + match result { + Some((mu, alpha, beta, l, _tau)) => { + // Pseudo treatment effect: τ = Y - μ - α - β - L + let y_ti = if y[[t_ex, i_ex]].is_finite() { + y[[t_ex, i_ex]] + } else { + continue; + }; + let tau_loocv = y_ti - mu - alpha[i_ex] - beta[t_ex] - l[[t_ex, i_ex]]; + tau_sq_sum += tau_loocv * tau_loocv; + n_valid += 1; + } + None => { + // Any failure means this λ combination is invalid per Equation 5 + return (f64::INFINITY, n_valid, Some((t_ex, i_ex))); + } + } + } + + if n_valid == 0 { + (f64::INFINITY, 0, None) + } else { + (tau_sq_sum, n_valid, None) + } +} + +/// Perform LOOCV grid search for joint method using parallel grid search. +/// +/// Evaluates all combinations of (lambda_time, lambda_unit, lambda_nn) in parallel +/// and returns the combination with lowest LOOCV score. +/// +/// # Arguments +/// * `y` - Outcome matrix (n_periods x n_units) +/// * `d` - Treatment indicator matrix (n_periods x n_units) +/// * `control_mask` - Boolean mask (n_periods x n_units) for control observations +/// * `lambda_time_grid` - Grid of time decay parameters +/// * `lambda_unit_grid` - Grid of unit distance parameters +/// * `lambda_nn_grid` - Grid of nuclear norm parameters +/// * `max_loocv_samples` - Maximum control observations to evaluate +/// * `max_iter` - Maximum iterations for model estimation +/// * `tol` - Convergence tolerance +/// * `seed` - Random seed for subsampling +/// +/// # Returns +/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score, n_valid, n_attempted, first_failed_obs) +#[pyfunction] +#[pyo3(signature = (y, d, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))] +#[allow(clippy::too_many_arguments)] +pub fn loocv_grid_search_joint<'py>( + _py: Python<'py>, + y: PyReadonlyArray2<'py, f64>, + d: PyReadonlyArray2<'py, f64>, + control_mask: PyReadonlyArray2<'py, u8>, + lambda_time_grid: PyReadonlyArray1<'py, f64>, + lambda_unit_grid: PyReadonlyArray1<'py, f64>, + lambda_nn_grid: PyReadonlyArray1<'py, f64>, + max_loocv_samples: usize, + max_iter: usize, + tol: f64, + seed: u64, +) -> PyResult<(f64, f64, f64, f64, usize, usize, Option<(usize, usize)>)> { + let y_arr = y.as_array(); + let d_arr = d.as_array(); + let control_mask_arr = control_mask.as_array(); + let lambda_time_vec: Vec = lambda_time_grid.as_array().to_vec(); + let lambda_unit_vec: Vec = lambda_unit_grid.as_array().to_vec(); + let lambda_nn_vec: Vec = lambda_nn_grid.as_array().to_vec(); + + let n_periods = y_arr.nrows(); + let n_units = y_arr.ncols(); + + // Determine treated periods from D matrix + let mut first_treat_period = n_periods; + for t in 0..n_periods { + for i in 0..n_units { + if d_arr[[t, i]] == 1.0 { + first_treat_period = first_treat_period.min(t); + break; + } + } + } + let treated_periods = n_periods.saturating_sub(first_treat_period); + + // Get control observations for LOOCV + let control_obs = get_control_observations(&y_arr, &control_mask_arr, max_loocv_samples, seed); + let n_attempted = control_obs.len(); + + // Build grid combinations + let mut grid_combinations: Vec<(f64, f64, f64)> = Vec::new(); + for < in &lambda_time_vec { + for &lu in &lambda_unit_vec { + for &ln in &lambda_nn_vec { + grid_combinations.push((lt, lu, ln)); + } + } + } + + // Parallel grid search - try all combinations + let results: Vec<(f64, f64, f64, f64, usize, Option<(usize, usize)>)> = grid_combinations + .into_par_iter() + .map(|(lt, lu, ln)| { + // Convert infinity values + let lt_eff = if lt.is_infinite() { 0.0 } else { lt }; + let lu_eff = if lu.is_infinite() { 0.0 } else { lu }; + let ln_eff = if ln.is_infinite() { 1e10 } else { ln }; + + let (score, n_valid, first_failed) = loocv_score_joint( + &y_arr, + &d_arr, + &control_obs, + lt_eff, + lu_eff, + ln_eff, + treated_periods, + max_iter, + tol, + ); + + (lt, lu, ln, score, n_valid, first_failed) + }) + .collect(); + + // Find best result + let mut best_result = ( + lambda_time_vec.first().copied().unwrap_or(0.0), + lambda_unit_vec.first().copied().unwrap_or(0.0), + lambda_nn_vec.first().copied().unwrap_or(0.0), + f64::INFINITY, + 0usize, + None, + ); + + for (lt, lu, ln, score, n_valid, first_failed) in results { + if score < best_result.3 { + best_result = (lt, lu, ln, score, n_valid, first_failed); + } + } + + let (best_lt, best_lu, best_ln, best_score, n_valid, first_failed) = best_result; + + Ok((best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed)) +} + +/// Compute bootstrap variance estimation for TROP joint method in parallel. +/// +/// Performs unit-level block bootstrap, parallelizing across bootstrap iterations. +/// Uses stratified sampling to preserve treated/control unit ratio. +/// +/// # Arguments +/// * `y` - Outcome matrix (n_periods x n_units) +/// * `d` - Treatment indicator matrix (n_periods x n_units) +/// * `lambda_time` - Selected time decay parameter +/// * `lambda_unit` - Selected unit distance parameter +/// * `lambda_nn` - Selected nuclear norm parameter +/// * `n_bootstrap` - Number of bootstrap iterations +/// * `max_iter` - Maximum iterations for model estimation +/// * `tol` - Convergence tolerance +/// * `seed` - Random seed +/// +/// # Returns +/// (bootstrap_estimates, standard_error) +#[pyfunction] +#[pyo3(signature = (y, d, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))] +#[allow(clippy::too_many_arguments)] +pub fn bootstrap_trop_variance_joint<'py>( + py: Python<'py>, + y: PyReadonlyArray2<'py, f64>, + d: PyReadonlyArray2<'py, f64>, + lambda_time: f64, + lambda_unit: f64, + lambda_nn: f64, + n_bootstrap: usize, + max_iter: usize, + tol: f64, + seed: u64, +) -> PyResult<(&'py PyArray1, f64)> { + let y_arr = y.as_array().to_owned(); + let d_arr = d.as_array().to_owned(); + + let n_units = y_arr.ncols(); + let n_periods = y_arr.nrows(); + + // Identify treated and control units for stratified sampling + let mut original_treated_units: Vec = Vec::new(); + let mut original_control_units: Vec = Vec::new(); + for i in 0..n_units { + let is_ever_treated = (0..n_periods).any(|t| d_arr[[t, i]] == 1.0); + if is_ever_treated { + original_treated_units.push(i); + } else { + original_control_units.push(i); + } + } + let n_treated_units = original_treated_units.len(); + let n_control_units = original_control_units.len(); + + // Determine treated periods from D matrix + let mut first_treat_period = n_periods; + for t in 0..n_periods { + for i in 0..n_units { + if d_arr[[t, i]] == 1.0 { + first_treat_period = first_treat_period.min(t); + break; + } + } + } + let treated_periods = n_periods.saturating_sub(first_treat_period); + + // Convert infinity values for computation + let lt_eff = if lambda_time.is_infinite() { 0.0 } else { lambda_time }; + let lu_eff = if lambda_unit.is_infinite() { 0.0 } else { lambda_unit }; + let ln_eff = if lambda_nn.is_infinite() { 1e10 } else { lambda_nn }; + + // Run bootstrap iterations in parallel + let bootstrap_estimates: Vec = (0..n_bootstrap) + .into_par_iter() + .filter_map(|b| { + use rand::prelude::*; + use rand_xoshiro::Xoshiro256PlusPlus; + + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(b as u64)); + + // Stratified sampling - sample control and treated units separately + let mut sampled_units: Vec = Vec::with_capacity(n_units); + + // Sample control units with replacement + for _ in 0..n_control_units { + if n_control_units > 0 { + let idx = rng.gen_range(0..n_control_units); + sampled_units.push(original_control_units[idx]); + } + } + + // Sample treated units with replacement + for _ in 0..n_treated_units { + if n_treated_units > 0 { + let idx = rng.gen_range(0..n_treated_units); + sampled_units.push(original_treated_units[idx]); + } + } + + // Create bootstrap matrices by selecting columns + let mut y_boot = Array2::::zeros((n_periods, n_units)); + let mut d_boot = Array2::::zeros((n_periods, n_units)); + + for (new_idx, &old_idx) in sampled_units.iter().enumerate() { + for t in 0..n_periods { + y_boot[[t, new_idx]] = y_arr[[t, old_idx]]; + d_boot[[t, new_idx]] = d_arr[[t, old_idx]]; + } + } + + // Compute weights and fit joint model + let delta = compute_joint_weights( + &y_boot.view(), + &d_boot.view(), + lt_eff, + lu_eff, + treated_periods, + ); + + let result = if ln_eff >= 1e10 { + solve_joint_no_lowrank(&y_boot.view(), &d_boot.view(), &delta.view()) + .map(|(_, _, _, tau)| tau) + } else { + solve_joint_with_lowrank( + &y_boot.view(), + &d_boot.view(), + &delta.view(), + ln_eff, + max_iter, + tol, + ) + .map(|(_, _, _, _, tau)| tau) + }; + + result + }) + .collect(); + + // Compute standard error + let se = if bootstrap_estimates.len() < 2 { + 0.0 + } else { + let n = bootstrap_estimates.len() as f64; + let mean = bootstrap_estimates.iter().sum::() / n; + let variance = bootstrap_estimates + .iter() + .map(|x| (x - mean).powi(2)) + .sum::() + / (n - 1.0); + variance.sqrt() + }; + + let estimates_arr = Array1::from_vec(bootstrap_estimates); + Ok((estimates_arr.into_pyarray(py), se)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/test_rust_backend.py b/tests/test_rust_backend.py index 332f5321..76edf80e 100644 --- a/tests/test_rust_backend.py +++ b/tests/test_rust_backend.py @@ -1134,6 +1134,540 @@ def test_trop_produces_valid_results(self): assert results.lambda_nn in [0.0, 0.1] +@pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") +class TestTROPJointRustBackend: + """Test suite for TROP joint method Rust backend functions.""" + + def test_loocv_grid_search_joint_returns_valid_result(self): + """Test loocv_grid_search_joint returns valid tuning parameters.""" + from diff_diff._rust_backend import loocv_grid_search_joint + + np.random.seed(42) + n_periods, n_units = 10, 20 + n_treated = 5 + n_post = 3 + + # Generate simple data + Y = np.random.randn(n_periods, n_units) + D = np.zeros((n_periods, n_units)) + D[-n_post:, :n_treated] = 1.0 + + control_mask = (D == 0).astype(np.uint8) + lambda_time_grid = np.array([0.0, 1.0]) + lambda_unit_grid = np.array([0.0, 1.0]) + lambda_nn_grid = np.array([0.0, 0.1]) + + result = loocv_grid_search_joint( + Y, D, control_mask, + lambda_time_grid, lambda_unit_grid, lambda_nn_grid, + 50, 100, 1e-6, 42 + ) + + best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, _ = result + + # Check types and bounds + assert isinstance(best_lt, float) + assert isinstance(best_lu, float) + assert isinstance(best_ln, float) + assert best_lt in [0.0, 1.0] + assert best_lu in [0.0, 1.0] + assert best_ln in [0.0, 0.1] + assert n_valid > 0 + assert n_attempted > 0 + assert best_score >= 0 or np.isinf(best_score) + + def test_loocv_grid_search_joint_reproducible(self): + """Test loocv_grid_search_joint is reproducible with same seed.""" + from diff_diff._rust_backend import loocv_grid_search_joint + + np.random.seed(42) + n_periods, n_units = 8, 15 + n_treated = 4 + n_post = 2 + + Y = np.random.randn(n_periods, n_units) + D = np.zeros((n_periods, n_units)) + D[-n_post:, :n_treated] = 1.0 + + control_mask = (D == 0).astype(np.uint8) + lambda_time_grid = np.array([0.0, 0.5]) + lambda_unit_grid = np.array([0.0, 0.5]) + lambda_nn_grid = np.array([0.0, 0.1]) + + result1 = loocv_grid_search_joint( + Y, D, control_mask, + lambda_time_grid, lambda_unit_grid, lambda_nn_grid, + 30, 50, 1e-6, 42 + ) + result2 = loocv_grid_search_joint( + Y, D, control_mask, + lambda_time_grid, lambda_unit_grid, lambda_nn_grid, + 30, 50, 1e-6, 42 + ) + + # Same seed should produce same results + assert result1[:4] == result2[:4] + + def test_bootstrap_trop_variance_joint_shape(self): + """Test bootstrap_trop_variance_joint returns valid output.""" + from diff_diff._rust_backend import bootstrap_trop_variance_joint + + np.random.seed(42) + n_periods, n_units = 8, 15 + n_treated = 4 + n_post = 2 + + Y = np.random.randn(n_periods, n_units) + D = np.zeros((n_periods, n_units)) + D[-n_post:, :n_treated] = 1.0 + + estimates, se = bootstrap_trop_variance_joint( + Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 + ) + + assert isinstance(estimates, np.ndarray) + assert len(estimates) > 0 + assert isinstance(se, float) + assert se >= 0 + + def test_bootstrap_trop_variance_joint_reproducible(self): + """Test bootstrap_trop_variance_joint is reproducible.""" + from diff_diff._rust_backend import bootstrap_trop_variance_joint + + np.random.seed(42) + n_periods, n_units = 8, 15 + n_treated = 4 + n_post = 2 + + Y = np.random.randn(n_periods, n_units) + D = np.zeros((n_periods, n_units)) + D[-n_post:, :n_treated] = 1.0 + + est1, se1 = bootstrap_trop_variance_joint( + Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 + ) + est2, se2 = bootstrap_trop_variance_joint( + Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 + ) + + np.testing.assert_array_almost_equal(est1, est2) + np.testing.assert_almost_equal(se1, se2) + + +@pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") +class TestTROPJointRustVsNumpy: + """Tests comparing TROP joint Rust and NumPy implementations.""" + + def test_trop_joint_produces_valid_results(self): + """Test TROP joint with Rust backend produces valid results.""" + import pandas as pd + from diff_diff import TROP + + np.random.seed(42) + n_units, n_periods = 20, 10 + n_treated = 5 + n_post = 3 + true_effect = 2.0 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 10.0 + i * 0.2 + t * 0.3 + np.random.randn() * 0.5 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += true_effect + data.append({ + 'unit': i, + 'time': t, + 'outcome': y, + 'treated': treatment_indicator, + }) + + df = pd.DataFrame(data) + + trop = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=30, + seed=42 + ) + results = trop.fit(df, 'outcome', 'treated', 'unit', 'time') + + # Check results are valid + assert np.isfinite(results.att), "ATT should be finite" + assert np.isfinite(results.se), "SE should be finite" + assert results.se >= 0, "SE should be non-negative" + + # ATT should be positive (same direction as true effect) + assert results.att > 0, f"ATT {results.att:.2f} should be positive" + + # Tuning parameters should be from the grid + assert results.lambda_time in [0.0, 1.0] + assert results.lambda_unit in [0.0, 1.0] + assert results.lambda_nn in [0.0, 0.1] + + def test_trop_joint_and_twostep_agree_in_direction(self): + """Test joint and twostep methods agree on treatment effect direction.""" + import pandas as pd + from diff_diff import TROP + + np.random.seed(42) + n_units, n_periods = 20, 10 + n_treated = 5 + n_post = 3 + true_effect = 2.0 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 10.0 + i * 0.2 + t * 0.3 + np.random.randn() * 0.5 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += true_effect + data.append({ + 'unit': i, + 'time': t, + 'outcome': y, + 'treated': treatment_indicator, + }) + + df = pd.DataFrame(data) + + # Fit with joint method + trop_joint = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=20, + seed=42 + ) + results_joint = trop_joint.fit(df, 'outcome', 'treated', 'unit', 'time') + + # Fit with twostep method + trop_twostep = TROP( + method="twostep", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=20, + seed=42 + ) + results_twostep = trop_twostep.fit(df, 'outcome', 'treated', 'unit', 'time') + + # Both should have same sign (both positive for true_effect=2.0) + assert np.sign(results_joint.att) == np.sign(results_twostep.att) + + def test_trop_joint_handles_nan_outcomes(self): + """Test TROP joint method handles NaN outcome values gracefully.""" + import pandas as pd + from diff_diff import TROP + + np.random.seed(42) + n_units, n_periods = 20, 10 + n_treated = 5 + n_post = 3 + true_effect = 2.0 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 10.0 + i * 0.2 + t * 0.3 + np.random.randn() * 0.5 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += true_effect + data.append({ + 'unit': i, + 'time': t, + 'outcome': y, + 'treated': treatment_indicator, + }) + + df = pd.DataFrame(data) + + # Introduce NaN values in control observations (pre-treatment periods) + # Set 5% of control pre-treatment observations to NaN + nan_indices = [] + for idx, row in df.iterrows(): + if row['treated'] == 0 and row['time'] < (n_periods - n_post): + if np.random.rand() < 0.05: + nan_indices.append(idx) + df.loc[nan_indices, 'outcome'] = np.nan + + n_nan = len(nan_indices) + assert n_nan > 0, "Should have introduced some NaN values" + + trop = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=20, + seed=42 + ) + results = trop.fit(df, 'outcome', 'treated', 'unit', 'time') + + # Results should be finite (NaN observations are excluded) + assert np.isfinite(results.att), f"ATT {results.att} should be finite with NaN data" + assert np.isfinite(results.se), f"SE {results.se} should be finite with NaN data" + assert results.se >= 0, "SE should be non-negative" + + # ATT should still be positive (true effect is positive) + assert results.att > 0, f"ATT {results.att:.2f} should be positive" + + def test_trop_joint_no_valid_pre_unit_gets_zero_weight(self): + """Test that units with no valid pre-period data get zero weight. + + When a control unit has all NaN values in the pre-treatment period, + it should receive zero weight (not maximum weight). This prevents + such units from influencing the counterfactual estimation. + + This tests the fix for PR #113 Round 3 feedback (P1-1) where Rust + backend was setting dist=0 -> delta_unit=exp(0)=1.0 (max weight) + instead of dist=inf -> delta_unit=exp(-inf)=0.0 (zero weight). + """ + import pandas as pd + from diff_diff import TROP + + np.random.seed(42) + n_units, n_periods = 15, 10 + n_treated = 3 + n_post = 3 + true_effect = 2.0 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 10.0 + i * 0.2 + t * 0.3 + np.random.randn() * 0.3 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += true_effect + data.append({ + 'unit': i, + 'time': t, + 'outcome': y, + 'treated': treatment_indicator, + }) + + df = pd.DataFrame(data) + + # Set ALL pre-period outcomes to NaN for one control unit (unit n_treated) + # This unit has no valid pre-period data and should get zero weight + control_unit_with_no_pre = n_treated # First control unit + pre_mask = (df['unit'] == control_unit_with_no_pre) & (df['time'] < (n_periods - n_post)) + df.loc[pre_mask, 'outcome'] = np.nan + + # Verify we set NaN correctly + unit_pre_data = df[(df['unit'] == control_unit_with_no_pre) & (df['time'] < (n_periods - n_post))] + assert unit_pre_data['outcome'].isna().all(), "Control unit should have all NaN in pre-period" + + # Fit with joint method - should handle gracefully + trop = TROP( + method="joint", + lambda_time_grid=[0.5, 1.0], + lambda_unit_grid=[0.5, 1.0], + lambda_nn_grid=[0.0], + n_bootstrap=20, + seed=42 + ) + results = trop.fit(df, 'outcome', 'treated', 'unit', 'time') + + # Results should be finite - the unit with no valid pre-period data + # should get zero weight and not break estimation + assert np.isfinite(results.att), f"ATT {results.att} should be finite" + assert np.isfinite(results.se), f"SE {results.se} should be finite" + + # ATT should be in reasonable range of true effect + # The no-valid-pre unit getting zero weight shouldn't corrupt the estimate + assert abs(results.att - true_effect) < 1.5, \ + f"ATT {results.att:.2f} should be close to true effect {true_effect}" + + def test_trop_joint_nan_exclusion_rust_python_parity(self): + """Test Rust and Python backends produce matching results with NaN data. + + This verifies that when data contains NaN values: + 1. Both backends exclude NaN observations consistently + 2. ATT estimates are close (within tolerance) + 3. Neither backend produces corrupt results + + This tests the fix for PR #113 Round 3 feedback (P2-1). + """ + import os + import pandas as pd + from diff_diff import TROP + + np.random.seed(42) + n_units, n_periods = 20, 10 + n_treated = 5 + n_post = 3 + true_effect = 2.0 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 10.0 + i * 0.2 + t * 0.3 + np.random.randn() * 0.3 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += true_effect + data.append({ + 'unit': i, + 'time': t, + 'outcome': y, + 'treated': treatment_indicator, + }) + + df = pd.DataFrame(data) + + # Introduce scattered NaN values (5% of control pre-period observations) + np.random.seed(123) # Different seed for NaN placement + for idx, row in df.iterrows(): + if row['treated'] == 0 and row['time'] < (n_periods - n_post): + if np.random.rand() < 0.05: + df.loc[idx, 'outcome'] = np.nan + + n_nan = df['outcome'].isna().sum() + assert n_nan > 0, "Should have some NaN values" + + # Common TROP parameters + trop_params = dict( + method="joint", + lambda_time_grid=[0.5, 1.0], + lambda_unit_grid=[0.5, 1.0], + lambda_nn_grid=[0.0], + n_bootstrap=20, + seed=42 + ) + + # Run with Rust backend (current default when available) + trop_rust = TROP(**trop_params) + results_rust = trop_rust.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') + + # Run with Python-only backend using mock.patch to avoid module reload issues + # (Module reload breaks isinstance() checks in other tests due to class identity) + from unittest.mock import patch + import sys + trop_module = sys.modules['diff_diff.trop'] + + with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ + patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + + trop_python = TROP(**trop_params) + results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') + + # Both should produce finite results + assert np.isfinite(results_rust.att), f"Rust ATT {results_rust.att} should be finite" + assert np.isfinite(results_python.att), f"Python ATT {results_python.att} should be finite" + + # ATT estimates should be close (within reasonable tolerance) + # Allow some difference due to LOOCV randomness and numerical differences + att_diff = abs(results_rust.att - results_python.att) + assert att_diff < 0.5, \ + f"Rust ATT ({results_rust.att:.3f}) and Python ATT ({results_python.att:.3f}) " \ + f"differ by {att_diff:.3f}, should be < 0.5" + + # Both should recover true effect direction + assert results_rust.att > 0, f"Rust ATT {results_rust.att} should be positive" + assert results_python.att > 0, f"Python ATT {results_python.att} should be positive" + + def test_trop_joint_treated_pre_nan_rust_python_parity(self): + """Test Rust/Python parity when treated units have pre-period NaN. + + When all treated units have NaN at a pre-period, average_treated[t] = NaN. + Both backends should exclude this period from unit distance calculation + (both numerator and denominator) to avoid inflating valid_count. + + This tests the fix for PR #113 Round 5 feedback (P2). + """ + import os + import pandas as pd + from diff_diff import TROP + + np.random.seed(42) + n_units, n_periods = 20, 10 + n_treated = 5 + n_post = 3 + true_effect = 2.0 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 10.0 + i * 0.2 + t * 0.3 + np.random.randn() * 0.3 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += true_effect + data.append({ + 'unit': i, + 'time': t, + 'outcome': y, + 'treated': treatment_indicator, + }) + + df = pd.DataFrame(data) + + # Set ALL treated units' outcomes at period 3 (a pre-period) to NaN + # This makes average_treated[3] = NaN + target_period = 3 + treated_units = list(range(n_treated)) + mask = df['unit'].isin(treated_units) & (df['time'] == target_period) + df.loc[mask, 'outcome'] = np.nan + + # Verify we set NaN correctly + n_nan = df.loc[mask, 'outcome'].isna().sum() + assert n_nan == n_treated, f"Should have {n_treated} NaN, got {n_nan}" + + # Common TROP parameters + trop_params = dict( + method="joint", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.0], + n_bootstrap=20, + seed=42 + ) + + # Run with Rust backend (current default when available) + trop_rust = TROP(**trop_params) + results_rust = trop_rust.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') + + # Run with Python-only backend using mock.patch to avoid module reload issues + # (Module reload breaks isinstance() checks in other tests due to class identity) + from unittest.mock import patch + import sys + trop_module = sys.modules['diff_diff.trop'] + + with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ + patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + + trop_python = TROP(**trop_params) + results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') + + # Both should produce finite results + assert np.isfinite(results_rust.att), f"Rust ATT {results_rust.att} should be finite" + assert np.isfinite(results_python.att), f"Python ATT {results_python.att} should be finite" + + # ATT estimates should be close (within reasonable tolerance) + att_diff = abs(results_rust.att - results_python.att) + assert att_diff < 0.5, \ + f"Rust ATT ({results_rust.att:.3f}) and Python ATT ({results_python.att:.3f}) " \ + f"differ by {att_diff:.3f}, should be < 0.5" + + class TestFallbackWhenNoRust: """Test that pure Python fallback works when Rust is unavailable.""" diff --git a/tests/test_trop.py b/tests/test_trop.py index 2babc901..eb38e3a0 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2655,3 +2655,641 @@ def test_n_post_periods_counts_observed_treatment(self): assert results.n_post_periods == 2, ( f"Expected 2 post-periods with D=1, got {results.n_post_periods}" ) + + +class TestTROPJointMethod: + """Tests for TROP method='joint'. + + The joint method estimates a single scalar treatment effect τ via + weighted least squares, as opposed to the twostep method which + computes per-observation effects. + """ + + def test_joint_basic(self, simple_panel_data): + """Joint method runs and produces reasonable ATT.""" + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + assert isinstance(results, TROPResults) + assert trop_est.is_fitted_ + assert results.n_obs == len(simple_panel_data) + assert results.n_control == 15 + assert results.n_treated == 5 + # ATT should be positive (true effect is 3.0) + assert results.att > 0 + + def test_joint_no_lowrank(self, simple_panel_data): + """Joint method with lambda_nn=inf (no low-rank).""" + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[float('inf')], # Disable low-rank + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + assert isinstance(results, TROPResults) + # Effective rank should be 0 when L=0 + assert results.effective_rank == 0.0 + # Factor matrix should be all zeros + assert np.allclose(results.factor_matrix, 0.0) + + def test_joint_with_lowrank(self, factor_dgp_data): + """Joint method with finite lambda_nn (with low-rank).""" + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1, 1.0], + n_bootstrap=20, + seed=42, + ) + results = trop_est.fit( + factor_dgp_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + assert isinstance(results, TROPResults) + assert results.effective_rank >= 0 + # Should produce non-zero factor matrix if low-rank is used + # (depends on which lambda_nn is selected) + + def test_joint_matches_direction(self, simple_panel_data): + """Joint method sign/magnitude roughly matches twostep.""" + # Fit with twostep + trop_twostep = TROP( + method="twostep", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=10, + seed=42, + ) + results_twostep = trop_twostep.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Fit with joint + trop_joint = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=10, + seed=42, + ) + results_joint = trop_joint.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Both should have positive ATT (true effect is 3.0) + assert results_twostep.att > 0 + assert results_joint.att > 0 + + # Signs should match + assert np.sign(results_twostep.att) == np.sign(results_joint.att) + + def test_method_parameter_validation(self): + """Invalid method raises ValueError.""" + with pytest.raises(ValueError, match="method must be one of"): + TROP(method="invalid_method") + + def test_method_in_get_params(self): + """method parameter appears in get_params().""" + trop_est = TROP(method="joint") + params = trop_est.get_params() + assert "method" in params + assert params["method"] == "joint" + + def test_method_in_set_params(self): + """method parameter can be set via set_params().""" + trop_est = TROP(method="twostep") + assert trop_est.method == "twostep" + + trop_est.set_params(method="joint") + assert trop_est.method == "joint" + + def test_joint_bootstrap_variance(self, simple_panel_data): + """Joint method bootstrap variance estimation works.""" + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + variance_method="bootstrap", + n_bootstrap=20, + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + assert results.se > 0 + assert results.variance_method == "bootstrap" + assert results.n_bootstrap == 20 + assert results.bootstrap_distribution is not None + + def test_joint_jackknife_variance(self, simple_panel_data): + """Joint method jackknife variance estimation works.""" + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + variance_method="jackknife", + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + assert results.se >= 0 + assert results.variance_method == "jackknife" + + def test_joint_confidence_interval(self, simple_panel_data): + """Joint method produces valid confidence intervals.""" + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + alpha=0.05, + n_bootstrap=30, + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + lower, upper = results.conf_int + assert lower < results.att < upper + assert lower < upper + + def test_joint_loocv_selects_from_grid(self, simple_panel_data): + """Joint method LOOCV selects tuning parameters from the grid.""" + grid_time = [0.0, 0.5, 1.0] + grid_unit = [0.0, 0.5, 1.0] + grid_nn = [0.0, 0.1] + + trop_est = TROP( + method="joint", + lambda_time_grid=grid_time, + lambda_unit_grid=grid_unit, + lambda_nn_grid=grid_nn, + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Selected lambdas should be from the grid + assert results.lambda_time in grid_time + assert results.lambda_unit in grid_unit + assert results.lambda_nn in grid_nn + # LOOCV score should be computed + assert np.isfinite(results.loocv_score) or np.isnan(results.loocv_score) + + def test_joint_loocv_score_internal(self, simple_panel_data): + """Test the internal _loocv_score_joint method produces valid scores.""" + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + seed=42, + ) + + # Setup data matrices + all_units = sorted(simple_panel_data['unit'].unique()) + all_periods = sorted(simple_panel_data['period'].unique()) + n_units = len(all_units) + n_periods = len(all_periods) + + Y = ( + simple_panel_data.pivot(index='period', columns='unit', values='outcome') + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + simple_panel_data.pivot(index='period', columns='unit', values='treated') + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(int) + .values + ) + + control_mask = D == 0 + control_obs = [ + (t, i) for t in range(n_periods) for i in range(n_units) + if control_mask[t, i] and not np.isnan(Y[t, i]) + ][:20] # Limit for speed + + treated_periods = 3 # From fixture: n_post = 3 + + # Score should be finite + score = trop_est._loocv_score_joint( + Y, D, control_obs, 0.0, 0.0, 0.0, + treated_periods, n_units, n_periods + ) + assert np.isfinite(score) or np.isinf(score), "Score should be finite or inf" + + # Score with larger lambda_nn should still work + score2 = trop_est._loocv_score_joint( + Y, D, control_obs, 1.0, 1.0, 0.1, + treated_periods, n_units, n_periods + ) + assert np.isfinite(score2) or np.isinf(score2), "Score should be finite or inf" + + def test_joint_handles_nan_outcomes(self, simple_panel_data): + """Joint method handles NaN outcome values gracefully.""" + # Introduce NaN in some control observations + data = simple_panel_data.copy() + control_mask = data['treated'] == 0 + control_indices = data[control_mask].index.tolist() + + # Set 5 random control observations to NaN + np.random.seed(42) + nan_indices = np.random.choice(control_indices, size=5, replace=False) + data.loc[nan_indices, 'outcome'] = np.nan + + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Results should be finite (NaN observations excluded) + assert np.isfinite(results.att), "ATT should be finite with NaN data" + assert np.isfinite(results.se), "SE should be finite with NaN data" + # ATT should be positive (true effect is 3.0) + assert results.att > 0, "ATT should be positive" + + def test_joint_with_lowrank_handles_nan(self, simple_panel_data): + """Joint method with low-rank handles NaN values correctly.""" + # Introduce NaN in some control observations + data = simple_panel_data.copy() + control_mask = data['treated'] == 0 + control_indices = data[control_mask].index.tolist() + + # Set 3 random control observations to NaN + np.random.seed(123) + nan_indices = np.random.choice(control_indices, size=3, replace=False) + data.loc[nan_indices, 'outcome'] = np.nan + + trop_est = TROP( + method="joint", + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.1], # Finite lambda_nn enables low-rank + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Results should be finite + assert np.isfinite(results.att), "ATT should be finite with NaN data" + assert np.isfinite(results.se), "SE should be finite with NaN data" + + def test_joint_nan_exclusion_behavior(self, simple_panel_data): + """Verify NaN observations are truly excluded from estimation. + + This tests the PR #113 fix: NaN observations should not contribute + to the weighted gradient step. We verify this by comparing results + when fitting on data with NaN vs data with those observations removed. + """ + # Get a clean copy + data_full = simple_panel_data.copy() + + # Identify a specific control observation to "remove" + control_mask = data_full['treated'] == 0 + control_indices = data_full[control_mask].index.tolist() + + # Pick a few specific observations to remove/set to NaN + np.random.seed(42) + remove_indices = np.random.choice(control_indices, size=3, replace=False) + + # Create version with NaN + data_nan = data_full.copy() + data_nan.loc[remove_indices, 'outcome'] = np.nan + + # Create version with rows removed + data_dropped = data_full.drop(remove_indices) + + # Fit on both versions with identical settings + trop_nan = TROP( + method="joint", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.0], # Disable low-rank for cleaner comparison + n_bootstrap=10, + seed=42, + ) + trop_dropped = TROP( + method="joint", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.0], + n_bootstrap=10, + seed=42, + ) + + results_nan = trop_nan.fit( + data_nan, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + results_dropped = trop_dropped.fit( + data_dropped, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # ATT should be very close (allowing small numerical tolerance) + # If NaN observations were not truly excluded, ATT would differ + assert np.abs(results_nan.att - results_dropped.att) < 0.5, ( + f"ATT with NaN ({results_nan.att:.4f}) should match dropped data " + f"({results_dropped.att:.4f}) - true NaN exclusion" + ) + + def test_joint_jackknife_produces_variation(self, simple_panel_data): + """Verify jackknife produces variation across leave-out iterations. + + This tests the PR #113 fix: jackknife should truly exclude units + via weight zeroing, not imputation. If imputation were used, all + jackknife estimates would be nearly identical. + """ + trop_est = TROP( + method="joint", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.0], + variance_method="jackknife", + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # SE should be positive (variation exists) + assert results.se > 0, "Jackknife SE should be positive" + + # If we can access jackknife estimates, verify they vary + # (The SE being > 0 already implies variation, but this is more explicit) + if hasattr(results, 'bootstrap_distribution') and results.bootstrap_distribution is not None: + # For jackknife, this stores the jackknife estimates + jack_estimates = results.bootstrap_distribution + if len(jack_estimates) > 1: + estimate_std = np.std(jack_estimates) + assert estimate_std > 0, "Jackknife estimates should vary" + + def test_joint_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): + """Verify units with no valid pre-period data get zero weight. + + This tests the PR #113 fix: units with no valid pre-period observations + should get zero weight (instead of max weight via dist=0). + """ + # Create data where one control unit has all NaN in pre-period + data = simple_panel_data.copy() + + # Find a control unit (unit that never has treated=1) + unit_ever_treated = data.groupby('unit')['treated'].max() + control_units = unit_ever_treated[unit_ever_treated == 0].index.tolist() + target_unit = control_units[0] + + # Get pre-periods (periods where this control unit has treated=0) + unit_data = data[data['unit'] == target_unit] + pre_periods = sorted(unit_data[unit_data['treated'] == 0]['period'].unique())[:5] + + # Set all pre-period values for target_unit to NaN + mask = (data['unit'] == target_unit) & (data['period'].isin(pre_periods)) + data.loc[mask, 'outcome'] = np.nan + + trop_est = TROP( + method="joint", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], # Non-zero lambda_unit to use distance weighting + lambda_nn_grid=[0.0], + n_bootstrap=10, + seed=42, + ) + + # This should not error and should produce finite results + results = trop_est.fit( + data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + assert np.isfinite(results.att), "ATT should be finite even with unit having no pre-period data" + assert np.isfinite(results.se), "SE should be finite" + + def test_joint_treated_pre_nan_handling(self, simple_panel_data): + """Verify joint method handles NaN in treated units during pre-periods. + + When all treated units have NaN at a pre-period, average_treated[t] = NaN. + This period should be excluded from unit distance calculation (both numerator + and denominator) to avoid inflating valid_count. + + This tests the fix for PR #113 Round 5 feedback (P1). + """ + data = simple_panel_data.copy() + + # Find treated units and pre-periods + treated_units = data[data['treated'] == 1]['unit'].unique() + # Pre-periods are periods where treated=0 for treated units + pre_periods = sorted( + data[(data['unit'].isin(treated_units)) & (data['treated'] == 0)]['period'].unique() + ) + assert len(pre_periods) >= 2, "Need at least 2 pre-periods for this test" + + # Pick a middle pre-period + target_period = pre_periods[len(pre_periods) // 2] + + # Set ALL treated units' outcomes at target_period to NaN + # This makes average_treated[target_period] = NaN + mask = (data['unit'].isin(treated_units)) & (data['period'] == target_period) + data.loc[mask, 'outcome'] = np.nan + + # Verify we set NaN correctly + n_nan = data.loc[mask, 'outcome'].isna().sum() + assert n_nan == len(treated_units), f"Should have {len(treated_units)} NaN, got {n_nan}" + + trop_est = TROP( + method="joint", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.0], + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Results should be finite - NaN period properly excluded from distance calc + assert np.isfinite(results.att), f"ATT should be finite, got {results.att}" + assert np.isfinite(results.se), f"SE should be finite, got {results.se}" + + def test_joint_rejects_staggered_adoption(self): + """Joint method raises ValueError for staggered adoption data. + + The joint method assumes all treated units receive treatment at the + same time. With staggered adoption (units first treated at different + periods), the method's weights and variance estimation are invalid. + """ + # Create data with staggered treatment (units treated at different times) + data = [] + np.random.seed(42) + for i in range(10): + # Units 0-2 first treated at t=5, units 3-4 first treated at t=7 + first_treat = 5 if i < 3 else 7 + is_treated_unit = i < 5 # Units 0-4 are treated, 5-9 are control + for t in range(10): + treated = 1 if is_treated_unit and t >= first_treat else 0 + data.append({ + 'unit': i, + 'time': t, + 'outcome': np.random.randn(), + 'treated': treated + }) + df = pd.DataFrame(data) + + trop = TROP(method="joint") + with pytest.raises(ValueError, match="staggered adoption"): + trop.fit(df, 'outcome', 'treated', 'unit', 'time') + + def test_joint_python_loocv_subsampling(self): + """Test that joint method works with Python-only LOOCV when control_obs > max_loocv_samples. + + This tests the fix for PR #113 Round 7 feedback (P1): Python fallback + LOOCV sampling could raise ValueError when control_obs is a list of tuples. + """ + from unittest.mock import patch + import sys + + np.random.seed(42) + # Create data with many control observations (> default max_loocv_samples=500) + n_units, n_periods = 30, 25 # 30*25 = 750 observations, most are control + n_treated = 3 + n_post = 3 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 10.0 + i * 0.1 + t * 0.1 + np.random.randn() * 0.5 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += 2.0 + data.append({ + 'unit': i, + 'time': t, + 'outcome': y, + 'treated': treatment_indicator, + }) + + df = pd.DataFrame(data) + + # Patch to force Python backend and set small max_loocv_samples + trop_module = sys.modules['diff_diff.trop'] + + with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ + patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + + # Use small max_loocv_samples to trigger subsampling + trop_est = TROP( + method="joint", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.0], + max_loocv_samples=100, # Force subsampling (control_obs > 100) + n_bootstrap=0, + seed=42 + ) + + # This should not raise ValueError + results = trop_est.fit(df, 'outcome', 'treated', 'unit', 'time') + + assert isinstance(results, TROPResults) + assert np.isfinite(results.att)