From f20fc361e5f6658309723a924e42ff7f544fedce Mon Sep 17 00:00:00 2001 From: Apoorva Lal Date: Mon, 5 Aug 2024 22:39:24 -0700 Subject: [PATCH] configure tests --- duckreg/estimators.py | 7 ++- tests/conftest.py | 15 +++++++ tests/test_fitter.py | 89 +++++++++++++++++++++++++-------------- tests/test_vs_pyfixest.py | 66 ----------------------------- tests/utils.py | 24 +++++++---- 5 files changed, 93 insertions(+), 108 deletions(-) create mode 100644 tests/conftest.py delete mode 100644 tests/test_vs_pyfixest.py diff --git a/duckreg/estimators.py b/duckreg/estimators.py index 5b69d56..2daa59f 100644 --- a/duckreg/estimators.py +++ b/duckreg/estimators.py @@ -141,6 +141,7 @@ def estimate_feols(self): return fit def bootstrap(self): + self.se = "bootstrap" if self.fevars: boot_coefs = np.zeros( (self.n_bootstraps, len(self.covars) * len(self.outcome_vars)) @@ -207,8 +208,10 @@ def bootstrap(self): return vcov - def summary(self): # ovveride the summary method to include the heteroskedasticity-robust variance covariance matrix when available - if self.n_bootstraps > 0 or self.se == "hc1": + def summary( + self, + ): # ovveride the summary method to include the heteroskedasticity-robust variance covariance matrix when available + if self.n_bootstraps > 0 or (hasattr(self, "se") and self.se == "hc1"): return { "point_estimate": self.point_estimate, "standard_error": np.sqrt(np.diag(self.vcov)), diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..904866f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--force-regen", + action="store_true", + default=False, + help="Force regeneration of test data", + ) + + +@pytest.fixture(scope="session") +def force_regen(request): + return request.config.getoption("--force-regen") diff --git a/tests/test_fitter.py b/tests/test_fitter.py index 3b2a206..c81e7d6 100644 --- a/tests/test_fitter.py +++ b/tests/test_fitter.py @@ -1,47 +1,74 @@ -import pytest import numpy as np +import pytest +import os from duckreg.estimators import DuckRegression from tests.utils import generate_sample_data, create_duckdb_database +import duckdb +import pandas as pd + @pytest.fixture(scope="session") -def database(): - df = generate_sample_data() - db_name = 'test_dataset.db' - create_duckdb_database(df, db_name) +def get_data(force_regen): + if force_regen: + return generate_sample_data(1_000_000, seed=42) + else: + return generate_sample_data(1_000_000, seed=42) -@pytest.mark.parametrize("fml", ["Y ~ D", "Y ~ D + f1", "Y ~ D + f1 + f2"]) -@pytest.mark.parametrize("cluster_col", ["f1"]) -def test_fitters(fml, cluster_col): - m_duck = DuckRegression( - db_name='test_dataset.db', - table_name='data', - formula=fml, - cluster_col=cluster_col, - n_bootstraps=20, - seed = 42 - ) - m_duck.fit() +@pytest.fixture(scope="session") +def database(get_data, force_regen): + df = get_data + db_name = "test_dataset.db" + if force_regen and os.path.exists(db_name): + os.remove(db_name) + db_path = create_duckdb_database(df, db_name) + return db_path - m_feols = DuckRegression( - db_name='test_dataset.db', - table_name='data', - formula=fml, - cluster_col=cluster_col, - n_bootstraps=20, - seed = 42, - fitter = "feols" - ).fit() +def get_numpy_coefficients(db_path, formula): + conn = duckdb.connect(db_path) + df = conn.execute("SELECT * FROM data").df() + conn.close() - results = m_duck.summary() - coefs = results["point_estimate"] - se = results["standard_error"] + y = df["Y"].values + X_cols = [x.strip() for x in formula.split("~")[1].strip().split("+")] + X = df[X_cols].values + X = np.column_stack([np.ones(X.shape[0]), X]) + + coeffs = np.linalg.inv(X.T @ X) @ X.T @ y + return coeffs[1:] - assert np.all(np.abs(coefs) - np.abs(m_feols.coef().values) < 1e-12), "Coeficients are not equal" - assert np.all(np.abs(se) - np.abs(m_feols.se().values) < 1e-12), "Standard errors are not equal" +@pytest.mark.parametrize( + "fml", + [ + "Y ~ D", + "Y ~ D + f1", + "Y ~ D + f1 + f2", + ], +) +def test_fitters(database, fml): + db_path = database + uncompressed_coeffs = get_numpy_coefficients(db_path, fml) + + m_duck = DuckRegression( + db_name=db_path, + table_name="data", + formula=fml, + cluster_col="", + n_bootstraps=0, + seed=42, + ) + m_duck.fit() + np.testing.assert_allclose( + m_duck.df_compressed["count"].sum(), 1_000_000, rtol=1e-4 + ), "Number of observations are not equal" + results = m_duck.summary() + compressed_coeffs = results["point_estimate"][1:] + np.testing.assert_allclose( + compressed_coeffs, uncompressed_coeffs, rtol=1e-4 + ), f"Coefficients are not equal for formula {fml}" diff --git a/tests/test_vs_pyfixest.py b/tests/test_vs_pyfixest.py deleted file mode 100644 index 292b0ac..0000000 --- a/tests/test_vs_pyfixest.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest -import numpy as np -from duckreg.estimators import DuckRegression -from tests.utils import generate_sample_data, create_duckdb_database -import pyfixest as pf - -@pytest.fixture(scope="session") -def get_data(): - return generate_sample_data() - -@pytest.fixture(scope="session") -def database(get_data): - df = get_data - db_name = 'test_dataset.db' - create_duckdb_database(df, db_name) - return df - - -@pytest.mark.parametrize("fml", ["Y ~ D", "Y ~ D + f1", "Y ~ D + f1 + f2"]) -@pytest.mark.parametrize("cluster_col", [""]) - -def test_vs_pyfixest_deterministic(get_data, fml, cluster_col): - - m_duck = DuckRegression( - db_name='test_dataset.db', - table_name='data', - formula=fml, - cluster_col=cluster_col, - n_bootstraps=0, - seed = 42 - ) - m_duck.fit() - m_duck.fit_vcov() - - m_feols = pf.feols( - fml, - data = get_data, - vcov = "hetero" if cluster_col == "" else {"CRV1": cluster_col}, - ssc = pf.ssc(adj = False, cluster_adj = True) - ) - - results = m_duck.summary() - coefs = results["point_estimate"] - se = results["standard_error"] - - assert np.all(np.abs(coefs) - np.abs(m_feols.coef().values) < 1e-8), "Coeficients are not equal" - assert np.all(np.abs(se) - np.abs(m_feols.se().values) < 1e-4), "Standard errors are not equal" - -def test_multiple_estimation_stochastic(): - - pass - - -def test_vs_pyfixest_stochastic(): - - pass - - -def test_mundlak(): - - pass - - -def test_double_demeaning(): - - pass \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index bd13836..c06f1b1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,11 @@ +import os import numpy as np import pandas as pd import duckdb + # Generate sample data -def generate_sample_data(N=10_000, seed=12345): +def generate_sample_data(N=10_000_000, seed=42): rng = np.random.default_rng(seed) D = rng.choice([0, 1], size=(N, 1)) X = rng.choice(range(20), (N, 2), True) @@ -11,14 +13,18 @@ def generate_sample_data(N=10_000, seed=12345): Y2 = -1 * D + X @ np.array([1, 2]).reshape(2, 1) + rng.normal(size=(N, 1)) df = pd.DataFrame( np.concatenate([Y, Y2, D, X], axis=1), columns=["Y", "Y2", "D", "f1", "f2"] - ).assign(rowid=range(N)) + ) return df -# Function to create and populate DuckDB database -def create_duckdb_database(df, db_name="large_dataset.db", table="data"): - conn = duckdb.connect(db_name) - conn.execute(f"DROP TABLE IF EXISTS {table}") - conn.execute(f"CREATE TABLE {table} AS SELECT * FROM df") - conn.close() - print(f"Data loaded into DuckDB database: {db_name}") \ No newline at end of file +def create_duckdb_database(df, db_name="test_dataset.db", table="data"): + db_path = os.path.abspath(db_name) + conn = duckdb.connect(db_path) + try: + conn.execute(f"DROP TABLE IF EXISTS {table}") + conn.execute(f"CREATE TABLE {table} AS SELECT * FROM df") + result = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() + print(f"Created table '{table}' with {result[0]} rows in database: {db_path}") + finally: + conn.close() + return db_path