Skip to content

Commit

Permalink
Adapt test.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Jun 18, 2024
1 parent 6846325 commit ab23a4e
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None):
return np.zeros((len(X), 2, 1))


@pytest.mark.parametrize("nuisance_model_factory", [LGBMRegressor])
@pytest.mark.parametrize("treatment_model_factory", [LGBMRegressor])
@pytest.mark.parametrize("is_classification", [True, False])
@pytest.mark.parametrize("nuisance_model_params", [None, {}, {"n_estimators": 5}])
@pytest.mark.parametrize("treatment_model_params", [None, {}, {"n_estimators": 5}])
Expand All @@ -110,33 +108,49 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None):
[
None,
{
"nuisance1": ["X1"],
"nuisance2": ["X2"],
"propensity_model": ["Xp"],
"treatment1": ["X1"],
"treatment2": ["X2"],
VARIANT_OUTCOME_MODEL: ["X1"],
CONTROL_EFFECT_MODEL: ["X2"],
TREATMENT_EFFECT_MODEL: ["X1"],
TREATMENT_MODEL: ["X2"],
PROPENSITY_MODEL: ["Xp"],
OUTCOME_MODEL: ["X1"],
},
],
)
@pytest.mark.parametrize(
"n_folds", [5, {"nuisance1": 1, "nuisance2": 1, "treatment1": 5, "treatment2": 10}]
"n_folds",
[
5,
{
VARIANT_OUTCOME_MODEL: 5,
CONTROL_EFFECT_MODEL: 5,
TREATMENT_EFFECT_MODEL: 5,
TREATMENT_MODEL: 5,
PROPENSITY_MODEL: 5,
OUTCOME_MODEL: 5,
},
],
)
@pytest.mark.parametrize("propensity_model_factory", [None, LGBMClassifier])
@pytest.mark.parametrize("propensity_model_params", [None, {}, {"n_estimators": 5}])
@pytest.mark.parametrize("n_variants", [2, 5, 10])
@pytest.mark.parametrize(
"implementation",
[TLearner, XLearner, RLearner, DRLearner],
)
def test_metalearner_init(
nuisance_model_factory,
treatment_model_factory,
propensity_model_factory,
is_classification,
n_variants,
nuisance_model_params,
treatment_model_params,
propensity_model_params,
feature_set,
n_folds,
implementation,
):
_TestMetaLearner(
propensity_model_factory = LGBMClassifier
nuisance_model_factory = LGBMClassifier if is_classification else LGBMRegressor
treatment_model_factory = LGBMRegressor
model = implementation(
nuisance_model_factory=nuisance_model_factory,
is_classification=is_classification,
n_variants=n_variants,
Expand All @@ -148,6 +162,12 @@ def test_metalearner_init(
feature_set=feature_set,
n_folds=n_folds,
)
all_base_models = set(model.nuisance_model_specifications().keys()) | set(
model.treatment_model_specifications().keys()
)
assert set(model.n_folds.keys()) == all_base_models
assert all(isinstance(n_fold, int) for n_fold in model.n_folds.values())
assert set(model.feature_set.keys()) == all_base_models


@pytest.mark.parametrize(
Expand Down

0 comments on commit ab23a4e

Please sign in to comment.