From ae49674469559d98ea3c28305c2f6b706eeaf6fc Mon Sep 17 00:00:00 2001 From: Kevin Klein <7267523+kklein@users.noreply.github.com> Date: Mon, 26 Feb 2024 20:46:21 +0100 Subject: [PATCH] ABC for MetaLearner (#5) Co-authored-by: MatthiasLuxQC <96533068+MatthiasLuxQC@users.noreply.github.com> --- metalearners/_utils.py | 20 +++- metalearners/cross_fit_estimator.py | 8 ++ metalearners/metalearner.py | 169 ++++++++++++++++++++++++++++ tests/test_metalearner.py | 59 ++++++++++ 4 files changed, 251 insertions(+), 5 deletions(-) create mode 100644 metalearners/metalearner.py create mode 100644 tests/test_metalearner.py diff --git a/metalearners/_utils.py b/metalearners/_utils.py index 535a96b..2858b29 100644 --- a/metalearners/_utils.py +++ b/metalearners/_utils.py @@ -1,6 +1,7 @@ # Copyright (c) QuantCo 2024-2024 # SPDX-License-Identifier: LicenseRef-QuantCo +import operator from typing import Protocol, Union import numpy as np @@ -12,13 +13,11 @@ class _ScikitModel(Protocol): # https://stackoverflow.com/questions/54868698/what-type-is-a-sklearn-model/60542986#60542986 - def __call__(self, **kwargs): ... + def fit(self, X, y, *params, **kwargs): ... - def fit(self, X, y, sample_weight=None, **kwargs): ... + def predict(self, X, *params, **kwargs): ... - def predict(self, X, **kwargs): ... - - def score(self, X, y, sample_weight=None, **kwargs): ... + def score(self, X, y, **kwargs): ... def set_params(self, **params): ... @@ -28,3 +27,14 @@ def index_matrix(matrix: Matrix, rows: Vector) -> Matrix: if isinstance(matrix, pd.DataFrame): return matrix.iloc[rows] return matrix[rows, :] + + +def validate_number_positive( + value: Union[int, float], name: str, strict: bool = False +) -> None: + if strict: + comparison = operator.lt + else: + comparison = operator.le + if comparison(value, 0): + raise ValueError(f"{name} was expected to be positive but was {value}.") diff --git a/metalearners/cross_fit_estimator.py b/metalearners/cross_fit_estimator.py index 7ce2601..af31231 100644 --- a/metalearners/cross_fit_estimator.py +++ b/metalearners/cross_fit_estimator.py @@ -86,6 +86,7 @@ def fit( self, X: Matrix, y: Union[Vector, Matrix], + **kwargs, ) -> Self: """Fit the underlying estimators. @@ -217,6 +218,7 @@ def predict( X: Matrix, is_oos: bool, oos_method: Optional[_OosMethod] = None, + **kwargs, ) -> np.ndarray: """Predict from ``X``. @@ -253,3 +255,9 @@ def predict_proba( method="predict_proba", oos_method=oos_method, ) + + def score(self, X, y, sample_weight=None, **kwargs): + raise NotImplementedError() + + def set_params(self, **params): + raise NotImplementedError() diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py new file mode 100644 index 0000000..a1978d7 --- /dev/null +++ b/metalearners/metalearner.py @@ -0,0 +1,169 @@ +# Copyright (c) QuantCo 2024-2024 +# SPDX-License-Identifier: LicenseRef-QuantCo + +from abc import ABC, abstractmethod +from typing import Collection, Dict, List, Optional, Union + +import numpy as np +from typing_extensions import Self + +from metalearners._utils import Matrix, Vector, _ScikitModel, validate_number_positive +from metalearners.cross_fit_estimator import CrossFitEstimator + +Params = Dict[str, Union[int, float, str]] +Features = Union[Collection[str], Collection[int]] + + +def _initialize_model_dict(argument, expected_names: Collection[str]) -> Dict: + if isinstance(argument, dict) and set(argument.keys()) == set(expected_names): + return argument + return {name: argument for name in expected_names} + + +class MetaLearner(ABC): + + @classmethod + @abstractmethod + def nuisance_model_names(cls) -> List[str]: ... + + @classmethod + @abstractmethod + def treatment_model_names(cls) -> List[str]: ... + + def __init__( + self, + nuisance_model_factory: Union[_ScikitModel, Dict[str, _ScikitModel]], + treatment_model_factory: Union[_ScikitModel, Dict[str, _ScikitModel]], + nuisance_model_params: Optional[Union[Params, Dict[str, Params]]] = None, + treatment_model_params: Optional[Union[Params, Dict[str, Params]]] = None, + feature_set: Optional[Union[Features, Dict[str, Features]]] = None, + # TODO: Consider implementing selection of number of folds for various estimators. + n_folds: int = 10, + ): + """Initialize a MetaLearner. + + All of + * ``nuisance_model_factory`` + * ``treatment_model_factory`` + * ``nuisance_model_params`` + * ``treatment_model_params`` + * ``feature_set`` + + can either + + * contain a single value, such that the value will be used for all relevant models + of the respective MetaLearner or + * a dictionary mapping from the relevant models (``model_kind``, a ``str``) to the + respective value + """ + nuisance_model_names = self.__class__.nuisance_model_names() + treatment_model_names = self.__class__.treatment_model_names() + + self.nuisance_model_factory = _initialize_model_dict( + nuisance_model_factory, nuisance_model_names + ) + if nuisance_model_params is None: + self.nuisance_model_params = _initialize_model_dict( + {}, nuisance_model_names + ) + else: + self.nuisance_model_params = _initialize_model_dict( + nuisance_model_params, nuisance_model_names + ) + self.treatment_model_factory = _initialize_model_dict( + treatment_model_factory, treatment_model_names + ) + if treatment_model_params is None: + self.treatment_model_params = _initialize_model_dict( + {}, treatment_model_names + ) + else: + self.treatment_model_params = _initialize_model_dict( + treatment_model_params, treatment_model_names + ) + + validate_number_positive(n_folds, "n_folds") + self.n_folds = n_folds + + if feature_set is None: + self.feature_set = None + else: + self.feature_set = _initialize_model_dict( + feature_set, nuisance_model_names + treatment_model_names + ) + + self._nuisance_models: Dict[str, _ScikitModel] = { + name: CrossFitEstimator( + n_folds=self.n_folds, + estimator_factory=self.nuisance_model_factory[name], + estimator_params=self.nuisance_model_params[name], + ) + for name in nuisance_model_names + } + self._treatment_models: Dict[str, _ScikitModel] = { + name: CrossFitEstimator( + n_folds=self.n_folds, + estimator_factory=self.treatment_model_factory[name], + estimator_params=self.treatment_model_params[name], + ) + for name in treatment_model_names + } + + def fit_nuisance(self, X: Matrix, y: Vector, model_kind: str) -> Self: + """Fit a given nuisance model of a MetaLearner. + + ``y`` represents the objective of the given nuisance model, not necessarily the outcome of the experiment. + """ + X_filtered = X[self.feature_set[model_kind]] if self.feature_set else X + self._nuisance_models[model_kind].fit(X_filtered, y) + return self + + def fit_treatment(self, X: Matrix, y: Vector, model_kind: str) -> Self: + """Fit the tratment model of a MetaLearner. + + ``y`` represents the objective of the given treatment model, not necessarily the outcome of the experiment. + """ + X_filtered = X[self.feature_set[model_kind]] if self.feature_set else X + self._treatment_models[model_kind].fit(X_filtered, y) + return self + + @abstractmethod + def fit(self, X: Matrix, y: Vector, w: Vector) -> Self: + """Fit all models of a MetaLearner.""" + ... + + def predict_nuisance(self, X: Matrix, model_kind: str) -> np.ndarray: + """Estimate based on a given nuisance model. + + Importantly, this method needs to implement the subselection of ``X`` based on + the ``feature_set`` field of ``MetaLearner``. + """ + X_filtered = X[self.feature_set[model_kind]] if self.feature_set else X + return self._nuisance_models[model_kind].predict(X_filtered) + + def predict_treatment(self, X: Matrix, model_kind: str) -> np.ndarray: + """Estimate based on a given treatment model. + + Importantly, this method needs to implement the subselection of ``X`` based on + the ``feature_set`` field of ``MetaLearner``. + """ + X_filtered = X[self.feature_set[model_kind]] if self.feature_set else X + return self._treatment_models[model_kind].predict(X_filtered) + + @abstractmethod + def predict(self, X: Matrix) -> np.ndarray: + """Estimate the Conditional Average Treatment Effect. + + This method can be identical to predict_treatment but doesn't need to. + """ + ... + + @abstractmethod + def evaluate(self, X: Matrix, y: Vector, w: Vector) -> Dict[str, Union[float, int]]: + """Evaluate all models contained in a MetaLearner.""" + ... + + @abstractmethod + def _pseudo_outcome(self, *args, **kwargs) -> Vector: + """Compute the vector of pseudo outcomes of the respective MetaLearner .""" + ... diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py new file mode 100644 index 0000000..92de440 --- /dev/null +++ b/tests/test_metalearner.py @@ -0,0 +1,59 @@ +# Copyright (c) QuantCo 2024-2024 +# SPDX-License-Identifier: LicenseRef-QuantCo + +import numpy as np +import pytest +from lightgbm import LGBMRegressor + +from metalearners.metalearner import MetaLearner + + +class _TestMetaLearner(MetaLearner): + @classmethod + def nuisance_model_names(cls): + return ["nuisance1", "nuisance2"] + + @classmethod + def treatment_model_names(cls): + return ["treatment1", "treatment2"] + + def fit(self, X, y, w): + for model_kind in self.__class__.nuisance_model_names(): + self._nuisance_models[model_kind].fit(X, y) + for model_kind in self.__class__.treatment_model_names(): + self._treatment_models[model_kind].fit(X, y) + return self + + def predict(self, X): + return np.zeros(len(X)) + + def evaluate(self, X, y, w): + return {} + + def _pseudo_outcome(self, X): + return np.zeros(len(X)) + + +@pytest.mark.parametrize("nuisance_model_factory", [LGBMRegressor]) +@pytest.mark.parametrize("treatment_model_factory", [LGBMRegressor]) +@pytest.mark.parametrize("nuisance_model_params", [None, {}, {"n_estimators": 5}]) +@pytest.mark.parametrize("treatment_model_params", [None, {}, {"n_estimators": 5}]) +@pytest.mark.parametrize("feature_set", [None]) +@pytest.mark.parametrize("n_folds", [5]) +def test_metalearner_init( + mindset_data, + nuisance_model_factory, + treatment_model_factory, + nuisance_model_params, + treatment_model_params, + feature_set, + n_folds, +): + _TestMetaLearner( + nuisance_model_factory=nuisance_model_factory, + treatment_model_factory=treatment_model_factory, + nuisance_model_params=nuisance_model_params, + treatment_model_params=treatment_model_params, + feature_set=feature_set, + n_folds=n_folds, + )