diff --git a/pyproject.toml b/pyproject.toml index 4af3e929..5e3db131 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ description = "Differentiable and accelerated spherical transforms with JAX" dependencies = [ "numpy>=1.20", - "jax>=0.3.13", + "jax>=0.3.13,<0.6.0", "jaxlib", "torch", ]