From 990ef5de0bd3e466cf607113002b1c07c152b8b1 Mon Sep 17 00:00:00 2001 From: natsukium Date: Fri, 20 Dec 2024 00:59:57 +0900 Subject: [PATCH] python312Packages.jax-cuda12-plugin: init at 0.4.36 --- .../jax-cuda12-pjrt/default.nix | 57 +++++++++++++++++++ .../jax-cuda12-plugin/default.nix | 13 +++++ pkgs/top-level/python-packages.nix | 4 ++ 3 files changed, 74 insertions(+) create mode 100644 pkgs/development/python-modules/jax-cuda12-pjrt/default.nix create mode 100644 pkgs/development/python-modules/jax-cuda12-plugin/default.nix diff --git a/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix b/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix new file mode 100644 index 0000000000000..b4d46a0133b26 --- /dev/null +++ b/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix @@ -0,0 +1,57 @@ +{ + lib, + stdenv, + buildPythonPackage, + fetchurl, + pypaInstallHook, + wheelUnpackHook, + jaxlib, + cudaPackages, +}: +let + inherit (jaxlib) version; + inherit (cudaPackages) cudaVersion; + srcs = { + "x86_64-linux" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl"; + hash = "sha256-48NwXY231j2pq/rr8G9c0GZ/WssHSKXF6wDYAEHpIu0="; + }; + "aarch64-linux" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl"; + hash = "sha256-HfwL7AgguoAbYelCEGS25YI4xDC0rY9UBDMj2TwCF8Y="; + }; + }; +in +buildPythonPackage { + pname = "jax-cuda12-pjrt"; + inherit version; + pyproject = false; + + src = + srcs.${stdenv.hostPlatform.system} + or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}"); + + nativeBuildInputs = [ + pypaInstallHook + wheelUnpackHook + ]; + + # no tests + doCheck = false; + + # does not work on its own + dontUsePythonImportsCheck = true; + + meta = { + description = "JAX XLA PJRT Plugin for NVIDIA GPUs"; + homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; + sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; + license = lib.licenses.asl20; + maintainers = with lib.maintainers; [ natsukium ]; + platforms = lib.attrNames srcs; + # see CUDA compatibility matrix + # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder + broken = + !(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1"); + }; +} diff --git a/pkgs/development/python-modules/jax-cuda12-plugin/default.nix b/pkgs/development/python-modules/jax-cuda12-plugin/default.nix new file mode 100644 index 0000000000000..67fb0ad893b80 --- /dev/null +++ b/pkgs/development/python-modules/jax-cuda12-plugin/default.nix @@ -0,0 +1,13 @@ +{ + mkPythonMetaPackage, + jax-cuda12-pjrt, +}: + +mkPythonMetaPackage { + pname = "jax-cuda12-plugin"; + inherit (jax-cuda12-pjrt) version; + dependencies = [ jax-cuda12-pjrt ]; + meta = { + inherit (jax-cuda12-pjrt.meta) description homepage; + }; +} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 0f2490e23e862..778f3b0e369ba 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -6560,6 +6560,10 @@ self: super: with self; { jax = callPackage ../development/python-modules/jax { }; + jax-cuda12-pjrt = callPackage ../development/python-modules/jax-cuda12-pjrt { }; + + jax-cuda12-plugin = callPackage ../development/python-modules/jax-cuda12-plugin { }; + jax-jumpy = callPackage ../development/python-modules/jax-jumpy { }; jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix {