diff --git a/.gitignore b/.gitignore index 272a10cf..b430d9b4 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,6 @@ Cargo.lock # Maturin build artifacts target/ + +# Claude Code - local settings (user-specific permissions) +.claude/settings.local.json diff --git a/TODO.md b/TODO.md index 2e800f91..cee5c945 100644 --- a/TODO.md +++ b/TODO.md @@ -85,6 +85,10 @@ Enhancements for `honest_did.py`: ## CallawaySantAnna Bootstrap Improvements - [ ] Consider aligning p-value computation with R `did` package (symmetric percentile method) +- [ ] Investigate RuntimeWarnings in influence function aggregation (`staggered.py:1722`, `staggered.py:1999-2018`) + - Warnings: "divide by zero", "overflow", "invalid value" in matmul operations + - Occurs during bootstrap SE computation with small sample sizes or edge cases + - Does not affect correctness (results are still valid), but should be suppressed or handled gracefully --- diff --git a/diff_diff/pretrends.py b/diff_diff/pretrends.py index fbda06e4..fb88e96e 100644 --- a/diff_diff/pretrends.py +++ b/diff_diff/pretrends.py @@ -202,6 +202,63 @@ def to_dataframe(self) -> pd.DataFrame: """Convert results to DataFrame.""" return pd.DataFrame([self.to_dict()]) + def power_at(self, M: float) -> float: + """ + Compute power to detect a specific violation magnitude. + + This method allows computing power at different M values without + re-fitting the model, using the stored variance-covariance matrix. + + Parameters + ---------- + M : float + Violation magnitude to evaluate. + + Returns + ------- + float + Power to detect violation of magnitude M. + """ + from scipy import stats + + n_pre = self.n_pre_periods + + # Reconstruct violation weights based on violation type + # Must match PreTrendsPower._get_violation_weights() exactly + if self.violation_type == "linear": + # Linear trend: weights decrease toward treatment + # [n-1, n-2, ..., 1, 0] for n pre-periods + weights = np.arange(-n_pre + 1, 1, dtype=float) + weights = -weights # Now [n-1, n-2, ..., 1, 0] + elif self.violation_type == "constant": + weights = np.ones(n_pre) + elif self.violation_type == "last_period": + weights = np.zeros(n_pre) + weights[-1] = 1.0 + else: + # For custom, we can't reconstruct - use equal weights as fallback + weights = np.ones(n_pre) + + # Normalize weights to unit L2 norm + norm = np.linalg.norm(weights) + if norm > 0: + weights = weights / norm + + # Compute non-centrality parameter + try: + vcov_inv = np.linalg.inv(self.vcov) + except np.linalg.LinAlgError: + vcov_inv = np.linalg.pinv(self.vcov) + + # delta = M * weights + # nc = delta' * V^{-1} * delta + noncentrality = M**2 * (weights @ vcov_inv @ weights) + + # Compute power using non-central chi-squared + power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality) + + return float(power) + @dataclass class PreTrendsPowerCurve: @@ -471,10 +528,18 @@ def _get_violation_weights(self, n_pre: int) -> np.ndarray: def _extract_pre_period_params( self, results: Union[MultiPeriodDiDResults, Any], + pre_periods: Optional[List[int]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]: """ Extract pre-period parameters from results. + Parameters + ---------- + results : MultiPeriodDiDResults or similar + Results object from event study estimation. + pre_periods : list of int, optional + Explicit list of pre-treatment periods. If None, uses results.pre_periods. + Returns ------- effects : np.ndarray @@ -487,13 +552,18 @@ def _extract_pre_period_params( Number of pre-periods. """ if isinstance(results, MultiPeriodDiDResults): - # Get pre-period information - all_pre_periods = results.pre_periods + # Get pre-period information - use explicit pre_periods if provided + if pre_periods is not None: + all_pre_periods = list(pre_periods) + else: + all_pre_periods = results.pre_periods if len(all_pre_periods) == 0: raise ValueError( "No pre-treatment periods found in results. " - "Pre-trends power analysis requires pre-period coefficients." + "Pre-trends power analysis requires pre-period coefficients. " + "If you estimated all periods as post_periods, use the pre_periods " + "parameter to specify which are actually pre-treatment." ) # Only include periods with actual estimated coefficients @@ -775,6 +845,7 @@ def fit( self, results: Union[MultiPeriodDiDResults, Any], M: Optional[float] = None, + pre_periods: Optional[List[int]] = None, ) -> PreTrendsPowerResults: """ Compute pre-trends power analysis. @@ -786,6 +857,11 @@ def fit( M : float, optional Specific violation magnitude to evaluate. If None, evaluates at a default magnitude based on the data. + pre_periods : list of int, optional + Explicit list of pre-treatment periods to use for power analysis. + If None, attempts to infer from results.pre_periods. Use this when + you've estimated an event study with all periods in post_periods + and need to specify which are actually pre-treatment. Returns ------- @@ -793,7 +869,7 @@ def fit( Power analysis results including power and MDV. """ # Extract pre-period parameters - effects, ses, vcov, n_pre = self._extract_pre_period_params(results) + effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods) # Get violation weights weights = self._get_violation_weights(n_pre) @@ -831,6 +907,7 @@ def power_at( self, results: Union[MultiPeriodDiDResults, Any], M: float, + pre_periods: Optional[List[int]] = None, ) -> float: """ Compute power to detect a specific violation magnitude. @@ -841,13 +918,15 @@ def power_at( Event study results. M : float Violation magnitude. + pre_periods : list of int, optional + Explicit list of pre-treatment periods. See fit() for details. Returns ------- float Power to detect violation of magnitude M. """ - result = self.fit(results, M=M) + result = self.fit(results, M=M, pre_periods=pre_periods) return result.power def power_curve( @@ -855,6 +934,7 @@ def power_curve( results: Union[MultiPeriodDiDResults, Any], M_grid: Optional[List[float]] = None, n_points: int = 50, + pre_periods: Optional[List[int]] = None, ) -> PreTrendsPowerCurve: """ Compute power across a range of violation magnitudes. @@ -868,6 +948,8 @@ def power_curve( automatic grid from 0 to 2.5 * MDV. n_points : int, default=50 Number of points in automatic grid. + pre_periods : list of int, optional + Explicit list of pre-treatment periods. See fit() for details. Returns ------- @@ -875,7 +957,7 @@ def power_curve( Power curve data with plot method. """ # Extract parameters - effects, ses, vcov, n_pre = self._extract_pre_period_params(results) + _, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods) weights = self._get_violation_weights(n_pre) # Compute MDV @@ -906,6 +988,7 @@ def power_curve( def sensitivity_to_honest_did( self, results: Union[MultiPeriodDiDResults, Any], + pre_periods: Optional[List[int]] = None, ) -> Dict[str, Any]: """ Compare pre-trends power analysis with HonestDiD sensitivity. @@ -917,6 +1000,8 @@ def sensitivity_to_honest_did( ---------- results : results object Event study results. + pre_periods : list of int, optional + Explicit list of pre-treatment periods. See fit() for details. Returns ------- @@ -926,7 +1011,7 @@ def sensitivity_to_honest_did( - honest_M_at_mdv: Corresponding M value for HonestDiD - interpretation: Text explaining the relationship """ - pt_results = self.fit(results) + pt_results = self.fit(results, pre_periods=pre_periods) mdv = pt_results.mdv # The MDV represents the size of violation the test could detect @@ -993,6 +1078,7 @@ def compute_pretrends_power( alpha: float = 0.05, target_power: float = 0.80, violation_type: str = "linear", + pre_periods: Optional[List[int]] = None, ) -> PreTrendsPowerResults: """ Convenience function for pre-trends power analysis. @@ -1009,6 +1095,9 @@ def compute_pretrends_power( Target power for MDV calculation. violation_type : str, default='linear' Type of violation pattern. + pre_periods : list of int, optional + Explicit list of pre-treatment periods. If None, attempts to infer + from results. Use when you've estimated all periods as post_periods. Returns ------- @@ -1021,7 +1110,7 @@ def compute_pretrends_power( >>> from diff_diff.pretrends import compute_pretrends_power >>> >>> results = MultiPeriodDiD().fit(data, ...) - >>> power_results = compute_pretrends_power(results) + >>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3]) >>> print(f"MDV: {power_results.mdv:.3f}") >>> print(f"Power: {power_results.power:.1%}") """ @@ -1030,7 +1119,7 @@ def compute_pretrends_power( power=target_power, violation_type=violation_type, ) - return pt.fit(results, M=M) + return pt.fit(results, M=M, pre_periods=pre_periods) def compute_mdv( @@ -1038,6 +1127,7 @@ def compute_mdv( alpha: float = 0.05, target_power: float = 0.80, violation_type: str = "linear", + pre_periods: Optional[List[int]] = None, ) -> float: """ Compute minimum detectable violation. @@ -1049,9 +1139,12 @@ def compute_mdv( alpha : float, default=0.05 Significance level. target_power : float, default=0.80 - Target power. + Target power for MDV calculation. violation_type : str, default='linear' Type of violation pattern. + pre_periods : list of int, optional + Explicit list of pre-treatment periods. If None, attempts to infer + from results. Use when you've estimated all periods as post_periods. Returns ------- @@ -1063,5 +1156,5 @@ def compute_mdv( power=target_power, violation_type=violation_type, ) - result = pt.fit(results) + result = pt.fit(results, pre_periods=pre_periods) return result.mdv diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 01d1148c..f1ee7459 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1053,6 +1053,10 @@ def fit( df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) + # Standardize the first_treat column name for internal use + # This avoids hardcoding column names in internal methods + df['first_treat'] = df[first_treat] + # Identify groups and time periods time_periods = sorted(df[time].unique()) treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0]) diff --git a/docs/tutorials/01_basic_did.ipynb b/docs/tutorials/01_basic_did.ipynb index 12037658..48c744ae 100644 --- a/docs/tutorials/01_basic_did.ipynb +++ b/docs/tutorials/01_basic_did.ipynb @@ -19,7 +19,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:35.801900Z", + "iopub.status.busy": "2026-01-18T18:10:35.801757Z", + "iopub.status.idle": "2026-01-18T18:10:36.233206Z", + "shell.execute_reply": "2026-01-18T18:10:36.232925Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -40,7 +47,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.234607Z", + "iopub.status.busy": "2026-01-18T18:10:36.234517Z", + "iopub.status.idle": "2026-01-18T18:10:36.239917Z", + "shell.execute_reply": "2026-01-18T18:10:36.239725Z" + } + }, "outputs": [], "source": [ "# Generate synthetic DiD data with known ATT of 5.0\n", @@ -49,7 +63,8 @@ " n_periods=2,\n", " treatment_effect=5.0,\n", " treatment_fraction=0.5,\n", - " noise_std=1.0,\n", + " treatment_period=1, # Period 1 is post-treatment (periods are 0 and 1)\n", + " noise_sd=1.0,\n", " seed=42\n", ")\n", "\n", @@ -60,7 +75,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.251493Z", + "iopub.status.busy": "2026-01-18T18:10:36.251420Z", + "iopub.status.idle": "2026-01-18T18:10:36.254312Z", + "shell.execute_reply": "2026-01-18T18:10:36.254088Z" + } + }, "outputs": [], "source": [ "# Examine the data structure\n", @@ -80,7 +102,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.255333Z", + "iopub.status.busy": "2026-01-18T18:10:36.255268Z", + "iopub.status.idle": "2026-01-18T18:10:36.257620Z", + "shell.execute_reply": "2026-01-18T18:10:36.257437Z" + } + }, "outputs": [], "source": [ "# Create the estimator\n", @@ -115,7 +144,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.258588Z", + "iopub.status.busy": "2026-01-18T18:10:36.258510Z", + "iopub.status.idle": "2026-01-18T18:10:36.260137Z", + "shell.execute_reply": "2026-01-18T18:10:36.259955Z" + } + }, "outputs": [], "source": [ "# Access individual components\n", @@ -140,7 +176,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.261002Z", + "iopub.status.busy": "2026-01-18T18:10:36.260952Z", + "iopub.status.idle": "2026-01-18T18:10:36.262794Z", + "shell.execute_reply": "2026-01-18T18:10:36.262624Z" + } + }, "outputs": [], "source": [ "# Using formula interface (R-style)\n", @@ -156,7 +199,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.263682Z", + "iopub.status.busy": "2026-01-18T18:10:36.263631Z", + "iopub.status.idle": "2026-01-18T18:10:36.265008Z", + "shell.execute_reply": "2026-01-18T18:10:36.264810Z" + } + }, "outputs": [], "source": [ "# Verify both methods give the same result\n", @@ -177,7 +227,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.265866Z", + "iopub.status.busy": "2026-01-18T18:10:36.265814Z", + "iopub.status.idle": "2026-01-18T18:10:36.268138Z", + "shell.execute_reply": "2026-01-18T18:10:36.267940Z" + } + }, "outputs": [], "source": [ "# Add some covariates to our data\n", @@ -201,7 +258,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.269011Z", + "iopub.status.busy": "2026-01-18T18:10:36.268965Z", + "iopub.status.idle": "2026-01-18T18:10:36.270317Z", + "shell.execute_reply": "2026-01-18T18:10:36.270123Z" + } + }, "outputs": [], "source": [ "# All coefficient estimates are available\n", @@ -225,7 +289,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.271191Z", + "iopub.status.busy": "2026-01-18T18:10:36.271130Z", + "iopub.status.idle": "2026-01-18T18:10:36.274752Z", + "shell.execute_reply": "2026-01-18T18:10:36.274501Z" + } + }, "outputs": [], "source": [ "# Generate data with more structure\n", @@ -263,7 +334,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.275657Z", + "iopub.status.busy": "2026-01-18T18:10:36.275605Z", + "iopub.status.idle": "2026-01-18T18:10:36.277635Z", + "shell.execute_reply": "2026-01-18T18:10:36.277443Z" + } + }, "outputs": [], "source": [ "# Using fixed effects with dummy variables\n", @@ -282,7 +360,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.278482Z", + "iopub.status.busy": "2026-01-18T18:10:36.278432Z", + "iopub.status.idle": "2026-01-18T18:10:36.280535Z", + "shell.execute_reply": "2026-01-18T18:10:36.280352Z" + } + }, "outputs": [], "source": [ "# Using absorbed fixed effects (within-transformation)\n", @@ -311,7 +396,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.281443Z", + "iopub.status.busy": "2026-01-18T18:10:36.281372Z", + "iopub.status.idle": "2026-01-18T18:10:36.284078Z", + "shell.execute_reply": "2026-01-18T18:10:36.283923Z" + } + }, "outputs": [], "source": [ "# Two-Way Fixed Effects estimator\n", @@ -341,7 +433,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.285055Z", + "iopub.status.busy": "2026-01-18T18:10:36.284989Z", + "iopub.status.idle": "2026-01-18T18:10:36.287340Z", + "shell.execute_reply": "2026-01-18T18:10:36.287159Z" + } + }, "outputs": [], "source": [ "# Create clustered data\n", @@ -379,7 +478,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.288239Z", + "iopub.status.busy": "2026-01-18T18:10:36.288176Z", + "iopub.status.idle": "2026-01-18T18:10:36.290659Z", + "shell.execute_reply": "2026-01-18T18:10:36.290472Z" + } + }, "outputs": [], "source": [ "# Compare standard errors: robust vs cluster-robust\n", @@ -418,7 +524,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.291566Z", + "iopub.status.busy": "2026-01-18T18:10:36.291505Z", + "iopub.status.idle": "2026-01-18T18:10:36.346217Z", + "shell.execute_reply": "2026-01-18T18:10:36.346015Z" + } + }, "outputs": [], "source": [ "# Wild cluster bootstrap inference\n", @@ -443,7 +556,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.347172Z", + "iopub.status.busy": "2026-01-18T18:10:36.347118Z", + "iopub.status.idle": "2026-01-18T18:10:36.348824Z", + "shell.execute_reply": "2026-01-18T18:10:36.348617Z" + } + }, "outputs": [], "source": [ "# Compare inference methods\n", @@ -466,7 +586,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.349842Z", + "iopub.status.busy": "2026-01-18T18:10:36.349787Z", + "iopub.status.idle": "2026-01-18T18:10:36.351291Z", + "shell.execute_reply": "2026-01-18T18:10:36.351112Z" + } + }, "outputs": [], "source": [ "# Export to dictionary\n", @@ -482,7 +609,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:10:36.352211Z", + "iopub.status.busy": "2026-01-18T18:10:36.352147Z", + "iopub.status.idle": "2026-01-18T18:10:36.355144Z", + "shell.execute_reply": "2026-01-18T18:10:36.354965Z" + } + }, "outputs": [], "source": [ "# Export to DataFrame (useful for combining multiple estimates)\n", @@ -529,7 +663,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/docs/tutorials/02_staggered_did.ipynb b/docs/tutorials/02_staggered_did.ipynb index cf6b2a00..8a4c35db 100644 --- a/docs/tutorials/02_staggered_did.ipynb +++ b/docs/tutorials/02_staggered_did.ipynb @@ -3,14 +3,54 @@ { "cell_type": "markdown", "metadata": {}, - "source": "# Staggered Difference-in-Differences\n\nThis notebook demonstrates how to handle **staggered treatment adoption** using modern DiD estimators. In staggered DiD settings:\n\n- Different units get treated at different times\n- Traditional TWFE can give biased estimates due to \"forbidden comparisons\"\n- Modern estimators compute group-time specific effects and aggregate them properly\n\nWe'll cover:\n1. Understanding staggered adoption\n2. The problem with TWFE (and Goodman-Bacon decomposition)\n3. The Callaway-Sant'Anna estimator\n4. Group-time effects ATT(g,t)\n5. Aggregating effects (simple, group, event-study)\n6. Bootstrap inference for valid standard errors\n7. Visualization\n8. **Sun-Abraham interaction-weighted estimator**\n9. **Comparing CS and SA as a robustness check**" + "source": [ + "# Staggered Difference-in-Differences\n", + "\n", + "This notebook demonstrates how to handle **staggered treatment adoption** using modern DiD estimators. In staggered DiD settings:\n", + "\n", + "- Different units get treated at different times\n", + "- Traditional TWFE can give biased estimates due to \"forbidden comparisons\"\n", + "- Modern estimators compute group-time specific effects and aggregate them properly\n", + "\n", + "We'll cover:\n", + "1. Understanding staggered adoption\n", + "2. The problem with TWFE (and Goodman-Bacon decomposition)\n", + "3. The Callaway-Sant'Anna estimator\n", + "4. Group-time effects ATT(g,t)\n", + "5. Aggregating effects (simple, group, event-study)\n", + "6. Bootstrap inference for valid standard errors\n", + "7. Visualization\n", + "8. **Sun-Abraham interaction-weighted estimator**\n", + "9. **Comparing CS and SA as a robustness check**" + ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:53.563913Z", + "iopub.status.busy": "2026-01-18T18:23:53.563787Z", + "iopub.status.idle": "2026-01-18T18:23:54.146380Z", + "shell.execute_reply": "2026-01-18T18:23:54.146080Z" + } + }, "outputs": [], - "source": "import numpy as np\nimport pandas as pd\nfrom diff_diff import CallawaySantAnna, SunAbraham, MultiPeriodDiD\nfrom diff_diff.visualization import plot_event_study, plot_group_effects\n\n# For nicer plots (optional)\ntry:\n import matplotlib.pyplot as plt\n plt.style.use('seaborn-v0_8-whitegrid')\n HAS_MATPLOTLIB = True\nexcept ImportError:\n HAS_MATPLOTLIB = False\n print(\"matplotlib not installed - visualization examples will be skipped\")" + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from diff_diff import CallawaySantAnna, SunAbraham, MultiPeriodDiD\n", + "from diff_diff.visualization import plot_event_study, plot_group_effects\n", + "\n", + "# For nicer plots (optional)\n", + "try:\n", + " import matplotlib.pyplot as plt\n", + " plt.style.use('seaborn-v0_8-whitegrid')\n", + " HAS_MATPLOTLIB = True\n", + "except ImportError:\n", + " HAS_MATPLOTLIB = False\n", + " print(\"matplotlib not installed - visualization examples will be skipped\")" + ] }, { "cell_type": "markdown", @@ -24,7 +64,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.147667Z", + "iopub.status.busy": "2026-01-18T18:23:54.147584Z", + "iopub.status.idle": "2026-01-18T18:23:54.153830Z", + "shell.execute_reply": "2026-01-18T18:23:54.153627Z" + } + }, "outputs": [], "source": [ "# Generate staggered adoption data\n", @@ -77,7 +124,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.165680Z", + "iopub.status.busy": "2026-01-18T18:23:54.165600Z", + "iopub.status.idle": "2026-01-18T18:23:54.168804Z", + "shell.execute_reply": "2026-01-18T18:23:54.168561Z" + } + }, "outputs": [], "source": [ "# Examine treatment timing\n", @@ -105,7 +159,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.169846Z", + "iopub.status.busy": "2026-01-18T18:23:54.169789Z", + "iopub.status.idle": "2026-01-18T18:23:54.173444Z", + "shell.execute_reply": "2026-01-18T18:23:54.173244Z" + } + }, "outputs": [], "source": [ "from diff_diff import TwoWayFixedEffects\n", @@ -126,22 +187,80 @@ }, { "cell_type": "markdown", - "source": "### Understanding *Why* TWFE Fails: Goodman-Bacon Decomposition\n\nThe Goodman-Bacon (2021) decomposition reveals exactly why TWFE can be biased. It shows that the TWFE estimate is a weighted average of all possible 2x2 DiD comparisons, including problematic \"forbidden comparisons\" where already-treated units are used as controls.\n\nThere are three types of comparisons:\n1. **Treated vs Never-treated** (green): Clean comparisons using never-treated units\n2. **Earlier vs Later treated** (blue): Uses later-treated as controls before they're treated\n3. **Later vs Earlier treated** (red): Uses already-treated as controls — the \"forbidden comparisons\"\n\nWhen treatment effects are heterogeneous (as in our data where effects grow over time), the forbidden comparisons can bias the TWFE estimate.", - "metadata": {} + "metadata": {}, + "source": [ + "### Understanding *Why* TWFE Fails: Goodman-Bacon Decomposition\n", + "\n", + "The Goodman-Bacon (2021) decomposition reveals exactly why TWFE can be biased. It shows that the TWFE estimate is a weighted average of all possible 2x2 DiD comparisons, including problematic \"forbidden comparisons\" where already-treated units are used as controls.\n", + "\n", + "There are three types of comparisons:\n", + "1. **Treated vs Never-treated** (green): Clean comparisons using never-treated units\n", + "2. **Earlier vs Later treated** (blue): Uses later-treated as controls before they're treated\n", + "3. **Later vs Earlier treated** (red): Uses already-treated as controls — the \"forbidden comparisons\"\n", + "\n", + "When treatment effects are heterogeneous (as in our data where effects grow over time), the forbidden comparisons can bias the TWFE estimate." + ] }, { "cell_type": "code", - "source": "from diff_diff import bacon_decompose, plot_bacon\n\n# Perform the Goodman-Bacon decomposition\nbacon_results = bacon_decompose(\n df,\n outcome='outcome',\n unit='unit',\n time='period',\n first_treat='cohort' # Same as 'cohort' column - 0 means never-treated\n)\n\n# View the decomposition summary\nbacon_results.print_summary()", - "metadata": {}, "execution_count": null, - "outputs": [] + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.174452Z", + "iopub.status.busy": "2026-01-18T18:23:54.174400Z", + "iopub.status.idle": "2026-01-18T18:23:54.180139Z", + "shell.execute_reply": "2026-01-18T18:23:54.179950Z" + } + }, + "outputs": [], + "source": [ + "from diff_diff import bacon_decompose, plot_bacon\n", + "\n", + "# Perform the Goodman-Bacon decomposition\n", + "bacon_results = bacon_decompose(\n", + " df,\n", + " outcome='outcome',\n", + " unit='unit',\n", + " time='period',\n", + " first_treat='cohort' # Same as 'cohort' column - 0 means never-treated\n", + ")\n", + "\n", + "# View the decomposition summary\n", + "bacon_results.print_summary()" + ] }, { "cell_type": "code", - "source": "# Visualize the decomposition\nif HAS_MATPLOTLIB:\n fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n \n # Scatter plot: shows each 2x2 comparison\n plot_bacon(bacon_results, ax=axes[0], plot_type='scatter', show=False)\n \n # Bar chart: shows total weight by comparison type\n plot_bacon(bacon_results, ax=axes[1], plot_type='bar', show=False)\n \n plt.tight_layout()\n plt.show()\n \n # Interpret the results\n forbidden_weight = bacon_results.total_weight_later_vs_earlier\n print(f\"\\n⚠️ {forbidden_weight:.1%} of the TWFE weight comes from 'forbidden comparisons'\")\n print(\" where already-treated units are used as controls.\")\n print(\"\\n→ This explains why TWFE can be biased. Use Callaway-Sant'Anna instead!\")", - "metadata": {}, "execution_count": null, - "outputs": [] + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.181019Z", + "iopub.status.busy": "2026-01-18T18:23:54.180969Z", + "iopub.status.idle": "2026-01-18T18:23:54.283154Z", + "shell.execute_reply": "2026-01-18T18:23:54.282934Z" + } + }, + "outputs": [], + "source": [ + "# Visualize the decomposition\n", + "if HAS_MATPLOTLIB:\n", + " fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", + " \n", + " # Scatter plot: shows each 2x2 comparison\n", + " plot_bacon(bacon_results, ax=axes[0], plot_type='scatter', show=False)\n", + " \n", + " # Bar chart: shows total weight by comparison type\n", + " plot_bacon(bacon_results, ax=axes[1], plot_type='bar', show=False)\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + " \n", + " # Interpret the results\n", + " forbidden_weight = bacon_results.total_weight_later_vs_earlier\n", + " print(f\"\\n⚠️ {forbidden_weight:.1%} of the TWFE weight comes from 'forbidden comparisons'\")\n", + " print(\" where already-treated units are used as controls.\")\n", + " print(\"\\n→ This explains why TWFE can be biased. Use Callaway-Sant'Anna instead!\")" + ] }, { "cell_type": "markdown", @@ -158,14 +277,20 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.284227Z", + "iopub.status.busy": "2026-01-18T18:23:54.284141Z", + "iopub.status.idle": "2026-01-18T18:23:54.289412Z", + "shell.execute_reply": "2026-01-18T18:23:54.289217Z" + } + }, "outputs": [], "source": [ "# Callaway-Sant'Anna estimation\n", "cs = CallawaySantAnna(\n", " control_group=\"never_treated\", # Use never-treated as controls\n", - " anticipation=0, # No anticipation effects\n", - " base_period=\"universal\" # Use period before treatment for each group\n", + " anticipation=0 # No anticipation effects\n", ")\n", "\n", "results_cs = cs.fit(\n", @@ -173,7 +298,8 @@ " outcome=\"outcome\",\n", " unit=\"unit\",\n", " time=\"period\",\n", - " cohort=\"cohort\"\n", + " first_treat=\"cohort\", # Column with first treatment period (0 = never treated)\n", + " aggregate=\"all\" # Compute all aggregations (simple, event_study, group)\n", ")\n", "\n", "print(results_cs.summary())" @@ -193,23 +319,37 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.290406Z", + "iopub.status.busy": "2026-01-18T18:23:54.290354Z", + "iopub.status.idle": "2026-01-18T18:23:54.291918Z", + "shell.execute_reply": "2026-01-18T18:23:54.291715Z" + } + }, "outputs": [], "source": [ "# View all group-time effects\n", "print(\"Group-Time Effects ATT(g,t):\")\n", "print(\"=\" * 60)\n", "\n", - "for gt_effect in results_cs.group_time_effects:\n", - " sig = \"*\" if gt_effect.p_value < 0.05 else \"\"\n", - " print(f\"ATT({gt_effect.cohort},{gt_effect.time}): {gt_effect.att:>7.4f} \"\n", - " f\"(SE: {gt_effect.se:.4f}, p: {gt_effect.p_value:.3f}) {sig}\")" + "for (g, t), data in results_cs.group_time_effects.items():\n", + " sig = \"*\" if data['p_value'] < 0.05 else \"\"\n", + " print(f\"ATT({g},{t}): {data['effect']:>7.4f} \"\n", + " f\"(SE: {data['se']:.4f}, p: {data['p_value']:.3f}) {sig}\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.292699Z", + "iopub.status.busy": "2026-01-18T18:23:54.292649Z", + "iopub.status.idle": "2026-01-18T18:23:54.295835Z", + "shell.execute_reply": "2026-01-18T18:23:54.295677Z" + } + }, "outputs": [], "source": [ "# Convert to DataFrame for easier analysis\n", @@ -230,80 +370,180 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.296754Z", + "iopub.status.busy": "2026-01-18T18:23:54.296696Z", + "iopub.status.idle": "2026-01-18T18:23:54.298177Z", + "shell.execute_reply": "2026-01-18T18:23:54.297978Z" + } + }, "outputs": [], "source": [ "# Simple aggregation: weighted average across all (g,t)\n", - "simple_agg = results_cs.aggregate(\"simple\")\n", - "\n", + "# This is computed automatically and stored in overall_att/overall_se\n", "print(\"Simple Aggregation (Overall ATT):\")\n", - "print(f\"ATT: {simple_agg['att']:.4f}\")\n", - "print(f\"SE: {simple_agg['se']:.4f}\")\n", - "print(f\"95% CI: [{simple_agg['conf_int'][0]:.4f}, {simple_agg['conf_int'][1]:.4f}]\")" + "print(f\"ATT: {results_cs.overall_att:.4f}\")\n", + "print(f\"SE: {results_cs.overall_se:.4f}\")\n", + "print(f\"95% CI: [{results_cs.overall_conf_int[0]:.4f}, {results_cs.overall_conf_int[1]:.4f}]\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.299040Z", + "iopub.status.busy": "2026-01-18T18:23:54.298981Z", + "iopub.status.idle": "2026-01-18T18:23:54.300379Z", + "shell.execute_reply": "2026-01-18T18:23:54.300207Z" + } + }, "outputs": [], "source": [ "# Group aggregation: average effect by cohort\n", - "group_agg = results_cs.aggregate(\"group\")\n", - "\n", + "# Requires aggregate=\"group\" or \"all\" in fit()\n", "print(\"\\nGroup Aggregation (ATT by cohort):\")\n", - "for cohort, effects in group_agg.items():\n", - " print(f\"Cohort {cohort}: ATT = {effects['att']:.4f} (SE: {effects['se']:.4f})\")" + "for cohort, effects in results_cs.group_effects.items():\n", + " print(f\"Cohort {cohort}: ATT = {effects['effect']:.4f} (SE: {effects['se']:.4f})\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.301184Z", + "iopub.status.busy": "2026-01-18T18:23:54.301120Z", + "iopub.status.idle": "2026-01-18T18:23:54.302604Z", + "shell.execute_reply": "2026-01-18T18:23:54.302425Z" + } + }, "outputs": [], "source": [ "# Event-study aggregation: average effect by time relative to treatment\n", - "event_agg = results_cs.aggregate(\"event\")\n", - "\n", + "# Requires aggregate=\"event_study\" or \"all\" in fit()\n", "print(\"\\nEvent-Study Aggregation (ATT by event time):\")\n", "print(f\"{'Event Time':>12} {'ATT':>10} {'SE':>10} {'95% CI':>25}\")\n", "print(\"-\" * 60)\n", "\n", - "for event_time in sorted(event_agg.keys()):\n", - " effects = event_agg[event_time]\n", + "for event_time in sorted(results_cs.event_study_effects.keys()):\n", + " effects = results_cs.event_study_effects[event_time]\n", " ci = effects['conf_int']\n", - " print(f\"{event_time:>12} {effects['att']:>10.4f} {effects['se']:>10.4f} \"\n", + " print(f\"{event_time:>12} {effects['effect']:>10.4f} {effects['se']:>10.4f} \"\n", " f\"[{ci[0]:>8.4f}, {ci[1]:>8.4f}]\")" ] }, { "cell_type": "markdown", - "source": "## 6. Bootstrap Inference\n\nWith few clusters or when analytical standard errors may be unreliable, the **multiplier bootstrap** provides valid inference. This implements the approach from Callaway & Sant'Anna (2021), perturbing unit-level influence functions.\n\n**Why use bootstrap?**\n- Analytical SEs may understate uncertainty with few clusters\n- Bootstrap provides finite-sample valid confidence intervals\n- P-values are computed from the bootstrap distribution\n\n**Weight types:**\n- `'rademacher'` - Default, ±1 with p=0.5, good for most cases\n- `'mammen'` - Two-point distribution, matches first 3 moments\n- `'webb'` - Six-point distribution, recommended for very few clusters (<10)", - "metadata": {} + "metadata": {}, + "source": [ + "## 6. Bootstrap Inference\n", + "\n", + "With few clusters or when analytical standard errors may be unreliable, the **multiplier bootstrap** provides valid inference. This implements the approach from Callaway & Sant'Anna (2021), perturbing unit-level influence functions.\n", + "\n", + "**Why use bootstrap?**\n", + "- Analytical SEs may understate uncertainty with few clusters\n", + "- Bootstrap provides finite-sample valid confidence intervals\n", + "- P-values are computed from the bootstrap distribution\n", + "\n", + "**Weight types:**\n", + "- `'rademacher'` - Default, ±1 with p=0.5, good for most cases\n", + "- `'mammen'` - Two-point distribution, matches first 3 moments\n", + "- `'webb'` - Six-point distribution, recommended for very few clusters (<10)" + ] }, { "cell_type": "code", - "source": "# Callaway-Sant'Anna with bootstrap inference\ncs_boot = CallawaySantAnna(\n control_group=\"never_treated\",\n n_bootstrap=499, # Number of bootstrap iterations\n bootstrap_weight_type='rademacher', # or 'mammen', 'webb'\n seed=42 # For reproducibility\n)\n\nresults_boot = cs_boot.fit(\n df,\n outcome=\"outcome\",\n unit=\"unit\",\n time=\"period\",\n cohort=\"cohort\",\n aggregate=\"event_study\" # Compute event study aggregation\n)\n\n# Access bootstrap results\nprint(\"Bootstrap Inference Results:\")\nprint(\"=\" * 60)\nprint(f\"\\nOverall ATT: {results_boot.overall_att:.4f}\")\nprint(f\"Bootstrap SE: {results_boot.bootstrap_results.overall_att_se:.4f}\")\nprint(f\"Bootstrap 95% CI: [{results_boot.bootstrap_results.overall_att_ci[0]:.4f}, \"\n f\"{results_boot.bootstrap_results.overall_att_ci[1]:.4f}]\")\nprint(f\"Bootstrap p-value: {results_boot.bootstrap_results.overall_att_p_value:.4f}\")", - "metadata": {}, "execution_count": null, - "outputs": [] + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.303485Z", + "iopub.status.busy": "2026-01-18T18:23:54.303436Z", + "iopub.status.idle": "2026-01-18T18:23:54.309925Z", + "shell.execute_reply": "2026-01-18T18:23:54.309739Z" + } + }, + "outputs": [], + "source": [ + "# Callaway-Sant'Anna with bootstrap inference\n", + "cs_boot = CallawaySantAnna(\n", + " control_group=\"never_treated\",\n", + " n_bootstrap=499, # Number of bootstrap iterations\n", + " bootstrap_weight_type='rademacher', # or 'mammen', 'webb'\n", + " seed=42 # For reproducibility\n", + ")\n", + "\n", + "results_boot = cs_boot.fit(\n", + " df,\n", + " outcome=\"outcome\",\n", + " unit=\"unit\",\n", + " time=\"period\",\n", + " first_treat=\"cohort\", # Column with first treatment period\n", + " aggregate=\"event_study\" # Compute event study aggregation\n", + ")\n", + "\n", + "# Access bootstrap results\n", + "print(\"Bootstrap Inference Results:\")\n", + "print(\"=\" * 60)\n", + "print(f\"\\nOverall ATT: {results_boot.overall_att:.4f}\")\n", + "print(f\"Bootstrap SE: {results_boot.bootstrap_results.overall_att_se:.4f}\")\n", + "print(f\"Bootstrap 95% CI: [{results_boot.bootstrap_results.overall_att_ci[0]:.4f}, \"\n", + " f\"{results_boot.bootstrap_results.overall_att_ci[1]:.4f}]\")\n", + "print(f\"Bootstrap p-value: {results_boot.bootstrap_results.overall_att_p_value:.4f}\")" + ] }, { "cell_type": "code", - "source": "# Event study with bootstrap confidence intervals\nprint(\"\\nEvent Study with Bootstrap Inference:\")\nprint(f\"{'Event Time':>12} {'ATT':>10} {'Boot SE':>10} {'Boot 95% CI':>25} {'p-value':>10}\")\nprint(\"-\" * 70)\n\nevent_ses = results_boot.bootstrap_results.event_study_ses\nevent_cis = results_boot.bootstrap_results.event_study_cis\nevent_pvals = results_boot.bootstrap_results.event_study_p_values\n\nfor event_time in sorted(event_ses.keys()):\n att = results_boot.event_study_effects[event_time]['effect']\n se = event_ses[event_time]\n ci = event_cis[event_time]\n pval = event_pvals[event_time]\n sig = \"*\" if pval < 0.05 else \"\"\n print(f\"{event_time:>12} {att:>10.4f} {se:>10.4f} [{ci[0]:>8.4f}, {ci[1]:>8.4f}] {pval:>10.4f} {sig}\")", - "metadata": {}, "execution_count": null, - "outputs": [] + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.310776Z", + "iopub.status.busy": "2026-01-18T18:23:54.310721Z", + "iopub.status.idle": "2026-01-18T18:23:54.312619Z", + "shell.execute_reply": "2026-01-18T18:23:54.312433Z" + } + }, + "outputs": [], + "source": [ + "# Event study with bootstrap confidence intervals\n", + "print(\"\\nEvent Study with Bootstrap Inference:\")\n", + "print(f\"{'Event Time':>12} {'ATT':>10} {'Boot SE':>10} {'Boot 95% CI':>25} {'p-value':>10}\")\n", + "print(\"-\" * 70)\n", + "\n", + "event_ses = results_boot.bootstrap_results.event_study_ses\n", + "event_cis = results_boot.bootstrap_results.event_study_cis\n", + "event_pvals = results_boot.bootstrap_results.event_study_p_values\n", + "\n", + "for event_time in sorted(event_ses.keys()):\n", + " att = results_boot.event_study_effects[event_time]['effect']\n", + " se = event_ses[event_time]\n", + " ci = event_cis[event_time]\n", + " pval = event_pvals[event_time]\n", + " sig = \"*\" if pval < 0.05 else \"\"\n", + " print(f\"{event_time:>12} {att:>10.4f} {se:>10.4f} [{ci[0]:>8.4f}, {ci[1]:>8.4f}] {pval:>10.4f} {sig}\")" + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## 7. Visualization\n\nEvent-study plots are the standard way to visualize DiD results with multiple periods." + "source": [ + "## 7. Visualization\n", + "\n", + "Event-study plots are the standard way to visualize DiD results with multiple periods." + ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.313659Z", + "iopub.status.busy": "2026-01-18T18:23:54.313599Z", + "iopub.status.idle": "2026-01-18T18:23:54.343103Z", + "shell.execute_reply": "2026-01-18T18:23:54.342886Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", @@ -325,7 +565,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.344073Z", + "iopub.status.busy": "2026-01-18T18:23:54.344016Z", + "iopub.status.idle": "2026-01-18T18:23:54.382161Z", + "shell.execute_reply": "2026-01-18T18:23:54.381956Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", @@ -343,12 +590,25 @@ { "cell_type": "markdown", "metadata": {}, - "source": "## 8. Different Control Group Options\n\nThe CS estimator supports different control group specifications:\n- `\"never_treated\"`: Only use units that are never treated\n- `\"not_yet_treated\"`: Use units that haven't been treated yet at time t" + "source": [ + "## 8. Different Control Group Options\n", + "\n", + "The CS estimator supports different control group specifications:\n", + "- `\"never_treated\"`: Only use units that are never treated\n", + "- `\"not_yet_treated\"`: Use units that haven't been treated yet at time t" + ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.383226Z", + "iopub.status.busy": "2026-01-18T18:23:54.383149Z", + "iopub.status.idle": "2026-01-18T18:23:54.388355Z", + "shell.execute_reply": "2026-01-18T18:23:54.388156Z" + } + }, "outputs": [], "source": [ "# Using not-yet-treated as control\n", @@ -361,29 +621,37 @@ " outcome=\"outcome\",\n", " unit=\"unit\",\n", " time=\"period\",\n", - " cohort=\"cohort\"\n", + " first_treat=\"cohort\"\n", ")\n", "\n", - "# Compare\n", - "simple_never = results_cs.aggregate(\"simple\")\n", - "simple_nyt = results_nyt.aggregate(\"simple\")\n", - "\n", + "# Compare using overall_att/overall_se attributes\n", "print(\"Comparison of control group specifications:\")\n", "print(f\"{'Control Group':<20} {'ATT':>10} {'SE':>10}\")\n", "print(\"-\" * 40)\n", - "print(f\"{'Never-treated':<20} {simple_never['att']:>10.4f} {simple_never['se']:>10.4f}\")\n", - "print(f\"{'Not-yet-treated':<20} {simple_nyt['att']:>10.4f} {simple_nyt['se']:>10.4f}\")" + "print(f\"{'Never-treated':<20} {results_cs.overall_att:>10.4f} {results_cs.overall_se:>10.4f}\")\n", + "print(f\"{'Not-yet-treated':<20} {results_nyt.overall_att:>10.4f} {results_nyt.overall_se:>10.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## 9. Handling Anticipation Effects\n\nIf units start changing behavior before official treatment (anticipation), you can specify the anticipation period." + "source": [ + "## 9. Handling Anticipation Effects\n", + "\n", + "If units start changing behavior before official treatment (anticipation), you can specify the anticipation period." + ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.389338Z", + "iopub.status.busy": "2026-01-18T18:23:54.389272Z", + "iopub.status.idle": "2026-01-18T18:23:54.393588Z", + "shell.execute_reply": "2026-01-18T18:23:54.393388Z" + } + }, "outputs": [], "source": [ "# Allow for 1 period of anticipation\n", @@ -397,22 +665,32 @@ " outcome=\"outcome\",\n", " unit=\"unit\",\n", " time=\"period\",\n", - " cohort=\"cohort\"\n", + " first_treat=\"cohort\"\n", ")\n", "\n", - "simple_antic = results_antic.aggregate(\"simple\")\n", - "print(f\"With anticipation=1: ATT = {simple_antic['att']:.4f}\")" + "print(f\"With anticipation=1: ATT = {results_antic.overall_att:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## 10. Adding Covariates\n\nYou can include covariates to improve precision through outcome regression or propensity score methods." + "source": [ + "## 10. Adding Covariates\n", + "\n", + "You can include covariates to improve precision through outcome regression or propensity score methods." + ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.394514Z", + "iopub.status.busy": "2026-01-18T18:23:54.394463Z", + "iopub.status.idle": "2026-01-18T18:23:54.404705Z", + "shell.execute_reply": "2026-01-18T18:23:54.404487Z" + } + }, "outputs": [], "source": [ "# Add covariates to data\n", @@ -429,23 +707,33 @@ " outcome=\"outcome\",\n", " unit=\"unit\",\n", " time=\"period\",\n", - " cohort=\"cohort\",\n", + " first_treat=\"cohort\",\n", " covariates=[\"size\", \"age\"]\n", ")\n", "\n", - "simple_cov = results_cov.aggregate(\"simple\")\n", - "print(f\"With covariates: ATT = {simple_cov['att']:.4f} (SE: {simple_cov['se']:.4f})\")" + "print(f\"With covariates: ATT = {results_cov.overall_att:.4f} (SE: {results_cov.overall_se:.4f})\")" ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## 11. Comparing with MultiPeriodDiD\n\nFor comparison, here's how you would use `MultiPeriodDiD` which estimates period-specific effects but doesn't handle staggered adoption as carefully." + "source": [ + "## 11. Comparing with MultiPeriodDiD\n", + "\n", + "For comparison, here's how you would use `MultiPeriodDiD` which estimates period-specific effects but doesn't handle staggered adoption as carefully." + ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.405752Z", + "iopub.status.busy": "2026-01-18T18:23:54.405696Z", + "iopub.status.idle": "2026-01-18T18:23:54.408664Z", + "shell.execute_reply": "2026-01-18T18:23:54.408454Z" + } + }, "outputs": [], "source": [ "# For this comparison, let's use data from cohort 3 only (single treatment timing)\n", @@ -466,7 +754,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.409636Z", + "iopub.status.busy": "2026-01-18T18:23:54.409582Z", + "iopub.status.idle": "2026-01-18T18:23:54.411076Z", + "shell.execute_reply": "2026-01-18T18:23:54.410866Z" + } + }, "outputs": [], "source": [ "# Period-specific effects from MultiPeriodDiD\n", @@ -477,39 +772,169 @@ }, { "cell_type": "markdown", - "source": "## 12. Sun-Abraham Interaction-Weighted Estimator\n\nThe Sun-Abraham (2021) estimator provides an alternative approach to staggered DiD. While Callaway-Sant'Anna aggregates 2x2 DiD comparisons, Sun-Abraham uses an **interaction-weighted regression** approach:\n\n1. Run a saturated regression with cohort × relative-time indicators\n2. Weight cohort-specific effects by each cohort's share of treated observations at each relative time\n\n**Key differences from CS:**\n- Regression-based vs. 2x2 DiD aggregation\n- Different weighting scheme\n- More efficient under homogeneous effects\n- Consistent under heterogeneous effects (like CS)\n\n**When to use both:** Running both CS and SA provides a useful robustness check. When they agree, results are more credible.", - "metadata": {} + "metadata": {}, + "source": [ + "## 12. Sun-Abraham Interaction-Weighted Estimator\n", + "\n", + "The Sun-Abraham (2021) estimator provides an alternative approach to staggered DiD. While Callaway-Sant'Anna aggregates 2x2 DiD comparisons, Sun-Abraham uses an **interaction-weighted regression** approach:\n", + "\n", + "1. Run a saturated regression with cohort × relative-time indicators\n", + "2. Weight cohort-specific effects by each cohort's share of treated observations at each relative time\n", + "\n", + "**Key differences from CS:**\n", + "- Regression-based vs. 2x2 DiD aggregation\n", + "- Different weighting scheme\n", + "- More efficient under homogeneous effects\n", + "- Consistent under heterogeneous effects (like CS)\n", + "\n", + "**When to use both:** Running both CS and SA provides a useful robustness check. When they agree, results are more credible." + ] }, { "cell_type": "code", - "source": "# Sun-Abraham estimation\nsa = SunAbraham(\n control_group=\"never_treated\", # Use never-treated as controls\n anticipation=0 # No anticipation effects\n)\n\nresults_sa = sa.fit(\n df,\n outcome=\"outcome\",\n unit=\"unit\",\n time=\"period\",\n first_treat=\"cohort\" # Column with first treatment period (0 = never treated)\n)\n\n# View summary\nresults_sa.print_summary()", - "metadata": {}, "execution_count": null, - "outputs": [] + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.412012Z", + "iopub.status.busy": "2026-01-18T18:23:54.411958Z", + "iopub.status.idle": "2026-01-18T18:23:54.422080Z", + "shell.execute_reply": "2026-01-18T18:23:54.421878Z" + } + }, + "outputs": [], + "source": [ + "# Sun-Abraham estimation\n", + "sa = SunAbraham(\n", + " control_group=\"never_treated\", # Use never-treated as controls\n", + " anticipation=0 # No anticipation effects\n", + ")\n", + "\n", + "results_sa = sa.fit(\n", + " df,\n", + " outcome=\"outcome\",\n", + " unit=\"unit\",\n", + " time=\"period\",\n", + " first_treat=\"cohort\" # Column with first treatment period (0 = never treated)\n", + ")\n", + "\n", + "# View summary\n", + "results_sa.print_summary()" + ] }, { "cell_type": "code", - "source": "# Event study effects by relative time\nprint(\"Sun-Abraham Event Study Effects:\")\nprint(f\"{'Rel. Time':>12} {'Effect':>10} {'SE':>10} {'p-value':>10}\")\nprint(\"-\" * 45)\n\nfor rel_time in sorted(results_sa.event_study_effects.keys()):\n eff = results_sa.event_study_effects[rel_time]\n sig = \"*\" if eff['p_value'] < 0.05 else \"\"\n print(f\"{rel_time:>12} {eff['effect']:>10.4f} {eff['se']:>10.4f} {eff['p_value']:>10.4f} {sig}\")\n\n# Cohort weights show how each cohort contributes to event-study estimates\nprint(\"\\n\\nCohort Weights by Relative Time:\")\nfor rel_time in sorted(results_sa.cohort_weights.keys()):\n weights = results_sa.cohort_weights[rel_time]\n print(f\"e={rel_time}: {weights}\")", - "metadata": {}, "execution_count": null, - "outputs": [] + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.423081Z", + "iopub.status.busy": "2026-01-18T18:23:54.423018Z", + "iopub.status.idle": "2026-01-18T18:23:54.425058Z", + "shell.execute_reply": "2026-01-18T18:23:54.424845Z" + } + }, + "outputs": [], + "source": [ + "# Event study effects by relative time\n", + "print(\"Sun-Abraham Event Study Effects:\")\n", + "print(f\"{'Rel. Time':>12} {'Effect':>10} {'SE':>10} {'p-value':>10}\")\n", + "print(\"-\" * 45)\n", + "\n", + "for rel_time in sorted(results_sa.event_study_effects.keys()):\n", + " eff = results_sa.event_study_effects[rel_time]\n", + " sig = \"*\" if eff['p_value'] < 0.05 else \"\"\n", + " print(f\"{rel_time:>12} {eff['effect']:>10.4f} {eff['se']:>10.4f} {eff['p_value']:>10.4f} {sig}\")\n", + "\n", + "# Cohort weights show how each cohort contributes to event-study estimates\n", + "print(\"\\n\\nCohort Weights by Relative Time:\")\n", + "for rel_time in sorted(results_sa.cohort_weights.keys()):\n", + " weights = results_sa.cohort_weights[rel_time]\n", + " print(f\"e={rel_time}: {weights}\")" + ] }, { "cell_type": "markdown", - "source": "## 13. Comparing CS and SA as a Robustness Check\n\nRunning both estimators provides a useful robustness check. When they agree, results are more credible.", - "metadata": {} + "metadata": {}, + "source": [ + "## 13. Comparing CS and SA as a Robustness Check\n", + "\n", + "Running both estimators provides a useful robustness check. When they agree, results are more credible." + ] }, { "cell_type": "code", - "source": "# Compare overall ATT from both estimators\ncs_att = results_cs.aggregate(\"simple\")\n\nprint(\"Robustness Check: CS vs SA\")\nprint(\"=\" * 50)\nprint(f\"{'Estimator':<25} {'Overall ATT':>12} {'SE':>10}\")\nprint(\"-\" * 50)\nprint(f\"{'Callaway-Sant\\'Anna':<25} {cs_att['att']:>12.4f} {cs_att['se']:>10.4f}\")\nprint(f\"{'Sun-Abraham':<25} {results_sa.overall_att:>12.4f} {results_sa.overall_se:>10.4f}\")\n\n# Compare event study effects\nprint(\"\\n\\nEvent Study Comparison:\")\nprint(f\"{'Rel. Time':>12} {'CS ATT':>10} {'SA ATT':>10} {'Difference':>12}\")\nprint(\"-\" * 50)\n\ncs_event = results_cs.aggregate(\"event\")\nfor rel_time in sorted(results_sa.event_study_effects.keys()):\n sa_eff = results_sa.event_study_effects[rel_time]['effect']\n if rel_time in cs_event:\n cs_eff = cs_event[rel_time]['att']\n diff = sa_eff - cs_eff\n print(f\"{rel_time:>12} {cs_eff:>10.4f} {sa_eff:>10.4f} {diff:>12.4f}\")\n\nprint(\"\\n→ Similar results indicate robust findings across estimation methods\")", - "metadata": {}, "execution_count": null, - "outputs": [] + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:23:54.426005Z", + "iopub.status.busy": "2026-01-18T18:23:54.425952Z", + "iopub.status.idle": "2026-01-18T18:23:54.427888Z", + "shell.execute_reply": "2026-01-18T18:23:54.427679Z" + } + }, + "outputs": [], + "source": [ + "# Compare overall ATT from both estimators\n", + "print(\"Robustness Check: CS vs SA\")\n", + "print(\"=\" * 50)\n", + "print(f\"{'Estimator':<25} {'Overall ATT':>12} {'SE':>10}\")\n", + "print(\"-\" * 50)\n", + "print(f\"{'Callaway-SantAnna':<25} {results_cs.overall_att:>12.4f} {results_cs.overall_se:>10.4f}\")\n", + "print(f\"{'Sun-Abraham':<25} {results_sa.overall_att:>12.4f} {results_sa.overall_se:>10.4f}\")\n", + "\n", + "# Compare event study effects\n", + "print(\"\\n\\nEvent Study Comparison:\")\n", + "print(f\"{'Rel. Time':>12} {'CS ATT':>10} {'SA ATT':>10} {'Difference':>12}\")\n", + "print(\"-\" * 50)\n", + "\n", + "# Use the pre-computed event_study_effects from results_cs\n", + "for rel_time in sorted(results_sa.event_study_effects.keys()):\n", + " sa_eff = results_sa.event_study_effects[rel_time]['effect']\n", + " if results_cs.event_study_effects and rel_time in results_cs.event_study_effects:\n", + " cs_eff = results_cs.event_study_effects[rel_time]['effect']\n", + " diff = sa_eff - cs_eff\n", + " print(f\"{rel_time:>12} {cs_eff:>10.4f} {sa_eff:>10.4f} {diff:>12.4f}\")\n", + "\n", + "print(\"\\nSimilar results indicate robust findings across estimation methods\")" + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## Summary\n\nKey takeaways:\n\n1. **TWFE can be biased** with staggered adoption and heterogeneous effects\n2. **Goodman-Bacon decomposition** reveals *why* TWFE fails by showing:\n - The implicit 2x2 comparisons and their weights\n - How much weight falls on \"forbidden comparisons\" (already-treated as controls)\n3. **Callaway-Sant'Anna** properly handles staggered adoption by:\n - Computing group-time specific effects ATT(g,t)\n - Only using valid comparison groups\n - Properly aggregating effects\n4. **Sun-Abraham** provides an alternative approach using:\n - Interaction-weighted regression with cohort × relative-time indicators\n - Different weighting scheme than CS\n - More efficient under homogeneous effects\n5. **Run both CS and SA** as a robustness check—when they agree, results are more credible\n6. **Aggregation options**:\n - `\"simple\"`: Overall ATT\n - `\"group\"`: ATT by cohort\n - `\"event\"`: ATT by event time (for event-study plots)\n7. **Bootstrap inference** provides valid standard errors and confidence intervals:\n - Use `n_bootstrap` parameter to enable multiplier bootstrap\n - Choose weight type: `'rademacher'`, `'mammen'`, or `'webb'`\n - Bootstrap results include SEs, CIs, and p-values for all aggregations\n8. **Control group choices** affect efficiency and assumptions:\n - `\"never_treated\"`: Stronger parallel trends assumption\n - `\"not_yet_treated\"`: Weaker assumption, uses more data\n\nFor more details, see:\n- Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-differences with multiple time periods. *Journal of Econometrics*.\n- Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in event studies with heterogeneous treatment effects. *Journal of Econometrics*.\n- Goodman-Bacon, A. (2021). Difference-in-differences with variation in treatment timing. *Journal of Econometrics*." + "source": [ + "## Summary\n", + "\n", + "Key takeaways:\n", + "\n", + "1. **TWFE can be biased** with staggered adoption and heterogeneous effects\n", + "2. **Goodman-Bacon decomposition** reveals *why* TWFE fails by showing:\n", + " - The implicit 2x2 comparisons and their weights\n", + " - How much weight falls on \"forbidden comparisons\" (already-treated as controls)\n", + "3. **Callaway-Sant'Anna** properly handles staggered adoption by:\n", + " - Computing group-time specific effects ATT(g,t)\n", + " - Only using valid comparison groups\n", + " - Properly aggregating effects\n", + "4. **Sun-Abraham** provides an alternative approach using:\n", + " - Interaction-weighted regression with cohort × relative-time indicators\n", + " - Different weighting scheme than CS\n", + " - More efficient under homogeneous effects\n", + "5. **Run both CS and SA** as a robustness check—when they agree, results are more credible\n", + "6. **Aggregation options**:\n", + " - `\"simple\"`: Overall ATT\n", + " - `\"group\"`: ATT by cohort\n", + " - `\"event\"`: ATT by event time (for event-study plots)\n", + "7. **Bootstrap inference** provides valid standard errors and confidence intervals:\n", + " - Use `n_bootstrap` parameter to enable multiplier bootstrap\n", + " - Choose weight type: `'rademacher'`, `'mammen'`, or `'webb'`\n", + " - Bootstrap results include SEs, CIs, and p-values for all aggregations\n", + "8. **Control group choices** affect efficiency and assumptions:\n", + " - `\"never_treated\"`: Stronger parallel trends assumption\n", + " - `\"not_yet_treated\"`: Weaker assumption, uses more data\n", + "\n", + "For more details, see:\n", + "- Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-differences with multiple time periods. *Journal of Econometrics*.\n", + "- Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in event studies with heterogeneous treatment effects. *Journal of Econometrics*.\n", + "- Goodman-Bacon, A. (2021). Difference-in-differences with variation in treatment timing. *Journal of Econometrics*." + ] } ], "metadata": { @@ -528,9 +953,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11" + "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/docs/tutorials/03_synthetic_did.ipynb b/docs/tutorials/03_synthetic_did.ipynb index 8e33a1eb..8815ea54 100644 --- a/docs/tutorials/03_synthetic_did.ipynb +++ b/docs/tutorials/03_synthetic_did.ipynb @@ -27,7 +27,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:00.673169Z", + "iopub.status.busy": "2026-01-18T18:24:00.673053Z", + "iopub.status.idle": "2026-01-18T18:24:01.249645Z", + "shell.execute_reply": "2026-01-18T18:24:01.249376Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -60,7 +67,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:01.250967Z", + "iopub.status.busy": "2026-01-18T18:24:01.250874Z", + "iopub.status.idle": "2026-01-18T18:24:01.254466Z", + "shell.execute_reply": "2026-01-18T18:24:01.254263Z" + } + }, "outputs": [], "source": [ "# Generate data with few treated units\n", @@ -122,7 +136,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:01.266189Z", + "iopub.status.busy": "2026-01-18T18:24:01.266102Z", + "iopub.status.idle": "2026-01-18T18:24:01.326489Z", + "shell.execute_reply": "2026-01-18T18:24:01.326257Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", @@ -163,7 +184,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:01.327617Z", + "iopub.status.busy": "2026-01-18T18:24:01.327539Z", + "iopub.status.idle": "2026-01-18T18:24:05.053292Z", + "shell.execute_reply": "2026-01-18T18:24:05.053046Z" + } + }, "outputs": [], "source": [ "# Fit Synthetic DiD\n", @@ -196,7 +224,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.054407Z", + "iopub.status.busy": "2026-01-18T18:24:05.054345Z", + "iopub.status.idle": "2026-01-18T18:24:05.057397Z", + "shell.execute_reply": "2026-01-18T18:24:05.057200Z" + } + }, "outputs": [], "source": [ "# Create post indicator for standard DiD\n", @@ -238,7 +273,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.058450Z", + "iopub.status.busy": "2026-01-18T18:24:05.058395Z", + "iopub.status.idle": "2026-01-18T18:24:05.061129Z", + "shell.execute_reply": "2026-01-18T18:24:05.060910Z" + } + }, "outputs": [], "source": [ "# View unit weights\n", @@ -250,7 +292,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.062204Z", + "iopub.status.busy": "2026-01-18T18:24:05.062123Z", + "iopub.status.idle": "2026-01-18T18:24:05.063869Z", + "shell.execute_reply": "2026-01-18T18:24:05.063662Z" + } + }, "outputs": [], "source": [ "# Check weight properties\n", @@ -264,7 +313,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.064771Z", + "iopub.status.busy": "2026-01-18T18:24:05.064721Z", + "iopub.status.idle": "2026-01-18T18:24:05.122105Z", + "shell.execute_reply": "2026-01-18T18:24:05.121719Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", @@ -294,7 +350,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.123239Z", + "iopub.status.busy": "2026-01-18T18:24:05.123166Z", + "iopub.status.idle": "2026-01-18T18:24:05.125341Z", + "shell.execute_reply": "2026-01-18T18:24:05.125153Z" + } + }, "outputs": [], "source": [ "# View time weights\n", @@ -306,7 +369,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.126213Z", + "iopub.status.busy": "2026-01-18T18:24:05.126147Z", + "iopub.status.idle": "2026-01-18T18:24:05.177856Z", + "shell.execute_reply": "2026-01-18T18:24:05.177598Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", @@ -334,7 +404,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.178970Z", + "iopub.status.busy": "2026-01-18T18:24:05.178905Z", + "iopub.status.idle": "2026-01-18T18:24:05.180389Z", + "shell.execute_reply": "2026-01-18T18:24:05.180145Z" + } + }, "outputs": [], "source": [ "print(f\"Pre-treatment fit (RMSE): {results.pre_treatment_fit:.4f}\")\n", @@ -344,7 +421,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.181314Z", + "iopub.status.busy": "2026-01-18T18:24:05.181251Z", + "iopub.status.idle": "2026-01-18T18:24:05.222992Z", + "shell.execute_reply": "2026-01-18T18:24:05.222769Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", @@ -394,19 +478,27 @@ "\n", "SDID supports two inference methods:\n", "\n", - "1. **Bootstrap** (`n_bootstrap > 0`): Resample and re-estimate\n", - "2. **Placebo** (`n_bootstrap = 0`): Use placebo effects from control units" + "1. **Bootstrap** (`variance_method=\"bootstrap\"`): Block bootstrap at unit level (default)\n", + "2. **Placebo** (`variance_method=\"placebo\"`): Placebo-based variance using Algorithm 4 from Arkhangelsky et al. (2021)" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.224029Z", + "iopub.status.busy": "2026-01-18T18:24:05.223969Z", + "iopub.status.idle": "2026-01-18T18:24:05.230577Z", + "shell.execute_reply": "2026-01-18T18:24:05.230384Z" + } + }, "outputs": [], "source": [ "# Placebo-based inference\n", "sdid_placebo = SyntheticDiD(\n", - " n_bootstrap=0, # Use placebo inference\n", + " variance_method=\"placebo\", # Use placebo inference\n", + " n_bootstrap=200, # Number of placebo replications\n", " seed=42\n", ")\n", "\n", @@ -428,7 +520,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.231516Z", + "iopub.status.busy": "2026-01-18T18:24:05.231464Z", + "iopub.status.idle": "2026-01-18T18:24:05.276787Z", + "shell.execute_reply": "2026-01-18T18:24:05.276580Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", @@ -464,7 +563,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.277862Z", + "iopub.status.busy": "2026-01-18T18:24:05.277806Z", + "iopub.status.idle": "2026-01-18T18:24:05.292848Z", + "shell.execute_reply": "2026-01-18T18:24:05.292651Z" + } + }, "outputs": [], "source": [ "# Compare different regularization levels\n", @@ -473,7 +579,8 @@ "for lambda_reg in [0.0, 1.0, 10.0]:\n", " sdid_reg = SyntheticDiD(\n", " lambda_reg=lambda_reg,\n", - " n_bootstrap=0,\n", + " variance_method=\"placebo\",\n", + " n_bootstrap=200,\n", " seed=42\n", " )\n", " \n", @@ -514,7 +621,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.293911Z", + "iopub.status.busy": "2026-01-18T18:24:05.293847Z", + "iopub.status.idle": "2026-01-18T18:24:05.295717Z", + "shell.execute_reply": "2026-01-18T18:24:05.295518Z" + } + }, "outputs": [], "source": [ "# Filter to single treated unit\n", @@ -528,12 +642,20 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.296723Z", + "iopub.status.busy": "2026-01-18T18:24:05.296651Z", + "iopub.status.idle": "2026-01-18T18:24:05.303028Z", + "shell.execute_reply": "2026-01-18T18:24:05.302821Z" + } + }, "outputs": [], "source": [ "# Fit SDID with single treated unit\n", "sdid_single = SyntheticDiD(\n", - " n_bootstrap=0,\n", + " variance_method=\"placebo\",\n", + " n_bootstrap=200,\n", " seed=42\n", ")\n", "\n", @@ -561,7 +683,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:24:05.304012Z", + "iopub.status.busy": "2026-01-18T18:24:05.303960Z", + "iopub.status.idle": "2026-01-18T18:24:06.149954Z", + "shell.execute_reply": "2026-01-18T18:24:06.149710Z" + } + }, "outputs": [], "source": [ "# Add covariates\n", @@ -602,8 +731,8 @@ "3. **Time weights**: Determine which pre-periods are most informative\n", "4. **Pre-treatment fit**: Lower RMSE indicates better synthetic match\n", "5. **Inference options**:\n", - " - Bootstrap (`n_bootstrap > 0`): Standard bootstrap SE\n", - " - Placebo (`n_bootstrap = 0`): Uses placebo effects from controls\n", + " - Bootstrap (`variance_method=\"bootstrap\"`): Block bootstrap at unit level (default)\n", + " - Placebo (`variance_method=\"placebo\"`): Placebo-based variance from controls\n", "6. **Regularization**: Higher values give more uniform weights\n", "\n", "Reference:\n", @@ -627,7 +756,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/docs/tutorials/04_parallel_trends.ipynb b/docs/tutorials/04_parallel_trends.ipynb index 86be01bf..17b6aa4c 100644 --- a/docs/tutorials/04_parallel_trends.ipynb +++ b/docs/tutorials/04_parallel_trends.ipynb @@ -20,7 +20,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:53.984356Z", + "iopub.status.busy": "2026-01-18T18:25:53.984243Z", + "iopub.status.idle": "2026-01-18T18:25:54.582162Z", + "shell.execute_reply": "2026-01-18T18:25:54.581881Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -63,7 +70,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.583396Z", + "iopub.status.busy": "2026-01-18T18:25:54.583309Z", + "iopub.status.idle": "2026-01-18T18:25:54.587946Z", + "shell.execute_reply": "2026-01-18T18:25:54.587740Z" + } + }, "outputs": [], "source": [ "def generate_panel_data(n_units=100, n_periods=8, parallel=True, seed=42):\n", @@ -139,7 +153,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.600264Z", + "iopub.status.busy": "2026-01-18T18:25:54.600181Z", + "iopub.status.idle": "2026-01-18T18:25:54.682661Z", + "shell.execute_reply": "2026-01-18T18:25:54.682443Z" + } + }, "outputs": [], "source": [ "def plot_trends(df, title, ax):\n", @@ -179,7 +200,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.683784Z", + "iopub.status.busy": "2026-01-18T18:25:54.683701Z", + "iopub.status.idle": "2026-01-18T18:25:54.686351Z", + "shell.execute_reply": "2026-01-18T18:25:54.686145Z" + } + }, "outputs": [], "source": [ "# Test for parallel trends (parallel case)\n", @@ -207,7 +235,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.687344Z", + "iopub.status.busy": "2026-01-18T18:25:54.687259Z", + "iopub.status.idle": "2026-01-18T18:25:54.689698Z", + "shell.execute_reply": "2026-01-18T18:25:54.689467Z" + } + }, "outputs": [], "source": [ "# Test for parallel trends (non-parallel case)\n", @@ -240,7 +275,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.690703Z", + "iopub.status.busy": "2026-01-18T18:25:54.690636Z", + "iopub.status.idle": "2026-01-18T18:25:54.721652Z", + "shell.execute_reply": "2026-01-18T18:25:54.721428Z" + } + }, "outputs": [], "source": [ "# Robust test (parallel case)\n", @@ -270,7 +312,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.722663Z", + "iopub.status.busy": "2026-01-18T18:25:54.722600Z", + "iopub.status.idle": "2026-01-18T18:25:54.752164Z", + "shell.execute_reply": "2026-01-18T18:25:54.751970Z" + } + }, "outputs": [], "source": [ "# Robust test (non-parallel case)\n", @@ -295,7 +344,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.753190Z", + "iopub.status.busy": "2026-01-18T18:25:54.753130Z", + "iopub.status.idle": "2026-01-18T18:25:54.833353Z", + "shell.execute_reply": "2026-01-18T18:25:54.833125Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", @@ -335,7 +391,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.834392Z", + "iopub.status.busy": "2026-01-18T18:25:54.834327Z", + "iopub.status.idle": "2026-01-18T18:25:54.837416Z", + "shell.execute_reply": "2026-01-18T18:25:54.837214Z" + } + }, "outputs": [], "source": [ "# Equivalence test (parallel case)\n", @@ -361,7 +424,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.838452Z", + "iopub.status.busy": "2026-01-18T18:25:54.838394Z", + "iopub.status.idle": "2026-01-18T18:25:54.841315Z", + "shell.execute_reply": "2026-01-18T18:25:54.841135Z" + } + }, "outputs": [], "source": [ "# Equivalence test (non-parallel case)\n", @@ -398,7 +468,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.842244Z", + "iopub.status.busy": "2026-01-18T18:25:54.842189Z", + "iopub.status.idle": "2026-01-18T18:25:54.844654Z", + "shell.execute_reply": "2026-01-18T18:25:54.844472Z" + } + }, "outputs": [], "source": [ "# First, fit the main model\n", @@ -418,7 +495,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.845525Z", + "iopub.status.busy": "2026-01-18T18:25:54.845467Z", + "iopub.status.idle": "2026-01-18T18:25:54.848634Z", + "shell.execute_reply": "2026-01-18T18:25:54.848440Z" + } + }, "outputs": [], "source": [ "# Placebo timing test\n", @@ -428,47 +512,68 @@ " outcome='outcome',\n", " treatment='treated',\n", " time='period',\n", - " placebo_time=2, # Pretend treatment at period 2\n", - " actual_treatment_time=4\n", + " fake_treatment_period=2, # Pretend treatment at period 2\n", + " post_periods=[4, 5, 6, 7] # Actual post-treatment periods to exclude\n", ")\n", "\n", "print(\"\\nPlacebo Timing Test:\")\n", "print(\"=\" * 50)\n", - "print(f\"Placebo ATT: {placebo_timing.effect:.4f}\")\n", + "print(f\"Placebo ATT: {placebo_timing.placebo_effect:.4f}\")\n", "print(f\"SE: {placebo_timing.se:.4f}\")\n", "print(f\"p-value: {placebo_timing.p_value:.4f}\")\n", - "print(f\"\\nPass (effect not significant): {placebo_timing.passed}\")" + "print(f\"\\nPass (effect not significant): {not placebo_timing.is_significant}\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.849616Z", + "iopub.status.busy": "2026-01-18T18:25:54.849557Z", + "iopub.status.idle": "2026-01-18T18:25:54.852741Z", + "shell.execute_reply": "2026-01-18T18:25:54.852566Z" + } + }, "outputs": [], "source": [ "# Placebo group test\n", - "# Estimate DiD using only never-treated units (random placebo assignment)\n", + "# Estimate DiD using only never-treated units (some randomly designated as \"fake treated\")\n", + "# First, identify control units (never-treated)\n", + "control_units = df_parallel[df_parallel['treated'] == 0]['unit'].unique()\n", + "\n", + "# Randomly select half of control units as \"fake treated\"\n", + "np.random.seed(42)\n", + "fake_treated = np.random.choice(control_units, size=len(control_units)//2, replace=False).tolist()\n", + "\n", "placebo_group = placebo_group_test(\n", " df_parallel,\n", " outcome='outcome',\n", - " treatment='treated',\n", - " time='post',\n", + " time='period',\n", " unit='unit',\n", - " seed=42\n", + " fake_treated_units=fake_treated,\n", + " post_periods=[4, 5, 6, 7] # Periods to use as post-treatment\n", ")\n", "\n", "print(\"\\nPlacebo Group Test:\")\n", "print(\"=\" * 50)\n", - "print(f\"Placebo ATT: {placebo_group.effect:.4f}\")\n", + "print(f\"Placebo ATT: {placebo_group.placebo_effect:.4f}\")\n", "print(f\"SE: {placebo_group.se:.4f}\")\n", "print(f\"p-value: {placebo_group.p_value:.4f}\")\n", - "print(f\"\\nPass: {placebo_group.passed}\")" + "print(f\"\\nPass (effect not significant): {not placebo_group.is_significant}\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:54.853660Z", + "iopub.status.busy": "2026-01-18T18:25:54.853607Z", + "iopub.status.idle": "2026-01-18T18:25:55.246903Z", + "shell.execute_reply": "2026-01-18T18:25:55.246635Z" + } + }, "outputs": [], "source": [ "# Permutation test\n", @@ -477,31 +582,39 @@ " outcome='outcome',\n", " treatment='treated',\n", " time='post',\n", + " unit='unit',\n", " n_permutations=999,\n", " seed=42\n", ")\n", "\n", "print(\"\\nPermutation Test:\")\n", "print(\"=\" * 50)\n", - "print(f\"Observed ATT: {perm_results.effect:.4f}\")\n", + "print(f\"Observed ATT: {perm_results.placebo_effect:.4f}\")\n", "print(f\"Permutation p-value: {perm_results.p_value:.4f}\")\n", - "print(f\"Number of permutations: {perm_results.n_permutations}\")" + "print(f\"Number of permutations: {len(perm_results.permutation_distribution)}\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:55.248297Z", + "iopub.status.busy": "2026-01-18T18:25:55.248212Z", + "iopub.status.idle": "2026-01-18T18:25:55.299914Z", + "shell.execute_reply": "2026-01-18T18:25:55.299690Z" + } + }, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", " # Visualize permutation distribution\n", " fig, ax = plt.subplots(figsize=(10, 6))\n", " \n", - " ax.hist(perm_results.permuted_effects, bins=30, alpha=0.7, \n", + " ax.hist(perm_results.permutation_distribution, bins=30, alpha=0.7, \n", " edgecolor='black', label='Permuted effects')\n", - " ax.axvline(x=perm_results.effect, color='red', linewidth=2, \n", - " linestyle='--', label=f'Observed = {perm_results.effect:.2f}')\n", + " ax.axvline(x=perm_results.placebo_effect, color='red', linewidth=2, \n", + " linestyle='--', label=f'Observed = {perm_results.placebo_effect:.2f}')\n", " ax.axvline(x=0, color='gray', linewidth=1, linestyle=':')\n", " \n", " ax.set_xlabel('Effect')\n", @@ -524,7 +637,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:55.301036Z", + "iopub.status.busy": "2026-01-18T18:25:55.300963Z", + "iopub.status.idle": "2026-01-18T18:25:55.308133Z", + "shell.execute_reply": "2026-01-18T18:25:55.307926Z" + } + }, "outputs": [], "source": [ "# Run comprehensive diagnostics\n", @@ -534,7 +654,8 @@ " treatment='treated',\n", " time='period',\n", " unit='unit',\n", - " treatment_time=4,\n", + " pre_periods=[0, 1, 2, 3], # Pre-treatment periods\n", + " post_periods=[4, 5, 6, 7], # Post-treatment periods\n", " n_permutations=499,\n", " seed=42\n", ")\n", @@ -545,7 +666,11 @@ "print(\"-\" * 60)\n", "\n", "for test_name, result in all_tests.items():\n", - " print(f\"{test_name:<25} {result.effect:>10.4f} {result.p_value:>10.4f} {str(result.passed):>10}\")" + " if isinstance(result, dict) and 'error' in result:\n", + " print(f\"{test_name:<25} {'ERROR':>10} {'-':>10} {result['error'][:20]}\")\n", + " else:\n", + " passed = not result.is_significant # Pass if NOT significant\n", + " print(f\"{test_name:<25} {result.placebo_effect:>10.4f} {result.p_value:>10.4f} {str(passed):>10}\")" ] }, { @@ -560,7 +685,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:55.309166Z", + "iopub.status.busy": "2026-01-18T18:25:55.309093Z", + "iopub.status.idle": "2026-01-18T18:25:55.312092Z", + "shell.execute_reply": "2026-01-18T18:25:55.311875Z" + } + }, "outputs": [], "source": [ "# Event study\n", @@ -580,7 +712,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:55.313134Z", + "iopub.status.busy": "2026-01-18T18:25:55.313065Z", + "iopub.status.idle": "2026-01-18T18:25:55.347212Z", + "shell.execute_reply": "2026-01-18T18:25:55.346899Z" + } + }, "outputs": [], "source": [ "from diff_diff.visualization import plot_event_study\n", @@ -615,7 +754,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:25:55.348553Z", + "iopub.status.busy": "2026-01-18T18:25:55.348463Z", + "iopub.status.idle": "2026-01-18T18:25:56.516352Z", + "shell.execute_reply": "2026-01-18T18:25:56.516020Z" + } + }, "outputs": [], "source": [ "# Example: Compare standard DiD vs Synthetic DiD on non-parallel data\n", @@ -705,7 +851,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/docs/tutorials/05_honest_did.ipynb b/docs/tutorials/05_honest_did.ipynb index d0232f01..c6379cc2 100644 --- a/docs/tutorials/05_honest_did.ipynb +++ b/docs/tutorials/05_honest_did.ipynb @@ -27,7 +27,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-1", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:27.539029Z", + "iopub.status.busy": "2026-01-18T18:11:27.538682Z", + "iopub.status.idle": "2026-01-18T18:11:28.165566Z", + "shell.execute_reply": "2026-01-18T18:11:28.165280Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -85,7 +92,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-4", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.166895Z", + "iopub.status.busy": "2026-01-18T18:11:28.166814Z", + "iopub.status.idle": "2026-01-18T18:11:28.171624Z", + "shell.execute_reply": "2026-01-18T18:11:28.171410Z" + } + }, "outputs": [], "source": [ "def generate_did_data(n_units=200, n_periods=10, true_att=5.0, \n", @@ -155,7 +169,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-6", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.172679Z", + "iopub.status.busy": "2026-01-18T18:11:28.172616Z", + "iopub.status.idle": "2026-01-18T18:11:28.178247Z", + "shell.execute_reply": "2026-01-18T18:11:28.178064Z" + } + }, "outputs": [], "source": [ "# Fit event study\n", @@ -175,7 +196,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-7", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.179151Z", + "iopub.status.busy": "2026-01-18T18:11:28.179098Z", + "iopub.status.idle": "2026-01-18T18:11:28.223339Z", + "shell.execute_reply": "2026-01-18T18:11:28.223125Z" + } + }, "outputs": [], "source": [ "from diff_diff.visualization import plot_event_study\n", @@ -213,7 +241,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-9", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.224277Z", + "iopub.status.busy": "2026-01-18T18:11:28.224216Z", + "iopub.status.idle": "2026-01-18T18:11:28.225894Z", + "shell.execute_reply": "2026-01-18T18:11:28.225716Z" + } + }, "outputs": [], "source": [ "# Create HonestDiD estimator\n", @@ -251,7 +286,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-11", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.226879Z", + "iopub.status.busy": "2026-01-18T18:11:28.226825Z", + "iopub.status.idle": "2026-01-18T18:11:28.228403Z", + "shell.execute_reply": "2026-01-18T18:11:28.228181Z" + } + }, "outputs": [], "source": [ "# Key results\n", @@ -277,7 +319,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-13", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.229357Z", + "iopub.status.busy": "2026-01-18T18:11:28.229303Z", + "iopub.status.idle": "2026-01-18T18:11:28.230989Z", + "shell.execute_reply": "2026-01-18T18:11:28.230812Z" + } + }, "outputs": [], "source": [ "# Run sensitivity analysis over a grid of M values\n", @@ -293,7 +342,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-14", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.231809Z", + "iopub.status.busy": "2026-01-18T18:11:28.231746Z", + "iopub.status.idle": "2026-01-18T18:11:28.233223Z", + "shell.execute_reply": "2026-01-18T18:11:28.233059Z" + } + }, "outputs": [], "source": [ "# Key takeaway: the breakdown value\n", @@ -312,7 +368,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-15", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.233991Z", + "iopub.status.busy": "2026-01-18T18:11:28.233943Z", + "iopub.status.idle": "2026-01-18T18:11:28.273698Z", + "shell.execute_reply": "2026-01-18T18:11:28.273494Z" + } + }, "outputs": [], "source": [ "# Visualize the sensitivity analysis\n", @@ -355,7 +418,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-18", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.274758Z", + "iopub.status.busy": "2026-01-18T18:11:28.274699Z", + "iopub.status.idle": "2026-01-18T18:11:28.276713Z", + "shell.execute_reply": "2026-01-18T18:11:28.276518Z" + } + }, "outputs": [], "source": [ "# Compare different M values\n", @@ -384,7 +454,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-20", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.277579Z", + "iopub.status.busy": "2026-01-18T18:11:28.277530Z", + "iopub.status.idle": "2026-01-18T18:11:28.279233Z", + "shell.execute_reply": "2026-01-18T18:11:28.279030Z" + } + }, "outputs": [], "source": [ "# Compute breakdown value directly\n", @@ -428,7 +505,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-22", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.280155Z", + "iopub.status.busy": "2026-01-18T18:11:28.280103Z", + "iopub.status.idle": "2026-01-18T18:11:28.287492Z", + "shell.execute_reply": "2026-01-18T18:11:28.287309Z" + } + }, "outputs": [], "source": [ "# Smoothness restriction\n", @@ -446,7 +530,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-23", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.288351Z", + "iopub.status.busy": "2026-01-18T18:11:28.288304Z", + "iopub.status.idle": "2026-01-18T18:11:28.291203Z", + "shell.execute_reply": "2026-01-18T18:11:28.291011Z" + } + }, "outputs": [], "source": [ "# Compare smoothness vs relative magnitudes\n", @@ -476,7 +567,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-25", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.292063Z", + "iopub.status.busy": "2026-01-18T18:11:28.292010Z", + "iopub.status.idle": "2026-01-18T18:11:28.293465Z", + "shell.execute_reply": "2026-01-18T18:11:28.293303Z" + } + }, "outputs": [], "source": [ "# One-liner for quick bounds\n", @@ -503,7 +601,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-27", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.294262Z", + "iopub.status.busy": "2026-01-18T18:11:28.294221Z", + "iopub.status.idle": "2026-01-18T18:11:28.297053Z", + "shell.execute_reply": "2026-01-18T18:11:28.296872Z" + } + }, "outputs": [], "source": [ "# Single result to DataFrame\n", @@ -515,7 +620,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-28", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:28.297856Z", + "iopub.status.busy": "2026-01-18T18:11:28.297808Z", + "iopub.status.idle": "2026-01-18T18:11:28.301015Z", + "shell.execute_reply": "2026-01-18T18:11:28.300833Z" + } + }, "outputs": [], "source": [ "# Sensitivity analysis to DataFrame\n", @@ -527,7 +639,41 @@ "cell_type": "markdown", "id": "cell-29", "metadata": {}, - "source": "## Summary\n\n**Key Takeaways:**\n\n1. **Honest DiD** provides robust inference without assuming parallel trends holds exactly\n\n2. **Relative magnitudes** (M̄) bounds post-treatment violations by a multiple of observed pre-treatment violations\n - M̄=0: Standard parallel trends\n - M̄=1: Violations as bad as worst pre-period\n - M̄>1: Even larger violations allowed\n\n3. **Smoothness** (M) bounds the curvature of violations over time\n - M=0: Linear extrapolation of pre-trends\n - M>0: Allows non-linear changes\n\n4. **Breakdown value** tells you how robust your conclusion is\n\n5. **Best practices:**\n - Report results for multiple M values\n - Include the sensitivity plot in publications\n - Discuss what violation magnitudes are plausible in your setting\n - Use breakdown value to assess robustness\n\n**Related Tutorials:**\n- `04_parallel_trends.ipynb` - Standard parallel trends testing\n- `06_power_analysis.ipynb` - Power analysis for study design\n- `07_pretrends_power.ipynb` - Pre-trends power analysis (Roth 2022) - assess what violations your pre-trends test could have detected\n\n**Reference:**\n\nRambachan, A., & Roth, J. (2023). A More Credible Approach to Parallel Trends. \n*The Review of Economic Studies*, 90(5), 2555-2591. \nhttps://doi.org/10.1093/restud/rdad018" + "source": [ + "## Summary\n", + "\n", + "**Key Takeaways:**\n", + "\n", + "1. **Honest DiD** provides robust inference without assuming parallel trends holds exactly\n", + "\n", + "2. **Relative magnitudes** (M̄) bounds post-treatment violations by a multiple of observed pre-treatment violations\n", + " - M̄=0: Standard parallel trends\n", + " - M̄=1: Violations as bad as worst pre-period\n", + " - M̄>1: Even larger violations allowed\n", + "\n", + "3. **Smoothness** (M) bounds the curvature of violations over time\n", + " - M=0: Linear extrapolation of pre-trends\n", + " - M>0: Allows non-linear changes\n", + "\n", + "4. **Breakdown value** tells you how robust your conclusion is\n", + "\n", + "5. **Best practices:**\n", + " - Report results for multiple M values\n", + " - Include the sensitivity plot in publications\n", + " - Discuss what violation magnitudes are plausible in your setting\n", + " - Use breakdown value to assess robustness\n", + "\n", + "**Related Tutorials:**\n", + "- `04_parallel_trends.ipynb` - Standard parallel trends testing\n", + "- `06_power_analysis.ipynb` - Power analysis for study design\n", + "- `07_pretrends_power.ipynb` - Pre-trends power analysis (Roth 2022) - assess what violations your pre-trends test could have detected\n", + "\n", + "**Reference:**\n", + "\n", + "Rambachan, A., & Roth, J. (2023). A More Credible Approach to Parallel Trends. \n", + "*The Review of Economic Studies*, 90(5), 2555-2591. \n", + "https://doi.org/10.1093/restud/rdad018" + ] } ], "metadata": { @@ -537,10 +683,18 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.11.0" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/tutorials/06_power_analysis.ipynb b/docs/tutorials/06_power_analysis.ipynb index 6e862c5c..8d7c661a 100644 --- a/docs/tutorials/06_power_analysis.ipynb +++ b/docs/tutorials/06_power_analysis.ipynb @@ -21,7 +21,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-1", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:38.463657Z", + "iopub.status.busy": "2026-01-18T18:11:38.463382Z", + "iopub.status.idle": "2026-01-18T18:11:39.054353Z", + "shell.execute_reply": "2026-01-18T18:11:39.054066Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -59,7 +66,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-3", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.055594Z", + "iopub.status.busy": "2026-01-18T18:11:39.055511Z", + "iopub.status.idle": "2026-01-18T18:11:39.057334Z", + "shell.execute_reply": "2026-01-18T18:11:39.057102Z" + } + }, "outputs": [], "source": [ "# Create a PowerAnalysis object with standard settings\n", @@ -91,7 +105,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-5", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.058344Z", + "iopub.status.busy": "2026-01-18T18:11:39.058290Z", + "iopub.status.idle": "2026-01-18T18:11:39.059837Z", + "shell.execute_reply": "2026-01-18T18:11:39.059623Z" + } + }, "outputs": [], "source": [ "# Calculate MDE for a basic 2x2 DiD design\n", @@ -109,7 +130,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-6", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.060673Z", + "iopub.status.busy": "2026-01-18T18:11:39.060624Z", + "iopub.status.idle": "2026-01-18T18:11:39.062114Z", + "shell.execute_reply": "2026-01-18T18:11:39.061922Z" + } + }, "outputs": [], "source": [ "# Access individual results\n", @@ -133,7 +161,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-8", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.063036Z", + "iopub.status.busy": "2026-01-18T18:11:39.062974Z", + "iopub.status.idle": "2026-01-18T18:11:39.064838Z", + "shell.execute_reply": "2026-01-18T18:11:39.064670Z" + } + }, "outputs": [], "source": [ "# Compare MDE across different sample sizes\n", @@ -161,7 +196,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-10", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.065717Z", + "iopub.status.busy": "2026-01-18T18:11:39.065664Z", + "iopub.status.idle": "2026-01-18T18:11:39.067314Z", + "shell.execute_reply": "2026-01-18T18:11:39.067101Z" + } + }, "outputs": [], "source": [ "# How many units do we need to detect an effect of 5 units?\n", @@ -177,7 +219,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-11", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.068173Z", + "iopub.status.busy": "2026-01-18T18:11:39.068121Z", + "iopub.status.idle": "2026-01-18T18:11:39.070300Z", + "shell.execute_reply": "2026-01-18T18:11:39.070125Z" + } + }, "outputs": [], "source": [ "# Compare sample sizes needed for different effect sizes\n", @@ -205,7 +254,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-13", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.071160Z", + "iopub.status.busy": "2026-01-18T18:11:39.071108Z", + "iopub.status.idle": "2026-01-18T18:11:39.072677Z", + "shell.execute_reply": "2026-01-18T18:11:39.072488Z" + } + }, "outputs": [], "source": [ "# What's our power to detect an effect of 4 with 75 units per group?\n", @@ -236,7 +292,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-15", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.073490Z", + "iopub.status.busy": "2026-01-18T18:11:39.073441Z", + "iopub.status.idle": "2026-01-18T18:11:39.083809Z", + "shell.execute_reply": "2026-01-18T18:11:39.083654Z" + } + }, "outputs": [], "source": [ "# Generate power curve data\n", @@ -257,7 +320,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-16", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.084613Z", + "iopub.status.busy": "2026-01-18T18:11:39.084563Z", + "iopub.status.idle": "2026-01-18T18:11:39.136695Z", + "shell.execute_reply": "2026-01-18T18:11:39.136496Z" + } + }, "outputs": [], "source": [ "# Plot the power curve\n", @@ -283,7 +353,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-18", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.137741Z", + "iopub.status.busy": "2026-01-18T18:11:39.137683Z", + "iopub.status.idle": "2026-01-18T18:11:39.180081Z", + "shell.execute_reply": "2026-01-18T18:11:39.179871Z" + } + }, "outputs": [], "source": [ "# How does power change with sample size for a fixed effect?\n", @@ -320,7 +397,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-20", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.181151Z", + "iopub.status.busy": "2026-01-18T18:11:39.181088Z", + "iopub.status.idle": "2026-01-18T18:11:39.183405Z", + "shell.execute_reply": "2026-01-18T18:11:39.183191Z" + } + }, "outputs": [], "source": [ "# Compare MDE with different numbers of periods\n", @@ -346,7 +430,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-21", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.184381Z", + "iopub.status.busy": "2026-01-18T18:11:39.184325Z", + "iopub.status.idle": "2026-01-18T18:11:39.186535Z", + "shell.execute_reply": "2026-01-18T18:11:39.186329Z" + } + }, "outputs": [], "source": [ "# Effect of intra-cluster correlation (ICC)\n", @@ -386,7 +477,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-23", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.187456Z", + "iopub.status.busy": "2026-01-18T18:11:39.187404Z", + "iopub.status.idle": "2026-01-18T18:11:39.332686Z", + "shell.execute_reply": "2026-01-18T18:11:39.332458Z" + } + }, "outputs": [], "source": [ "# Simulation-based power analysis\n", @@ -411,7 +509,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-24", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.333754Z", + "iopub.status.busy": "2026-01-18T18:11:39.333684Z", + "iopub.status.idle": "2026-01-18T18:11:39.335737Z", + "shell.execute_reply": "2026-01-18T18:11:39.335456Z" + } + }, "outputs": [], "source": [ "# Key metrics from simulation\n", @@ -439,7 +544,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-26", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.337041Z", + "iopub.status.busy": "2026-01-18T18:11:39.336952Z", + "iopub.status.idle": "2026-01-18T18:11:39.769029Z", + "shell.execute_reply": "2026-01-18T18:11:39.768817Z" + } + }, "outputs": [], "source": [ "# Simulate power for multiple effect sizes\n", @@ -463,7 +575,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-27", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.770023Z", + "iopub.status.busy": "2026-01-18T18:11:39.769965Z", + "iopub.status.idle": "2026-01-18T18:11:39.803246Z", + "shell.execute_reply": "2026-01-18T18:11:39.803031Z" + } + }, "outputs": [], "source": [ "# Plot simulation-based power curve\n", @@ -489,7 +608,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-29", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.804335Z", + "iopub.status.busy": "2026-01-18T18:11:39.804277Z", + "iopub.status.idle": "2026-01-18T18:11:39.806582Z", + "shell.execute_reply": "2026-01-18T18:11:39.806373Z" + } + }, "outputs": [], "source": [ "# Quick MDE calculation\n", @@ -526,7 +652,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-31", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.807527Z", + "iopub.status.busy": "2026-01-18T18:11:39.807469Z", + "iopub.status.idle": "2026-01-18T18:11:39.809560Z", + "shell.execute_reply": "2026-01-18T18:11:39.809379Z" + } + }, "outputs": [], "source": [ "# Sensitivity to sigma\n", @@ -555,7 +688,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-33", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:11:39.810531Z", + "iopub.status.busy": "2026-01-18T18:11:39.810474Z", + "iopub.status.idle": "2026-01-18T18:11:39.812613Z", + "shell.execute_reply": "2026-01-18T18:11:39.812380Z" + } + }, "outputs": [], "source": [ "# Compare required sample sizes for different power levels\n", @@ -573,7 +713,26 @@ "cell_type": "markdown", "id": "cell-34", "metadata": {}, - "source": "## Summary\n\nKey takeaways for DiD power analysis:\n\n1. **Always do a power analysis** before running a study\n2. **MDE decreases** with sample size, more periods, and lower variance\n3. **ICC matters** for panel data - high autocorrelation reduces effective sample size\n4. **Use simulation** for complex designs (staggered, synthetic DiD)\n5. **Be realistic about sigma** - err on the side of larger values\n6. **Consider your smallest meaningful effect** - don't just target statistical significance\n\nFor more on DiD estimation, see the other tutorials:\n- `01_basic_did.ipynb` - Basic DiD estimation\n- `02_staggered_did.ipynb` - Staggered adoption designs\n- `03_synthetic_did.ipynb` - Synthetic DiD\n- `04_parallel_trends.ipynb` - Testing assumptions\n- `05_honest_did.ipynb` - Sensitivity analysis\n- `07_pretrends_power.ipynb` - Pre-trends power analysis (Roth 2022)" + "source": [ + "## Summary\n", + "\n", + "Key takeaways for DiD power analysis:\n", + "\n", + "1. **Always do a power analysis** before running a study\n", + "2. **MDE decreases** with sample size, more periods, and lower variance\n", + "3. **ICC matters** for panel data - high autocorrelation reduces effective sample size\n", + "4. **Use simulation** for complex designs (staggered, synthetic DiD)\n", + "5. **Be realistic about sigma** - err on the side of larger values\n", + "6. **Consider your smallest meaningful effect** - don't just target statistical significance\n", + "\n", + "For more on DiD estimation, see the other tutorials:\n", + "- `01_basic_did.ipynb` - Basic DiD estimation\n", + "- `02_staggered_did.ipynb` - Staggered adoption designs\n", + "- `03_synthetic_did.ipynb` - Synthetic DiD\n", + "- `04_parallel_trends.ipynb` - Testing assumptions\n", + "- `05_honest_did.ipynb` - Sensitivity analysis\n", + "- `07_pretrends_power.ipynb` - Pre-trends power analysis (Roth 2022)" + ] } ], "metadata": { @@ -583,10 +742,18 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.9.0" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/tutorials/07_pretrends_power.ipynb b/docs/tutorials/07_pretrends_power.ipynb index 8e0f3fd9..01965101 100644 --- a/docs/tutorials/07_pretrends_power.ipynb +++ b/docs/tutorials/07_pretrends_power.ipynb @@ -26,7 +26,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-1", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.272194Z", + "iopub.status.busy": "2026-01-18T18:46:37.272120Z", + "iopub.status.idle": "2026-01-18T18:46:37.811943Z", + "shell.execute_reply": "2026-01-18T18:46:37.811663Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -88,7 +95,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-4", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.813263Z", + "iopub.status.busy": "2026-01-18T18:46:37.813169Z", + "iopub.status.idle": "2026-01-18T18:46:37.820440Z", + "shell.execute_reply": "2026-01-18T18:46:37.820211Z" + } + }, "outputs": [], "source": [ "def generate_event_study_data(n_units=300, n_periods=10, true_att=5.0, seed=42):\n", @@ -152,19 +166,35 @@ "cell_type": "code", "execution_count": null, "id": "cell-6", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.821563Z", + "iopub.status.busy": "2026-01-18T18:46:37.821487Z", + "iopub.status.idle": "2026-01-18T18:46:37.829928Z", + "shell.execute_reply": "2026-01-18T18:46:37.829695Z" + } + }, "outputs": [], "source": [ - "# Fit event study\n", + "# Fit event study with ALL periods (pre and post) relative to reference period\n", + "# For pre-trends power analysis, we need coefficients for pre-periods too\n", "mp_did = MultiPeriodDiD()\n", + "\n", + "# Use period 4 as the reference period (last pre-period, excluded from estimation)\n", + "# Estimate coefficients for all other periods: 0, 1, 2, 3 (pre) and 5, 6, 7, 8, 9 (post)\n", + "all_estimation_periods = [0, 1, 2, 3, 5, 6, 7, 8, 9] # All except reference period 4\n", + "\n", "event_results = mp_did.fit(\n", " df,\n", " outcome='outcome',\n", " treatment='treated',\n", " time='period',\n", - " post_periods=[5, 6, 7, 8, 9] # Periods 5-9 are post-treatment\n", + " post_periods=all_estimation_periods # Include all periods for full event study\n", ")\n", "\n", + "# Note: For standard DiD analysis, we'd normally use post_periods=[5,6,7,8,9]\n", + "# But for pre-trends power analysis, we need pre-period coefficients too\n", + "\n", "print(event_results.summary())" ] }, @@ -172,7 +202,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-7", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.831048Z", + "iopub.status.busy": "2026-01-18T18:46:37.830959Z", + "iopub.status.idle": "2026-01-18T18:46:37.832729Z", + "shell.execute_reply": "2026-01-18T18:46:37.832478Z" + } + }, "outputs": [], "source": [ "# Visualize the event study\n", @@ -211,7 +248,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-10", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.833877Z", + "iopub.status.busy": "2026-01-18T18:46:37.833807Z", + "iopub.status.idle": "2026-01-18T18:46:37.836682Z", + "shell.execute_reply": "2026-01-18T18:46:37.836459Z" + } + }, "outputs": [], "source": [ "# Create a PreTrendsPower object\n", @@ -221,8 +265,13 @@ " violation_type='linear' # Type of violation to consider\n", ")\n", "\n", - "# Fit to the event study results\n", - "pt_results = pt.fit(event_results)\n", + "# Define the actual pre-treatment periods (those before treatment starts at period 5)\n", + "# These are the periods we want to analyze for pre-trends power\n", + "pre_treatment_periods = [0, 1, 2, 3]\n", + "\n", + "# Fit to the event study results, specifying which periods are pre-treatment\n", + "# This is needed because we estimated all periods as post_periods in the event study\n", + "pt_results = pt.fit(event_results, pre_periods=pre_treatment_periods)\n", "\n", "print(pt_results.summary())" ] @@ -251,7 +300,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-12", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.837722Z", + "iopub.status.busy": "2026-01-18T18:46:37.837655Z", + "iopub.status.idle": "2026-01-18T18:46:37.839535Z", + "shell.execute_reply": "2026-01-18T18:46:37.839305Z" + } + }, "outputs": [], "source": [ "# Access key results\n", @@ -280,7 +336,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-14", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.840755Z", + "iopub.status.busy": "2026-01-18T18:46:37.840687Z", + "iopub.status.idle": "2026-01-18T18:46:37.843047Z", + "shell.execute_reply": "2026-01-18T18:46:37.842820Z" + } + }, "outputs": [], "source": [ "# Compute power for specific violation magnitudes\n", @@ -309,15 +372,40 @@ "cell_type": "code", "execution_count": null, "id": "cell-16", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.844061Z", + "iopub.status.busy": "2026-01-18T18:46:37.843996Z", + "iopub.status.idle": "2026-01-18T18:46:37.852782Z", + "shell.execute_reply": "2026-01-18T18:46:37.852547Z" + } + }, "outputs": [], - "source": "# Generate power curve\ncurve = pt.power_curve(\n event_results,\n n_points=50\n)\n\n# Preview the data\nprint(\"Power curve data (first 10 points):\")\nprint(curve.to_dataframe().head(10))" + "source": [ + "# Generate power curve\n", + "curve = pt.power_curve(\n", + " event_results,\n", + " n_points=50,\n", + " pre_periods=pre_treatment_periods\n", + ")\n", + "\n", + "# Preview the data\n", + "print(\"Power curve data (first 10 points):\")\n", + "print(curve.to_dataframe().head(10))" + ] }, { "cell_type": "code", "execution_count": null, "id": "cell-17", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.853831Z", + "iopub.status.busy": "2026-01-18T18:46:37.853755Z", + "iopub.status.idle": "2026-01-18T18:46:37.855457Z", + "shell.execute_reply": "2026-01-18T18:46:37.855229Z" + } + }, "outputs": [], "source": [ "# Plot the power curve\n", @@ -375,7 +463,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-20", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.856555Z", + "iopub.status.busy": "2026-01-18T18:46:37.856485Z", + "iopub.status.idle": "2026-01-18T18:46:37.861471Z", + "shell.execute_reply": "2026-01-18T18:46:37.861226Z" + } + }, "outputs": [], "source": [ "# Compare violation types\n", @@ -386,7 +481,7 @@ "\n", "for vtype in violation_types:\n", " pt_v = PreTrendsPower(violation_type=vtype)\n", - " results_v = pt_v.fit(event_results)\n", + " results_v = pt_v.fit(event_results, pre_periods=pre_treatment_periods)\n", " power_at_2 = results_v.power_at(2.0)\n", " print(f\"{vtype:>15} {results_v.mdv:>10.3f} {power_at_2:>15.1%}\")" ] @@ -395,22 +490,30 @@ "cell_type": "code", "execution_count": null, "id": "cell-21", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.862504Z", + "iopub.status.busy": "2026-01-18T18:46:37.862438Z", + "iopub.status.idle": "2026-01-18T18:46:37.865248Z", + "shell.execute_reply": "2026-01-18T18:46:37.865032Z" + } + }, "outputs": [], "source": [ "# Custom violation weights\n", "# Example: Violation concentrated in periods 2 and 3 (approaching treatment)\n", - "n_pre = len([e for e in event_results.period_effects if e.relative_time < 0])\n", + "# We have pre-periods 0, 1, 2, 3 estimated (reference period 4 is excluded)\n", + "n_pre = 4 # Periods 0, 1, 2, 3\n", "custom_weights = np.zeros(n_pre)\n", - "custom_weights[-2:] = 1.0 # Weight on last two pre-periods\n", + "custom_weights[-2:] = 1.0 # Weight on last two pre-periods (periods 2 and 3)\n", "\n", "pt_custom = PreTrendsPower(\n", " violation_type='custom',\n", " violation_weights=custom_weights\n", ")\n", - "results_custom = pt_custom.fit(event_results)\n", + "results_custom = pt_custom.fit(event_results, pre_periods=pre_treatment_periods)\n", "\n", - "print(f\"Custom violation (last 2 periods): MDV = {results_custom.mdv:.3f}\")" + "print(f\"Custom violation (last 2 pre-periods): MDV = {results_custom.mdv:.3f}\")" ] }, { @@ -425,9 +528,35 @@ "cell_type": "code", "execution_count": null, "id": "cell-23", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.866289Z", + "iopub.status.busy": "2026-01-18T18:46:37.866218Z", + "iopub.status.idle": "2026-01-18T18:46:37.868250Z", + "shell.execute_reply": "2026-01-18T18:46:37.868009Z" + } + }, "outputs": [], - "source": "if HAS_MATPLOTLIB:\n fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n \n for ax, vtype in zip(axes, ['linear', 'constant', 'last_period']):\n pt_v = PreTrendsPower(violation_type=vtype)\n curve_v = pt_v.power_curve(event_results, n_points=50)\n \n plot_pretrends_power(\n curve_v,\n ax=ax,\n show_mdv=True,\n target_power=0.80,\n title=f'Violation Type: {vtype.replace(\"_\", \" \").title()}',\n show=False\n )\n \n plt.tight_layout()\n plt.show()" + "source": [ + "if HAS_MATPLOTLIB:\n", + " fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", + " \n", + " for ax, vtype in zip(axes, ['linear', 'constant', 'last_period']):\n", + " pt_v = PreTrendsPower(violation_type=vtype)\n", + " curve_v = pt_v.power_curve(event_results, n_points=50, pre_periods=pre_treatment_periods)\n", + " \n", + " plot_pretrends_power(\n", + " curve_v,\n", + " ax=ax,\n", + " show_mdv=True,\n", + " target_power=0.80,\n", + " title=f'Violation Type: {vtype.replace(\"_\", \" \").title()}',\n", + " show=False\n", + " )\n", + " \n", + " plt.tight_layout()\n", + " plt.show()" + ] }, { "cell_type": "markdown", @@ -449,14 +578,21 @@ "cell_type": "code", "execution_count": null, "id": "cell-25", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.869275Z", + "iopub.status.busy": "2026-01-18T18:46:37.869196Z", + "iopub.status.idle": "2026-01-18T18:46:37.873240Z", + "shell.execute_reply": "2026-01-18T18:46:37.872987Z" + } + }, "outputs": [], "source": [ "from diff_diff import HonestDiD\n", "\n", "# First, compute MDV\n", "pt = PreTrendsPower(violation_type='linear')\n", - "pt_results = pt.fit(event_results)\n", + "pt_results = pt.fit(event_results, pre_periods=pre_treatment_periods)\n", "\n", "print(f\"MDV from pre-trends power analysis: {pt_results.mdv:.3f}\")\n", "print(\"\")\n", @@ -475,20 +611,29 @@ "cell_type": "code", "execution_count": null, "id": "cell-26", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.874345Z", + "iopub.status.busy": "2026-01-18T18:46:37.874260Z", + "iopub.status.idle": "2026-01-18T18:46:37.877151Z", + "shell.execute_reply": "2026-01-18T18:46:37.876909Z" + } + }, "outputs": [], "source": [ "# Use the built-in sensitivity integration\n", "sensitivity_results = pt.sensitivity_to_honest_did(\n", " event_results,\n", - " honest_method='smoothness'\n", + " pre_periods=pre_treatment_periods\n", ")\n", "\n", "print(\"Joint sensitivity analysis:\")\n", "print(f\" MDV: {sensitivity_results['mdv']:.3f}\")\n", - "print(f\" Original estimate: {sensitivity_results['original_estimate']:.3f}\")\n", - "print(f\" Robust CI at M=MDV: [{sensitivity_results['ci_lb']:.3f}, {sensitivity_results['ci_ub']:.3f}]\")\n", - "print(f\" Significant at M=MDV: {sensitivity_results['significant_at_mdv']}\")" + "print(f\" Max pre-period SE: {sensitivity_results['max_pre_se']:.3f}\")\n", + "print(f\" MDV / max(SE): {sensitivity_results['mdv_in_ses']:.2f}\")\n", + "print(\"\")\n", + "print(\"Interpretation:\")\n", + "print(sensitivity_results['interpretation'])" ] }, { @@ -505,17 +650,16 @@ "cell_type": "code", "execution_count": null, "id": "cell-28", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.878268Z", + "iopub.status.busy": "2026-01-18T18:46:37.878181Z", + "iopub.status.idle": "2026-01-18T18:46:37.881771Z", + "shell.execute_reply": "2026-01-18T18:46:37.881548Z" + } + }, "outputs": [], - "source": [ - "# Quick MDV calculation\n", - "mdv = compute_mdv(event_results, power=0.80, violation_type='linear')\n", - "print(f\"MDV: {mdv:.3f}\")\n", - "\n", - "# Quick power calculation at a specific violation\n", - "power = compute_pretrends_power(event_results, violation_magnitude=2.0)\n", - "print(f\"Power at violation=2.0: {power:.1%}\")" - ] + "source": "# Quick MDV calculation\nmdv = compute_mdv(event_results, target_power=0.80, violation_type='linear', pre_periods=pre_treatment_periods)\nprint(f\"MDV: {mdv:.3f}\")\n\n# Quick power calculation at a specific violation\npower_result = compute_pretrends_power(event_results, M=2.0, pre_periods=pre_treatment_periods)\nprint(f\"Power at violation=2.0: {power_result.power:.1%}\")" }, { "cell_type": "markdown", @@ -531,42 +675,50 @@ "cell_type": "code", "execution_count": null, "id": "cell-30", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.882847Z", + "iopub.status.busy": "2026-01-18T18:46:37.882777Z", + "iopub.status.idle": "2026-01-18T18:46:37.890940Z", + "shell.execute_reply": "2026-01-18T18:46:37.890708Z" + } + }, "outputs": [], "source": [ "# Typical workflow for pre-trends power analysis\n", "\n", - "# Step 1: Estimate event study\n", + "# Step 1: Estimate event study with ALL periods (pre and post) relative to reference\n", + "# For pre-trends power analysis, we need pre-period coefficients\n", "mp_did = MultiPeriodDiD()\n", + "\n", + "# Reference period is 4 (last pre-period)\n", + "# Estimate coefficients for periods 0, 1, 2, 3 (pre) and 5, 6, 7, 8, 9 (post)\n", + "all_estimation_periods = [0, 1, 2, 3, 5, 6, 7, 8, 9]\n", + "pre_treatment_periods = [0, 1, 2, 3] # Define which are pre-treatment\n", + "\n", "results = mp_did.fit(\n", " df, \n", " outcome='outcome',\n", " treatment='treated', \n", " time='period',\n", - " post_periods=[5, 6, 7, 8, 9]\n", + " post_periods=all_estimation_periods\n", ")\n", "\n", - "# Step 2: Check standard parallel trends test\n", - "print(\"Step 2: Standard Pre-Trends Test\")\n", - "print(f\"Pre-trends test p-value: {results.pretrend_test_pvalue:.4f}\")\n", - "print(f\"Conclusion: {'Fail to reject parallel trends' if results.pretrend_test_pvalue > 0.05 else 'Reject parallel trends'}\")\n", - "print(\"\")\n", - "\n", - "# Step 3: Assess power of the pre-trends test \n", - "print(\"Step 3: Pre-Trends Power Analysis\")\n", + "# Step 2: Assess power of the pre-trends test \n", + "print(\"Step 2: Pre-Trends Power Analysis\")\n", "pt = PreTrendsPower(alpha=0.05, power=0.80, violation_type='linear')\n", - "pt_results = pt.fit(results)\n", + "pt_results = pt.fit(results, pre_periods=pre_treatment_periods)\n", "print(f\"MDV (80% power): {pt_results.mdv:.3f}\")\n", "print(\"\")\n", "\n", - "# Step 4: Interpret\n", - "print(\"Step 4: Interpretation\")\n", + "# Step 3: Interpret\n", + "print(\"Step 3: Interpretation\")\n", "print(f\"Your pre-trends test could only detect violations >= {pt_results.mdv:.3f}\")\n", "print(f\"Violations smaller than this would likely go undetected.\")\n", "print(\"\")\n", "\n", - "# Step 5: Connect to Honest DiD for robust inference\n", - "print(\"Step 5: Robust Inference with Honest DiD\")\n", + "# Step 4: Connect to Honest DiD for robust inference\n", + "print(\"Step 4: Robust Inference with Honest DiD\")\n", "honest = HonestDiD(method='smoothness', M=pt_results.mdv)\n", "honest_results = honest.fit(results)\n", "print(f\"Robust 95% CI (M=MDV): [{honest_results.ci_lb:.3f}, {honest_results.ci_ub:.3f}]\")\n", @@ -587,7 +739,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-32", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.891913Z", + "iopub.status.busy": "2026-01-18T18:46:37.891848Z", + "iopub.status.idle": "2026-01-18T18:46:37.901021Z", + "shell.execute_reply": "2026-01-18T18:46:37.900793Z" + } + }, "outputs": [], "source": [ "# Export single result\n", @@ -597,7 +756,7 @@ "\n", "# Export power curve\n", "print(\"Power curve as DataFrame (first 10 rows):\")\n", - "curve = pt.power_curve(event_results)\n", + "curve = pt.power_curve(event_results, pre_periods=pre_treatment_periods)\n", "print(curve.to_dataframe().head(10))" ] }, @@ -605,7 +764,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-33", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:46:37.902030Z", + "iopub.status.busy": "2026-01-18T18:46:37.901963Z", + "iopub.status.idle": "2026-01-18T18:46:37.903580Z", + "shell.execute_reply": "2026-01-18T18:46:37.903362Z" + } + }, "outputs": [], "source": [ "# Export to dict for JSON serialization\n", @@ -667,8 +833,16 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.9" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" } }, "nbformat": 4, diff --git a/docs/tutorials/08_triple_diff.ipynb b/docs/tutorials/08_triple_diff.ipynb index 761891ad..18152db4 100644 --- a/docs/tutorials/08_triple_diff.ipynb +++ b/docs/tutorials/08_triple_diff.ipynb @@ -31,7 +31,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.278606Z", + "iopub.status.busy": "2026-01-18T18:12:01.278233Z", + "iopub.status.idle": "2026-01-18T18:12:01.870293Z", + "shell.execute_reply": "2026-01-18T18:12:01.870005Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -56,7 +63,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.871622Z", + "iopub.status.busy": "2026-01-18T18:12:01.871529Z", + "iopub.status.idle": "2026-01-18T18:12:01.888251Z", + "shell.execute_reply": "2026-01-18T18:12:01.888059Z" + } + }, "outputs": [], "source": [ "def generate_ddd_data(n_per_cell=200, true_att=2.0, seed=42):\n", @@ -118,7 +132,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.900276Z", + "iopub.status.busy": "2026-01-18T18:12:01.900188Z", + "iopub.status.idle": "2026-01-18T18:12:01.902831Z", + "shell.execute_reply": "2026-01-18T18:12:01.902625Z" + } + }, "outputs": [], "source": [ "# Create and fit the DDD estimator\n", @@ -157,7 +178,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.903891Z", + "iopub.status.busy": "2026-01-18T18:12:01.903826Z", + "iopub.status.idle": "2026-01-18T18:12:01.908608Z", + "shell.execute_reply": "2026-01-18T18:12:01.908413Z" + } + }, "outputs": [], "source": [ "# Compute cell means\n", @@ -199,7 +227,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.909520Z", + "iopub.status.busy": "2026-01-18T18:12:01.909466Z", + "iopub.status.idle": "2026-01-18T18:12:01.914372Z", + "shell.execute_reply": "2026-01-18T18:12:01.914168Z" + } + }, "outputs": [], "source": [ "# Compare estimation methods\n", @@ -231,7 +266,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.915272Z", + "iopub.status.busy": "2026-01-18T18:12:01.915218Z", + "iopub.status.idle": "2026-01-18T18:12:01.919476Z", + "shell.execute_reply": "2026-01-18T18:12:01.919279Z" + } + }, "outputs": [], "source": [ "# Estimate with covariates\n", @@ -261,7 +303,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.920395Z", + "iopub.status.busy": "2026-01-18T18:12:01.920337Z", + "iopub.status.idle": "2026-01-18T18:12:01.924546Z", + "shell.execute_reply": "2026-01-18T18:12:01.924341Z" + } + }, "outputs": [], "source": [ "# One-liner estimation\n", @@ -290,7 +339,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.925436Z", + "iopub.status.busy": "2026-01-18T18:12:01.925375Z", + "iopub.status.idle": "2026-01-18T18:12:01.990959Z", + "shell.execute_reply": "2026-01-18T18:12:01.990744Z" + } + }, "outputs": [], "source": [ "# Plot cell means over time\n", @@ -356,7 +412,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.992031Z", + "iopub.status.busy": "2026-01-18T18:12:01.991971Z", + "iopub.status.idle": "2026-01-18T18:12:01.993718Z", + "shell.execute_reply": "2026-01-18T18:12:01.993529Z" + } + }, "outputs": [], "source": [ "# Access individual results\n", @@ -372,7 +435,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.994564Z", + "iopub.status.busy": "2026-01-18T18:12:01.994507Z", + "iopub.status.idle": "2026-01-18T18:12:01.996334Z", + "shell.execute_reply": "2026-01-18T18:12:01.996146Z" + } + }, "outputs": [], "source": [ "# Convert to DataFrame for further analysis\n", @@ -383,7 +453,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:01.997174Z", + "iopub.status.busy": "2026-01-18T18:12:01.997120Z", + "iopub.status.idle": "2026-01-18T18:12:01.998502Z", + "shell.execute_reply": "2026-01-18T18:12:01.998337Z" + } + }, "outputs": [], "source": [ "# View cell means\n", @@ -434,7 +511,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/docs/tutorials/09_real_world_examples.ipynb b/docs/tutorials/09_real_world_examples.ipynb index 731d2dab..7e6054ca 100644 --- a/docs/tutorials/09_real_world_examples.ipynb +++ b/docs/tutorials/09_real_world_examples.ipynb @@ -20,7 +20,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-1", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:12.708254Z", + "iopub.status.busy": "2026-01-18T18:12:12.708150Z", + "iopub.status.idle": "2026-01-18T18:12:13.272246Z", + "shell.execute_reply": "2026-01-18T18:12:13.271985Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -55,7 +62,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-2", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.273525Z", + "iopub.status.busy": "2026-01-18T18:12:13.273430Z", + "iopub.status.idle": "2026-01-18T18:12:13.275263Z", + "shell.execute_reply": "2026-01-18T18:12:13.275046Z" + } + }, "outputs": [], "source": [ "# List available datasets\n", @@ -93,7 +107,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-4", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.276232Z", + "iopub.status.busy": "2026-01-18T18:12:13.276177Z", + "iopub.status.idle": "2026-01-18T18:12:13.401937Z", + "shell.execute_reply": "2026-01-18T18:12:13.401672Z" + } + }, "outputs": [], "source": [ "# Load the Card-Krueger dataset\n", @@ -110,7 +131,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-5", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.403112Z", + "iopub.status.busy": "2026-01-18T18:12:13.403030Z", + "iopub.status.idle": "2026-01-18T18:12:13.410010Z", + "shell.execute_reply": "2026-01-18T18:12:13.409796Z" + } + }, "outputs": [], "source": [ "# Summary statistics by state\n", @@ -146,7 +174,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-7", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.410999Z", + "iopub.status.busy": "2026-01-18T18:12:13.410935Z", + "iopub.status.idle": "2026-01-18T18:12:13.417417Z", + "shell.execute_reply": "2026-01-18T18:12:13.417183Z" + } + }, "outputs": [], "source": [ "# Reshape to long format\n", @@ -181,7 +216,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-9", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.418354Z", + "iopub.status.busy": "2026-01-18T18:12:13.418291Z", + "iopub.status.idle": "2026-01-18T18:12:13.421888Z", + "shell.execute_reply": "2026-01-18T18:12:13.421691Z" + } + }, "outputs": [], "source": [ "# Basic DiD estimation\n", @@ -203,7 +245,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-10", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.422847Z", + "iopub.status.busy": "2026-01-18T18:12:13.422783Z", + "iopub.status.idle": "2026-01-18T18:12:13.425967Z", + "shell.execute_reply": "2026-01-18T18:12:13.425767Z" + } + }, "outputs": [], "source": [ "# Manual calculation to verify\n", @@ -230,7 +279,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-11", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.426980Z", + "iopub.status.busy": "2026-01-18T18:12:13.426903Z", + "iopub.status.idle": "2026-01-18T18:12:13.429603Z", + "shell.execute_reply": "2026-01-18T18:12:13.429413Z" + } + }, "outputs": [], "source": [ "# With chain fixed effects for better precision\n", @@ -266,7 +322,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-13", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.430537Z", + "iopub.status.busy": "2026-01-18T18:12:13.430466Z", + "iopub.status.idle": "2026-01-18T18:12:13.516522Z", + "shell.execute_reply": "2026-01-18T18:12:13.516288Z" + } + }, "outputs": [], "source": [ "# Visualization: Employment trends\n", @@ -333,7 +396,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-15", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.517666Z", + "iopub.status.busy": "2026-01-18T18:12:13.517584Z", + "iopub.status.idle": "2026-01-18T18:12:13.727933Z", + "shell.execute_reply": "2026-01-18T18:12:13.727670Z" + } + }, "outputs": [], "source": [ "# Load the Castle Doctrine dataset\n", @@ -349,7 +419,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-16", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.729368Z", + "iopub.status.busy": "2026-01-18T18:12:13.729275Z", + "iopub.status.idle": "2026-01-18T18:12:13.732727Z", + "shell.execute_reply": "2026-01-18T18:12:13.732440Z" + } + }, "outputs": [], "source": [ "# Treatment timing\n", @@ -381,7 +458,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-18", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.734052Z", + "iopub.status.busy": "2026-01-18T18:12:13.733946Z", + "iopub.status.idle": "2026-01-18T18:12:13.739863Z", + "shell.execute_reply": "2026-01-18T18:12:13.739632Z" + } + }, "outputs": [], "source": [ "# TWFE estimation (potentially biased)\n", @@ -409,7 +493,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-19", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.740983Z", + "iopub.status.busy": "2026-01-18T18:12:13.740911Z", + "iopub.status.idle": "2026-01-18T18:12:13.768572Z", + "shell.execute_reply": "2026-01-18T18:12:13.768377Z" + } + }, "outputs": [], "source": [ "# Goodman-Bacon decomposition reveals the problem\n", @@ -428,7 +519,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-20", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.769528Z", + "iopub.status.busy": "2026-01-18T18:12:13.769465Z", + "iopub.status.idle": "2026-01-18T18:12:13.866261Z", + "shell.execute_reply": "2026-01-18T18:12:13.866039Z" + } + }, "outputs": [], "source": [ "# Visualize the decomposition\n", @@ -462,31 +560,103 @@ "cell_type": "code", "execution_count": null, "id": "cell-22", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.867310Z", + "iopub.status.busy": "2026-01-18T18:12:13.867249Z", + "iopub.status.idle": "2026-01-18T18:12:13.876157Z", + "shell.execute_reply": "2026-01-18T18:12:13.875945Z" + } + }, "outputs": [], - "source": "# Callaway-Sant'Anna estimation\ncs = CallawaySantAnna(\n control_group='never_treated',\n n_bootstrap=199,\n seed=42\n)\n\nresults_cs = cs.fit(\n castle,\n outcome='homicide_rate',\n unit='state',\n time='year',\n first_treat='first_treat',\n aggregate='all' # Compute all aggregations (simple, event_study, group)\n)\n\nprint(results_cs.summary())" + "source": [ + "# Callaway-Sant'Anna estimation\n", + "cs = CallawaySantAnna(\n", + " control_group='never_treated',\n", + " n_bootstrap=199,\n", + " seed=42\n", + ")\n", + "\n", + "results_cs = cs.fit(\n", + " castle,\n", + " outcome='homicide_rate',\n", + " unit='state',\n", + " time='year',\n", + " first_treat='first_treat',\n", + " aggregate='all' # Compute all aggregations (simple, event_study, group)\n", + ")\n", + "\n", + "print(results_cs.summary())" + ] }, { "cell_type": "code", "execution_count": null, "id": "cell-23", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.877146Z", + "iopub.status.busy": "2026-01-18T18:12:13.877089Z", + "iopub.status.idle": "2026-01-18T18:12:13.878877Z", + "shell.execute_reply": "2026-01-18T18:12:13.878673Z" + } + }, "outputs": [], - "source": "# Aggregated Results\nprint(\"Aggregated Results\")\nprint(\"=\" * 60)\n\n# Overall ATT (simple aggregation is computed automatically)\nprint(f\"\\nOverall ATT: {results_cs.overall_att:.4f} (SE: {results_cs.overall_se:.4f})\")\nprint(f\"95% CI: [{results_cs.overall_conf_int[0]:.4f}, {results_cs.overall_conf_int[1]:.4f}]\")\n\n# By cohort (group_effects is populated when aggregate='group' or 'all')\nprint(\"\\nEffects by Adoption Cohort:\")\nfor cohort in sorted(results_cs.group_effects.keys()):\n eff = results_cs.group_effects[cohort]\n print(f\" Cohort {cohort}: {eff['effect']:>7.4f} (SE: {eff['se']:.4f})\")" + "source": [ + "# Aggregated Results\n", + "print(\"Aggregated Results\")\n", + "print(\"=\" * 60)\n", + "\n", + "# Overall ATT (simple aggregation is computed automatically)\n", + "print(f\"\\nOverall ATT: {results_cs.overall_att:.4f} (SE: {results_cs.overall_se:.4f})\")\n", + "print(f\"95% CI: [{results_cs.overall_conf_int[0]:.4f}, {results_cs.overall_conf_int[1]:.4f}]\")\n", + "\n", + "# By cohort (group_effects is populated when aggregate='group' or 'all')\n", + "print(\"\\nEffects by Adoption Cohort:\")\n", + "for cohort in sorted(results_cs.group_effects.keys()):\n", + " eff = results_cs.group_effects[cohort]\n", + " print(f\" Cohort {cohort}: {eff['effect']:>7.4f} (SE: {eff['se']:.4f})\")" + ] }, { "cell_type": "code", "execution_count": null, "id": "cell-24", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.879816Z", + "iopub.status.busy": "2026-01-18T18:12:13.879764Z", + "iopub.status.idle": "2026-01-18T18:12:13.881970Z", + "shell.execute_reply": "2026-01-18T18:12:13.881750Z" + } + }, "outputs": [], - "source": "# Event study aggregation (event_study_effects is populated when aggregate='event_study' or 'all')\nprint(\"Event Study Results (Effect by Years Since Adoption)\")\nprint(\"=\" * 60)\nprint(f\"{'Event Time':>12} {'ATT':>10} {'SE':>10} {'95% CI':>25}\")\nprint(\"-\" * 60)\n\nfor e in sorted(results_cs.event_study_effects.keys()):\n eff = results_cs.event_study_effects[e]\n ci = eff['conf_int']\n sig = '*' if eff['p_value'] < 0.05 else ''\n print(f\"{e:>12} {eff['effect']:>10.4f} {eff['se']:>10.4f} [{ci[0]:>8.4f}, {ci[1]:>8.4f}] {sig}\")" + "source": [ + "# Event study aggregation (event_study_effects is populated when aggregate='event_study' or 'all')\n", + "print(\"Event Study Results (Effect by Years Since Adoption)\")\n", + "print(\"=\" * 60)\n", + "print(f\"{'Event Time':>12} {'ATT':>10} {'SE':>10} {'95% CI':>25}\")\n", + "print(\"-\" * 60)\n", + "\n", + "for e in sorted(results_cs.event_study_effects.keys()):\n", + " eff = results_cs.event_study_effects[e]\n", + " ci = eff['conf_int']\n", + " sig = '*' if eff['p_value'] < 0.05 else ''\n", + " print(f\"{e:>12} {eff['effect']:>10.4f} {eff['se']:>10.4f} [{ci[0]:>8.4f}, {ci[1]:>8.4f}] {sig}\")" + ] }, { "cell_type": "code", "execution_count": null, "id": "cell-25", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.882893Z", + "iopub.status.busy": "2026-01-18T18:12:13.882839Z", + "iopub.status.idle": "2026-01-18T18:12:13.914914Z", + "shell.execute_reply": "2026-01-18T18:12:13.914698Z" + } + }, "outputs": [], "source": [ "# Event study visualization\n", @@ -517,7 +687,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-27", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.915986Z", + "iopub.status.busy": "2026-01-18T18:12:13.915913Z", + "iopub.status.idle": "2026-01-18T18:12:13.945161Z", + "shell.execute_reply": "2026-01-18T18:12:13.944977Z" + } + }, "outputs": [], "source": [ "# Sun-Abraham estimation\n", @@ -538,9 +715,25 @@ "cell_type": "code", "execution_count": null, "id": "cell-28", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.946119Z", + "iopub.status.busy": "2026-01-18T18:12:13.946062Z", + "iopub.status.idle": "2026-01-18T18:12:13.947615Z", + "shell.execute_reply": "2026-01-18T18:12:13.947414Z" + } + }, "outputs": [], - "source": "# Compare CS and SA\nprint(\"Robustness Check: CS vs Sun-Abraham\")\nprint(\"=\" * 60)\nprint(f\"{'Estimator':<25} {'Overall ATT':>15} {'SE':>10}\")\nprint(\"-\" * 60)\nprint(f\"{'Callaway-Sant\\'Anna':<25} {results_cs.overall_att:>15.4f} {results_cs.overall_se:>10.4f}\")\nprint(f\"{'Sun-Abraham':<25} {results_sa.overall_att:>15.4f} {results_sa.overall_se:>10.4f}\")\nprint(f\"{'TWFE (potentially biased)':<25} {results_twfe.att:>15.4f} {results_twfe.se:>10.4f}\")" + "source": [ + "# Compare CS and SA\n", + "print(\"Robustness Check: CS vs Sun-Abraham\")\n", + "print(\"=\" * 60)\n", + "print(f\"{'Estimator':<25} {'Overall ATT':>15} {'SE':>10}\")\n", + "print(\"-\" * 60)\n", + "print(f\"{'Callaway-Sant\\'Anna':<25} {results_cs.overall_att:>15.4f} {results_cs.overall_se:>10.4f}\")\n", + "print(f\"{'Sun-Abraham':<25} {results_sa.overall_att:>15.4f} {results_sa.overall_se:>10.4f}\")\n", + "print(f\"{'TWFE (potentially biased)':<25} {results_twfe.att:>15.4f} {results_twfe.se:>10.4f}\")" + ] }, { "cell_type": "markdown", @@ -569,7 +762,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-30", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:13.948464Z", + "iopub.status.busy": "2026-01-18T18:12:13.948408Z", + "iopub.status.idle": "2026-01-18T18:12:14.044098Z", + "shell.execute_reply": "2026-01-18T18:12:14.043827Z" + } + }, "outputs": [], "source": [ "# Load divorce laws dataset\n", @@ -585,7 +785,14 @@ "cell_type": "code", "execution_count": null, "id": "cell-31", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:14.045297Z", + "iopub.status.busy": "2026-01-18T18:12:14.045219Z", + "iopub.status.idle": "2026-01-18T18:12:14.048293Z", + "shell.execute_reply": "2026-01-18T18:12:14.048051Z" + } + }, "outputs": [], "source": [ "# Treatment timing distribution\n", @@ -606,23 +813,73 @@ "cell_type": "code", "execution_count": null, "id": "cell-32", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:14.049266Z", + "iopub.status.busy": "2026-01-18T18:12:14.049203Z", + "iopub.status.idle": "2026-01-18T18:12:14.078423Z", + "shell.execute_reply": "2026-01-18T18:12:14.078229Z" + } + }, "outputs": [], - "source": "# Callaway-Sant'Anna estimation\ncs_divorce = CallawaySantAnna(\n control_group='never_treated',\n n_bootstrap=199,\n seed=42\n)\n\nresults_divorce = cs_divorce.fit(\n divorce,\n outcome='divorce_rate',\n unit='state',\n time='year',\n first_treat='first_treat',\n aggregate='all' # Compute all aggregations (simple, event_study, group)\n)\n\nprint(results_divorce.summary())" + "source": [ + "# Callaway-Sant'Anna estimation\n", + "cs_divorce = CallawaySantAnna(\n", + " control_group='never_treated',\n", + " n_bootstrap=199,\n", + " seed=42\n", + ")\n", + "\n", + "results_divorce = cs_divorce.fit(\n", + " divorce,\n", + " outcome='divorce_rate',\n", + " unit='state',\n", + " time='year',\n", + " first_treat='first_treat',\n", + " aggregate='all' # Compute all aggregations (simple, event_study, group)\n", + ")\n", + "\n", + "print(results_divorce.summary())" + ] }, { "cell_type": "code", "execution_count": null, "id": "cell-33", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:14.079304Z", + "iopub.status.busy": "2026-01-18T18:12:14.079243Z", + "iopub.status.idle": "2026-01-18T18:12:14.081149Z", + "shell.execute_reply": "2026-01-18T18:12:14.080864Z" + } + }, "outputs": [], - "source": "# Event study results (event_study_effects is populated when aggregate='event_study' or 'all')\nprint(\"Event Study: Effect of Unilateral Divorce on Divorce Rates\")\nprint(\"=\" * 65)\nprint(f\"{'Years Since':>12} {'Effect':>10} {'SE':>10} {'Significant':>12}\")\nprint(\"-\" * 65)\n\nfor e in sorted(results_divorce.event_study_effects.keys()):\n eff = results_divorce.event_study_effects[e]\n sig = 'Yes' if eff['p_value'] < 0.05 else 'No'\n print(f\"{e:>12} {eff['effect']:>10.4f} {eff['se']:>10.4f} {sig:>12}\")" + "source": [ + "# Event study results (event_study_effects is populated when aggregate='event_study' or 'all')\n", + "print(\"Event Study: Effect of Unilateral Divorce on Divorce Rates\")\n", + "print(\"=\" * 65)\n", + "print(f\"{'Years Since':>12} {'Effect':>10} {'SE':>10} {'Significant':>12}\")\n", + "print(\"-\" * 65)\n", + "\n", + "for e in sorted(results_divorce.event_study_effects.keys()):\n", + " eff = results_divorce.event_study_effects[e]\n", + " sig = 'Yes' if eff['p_value'] < 0.05 else 'No'\n", + " print(f\"{e:>12} {eff['effect']:>10.4f} {eff['se']:>10.4f} {sig:>12}\")" + ] }, { "cell_type": "code", "execution_count": null, "id": "cell-34", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:14.082045Z", + "iopub.status.busy": "2026-01-18T18:12:14.081990Z", + "iopub.status.idle": "2026-01-18T18:12:14.128418Z", + "shell.execute_reply": "2026-01-18T18:12:14.128205Z" + } + }, "outputs": [], "source": [ "# Event study visualization\n", @@ -659,9 +916,25 @@ "cell_type": "code", "execution_count": null, "id": "cell-36", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-01-18T18:12:14.129425Z", + "iopub.status.busy": "2026-01-18T18:12:14.129370Z", + "iopub.status.idle": "2026-01-18T18:12:14.131060Z", + "shell.execute_reply": "2026-01-18T18:12:14.130880Z" + } + }, "outputs": [], - "source": "# Effects by cohort (group_effects is populated when aggregate='group' or 'all')\nprint(\"Effects by Adoption Cohort\")\nprint(\"=\" * 50)\n\nfor cohort in sorted(results_divorce.group_effects.keys()):\n eff = results_divorce.group_effects[cohort]\n sig = '*' if eff['p_value'] < 0.05 else ''\n print(f\"Cohort {cohort}: {eff['effect']:>7.4f} (SE: {eff['se']:.4f}) {sig}\")" + "source": [ + "# Effects by cohort (group_effects is populated when aggregate='group' or 'all')\n", + "print(\"Effects by Adoption Cohort\")\n", + "print(\"=\" * 50)\n", + "\n", + "for cohort in sorted(results_divorce.group_effects.keys()):\n", + " eff = results_divorce.group_effects[cohort]\n", + " sig = '*' if eff['p_value'] < 0.05 else ''\n", + " print(f\"Cohort {cohort}: {eff['effect']:>7.4f} (SE: {eff['se']:.4f}) {sig}\")" + ] }, { "cell_type": "markdown", @@ -723,10 +996,18 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.9" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/tutorials/10_trop.ipynb b/docs/tutorials/10_trop.ipynb index 21fa5b55..715d54a2 100644 --- a/docs/tutorials/10_trop.ipynb +++ b/docs/tutorials/10_trop.ipynb @@ -65,7 +65,105 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "def generate_factor_dgp(\n n_units=50,\n n_pre=10,\n n_post=5,\n n_treated=10,\n n_factors=2,\n treatment_effect=2.0,\n factor_strength=1.0,\n noise_std=0.5,\n seed=42\n):\n \"\"\"\n Generate panel data with known factor structure.\n \n DGP: Y_it = mu + alpha_i + beta_t + L_it + tau*D_it + eps_it\n \n where L_it = Lambda_i'F_t is the interactive fixed effects component.\n \n This creates a scenario where standard DiD/SDID may be biased,\n but TROP should recover the true treatment effect.\n \n Returns DataFrame with columns:\n - 'treated': observation-level indicator (1 if treated AND post-period) - for TROP\n - 'treat': unit-level ever-treated indicator (1 for all periods if unit is treated) - for SDID\n \"\"\"\n rng = np.random.default_rng(seed)\n \n n_control = n_units - n_treated\n n_periods = n_pre + n_post\n \n # Generate factors F: (n_periods, n_factors)\n F = rng.normal(0, 1, (n_periods, n_factors))\n \n # Generate loadings Lambda: (n_factors, n_units)\n # Make treated units have correlated loadings (creates confounding)\n Lambda = rng.normal(0, 1, (n_factors, n_units))\n Lambda[:, :n_treated] += 0.5 # Treated units have higher loadings\n \n # Unit fixed effects\n alpha = rng.normal(0, 1, n_units)\n alpha[:n_treated] += 1.0 # Treated units have higher intercept\n \n # Time fixed effects\n beta = np.linspace(0, 2, n_periods)\n \n # Generate outcomes\n data = []\n for i in range(n_units):\n is_treated = i < n_treated\n \n for t in range(n_periods):\n post = t >= n_pre\n \n y = 10.0 + alpha[i] + beta[t]\n y += factor_strength * (Lambda[:, i] @ F[t, :]) # L_it component\n \n if is_treated and post:\n y += treatment_effect\n \n y += rng.normal(0, noise_std)\n \n data.append({\n 'unit': i,\n 'period': t,\n 'outcome': y,\n 'treated': int(is_treated and post), # Observation-level (for TROP)\n 'treat': int(is_treated) # Unit-level ever-treated (for SDID)\n })\n \n return pd.DataFrame(data)\n\n\n# Generate data with factor structure\ntrue_att = 2.0\nn_factors = 2\nn_pre = 10\nn_post = 5\n\ndf = generate_factor_dgp(\n n_units=50,\n n_pre=n_pre,\n n_post=n_post,\n n_treated=10,\n n_factors=n_factors,\n treatment_effect=true_att,\n factor_strength=1.5, # Strong factor confounding\n noise_std=0.5,\n seed=42\n)\n\nprint(f\"Dataset: {len(df)} observations\")\nprint(f\"Treated units: 10\")\nprint(f\"Control units: 40\")\nprint(f\"Pre-treatment periods: {n_pre}\")\nprint(f\"Post-treatment periods: {n_post}\")\nprint(f\"True treatment effect: {true_att}\")\nprint(f\"True number of factors: {n_factors}\")" + "source": [ + "def generate_factor_dgp(\n", + " n_units=50,\n", + " n_pre=10,\n", + " n_post=5,\n", + " n_treated=10,\n", + " n_factors=2,\n", + " treatment_effect=2.0,\n", + " factor_strength=1.0,\n", + " noise_std=0.5,\n", + " seed=42\n", + "):\n", + " \"\"\"\n", + " Generate panel data with known factor structure.\n", + " \n", + " DGP: Y_it = mu + alpha_i + beta_t + L_it + tau*D_it + eps_it\n", + " \n", + " where L_it = Lambda_i'F_t is the interactive fixed effects component.\n", + " \n", + " This creates a scenario where standard DiD/SDID may be biased,\n", + " but TROP should recover the true treatment effect.\n", + " \n", + " Returns DataFrame with columns:\n", + " - 'treated': observation-level indicator (1 if treated AND post-period) - for TROP\n", + " - 'treat': unit-level ever-treated indicator (1 for all periods if unit is treated) - for SDID\n", + " \"\"\"\n", + " rng = np.random.default_rng(seed)\n", + " \n", + " n_control = n_units - n_treated\n", + " n_periods = n_pre + n_post\n", + " \n", + " # Generate factors F: (n_periods, n_factors)\n", + " F = rng.normal(0, 1, (n_periods, n_factors))\n", + " \n", + " # Generate loadings Lambda: (n_factors, n_units)\n", + " # Make treated units have correlated loadings (creates confounding)\n", + " Lambda = rng.normal(0, 1, (n_factors, n_units))\n", + " Lambda[:, :n_treated] += 0.5 # Treated units have higher loadings\n", + " \n", + " # Unit fixed effects\n", + " alpha = rng.normal(0, 1, n_units)\n", + " alpha[:n_treated] += 1.0 # Treated units have higher intercept\n", + " \n", + " # Time fixed effects\n", + " beta = np.linspace(0, 2, n_periods)\n", + " \n", + " # Generate outcomes\n", + " data = []\n", + " for i in range(n_units):\n", + " is_treated = i < n_treated\n", + " \n", + " for t in range(n_periods):\n", + " post = t >= n_pre\n", + " \n", + " y = 10.0 + alpha[i] + beta[t]\n", + " y += factor_strength * (Lambda[:, i] @ F[t, :]) # L_it component\n", + " \n", + " if is_treated and post:\n", + " y += treatment_effect\n", + " \n", + " y += rng.normal(0, noise_std)\n", + " \n", + " data.append({\n", + " 'unit': i,\n", + " 'period': t,\n", + " 'outcome': y,\n", + " 'treated': int(is_treated and post), # Observation-level (for TROP)\n", + " 'treat': int(is_treated) # Unit-level ever-treated (for SDID)\n", + " })\n", + " \n", + " return pd.DataFrame(data)\n", + "\n", + "\n", + "# Generate data with factor structure (reduced size for faster execution)\n", + "true_att = 2.0\n", + "n_factors = 2\n", + "n_pre = 6 # Reduced from 10\n", + "n_post = 3 # Reduced from 5\n", + "\n", + "df = generate_factor_dgp(\n", + " n_units=30, # Reduced from 50\n", + " n_pre=n_pre,\n", + " n_post=n_post,\n", + " n_treated=6, # Reduced from 10\n", + " n_factors=n_factors,\n", + " treatment_effect=true_att,\n", + " factor_strength=1.5, # Strong factor confounding\n", + " noise_std=0.5,\n", + " seed=42\n", + ")\n", + "\n", + "print(f\"Dataset: {len(df)} observations\")\n", + "print(f\"Treated units: 6\")\n", + "print(f\"Control units: 24\")\n", + "print(f\"Pre-treatment periods: {n_pre}\")\n", + "print(f\"Post-treatment periods: {n_post}\")\n", + "print(f\"True treatment effect: {true_att}\")\n", + "print(f\"True number of factors: {n_factors}\")" + ] }, { "cell_type": "code", @@ -128,10 +226,10 @@ "source": [ "# Fit TROP with automatic tuning via LOOCV\n", "trop_est = TROP(\n", - " lambda_time_grid=[0.0, 0.5, 1.0, 2.0], # Time decay grid\n", - " lambda_unit_grid=[0.0, 0.5, 1.0, 2.0], # Unit distance grid \n", - " lambda_nn_grid=[0.0, 0.1, 1.0], # Nuclear norm grid\n", - " n_bootstrap=100, # Bootstrap replications for SE\n", + " lambda_time_grid=[0.0, 1.0], # Reduced time decay grid\n", + " lambda_unit_grid=[0.0, 1.0], # Reduced unit distance grid \n", + " lambda_nn_grid=[0.0, 0.1], # Reduced nuclear norm grid\n", + " n_bootstrap=50, # Reduced bootstrap replications for SE\n", " seed=42\n", ")\n", "\n", @@ -204,12 +302,12 @@ "print(f\"{'λ_nn':>10} {'ATT':>12} {'Bias':>12} {'Eff. Rank':>15}\")\n", "print(\"-\"*65)\n", "\n", - "for lambda_nn in [0.0, 0.1, 1.0, 10.0]:\n", + "for lambda_nn in [0.0, 0.1, 1.0]: # Reduced grid\n", " trop_fixed = TROP(\n", " lambda_time_grid=[1.0], # Fixed\n", " lambda_unit_grid=[1.0], # Fixed\n", " lambda_nn_grid=[lambda_nn], # Vary this\n", - " n_bootstrap=20,\n", + " n_bootstrap=20, # Reduced for faster execution\n", " seed=42\n", " )\n", " \n", @@ -360,7 +458,55 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "# SDID (no factor adjustment)\n# Note: SDID uses 'treat' (unit-level ever-treated indicator)\nsdid = SyntheticDiD(\n n_bootstrap=100,\n seed=42\n)\n\nsdid_results = sdid.fit(\n df,\n outcome='outcome',\n treatment='treat', # Unit-level ever-treated indicator\n unit='unit',\n time='period',\n post_periods=post_periods\n)\n\n# TROP (with factor adjustment)\n# Note: TROP uses 'treated' (observation-level treatment indicator)\ntrop_est2 = TROP(\n lambda_nn_grid=[0.0, 0.1, 1.0], # Allow factor estimation\n n_bootstrap=100,\n seed=42\n)\n\ntrop_results = trop_est2.fit(\n df,\n outcome='outcome',\n treatment='treated', # Observation-level indicator\n unit='unit',\n time='period',\n post_periods=post_periods\n)\n\nprint(\"Comparison: SDID vs TROP\")\nprint(\"=\"*60)\nprint(f\"True ATT: {true_att:.4f}\")\nprint()\nprint(f\"Synthetic DiD (no factor adjustment):\")\nprint(f\" ATT: {sdid_results.att:.4f}\")\nprint(f\" SE: {sdid_results.se:.4f}\")\nprint(f\" Bias: {sdid_results.att - true_att:.4f}\")\nprint()\nprint(f\"TROP (with factor adjustment):\")\nprint(f\" ATT: {trop_results.att:.4f}\")\nprint(f\" SE: {trop_results.se:.4f}\")\nprint(f\" Bias: {trop_results.att - true_att:.4f}\")\nprint(f\" Effective rank: {trop_results.effective_rank:.2f}\")" + "source": [ + "# SDID (no factor adjustment)\n", + "# Note: SDID uses 'treat' (unit-level ever-treated indicator)\n", + "sdid = SyntheticDiD(\n", + " n_bootstrap=50, # Reduced for faster execution\n", + " seed=42\n", + ")\n", + "\n", + "sdid_results = sdid.fit(\n", + " df,\n", + " outcome='outcome',\n", + " treatment='treat', # Unit-level ever-treated indicator\n", + " unit='unit',\n", + " time='period',\n", + " post_periods=post_periods\n", + ")\n", + "\n", + "# TROP (with factor adjustment)\n", + "# Note: TROP uses 'treated' (observation-level treatment indicator)\n", + "trop_est2 = TROP(\n", + " lambda_nn_grid=[0.0, 0.1], # Reduced grid for faster execution\n", + " n_bootstrap=50, # Reduced for faster execution\n", + " seed=42\n", + ")\n", + "\n", + "trop_results = trop_est2.fit(\n", + " df,\n", + " outcome='outcome',\n", + " treatment='treated', # Observation-level indicator\n", + " unit='unit',\n", + " time='period',\n", + " post_periods=post_periods\n", + ")\n", + "\n", + "print(\"Comparison: SDID vs TROP\")\n", + "print(\"=\"*60)\n", + "print(f\"True ATT: {true_att:.4f}\")\n", + "print()\n", + "print(f\"Synthetic DiD (no factor adjustment):\")\n", + "print(f\" ATT: {sdid_results.att:.4f}\")\n", + "print(f\" SE: {sdid_results.se:.4f}\")\n", + "print(f\" Bias: {sdid_results.att - true_att:.4f}\")\n", + "print()\n", + "print(f\"TROP (with factor adjustment):\")\n", + "print(f\" ATT: {trop_results.att:.4f}\")\n", + "print(f\" SE: {trop_results.se:.4f}\")\n", + "print(f\" Bias: {trop_results.att - true_att:.4f}\")\n", + "print(f\" Effective rank: {trop_results.effective_rank:.2f}\")" + ] }, { "cell_type": "markdown", @@ -376,7 +522,81 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "# Monte Carlo comparison\nn_sims = 20\ntrop_estimates = []\nsdid_estimates = []\n\nprint(f\"Running {n_sims} simulations...\")\n\nfor sim in range(n_sims):\n # Generate new data (includes both 'treated' and 'treat' columns)\n sim_data = generate_factor_dgp(\n n_units=50,\n n_pre=10,\n n_post=5,\n n_treated=10,\n n_factors=2,\n treatment_effect=2.0,\n factor_strength=1.5,\n noise_std=0.5,\n seed=100 + sim\n )\n \n # TROP (uses observation-level 'treated')\n try:\n trop_m = TROP(\n lambda_time_grid=[1.0],\n lambda_unit_grid=[1.0],\n lambda_nn_grid=[0.1],\n n_bootstrap=10, \n seed=42 + sim\n )\n trop_res = trop_m.fit(\n sim_data,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period',\n post_periods=list(range(10, 15))\n )\n trop_estimates.append(trop_res.att)\n except Exception as e:\n print(f\"TROP failed on sim {sim}: {e}\")\n \n # SDID (uses unit-level 'treat')\n try:\n sdid_m = SyntheticDiD(n_bootstrap=10, seed=42 + sim)\n sdid_res = sdid_m.fit(\n sim_data,\n outcome='outcome',\n treatment='treat', # Unit-level ever-treated indicator\n unit='unit',\n time='period',\n post_periods=list(range(10, 15))\n )\n sdid_estimates.append(sdid_res.att)\n except Exception as e:\n print(f\"SDID failed on sim {sim}: {e}\")\n\nprint(f\"\\nMonte Carlo Results (True ATT = {true_att})\")\nprint(\"=\"*60)\nprint(f\"{'Estimator':<15} {'Mean':>12} {'Bias':>12} {'RMSE':>12}\")\nprint(\"-\"*60)\n\nif trop_estimates:\n trop_mean = np.mean(trop_estimates)\n trop_bias = trop_mean - true_att\n trop_rmse = np.sqrt(np.mean([(e - true_att)**2 for e in trop_estimates]))\n print(f\"{'TROP':<15} {trop_mean:>12.4f} {trop_bias:>12.4f} {trop_rmse:>12.4f}\")\n\nif sdid_estimates:\n sdid_mean = np.mean(sdid_estimates)\n sdid_bias = sdid_mean - true_att\n sdid_rmse = np.sqrt(np.mean([(e - true_att)**2 for e in sdid_estimates]))\n print(f\"{'SDID':<15} {sdid_mean:>12.4f} {sdid_bias:>12.4f} {sdid_rmse:>12.4f}\")" + "source": [ + "# Monte Carlo comparison (reduced for faster tutorial execution)\n", + "n_sims = 5 # Reduced from 20 for faster validation\n", + "trop_estimates = []\n", + "sdid_estimates = []\n", + "\n", + "print(f\"Running {n_sims} simulations...\")\n", + "\n", + "for sim in range(n_sims):\n", + " # Generate new data (includes both 'treated' and 'treat' columns)\n", + " sim_data = generate_factor_dgp(\n", + " n_units=50,\n", + " n_pre=10,\n", + " n_post=5,\n", + " n_treated=10,\n", + " n_factors=2,\n", + " treatment_effect=2.0,\n", + " factor_strength=1.5,\n", + " noise_std=0.5,\n", + " seed=100 + sim\n", + " )\n", + " \n", + " # TROP (uses observation-level 'treated')\n", + " try:\n", + " trop_m = TROP(\n", + " lambda_time_grid=[1.0],\n", + " lambda_unit_grid=[1.0],\n", + " lambda_nn_grid=[0.1],\n", + " n_bootstrap=10, \n", + " seed=42 + sim\n", + " )\n", + " trop_res = trop_m.fit(\n", + " sim_data,\n", + " outcome='outcome',\n", + " treatment='treated',\n", + " unit='unit',\n", + " time='period',\n", + " post_periods=list(range(10, 15))\n", + " )\n", + " trop_estimates.append(trop_res.att)\n", + " except Exception as e:\n", + " print(f\"TROP failed on sim {sim}: {e}\")\n", + " \n", + " # SDID (uses unit-level 'treat')\n", + " try:\n", + " sdid_m = SyntheticDiD(n_bootstrap=10, seed=42 + sim)\n", + " sdid_res = sdid_m.fit(\n", + " sim_data,\n", + " outcome='outcome',\n", + " treatment='treat', # Unit-level ever-treated indicator\n", + " unit='unit',\n", + " time='period',\n", + " post_periods=list(range(10, 15))\n", + " )\n", + " sdid_estimates.append(sdid_res.att)\n", + " except Exception as e:\n", + " print(f\"SDID failed on sim {sim}: {e}\")\n", + "\n", + "print(f\"\\nMonte Carlo Results (True ATT = {true_att})\")\n", + "print(\"=\"*60)\n", + "print(f\"{'Estimator':<15} {'Mean':>12} {'Bias':>12} {'RMSE':>12}\")\n", + "print(\"-\"*60)\n", + "\n", + "if trop_estimates:\n", + " trop_mean = np.mean(trop_estimates)\n", + " trop_bias = trop_mean - true_att\n", + " trop_rmse = np.sqrt(np.mean([(e - true_att)**2 for e in trop_estimates]))\n", + " print(f\"{'TROP':<15} {trop_mean:>12.4f} {trop_bias:>12.4f} {trop_rmse:>12.4f}\")\n", + "\n", + "if sdid_estimates:\n", + " sdid_mean = np.mean(sdid_estimates)\n", + " sdid_bias = sdid_mean - true_att\n", + " sdid_rmse = np.sqrt(np.mean([(e - true_att)**2 for e in sdid_estimates]))\n", + " print(f\"{'SDID':<15} {sdid_mean:>12.4f} {sdid_bias:>12.4f} {sdid_rmse:>12.4f}\")" + ] }, { "cell_type": "code", @@ -423,7 +643,7 @@ " unit='unit',\n", " time='period',\n", " post_periods=post_periods,\n", - " n_bootstrap=50,\n", + " n_bootstrap=20, # Reduced for faster execution\n", " seed=42\n", ")\n", "\n", @@ -463,7 +683,7 @@ " lambda_unit_grid=[1.0], \n", " lambda_nn_grid=[0.1],\n", " variance_method=method,\n", - " n_bootstrap=100,\n", + " n_bootstrap=30, # Reduced for faster execution\n", " seed=42\n", " )\n", " \n", @@ -568,4 +788,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/TROP-ref/2508.21536v2.pdf b/papers/2508.21536v2.pdf similarity index 100% rename from TROP-ref/2508.21536v2.pdf rename to papers/2508.21536v2.pdf diff --git a/tests/test_pretrends.py b/tests/test_pretrends.py index b2fd36ef..c7e8db36 100644 --- a/tests/test_pretrends.py +++ b/tests/test_pretrends.py @@ -811,3 +811,218 @@ def test_power_curve_has_plot_method(self, mock_multiperiod_results): assert hasattr(curve, 'plot') assert callable(curve.plot) + + +# ============================================================================= +# Tests for PreTrendsPowerResults.power_at() method +# ============================================================================= + + +class TestPreTrendsPowerResultsPowerAt: + """Tests for the power_at method on PreTrendsPowerResults.""" + + def test_power_at_basic(self, mock_multiperiod_results): + """Test basic power_at functionality.""" + pt = PreTrendsPower() + results = pt.fit(mock_multiperiod_results) + + # Compute power at different M values + power_1 = results.power_at(1.0) + power_2 = results.power_at(2.0) + power_5 = results.power_at(5.0) + + # Power should increase with M + assert power_1 < power_2 < power_5 + + # Power should be between 0 and 1 + assert 0 <= power_1 <= 1 + assert 0 <= power_2 <= 1 + assert 0 <= power_5 <= 1 + + def test_power_at_zero(self, mock_multiperiod_results): + """Test power_at with M=0 (should equal alpha).""" + pt = PreTrendsPower(alpha=0.05) + results = pt.fit(mock_multiperiod_results) + + power_0 = results.power_at(0.0) + + # At M=0, power should equal size (alpha) + assert np.isclose(power_0, 0.05, atol=0.01) + + def test_power_at_matches_fit(self, mock_multiperiod_results): + """Test that power_at gives same result as fitting with that M.""" + pt = PreTrendsPower() + + # Get results from fit + results1 = pt.fit(mock_multiperiod_results, M=2.0) + + # Get power from power_at method + results_base = pt.fit(mock_multiperiod_results) + power_from_method = results_base.power_at(2.0) + + # Should be the same (or very close) + assert np.isclose(results1.power, power_from_method, rtol=0.01) + + def test_power_at_linear_weights(self, mock_multiperiod_results): + """Test power_at uses correct linear weights.""" + pt = PreTrendsPower(violation_type="linear") + results = pt.fit(mock_multiperiod_results) + + # Power_at should work without error + power = results.power_at(1.0) + assert 0 <= power <= 1 + + def test_power_at_constant_weights(self, mock_multiperiod_results): + """Test power_at uses correct constant weights.""" + pt = PreTrendsPower(violation_type="constant") + results = pt.fit(mock_multiperiod_results) + + power = results.power_at(1.0) + assert 0 <= power <= 1 + + def test_power_at_last_period_weights(self, mock_multiperiod_results): + """Test power_at uses correct last_period weights.""" + pt = PreTrendsPower(violation_type="last_period") + results = pt.fit(mock_multiperiod_results) + + power = results.power_at(1.0) + assert 0 <= power <= 1 + + +# ============================================================================= +# Tests for pre_periods parameter +# ============================================================================= + + +class TestPrePeriodsParameter: + """Tests for the pre_periods parameter in fit and related methods.""" + + @pytest.fixture + def event_study_all_periods_results(self): + """Create results simulating all periods estimated as post_periods. + + This mimics the event study workflow where we estimate coefficients + for ALL periods (pre and post) to get pre-period placebo effects. + """ + # Periods 0-3 are pre-treatment, 4-7 are post + # But we estimate ALL periods as "post" to get coefficients + period_effects = {} + coefficients = {} + + # Pre-periods (0, 1, 2) - period 3 would be reference + for p in [0, 1, 2]: + period_effects[p] = PeriodEffect( + period=p, effect=np.random.normal(0, 0.1), se=0.5, + t_stat=0.2, p_value=0.84, conf_int=(-0.88, 1.08) + ) + coefficients[f'treated:period_{p}'] = period_effects[p].effect + + # Post-periods (4, 5, 6, 7) + for p in [4, 5, 6, 7]: + period_effects[p] = PeriodEffect( + period=p, effect=5.0 + np.random.normal(0, 0.1), se=0.5, + t_stat=10.0, p_value=0.0001, conf_int=(4.02, 5.98) + ) + coefficients[f'treated:period_{p}'] = period_effects[p].effect + + # In this scenario, pre_periods=[3] (only reference), post_periods=[0,1,2,4,5,6,7] + vcov = np.diag([0.25] * 7) + + return MultiPeriodDiDResults( + period_effects=period_effects, + avg_att=5.0, + avg_se=0.25, + avg_t_stat=20.0, + avg_p_value=0.0001, + avg_conf_int=(4.51, 5.49), + n_obs=800, + n_treated=400, + n_control=400, + pre_periods=[3], # Only reference period + post_periods=[0, 1, 2, 4, 5, 6, 7], # All estimated periods + vcov=vcov, + coefficients=coefficients, + ) + + def test_fit_with_explicit_pre_periods(self, event_study_all_periods_results): + """Test fit() with explicit pre_periods parameter.""" + pt = PreTrendsPower() + + # Without pre_periods, would fail because results.pre_periods=[3] + # and period 3 has no coefficient (it's the reference) + # With explicit pre_periods=[0,1,2], should work + results = pt.fit( + event_study_all_periods_results, + pre_periods=[0, 1, 2] + ) + + assert results.n_pre_periods == 3 + assert results.power >= 0 + assert results.mdv > 0 + + def test_pre_periods_overrides_results(self, event_study_all_periods_results): + """Test that pre_periods parameter overrides results.pre_periods.""" + pt = PreTrendsPower() + + # Explicitly set pre_periods to [0, 1] + results = pt.fit( + event_study_all_periods_results, + pre_periods=[0, 1] + ) + + # Should use 2 pre-periods, not what's in results + assert results.n_pre_periods == 2 + + def test_power_at_with_pre_periods(self, event_study_all_periods_results): + """Test power_at() method with pre_periods parameter.""" + pt = PreTrendsPower() + + power = pt.power_at( + event_study_all_periods_results, + M=1.0, + pre_periods=[0, 1, 2] + ) + + assert 0 <= power <= 1 + + def test_power_curve_with_pre_periods(self, event_study_all_periods_results): + """Test power_curve() with pre_periods parameter.""" + pt = PreTrendsPower() + + curve = pt.power_curve( + event_study_all_periods_results, + n_points=10, + pre_periods=[0, 1, 2] + ) + + assert len(curve.M_values) == 10 + assert len(curve.powers) == 10 + + def test_sensitivity_to_honest_did_with_pre_periods(self, event_study_all_periods_results): + """Test sensitivity_to_honest_did() with pre_periods parameter.""" + pt = PreTrendsPower() + + sensitivity = pt.sensitivity_to_honest_did( + event_study_all_periods_results, + pre_periods=[0, 1, 2] + ) + + assert 'mdv' in sensitivity + assert sensitivity['mdv'] > 0 + + def test_convenience_functions_with_pre_periods(self, event_study_all_periods_results): + """Test convenience functions with pre_periods parameter.""" + # compute_mdv + mdv = compute_mdv( + event_study_all_periods_results, + pre_periods=[0, 1, 2] + ) + assert mdv > 0 + + # compute_pretrends_power + results = compute_pretrends_power( + event_study_all_periods_results, + M=1.0, + pre_periods=[0, 1, 2] + ) + assert results.n_pre_periods == 3 diff --git a/tests/test_staggered.py b/tests/test_staggered.py index 837d0cc2..06caf79e 100644 --- a/tests/test_staggered.py +++ b/tests/test_staggered.py @@ -1555,3 +1555,248 @@ def test_event_study_analytical_se(self): f"Event study SE at e={e}: analytical={se_analytical:.4f}, " f"bootstrap={se_bootstrap:.4f}, diff={rel_diff:.1%}" ) + + +class TestCallawaySantAnnaNonStandardColumnNames: + """Tests for CallawaySantAnna with non-standard column names. + + These tests verify that the estimator works correctly when column names + differ from the default names (outcome, unit, time, first_treat). + """ + + def generate_data_with_custom_names( + self, + outcome_name: str = 'y', + unit_name: str = 'id', + time_name: str = 'period', + first_treat_name: str = 'treatment_start', + n_units: int = 100, + n_periods: int = 10, + seed: int = 42, + ) -> pd.DataFrame: + """Generate staggered data with custom column names.""" + np.random.seed(seed) + + # Generate standard data + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + # 30% never-treated, rest treated at period 4 or 6 + n_never = int(n_units * 0.3) + first_treat = np.zeros(n_units) + first_treat[n_never:n_never + (n_units - n_never) // 2] = 4 + first_treat[n_never + (n_units - n_never) // 2:] = 6 + first_treat_expanded = np.repeat(first_treat, n_periods) + + # Generate outcomes + unit_fe = np.repeat(np.random.randn(n_units) * 2, n_periods) + time_fe = np.tile(np.linspace(0, 1, n_periods), n_units) + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + outcomes = unit_fe + time_fe + 2.5 * post + np.random.randn(len(units)) * 0.5 + + return pd.DataFrame({ + outcome_name: outcomes, + unit_name: units, + time_name: times, + first_treat_name: first_treat_expanded.astype(int), + }) + + def test_non_standard_first_treat_name(self): + """Test with non-standard first_treat column name.""" + data = self.generate_data_with_custom_names( + first_treat_name='treatment_cohort' + ) + + cs = CallawaySantAnna() + results = cs.fit( + data, + outcome='y', + unit='id', + time='period', + first_treat='treatment_cohort' + ) + + assert results.overall_att is not None + assert np.isfinite(results.overall_att) + assert results.overall_se > 0 + # Treatment effect should be approximately 2.5 + assert abs(results.overall_att - 2.5) < 1.5 + + def test_non_standard_all_column_names(self): + """Test with all non-standard column names.""" + data = self.generate_data_with_custom_names( + outcome_name='response_var', + unit_name='entity_id', + time_name='time_period', + first_treat_name='treatment_timing', + ) + + cs = CallawaySantAnna() + results = cs.fit( + data, + outcome='response_var', + unit='entity_id', + time='time_period', + first_treat='treatment_timing' + ) + + assert results.overall_att is not None + assert np.isfinite(results.overall_att) + assert results.overall_se > 0 + + def test_non_standard_names_with_bootstrap(self): + """Test non-standard column names with bootstrap inference.""" + data = self.generate_data_with_custom_names( + first_treat_name='g', # Short name like R's `did` package uses + n_units=50 + ) + + cs = CallawaySantAnna(n_bootstrap=99, seed=42) + results = cs.fit( + data, + outcome='y', + unit='id', + time='period', + first_treat='g' + ) + + assert results.bootstrap_results is not None + assert results.overall_se > 0 + assert results.overall_conf_int[0] < results.overall_att < results.overall_conf_int[1] + + def test_non_standard_names_with_event_study(self): + """Test non-standard column names with event study aggregation.""" + data = self.generate_data_with_custom_names( + first_treat_name='cohort', + n_periods=12 + ) + + cs = CallawaySantAnna() + results = cs.fit( + data, + outcome='y', + unit='id', + time='period', + first_treat='cohort', + aggregate='event_study' + ) + + assert results.event_study_effects is not None + assert len(results.event_study_effects) > 0 + + def test_non_standard_names_with_covariates(self): + """Test non-standard column names with covariate adjustment.""" + # Generate data with covariates + data = self.generate_data_with_custom_names( + first_treat_name='treatment_time' + ) + # Add covariates with custom names + data['covariate_x'] = np.random.randn(len(data)) + data['covariate_z'] = np.random.binomial(1, 0.5, len(data)) + + cs = CallawaySantAnna(estimation_method='dr') + results = cs.fit( + data, + outcome='y', + unit='id', + time='period', + first_treat='treatment_time', + covariates=['covariate_x', 'covariate_z'] + ) + + assert results.overall_att is not None + assert results.overall_se > 0 + + def test_non_standard_names_with_not_yet_treated(self): + """Test non-standard column names with not_yet_treated control group.""" + data = self.generate_data_with_custom_names( + first_treat_name='adoption_period' + ) + + cs = CallawaySantAnna(control_group='not_yet_treated') + results = cs.fit( + data, + outcome='y', + unit='id', + time='period', + first_treat='adoption_period' + ) + + assert results.overall_att is not None + assert results.control_group == 'not_yet_treated' + + def test_non_standard_names_matches_standard_names(self): + """Verify results are identical regardless of column naming.""" + np.random.seed(42) + + # Generate identical data with different column names + data_standard = generate_staggered_data(n_units=80, seed=42) + + data_custom = data_standard.rename(columns={ + 'outcome': 'y', + 'unit': 'entity', + 'time': 't', + 'first_treat': 'g', + }) + + # Fit with standard names + cs1 = CallawaySantAnna(seed=123) + results1 = cs1.fit( + data_standard, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat' + ) + + # Fit with custom names + cs2 = CallawaySantAnna(seed=123) + results2 = cs2.fit( + data_custom, + outcome='y', + unit='entity', + time='t', + first_treat='g' + ) + + # Results should be identical + assert abs(results1.overall_att - results2.overall_att) < 1e-10 + assert abs(results1.overall_se - results2.overall_se) < 1e-10 + + def test_column_name_with_spaces(self): + """Test column names containing spaces.""" + data = self.generate_data_with_custom_names() + data = data.rename(columns={ + 'y': 'outcome variable', + 'treatment_start': 'treatment period', + }) + + cs = CallawaySantAnna() + results = cs.fit( + data, + outcome='outcome variable', + unit='id', + time='period', + first_treat='treatment period' + ) + + assert results.overall_att is not None + assert results.overall_se > 0 + + def test_column_name_with_special_characters(self): + """Test column names with underscores and numbers.""" + data = self.generate_data_with_custom_names() + data = data.rename(columns={ + 'treatment_start': 'first_treat_2024', + }) + + cs = CallawaySantAnna() + results = cs.fit( + data, + outcome='y', + unit='id', + time='period', + first_treat='first_treat_2024' + ) + + assert results.overall_att is not None