Skip to content

Commit

Permalink
drop alter statement to adhere to rule of never writing to source
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvalal committed Aug 23, 2024
1 parent 04f16a9 commit ccf6933
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 55 deletions.
27 changes: 13 additions & 14 deletions duckreg/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,27 +439,26 @@ def __init__(

def prepare_data(self):
# create_cohort_and_ever_treated_columns
self.cohort_query = f"""
ALTER TABLE {self.table_name} ADD COLUMN cohort INTEGER;
UPDATE {self.table_name} SET cohort = (
self.temp_table_query = f"""
CREATE TEMP TABLE temp_{self.table_name} AS
SELECT *, NULL::INTEGER AS cohort, NULL::INTEGER AS ever_treated
FROM {self.table_name};
UPDATE temp_{self.table_name} SET cohort = (
SELECT MIN({self.time_col})
FROM {self.table_name} AS p2
WHERE p2.{self.unit_col} = {self.table_name}.{self.unit_col} AND p2.{self.treatment_col} = 1
WHERE p2.{self.unit_col} = temp_{self.table_name}.{self.unit_col} AND
p2.{self.treatment_col} = 1
);
UPDATE temp_{self.table_name} SET cohort = NULL WHERE cohort = 2147483647;
UPDATE temp_{self.table_name} SET ever_treated = CASE WHEN cohort IS NOT NULL THEN 1 ELSE 0 END;
"""
self.conn.execute(self.cohort_query)
self.ever_treated_query = f"""
UPDATE {self.table_name} SET cohort = NULL WHERE cohort = 2147483647; -- Set to NULL if never treated
ALTER TABLE {self.table_name} ADD COLUMN ever_treated INTEGER;
UPDATE {self.table_name} SET ever_treated = CASE WHEN cohort IS NOT NULL THEN 1 ELSE 0 END;
"""
self.conn.execute(self.ever_treated_query)
self.conn.execute(self.temp_table_query)
# retrieve_num_periods_and_cohorts
self.num_periods = self.conn.execute(
f"SELECT MAX({self.time_col}) FROM {self.table_name}"
f"SELECT MAX({self.time_col}) FROM temp_{self.table_name}"
).fetchone()[0]
cohorts = self.conn.execute(
f"SELECT DISTINCT cohort FROM {self.table_name} WHERE cohort IS NOT NULL"
f"SELECT DISTINCT cohort FROM temp_{self.table_name} WHERE cohort IS NOT NULL"
).fetchall()
self.cohorts = [row[0] for row in cohorts]
# generate_time_dummies
Expand Down Expand Up @@ -506,7 +505,7 @@ def prepare_data(self):
-- Treated group interacted with treatment time dummies
{self.treatment_dummies}
FROM
{self.table_name} p;
temp_{self.table_name} p;
"""
self.conn.execute(self.design_matrix_query)

Expand Down
Loading

0 comments on commit ccf6933

Please sign in to comment.