Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions .github/workflows/notebooks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: Tutorial Notebooks

on:
push:
branches: [main]
paths:
- 'docs/tutorials/**'
- 'diff_diff/**'
- 'pyproject.toml'
- '.github/workflows/notebooks.yml'
pull_request:
branches: [main]
paths:
- 'docs/tutorials/**'
- 'diff_diff/**'
- 'pyproject.toml'
- '.github/workflows/notebooks.yml'
schedule:
# Weekly Sunday 6am UTC — smoke test that notebooks still execute cleanly
- cron: '0 6 * * 0'

jobs:
execute-notebooks:
name: Execute tutorial notebooks
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install dependencies
# Keep in sync with pyproject.toml [project.dependencies] and [project.optional-dependencies.dev]
run: |
pip install numpy pandas scipy matplotlib nbmake pytest ipykernel
# Add repo root to Python path so Jupyter kernels can import diff_diff
# (pip install -e . requires the Rust/maturin toolchain; .pth avoids that)
python -c "import site; print(site.getsitepackages()[0])" | xargs -I{} sh -c 'echo "$PWD" > {}/diff_diff_dev.pth'

- name: Execute notebooks
env:
DIFF_DIFF_BACKEND: python
run: |
pytest --nbmake docs/tutorials/ \
--nbmake-timeout=600 \
--ignore=docs/tutorials/06_power_analysis.ipynb \
--ignore=docs/tutorials/10_trop.ipynb \
-v \
--tb=short
# Excluded notebooks (too slow for pure-Python CI without Rust backend):
# 06_power_analysis — SyntheticDiD simulate_power Monte Carlo (>600s)
# 10_trop — LOOCV grid search (>600s)

