Skip to content

Commit

Permalink
Add warning in docstring and build_onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 15, 2024
1 parent e8d7a8d commit c6b7951
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 0 deletions.
8 changes: 8 additions & 0 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) QuantCo 2024-2024
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from collections.abc import Callable
from inspect import signature
from operator import le, lt
Expand Down Expand Up @@ -557,3 +558,10 @@ def infer_input_dict(model) -> dict:
input_dict[input_tensor.name] = argument(Tensor(input_dtype, input_shape))

return input_dict


def warning_experimental_feature(function_name: str):
warnings.warn(
f"{function_name} is an experimental feature. Use it at your own risk!",
stacklevel=2,
)
2 changes: 2 additions & 0 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
index_matrix,
infer_input_dict,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
)
from metalearners.cross_fit_estimator import OVERALL
from metalearners.metalearner import (
Expand Down Expand Up @@ -335,6 +336,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
* ``"treatment_model"``
"""
warning_experimental_feature("build_onnx")
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
4 changes: 4 additions & 0 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,10 @@ def necessary_onnx_models(cls) -> set[str]:
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""Convert the MetaLearner to an ONNX model.
.. warning::
This is a experimental feature which is not subject to deprecation cycles. Use
it at your own risk!
``output_name`` can be used to change the output name of the ONNX model.
``models`` should be a dictionary of sequences with the necessary base models converted to
Expand Down
2 changes: 2 additions & 0 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
infer_input_dict,
validate_all_vectors_same_index,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
)
from metalearners.cross_fit_estimator import OVERALL
from metalearners.metalearner import (
Expand Down Expand Up @@ -520,6 +521,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
* ``"treatment_model"``
"""
warning_experimental_feature("build_onnx")
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
2 changes: 2 additions & 0 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
index_matrix,
infer_input_dict,
infer_probabilities_output,
warning_experimental_feature,
)
from metalearners.cross_fit_estimator import OVERALL
from metalearners.metalearner import (
Expand Down Expand Up @@ -147,6 +148,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
* ``"variant_outcome_model"``
"""
warning_experimental_feature("build_onnx")
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
2 changes: 2 additions & 0 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
infer_input_dict,
infer_probabilities_output,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
)
from metalearners.cross_fit_estimator import MEDIAN, OVERALL
from metalearners.metalearner import (
Expand Down Expand Up @@ -419,6 +420,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
* ``"control_effect_model"``
* ``"treatment_effect_model"``
"""
warning_experimental_feature("build_onnx")
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down

0 comments on commit c6b7951

Please sign in to comment.