Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 87 additions & 20 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,92 @@ For the public feature roadmap, see [ROADMAP.md](ROADMAP.md).

---

## Priority Items for 1.0.1

### Linter/Type Errors (Blocking) - COMPLETED

| Issue | Location | Status |
|-------|----------|--------|
| ~~Unused import `Union`~~ | `power.py:25` | Fixed |
| ~~Unsorted imports~~ | `staggered.py:8` | Fixed |
| ~~10 mypy errors - Optional type handling~~ | `staggered.py:843-1631` | Fixed |

### Quick Wins - COMPLETED

- [x] Fix ruff errors (2 auto-fixable)
- [x] Fix mypy errors in staggered.py (Optional dict access needs guards)
- [x] Remove duplicate `_get_significance_stars()` from `diagnostics.py` (now imports from `results.py`)

---

## Known Limitations

| Issue | Location | Priority | Notes |
|-------|----------|----------|-------|
| MultiPeriodDiD wild bootstrap not supported | `estimators.py:944-951` | Low | Edge case |
| MultiPeriodDiD wild bootstrap not supported | `estimators.py:1068-1074` | Low | Edge case |
| `predict()` raises NotImplementedError | `estimators.py:532-554` | Low | Rarely needed |
| SyntheticDiD bootstrap can fail silently | `estimators.py:1580-1654` | Medium | Needs error handling |
| Diagnostics module error handling | `diagnostics.py:782-885` | Medium | Improve robustness |

---

## Code Quality Issues

### Bare Exception Handling - COMPLETED

~~Replace broad `except Exception` with specific exceptions:~~

| Location | Status |
|----------|--------|
| ~~`diagnostics.py:624`~~ | Fixed - catches `ValueError`, `KeyError`, `LinAlgError` |
| ~~`diagnostics.py:735`~~ | Fixed - catches `ValueError`, `KeyError`, `LinAlgError` |
| ~~`honest_did.py:807`~~ | Fixed - catches `ValueError`, `TypeError` |
| ~~`honest_did.py:822`~~ | Fixed - catches `ValueError`, `TypeError` |

### Code Duplication

| Duplicate Code | Locations | Status |
|---------------|-----------|--------|
| ~~`_get_significance_stars()`~~ | `results.py:183`, ~~`diagnostics.py`~~ | Fixed in 1.0.1 |
| Wild bootstrap inference block | `estimators.py:278-296`, `estimators.py:725-748` | Future: extract to shared method |
| Within-transformation logic | `estimators.py:217-232`, `estimators.py:787-833`, `bacon.py:567-642` | Future: extract to utils.py |
| Linear regression helper | `staggered.py:205-240`, `estimators.py:366-408` | Future: consider consolidation |

### API Inconsistencies - PARTIALLY ADDRESSED

**Bootstrap parameter naming:**
| Estimator | Parameter | Status |
|-----------|-----------|--------|
| DifferenceInDifferences | `bootstrap_weights` | OK |
| CallawaySantAnna | `bootstrap_weights` | Fixed in 1.0.1 (deprecated `bootstrap_weight_type`) |
| TwoWayFixedEffects | `bootstrap_weights` | OK |

**Cluster variable defaults:**
- ~~`TwoWayFixedEffects` silently defaults cluster to `unit` at runtime~~ - Documented in 1.0.1

---

## Large Module Files

Current line counts (target: < 1000 lines per module):

| File | Lines | Status |
|------|-------|--------|
| `staggered.py` | 1822 | Consider splitting |
| `estimators.py` | 1812 | Consider splitting |
| `honest_did.py` | 1491 | Acceptable |
| `utils.py` | 1350 | Acceptable |
| `power.py` | 1350 | Acceptable |
| `prep.py` | 1338 | Acceptable |
| `visualization.py` | 1388 | Acceptable |
| `bacon.py` | 1027 | OK |

**Potential splits:**
- `estimators.py` → `twfe.py`, `synthetic_did.py` (keep base classes in estimators.py)
- `staggered.py` → `staggered_bootstrap.py` (move bootstrap logic)

---

## Standard Error Consistency

Different estimators compute SEs differently. Consider unified interface.
Expand All @@ -43,6 +118,8 @@ Edge cases needing tests:
- [ ] CallawaySantAnna with single cohort
- [ ] SyntheticDiD with insufficient pre-periods

**Note**: 21 visualization tests are skipped when matplotlib unavailable - this is expected.

---

## Documentation Improvements
Expand All @@ -54,25 +131,6 @@ Edge cases needing tests:

---

## Code Quality

### Refactoring Candidates

- `estimators.py` is large (~1600 lines). Consider splitting TWFE and SyntheticDiD into separate modules.
- Duplicate code in fixed effects handling between `DifferenceInDifferences` and `TwoWayFixedEffects`.

### Type Hints

- Most modules have type hints, but some internal functions lack them
- Consider stricter mypy settings

### Dependencies

- Core: numpy, pandas, scipy only (no statsmodels) - keep it this way
- Optional: matplotlib for visualization

---