- name: Upload failed notebook outputs
if: failure()
uses: actions/upload-artifact@v4
with:
name: failed-notebook-outputs
path: docs/tutorials/*.ipynb
retention-days: 7
29 changes: 8 additions & 21 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,9 @@ Deferred items from PR reviews that were not addressed before merge.

| Issue | Location | PR | Priority |
|-------|----------|----|----------|
| Tutorial notebooks not executed in CI | `docs/tutorials/*.ipynb` | #159 | Low |
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
| CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low |
| ~~Context-dependent doc snippets pass via blanket NameError~~ | `tests/test_doc_snippets.py` | #206 | ~~Low~~ — resolved: allow-list replaces blanket catch |
| ~1,460 `duplicate object description` Sphinx warnings — each class attribute is documented in both module API pages and autosummary stubs; fix by adding `:no-index:` to one location or restructuring API docs to avoid overlap | `docs/api/*.rst`, `docs/api/_autosummary/` | — | Low |
| ~376 `duplicate object description` Sphinx warnings — caused by autodoc `:members:` on dataclass attributes within manual API pages (not from autosummary stubs); fix requires restructuring `docs/api/*.rst` pages to avoid documenting the same attribute via both `:members:` and inline `autosummary` tables | `docs/api/*.rst` | — | Low |

---

Expand All @@ -88,29 +86,15 @@ Different estimators compute SEs differently. Consider unified interface.

### Type Annotations

Mypy reports 9 errors (down from 81 before spring cleanup). All remaining are
mixin `attr-defined` errors — methods accessed via `self` that live on the
concrete class, not the mixin. Fixing these requires Protocol classes, which is
low priority.

| Category | Count | Notes |
|----------|-------|-------|
| attr-defined (mixin methods) | 9 | Structural — requires Protocol refactor |

**Resolved in spring cleanup:**
- [x] `@overload` on `solve_ols` / `_solve_ols_numpy` — eliminated all unpacking mismatches
- [x] `assert X is not None` guards — eliminated all Optional indexing errors
- [x] Mixin scalar attribute stubs — eliminated 26 mixin attr-defined errors
- [x] Matplotlib `tab10` lookup fix
Mypy reports 0 errors. All mixin `attr-defined` errors resolved via
`TYPE_CHECKING`-guarded method stubs in bootstrap mixin classes.

## Deprecated Code

Deprecated parameters still present for backward compatibility:

- [x] `bootstrap_weight_type` in `CallawaySantAnna` (`staggered.py`)
- `bootstrap_weight_type` in `CallawaySantAnna` (`staggered.py`)
- Deprecated in favor of `bootstrap_weights` parameter
- ✅ Deprecation warning updated to say "removed in v3.0"
- ✅ README.md and tutorial 02 updated to use `bootstrap_weights`
- Remove in next major version (v3.0)

---
Expand All @@ -126,7 +110,10 @@ Deprecated parameters still present for backward compatibility:
Enhancements for `honest_did.py`:

- [ ] Improved C-LF implementation with direct optimization instead of grid search
- [ ] Support for CallawaySantAnnaResults (currently only MultiPeriodDiDResults)
(current implementation uses simplified FLCI approach with estimation uncertainty
adjustment; see `honest_did.py:947`)
- [x] Support for CallawaySantAnnaResults (implemented in `honest_did.py:612-653`;
requires `aggregate='event_study'` when calling `CallawaySantAnna.fit()`)
- [ ] Event-study-specific bounds for each post-period
- [ ] Hybrid inference methods
- [ ] Simulation-based power analysis for honest bounds
Expand Down
38 changes: 37 additions & 1 deletion diff_diff/imputation_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

import warnings
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -68,6 +68,42 @@ class ImputationDiDBootstrapMixin:
anticipation: int
horizon_max: Optional[int]

if TYPE_CHECKING:

def _compute_cluster_psi_sums(
self,
df: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
covariates: Optional[List[str]],
omega_0_mask: pd.Series,
omega_1_mask: pd.Series,
unit_fe: Dict[Any, float],
time_fe: Dict[Any, float],
grand_mean: float,
delta_hat: Optional[np.ndarray],
weights: np.ndarray,
cluster_var: str,
kept_cov_mask: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]: ...

@staticmethod
def _build_cohort_rel_times(
df: pd.DataFrame,
first_treat: str,
) -> Dict[Any, Set[int]]: ...

@staticmethod
def _compute_balanced_cohort_mask(
df_treated: pd.DataFrame,
first_treat: str,
all_horizons: List[int],
balance_e: int,
cohort_rel_times: Dict[Any, Set[int]],
) -> np.ndarray: ...

def _precompute_bootstrap_psi(
self,
df: pd.DataFrame,
Expand Down
20 changes: 19 additions & 1 deletion diff_diff/staggered_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
)

if TYPE_CHECKING:
pass
import pandas as pd

from diff_diff.staggered_aggregation import PrecomputedData


# =============================================================================
Expand Down Expand Up @@ -117,6 +119,22 @@ class CallawaySantAnnaBootstrapMixin:
seed: Optional[int]
anticipation: int

if TYPE_CHECKING:

def _compute_combined_influence_function(
self,
gt_pairs: List[Tuple[Any, Any]],
weights: np.ndarray,
effects: np.ndarray,
groups_for_gt: np.ndarray,
influence_func_info: Dict,
df: "pd.DataFrame",
unit: str,
precomputed: Optional["PrecomputedData"] = None,
global_unit_to_idx: Optional[Dict[Any, int]] = None,
n_global_units: Optional[int] = None,
) -> Tuple[np.ndarray, Optional[List]]: ...

def _run_multiplier_bootstrap(
self,
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
Expand Down
23 changes: 22 additions & 1 deletion diff_diff/two_stage_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

import warnings
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -41,6 +41,27 @@ class TwoStageDiDBootstrapMixin:
seed: Optional[int]
horizon_max: Optional[int]

if TYPE_CHECKING:
from scipy import sparse

def _build_fe_design(
self,
df: pd.DataFrame,
unit: str,
time: str,
covariates: Optional[List[str]],
omega_0_mask: pd.Series,
) -> Tuple[
"sparse.csr_matrix", "sparse.csr_matrix", Dict[Any, int], Dict[Any, int]
]: ...

@staticmethod
def _compute_gmm_scores(
c_by_cluster: np.ndarray,
gamma_hat: np.ndarray,
s2_by_cluster: np.ndarray,
) -> np.ndarray: ...

def _compute_cluster_S_scores(
self,
df: pd.DataFrame,
Expand Down
85 changes: 4 additions & 81 deletions docs/tutorials/07_pretrends_power.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,7 @@
"id": "cell-6",
"metadata": {},
"outputs": [],
"source": [
"# Fit event study with ALL periods (pre and post) relative to reference period\n",
"# For pre-trends power analysis, we need coefficients for pre-periods too\n",
"mp_did = MultiPeriodDiD()\n",
"\n",
"# Use period 4 as the reference period (last pre-period, excluded from estimation)\n",
"# Estimate coefficients for all other periods: 0, 1, 2, 3 (pre) and 5, 6, 7, 8, 9 (post)\n",
"all_estimation_periods = [0, 1, 2, 3, 5, 6, 7, 8, 9] # All except reference period 4\n",
"\n",
"event_results = mp_did.fit(\n",
" df,\n",
" outcome='outcome',\n",
" treatment='treated',\n",
" time='period',\n",
" post_periods=all_estimation_periods # Include all periods for full event study\n",
")\n",
"\n",
"# Note: For standard DiD analysis, we'd normally use post_periods=[5,6,7,8,9]\n",
"# But for pre-trends power analysis, we need pre-period coefficients too\n",
"\n",
"print(event_results.summary())"
]
"source": "# Fit event study with ALL periods (pre and post) relative to reference period\n# For pre-trends power analysis, we need coefficients for pre-periods too\nmp_did = MultiPeriodDiD()\n\n# Use period 4 as the reference period (last pre-period, excluded from estimation)\n# Specify post_periods as the actual post-treatment periods; MultiPeriodDiD\n# automatically estimates pre-period coefficients for the event study.\nevent_results = mp_did.fit(\n df,\n outcome='outcome',\n treatment='treated',\n time='period',\n post_periods=[5, 6, 7, 8, 9]\n)\n\nprint(event_results.summary())"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -199,24 +178,7 @@
"id": "cell-10",
"metadata": {},
"outputs": [],
"source": [
"# Create a PreTrendsPower object\n",
"pt = PreTrendsPower(\n",
" alpha=0.05, # Significance level for pre-trends test\n",
" power=0.80, # Target power for MDV calculation\n",
" violation_type='linear' # Type of violation to consider\n",
")\n",
"\n",
"# Define the actual pre-treatment periods (those before treatment starts at period 5)\n",
"# These are the periods we want to analyze for pre-trends power\n",
"pre_treatment_periods = [0, 1, 2, 3]\n",
"\n",
"# Fit to the event study results, specifying which periods are pre-treatment\n",
"# This is needed because we estimated all periods as post_periods in the event study\n",
"pt_results = pt.fit(event_results, pre_periods=pre_treatment_periods)\n",
"\n",
"print(pt_results.summary())"
]
"source": "# Create a PreTrendsPower object\npt = PreTrendsPower(\n alpha=0.05, # Significance level for pre-trends test\n power=0.80, # Target power for MDV calculation\n violation_type='linear' # Type of violation to consider\n)\n\n# Define the actual pre-treatment periods (those before treatment starts at period 5)\n# These are the periods we want to analyze for pre-trends power\npre_treatment_periods = [0, 1, 2, 3]\n\n# Fit to the event study results, specifying which periods are pre-treatment\npt_results = pt.fit(event_results, pre_periods=pre_treatment_periods)\n\nprint(pt_results.summary())"
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -558,46 +520,7 @@
"id": "cell-30",
"metadata": {},
"outputs": [],
"source": [
"# Typical workflow for pre-trends power analysis\n",
"\n",
"# Step 1: Estimate event study with ALL periods (pre and post) relative to reference\n",
"# For pre-trends power analysis, we need pre-period coefficients\n",
"mp_did = MultiPeriodDiD()\n",
"\n",
"# Reference period is 4 (last pre-period)\n",
"# Estimate coefficients for periods 0, 1, 2, 3 (pre) and 5, 6, 7, 8, 9 (post)\n",
"all_estimation_periods = [0, 1, 2, 3, 5, 6, 7, 8, 9]\n",
"pre_treatment_periods = [0, 1, 2, 3] # Define which are pre-treatment\n",
"\n",
"results = mp_did.fit(\n",
" df, \n",
" outcome='outcome',\n",
" treatment='treated', \n",
" time='period',\n",
" post_periods=all_estimation_periods\n",
")\n",
"\n",
"# Step 2: Assess power of the pre-trends test \n",
"print(\"Step 2: Pre-Trends Power Analysis\")\n",
"pt = PreTrendsPower(alpha=0.05, power=0.80, violation_type='linear')\n",
"pt_results = pt.fit(results, pre_periods=pre_treatment_periods)\n",
"print(f\"MDV (80% power): {pt_results.mdv:.3f}\")\n",
"print(\"\")\n",
"\n",
"# Step 3: Interpret\n",
"print(\"Step 3: Interpretation\")\n",
"print(f\"Your pre-trends test could only detect violations >= {pt_results.mdv:.3f}\")\n",
"print(f\"Violations smaller than this would likely go undetected.\")\n",
"print(\"\")\n",
"\n",
"# Step 4: Connect to Honest DiD for robust inference\n",
"print(\"Step 4: Robust Inference with Honest DiD\")\n",
"honest = HonestDiD(method='smoothness', M=pt_results.mdv)\n",
"honest_results = honest.fit(results)\n",
"print(f\"Robust 95% CI (M=MDV): [{honest_results.ci_lb:.3f}, {honest_results.ci_ub:.3f}]\")\n",
"print(f\"Conclusion: {'Effect is robust' if honest_results.is_significant else 'Effect may not be robust'}\")"
]
"source": "# Typical workflow for pre-trends power analysis\n\n# Step 1: Estimate event study with proper pre/post period classification\nmp_did = MultiPeriodDiD()\n\n# Specify actual post-treatment periods; pre-period coefficients are\n# estimated automatically by MultiPeriodDiD for the event study\npre_treatment_periods = [0, 1, 2, 3] # Define which are pre-treatment\n\nresults = mp_did.fit(\n df, \n outcome='outcome',\n treatment='treated', \n time='period',\n post_periods=[5, 6, 7, 8, 9]\n)\n\n# Step 2: Assess power of the pre-trends test \nprint(\"Step 2: Pre-Trends Power Analysis\")\npt = PreTrendsPower(alpha=0.05, power=0.80, violation_type='linear')\npt_results = pt.fit(results, pre_periods=pre_treatment_periods)\nprint(f\"MDV (80% power): {pt_results.mdv:.3f}\")\nprint(\"\")\n\n# Step 3: Interpret\nprint(\"Step 3: Interpretation\")\nprint(f\"Your pre-trends test could only detect violations >= {pt_results.mdv:.3f}\")\nprint(f\"Violations smaller than this would likely go undetected.\")\nprint(\"\")\n\n# Step 4: Connect to Honest DiD for robust inference\nprint(\"Step 4: Robust Inference with Honest DiD\")\nhonest = HonestDiD(method='smoothness', M=pt_results.mdv)\nhonest_results = honest.fit(results)\nprint(f\"Robust 95% CI (M=MDV): [{honest_results.ci_lb:.3f}, {honest_results.ci_ub:.3f}]\")\nprint(f\"Conclusion: {'Effect is robust' if honest_results.is_significant else 'Effect may not be robust'}\")"
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -693,4 +616,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dev = [
"mypy>=1.0",
"maturin>=1.4,<2.0",
"matplotlib>=3.5",
"nbmake>=1.5",
]
docs = [
"sphinx>=6.0",
Expand Down
Loading