Skip to content

Commit

Permalink
Check for superset instead of set equality in _initialize_model_dict (
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein authored Jun 18, 2024
1 parent 2cce2a0 commit edad6bb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 17 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ Changelog
0.5.0 (2024-06-18)
------------------

* No longer raise an error if ``feature_set`` is provided to :class:`metalearners.SLearner`.
* No longer raise an error if ``feature_set`` is provided to
:class:`metalearners.SLearner`.

* Fix a bug where base model dictionaries -- e.g. ``n_folds`` or
``feature-set`` -- were improperly initialized if the provided
dictionary's keys were a strict superset of the expected keys.

0.4.2 (2024-06-18)
------------------
Expand Down
7 changes: 4 additions & 3 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def _get_result_skeleton(


def _initialize_model_dict(argument, expected_names: Collection[str]) -> dict:
if isinstance(argument, dict) and set(argument.keys()) == set(expected_names):
return argument
if isinstance(argument, dict) and set(argument.keys()) >= set(expected_names):
return {key: argument[key] for key in expected_names}
return {name: argument for name in expected_names}


Expand Down Expand Up @@ -199,7 +199,8 @@ class MetaLearner(ABC):
* contain a single value, such that the value will be used for all relevant models
of the respective MetaLearner or
* a dictionary mapping from the relevant models (``model_kind``, a ``str``) to the
respective value
respective value; at least all relevant models need to be present, more are allowed
and ignored
The possible values for defining ``feature_set`` (either one single value for all
the models or the values inside the dictionary specifying for each model) can be:
Expand Down
48 changes: 35 additions & 13 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None):
return np.zeros((len(X), 2, 1))


@pytest.mark.parametrize("nuisance_model_factory", [LGBMRegressor])
@pytest.mark.parametrize("treatment_model_factory", [LGBMRegressor])
@pytest.mark.parametrize("is_classification", [True, False])
@pytest.mark.parametrize("nuisance_model_params", [None, {}, {"n_estimators": 5}])
@pytest.mark.parametrize("treatment_model_params", [None, {}, {"n_estimators": 5}])
Expand All @@ -110,33 +108,51 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None):
[
None,
{
"nuisance1": ["X1"],
"nuisance2": ["X2"],
"propensity_model": ["Xp"],
"treatment1": ["X1"],
"treatment2": ["X2"],
VARIANT_OUTCOME_MODEL: ["X1"],
CONTROL_EFFECT_MODEL: ["X2"],
TREATMENT_EFFECT_MODEL: ["X1"],
TREATMENT_MODEL: ["X2"],
PROPENSITY_MODEL: ["Xp"],
OUTCOME_MODEL: ["X1"],
_BASE_MODEL: ["X2"],
},
],
)
@pytest.mark.parametrize(
"n_folds", [5, {"nuisance1": 1, "nuisance2": 1, "treatment1": 5, "treatment2": 10}]
"n_folds",
[
5,
{
VARIANT_OUTCOME_MODEL: 5,
CONTROL_EFFECT_MODEL: 5,
TREATMENT_EFFECT_MODEL: 5,
TREATMENT_MODEL: 5,
PROPENSITY_MODEL: 5,
OUTCOME_MODEL: 5,
_BASE_MODEL: 5,
},
],
)
@pytest.mark.parametrize("propensity_model_factory", [None, LGBMClassifier])
@pytest.mark.parametrize("propensity_model_params", [None, {}, {"n_estimators": 5}])
@pytest.mark.parametrize("n_variants", [2, 5, 10])
@pytest.mark.parametrize(
"implementation",
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
def test_metalearner_init(
nuisance_model_factory,
treatment_model_factory,
propensity_model_factory,
is_classification,
n_variants,
nuisance_model_params,
treatment_model_params,
propensity_model_params,
feature_set,
n_folds,
implementation,
):
_TestMetaLearner(
propensity_model_factory = LGBMClassifier
nuisance_model_factory = LGBMClassifier if is_classification else LGBMRegressor
treatment_model_factory = LGBMRegressor
model = implementation(
nuisance_model_factory=nuisance_model_factory,
is_classification=is_classification,
n_variants=n_variants,
Expand All @@ -148,6 +164,12 @@ def test_metalearner_init(
feature_set=feature_set,
n_folds=n_folds,
)
all_base_models = set(model.nuisance_model_specifications().keys()) | set(
model.treatment_model_specifications().keys()
)
assert set(model.n_folds.keys()) == all_base_models
assert all(isinstance(n_fold, int) for n_fold in model.n_folds.values())
assert set(model.feature_set.keys()) == all_base_models


@pytest.mark.parametrize(
Expand Down

0 comments on commit edad6bb

Please sign in to comment.