Skip to content

Commit

Permalink
Provide helper method to initialize MetaLearner.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Jul 24, 2024
1 parent a1ee27c commit 46dd4ea
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 2 deletions.
6 changes: 5 additions & 1 deletion metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from joblib import Parallel, delayed
from typing_extensions import Self
from typing_extensions import Any, Self

from metalearners._typing import (
Features,
Expand Down Expand Up @@ -398,3 +398,7 @@ def _pseudo_outcome(
)

return pseudo_outcome

@property
def init_args(self) -> dict[str, Any]:
return super().init_args | {"adaptive_clipping": self.adaptive_clipping}
33 changes: 32 additions & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Collection, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import TypedDict
from typing import Any, TypedDict

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1123,6 +1123,37 @@ def _default_scoring() -> Scoring:
return default_scoring
return dict(default_scoring) | dict(scoring)

@property
def init_args(self) -> dict[str, Any]:
"""Create initiliazation parameters for a new MetaLearner."""
if self._prefitted_nuisance_models:
raise ValueError(
"Cannot recreate MetaLearner if a pre-fitted model was used."
)
return {
"is_classification": self.is_classification,
"n_variants": self.n_variants,
"nuisance_model_factory": {
k: v
for k, v in self.nuisance_model_factory.items()
if k != PROPENSITY_MODEL
},
"treatment_model_factory": self.treatment_model_factory,
"propensity_model_factory": self.nuisance_model_factory.get(
PROPENSITY_MODEL
),
"nuisance_model_params": {
k: v
for k, v in self.nuisance_model_params.items()
if k != PROPENSITY_MODEL
},
"treatment_model_params": self.treatment_model_params,
"propensity_model_params": self.nuisance_model_params.get(PROPENSITY_MODEL),
"feature_set": self.feature_set,
"n_folds": self.n_folds,
"random_state": self.random_state,
}


class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC):

Expand Down
15 changes: 15 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,18 @@ def test_validate_outcome_different_classes(implementation, use_pandas, rng):
ValueError, match="have seen different sets of classification outcomes."
):
ml.fit(X, y, w)


@pytest.mark.parametrize(
"implementation",
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
def test_init_args_smoke(implementation):
ml = implementation(
True,
2,
LogisticRegression,
LinearRegression,
LogisticRegression,
)
implementation(**ml.init_args)

0 comments on commit 46dd4ea

Please sign in to comment.