diff --git a/src/jaxgym/jax/pytree_space.py b/src/jaxgym/jax/pytree_space.py index 081c66858..afdc19bb9 100644 --- a/src/jaxgym/jax/pytree_space.py +++ b/src/jaxgym/jax/pytree_space.py @@ -227,9 +227,7 @@ def contains(self, x: jtp.PyTree) -> bool: def is_inside_bounds(x, low, high): return jax.lax.select( - pred=jnp.alltrue( - jnp.array([jnp.alltrue(x >= low), jnp.alltrue(x <= high)]) - ), + pred=x.size == 0 or jnp.all((x >= low) & (x <= high)), on_true=True, on_false=False, )