diff --git a/metalearners/_utils.py b/metalearners/_utils.py index c02b024..df496ff 100644 --- a/metalearners/_utils.py +++ b/metalearners/_utils.py @@ -1,6 +1,7 @@ # Copyright (c) QuantCo 2024-2024 # SPDX-License-Identifier: BSD-3-Clause +import warnings from collections.abc import Callable from inspect import signature from operator import le, lt @@ -557,3 +558,10 @@ def infer_input_dict(model) -> dict: input_dict[input_tensor.name] = argument(Tensor(input_dtype, input_shape)) return input_dict + + +def warning_experimental_feature(function_name: str): + warnings.warn( + f"{function_name} is an experimental feature. Use it at your own risk!", + stacklevel=2, + ) diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index efe776e..7c898b5 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -20,6 +20,7 @@ index_matrix, infer_input_dict, validate_valid_treatment_variant_not_control, + warning_experimental_feature, ) from metalearners.cross_fit_estimator import OVERALL from metalearners.metalearner import ( @@ -335,6 +336,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): * ``"treatment_model"`` """ + warning_experimental_feature("build_onnx") check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 4bf1cdd..e9be774 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -1110,6 +1110,10 @@ def necessary_onnx_models(cls) -> set[str]: def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): """Convert the MetaLearner to an ONNX model. + .. warning:: + This is a experimental feature which is not subject to deprecation cycles. Use + it at your own risk! + ``output_name`` can be used to change the output name of the ONNX model. ``models`` should be a dictionary of sequences with the necessary base models converted to diff --git a/metalearners/rlearner.py b/metalearners/rlearner.py index e9ee2fa..aa5912a 100644 --- a/metalearners/rlearner.py +++ b/metalearners/rlearner.py @@ -23,6 +23,7 @@ infer_input_dict, validate_all_vectors_same_index, validate_valid_treatment_variant_not_control, + warning_experimental_feature, ) from metalearners.cross_fit_estimator import OVERALL from metalearners.metalearner import ( @@ -520,6 +521,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): * ``"treatment_model"`` """ + warning_experimental_feature("build_onnx") check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op diff --git a/metalearners/tlearner.py b/metalearners/tlearner.py index 649e115..98c8aed 100644 --- a/metalearners/tlearner.py +++ b/metalearners/tlearner.py @@ -15,6 +15,7 @@ index_matrix, infer_input_dict, infer_probabilities_output, + warning_experimental_feature, ) from metalearners.cross_fit_estimator import OVERALL from metalearners.metalearner import ( @@ -147,6 +148,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): * ``"variant_outcome_model"`` """ + warning_experimental_feature("build_onnx") check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index d5fe257..dab02f2 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -20,6 +20,7 @@ infer_input_dict, infer_probabilities_output, validate_valid_treatment_variant_not_control, + warning_experimental_feature, ) from metalearners.cross_fit_estimator import MEDIAN, OVERALL from metalearners.metalearner import ( @@ -419,6 +420,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): * ``"control_effect_model"`` * ``"treatment_effect_model"`` """ + warning_experimental_feature("build_onnx") check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op