diff --git a/vit_jax/requirements.txt b/vit_jax/requirements.txt index 35efce6..fb1cf46 100644 --- a/vit_jax/requirements.txt +++ b/vit_jax/requirements.txt @@ -6,7 +6,7 @@ einops>=0.3.0 flax>=0.6.4 git+https://github.com/google/flaxformer -jax[cuda11_cudnn82]>=0.4.2 +jax[cuda11_cudnn86]>=0.4.2 --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ml-collections>=0.1.0