diff --git a/pkgs/development/python-modules/orbax-checkpoint/default.nix b/pkgs/development/python-modules/orbax-checkpoint/default.nix index 511542d0a4902..5c815f85dd801 100644 --- a/pkgs/development/python-modules/orbax-checkpoint/default.nix +++ b/pkgs/development/python-modules/orbax-checkpoint/default.nix @@ -11,14 +11,13 @@ # dependencies etils, humanize, - importlib-resources, jax, - jaxlib, msgpack, nest-asyncio, numpy, protobuf, pyyaml, + simplejson, tensorstore, typing-extensions, @@ -26,20 +25,21 @@ chex, google-cloud-logging, mock, + optax, pytest-xdist, pytestCheckHook, }: buildPythonPackage rec { pname = "orbax-checkpoint"; - version = "0.6.4"; + version = "0.10.3"; pyproject = true; src = fetchFromGitHub { owner = "google"; repo = "orbax"; - rev = "refs/tags/v${version}"; - hash = "sha256-xd75/AKBFUdA6a8sQnCB2rVbHl/Foy4LTb07jnwrTjA="; + tag = "v${version}"; + hash = "sha256-BTg4kUz5jfoK2uR/deqqJb8PYoj+FfkuoMZAeSjKKnA="; }; sourceRoot = "${src.name}/checkpoint"; @@ -50,14 +50,13 @@ buildPythonPackage rec { absl-py etils humanize - importlib-resources jax - jaxlib msgpack nest-asyncio numpy protobuf pyyaml + simplejson tensorstore typing-extensions ]; @@ -66,6 +65,7 @@ buildPythonPackage rec { chex google-cloud-logging mock + optax pytest-xdist pytestCheckHook ]; @@ -84,14 +84,19 @@ buildPythonPackage rec { disabledTestPaths = [ # Circular dependency flax + "orbax/checkpoint/_src/metadata/empty_values_test.py" + "orbax/checkpoint/_src/metadata/tree_rich_types_test.py" + "orbax/checkpoint/_src/metadata/tree_test.py" + "orbax/checkpoint/_src/testing/test_tree_utils.py" + "orbax/checkpoint/_src/tree/utils_test.py" + "orbax/checkpoint/single_host_test.py" "orbax/checkpoint/transform_utils_test.py" - "orbax/checkpoint/utils_test.py" ]; meta = { description = "Orbax provides common utility libraries for JAX users"; homepage = "https://github.com/google/orbax/tree/main/checkpoint"; - changelog = "https://github.com/google/orbax/releases/tag/v${version}"; + changelog = "https://github.com/google/orbax/blob/v${version}/checkpoint/CHANGELOG.md"; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ fab ]; };