From 2667779458f905214dfa595fbdd54abeb7977991 Mon Sep 17 00:00:00 2001 From: Apoorva Lal Date: Wed, 7 Aug 2024 21:32:40 -0700 Subject: [PATCH] test HC1 SE --- tests/test_fitter.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/test_fitter.py b/tests/test_fitter.py index c81e7d6..b68a5a6 100644 --- a/tests/test_fitter.py +++ b/tests/test_fitter.py @@ -5,6 +5,7 @@ from tests.utils import generate_sample_data, create_duckdb_database import duckdb import pandas as pd +import pyfixest as pf @pytest.fixture(scope="session") @@ -39,6 +40,14 @@ def get_numpy_coefficients(db_path, formula): return coeffs[1:] +def get_pyfixest_estimates(db_path, formula): + conn = duckdb.connect(db_path) + df = conn.execute("SELECT * FROM data").df() + conn.close() + m_pf = pf.feols(formula, data=df, vcov="hetero") + return m_pf + + @pytest.mark.parametrize( "fml", [ @@ -50,8 +59,6 @@ def get_numpy_coefficients(db_path, formula): def test_fitters(database, fml): db_path = database - uncompressed_coeffs = get_numpy_coefficients(db_path, fml) - m_duck = DuckRegression( db_name=db_path, table_name="data", @@ -61,14 +68,28 @@ def test_fitters(database, fml): seed=42, ) m_duck.fit() - + m_duck.fit_vcov() + # nobs np.testing.assert_allclose( m_duck.df_compressed["count"].sum(), 1_000_000, rtol=1e-4 ), "Number of observations are not equal" results = m_duck.summary() - compressed_coeffs = results["point_estimate"][1:] + compressed_coeffs, compressed_se = ( + results["point_estimate"][1:], + results["standard_error"][1:], + ) + uncompressed_coeffs = get_numpy_coefficients(db_path, fml) + uncompressed_coeffs2 = get_pyfixest_estimates(db_path, fml) np.testing.assert_allclose( compressed_coeffs, uncompressed_coeffs, rtol=1e-4 ), f"Coefficients are not equal for formula {fml}" + + np.testing.assert_allclose( + compressed_coeffs, uncompressed_coeffs2.coef().values[1:], rtol=1e-4 + ), f"Coefficients are not equal to pyfixest version for formula {fml}" + + np.testing.assert_allclose( + compressed_se, uncompressed_coeffs2.se().values[1:], rtol=1e-4 + ), f"Standard errors are not equal to pyfixest version for formula {fml}"