Skip to content

Commit

Permalink
Avoid indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 24, 2024
1 parent de5d495 commit beba0a2
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docs/examples/example_onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@
" onnx_model.SerializeToString(), providers=rt.get_available_providers()\n",
")\n",
"\n",
"pred_onnx = sess.run(\n",
"(pred_onnx,) = sess.run(\n",
" [\"tau\"],\n",
" {\"X\": X_onnx},\n",
")"
Expand All @@ -307,7 +307,7 @@
"outputs": [],
"source": [
"np.testing.assert_allclose(\n",
" xlearner.predict(df[feature_columns], True, \"overall\"), pred_onnx[0], atol=1e-6\n",
" xlearner.predict(df[feature_columns], True, \"overall\"), pred_onnx, atol=1e-6\n",
")"
]
},
Expand Down
4 changes: 2 additions & 2 deletions tests/test_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def test_drlearner_onnx(
else:
onnx_X = X.astype(np.float32)

pred_onnx = sess.run(
(pred_onnx,) = sess.run(
["tau"],
{"X": onnx_X},
)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=5e-4)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx, atol=5e-4)
4 changes: 2 additions & 2 deletions tests/test_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def test_rlearner_onnx(
else:
onnx_X = X.astype(np.float32)

pred_onnx = sess.run(
(pred_onnx,) = sess.run(
["tau"],
{"X": onnx_X},
)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=5e-4)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx, atol=5e-4)
4 changes: 2 additions & 2 deletions tests/test_tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def test_tlearner_onnx(
else:
onnx_X = X.astype(np.float32)

pred_onnx = sess.run(
(pred_onnx,) = sess.run(
["tau"],
{"X": onnx_X},
)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=5e-4)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx, atol=5e-4)
4 changes: 2 additions & 2 deletions tests/test_xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def test_xlearner_onnx(
final.SerializeToString(), providers=rt.get_available_providers()
)

pred_onnx = sess.run(
(pred_onnx,) = sess.run(
["tau"],
{"X": X.astype(np.float32)},
)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=5e-4)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx, atol=5e-4)

0 comments on commit beba0a2

Please sign in to comment.