From 13eeed14847fd3862c9b18bca4f3a963fcc4874d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Thu, 4 Jul 2024 10:50:10 +0200 Subject: [PATCH] Add test propensity model reuse --- tests/test_grid_search.py | 57 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py index 762f163..aeae5ec 100644 --- a/tests/test_grid_search.py +++ b/tests/test_grid_search.py @@ -8,7 +8,7 @@ from metalearners.drlearner import DRLearner from metalearners.grid_search import MetaLearnerGridSearch -from metalearners.metalearner import VARIANT_OUTCOME_MODEL +from metalearners.metalearner import PROPENSITY_MODEL, VARIANT_OUTCOME_MODEL from metalearners.rlearner import RLearner from metalearners.slearner import SLearner from metalearners.tlearner import TLearner @@ -131,7 +131,7 @@ def test_metalearnergridsearch_smoke( assert train_scores_cols == test_scores_cols -def test_metalearnergridsearch_reuse_smoke(rng): +def test_metalearnergridsearch_reuse_nuisance_smoke(rng): n_variants = 3 n_samples = 250 n_test_samples = 100 @@ -188,3 +188,56 @@ def test_metalearnergridsearch_reuse_smoke(rng): } assert gs.results_ is not None assert gs.results_.shape[0] == 8 + + +def test_metalearnergridsearch_reuse_propensity_smoke(rng): + n_variants = 3 + n_samples = 250 + n_test_samples = 100 + + X = rng.standard_normal((n_samples, 3)) + X_test = rng.standard_normal((n_test_samples, 3)) + y = rng.standard_normal(n_samples) + y_test = rng.standard_normal(n_test_samples) + w = rng.integers(0, n_variants, n_samples) + w_test = rng.integers(0, n_variants, n_test_samples) + + rl = RLearner( + False, + n_variants, + LGBMRegressor, + LGBMRegressor, + LGBMClassifier, + nuisance_model_params={"verbose": -1, "n_estimators": 1}, + treatment_model_params={"verbose": -1, "n_estimators": 1}, + propensity_model_params={"verbose": -1, "n_estimators": 1}, + n_folds=2, + ) + rl.fit(X, y, w) + + gs = MetaLearnerGridSearch( + DRLearner, + { + "is_classification": False, + "n_variants": n_variants, + "n_folds": 5, # To test with different n_folds than the pretrained + "fitted_propensity_model": rl._nuisance_models[PROPENSITY_MODEL][0], + }, + { + "treatment_model": [LGBMRegressor], + "variant_outcome_model": [LinearRegression], + }, + { + "treatment_model": { + "LGBMRegressor": {"n_estimators": [1, 2], "verbose": [-1]} + }, + }, + verbose=3, + random_state=1, + ) + gs.fit(X, y, w, X_test, y_test, w_test) + assert gs.raw_results_ is not None + for raw_result in gs.raw_results_: + assert raw_result.metalearner._prefitted_nuisance_models == {PROPENSITY_MODEL} + assert gs.results_ is not None + assert gs.results_.shape[0] == 2