diff --git a/tests/test_xlearner.py b/tests/test_xlearner.py index 521e06b..4310789 100644 --- a/tests/test_xlearner.py +++ b/tests/test_xlearner.py @@ -129,17 +129,13 @@ def test_xlearner_onnx( onnx_models[PROPENSITY_MODEL].append(onnx_model) final = ml.build_onnx(onnx_models) - intermediate_tensor_name = "Div_1_C" - intermediate_layer_value_info = onnx.helper.ValueInfoProto() - intermediate_layer_value_info.name = intermediate_tensor_name - final.graph.output.extend([intermediate_layer_value_info]) sess = rt.InferenceSession( final.SerializeToString(), providers=rt.get_available_providers() ) pred_onnx = sess.run( - ["tau", "Div_1_C"], + ["tau"], {"input": X.astype(np.float32)}, ) np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=5e-4)