Skip to content

Commit

Permalink
python312Packages.jax-cuda12-plugin: init at 0.4.36
Browse files Browse the repository at this point in the history
  • Loading branch information
natsukium authored and GaetanLepage committed Dec 26, 2024
1 parent 873d0cb commit 990ef5d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
57 changes: 57 additions & 0 deletions pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
lib,
stdenv,
buildPythonPackage,
fetchurl,
pypaInstallHook,
wheelUnpackHook,
jaxlib,
cudaPackages,
}:
let
inherit (jaxlib) version;
inherit (cudaPackages) cudaVersion;
srcs = {
"x86_64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl";
hash = "sha256-48NwXY231j2pq/rr8G9c0GZ/WssHSKXF6wDYAEHpIu0=";
};
"aarch64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl";
hash = "sha256-HfwL7AgguoAbYelCEGS25YI4xDC0rY9UBDMj2TwCF8Y=";
};
};
in
buildPythonPackage {
pname = "jax-cuda12-pjrt";
inherit version;
pyproject = false;

src =
srcs.${stdenv.hostPlatform.system}
or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}");

nativeBuildInputs = [
pypaInstallHook
wheelUnpackHook
];

# no tests
doCheck = false;

# does not work on its own
dontUsePythonImportsCheck = true;

meta = {
description = "JAX XLA PJRT Plugin for NVIDIA GPUs";
homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
license = lib.licenses.asl20;
maintainers = with lib.maintainers; [ natsukium ];
platforms = lib.attrNames srcs;
# see CUDA compatibility matrix
# https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
broken =
!(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
};
}
13 changes: 13 additions & 0 deletions pkgs/development/python-modules/jax-cuda12-plugin/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
mkPythonMetaPackage,
jax-cuda12-pjrt,
}:

mkPythonMetaPackage {
pname = "jax-cuda12-plugin";
inherit (jax-cuda12-pjrt) version;
dependencies = [ jax-cuda12-pjrt ];
meta = {
inherit (jax-cuda12-pjrt.meta) description homepage;
};
}
4 changes: 4 additions & 0 deletions pkgs/top-level/python-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -6560,6 +6560,10 @@ self: super: with self; {

jax = callPackage ../development/python-modules/jax { };

jax-cuda12-pjrt = callPackage ../development/python-modules/jax-cuda12-pjrt { };

jax-cuda12-plugin = callPackage ../development/python-modules/jax-cuda12-plugin { };

jax-jumpy = callPackage ../development/python-modules/jax-jumpy { };

jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix {
Expand Down

0 comments on commit 990ef5d

Please sign in to comment.