# Copyright 2022 - 2026 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Staggered Difference in Differences (Imputation-based).
This module implements the imputation-based staggered DiD estimator, following
the approach of Borusyak, Jaravel, and Spiess (2024). It handles settings where
different units receive treatment at different times.
"""
import warnings
from typing import Any, Literal
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib import pyplot as plt
from patsy import dmatrices
from sklearn.base import RegressorMixin
from causalpy.constants import HDI_PROB, LEGEND_FONT_SIZE
from causalpy.custom_exceptions import DataException, FormulaException
from causalpy.pymc_models import LinearRegression, PyMCModel
from causalpy.reporting import EffectSummary
from .base import BaseExperiment
[docs]
class StaggeredDifferenceInDifferences(BaseExperiment):
"""A class to analyse data from staggered adoption Difference-in-Differences settings.
This class implements the Borusyak, Jaravel, and Spiess (BJS, 2024)
imputation estimator for staggered adoption settings. It fits a model on
untreated observations only (pre-treatment periods for eventually-treated
units plus all periods for never-treated units), then predicts
counterfactual outcomes for all observations. Treatment effects are computed
as the difference between observed and predicted outcomes for treated
observations.
Assumptions
-----------
This estimator requires the following identifying assumptions:
1. **Absorbing treatment**: Once a unit receives treatment, it must remain
treated in all subsequent periods. Treatment cannot be reversed or
temporarily suspended. This is validated at runtime.
2. **Parallel trends**: In the absence of treatment, treated and control
units would have followed parallel outcome trajectories.
3. **No anticipation**: Units do not change their behavior in anticipation
of future treatment.
4. **Untreated support at each calendar period**: The time fixed effect
:math:`\\gamma_t` for calendar period :math:`t` is identified only if at
least one unit is untreated in that period. Without never-treated units,
post-treatment effects for the last-treated cohort (and any calendar
periods where every unit is already treated) are not identified. CausalPy
warns when this condition fails and marks the affected ``ATT(g, t)`` and
``ATT(e)`` cells as non-identified in the output tables.
Parameters
----------
data : pd.DataFrame
A pandas dataframe with panel data (unit x time observations).
formula : str
A statistical model formula. Recommended: "y ~ 1 + C(unit) + C(time)"
for unit and time fixed effects.
unit_variable_name : str
Name of the column identifying units.
time_variable_name : str
Name of the column identifying time periods.
treated_variable_name : str, optional
Name of the column indicating treatment status (0/1). Defaults to "treated".
treatment_time_variable_name : str, optional
Name of the column containing unit-level treatment time (G_i).
If None, treatment time is inferred from the treated_variable_name column.
never_treated_value : Any, optional
Value indicating never-treated units in treatment_time column.
Defaults to np.inf.
model : PyMCModel or RegressorMixin, optional
A model for the untreated outcome. Defaults to LinearRegression.
event_window : tuple[int, int], optional
Tuple (min_event_time, max_event_time) to restrict event-time aggregation.
If None, uses all available event-times.
reference_event_time : int, optional
Event-time index associated with plots (reserved for future use).
Defaults to -1.
**kwargs
Additional keyword arguments forwarded to :class:`BaseExperiment`.
Attributes
----------
data_ : pd.DataFrame
Augmented data with G (treatment time), event_time, y_hat0 (counterfactual),
and tau_hat (treatment effect) columns.
att_group_time_ : pd.DataFrame
Group-time ATT estimates: ATT(g, t) for each cohort g and calendar time t.
Includes an ``identified`` column; non-identified cells have ``NaN`` estimates.
att_event_time_ : pd.DataFrame
Event-time ATT estimates: ATT(e) for each event-time e = t - G.
Includes an ``identified`` column; non-identified cells have ``NaN`` estimates.
non_identified_periods_ : set
Calendar periods with no untreated observations.
non_identified_cohorts_ : set
Treatment cohorts with at least one non-identified post-treatment ATT(g, t).
Notes
-----
**Panel Balance**: This implementation supports both balanced and unbalanced panel
data. While balanced panels (where each unit is observed in every time period) are
common in staggered DiD applications, the imputation-based approach of Borusyak et
al. (2024) can accommodate unbalanced panels. The key requirement is that treatment
timing is well-defined for each unit, not that all units are observed in all periods.
Unit and observation counts in the summary output are computed without assuming
balanced panels.
Example
-------
>>> import causalpy as cp
>>> from causalpy.data.simulate_data import generate_staggered_did_data
>>> df = generate_staggered_did_data(n_units=30, n_time_periods=15, seed=42)
>>> result = cp.StaggeredDifferenceInDifferences(
... df,
... formula="y ~ 1 + C(unit) + C(time)",
... unit_variable_name="unit",
... time_variable_name="time",
... treated_variable_name="treated",
... treatment_time_variable_name="treatment_time",
... model=cp.pymc_models.LinearRegression(
... sample_kwargs={
... "tune": 100,
... "draws": 200,
... "chains": 2,
... "progressbar": False,
... }
... ),
... ) # doctest: +SKIP
References
----------
Borusyak, K., Jaravel, X., & Spiess, J. (2024). Revisiting Event Study Designs:
Robust and Efficient Estimation. Review of Economic Studies.
"""
supports_ols = True
supports_bayes = True
_default_model_class = LinearRegression
[docs]
def __init__(
self,
data: pd.DataFrame,
formula: str,
unit_variable_name: str,
time_variable_name: str,
treated_variable_name: str = "treated",
treatment_time_variable_name: str | None = None,
never_treated_value: Any = np.inf,
model: PyMCModel | RegressorMixin | None = None,
event_window: tuple[int, int] | None = None,
reference_event_time: int = -1,
**kwargs: Any,
) -> None:
# NOTE: kwargs is accepted for API compatibility with other experiment classes
# and is intentionally not used inside this constructor.
super().__init__(model=model)
# Store parameters
self.expt_type = "Staggered Difference in Differences"
self.formula = formula
self.unit_variable_name = unit_variable_name
self.time_variable_name = time_variable_name
self.treated_variable_name = treated_variable_name
self.treatment_time_variable_name = treatment_time_variable_name
self.never_treated_value = never_treated_value
self.event_window = event_window
self.reference_event_time = reference_event_time
# Make a copy of data to avoid modifying the original
data = data.copy()
data.index.name = "obs_ind"
# Input validation
self.data = data
self.input_validation()
# Step 1: Compute treatment time G_i for each unit
self._compute_treatment_times()
# Step 2: Compute event time for each observation
self._compute_event_times()
# Step 3: Identify untreated observations (training set)
self._identify_untreated_observations()
# Step 3b: Check calendar-period identification support
self._check_att_identification()
# Step 4: Build design matrices
self._build_design_matrices()
self.algorithm()
[docs]
def algorithm(self) -> None:
"""Run the experiment algorithm: fit model, predict counterfactuals, and aggregate effects."""
# Step 5: Fit model on untreated observations
self._fit_model()
# Step 6: Predict counterfactuals for all observations
self._predict_counterfactuals()
# Step 7: Compute treatment effects
self._compute_treatment_effects()
# Step 8: Aggregate to group-time and event-time ATTs
self._aggregate_effects()
def _validate_absorbing_treatment(self) -> None:
"""Validate that treatment is absorbing (once treated, always treated)."""
if self.treated_variable_name not in self.data.columns:
# Will infer from treatment_time, skip validation here
return
for unit in self.data[self.unit_variable_name].unique():
unit_data = self.data[
self.data[self.unit_variable_name] == unit
].sort_values(self.time_variable_name)
treated_values = unit_data[self.treated_variable_name].values
# Find first treated period
treated_indices = np.where(treated_values == 1)[0]
if len(treated_indices) == 0:
continue # Never treated
first_treated_idx = treated_indices[0]
# Check all subsequent periods are also treated
if not np.all(treated_values[first_treated_idx:] == 1):
raise DataException(
f"Treatment is not absorbing for unit {unit}. "
"Once a unit is treated, it must remain treated in all "
"subsequent periods."
)
def _compute_treatment_times(self) -> None:
"""Compute treatment time G_i for each unit."""
if self.treatment_time_variable_name is not None:
# Use provided treatment time column
# Get unique treatment time per unit
g_map = (
self.data.groupby(self.unit_variable_name)[
self.treatment_time_variable_name
]
.first()
.to_dict()
)
self.data["G"] = self.data[self.unit_variable_name].map(g_map)
else:
# Infer from treated variable: G = min{t : D_it = 1}
g_map = {}
for unit in self.data[self.unit_variable_name].unique():
unit_data = self.data[self.data[self.unit_variable_name] == unit]
treated_times = unit_data.loc[
unit_data[self.treated_variable_name] == 1, self.time_variable_name
]
if len(treated_times) == 0:
g_map[unit] = self.never_treated_value
else:
g_map[unit] = treated_times.min()
self.data["G"] = self.data[self.unit_variable_name].map(g_map)
# Store unique cohorts (excluding never-treated)
self.cohorts = sorted(
[g for g in self.data["G"].unique() if g != self.never_treated_value]
)
def _compute_event_times(self) -> None:
"""Compute event time (t - G) for each observation."""
self.data["event_time"] = self.data[self.time_variable_name] - self.data["G"]
# Set event_time to NaN for never-treated units
self.data.loc[self.data["G"] == self.never_treated_value, "event_time"] = np.nan
def _identify_untreated_observations(self) -> None:
"""Identify untreated observations for the training set."""
# Untreated if: (t < G) OR (never-treated)
is_never_treated = self.data["G"] == self.never_treated_value
is_pre_treatment = self.data[self.time_variable_name] < self.data["G"]
self.data["_is_untreated"] = is_never_treated | is_pre_treatment
# Verify we have some training data
n_untreated = self.data["_is_untreated"].sum()
if n_untreated == 0:
raise DataException(
"No untreated observations found. Cannot fit the model. "
"Ensure there are never-treated units or pre-treatment periods."
)
def _get_periods_without_untreated_support(self) -> set[Any]:
"""Return calendar periods with zero untreated observations."""
untreated_periods = set(
self.data.loc[self.data["_is_untreated"], self.time_variable_name].unique()
)
all_periods = set(self.data[self.time_variable_name].unique())
return all_periods - untreated_periods
def _get_non_identified_cohorts(self, periods: set[Any]) -> set[Any]:
"""Return cohorts with post-treatment cells in non-identified periods."""
non_identified_cohorts: set[Any] = set()
for cohort in self.cohorts:
for period in periods:
if period >= cohort:
non_identified_cohorts.add(cohort)
break
return non_identified_cohorts
def _check_att_identification(self) -> None:
"""Detect non-identified ATT cells and warn when untreated support is missing."""
self.non_identified_periods_ = self._get_periods_without_untreated_support()
self.non_identified_cohorts_ = self._get_non_identified_cohorts(
self.non_identified_periods_
)
if not self.non_identified_periods_:
return
periods_str = ", ".join(str(p) for p in sorted(self.non_identified_periods_))
cohorts_str = ", ".join(str(c) for c in sorted(self.non_identified_cohorts_))
warnings.warn(
"No untreated observations in calendar period(s) "
f"{{{periods_str}}}; treatment effects for cohort(s) "
f"{{{cohorts_str}}} are not identified at the affected post-treatment "
"cells. Provide never-treated units or restrict the event window. "
"Non-identified ATT(g, t) and ATT(e) cells are marked in the output "
"tables (identified=False) with NaN estimates.",
UserWarning,
stacklevel=2,
)
def _is_calendar_period_identified(self, period: Any) -> bool:
"""Return whether calendar period ``period`` has untreated support."""
return period not in self.non_identified_periods_
def _is_event_time_att_identified(self, event_time: int) -> bool:
"""Return whether aggregated ATT(e) is identified."""
for cohort in self.cohorts:
period = cohort + event_time
has_contributing_obs = (
(self.data["G"] == cohort)
& (self.data[self.time_variable_name] == period)
& (self.data["event_time"] == event_time)
).any()
if has_contributing_obs and not self._is_calendar_period_identified(period):
return False
return True
def _mark_non_identified_att_rows(self, att_df: pd.DataFrame) -> pd.DataFrame:
"""Add ``identified`` column and mask non-identified point estimates."""
if len(att_df) == 0:
att_df = att_df.copy()
att_df["identified"] = pd.Series(dtype=bool)
return att_df
att_df = att_df.copy()
if "cohort" in att_df.columns and "time" in att_df.columns:
att_df["identified"] = att_df["time"].map(
self._is_calendar_period_identified
)
elif "event_time" in att_df.columns:
att_df["identified"] = att_df["event_time"].apply(
lambda e: self._is_event_time_att_identified(int(e))
)
else:
att_df["identified"] = True
value_columns = [
col
for col in ("att", "att_lower", "att_upper", "att_std")
if col in att_df.columns
]
for col in value_columns:
att_df.loc[~att_df["identified"], col] = np.nan
return att_df
def _build_design_matrices(self) -> None:
"""Build design matrices using patsy."""
# Build design matrix for the full data
y, X = dmatrices(self.formula, self.data)
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.outcome_variable_name = y.design_info.column_names[0]
# Store full design matrix
self.X_full = np.asarray(X)
self.y_full = np.asarray(y)
# Get untreated subset for training
untreated_mask = np.asarray(self.data["_is_untreated"].values, dtype=bool)
self.X_train = self.X_full[untreated_mask]
self.y_train = self.y_full[untreated_mask]
def _fit_model(self) -> None:
"""Fit the model on untreated observations only."""
# Convert to xarray for PyMC models
n_train = self.X_train.shape[0]
if isinstance(self.model, PyMCModel):
X_train_xr = xr.DataArray(
self.X_train,
dims=["obs_ind", "coeffs"],
coords={
"obs_ind": np.arange(n_train),
"coeffs": self.labels,
},
)
y_train_xr = xr.DataArray(
self.y_train,
dims=["obs_ind", "treated_units"],
coords={"obs_ind": np.arange(n_train), "treated_units": ["unit_0"]},
)
COORDS = {
"coeffs": self.labels,
"obs_ind": np.arange(n_train),
"treated_units": ["unit_0"],
}
self.model.fit(X=X_train_xr, y=y_train_xr, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
self.model.fit(X=self.X_train, y=self.y_train)
else:
raise ValueError("Model type not recognized")
def _predict_counterfactuals(self) -> None:
"""Predict counterfactual outcomes for all observations."""
n_full = self.X_full.shape[0]
if isinstance(self.model, PyMCModel):
X_full_xr = xr.DataArray(
self.X_full,
dims=["obs_ind", "coeffs"],
coords={
"obs_ind": np.arange(n_full),
"coeffs": self.labels,
},
)
self.y_pred = self.model.predict(X=X_full_xr)
# Extract posterior mean for y_hat0
y_hat0_mean = (
self.y_pred["posterior_predictive"]
.mu.mean(dim=["chain", "draw"])
.isel(treated_units=0)
.values
)
self.data["y_hat0"] = y_hat0_mean
elif isinstance(self.model, RegressorMixin):
self.y_pred = self.model.predict(self.X_full)
self.data["y_hat0"] = np.squeeze(self.y_pred)
else:
raise ValueError("Model type not recognized")
def _compute_treatment_effects(self) -> None:
"""Compute treatment effects tau_hat = y - y_hat0 for treated observations."""
self.data["tau_hat"] = np.nan # Initialize with NaN
treated_mask = ~self.data["_is_untreated"]
self.data.loc[treated_mask, "tau_hat"] = (
self.data.loc[treated_mask, self.outcome_variable_name]
- self.data.loc[treated_mask, "y_hat0"]
)
# Store augmented data
self.data_ = self.data.copy()
def _aggregate_effects(self) -> None:
"""Aggregate effects to group-time and event-time ATTs.
This method aggregates individual treatment effects into:
1. Group-time ATTs: ATT(g, t) for each cohort g and calendar time t
2. Event-time ATTs: ATT(e) for each event-time e = t - G
For event-time ATTs, this includes both:
- Post-treatment effects (event_time >= 0): actual treatment effects
- Pre-treatment effects (event_time < 0): placebo/residual checks
Pre-treatment effects are computed as residuals (y - y_hat0) for
eventually-treated units before they receive treatment. These serve
as a placebo check - if the parallel trends assumption holds, they
should be centered around zero.
"""
treated_data = self.data[~self.data["_is_untreated"]].copy()
# Also get pre-treatment data for eventually-treated units (placebo check)
# These are observations where: G != never_treated_value AND event_time < 0
is_eventually_treated = self.data["G"] != self.never_treated_value
is_pre_treatment = self.data["event_time"] < 0
pretreatment_data = self.data[is_eventually_treated & is_pre_treatment].copy()
if isinstance(self.model, PyMCModel):
self._aggregate_effects_bayesian(treated_data, pretreatment_data)
else:
self._aggregate_effects_ols(treated_data, pretreatment_data)
def _aggregate_effects_bayesian(
self,
treated_data: pd.DataFrame,
pretreatment_data: pd.DataFrame,
hdi_prob: float = HDI_PROB,
) -> None:
"""Aggregate effects for Bayesian model with posterior uncertainty.
Parameters
----------
treated_data : pd.DataFrame
DataFrame containing only treated observations (event_time >= 0)
pretreatment_data : pd.DataFrame
DataFrame containing pre-treatment observations from eventually-treated
units (event_time < 0) for placebo check
hdi_prob : float, optional
Probability mass for the HDI interval bounds. Defaults to
:data:`~causalpy.constants.HDI_PROB` (currently 0.94).
"""
# Store the HDI probability used for interval computation
self.hdi_prob_ = hdi_prob
lower_pct = (1 - hdi_prob) / 2 * 100
upper_pct = (1 + hdi_prob) / 2 * 100
# Get posterior draws for mu
mu_draws = self.y_pred["posterior_predictive"].mu.isel(treated_units=0)
# Get observed y for all observations
y_observed = np.asarray(self.data[self.outcome_variable_name].values)
# Compute tau draws for all observations
# tau_draws has shape (chain, draw, obs_ind)
tau_draws_all = y_observed - mu_draws.values
# Get treated observation indices for group-time ATTs
_is_untreated = np.asarray(self.data["_is_untreated"].values, dtype=bool)
treated_mask = ~_is_untreated
treated_indices = np.where(treated_mask)[0]
tau_draws_treated = tau_draws_all[:, :, treated_indices]
event_time_treated = np.asarray(treated_data["event_time"].values)
# --- Group-time ATTs (post-treatment only) ---
gt_groups = treated_data.groupby(["G", self.time_variable_name]).groups
att_gt_rows: list[dict] = []
for key, idx in gt_groups.items():
g_val = key[0] # type: ignore[index]
t_val = key[1] # type: ignore[index]
# Find positions in treated_indices
positions = [np.where(treated_indices == i)[0][0] for i in idx]
tau_gt = tau_draws_treated[:, :, positions].mean(axis=2)
att_gt_rows.append(
{
"cohort": g_val,
"time": t_val,
"att": float(tau_gt.mean()),
"att_lower": float(np.percentile(tau_gt, lower_pct)),
"att_upper": float(np.percentile(tau_gt, upper_pct)),
}
)
self.att_group_time_ = self._mark_non_identified_att_rows(
pd.DataFrame(att_gt_rows)
)
# --- Event-time ATTs (including pre-treatment placebo) ---
att_et_rows: list[dict] = []
# Pre-treatment placebo effects (event_time < 0)
if len(pretreatment_data) > 0:
pretreat_indices = pretreatment_data.index.values
pretreat_idx_positions = np.array(
[np.where(self.data.index == idx)[0][0] for idx in pretreat_indices]
)
tau_draws_pretreat = tau_draws_all[:, :, pretreat_idx_positions]
event_time_pretreat = np.asarray(pretreatment_data["event_time"].values)
event_times_pre = np.unique(
event_time_pretreat[~np.isnan(event_time_pretreat)]
)
# Apply event window filter if specified
if self.event_window is not None:
event_times_pre = event_times_pre[
(event_times_pre >= self.event_window[0])
& (event_times_pre <= self.event_window[1])
]
for e in sorted(event_times_pre):
e_mask = event_time_pretreat == e
if e_mask.sum() == 0:
continue
positions_arr = np.where(e_mask)[0]
tau_e = tau_draws_pretreat[:, :, positions_arr].mean(axis=2)
att_et_rows.append(
{
"event_time": int(e),
"att": float(tau_e.mean()),
"att_lower": float(np.percentile(tau_e, lower_pct)),
"att_upper": float(np.percentile(tau_e, upper_pct)),
"n_obs": int(e_mask.sum()),
}
)
# Post-treatment effects (event_time >= 0)
event_times_post = np.unique(event_time_treated[~np.isnan(event_time_treated)])
if self.event_window is not None:
event_times_post = event_times_post[
(event_times_post >= self.event_window[0])
& (event_times_post <= self.event_window[1])
]
for e in sorted(event_times_post):
e_mask = event_time_treated == e
if e_mask.sum() == 0:
continue
positions_arr = np.where(e_mask)[0]
tau_e = tau_draws_treated[:, :, positions_arr].mean(axis=2)
att_et_rows.append(
{
"event_time": int(e),
"att": float(tau_e.mean()),
"att_lower": float(np.percentile(tau_e, lower_pct)),
"att_upper": float(np.percentile(tau_e, upper_pct)),
"n_obs": int(e_mask.sum()),
}
)
self.att_event_time_ = self._mark_non_identified_att_rows(
pd.DataFrame(att_et_rows)
)
def _aggregate_effects_ols(
self, treated_data: pd.DataFrame, pretreatment_data: pd.DataFrame
) -> None:
"""Aggregate effects for OLS model (point estimates only).
Parameters
----------
treated_data : pd.DataFrame
DataFrame containing only treated observations (event_time >= 0)
pretreatment_data : pd.DataFrame
DataFrame containing pre-treatment observations from eventually-treated
units (event_time < 0) for placebo check
"""
# --- Group-time ATTs (post-treatment only) ---
att_gt = (
treated_data.groupby(["G", self.time_variable_name])["tau_hat"]
.agg(["mean", "std", "count"])
.reset_index()
)
att_gt.columns = ["cohort", "time", "att", "att_std", "n_obs"]
self.att_group_time_ = self._mark_non_identified_att_rows(att_gt)
# --- Event-time ATTs (including pre-treatment placebo) ---
# Compute tau_hat for pre-treatment observations (residuals)
if len(pretreatment_data) > 0:
pretreatment_data = pretreatment_data.copy()
pretreatment_data["tau_hat"] = (
pretreatment_data[self.outcome_variable_name]
- pretreatment_data["y_hat0"]
)
# Combine pre-treatment and post-treatment for event-time aggregation
event_data = pd.concat([pretreatment_data, treated_data], ignore_index=True)
# Apply event window filter if specified
if self.event_window is not None:
event_data = event_data[
(event_data["event_time"] >= self.event_window[0])
& (event_data["event_time"] <= self.event_window[1])
]
att_et = (
event_data.groupby("event_time")["tau_hat"]
.agg(["mean", "std", "count"])
.reset_index()
)
att_et.columns = ["event_time", "att", "att_std", "n_obs"]
att_et["event_time"] = att_et["event_time"].astype(int)
self.att_event_time_ = self._mark_non_identified_att_rows(att_et)
[docs]
def summary(
self, round_to: int | None = 2, include_group_time: bool = False
) -> None:
"""Print summary of main results.
Parameters
----------
round_to : int, optional
Number of decimals for rounding. Defaults to 2.
include_group_time : bool
Whether to print the disaggregated cohort-by-calendar-time
``ATT(g, t)`` table after the event-time estimates. Defaults to
``False``.
"""
print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
print(f"Number of units: {self.data[self.unit_variable_name].nunique()}")
print(f"Number of time periods: {self.data[self.time_variable_name].nunique()}")
print(f"Treatment cohorts: {self.cohorts}")
n_never_treated = self.data.loc[
self.data["G"] == self.never_treated_value, self.unit_variable_name
].nunique()
print(f"Never-treated units: {n_never_treated}")
print("\nEvent-time estimates:")
att_et = self.att_event_time_.copy()
# Add indicator column for clarity
att_et["type"] = att_et["event_time"].apply(
lambda x: "placebo" if x < 0 else "ATT"
)
# Reorder columns to put type first
cols = ["event_time", "type"] + [
c for c in att_et.columns if c not in ["event_time", "type"]
]
print(att_et[cols].to_string(index=False))
if include_group_time:
print("\nGroup-time estimates:")
print(self.att_group_time_.to_string(index=False))
print("\nModel coefficients:")
self.print_coefficients(round_to)
[docs]
def plot(
self,
*,
hdi_prob: float | None = None,
figsize: tuple[float, float] = (10, 6),
show: bool = True,
legend_kwargs: dict[str, Any] | None = None,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot the staggered difference-in-differences event study.
Parameters
----------
hdi_prob : float, optional
Probability mass of the highest density interval shown by the
error bars. Unlike most other CausalPy experiments, ``hdi_prob``
for staggered DiD is fixed at fit time during effect aggregation
and the resulting bounds are cached on the instance. If
supplied here, the value must match the cached
:attr:`hdi_prob_`; otherwise a :class:`ValueError` is raised.
Pass ``None`` (the default) to plot using the cached value.
Ignored for OLS models.
figsize : tuple of (float, float)
Width and height of the figure in inches, passed to
:func:`matplotlib.pyplot.subplots`. Defaults to ``(10, 6)``.
show : bool
Whether to automatically display the plot. Defaults to ``True``.
legend_kwargs : dict, optional
Keyword arguments to adjust legend placement and styling.
Supported keys: ``loc``, ``bbox_to_anchor``, ``fontsize``,
``frameon``, ``title`` (``bbox_transform`` is accepted alongside
``bbox_to_anchor``). The existing legend is modified **in
place** so that custom handles are preserved.
Returns
-------
fig : matplotlib.figure.Figure
The figure that was created.
ax : list[matplotlib.axes.Axes]
A single-element list containing the event-study axes.
"""
return self._render_plot(
show=show,
legend_kwargs=legend_kwargs,
hdi_prob=hdi_prob,
figsize=figsize,
)
[docs]
def plot_group_time(
self,
*,
hdi_prob: float | None = None,
layout: Literal["facet", "overlay"] = "facet",
x_axis: Literal["event_time", "calendar_time"] = "event_time",
include_placebo: bool = True,
figsize: tuple[float, float] | None = None,
show: bool = True,
legend_kwargs: dict[str, Any] | None = None,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot cohort-specific ``ATT(g, t)`` trajectories.
Parameters
----------
hdi_prob : float, optional
Probability mass of the highest density interval shown by the
uncertainty bands. As with :meth:`plot`, Bayesian ``ATT(g, t)``
bounds are cached during effect aggregation. If supplied here, the
value must match the cached :attr:`hdi_prob_`; otherwise a
:class:`ValueError` is raised. Pass ``None`` (the default) to plot
using the cached value. Ignored for OLS models.
layout : {"facet", "overlay"}
Plot layout. ``"facet"`` draws one row per cohort and
``"overlay"`` draws all cohorts on a single axes. Defaults to
``"facet"``.
x_axis : {"event_time", "calendar_time"}
Time scale for the cohort trajectories. ``"event_time"`` plots
each cohort against periods since treatment, giving an
``ATT(g, e)`` view derived from ``ATT(g, t)``. ``"calendar_time"``
plots each cohort against calendar time ``t``. Defaults to
``"event_time"``.
include_placebo : bool
Whether to include pre-treatment residual estimates for
eventually-treated cohorts as placebo diagnostics. Defaults to
``True``.
figsize : tuple of (float, float), optional
Width and height of the figure in inches, passed to
:func:`matplotlib.pyplot.subplots`. Defaults to a height scaled by
the number of cohorts when ``layout="facet"`` and ``(10, 6)``
when ``layout="overlay"``.
show : bool
Whether to automatically display the plot. Defaults to ``True``.
legend_kwargs : dict, optional
Keyword arguments to adjust legend placement and styling.
Supported keys: ``loc``, ``bbox_to_anchor``, ``fontsize``,
``frameon``, ``title`` (``bbox_transform`` is accepted alongside
``bbox_to_anchor``). The existing legend is modified **in place**
so that custom handles are preserved.
Returns
-------
fig : matplotlib.figure.Figure
The figure that was created.
ax : list[matplotlib.axes.Axes]
Axes containing the cohort trajectories. The list has one axes
per cohort when ``layout="facet"`` and one axes when
``layout="overlay"``.
"""
return self._render_plot(
show=show,
legend_kwargs=legend_kwargs,
hdi_prob=hdi_prob,
layout=layout,
x_axis=x_axis,
include_placebo=include_placebo,
figsize=figsize,
view="group_time",
)
def _bayesian_plot(
self,
hdi_prob: float | None = None,
figsize: tuple[float, float] | None = (10, 6),
view: Literal["event_time", "group_time"] = "event_time",
layout: Literal["facet", "overlay"] = "facet",
x_axis: Literal["event_time", "calendar_time"] = "event_time",
include_placebo: bool = True,
**kwargs: Any,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot results for Bayesian model.
Parameters
----------
hdi_prob : float, optional
Probability mass of the highest density interval shown by the
error bars. Unlike most other CausalPy experiments, ``hdi_prob``
for ``StaggeredDiD`` is fixed at fit time during effect
aggregation (see ``_aggregate_effects_bayesian``) and the
resulting bounds are cached on the instance. If supplied here,
the value must match the cached
:attr:`~causalpy.experiments.staggered_did.StaggeredDiD.hdi_prob_`;
otherwise a :class:`ValueError` is raised. Pass ``None`` (the
default) to plot using the cached value.
figsize : tuple of (float, float), optional
Width and height of the figure in inches. Defaults to ``(10, 6)``.
view : {"event_time", "group_time"}, optional
Plot view to render. ``"event_time"`` draws the aggregated event
study and ``"group_time"`` draws cohort-specific ``ATT(g, t)``
trajectories. Defaults to ``"event_time"``.
layout : {"facet", "overlay"}, optional
Plot layout for the ``"group_time"`` view. Defaults to
``"facet"``.
x_axis : {"event_time", "calendar_time"}, optional
Time scale for the ``"group_time"`` view. Defaults to
``"event_time"``.
include_placebo : bool, optional
Whether to include pre-treatment residual estimates in the
``"group_time"`` view. Defaults to ``True``.
Returns
-------
tuple[plt.Figure, list[plt.Axes]]
Figure and axes objects.
"""
if hdi_prob is not None and hdi_prob != self.hdi_prob_:
raise ValueError(
"StaggeredDiD HDI bounds are computed during effect "
"aggregation, not at plot time. The cached HDI probability "
f"is {self.hdi_prob_}, but plot() received hdi_prob="
f"{hdi_prob}. To plot at a different HDI probability, "
"re-fit the experiment so that aggregation uses the desired "
"value, or omit hdi_prob to use the cached value."
)
if view == "group_time":
return self._bayesian_plot_group_time(
figsize=figsize,
layout=layout,
x_axis=x_axis,
include_placebo=include_placebo,
)
if view != "event_time":
raise ValueError("view must be 'event_time' or 'group_time'")
fig, ax = plt.subplots(1, 1, figsize=figsize)
att_et = self.att_event_time_.copy()
# Separate pre-treatment (placebo) and post-treatment (ATT)
pre_treatment = att_et[att_et["event_time"] < 0]
post_treatment = att_et[att_et["event_time"] >= 0]
# Plot pre-treatment placebo estimates (different style)
if len(pre_treatment) > 0:
ax.errorbar(
pre_treatment["event_time"],
pre_treatment["att"],
yerr=[
pre_treatment["att"] - pre_treatment["att_lower"],
pre_treatment["att_upper"] - pre_treatment["att"],
],
fmt="s", # Square markers for placebo
capsize=4,
capthick=2,
markersize=7,
color="gray",
alpha=0.7,
label=f"Placebo estimate ({int(self.hdi_prob_ * 100)}% HDI)",
)
# Plot post-treatment ATT estimates
if len(post_treatment) > 0:
ax.errorbar(
post_treatment["event_time"],
post_treatment["att"],
yerr=[
post_treatment["att"] - post_treatment["att_lower"],
post_treatment["att_upper"] - post_treatment["att"],
],
fmt="o",
capsize=4,
capthick=2,
markersize=8,
color="C0",
label=f"ATT estimate ({int(self.hdi_prob_ * 100)}% HDI)",
)
# Add horizontal line at zero
ax.axhline(y=0, color="black", linestyle="--", linewidth=1, alpha=0.7)
# Add vertical line at event_time = 0 (treatment onset)
ax.axvline(x=-0.5, color="red", linestyle="-", linewidth=2, alpha=0.7)
# Shade pre-treatment region
event_min = att_et["event_time"].min()
if event_min < 0:
ax.axvspan(
event_min - 0.5,
-0.5,
alpha=0.1,
color="gray",
)
# Labels and formatting
ax.set_xlabel("Event Time (periods relative to treatment)", fontsize=12)
ax.set_ylabel("Effect Estimate", fontsize=12)
ax.set_title("Staggered DiD Event Study", fontsize=14)
ax.legend(fontsize=LEGEND_FONT_SIZE)
# Set integer ticks for event time
ax.set_xticks(att_et["event_time"].values)
return fig, [ax]
def _bayesian_plot_group_time(
self,
figsize: tuple[float, float] | None = None,
layout: Literal["facet", "overlay"] = "facet",
x_axis: Literal["event_time", "calendar_time"] = "event_time",
include_placebo: bool = True,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot Bayesian cohort-time ``ATT(g, t)`` trajectories."""
att_gt, x_col, x_label, y_label = self._get_group_time_plot_data(
x_axis=x_axis, include_placebo=include_placebo
)
cohort_groups = list(att_gt.groupby("cohort", sort=True))
sharex = x_axis == "event_time"
fig, axes = self._make_group_time_axes(
att_gt=att_gt,
layout=layout,
figsize=figsize,
sharex=sharex,
sharey=layout == "facet",
)
for cohort_idx, (cohort, cohort_data) in enumerate(cohort_groups):
ax = axes[cohort] if layout == "facet" else axes["overlay"]
self._plot_bayesian_group_time_segment(
ax=ax,
cohort_data=cohort_data[cohort_data["type"] == "placebo"],
x_col=x_col,
line_type="placebo",
color="gray" if layout == "facet" else f"C{cohort_idx % 10}",
label=(
"Placebo estimate"
if layout == "facet"
else f"Cohort {cohort} placebo"
),
)
self._plot_bayesian_group_time_segment(
ax=ax,
cohort_data=cohort_data[cohort_data["type"] == "ATT"],
x_col=x_col,
line_type="ATT",
color="C0" if layout == "facet" else f"C{cohort_idx % 10}",
label="ATT estimate" if layout == "facet" else f"Cohort {cohort} ATT",
)
self._format_group_time_axis(
ax=ax,
cohort=cohort if layout == "facet" else None,
x_label=self._get_group_time_axis_label(
x_label=x_label,
layout=layout,
sharex=sharex,
axis_index=cohort_idx,
n_axes=len(cohort_groups),
),
y_label=y_label,
x_axis=x_axis,
treatment_time=cohort,
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
if layout == "overlay":
axes["overlay"].legend(title="Treatment cohort", fontsize=LEGEND_FONT_SIZE)
return fig, list(axes.values())
def _ols_plot(
self,
figsize: tuple[float, float] | None = (10, 6),
view: Literal["event_time", "group_time"] = "event_time",
layout: Literal["facet", "overlay"] = "facet",
x_axis: Literal["event_time", "calendar_time"] = "event_time",
include_placebo: bool = True,
**kwargs: Any,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot results for OLS model.
Parameters
----------
figsize : tuple of (float, float), optional
Width and height of the figure in inches. Defaults to ``(10, 6)``.
view : {"event_time", "group_time"}, optional
Plot view to render. ``"event_time"`` draws the aggregated event
study and ``"group_time"`` draws cohort-specific ``ATT(g, t)``
trajectories. Defaults to ``"event_time"``.
layout : {"facet", "overlay"}, optional
Plot layout for the ``"group_time"`` view. Defaults to
``"facet"``.
x_axis : {"event_time", "calendar_time"}, optional
Time scale for the ``"group_time"`` view. Defaults to
``"event_time"``.
include_placebo : bool, optional
Whether to include pre-treatment residual estimates in the
``"group_time"`` view. Defaults to ``True``.
Returns
-------
tuple[plt.Figure, list[plt.Axes]]
Figure and axes objects.
"""
if view == "group_time":
return self._ols_plot_group_time(
figsize=figsize,
layout=layout,
x_axis=x_axis,
include_placebo=include_placebo,
)
if view != "event_time":
raise ValueError("view must be 'event_time' or 'group_time'")
fig, ax = plt.subplots(1, 1, figsize=figsize)
att_et = self.att_event_time_.copy()
# Separate pre-treatment (placebo) and post-treatment (ATT)
pre_treatment = att_et[att_et["event_time"] < 0]
post_treatment = att_et[att_et["event_time"] >= 0]
# Plot pre-treatment placebo estimates (different style)
if len(pre_treatment) > 0:
ax.scatter(
pre_treatment["event_time"],
pre_treatment["att"],
s=60,
color="gray",
marker="s", # Square markers for placebo
zorder=3,
alpha=0.7,
label="Placebo estimate",
)
# Add error bars if std available
if "att_std" in pre_treatment.columns:
se = pre_treatment["att_std"] / np.sqrt(pre_treatment["n_obs"])
ax.errorbar(
pre_treatment["event_time"],
pre_treatment["att"],
yerr=1.96 * se,
fmt="none",
capsize=4,
capthick=2,
color="gray",
alpha=0.5,
)
# Plot post-treatment ATT estimates
if len(post_treatment) > 0:
ax.scatter(
post_treatment["event_time"],
post_treatment["att"],
s=80,
color="C0",
zorder=3,
label="ATT estimate",
)
# Add error bars if std available
if "att_std" in post_treatment.columns:
se = post_treatment["att_std"] / np.sqrt(post_treatment["n_obs"])
ax.errorbar(
post_treatment["event_time"],
post_treatment["att"],
yerr=1.96 * se,
fmt="none",
capsize=4,
capthick=2,
color="C0",
alpha=0.7,
)
# Add horizontal line at zero
ax.axhline(y=0, color="black", linestyle="--", linewidth=1, alpha=0.7)
# Add vertical line at event_time = 0 (treatment onset)
ax.axvline(x=-0.5, color="red", linestyle="-", linewidth=2, alpha=0.7)
# Shade pre-treatment region
event_min = att_et["event_time"].min()
if event_min < 0:
ax.axvspan(
event_min - 0.5,
-0.5,
alpha=0.1,
color="gray",
)
# Labels and formatting
ax.set_xlabel("Event Time (periods relative to treatment)", fontsize=12)
ax.set_ylabel("Effect Estimate", fontsize=12)
ax.set_title("Staggered DiD Event Study", fontsize=14)
ax.legend(fontsize=LEGEND_FONT_SIZE)
# Set integer ticks for event time
ax.set_xticks(att_et["event_time"].values)
return fig, [ax]
def _ols_plot_group_time(
self,
figsize: tuple[float, float] | None = None,
layout: Literal["facet", "overlay"] = "facet",
x_axis: Literal["event_time", "calendar_time"] = "event_time",
include_placebo: bool = True,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot OLS cohort-time ``ATT(g, t)`` trajectories."""
att_gt, x_col, x_label, y_label = self._get_group_time_plot_data(
x_axis=x_axis, include_placebo=include_placebo
)
cohort_groups = list(att_gt.groupby("cohort", sort=True))
sharex = x_axis == "event_time"
fig, axes = self._make_group_time_axes(
att_gt=att_gt,
layout=layout,
figsize=figsize,
sharex=sharex,
sharey=layout == "facet",
)
for cohort_idx, (cohort, cohort_data) in enumerate(cohort_groups):
ax = axes[cohort] if layout == "facet" else axes["overlay"]
self._plot_ols_group_time_segment(
ax=ax,
cohort_data=cohort_data[cohort_data["type"] == "placebo"],
x_col=x_col,
line_type="placebo",
color="gray" if layout == "facet" else f"C{cohort_idx % 10}",
label=(
"Placebo estimate"
if layout == "facet"
else f"Cohort {cohort} placebo"
),
)
self._plot_ols_group_time_segment(
ax=ax,
cohort_data=cohort_data[cohort_data["type"] == "ATT"],
x_col=x_col,
line_type="ATT",
color="C0" if layout == "facet" else f"C{cohort_idx % 10}",
label="ATT estimate" if layout == "facet" else f"Cohort {cohort} ATT",
)
self._format_group_time_axis(
ax=ax,
cohort=cohort if layout == "facet" else None,
x_label=self._get_group_time_axis_label(
x_label=x_label,
layout=layout,
sharex=sharex,
axis_index=cohort_idx,
n_axes=len(cohort_groups),
),
y_label=y_label,
x_axis=x_axis,
treatment_time=cohort,
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
if layout == "overlay":
axes["overlay"].legend(title="Treatment cohort", fontsize=LEGEND_FONT_SIZE)
return fig, list(axes.values())
def _get_group_time_plot_data(
self,
x_axis: Literal["event_time", "calendar_time"],
include_placebo: bool,
) -> tuple[pd.DataFrame, str, str, str]:
"""Return cohort-time data with the requested plotting time scale."""
if x_axis not in {"event_time", "calendar_time"}:
raise ValueError("x_axis must be 'event_time' or 'calendar_time'")
att_gt = self.att_group_time_.sort_values(["cohort", "time"]).copy()
att_gt["type"] = "ATT"
if include_placebo:
att_gt = pd.concat(
[self._get_group_time_placebo_data(), att_gt],
ignore_index=True,
sort=False,
).sort_values(["cohort", "time"])
y_label = "ATT(g, e)" if x_axis == "event_time" else "ATT(g, t)"
if include_placebo:
y_label = f"{y_label} / placebo"
if x_axis == "event_time":
att_gt["event_time"] = att_gt["time"] - att_gt["cohort"]
return (
att_gt,
"event_time",
"Event Time (periods relative to treatment)",
y_label,
)
return att_gt, "time", "Calendar Time", y_label
def _get_group_time_placebo_data(self) -> pd.DataFrame:
"""Return cohort-time placebo estimates for eventually-treated units."""
if isinstance(self.model, PyMCModel):
return self._get_group_time_placebo_data_bayesian()
return self._get_group_time_placebo_data_ols()
def _get_group_time_placebo_observations(self) -> pd.DataFrame:
"""Return pre-treatment observations for eventually-treated units."""
is_eventually_treated = self.data["G"] != self.never_treated_value
is_pre_treatment = self.data["event_time"] < 0
return self.data[is_eventually_treated & is_pre_treatment].copy()
def _get_group_time_placebo_data_bayesian(self) -> pd.DataFrame:
"""Return Bayesian cohort-time placebo estimates with HDI bounds."""
pretreatment_data = self._get_group_time_placebo_observations()
if len(pretreatment_data) == 0:
return pd.DataFrame()
hdi_prob = getattr(self, "hdi_prob_", HDI_PROB)
lower_pct = (1 - hdi_prob) / 2 * 100
upper_pct = (1 + hdi_prob) / 2 * 100
mu_draws = self.y_pred["posterior_predictive"].mu.isel(treated_units=0)
y_observed = np.asarray(self.data[self.outcome_variable_name].values)
tau_draws_all = y_observed - mu_draws.values
att_gt_rows: list[dict[str, Any]] = []
gt_groups = pretreatment_data.groupby(["G", self.time_variable_name]).groups
for key, idx in gt_groups.items():
g_val = key[0] # type: ignore[index]
t_val = key[1] # type: ignore[index]
positions = [np.where(self.data.index == i)[0][0] for i in idx]
tau_gt = tau_draws_all[:, :, positions].mean(axis=2)
att_gt_rows.append(
{
"cohort": g_val,
"time": t_val,
"att": float(tau_gt.mean()),
"att_lower": float(np.percentile(tau_gt, lower_pct)),
"att_upper": float(np.percentile(tau_gt, upper_pct)),
"n_obs": len(positions),
"type": "placebo",
}
)
return pd.DataFrame(att_gt_rows)
def _get_group_time_placebo_data_ols(self) -> pd.DataFrame:
"""Return OLS cohort-time placebo residual estimates."""
pretreatment_data = self._get_group_time_placebo_observations()
if len(pretreatment_data) == 0:
return pd.DataFrame()
pretreatment_data["tau_hat"] = (
pretreatment_data[self.outcome_variable_name] - pretreatment_data["y_hat0"]
)
att_gt = (
pretreatment_data.groupby(["G", self.time_variable_name])["tau_hat"]
.agg(["mean", "std", "count"])
.reset_index()
)
att_gt.columns = ["cohort", "time", "att", "att_std", "n_obs"]
att_gt["type"] = "placebo"
return att_gt
def _make_group_time_axes(
self,
att_gt: pd.DataFrame,
layout: Literal["facet", "overlay"],
figsize: tuple[float, float] | None,
sharex: bool,
sharey: bool,
) -> tuple[plt.Figure, dict[Any, plt.Axes]]:
"""Create axes for cohort trajectory plots."""
cohorts = list(att_gt["cohort"].drop_duplicates())
if layout == "overlay":
fig, ax = plt.subplots(
1, 1, figsize=figsize or (10, 6), layout="constrained"
)
return fig, {"overlay": ax}
if layout != "facet":
raise ValueError("layout must be 'facet' or 'overlay'")
fig_height = max(2.5 * len(cohorts), 3.0)
fig, axes_arr = plt.subplots(
len(cohorts),
1,
figsize=figsize or (10, fig_height),
sharex=sharex,
sharey=sharey,
squeeze=False,
layout="constrained",
)
return fig, {
cohort: axes_arr[row_idx, 0] for row_idx, cohort in enumerate(cohorts)
}
def _format_group_time_axis(
self,
ax: plt.Axes,
cohort: Any | None,
x_label: str,
y_label: str,
x_axis: Literal["event_time", "calendar_time"],
treatment_time: Any,
) -> None:
"""Apply shared formatting for cohort trajectory axes."""
ax.axhline(y=0, color="black", linestyle="--", linewidth=1, alpha=0.7)
if x_axis == "event_time":
ax.axvline(x=-0.5, color="red", linestyle="-", linewidth=1, alpha=0.5)
elif cohort is not None:
ax.axvline(
x=treatment_time - 0.5,
color="red",
linestyle="-",
linewidth=1,
alpha=0.5,
)
ax.set_xlabel(x_label, fontsize=12)
ax.set_ylabel(y_label, fontsize=12)
if cohort is None:
ax.set_title("Staggered DiD Cohort Trajectories", fontsize=14)
else:
ax.set_title(f"Cohort {cohort}", fontsize=12)
def _get_group_time_axis_label(
self,
x_label: str,
layout: Literal["facet", "overlay"],
sharex: bool,
axis_index: int,
n_axes: int,
) -> str:
"""Return an x-axis label only where it helps the figure."""
if layout == "facet" and sharex and axis_index < n_axes - 1:
return ""
return x_label
def _plot_bayesian_group_time_segment(
self,
ax: plt.Axes,
cohort_data: pd.DataFrame,
x_col: str,
line_type: Literal["placebo", "ATT"],
color: str,
label: str,
) -> None:
"""Plot one Bayesian placebo or ATT segment for a cohort."""
if len(cohort_data) == 0:
return
marker = "s" if line_type == "placebo" else "o"
linestyle = "--" if line_type == "placebo" else "-"
alpha = 0.15 if line_type == "placebo" else 0.2
ax.plot(
cohort_data[x_col],
cohort_data["att"],
marker=marker,
linestyle=linestyle,
color=color,
label=label,
)
ax.fill_between(
cohort_data[x_col],
cohort_data["att_lower"],
cohort_data["att_upper"],
color=color,
alpha=alpha,
)
def _plot_ols_group_time_segment(
self,
ax: plt.Axes,
cohort_data: pd.DataFrame,
x_col: str,
line_type: Literal["placebo", "ATT"],
color: str,
label: str,
) -> None:
"""Plot one OLS placebo or ATT segment for a cohort."""
if len(cohort_data) == 0:
return
marker = "s" if line_type == "placebo" else "o"
linestyle = "--" if line_type == "placebo" else "-"
if {"att_std", "n_obs"}.issubset(cohort_data.columns):
se = cohort_data["att_std"] / np.sqrt(cohort_data["n_obs"])
ax.errorbar(
cohort_data[x_col],
cohort_data["att"],
yerr=1.96 * se,
fmt=f"{marker}{linestyle}",
capsize=4,
capthick=2,
color=color,
label=label,
)
else:
ax.plot(
cohort_data[x_col],
cohort_data["att"],
marker=marker,
linestyle=linestyle,
color=color,
label=label,
)
[docs]
def get_plot_data_bayesian(self, hdi_prob: float = HDI_PROB) -> pd.DataFrame:
"""Get plotting data for Bayesian model.
Parameters
----------
hdi_prob : float, optional
Probability for HDI interval. Defaults to
:data:`~causalpy.constants.HDI_PROB` (currently 0.94).
Returns
-------
pd.DataFrame
DataFrame with event_time, att, att_lower, att_upper columns.
Includes both pre-treatment (placebo) and post-treatment effects.
"""
# If the requested hdi_prob matches what was used during aggregation,
# return the pre-computed results
stored_hdi_prob = getattr(self, "hdi_prob_", HDI_PROB)
if np.isclose(hdi_prob, stored_hdi_prob):
return self.att_event_time_.copy()
# Recompute intervals with the requested hdi_prob
lower_pct = (1 - hdi_prob) / 2 * 100
upper_pct = (1 + hdi_prob) / 2 * 100
# Get posterior draws for mu
mu_draws = self.y_pred["posterior_predictive"].mu.isel(treated_units=0)
# Get observed y for all observations
y_observed = np.asarray(self.data[self.outcome_variable_name].values)
# Compute tau draws for all observations
tau_draws_all = y_observed - mu_draws.values
att_et_rows: list[dict] = []
# Pre-treatment placebo effects (eventually-treated units, event_time < 0)
is_eventually_treated = self.data["G"] != self.never_treated_value
is_pre_treatment = self.data["event_time"] < 0
pretreatment_data = self.data[is_eventually_treated & is_pre_treatment].copy()
if len(pretreatment_data) > 0:
pretreat_indices = pretreatment_data.index.values
pretreat_idx_positions = np.array(
[np.where(self.data.index == idx)[0][0] for idx in pretreat_indices]
)
tau_draws_pretreat = tau_draws_all[:, :, pretreat_idx_positions]
event_time_pretreat = np.asarray(pretreatment_data["event_time"].values)
event_times_pre = np.unique(
event_time_pretreat[~np.isnan(event_time_pretreat)]
)
if self.event_window is not None:
event_times_pre = event_times_pre[
(event_times_pre >= self.event_window[0])
& (event_times_pre <= self.event_window[1])
]
for e in sorted(event_times_pre):
e_mask = event_time_pretreat == e
if e_mask.sum() == 0:
continue
positions_arr = np.where(e_mask)[0]
tau_e = tau_draws_pretreat[:, :, positions_arr].mean(axis=2)
att_et_rows.append(
{
"event_time": int(e),
"att": float(tau_e.mean()),
"att_lower": float(np.percentile(tau_e, lower_pct)),
"att_upper": float(np.percentile(tau_e, upper_pct)),
"n_obs": int(e_mask.sum()),
}
)
# Post-treatment effects (treated observations, event_time >= 0)
_is_untreated = np.asarray(self.data["_is_untreated"].values, dtype=bool)
treated_mask = ~_is_untreated
treated_indices = np.where(treated_mask)[0]
tau_draws_treated = tau_draws_all[:, :, treated_indices]
treated_data = self.data[~self.data["_is_untreated"]].copy()
event_time_treated = np.asarray(treated_data["event_time"].values)
event_times_post = np.unique(event_time_treated[~np.isnan(event_time_treated)])
if self.event_window is not None:
event_times_post = event_times_post[
(event_times_post >= self.event_window[0])
& (event_times_post <= self.event_window[1])
]
for e in sorted(event_times_post):
e_mask = event_time_treated == e
if e_mask.sum() == 0:
continue
positions_arr = np.where(e_mask)[0]
tau_e = tau_draws_treated[:, :, positions_arr].mean(axis=2)
att_et_rows.append(
{
"event_time": int(e),
"att": float(tau_e.mean()),
"att_lower": float(np.percentile(tau_e, lower_pct)),
"att_upper": float(np.percentile(tau_e, upper_pct)),
"n_obs": int(e_mask.sum()),
}
)
return self._mark_non_identified_att_rows(pd.DataFrame(att_et_rows))
[docs]
def get_plot_data_ols(self) -> pd.DataFrame:
"""Get plotting data for OLS model.
Returns
-------
pd.DataFrame
DataFrame with event_time, att, att_std, n_obs columns.
"""
return self.att_event_time_.copy()
[docs]
def effect_summary(
self,
*,
direction: Literal["increase", "decrease", "two-sided"] = "increase",
alpha: float = 0.05,
min_effect: float | None = None,
**kwargs: Any,
) -> EffectSummary:
"""
Generate a decision-ready summary of causal effects for Staggered Difference-in-Differences.
Parameters
----------
direction : {"increase", "decrease", "two-sided"}, default="increase"
Direction for tail probability calculation (PyMC only, ignored for OLS).
alpha : float, default=0.05
Significance level for HDI/CI intervals (1-alpha confidence level).
min_effect : float, optional
Region of Practical Equivalence (ROPE) threshold (PyMC only, ignored for OLS).
**kwargs
Reserved for forward-compatibility; not consumed by this
implementation.
Returns
-------
EffectSummary
Object with .table (DataFrame) and .text (str) attributes
"""
from causalpy.reporting import _effect_summary_staggered_did
return _effect_summary_staggered_did(
self,
direction=direction,
alpha=alpha,
min_effect=min_effect,
)