From d8555c2ed8ad382d7411b03c47ab634cf5338316 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sat, 17 Aug 2024 19:23:26 +0100 Subject: [PATCH] array --- tests/backend/test_jax_executable.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/backend/test_jax_executable.py b/tests/backend/test_jax_executable.py index ff4264916d..3d49643ea4 100644 --- a/tests/backend/test_jax_executable.py +++ b/tests/backend/test_jax_executable.py @@ -24,15 +24,15 @@ def test_abs(): executable = JaxExecutable.compile(module) - assert executable.execute([array(-2, dtype=jax.numpy.int32)])[0] == array( - 2, dtype=jax.numpy.int32 - ) - assert executable.execute([array(0, dtype=jax.numpy.int32)])[0] == array( - 0, dtype=jax.numpy.int32 - ) - assert executable.execute([array(2, dtype=jax.numpy.int32)])[0] == array( - 2, dtype=jax.numpy.int32 - ) + assert executable.execute([array(-2, dtype=jax.numpy.int32)]) == [ + array(2, dtype=jax.numpy.int32) + ] + assert executable.execute([array(0, dtype=jax.numpy.int32)]) == [ + array(0, dtype=jax.numpy.int32) + ] + assert executable.execute([array(2, dtype=jax.numpy.int32)]) == [ + array(2, dtype=jax.numpy.int32) + ] def test_no_main():