Skip to content

Commit

Permalink
Fix 'SLearner` feature set (#20)
Browse files Browse the repository at this point in the history
* Add failing test

* Fix failing test

* Remove _validate_params and override __init__
  • Loading branch information
FrancescMartiEscofetQC authored Jun 25, 2024
1 parent 4f8fef3 commit 3c5a3a4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 22 deletions.
16 changes: 0 additions & 16 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,6 @@ def treatment_model_specifications(cls) -> dict[str, _ModelSpecifications]:
"""Return the specifications of all second-stage models."""
...

def _validate_params(self, **kwargs): ...

@classmethod
@abstractmethod
def _supports_multi_treatment(cls) -> bool: ...
Expand Down Expand Up @@ -352,20 +350,6 @@ def __init__(
n_folds: int | dict[str, int] = 10,
random_state: int | None = None,
):
self._validate_params(
nuisance_model_factory=nuisance_model_factory,
treatment_model_factory=treatment_model_factory,
propensity_model_factory=propensity_model_factory,
is_classification=is_classification,
n_variants=n_variants,
nuisance_model_params=nuisance_model_params,
treatment_model_params=treatment_model_params,
propensity_model_params=propensity_model_params,
feature_set=feature_set,
n_folds=n_folds,
random_state=random_state,
)

nuisance_model_specifications = self.nuisance_model_specifications()
treatment_model_specifications = self.treatment_model_specifications()

Expand Down
44 changes: 41 additions & 3 deletions metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@
from sklearn.metrics import log_loss, root_mean_squared_error
from typing_extensions import Self

from metalearners._typing import Matrix, OosMethod, Vector
from metalearners._typing import (
Features,
Matrix,
ModelFactory,
OosMethod,
Params,
Vector,
_ScikitModel,
)
from metalearners._utils import (
convert_treatment,
get_one,
supports_categoricals,
)
from metalearners.cross_fit_estimator import OVERALL
from metalearners.cross_fit_estimator import OVERALL, CrossFitEstimator
from metalearners.metalearner import NUISANCE, MetaLearner, _ModelSpecifications

_BASE_MODEL = "base_model"
Expand Down Expand Up @@ -88,7 +96,22 @@ def _supports_multi_treatment(cls) -> bool:
def _supports_multi_class(cls) -> bool:
return True

def _validate_params(self, feature_set, **kwargs):
def __init__(
self,
is_classification: bool,
n_variants: int,
nuisance_model_factory: ModelFactory | None = None,
treatment_model_factory: ModelFactory | None = None,
propensity_model_factory: type[_ScikitModel] | None = None,
nuisance_model_params: Params | dict[str, Params] | None = None,
treatment_model_params: Params | dict[str, Params] | None = None,
propensity_model_params: Params | None = None,
fitted_nuisance_models: dict[str, list[CrossFitEstimator]] | None = None,
fitted_propensity_model: CrossFitEstimator | None = None,
feature_set: Features | dict[str, Features] | None = None,
n_folds: int | dict[str, int] = 10,
random_state: int | None = None,
):
if feature_set is not None:
# For SLearner it does not make sense to allow feature set as we only have one model
# and having it would bring problems when using fit_nuisance and predict_nuisance
Expand All @@ -97,6 +120,21 @@ def _validate_params(self, feature_set, **kwargs):
"Base-model specific feature_sets were provided to S-Learner. "
"These will be ignored and all available features will be used instead."
)
super().__init__(
is_classification=is_classification,
n_variants=n_variants,
nuisance_model_factory=nuisance_model_factory,
treatment_model_factory=treatment_model_factory,
propensity_model_factory=propensity_model_factory,
nuisance_model_params=nuisance_model_params,
treatment_model_params=treatment_model_params,
propensity_model_params=propensity_model_params,
fitted_nuisance_models=fitted_nuisance_models,
fitted_propensity_model=fitted_propensity_model,
feature_set=None,
n_folds=n_folds,
random_state=random_state,
)

def fit(
self,
Expand Down
15 changes: 12 additions & 3 deletions tests/test_slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@
from metalearners.slearner import SLearner, _append_treatment_to_covariates


def test_feature_set_doesnt_raise():
SLearner(
def test_feature_set_doesnt_raise(rng):
slearner = SLearner(
nuisance_model_factory=LinearRegression,
is_classification=False,
n_variants=2,
feature_set="",
feature_set=[0],
)

X = rng.standard_normal((100, 2))
y = rng.standard_normal(100)
w = rng.integers(0, 2, 100)
slearner.fit(X, y, w)
assert (
slearner._nuisance_models["base_model"][0]._overall_estimator.n_features_in_ # type: ignore
== 3
)


Expand Down

0 comments on commit 3c5a3a4

Please sign in to comment.