From 9f1ecfeaba4a36a2965793c9c8e9a96399db5646 Mon Sep 17 00:00:00 2001 From: natsukium Date: Fri, 20 Dec 2024 01:00:54 +0900 Subject: [PATCH] fixup! python312Packages.jax: 0.4.28 -> 0.4.36 --- .../python-modules/jax/default.nix | 19 ++++++++++++------- .../python-modules/jax/test-cuda.nix | 4 +--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index c5bd0fc5c7f85..fd61f154069f6 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -55,9 +55,6 @@ buildPythonPackage rec { # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 JAX_RELEASE = "1"; - # jaxlib is _not_ included in propagatedBuildInputs because there are - # different versions of jaxlib depending on the desired target hardware. The - # JAX project ships separate wheels for CPU, GPU, and TPU. dependencies = [ jaxlib ml-dtypes @@ -66,6 +63,13 @@ buildPythonPackage rec { scipy ]; + optional-dependencies = rec { + cuda = [ jax-cuda12-plugin ]; + cuda12 = cuda; + cuda12_pip = cuda; + cuda12_local = cuda; + }; + nativeCheckInputs = [ cloudpickle hypothesis @@ -158,11 +162,12 @@ buildPythonPackage rec { # # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin passthru.tests = { - test_cuda_jaxlibSource = callPackage ./test-cuda.nix { - jaxlib = jaxlib-build.override { cudaSupport = true; }; - }; + # jaxlib-build is broken as of 2024-12-20 + # test_cuda_jaxlibSource = callPackage ./test-cuda.nix { + # jax = jax.override { jaxlib = jaxlib-build; }; + # }; test_cuda_jaxlibBin = callPackage ./test-cuda.nix { - jaxlib = jaxlib-bin.override { cudaSupport = true; }; + jax = jax.override { jaxlib = jaxlib-bin; }; }; }; diff --git a/pkgs/development/python-modules/jax/test-cuda.nix b/pkgs/development/python-modules/jax/test-cuda.nix index 5aca523f31775..ebeebb7ed88cb 100644 --- a/pkgs/development/python-modules/jax/test-cuda.nix +++ b/pkgs/development/python-modules/jax/test-cuda.nix @@ -1,6 +1,5 @@ { jax, - jaxlib, pkgs, }: @@ -8,8 +7,7 @@ pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax - jaxlib - ]; + ] ++ jax.optional-dependencies.cuda; } '' import jax