Skip to content

Commit

Permalink
Abstract method and raise error for SLearner
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 5, 2024
1 parent 5863769 commit e9f5463
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 21 deletions.
6 changes: 1 addition & 5 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,7 @@ def _pseudo_outcome(

return pseudo_outcome

def build_onnx(
self,
models: Mapping[str, Sequence],
output_name: str = "tau",
):
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
3 changes: 3 additions & 0 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,9 @@ def _validate_feature_set_all(self):
"features."
)

@abstractmethod
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ...


class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC):

Expand Down
6 changes: 1 addition & 5 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,7 @@ def _pseudo_outcome_and_weights(

return pseudo_outcomes, weights

def build_onnx(
self,
models: Mapping[str, Sequence],
output_name: str = "tau",
):
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
8 changes: 8 additions & 0 deletions metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from collections.abc import Mapping, Sequence

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -154,6 +155,7 @@ def fit(
self._validate_treatment(w)
self._validate_outcome(y)
self._fitted_treatments = convert_treatment(w)
self._n_features = X.shape[1]

mock_model = self.nuisance_model_factory[_BASE_MODEL](
**self.nuisance_model_params[_BASE_MODEL]
Expand Down Expand Up @@ -284,3 +286,9 @@ def predict_conditional_average_outcomes(
return np.stack(conditional_average_outcomes_list, axis=1).reshape(
n_obs, self.n_variants, -1
)

def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
raise ValueError(
"The SLearner does not implement this method. Please refer to the tutorial "
"on the documentation on how to do this."
)
7 changes: 1 addition & 6 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,7 @@ def evaluate(
is_treatment_model=False,
)

# TODO: Fix typing without importing onnx
def build_onnx(
self,
models: Mapping[str, Sequence],
output_name: str = "tau",
):
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
6 changes: 1 addition & 5 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,7 @@ def _pseudo_outcome(

return imputed_te_control, imputed_te_treatment

def build_onnx(
self,
models: Mapping[str, Sequence],
output_name: str = "tau",
):
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
3 changes: 3 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) QuantCo 2024-2024
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Mapping, Sequence
from itertools import chain

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -99,6 +100,8 @@ def evaluate(self, X, y, w, is_oos, oos_method=None):
def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None):
return np.zeros((len(X), 2, 1))

def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ...


@pytest.mark.parametrize("is_classification", [True, False])
@pytest.mark.parametrize("nuisance_model_params", [None, {}, {"n_estimators": 5}])
Expand Down

0 comments on commit e9f5463

Please sign in to comment.