-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update requirements-jax-cuda.txt for clarity and compatibility
- Loading branch information
1 parent
881d8da
commit f606057
Showing
1 changed file
with
10 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,18 @@ | ||
# Tensorflow cpu-only version (needed for testing). | ||
tensorflow-cpu~=2.18.0 | ||
# TensorFlow CPU-only version (needed for testing) | ||
tensorflow-cpu~=2.18.0 # Compatible with tf2onnx | ||
tf2onnx | ||
|
||
# Torch cpu-only version (needed for testing). | ||
# Torch CPU-only version (needed for testing) | ||
--extra-index-url https://download.pytorch.org/whl/cpu | ||
torch>=2.1.0 | ||
torch>=2.1.0 # Ensure compatibility with torchvision>=0.16.0 | ||
torchvision>=0.16.0 | ||
torch-xla | ||
torch-xla # Typically for TPU; no conflict with CPU builds | ||
|
||
# Jax with cuda support. | ||
# TODO: Higher version breaks CI. | ||
# JAX with CUDA support | ||
# Using pinned version to avoid CI issues with newer releases | ||
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
jax[cuda12]==0.4.28 | ||
flax | ||
jax[cuda12]==0.4.28 # CUDA 12 support; pinned for CI stability | ||
flax # For neural network modeling with JAX | ||
|
||
# Common requirements | ||
-r requirements-common.txt |