From 0db5e9035379e0cd61f05f06af1d94385f9f4e42 Mon Sep 17 00:00:00 2001 From: Amr Kayid Date: Tue, 21 May 2024 17:47:14 +0000 Subject: [PATCH] upgrade jax --- poetry.lock | 108 ++++++++++++++++++++++++++++--------------------- pyproject.toml | 28 ++++++++++--- 2 files changed, 83 insertions(+), 53 deletions(-) diff --git a/poetry.lock b/poetry.lock index 3722afe..097ee91 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1129,13 +1129,13 @@ numpy = ">=1.17.3" [[package]] name = "huggingface-hub" -version = "0.23.0" +version = "0.23.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"}, - {file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"}, + {file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"}, + {file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"}, ] [package.dependencies] @@ -1257,13 +1257,13 @@ test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "num [[package]] name = "jax" -version = "0.4.26" +version = "0.4.28" description = "Differentiate, compile, and transform Numpy code." optional = false python-versions = ">=3.9" files = [ - {file = "jax-0.4.26-py3-none-any.whl", hash = "sha256:50dc795148ee6b0735b48b477e5abc556aa3a4c7af5d6940dad08024a908b02f"}, - {file = "jax-0.4.26.tar.gz", hash = "sha256:2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383"}, + {file = "jax-0.4.28-py3-none-any.whl", hash = "sha256:6a181e6b5a5b1140e19cdd2d5c4aa779e4cb4ec627757b918be322d8e81035ba"}, + {file = "jax-0.4.28.tar.gz", hash = "sha256:dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9"}, ] [package.dependencies] @@ -1274,43 +1274,43 @@ scipy = ">=1.9" [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.25)"] -cpu = ["jaxlib (==0.4.26)"] -cuda = ["jaxlib (==0.4.26+cuda12.cudnn89)"] -cuda12 = ["jax-cuda12-plugin (==0.4.26)", "jaxlib (==0.4.26)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] -cuda12-cudnn89 = ["jaxlib (==0.4.26+cuda12.cudnn89)"] -cuda12-local = ["jaxlib (==0.4.26+cuda12.cudnn89)"] -cuda12-pip = ["jaxlib (==0.4.26+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] -minimum-jaxlib = ["jaxlib (==0.4.20)"] -tpu = ["jaxlib (==0.4.26)", "libtpu-nightly (==0.1.dev20240403)", "requests"] +ci = ["jaxlib (==0.4.27)"] +cpu = ["jaxlib (==0.4.28)"] +cuda = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12 = ["jax-cuda12-plugin (==0.4.28)", "jaxlib (==0.4.28)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +cuda12-cudnn89 = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12-local = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.28+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +minimum-jaxlib = ["jaxlib (==0.4.27)"] +tpu = ["jaxlib (==0.4.28)", "libtpu-nightly (==0.1.dev20240508)", "requests"] [[package]] name = "jaxlib" -version = "0.4.26" +version = "0.4.28" description = "XLA library for JAX" optional = false python-versions = ">=3.9" files = [ - {file = "jaxlib-0.4.26-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:f9060fd81d1b2a2c2069e998db2be04853c40a244ab9edb1caf1c5cbd2f70881"}, - {file = "jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e61ded57e05764350f7065583d85ab9270e7f7ed6b9f9d9394fe6ff64d96aab7"}, - {file = "jaxlib-0.4.26-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d483aff58898bf37e341d6241cecb3e107aebe4ca237fe6267d4c18b7c09ea90"}, - {file = "jaxlib-0.4.26-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:22e943ed10faa7d85fd804ddc20581bf6fbcc60e114435c3b3327b9f1ebff895"}, - {file = "jaxlib-0.4.26-cp310-cp310-win_amd64.whl", hash = "sha256:eb0cc16efc6313eb100688a38078061caa3c907ebfa1d315485a08fd27f374dc"}, - {file = "jaxlib-0.4.26-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:de4b9a54cd96e6a732c1cf65ae2defdf6a01558a15db8bf6dbd8f40d363b085d"}, - {file = "jaxlib-0.4.26-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8cbdcb95ac73f80ea3a82a53b8f0621f37dfb01ab0203de6fc6691a4e2396984"}, - {file = "jaxlib-0.4.26-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:6754ee0d4dd44f708c3826c51ce648c5e08cfc56cabf23d4f3b428971ab00094"}, - {file = "jaxlib-0.4.26-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:3069da7d75f5b4dd15350fffe6e6b86ca09c4b9fde60b10515edb09cef653335"}, - {file = "jaxlib-0.4.26-cp311-cp311-win_amd64.whl", hash = "sha256:516d2b573975bd666278badd650620d5edf3abb835978459d78f135e95419b04"}, - {file = "jaxlib-0.4.26-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:2c66fe8285fed13bcd44b7e10aa90a25a4a58af82450a4b18d0f1573c04a7797"}, - {file = "jaxlib-0.4.26-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9275761ae907ec0e812031bbae644f7a217e314e62d518c85d60ce686d3a3b0b"}, - {file = "jaxlib-0.4.26-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:597df919f00646d3f9c6feb2a39c9fa0fca00032f19cfe758916db3db30d416a"}, - {file = "jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:f084dd65f2b3cd804102d9cecf938d876cbbd54cb95308634020fc71b98fac79"}, - {file = "jaxlib-0.4.26-cp312-cp312-win_amd64.whl", hash = "sha256:72f117535d6dbc568adbcf6e1740037e0fe1d6e5b9558ea4158556005cf72bfc"}, - {file = "jaxlib-0.4.26-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:0fea35b04cce0a6a758fd005132c02122dd49be5914d70c7d54e8eafdf3f352b"}, - {file = "jaxlib-0.4.26-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:520f71795c411b41cbea13488f1b17610780d7d9afc02ac5f9931a8c975780cb"}, - {file = "jaxlib-0.4.26-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d7bbb75f8e63c5ada57a386b7bfaac301f689149ba132509ccd0c865b2ebd4d2"}, - {file = "jaxlib-0.4.26-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:80b2440072da25d85634e98f755b8381781bd4c1ab4023b2ae0956c360124080"}, - {file = "jaxlib-0.4.26-cp39-cp39-win_amd64.whl", hash = "sha256:96c9a183d7a56572a5c1508de317a05badddfbbc8370a8fa8a2e548d5e059dc3"}, + {file = "jaxlib-0.4.28-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:a421d237f8c25d2850166d334603c673ddb9b6c26f52bc496704b8782297bd66"}, + {file = "jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f038e68bd10d1a3554722b0bbe36e6a448384437a75aa9d283f696f0ed9f8c09"}, + {file = "jaxlib-0.4.28-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:fabe77c174e9e196e9373097cefbb67e00c7e5f9d864583a7cfcf9dabd2429b6"}, + {file = "jaxlib-0.4.28-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:e3bcdc6f8e60f8554f415c14d930134e602e3ca33c38e546274fd545f875769b"}, + {file = "jaxlib-0.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:a8b31c0e5eea36b7915696b9be40ea8646edc395a3e5437bf7ef26b7239a567a"}, + {file = "jaxlib-0.4.28-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2ff8290edc7b92c7eae52517f65492633e267b2e9067bad3e4c323d213e77cf5"}, + {file = "jaxlib-0.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:793857faf37f371cafe752fea5fc811f435e43b8fb4b502058444a7f5eccf829"}, + {file = "jaxlib-0.4.28-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b41a6b0d506c09f86a18ecc05bd376f072b548af89c333107e49bb0c09c1a3f8"}, + {file = "jaxlib-0.4.28-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:45ce0f3c840cff8236cff26c37f26c9ff078695f93e0c162c320c281f5041275"}, + {file = "jaxlib-0.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:d4d762c3971d74e610a0e85a7ee063cea81a004b365b2a7dc65133f08b04fac5"}, + {file = "jaxlib-0.4.28-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:d6c09a545329722461af056e735146d2c8c74c22ac7426a845eb69f326b4f7a0"}, + {file = "jaxlib-0.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8dd8bffe3853702f63cd924da0ee25734a4d19cd5c926be033d772ba7d1c175d"}, + {file = "jaxlib-0.4.28-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:de2e8521eb51e16e85093a42cb51a781773fa1040dcf9245d7ea160a14ee5a5b"}, + {file = "jaxlib-0.4.28-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:46a1aa857f4feee8a43fcba95c0e0ab62d40c26cc9730b6c69655908ba359f8d"}, + {file = "jaxlib-0.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:eee428eac31697a070d655f1f24f6ab39ced76750d93b1de862377a52dcc2401"}, + {file = "jaxlib-0.4.28-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:4f98cc837b2b6c6dcfe0ab7ff9eb109314920946119aa3af9faa139718ff2787"}, + {file = "jaxlib-0.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b01562ec8ad75719b7d0389752489e97eb6b4dcb4c8c113be491634d5282ad3c"}, + {file = "jaxlib-0.4.28-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:aa77a9360a395ba9faf6932df637686fb0c14ddcf4fdc1d2febe04bc88a580a6"}, + {file = "jaxlib-0.4.28-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4a56ebf05b4a4c1791699d874e072f3f808f0986b4010b14fb549a69c90ca9dc"}, + {file = "jaxlib-0.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:459a4ddcc3e120904b9f13a245430d7801d707bca48925981cbdc59628057dc8"}, ] [package.dependencies] @@ -1461,6 +1461,21 @@ files = [ {file = "libclang-18.1.1.tar.gz", hash = "sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250"}, ] +[[package]] +name = "libtpu-nightly" +version = "0.1.dev20240521" +description = "Wrapper for the libtpu library." +optional = false +python-versions = "*" +files = [ + {file = "libtpu_nightly-0.1.dev20240521-py3-none-any.whl", hash = "sha256:24bc4b9ecaaa47fc1cba920cf274eb93f9853ce207abb9516181eeaa4b447633"}, +] + +[package.source] +type = "legacy" +url = "https://storage.googleapis.com/jax-releases/libtpu_releases.html" +reference = "jax_tpu" + [[package]] name = "linkify-it-py" version = "2.0.3" @@ -3383,13 +3398,13 @@ test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "po [[package]] name = "sentry-sdk" -version = "2.2.0" +version = "2.2.1" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = ">=3.6" files = [ - {file = "sentry_sdk-2.2.0-py2.py3-none-any.whl", hash = "sha256:674f58da37835ea7447fe0e34c57b4a4277fad558b0a7cb4a6c83bcb263086be"}, - {file = "sentry_sdk-2.2.0.tar.gz", hash = "sha256:70eca103cf4c6302365a9d7cf522e7ed7720828910eb23d43ada8e50d1ecda9d"}, + {file = "sentry_sdk-2.2.1-py2.py3-none-any.whl", hash = "sha256:7d617a1b30e80c41f3b542347651fcf90bb0a36f3a398be58b4f06b79c8d85bc"}, + {file = "sentry_sdk-2.2.1.tar.gz", hash = "sha256:8aa2ec825724d8d9d645cab68e6034928b1a6a148503af3e361db3fa6401183f"}, ] [package.dependencies] @@ -3411,7 +3426,7 @@ django = ["django (>=1.8)"] falcon = ["falcon (>=1.4)"] fastapi = ["fastapi (>=0.79.0)"] flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"] -grpcio = ["grpcio (>=1.21.1)"] +grpcio = ["grpcio (>=1.21.1)", "protobuf (>=3.8.0)"] httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] huggingface-hub = ["huggingface-hub (>=0.22)"] @@ -3533,19 +3548,18 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "69.5.1" +version = "70.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-69.5.1-py3-none-any.whl", hash = "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32"}, - {file = "setuptools-69.5.1.tar.gz", hash = "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987"}, + {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, + {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" @@ -4619,4 +4633,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more [metadata] lock-version = "2.0" python-versions = "3.10.12" -content-hash = "b3cba81a7708c1b27f617da00f81fde5fea8666ecfa72b06ddae7e093e44ee1f" +content-hash = "48440346be87ac80ea863235aaf74f1091accd69c505e0d3ea0b9fb446d7940d" diff --git a/pyproject.toml b/pyproject.toml index 92880dd..55fa33b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,9 +8,8 @@ readme = "README.md" [tool.poetry.dependencies] python = "3.10.12" -ruff = "0.4.2" -jax = "0.4.26" -jaxlib = "0.4.26" +jax = "0.4.28" +jaxlib = "0.4.28" ray = {version = "2.20.0", extras = ["default", "data"]} flax = "0.8.3" optax = "0.2.2" @@ -24,18 +23,35 @@ rich = "13.7.1" pydantic = "2.7.1" jmp = "0.0.4" jaxtyping = "0.2.28" -pre-commit = "3.7.0" -ipdb = "0.13.13" beartype = "0.18.5" tensorflow = "2.16.1" tensorflow-datasets = "4.9.4" pillow = "10.3.0" -coverage = "7.5.1" wandb = "0.17.0" ml-collections = "0.1.1" +[[tool.poetry.source]] +name = "jax_tpu" +url = "https://storage.googleapis.com/jax-releases/libtpu_releases.html" +priority = "supplemental" + +[tool.poetry.group.tpu] +optional = true + +[tool.poetry.group.tpu.dependencies] +libtpu-nightly = { version="0.1.dev20240521", source="jax_tpu" } + +[tool.poetry.group.dev] +optional = true + +[tool.poetry.group.dev.dependencies] +ruff = "0.4.2" +pre-commit = "3.7.0" +ipdb = "0.13.13" +coverage = "7.5.1" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"