You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello. When importing the latest packaged version of Jax-Triton (jax-triton==0.1.3), it returns an error as follows:
[/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/lowering.py](https://localhost:8080/#) in <module>
33 from jax._src.state import primitives as sp
34 from jax._src.state import discharge
---> 35 from jax._src.state import ShapedArrayRef
36 from jax_triton.triton_lib import get_triton_type
37 import jax.numpy as jnp
ImportError: cannot import name 'ShapedArrayRef' from 'jax._src.state' (/usr/local/lib/python3.10/dist-packages/jax/_src/state/__init__.py)
It appears that the deprecation of this method in Jax version 0.4.12 breaks the package initialization, as far as I can test.
The text was updated successfully, but these errors were encountered:
Hello. When importing the latest packaged version of Jax-Triton (jax-triton==0.1.3), it returns an error as follows:
It appears that the deprecation of this method in Jax version 0.4.12 breaks the package initialization, as far as I can test.
The text was updated successfully, but these errors were encountered: