diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 6964547..c9539fd 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -227,8 +227,6 @@ def treatment_model_specifications(cls) -> dict[str, _ModelSpecifications]: """Return the specifications of all second-stage models.""" ... - def _validate_params(self, **kwargs): ... - @classmethod @abstractmethod def _supports_multi_treatment(cls) -> bool: ... @@ -352,20 +350,6 @@ def __init__( n_folds: int | dict[str, int] = 10, random_state: int | None = None, ): - self._validate_params( - nuisance_model_factory=nuisance_model_factory, - treatment_model_factory=treatment_model_factory, - propensity_model_factory=propensity_model_factory, - is_classification=is_classification, - n_variants=n_variants, - nuisance_model_params=nuisance_model_params, - treatment_model_params=treatment_model_params, - propensity_model_params=propensity_model_params, - feature_set=feature_set, - n_folds=n_folds, - random_state=random_state, - ) - nuisance_model_specifications = self.nuisance_model_specifications() treatment_model_specifications = self.treatment_model_specifications() diff --git a/metalearners/slearner.py b/metalearners/slearner.py index 41b49a8..9d42522 100644 --- a/metalearners/slearner.py +++ b/metalearners/slearner.py @@ -8,13 +8,21 @@ from sklearn.metrics import log_loss, root_mean_squared_error from typing_extensions import Self -from metalearners._typing import Matrix, OosMethod, Vector +from metalearners._typing import ( + Features, + Matrix, + ModelFactory, + OosMethod, + Params, + Vector, + _ScikitModel, +) from metalearners._utils import ( convert_treatment, get_one, supports_categoricals, ) -from metalearners.cross_fit_estimator import OVERALL +from metalearners.cross_fit_estimator import OVERALL, CrossFitEstimator from metalearners.metalearner import NUISANCE, MetaLearner, _ModelSpecifications _BASE_MODEL = "base_model" @@ -88,7 +96,22 @@ def _supports_multi_treatment(cls) -> bool: def _supports_multi_class(cls) -> bool: return True - def _validate_params(self, feature_set, **kwargs): + def __init__( + self, + is_classification: bool, + n_variants: int, + nuisance_model_factory: ModelFactory | None = None, + treatment_model_factory: ModelFactory | None = None, + propensity_model_factory: type[_ScikitModel] | None = None, + nuisance_model_params: Params | dict[str, Params] | None = None, + treatment_model_params: Params | dict[str, Params] | None = None, + propensity_model_params: Params | None = None, + fitted_nuisance_models: dict[str, list[CrossFitEstimator]] | None = None, + fitted_propensity_model: CrossFitEstimator | None = None, + feature_set: Features | dict[str, Features] | None = None, + n_folds: int | dict[str, int] = 10, + random_state: int | None = None, + ): if feature_set is not None: # For SLearner it does not make sense to allow feature set as we only have one model # and having it would bring problems when using fit_nuisance and predict_nuisance @@ -97,6 +120,21 @@ def _validate_params(self, feature_set, **kwargs): "Base-model specific feature_sets were provided to S-Learner. " "These will be ignored and all available features will be used instead." ) + super().__init__( + is_classification=is_classification, + n_variants=n_variants, + nuisance_model_factory=nuisance_model_factory, + treatment_model_factory=treatment_model_factory, + propensity_model_factory=propensity_model_factory, + nuisance_model_params=nuisance_model_params, + treatment_model_params=treatment_model_params, + propensity_model_params=propensity_model_params, + fitted_nuisance_models=fitted_nuisance_models, + fitted_propensity_model=fitted_propensity_model, + feature_set=None, + n_folds=n_folds, + random_state=random_state, + ) def fit( self, diff --git a/tests/test_slearner.py b/tests/test_slearner.py index 729f20b..37f6086 100644 --- a/tests/test_slearner.py +++ b/tests/test_slearner.py @@ -10,12 +10,21 @@ from metalearners.slearner import SLearner, _append_treatment_to_covariates -def test_feature_set_doesnt_raise(): - SLearner( +def test_feature_set_doesnt_raise(rng): + slearner = SLearner( nuisance_model_factory=LinearRegression, is_classification=False, n_variants=2, - feature_set="", + feature_set=[0], + ) + + X = rng.standard_normal((100, 2)) + y = rng.standard_normal(100) + w = rng.integers(0, 2, 100) + slearner.fit(X, y, w) + assert ( + slearner._nuisance_models["base_model"][0]._overall_estimator.n_features_in_ # type: ignore + == 3 )