Skip to content

Commit

Permalink
Update JAX CUDA release
Browse files Browse the repository at this point in the history
  • Loading branch information
mariecwhite committed Apr 8, 2024
1 parent 7a2f9c4 commit 208593d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ if [ -z "$WITH_CUDA" ]; then
python -m pip install --upgrade "jax[cpu]" "flax"
else
echo "Installing jax and dependencies with cuda support"
python -m pip install --upgrade "jax[cuda11_local]" "flax" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install --upgrade "jax[cuda11]" "flax" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
fi

# Install Tensorflow to generate TFLite models.
Expand Down
2 changes: 1 addition & 1 deletion comparative_benchmark/jax/setup_venv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ if [ -z "$WITH_CUDA" ]; then
python -m pip install --upgrade "jax[cpu]" "flax"
else
echo "Installing jax and dependencies with cuda support"
python -m pip install --upgrade "jax[cuda11_local]" "flax" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install --upgrade "jax[cuda11]" "flax" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
fi

python -m pip install --upgrade -r "${TD}/requirements.txt"
Expand Down

0 comments on commit 208593d

Please sign in to comment.