From f3295e7e0d70f529685eac5e5fbc8448e9ddede7 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 12 Feb 2024 08:38:48 +0000 Subject: [PATCH] Run tests using JAX as the backend array API (on CPU) (#374) --- .github/workflows/jax-tests.yml | 49 +++++++++++++++++++++++++++++++++ cubed/backend_array_api.py | 33 +++++++++++++++++++--- cubed/storage/virtual.py | 5 ++-- 3 files changed, 81 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/jax-tests.yml diff --git a/.github/workflows/jax-tests.yml b/.github/workflows/jax-tests.yml new file mode 100644 index 00000000..08a62df6 --- /dev/null +++ b/.github/workflows/jax-tests.yml @@ -0,0 +1,49 @@ +name: JAX tests + +on: + schedule: + # Every weekday at 03:53 UTC, see https://crontab.guru/ + - cron: "53 3 * * 1-5" + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + python-version: ["3.9"] + + steps: + - name: Checkout source + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + + - name: Setup Graphviz + uses: ts-graphviz/setup-graphviz@v1 + + - name: Install + run: | + python -m pip install --upgrade pip + python -m pip install -e '.[test]' 'jax[cpu]' + python -m pip uninstall -y lithops # tests don't run on Lithops + + - name: Run tests + run: | + # exclude tests that rely on structured types since JAX doesn't support these + pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick" + env: + CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy + JAX_ENABLE_X64: True diff --git a/cubed/backend_array_api.py b/cubed/backend_array_api.py index 58e4c1a4..b347a444 100644 --- a/cubed/backend_array_api.py +++ b/cubed/backend_array_api.py @@ -1,13 +1,38 @@ +import os +from importlib import import_module + import numpy as np -# The array implementation used for backend operations. -# This must be compatible with the Python Array API standard, although +# The array implementation used for backend operations is stored in the +# namespace variable, and defaults to array_api_compat.nump, unless it +# is overridden by an environment variable. +# It must be compatible with the Python Array API standard, although # some extra functions are used too (nan functions, take_along_axis), # which array_api_compat provides, but other Array API implementations # may not. -import array_api_compat.numpy # noqa: F401 isort:skip -namespace = array_api_compat.numpy +if "CUBED_BACKEND_ARRAY_API_MODULE" in os.environ: + # This code is based on similar code in array_api_tests + xp_name = os.environ["CUBED_BACKEND_ARRAY_API_MODULE"] + _module, _sub = xp_name, None + if "." in xp_name: + _module, _sub = xp_name.split(".", 1) + xp = import_module(_module) + if _sub: + try: + xp = getattr(xp, _sub) + except AttributeError: + # _sub may be a submodule that needs to be imported. WE can't + # do this in every case because some array modules are not + # submodules that can be imported (like mxnet.nd). + xp = import_module(xp_name) + namespace = xp + +else: + import array_api_compat.numpy + + namespace = array_api_compat.numpy + # These functions to convert to/from backend arrays # assume that no extra memory is allocated, by using the diff --git a/cubed/storage/virtual.py b/cubed/storage/virtual.py index c6eb86f3..185681b0 100644 --- a/cubed/storage/virtual.py +++ b/cubed/storage/virtual.py @@ -5,6 +5,7 @@ import zarr from zarr.indexing import BasicIndexer, is_slice +from cubed.backend_array_api import backend_array_to_numpy_array from cubed.backend_array_api import namespace as nxp from cubed.backend_array_api import numpy_array_to_backend_array from cubed.types import T_DType, T_RegularChunks, T_Shape @@ -107,7 +108,7 @@ class VirtualInMemoryArray: def __init__( self, - array: np.ndarray, # TODO: generalise + array: np.ndarray, # TODO: generalise to array API type chunks: T_RegularChunks, max_nbytes: int = 10**6, ): @@ -129,7 +130,7 @@ def __init__( self.chunks = template.chunks self.template = template if array.size > 0: - template[...] = array + template[...] = backend_array_to_numpy_array(array) def __getitem__(self, key): return self.array.__getitem__(key)