From 3cad00eb94a4e856cf8680e082ccad0bb776f273 Mon Sep 17 00:00:00 2001 From: Kevin Klein <7267523+kklein@users.noreply.github.com> Date: Thu, 25 Jul 2024 15:01:58 +0200 Subject: [PATCH] Provide helper method to initiliaze a `MetaLearner` based on another `MetaLearner` (#71) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Provide helper method to initialize MetaLearner. * Fix logic revolving around pre-fitted models. * Add changelog entry. * Expand on docstring. * Compare attributes. * Update metalearners/metalearner.py Co-authored-by: Francesc Martí Escofet <154450563+FrancescMartiEscofetQC@users.noreply.github.com> * Update metalearners/metalearner.py Co-authored-by: Francesc Martí Escofet <154450563+FrancescMartiEscofetQC@users.noreply.github.com> --------- Co-authored-by: Francesc Martí Escofet <154450563+FrancescMartiEscofetQC@users.noreply.github.com> --- CHANGELOG.rst | 8 ++++++ metalearners/drlearner.py | 6 ++++- metalearners/metalearner.py | 51 ++++++++++++++++++++++++++++++++++++- tests/test_metalearner.py | 19 ++++++++++++++ 4 files changed, 82 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 50f2ee47..bd987b68 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,14 @@ Changelog ========= +0.9.0 (2024-07-xx) +------------------ + +**New features** + +* Added :meth:`metalearners.metalearner.MetaLearner.init_params`. + + 0.8.0 (2024-07-22) ------------------ diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index a7df898d..1a29487e 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -4,7 +4,7 @@ import numpy as np from joblib import Parallel, delayed -from typing_extensions import Self +from typing_extensions import Any, Self from metalearners._typing import ( Features, @@ -398,3 +398,7 @@ def _pseudo_outcome( ) return pseudo_outcome + + @property + def init_args(self) -> dict[str, Any]: + return super().init_args | {"adaptive_clipping": self.adaptive_clipping} diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 06415dcc..5bed116a 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Collection, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import TypedDict +from typing import Any, TypedDict import numpy as np import pandas as pd @@ -1123,6 +1123,55 @@ def _default_scoring() -> Scoring: return default_scoring return dict(default_scoring) | dict(scoring) + @property + def init_args(self) -> dict[str, Any]: + """Create initiliazation parameters for a new MetaLearner. + + Importantly, this does not copy further internal state, such as the weights or + parameters of trained base models. + """ + return { + "is_classification": self.is_classification, + "n_variants": self.n_variants, + "nuisance_model_factory": { + k: v + for k, v in self.nuisance_model_factory.items() + if k != PROPENSITY_MODEL + if k not in self._prefitted_nuisance_models + }, + "treatment_model_factory": self.treatment_model_factory, + "propensity_model_factory": ( + self.nuisance_model_factory.get(PROPENSITY_MODEL) + if PROPENSITY_MODEL not in self._prefitted_nuisance_models + else None + ), + "nuisance_model_params": { + k: v + for k, v in self.nuisance_model_params.items() + if k != PROPENSITY_MODEL + if k not in self._prefitted_nuisance_models + }, + "treatment_model_params": self.treatment_model_params, + "propensity_model_params": ( + self.nuisance_model_params.get(PROPENSITY_MODEL) + if PROPENSITY_MODEL not in self._prefitted_nuisance_models + else None + ), + "fitted_nuisance_models": { + k: deepcopy(v) + for k, v in self._nuisance_models.items() + if k in self._prefitted_nuisance_models and k != PROPENSITY_MODEL + }, + "fitted_propensity_model": ( + deepcopy(self._nuisance_models.get(PROPENSITY_MODEL)) + if PROPENSITY_MODEL in self._prefitted_nuisance_models + else None + ), + "feature_set": self.feature_set, + "n_folds": self.n_folds, + "random_state": self.random_state, + } + class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC): diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index 03633c2e..a65087cd 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -1119,3 +1119,22 @@ def test_validate_outcome_different_classes(implementation, use_pandas, rng): ValueError, match="have seen different sets of classification outcomes." ): ml.fit(X, y, w) + + +@pytest.mark.parametrize( + "implementation", + [TLearner, SLearner, XLearner, RLearner, DRLearner], +) +def test_init_args(implementation): + ml = implementation( + True, + 2, + LogisticRegression, + LinearRegression, + LogisticRegression, + ) + ml2 = implementation(**ml.init_args) + + assert set(ml.__dict__.keys()) == set(ml2.__dict__.keys()) + for key in ml.__dict__: + assert ml.__dict__[key] == ml2.__dict__[key]