Skip to content

Commit

Permalink
fixup! python312Packages.jax: 0.4.28 -> 0.4.36
Browse files Browse the repository at this point in the history
  • Loading branch information
natsukium authored and GaetanLepage committed Dec 26, 2024
1 parent 990ef5d commit 9f1ecfe
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
19 changes: 12 additions & 7 deletions pkgs/development/python-modules/jax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ buildPythonPackage rec {
# https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
JAX_RELEASE = "1";

# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU.
dependencies = [
jaxlib
ml-dtypes
Expand All @@ -66,6 +63,13 @@ buildPythonPackage rec {
scipy
];

optional-dependencies = rec {
cuda = [ jax-cuda12-plugin ];
cuda12 = cuda;
cuda12_pip = cuda;
cuda12_local = cuda;
};

nativeCheckInputs = [
cloudpickle
hypothesis
Expand Down Expand Up @@ -158,11 +162,12 @@ buildPythonPackage rec {
#
# NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
passthru.tests = {
test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
jaxlib = jaxlib-build.override { cudaSupport = true; };
};
# jaxlib-build is broken as of 2024-12-20
# test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
# jax = jax.override { jaxlib = jaxlib-build; };
# };
test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
jaxlib = jaxlib-bin.override { cudaSupport = true; };
jax = jax.override { jaxlib = jaxlib-bin; };
};
};

Expand Down
4 changes: 1 addition & 3 deletions pkgs/development/python-modules/jax/test-cuda.nix
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
{
jax,
jaxlib,
pkgs,
}:

pkgs.writers.writePython3Bin "jax-test-cuda"
{
libraries = [
jax
jaxlib
];
] ++ jax.optional-dependencies.cuda;
}
''
import jax
Expand Down

0 comments on commit 9f1ecfe

Please sign in to comment.