Skip to content

Commit

Permalink
Add option to validate parameters before __init__ in MetaLearner (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC authored May 3, 2024
1 parent d6da8ff commit 3635e6a
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def nuisance_model_names(cls) -> set[str]: ...
@abstractmethod
def treatment_model_names(cls) -> set[str]: ...

def _validate_params(self, **kwargs): ...

def __init__(
self,
nuisance_model_factory: ModelFactory,
Expand Down Expand Up @@ -80,6 +82,17 @@ def __init__(
* a dictionary mapping from the relevant models (``model_kind``, a ``str``) to the
respective value
"""
self._validate_params(
nuisance_model_factory=nuisance_model_factory,
treatment_model_factory=treatment_model_factory,
is_classification=is_classification,
nuisance_model_params=nuisance_model_params,
treatment_model_params=treatment_model_params,
feature_set=feature_set,
n_folds=n_folds,
random_state=random_state,
)

nuisance_model_names = self.__class__.nuisance_model_names()
treatment_model_names = self.__class__.treatment_model_names()

Expand Down

0 comments on commit 3635e6a

Please sign in to comment.