From 8f1e9e00a48735610e61ab7e55003cd121bbd80d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Thu, 25 Jul 2024 09:01:18 +0200 Subject: [PATCH] Fix XLearner float value --- metalearners/xlearner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index d14aa5e..9c0b9b7 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -498,7 +498,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): tau_hat_tv = op.add( op.mul(scaled_propensity, tau_hat_control[tv]), op.mul( - op.sub(op.const(1), scaled_propensity), + op.sub(op.constant(value_float=1), scaled_propensity), tau_hat_effect[tv], ), )