diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index 0951aba..13d55f9 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -100,8 +100,6 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None): return np.zeros((len(X), 2, 1)) -@pytest.mark.parametrize("nuisance_model_factory", [LGBMRegressor]) -@pytest.mark.parametrize("treatment_model_factory", [LGBMRegressor]) @pytest.mark.parametrize("is_classification", [True, False]) @pytest.mark.parametrize("nuisance_model_params", [None, {}, {"n_estimators": 5}]) @pytest.mark.parametrize("treatment_model_params", [None, {}, {"n_estimators": 5}]) @@ -110,24 +108,36 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None): [ None, { - "nuisance1": ["X1"], - "nuisance2": ["X2"], - "propensity_model": ["Xp"], - "treatment1": ["X1"], - "treatment2": ["X2"], + VARIANT_OUTCOME_MODEL: ["X1"], + CONTROL_EFFECT_MODEL: ["X2"], + TREATMENT_EFFECT_MODEL: ["X1"], + TREATMENT_MODEL: ["X2"], + PROPENSITY_MODEL: ["Xp"], + OUTCOME_MODEL: ["X1"], }, ], ) @pytest.mark.parametrize( - "n_folds", [5, {"nuisance1": 1, "nuisance2": 1, "treatment1": 5, "treatment2": 10}] + "n_folds", + [ + 5, + { + VARIANT_OUTCOME_MODEL: 5, + CONTROL_EFFECT_MODEL: 5, + TREATMENT_EFFECT_MODEL: 5, + TREATMENT_MODEL: 5, + PROPENSITY_MODEL: 5, + OUTCOME_MODEL: 5, + }, + ], ) -@pytest.mark.parametrize("propensity_model_factory", [None, LGBMClassifier]) @pytest.mark.parametrize("propensity_model_params", [None, {}, {"n_estimators": 5}]) @pytest.mark.parametrize("n_variants", [2, 5, 10]) +@pytest.mark.parametrize( + "implementation", + [TLearner, XLearner, RLearner, DRLearner], +) def test_metalearner_init( - nuisance_model_factory, - treatment_model_factory, - propensity_model_factory, is_classification, n_variants, nuisance_model_params, @@ -135,8 +145,12 @@ def test_metalearner_init( propensity_model_params, feature_set, n_folds, + implementation, ): - _TestMetaLearner( + propensity_model_factory = LGBMClassifier + nuisance_model_factory = LGBMClassifier if is_classification else LGBMRegressor + treatment_model_factory = LGBMRegressor + model = implementation( nuisance_model_factory=nuisance_model_factory, is_classification=is_classification, n_variants=n_variants, @@ -148,6 +162,12 @@ def test_metalearner_init( feature_set=feature_set, n_folds=n_folds, ) + all_base_models = set(model.nuisance_model_specifications().keys()) | set( + model.treatment_model_specifications().keys() + ) + assert set(model.n_folds.keys()) == all_base_models + assert all(isinstance(n_fold, int) for n_fold in model.n_folds.values()) + assert set(model.feature_set.keys()) == all_base_models @pytest.mark.parametrize(