diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index d14aa5e..9c0b9b7 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -498,7 +498,7 @@ def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): tau_hat_tv = op.add( op.mul(scaled_propensity, tau_hat_control[tv]), op.mul( - op.sub(op.const(1), scaled_propensity), + op.sub(op.constant(value_float=1), scaled_propensity), tau_hat_effect[tv], ), )