## CallawaySantAnna Bootstrap Improvements

Deferred improvements from code review (PR #32):
Expand Down Expand Up @@ -101,3 +159,12 @@ No major performance issues identified. Potential future optimizations:
- JIT compilation for bootstrap loops (numba)
- Parallel bootstrap iterations
- Sparse matrix handling for large fixed effects

---

## Type Hints

Missing type hints in internal functions:
- `utils.py:593` - `compute_trend()` nested function
- `staggered.py:173, 180` - Nested functions in `_logistic_regression()`
- `prep.py:604` - `format_label()` nested function
18 changes: 3 additions & 15 deletions diff_diff/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,10 @@
import pandas as pd

from diff_diff.estimators import DifferenceInDifferences
from diff_diff.results import _get_significance_stars
from diff_diff.utils import compute_confidence_interval, compute_p_value


def _get_significance_stars(p_value: float) -> str:
"""Return significance stars based on p-value."""
if p_value < 0.001:
return "***"
elif p_value < 0.01:
return "**"
elif p_value < 0.05:
return "*"
elif p_value < 0.1:
return "."
return ""


@dataclass
class PlaceboTestResults:
"""
Expand Down Expand Up @@ -633,7 +621,7 @@ def permutation_test(
time=time
)
permuted_effects[i] = perm_results.att
except Exception:
except (ValueError, KeyError, np.linalg.LinAlgError):
# Handle edge cases where fitting fails
permuted_effects[i] = np.nan

Expand Down Expand Up @@ -744,7 +732,7 @@ def leave_one_out_test(
time=time
)
loo_effects[u] = loo_results.att
except Exception:
except (ValueError, KeyError, np.linalg.LinAlgError):
# Skip units that cause fitting issues
loo_effects[u] = np.nan

Expand Down
4 changes: 3 additions & 1 deletion diff_diff/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,9 @@ class TwoWayFixedEffects(DifferenceInDifferences):
Whether to use heteroskedasticity-robust standard errors.
cluster : str, optional
Column name for cluster-robust standard errors.
Defaults to clustering at the unit level.
If None, automatically clusters at the unit level (the `unit`
parameter passed to `fit()`). This differs from
DifferenceInDifferences where cluster=None means no clustering.
alpha : float, default=0.05
Significance level for confidence intervals.

Expand Down
6 changes: 4 additions & 2 deletions diff_diff/honest_did.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,8 @@ def _solve_bounds_lp(
min_val = result_min.fun
else:
min_val = -np.inf
except Exception:
except (ValueError, TypeError):
# Optimization failed - return unbounded
min_val = -np.inf

# Solve for upper bound of -l'@delta (which gives lower bound of theta)
Expand All @@ -818,7 +819,8 @@ def _solve_bounds_lp(
max_val = -result_max.fun
else:
max_val = np.inf
except Exception:
except (ValueError, TypeError):
# Optimization failed - return unbounded
max_val = np.inf

theta_base = np.dot(l_vec, beta_post)
Expand Down
2 changes: 1 addition & 1 deletion diff_diff/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
Expand Down
69 changes: 50 additions & 19 deletions diff_diff/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
compute_p_value,
)


# =============================================================================
# Bootstrap Weight Generators
# =============================================================================
Expand Down Expand Up @@ -554,11 +553,14 @@ class CallawaySantAnna:
Number of bootstrap iterations for inference.
If 0, uses analytical standard errors.
Recommended: 999 or more for reliable inference.
bootstrap_weight_type : str, default="rademacher"
bootstrap_weights : str, default="rademacher"
Type of weights for multiplier bootstrap:
- "rademacher": +1/-1 with equal probability (standard choice)
- "mammen": Two-point distribution (asymptotically valid, matches skewness)
- "webb": Six-point distribution (recommended when n_clusters < 20)
bootstrap_weight_type : str, optional
.. deprecated:: 1.0.1
Use ``bootstrap_weights`` instead. Will be removed in v2.0.
seed : int, optional
Random seed for reproducibility.

Expand Down Expand Up @@ -640,9 +642,12 @@ def __init__(
alpha: float = 0.05,
cluster: Optional[str] = None,
n_bootstrap: int = 0,
bootstrap_weight_type: str = "rademacher",
bootstrap_weights: Optional[str] = None,
bootstrap_weight_type: Optional[str] = None,
seed: Optional[int] = None,
):
import warnings

if control_group not in ["never_treated", "not_yet_treated"]:
raise ValueError(
f"control_group must be 'never_treated' or 'not_yet_treated', "
Expand All @@ -653,10 +658,26 @@ def __init__(
f"estimation_method must be 'dr', 'ipw', or 'reg', "
f"got '{estimation_method}'"
)
if bootstrap_weight_type not in ["rademacher", "mammen", "webb"]:

# Handle bootstrap_weight_type deprecation
if bootstrap_weight_type is not None:
warnings.warn(
"bootstrap_weight_type is deprecated and will be removed in v2.0. "
"Use bootstrap_weights instead.",
DeprecationWarning,
stacklevel=2
)
if bootstrap_weights is None:
bootstrap_weights = bootstrap_weight_type

# Default to rademacher if neither specified
if bootstrap_weights is None:
bootstrap_weights = "rademacher"

if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
raise ValueError(
f"bootstrap_weight_type must be 'rademacher', 'mammen', or 'webb', "
f"got '{bootstrap_weight_type}'"
f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
f"got '{bootstrap_weights}'"
)

self.control_group = control_group
Expand All @@ -665,7 +686,9 @@ def __init__(
self.alpha = alpha
self.cluster = cluster
self.n_bootstrap = n_bootstrap
self.bootstrap_weight_type = bootstrap_weight_type
self.bootstrap_weights = bootstrap_weights
# Keep bootstrap_weight_type for backward compatibility
self.bootstrap_weight_type = bootstrap_weights
self.seed = seed

self.is_fitted_ = False
Expand Down Expand Up @@ -838,31 +861,37 @@ def fit(
group_time_effects[gt]['se'] = bootstrap_results.group_time_ses[gt]
group_time_effects[gt]['conf_int'] = bootstrap_results.group_time_cis[gt]
group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt]
effect = group_time_effects[gt]['effect']
se = group_time_effects[gt]['se']
effect = float(group_time_effects[gt]['effect'])
se = float(group_time_effects[gt]['se'])
group_time_effects[gt]['t_stat'] = effect / se if se > 0 else 0.0

# Update event study effects with bootstrap SEs
if event_study_effects is not None and bootstrap_results.event_study_ses is not None:
if (event_study_effects is not None
and bootstrap_results.event_study_ses is not None
and bootstrap_results.event_study_cis is not None
and bootstrap_results.event_study_p_values is not None):
for e in event_study_effects:
if e in bootstrap_results.event_study_ses:
event_study_effects[e]['se'] = bootstrap_results.event_study_ses[e]
event_study_effects[e]['conf_int'] = bootstrap_results.event_study_cis[e]
p_val = bootstrap_results.event_study_p_values[e]
event_study_effects[e]['p_value'] = p_val
effect = event_study_effects[e]['effect']
se = event_study_effects[e]['se']
effect = float(event_study_effects[e]['effect'])
se = float(event_study_effects[e]['se'])
event_study_effects[e]['t_stat'] = effect / se if se > 0 else 0.0

# Update group effects with bootstrap SEs
if group_effects is not None and bootstrap_results.group_effect_ses is not None:
if (group_effects is not None
and bootstrap_results.group_effect_ses is not None
and bootstrap_results.group_effect_cis is not None
and bootstrap_results.group_effect_p_values is not None):
for g in group_effects:
if g in bootstrap_results.group_effect_ses:
group_effects[g]['se'] = bootstrap_results.group_effect_ses[g]
group_effects[g]['conf_int'] = bootstrap_results.group_effect_cis[g]
group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g]
effect = group_effects[g]['effect']
se = group_effects[g]['se']
effect = float(group_effects[g]['effect'])
se = float(group_effects[g]['se'])
group_effects[g]['t_stat'] = effect / se if se > 0 else 0.0

# Store results
Expand Down Expand Up @@ -1557,7 +1586,7 @@ def _run_multiplier_bootstrap(
bootstrap_overall[b] = np.sum(overall_weights * bootstrap_atts_gt[b, :])

# Compute bootstrap event study effects
if bootstrap_event_study is not None:
if bootstrap_event_study is not None and event_study_info is not None:
for e, agg_info in event_study_info.items():
gt_indices = agg_info['gt_indices']
weights = agg_info['weights']
Expand All @@ -1566,7 +1595,7 @@ def _run_multiplier_bootstrap(
)

# Compute bootstrap group effects
if bootstrap_group is not None:
if bootstrap_group is not None and group_agg_info is not None:
for g, agg_info in group_agg_info.items():
gt_indices = agg_info['gt_indices']
weights = agg_info['weights']
Expand Down Expand Up @@ -1602,7 +1631,7 @@ def _run_multiplier_bootstrap(
event_study_cis = None
event_study_p_values = None

if bootstrap_event_study is not None:
if bootstrap_event_study is not None and event_study_info is not None:
event_study_ses = {}
event_study_cis = {}
event_study_p_values = {}
Expand All @@ -1622,7 +1651,7 @@ def _run_multiplier_bootstrap(
group_effect_cis = None
group_effect_p_values = None

if bootstrap_group is not None:
if bootstrap_group is not None and group_agg_info is not None:
group_effect_ses = {}
group_effect_cis = {}
group_effect_p_values = {}
Expand Down Expand Up @@ -1797,6 +1826,8 @@ def get_params(self) -> Dict[str, Any]:
"alpha": self.alpha,
"cluster": self.cluster,
"n_bootstrap": self.n_bootstrap,
"bootstrap_weights": self.bootstrap_weights,
# Deprecated but kept for backward compatibility
"bootstrap_weight_type": self.bootstrap_weight_type,
"seed": self.seed,
}
Expand Down
Loading