Skip to content

Commit

Permalink
Merge branch 'sasha/jax/compile-init' into sasha/jax/jit
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Aug 17, 2024
2 parents 18f7fe2 + d8555c2 commit 26bc15e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions tests/backend/test_jax_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def test_abs():

executable = JaxExecutable.compile(module)

assert executable.execute([array(-2, dtype=jax.numpy.int32)])[0] == array(
2, dtype=jax.numpy.int32
)
assert executable.execute([array(0, dtype=jax.numpy.int32)])[0] == array(
0, dtype=jax.numpy.int32
)
assert executable.execute([array(2, dtype=jax.numpy.int32)])[0] == array(
2, dtype=jax.numpy.int32
)
assert executable.execute([array(-2, dtype=jax.numpy.int32)]) == [
array(2, dtype=jax.numpy.int32)
]
assert executable.execute([array(0, dtype=jax.numpy.int32)]) == [
array(0, dtype=jax.numpy.int32)
]
assert executable.execute([array(2, dtype=jax.numpy.int32)]) == [
array(2, dtype=jax.numpy.int32)
]

@executable
def abs_tuple(a: jax.Array) -> tuple[jax.Array]: ...
Expand Down

0 comments on commit 26bc15e

Please sign in to comment.