Skip to content

Commit

Permalink
Provide helper method to initiliaze a MetaLearner based on another …
Browse files Browse the repository at this point in the history
…`MetaLearner` (#71)

* Provide helper method to initialize MetaLearner.

* Fix logic revolving around pre-fitted models.

* Add changelog entry.

* Expand on docstring.

* Compare attributes.

* Update metalearners/metalearner.py

Co-authored-by: Francesc Martí Escofet <[email protected]>

* Update metalearners/metalearner.py

Co-authored-by: Francesc Martí Escofet <[email protected]>

---------

Co-authored-by: Francesc Martí Escofet <[email protected]>
  • Loading branch information
kklein and FrancescMartiEscofetQC authored Jul 25, 2024
1 parent 2f428fb commit 3cad00e
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 2 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
Changelog
=========

0.9.0 (2024-07-xx)
------------------

**New features**

* Added :meth:`metalearners.metalearner.MetaLearner.init_params`.


0.8.0 (2024-07-22)
------------------

Expand Down
6 changes: 5 additions & 1 deletion metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from joblib import Parallel, delayed
from typing_extensions import Self
from typing_extensions import Any, Self

from metalearners._typing import (
Features,
Expand Down Expand Up @@ -398,3 +398,7 @@ def _pseudo_outcome(
)

return pseudo_outcome

@property
def init_args(self) -> dict[str, Any]:
return super().init_args | {"adaptive_clipping": self.adaptive_clipping}
51 changes: 50 additions & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Collection, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import TypedDict
from typing import Any, TypedDict

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1123,6 +1123,55 @@ def _default_scoring() -> Scoring:
return default_scoring
return dict(default_scoring) | dict(scoring)

@property
def init_args(self) -> dict[str, Any]:
"""Create initiliazation parameters for a new MetaLearner.
Importantly, this does not copy further internal state, such as the weights or
parameters of trained base models.
"""
return {
"is_classification": self.is_classification,
"n_variants": self.n_variants,
"nuisance_model_factory": {
k: v
for k, v in self.nuisance_model_factory.items()
if k != PROPENSITY_MODEL
if k not in self._prefitted_nuisance_models
},
"treatment_model_factory": self.treatment_model_factory,
"propensity_model_factory": (
self.nuisance_model_factory.get(PROPENSITY_MODEL)
if PROPENSITY_MODEL not in self._prefitted_nuisance_models
else None
),
"nuisance_model_params": {
k: v
for k, v in self.nuisance_model_params.items()
if k != PROPENSITY_MODEL
if k not in self._prefitted_nuisance_models
},
"treatment_model_params": self.treatment_model_params,
"propensity_model_params": (
self.nuisance_model_params.get(PROPENSITY_MODEL)
if PROPENSITY_MODEL not in self._prefitted_nuisance_models
else None
),
"fitted_nuisance_models": {
k: deepcopy(v)
for k, v in self._nuisance_models.items()
if k in self._prefitted_nuisance_models and k != PROPENSITY_MODEL
},
"fitted_propensity_model": (
deepcopy(self._nuisance_models.get(PROPENSITY_MODEL))
if PROPENSITY_MODEL in self._prefitted_nuisance_models
else None
),
"feature_set": self.feature_set,
"n_folds": self.n_folds,
"random_state": self.random_state,
}


class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC):

Expand Down
19 changes: 19 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,22 @@ def test_validate_outcome_different_classes(implementation, use_pandas, rng):
ValueError, match="have seen different sets of classification outcomes."
):
ml.fit(X, y, w)


@pytest.mark.parametrize(
"implementation",
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
def test_init_args(implementation):
ml = implementation(
True,
2,
LogisticRegression,
LinearRegression,
LogisticRegression,
)
ml2 = implementation(**ml.init_args)

assert set(ml.__dict__.keys()) == set(ml2.__dict__.keys())
for key in ml.__dict__:
assert ml.__dict__[key] == ml2.__dict__[key]

0 comments on commit 3cad00e

Please sign in to comment.