From e9f546343d80930c7a651e054971555e869a9014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Fri, 5 Jul 2024 11:22:12 +0200 Subject: [PATCH] Abstract method and raise error for SLearner --- metalearners/drlearner.py | 6 +----- metalearners/metalearner.py | 3 +++ metalearners/rlearner.py | 6 +----- metalearners/slearner.py | 8 ++++++++ metalearners/tlearner.py | 7 +------ metalearners/xlearner.py | 6 +----- tests/test_metalearner.py | 3 +++ 7 files changed, 18 insertions(+), 21 deletions(-) diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index fb3239e..c9397f2 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -324,11 +324,7 @@ def _pseudo_outcome( return pseudo_outcome - def build_onnx( - self, - models: Mapping[str, Sequence], - output_name: str = "tau", - ): + def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index d38cb85..5b013df 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -1088,6 +1088,9 @@ def _validate_feature_set_all(self): "features." ) + @abstractmethod + def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ... + class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC): diff --git a/metalearners/rlearner.py b/metalearners/rlearner.py index 8b4bd31..44b59b0 100644 --- a/metalearners/rlearner.py +++ b/metalearners/rlearner.py @@ -510,11 +510,7 @@ def _pseudo_outcome_and_weights( return pseudo_outcomes, weights - def build_onnx( - self, - models: Mapping[str, Sequence], - output_name: str = "tau", - ): + def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op diff --git a/metalearners/slearner.py b/metalearners/slearner.py index 9d49e20..2fb5a8b 100644 --- a/metalearners/slearner.py +++ b/metalearners/slearner.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings +from collections.abc import Mapping, Sequence import numpy as np import pandas as pd @@ -154,6 +155,7 @@ def fit( self._validate_treatment(w) self._validate_outcome(y) self._fitted_treatments = convert_treatment(w) + self._n_features = X.shape[1] mock_model = self.nuisance_model_factory[_BASE_MODEL]( **self.nuisance_model_params[_BASE_MODEL] @@ -284,3 +286,9 @@ def predict_conditional_average_outcomes( return np.stack(conditional_average_outcomes_list, axis=1).reshape( n_obs, self.n_variants, -1 ) + + def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): + raise ValueError( + "The SLearner does not implement this method. Please refer to the tutorial " + "on the documentation on how to do this." + ) diff --git a/metalearners/tlearner.py b/metalearners/tlearner.py index 037a01e..578d131 100644 --- a/metalearners/tlearner.py +++ b/metalearners/tlearner.py @@ -136,12 +136,7 @@ def evaluate( is_treatment_model=False, ) - # TODO: Fix typing without importing onnx - def build_onnx( - self, - models: Mapping[str, Sequence], - output_name: str = "tau", - ): + def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index bd682a5..316eda5 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -406,11 +406,7 @@ def _pseudo_outcome( return imputed_te_control, imputed_te_treatment - def build_onnx( - self, - models: Mapping[str, Sequence], - output_name: str = "tau", - ): + def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index a21877d..f2f998d 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -1,6 +1,7 @@ # Copyright (c) QuantCo 2024-2024 # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Mapping, Sequence from itertools import chain import matplotlib.pyplot as plt @@ -99,6 +100,8 @@ def evaluate(self, X, y, w, is_oos, oos_method=None): def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None): return np.zeros((len(X), 2, 1)) + def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ... + @pytest.mark.parametrize("is_classification", [True, False]) @pytest.mark.parametrize("nuisance_model_params", [None, {}, {"n_estimators": 5}])