From fdffa3d42e3bdabdf5fccd7f0eed6020814b0c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Wed, 3 Jul 2024 09:11:49 +0200 Subject: [PATCH] Refactor --- metalearners/metalearner.py | 8 ++++++++ metalearners/tlearner.py | 36 ++++++++++++++++-------------------- metalearners/xlearner.py | 26 ++++++++++++-------------- tests/test_xlearner.py | 2 +- 4 files changed, 37 insertions(+), 35 deletions(-) diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 09220e5a..05dac8fb 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -1077,6 +1077,14 @@ def _validate_onnx_models( "with name 'probabilities' or 'output_probability'." ) + def _validate_feature_set_all(self): + for feature_set in self.feature_set.values(): + if feature_set is not None: + raise ValueError( + "ONNX conversion can only be used if all base models use all the " + "features." + ) + class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC): diff --git a/metalearners/tlearner.py b/metalearners/tlearner.py index 1e3c3fd8..d6e4e25b 100644 --- a/metalearners/tlearner.py +++ b/metalearners/tlearner.py @@ -137,7 +137,12 @@ def evaluate( ) # TODO: Fix typing without importing onnx - def build_onnx(self, models: Mapping[str, Sequence]): + def build_onnx( + self, + models: Mapping[str, Sequence], + input_name: str = "input", + output_name: str = "tau", + ): check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op @@ -145,45 +150,36 @@ def build_onnx(self, models: Mapping[str, Sequence]): from spox import Tensor, argument, build, inline self._validate_onnx_models(models, {VARIANT_OUTCOME_MODEL}) + self._validate_feature_set_all() input_dtype, input_shape = infer_dtype_and_shape_onnx( models[VARIANT_OUTCOME_MODEL][0].graph.input[0] ) if not self.is_classification: - output_index = 0 - output_name = models[VARIANT_OUTCOME_MODEL][0].graph.output[0].name + model_output_name = models[VARIANT_OUTCOME_MODEL][0].graph.output[0].name else: - output_index, output_name = infer_probabilities_output( + _, model_output_name = infer_probabilities_output( models[VARIANT_OUTCOME_MODEL][0] ) - output_dtype, output_shape = infer_dtype_and_shape_onnx( - models[VARIANT_OUTCOME_MODEL][0].graph.output[output_index] - ) - input_tensor = argument(Tensor(input_dtype, input_shape)) - a = argument(Tensor(output_dtype, output_shape)) - b = argument(Tensor(output_dtype, output_shape)) - subtraction = op.sub(a, b) - sub_model = build({"a": a, "b": b}, {"subtraction": subtraction}) - - output_0 = inline(models[VARIANT_OUTCOME_MODEL][0])(input_tensor) + output_0 = inline(models[VARIANT_OUTCOME_MODEL][0])(input_tensor)[ + model_output_name + ] variant_cates = [] for m in models[VARIANT_OUTCOME_MODEL][1:]: - variant_output = inline(m)(input_tensor) + variant_output = inline(m)(input_tensor)[model_output_name] + variant_cate = op.sub(variant_output, output_0) variant_cates.append( op.unsqueeze( - inline(sub_model)( - variant_output[output_name], - output_0[output_name], - )["subtraction"], + variant_cate, axes=op.constant(value_int=1), ) ) cate = op.concat(variant_cates, axis=1) - final_model = build({"input": input_tensor}, {"tau": cate}) + final_model = build({input_name: input_tensor}, {output_name: cate}) check_model(final_model, full_check=True) return final_model diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 943db789..2e0506e8 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -407,7 +407,12 @@ def _pseudo_outcome( return imputed_te_control, imputed_te_treatment - def build_onnx(self, models: Mapping[str, Sequence]): + def build_onnx( + self, + models: Mapping[str, Sequence], + input_name: str = "input", + output_name: str = "tau", + ): check_onnx_installed() check_spox_installed() import spox.opset.ai.onnx.v21 as op @@ -417,14 +422,7 @@ def build_onnx(self, models: Mapping[str, Sequence]): self._validate_onnx_models( models, {PROPENSITY_MODEL, CONTROL_EFFECT_MODEL, TREATMENT_EFFECT_MODEL} ) - - # TODO: move this validation to metalearner level as it will be common for all - for model_kind, feature_set in self.feature_set.items(): - if feature_set is not None: - raise ValueError( - "ONNX conversion can only be used if all base models use all the " - "features." - ) + self._validate_feature_set_all() # All models should have the same input dtype and shape input_dtype, input_shape = infer_dtype_and_shape_onnx( @@ -432,16 +430,16 @@ def build_onnx(self, models: Mapping[str, Sequence]): ) input_tensor = argument(Tensor(input_dtype, input_shape)) - output_name = models[CONTROL_EFFECT_MODEL][0].graph.output[0].name + treatment_output_name = models[CONTROL_EFFECT_MODEL][0].graph.output[0].name tau_hat_control: list[Var] = [] for m in models[CONTROL_EFFECT_MODEL]: - tau_hat_control.append(inline(m)(input_tensor)[output_name]) + tau_hat_control.append(inline(m)(input_tensor)[treatment_output_name]) tau_hat_effect: list[Var] = [] for m in models[TREATMENT_EFFECT_MODEL]: - tau_hat_effect.append(inline(m)(input_tensor)[output_name]) + tau_hat_effect.append(inline(m)(input_tensor)[treatment_output_name]) - propensity_output_index, propensity_output_name = infer_probabilities_output( + _, propensity_output_name = infer_probabilities_output( models[PROPENSITY_MODEL][0] ) @@ -479,6 +477,6 @@ def build_onnx(self, models: Mapping[str, Sequence]): tau_hat.append(tau_hat_tv) cate = op.concat(tau_hat, axis=1) - final_model = build({"input": input_tensor}, {"tau": cate}) + final_model = build({input_name: input_tensor}, {output_name: cate}) check_model(final_model, full_check=True) return final_model diff --git a/tests/test_xlearner.py b/tests/test_xlearner.py index e8feecf8..be6c110b 100644 --- a/tests/test_xlearner.py +++ b/tests/test_xlearner.py @@ -142,4 +142,4 @@ def test_xlearner_onnx( ["tau", "Div_1_C"], {"input": X.astype(np.float32)}, ) - np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=1e-6) + np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=1e-5)