diff --git a/duckreg/duckreg.py b/duckreg/duckreg.py
index 54b4dd7..158a5a1 100644
--- a/duckreg/duckreg.py
+++ b/duckreg/duckreg.py
@@ -12,6 +12,7 @@ def __init__(
seed: int,
n_bootstraps: int = 100,
fitter="numpy",
+ keep_connection_open=False,
):
self.db_name = db_name
self.table_name = table_name
@@ -20,6 +21,7 @@ def __init__(
self.conn = duckdb.connect(db_name)
self.rng = np.random.default_rng(seed)
self.fitter = fitter
+ self.keep_connection_open = keep_connection_open
@abstractmethod
def prepare_data(self):
@@ -54,6 +56,7 @@ def fit(self):
self.point_estimate = self.estimate()
if self.n_bootstraps > 0:
self.vcov = self.bootstrap()
+ self.conn.close() if not self.keep_connection_open else None
return None
elif self.fitter == "feols":
fit = self.estimate_feols()
@@ -64,6 +67,7 @@ def fit(self):
fit.get_inference()
fit._vcov_type = "NP-Bootstrap"
fit._vcov_type_detail = "NP-Bootstrap"
+ self.conn.close() if not self.keep_connection_open else None
return fit
else:
@@ -73,7 +77,12 @@ def fit(self):
)
)
- def summary(self):
+ def summary(self) -> dict:
+ """Summary of regression
+
+ Returns:
+ dict
+ """
if self.n_bootstraps > 0:
return {
"point_estimate": self.point_estimate,
@@ -81,9 +90,22 @@ def summary(self):
}
return {"point_estimate": self.point_estimate}
+ def queries(self) -> dict:
+ """Collect all query methods in the class
-def wls(X: np.ndarray, y: np.ndarray, n: np.ndarray) -> np.ndarray:
+ Returns:
+ dict: Dictionary of query methods
+ """
+ self._query_names = [x for x in dir(self) if "query" in x]
+ self.queries = {
+ k: getattr(self, self._query_names[c])
+ for c, k in enumerate(self._query_names)
+ }
+ return self.queries
+
+def wls(X: np.ndarray, y: np.ndarray, n: np.ndarray) -> np.ndarray:
+ """Weighted least squares with frequency weights"""
N = np.sqrt(n)
N = N.reshape(-1, 1) if N.ndim == 1 else N
Xn = X * N
diff --git a/duckreg/estimators.py b/duckreg/estimators.py
index 5529acc..70167ac 100644
--- a/duckreg/estimators.py
+++ b/duckreg/estimators.py
@@ -18,8 +18,10 @@ def __init__(
cluster_col: str,
seed: int,
n_bootstraps: int = 100,
+ event_study: bool = False,
rowid_col: str = "rowid",
fitter: str = "numpy",
+ **kwargs,
):
super().__init__(
db_name=db_name,
@@ -27,6 +29,7 @@ def __init__(
seed=seed,
n_bootstraps=n_bootstraps,
fitter=fitter,
+ **kwargs,
)
self.formula = formula
self.cluster_col = cluster_col
@@ -34,7 +37,6 @@ def __init__(
self._parse_formula()
def _parse_formula(self):
-
lhs, rhs = self.formula.split("~")
rhs_deparsed = rhs.split("|")
covars, fevars = rhs.split("|") if len(rhs_deparsed) > 1 else (rhs, None)
@@ -76,7 +78,6 @@ def compress_data(self):
self.df_compressed.eval(create_means, inplace=True)
def collect_data(self, data: pd.DataFrame) -> pd.DataFrame:
-
y = data.filter(
regex=f"mean_{'(' + '|'.join(self.outcome_vars) + ')'}", axis=1
).values
@@ -124,7 +125,6 @@ def fit_vcov(self):
self.vcov = n_nk * (bread @ meat @ bread)
def estimate_feols(self):
-
if self.fevars:
fml = f"{'+'.join([f'mean_{x}' for x in self.outcome_vars])} ~ {' + '.join(self.covars)} | {' + '.join(self.fevars)}"
else:
@@ -234,12 +234,14 @@ def __init__(
time_col: str = None,
n_bootstraps: int = 100,
cluster_col: str = None,
+ **kwargs,
):
super().__init__(
db_name=db_name,
table_name=table_name,
seed=seed,
n_bootstraps=n_bootstraps,
+ **kwargs,
)
self.outcome_var = outcome_var
self.covariates = covariates
@@ -299,12 +301,12 @@ def compress_data(self):
{', ' + ', '.join([f'avg_{cov}_time' for cov in self.covariates]) if self.time_col is not None else ''}
"""
self.df_compressed = self.conn.execute(self.compress_query).fetchdf()
+
self.df_compressed[f"mean_{self.outcome_var}"] = (
self.df_compressed[f"sum_{self.outcome_var}"] / self.df_compressed["count"]
)
def collect_data(self, data: pd.DataFrame):
-
rhs = (
self.covariates
+ [f"avg_{cov}_unit" for cov in self.covariates]
@@ -326,7 +328,6 @@ def collect_data(self, data: pd.DataFrame):
return y, X, n
def estimate(self):
-
y, X, n = self.collect_data(data=self.df_compressed)
return wls(X, y, n)
@@ -402,8 +403,259 @@ def bootstrap(self):
################################################################################
+class DuckMundlakEventStudy(DuckReg):
+ def __init__(
+ self,
+ db_name: str,
+ table_name: str,
+ outcome_var: str,
+ treatment_col: str,
+ unit_col: str,
+ time_col: str,
+ cluster_col: str,
+ n_bootstraps: int = 100,
+ **kwargs,
+ ):
+ super().__init__(
+ db_name=db_name,
+ table_name=table_name,
+ n_bootstraps=n_bootstraps,
+ **kwargs,
+ )
+ self.table_name = table_name
+ self.outcome_var = outcome_var
+ self.treatment_col = treatment_col
+ self.unit_col = unit_col
+ self.time_col = time_col
+ self.num_periods = None
+ self.cohorts = None
+ self.time_dummies = None
+ self.post_treatment_dummies = None
+ self.transformed_query = None
+ self.compression_query = None
+ self.cluster_col = cluster_col
+ def prepare_data(self):
+ # create_cohort_and_ever_treated_columns
+ self.cohort_query = f"""
+ ALTER TABLE {self.table_name} ADD COLUMN cohort INTEGER;
+ UPDATE {self.table_name} SET cohort = (
+ SELECT MIN({self.time_col})
+ FROM {self.table_name} AS p2
+ WHERE p2.{self.unit_col} = {self.table_name}.{self.unit_col} AND p2.{self.treatment_col} = 1
+ );
+ """
+ self.conn.execute(self.cohort_query)
+ self.ever_treated_query = f"""
+ UPDATE {self.table_name} SET cohort = NULL WHERE cohort = 2147483647; -- Set to NULL if never treated
+ ALTER TABLE {self.table_name} ADD COLUMN ever_treated INTEGER;
+ UPDATE {self.table_name} SET ever_treated = CASE WHEN cohort IS NOT NULL THEN 1 ELSE 0 END;
+ """
+ self.conn.execute(self.ever_treated_query)
+ # retrieve_num_periods_and_cohorts
+ self.num_periods = self.conn.execute(
+ f"SELECT MAX({self.time_col}) FROM {self.table_name}"
+ ).fetchone()[0]
+ cohorts = self.conn.execute(
+ f"SELECT DISTINCT cohort FROM {self.table_name} WHERE cohort IS NOT NULL"
+ ).fetchall()
+ self.cohorts = [row[0] for row in cohorts]
+ # generate_time_dummies
+ self.time_dummies = ",\n".join(
+ [
+ f"CASE WHEN {self.time_col} = {i} THEN 1 ELSE 0 END AS time_{i}"
+ for i in range(self.num_periods + 1)
+ ]
+ )
+ # generate cohort dummies
+ cohort_intercepts = []
+ for cohort in self.cohorts:
+ cohort_intercepts.append(
+ f"CASE WHEN cohort = {cohort} THEN 1 ELSE 0 END AS cohort_{cohort}"
+ )
+ self.cohort_intercepts = ",\n".join(cohort_intercepts)
+
+ # generate_treatment_dummies
+ treatment_dummies = []
+ for cohort in self.cohorts:
+ for i in range(self.num_periods + 1):
+ treatment_dummies.append(
+ f"CASE WHEN cohort = {cohort} AND {self.time_col} = {i} THEN 1 ELSE 0 END AS treatment_time_{cohort}_{i}"
+ )
+ self.treatment_dummies = ",\n".join(treatment_dummies)
+ # create_transformed_query
+ self.design_matrix_query = f"""
+ CREATE TEMP TABLE transformed_panel_data AS
+ SELECT
+ p.{self.unit_col},
+ p.{self.time_col},
+ p.{self.treatment_col},
+ p.{self.outcome_var},
+ -- Intercept (constant term)
+ 1 AS intercept,
+ -- cohort intercepts
+ {self.cohort_intercepts},
+ -- Time dummies for each period
+ {self.time_dummies},
+ -- Treated group interacted with treatment time dummies
+ {self.treatment_dummies}
+ FROM
+ {self.table_name} p;
+ """
+ self.conn.execute(self.design_matrix_query)
+
+ def compress_data(self):
+ self.rhs = f"""
+ intercept,
+ {", ".join([f"cohort_{cohort}" for cohort in self.cohorts])},
+ {", ".join([f"time_{i}" for i in range(self.num_periods + 1)])},
+ {", ".join([f"treatment_time_{cohort}_{i}" for cohort in self.cohorts for i in range(self.num_periods + 1)])};
+ """
+ self.compression_query = f"""
+ CREATE TEMP TABLE compressed_panel_data AS
+ SELECT
+ {self.rhs.replace(";", "")},
+ COUNT(*) AS count,
+ SUM({self.outcome_var}) AS sum_{self.outcome_var}
+ FROM
+ transformed_panel_data
+ GROUP BY
+ {self.rhs}
+ """
+ self.conn.execute(self.compression_query)
+ self.df_compressed = self.conn.execute(
+ "SELECT * FROM compressed_panel_data"
+ ).fetchdf()
+ self.df_compressed[f"mean_{self.outcome_var}"] = (
+ self.df_compressed[f"sum_{self.outcome_var}"] / self.df_compressed["count"]
+ )
+
+ def collect_data(self, data):
+ self._rhs_list = [x.strip().replace(";", "") for x in self.rhs.split(",")]
+ X = data[self._rhs_list].values
+ y = data[f"mean_{self.outcome_var}"].values
+ n = data["count"].values
+
+ y = y.reshape(-1, 1) if y.ndim == 1 else y
+ X = X.reshape(-1, 1) if X.ndim == 1 else X
+ return y, X, n
+
+ def estimate(self):
+ y, X, n = self.collect_data(data=self.df_compressed)
+ coef = wls(X, y, n)
+ res = pd.DataFrame(
+ {
+ "est": coef.squeeze(),
+ },
+ index=self._rhs_list,
+ )
+ cohort_names = [x.split("_")[1] for x in self._rhs_list if "cohort_" in x]
+ event_study_coefs = {}
+ for c in cohort_names:
+ offset = res.filter(regex=f"^cohort_{c}", axis=0).values
+ event_study_coefs[c] = (
+ res.filter(regex=f"treatment_time_{c}_", axis=0) + offset
+ )
+
+ return event_study_coefs
+
+ def bootstrap(self):
+ # list all clusters
+ total_clusters = self.conn.execute(
+ f"SELECT COUNT(DISTINCT {self.cluster_col}) FROM transformed_panel_data"
+ ).fetchone()[0]
+ boot_coefs = {str(cohort): [] for cohort in self.cohorts}
+ # bootstrap loop
+ for _ in tqdm(range(self.n_bootstraps)):
+ resampled_clusters = (
+ self.conn.execute(
+ f"SELECT UNNEST(ARRAY(SELECT {self.cluster_col} FROM transformed_panel_data ORDER BY RANDOM() LIMIT {total_clusters}))"
+ )
+ .fetchdf()
+ .values.flatten()
+ .tolist()
+ )
+
+ self.conn.execute(
+ f"""
+ CREATE TEMP TABLE resampled_transformed_panel_data AS
+ SELECT * FROM transformed_panel_data
+ WHERE {self.cluster_col} IN ({', '.join(map(str, resampled_clusters))})
+ """
+ )
+
+ self.conn.execute(
+ f"""
+ CREATE TEMP TABLE resampled_compressed_panel_data AS
+ SELECT
+ {self.rhs.replace(";", "")},
+ COUNT(*) AS count,
+ SUM({self.outcome_var}) AS sum_{self.outcome_var}
+ FROM
+ resampled_transformed_panel_data
+ GROUP BY
+ {self.rhs.replace(";", "")}
+ """
+ )
+
+ df_boot = self.conn.execute(
+ "SELECT * FROM resampled_compressed_panel_data"
+ ).fetchdf()
+ df_boot[f"mean_{self.outcome_var}"] = (
+ df_boot[f"sum_{self.outcome_var}"] / df_boot["count"]
+ )
+
+ y, X, n = self.collect_data(data=df_boot)
+ coef = wls(X, y, n)
+ res = pd.DataFrame(
+ {
+ "est": coef.squeeze(),
+ },
+ index=self._rhs_list,
+ )
+ cohort_names = [x.split("_")[1] for x in self._rhs_list if "cohort_" in x]
+ for c in cohort_names:
+ offset = res.filter(regex=f"^cohort_{c}", axis=0).values
+ event_study_coefs = (
+ res.filter(regex=f"treatment_time_{c}_", axis=0) + offset
+ )
+ boot_coefs[c].append(event_study_coefs.values.flatten())
+
+ self.conn.execute("DROP TABLE resampled_transformed_panel_data")
+ self.conn.execute("DROP TABLE resampled_compressed_panel_data")
+ # Calculate the covariance matrix for each cohort
+ bootstrap_cov_matrix = {
+ cohort: np.cov(np.array(coefs).T) for cohort, coefs in boot_coefs.items()
+ }
+ return bootstrap_cov_matrix
+
+ def estimate_feols(self):
+ raise NotImplementedError(
+ "feols solver not implemented for Mundlak event study estimator"
+ )
+
+ def summary(self) -> dict:
+ """Summary of event study regression (overrides the parent class method)
+
+ Returns:
+ dict of event study coefficients and their standard errors
+ """
+ if self.n_bootstraps > 0:
+ summary_tables = {}
+ for c in self.point_estimate.keys():
+ point_estimate = self.point_estimate[c]
+ se = np.sqrt(np.diag(self.vcov[c]))
+ summary_tables[c] = pd.DataFrame(
+ np.c_[point_estimate, se],
+ columns=["point_estimate", "se"],
+ index=point_estimate.index,
+ )
+ return summary_tables
+ return {"point_estimate": self.point_estimate}
+
+
+################################################################################
class DuckDoubleDemeaning(DuckReg):
def __init__(
self,
@@ -416,12 +668,14 @@ def __init__(
seed: int,
n_bootstraps: int = 100,
cluster_col: str = None,
+ **kwargs,
):
super().__init__(
db_name=db_name,
table_name=table_name,
seed=seed,
n_bootstraps=n_bootstraps,
+ **kwargs,
)
self.outcome_var = outcome_var
self.treatment_var = treatment_var
@@ -486,7 +740,6 @@ def compress_data(self):
)
def collect_data(self, data: pd.DataFrame):
-
X = data[f"ddot_{self.treatment_var}"].values
X = np.c_[np.ones(X.shape[0]), X]
y = data[f"mean_{self.outcome_var}"].values
@@ -558,4 +811,4 @@ def bootstrap(self):
return np.cov(boot_coefs.T)
-######################################################################
+################################################################################
diff --git a/notebooks/event_study.ipynb b/notebooks/event_study.ipynb
new file mode 100644
index 0000000..6166df5
--- /dev/null
+++ b/notebooks/event_study.ipynb
@@ -0,0 +1,828 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# `duckreg` for panel data: Applying Mundlak Regression to estimate Event Studies at scale"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "from duckreg.estimators import DuckMundlak, DuckMundlakEventStudy\n",
+ "import duckdb\n",
+ "import pyfixest as pf\n",
+ "\n",
+ "np.random.seed(42)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def panel_dgp(\n",
+ " num_units=1000,\n",
+ " num_periods=30,\n",
+ " num_treated=50,\n",
+ " treatment_start=15,\n",
+ " hetfx=False,\n",
+ " base_treatment_effect=np.repeat(0, 15),\n",
+ " ar_coef=0.2, # Autoregressive coefficient for epsilon_it\n",
+ " sigma_unit=1,\n",
+ " sigma_time=0.5,\n",
+ " sigma_epsilon=0.5, # Standard deviation of epsilon_it\n",
+ "):\n",
+ " unit_intercepts = np.random.normal(0, sigma_unit, num_units)\n",
+ "\n",
+ " # Generate day-of-the-week pattern\n",
+ " day_effects = np.array(\n",
+ " [-0.1, 0.1, 0, 0, 0.1, 0.5, 0.5]\n",
+ " ) # Stronger effects on weekends\n",
+ " day_pattern = np.tile(day_effects, num_periods // 7 + 1)[:num_periods]\n",
+ "\n",
+ " # Generate autoregressive structure\n",
+ " ar_coef_time = 0.2\n",
+ " ar_noise_time = np.random.normal(0, sigma_time, num_periods)\n",
+ " time_intercepts = np.zeros(num_periods)\n",
+ " time_intercepts[0] = ar_noise_time[0]\n",
+ " for t in range(1, num_periods):\n",
+ " time_intercepts[t] = ar_coef_time * time_intercepts[t - 1] + ar_noise_time[t]\n",
+ " # Combine day-of-the-week pattern and autoregressive structure\n",
+ " time_intercepts = day_pattern + time_intercepts - np.mean(time_intercepts)\n",
+ " # Generate autoregressive noise for each unit\n",
+ " ar_noise = np.random.normal(0, sigma_epsilon, (num_units, num_periods))\n",
+ " noise = np.zeros((num_units, num_periods))\n",
+ " noise[:, 0] = ar_noise[:, 0]\n",
+ " for t in range(1, num_periods):\n",
+ " noise[:, t] = ar_coef * noise[:, t - 1] + ar_noise[:, t]\n",
+ " # N X T matrix of potential outcomes under control\n",
+ " Y0 = unit_intercepts[:, np.newaxis] + time_intercepts[np.newaxis, :] + noise\n",
+ " # Generate the base treatment effect (concave structure)\n",
+ " # Generate heterogeneous multipliers for each unit\n",
+ " if hetfx:\n",
+ " heterogeneous_multipliers = np.random.uniform(0.5, 1.5, num_units)\n",
+ " else:\n",
+ " heterogeneous_multipliers = np.ones(num_units)\n",
+ "\n",
+ " # Create a 2D array to store the heterogeneous treatment effects\n",
+ " treatment_effect = np.zeros((num_units, num_periods - treatment_start))\n",
+ " for i in range(num_units):\n",
+ " treatment_effect[i, :] = heterogeneous_multipliers[i] * base_treatment_effect\n",
+ "\n",
+ " # random assignment\n",
+ " treated_units = np.random.choice(num_units, num_treated, replace=False)\n",
+ " treatment_status = np.zeros((num_units, num_periods), dtype=bool)\n",
+ " treatment_status[treated_units, treatment_start:] = True\n",
+ "\n",
+ " # Apply the heterogeneous treatment effect to the treated units\n",
+ " Y1 = Y0.copy()\n",
+ " for t in range(treatment_start, num_periods):\n",
+ " Y1[:, t][treatment_status[:, t]] += treatment_effect[:, t - treatment_start][\n",
+ " treatment_status[:, t]\n",
+ " ]\n",
+ "\n",
+ " # Create a DataFrame\n",
+ " unit_ids = np.repeat(np.arange(num_units), num_periods)\n",
+ " time_ids = np.tile(np.arange(num_periods), num_units)\n",
+ " W_it = treatment_status.flatten()\n",
+ " Y_it = np.where(W_it, Y1.flatten(), Y0.flatten())\n",
+ " df = pd.DataFrame(\n",
+ " {\n",
+ " \"unit_id\": unit_ids,\n",
+ " \"time_id\": time_ids,\n",
+ " \"W_it\": W_it.astype(int),\n",
+ " \"Y_it\": Y_it,\n",
+ " }\n",
+ " )\n",
+ " return df\n",
+ "\n",
+ "\n",
+ "# Function to create and populate DuckDB database\n",
+ "def create_duckdb_database(df, db_name=\"large_dataset.db\", table=\"panel_data\"):\n",
+ " conn = duckdb.connect(db_name)\n",
+ " conn.execute(f\"DROP TABLE IF EXISTS {table}\")\n",
+ " conn.execute(f\"CREATE TABLE {table} AS SELECT * FROM df\")\n",
+ " conn.close()\n",
+ " print(f\"Data loaded into DuckDB database: {db_name}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "num_periods = 30\n",
+ "treat_start_period = 15\n",
+ "treat_effect_vector= 0.2 * np.log(2 * np.arange(1, num_periods - treat_start_period + 1))\n",
+ "treat_effect_vector[8:] = 0 # switch off effects after a week\n",
+ "sigma_i, sigma_t = 2, 1\n",
+ "event_study_true = np.r_[np.repeat(0, num_periods-treat_start_period), treat_effect_vector]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Data loaded into DuckDB database: event_study_data.db\n"
+ ]
+ }
+ ],
+ "source": [
+ "df = panel_dgp(\n",
+ " num_units=10_000, num_treated= 5_000,\n",
+ " num_periods=30,\n",
+ " treatment_start = treat_start_period,\n",
+ " hetfx=False,\n",
+ " base_treatment_effect = treat_effect_vector,\n",
+ " sigma_unit = sigma_i, sigma_time = sigma_t,\n",
+ ")\n",
+ "\n",
+ "db_name = 'event_study_data.db'\n",
+ "create_duckdb_database(df, db_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unit_id | \n",
+ " time_id | \n",
+ " W_it | \n",
+ " Y_it | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0.436139 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 1.016883 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 0.480516 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0 | \n",
+ " 3 | \n",
+ " 0 | \n",
+ " 0.578209 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0 | \n",
+ " 4 | \n",
+ " 0 | \n",
+ " 1.066747 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unit_id time_id W_it Y_it\n",
+ "0 0 0 0 0.436139\n",
+ "1 0 1 0 1.016883\n",
+ "2 0 2 0 0.480516\n",
+ "3 0 3 0 0.578209\n",
+ "4 0 4 0 1.066747"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Static Specification\n",
+ "\n",
+ "### Two-way Mundlak"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:01<00:00, 65.08it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "point_estimate 0.203918\n",
+ "standard_error 0.003637\n",
+ "Name: 1, dtype: float64"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mundlak = DuckMundlak(\n",
+ " db_name=\"event_study_data.db\",\n",
+ " table_name=\"panel_data\",\n",
+ " outcome_var=\"Y_it\",\n",
+ " covariates=[\"W_it\"],\n",
+ " unit_col=\"unit_id\",\n",
+ " time_col=\"time_id\",\n",
+ " cluster_col=\"unit_id\",\n",
+ " n_bootstraps=100,\n",
+ " seed = 42\n",
+ ")\n",
+ "mundlak.fit()\n",
+ "\n",
+ "mundlak_results = mundlak.summary()\n",
+ "\n",
+ "restab = pd.DataFrame(\n",
+ " np.c_[mundlak_results[\"point_estimate\"], mundlak_results[\"standard_error\"]],\n",
+ " columns=[\"point_estimate\", \"standard_error\"],\n",
+ ").iloc[1, :]\n",
+ "restab"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.21533040462966416"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "treat_effect_vector.mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The Two-way Mundlak specification consistently recovers the average treatment effect in the post-treatment period. Under staggered adoption, however, this guarantee is lost. More on this later. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Dynamic Specification\n",
+ "\n",
+ "### Single Treatment Cohort"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df[\"ever_treated\"] = df.groupby(\"unit_id\")[\"W_it\"].transform(\"max\")\n",
+ "m2 = pf.feols(\"\"\"\n",
+ " Y_it ~ i(time_id, ever_treated, ref = 14) |\n",
+ " unit_id + time_id\n",
+ " \"\"\",\n",
+ " df\n",
+ " )\n",
+ "evstudy_coefs = m2.coef().values\n",
+ "# insert zero in reference period\n",
+ "evstudy_coefs = np.insert(evstudy_coefs, 14, 0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### compressed estimation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " unit_id time_id W_it Y_it\n",
+ "0 0 0 0 0.436139\n",
+ "1 0 1 0 1.016883\n",
+ "2 0 2 0 0.480516\n",
+ "3 0 3 0 0.578209\n",
+ "4 0 4 0 1.066747\n"
+ ]
+ }
+ ],
+ "source": [
+ "conn = duckdb.connect(\"event_study_data.db\")\n",
+ "print(conn.execute(\"SELECT * FROM panel_data LIMIT 5\").fetchdf())\n",
+ "conn.close()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mundlak = DuckMundlakEventStudy(\n",
+ " db_name=\"event_study_data.db\",\n",
+ " table_name=\"panel_data\",\n",
+ " outcome_var=\"Y_it\",\n",
+ " treatment_col=\"W_it\",\n",
+ " unit_col=\"unit_id\",\n",
+ " time_col=\"time_id\",\n",
+ " cluster_col=\"unit_id\",\n",
+ " n_bootstraps=0, # set to nonzero to get block-bootstrapped standard errors\n",
+ " seed=42,\n",
+ ")\n",
+ "\n",
+ "mundlak.fit()\n",
+ "evsum = mundlak.summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "