Skip to content

Commit

Permalink
Update requirements-jax-cuda.txt for clarity and compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
saiabhinav001 authored Jan 6, 2025
1 parent 881d8da commit f606057
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions requirements-jax-cuda.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Tensorflow cpu-only version (needed for testing).
tensorflow-cpu~=2.18.0
# TensorFlow CPU-only version (needed for testing)
tensorflow-cpu~=2.18.0 # Compatible with tf2onnx
tf2onnx

# Torch cpu-only version (needed for testing).
# Torch CPU-only version (needed for testing)
--extra-index-url https://download.pytorch.org/whl/cpu
torch>=2.1.0
torch>=2.1.0 # Ensure compatibility with torchvision>=0.16.0
torchvision>=0.16.0
torch-xla
torch-xla # Typically for TPU; no conflict with CPU builds

# Jax with cuda support.
# TODO: Higher version breaks CI.
# JAX with CUDA support
# Using pinned version to avoid CI issues with newer releases
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12]==0.4.28
flax
jax[cuda12]==0.4.28 # CUDA 12 support; pinned for CI stability
flax # For neural network modeling with JAX

# Common requirements
-r requirements-common.txt

0 comments on commit f606057

Please sign in to comment.