From e87671dbbb82bb909046f617750138e43d36eb9a Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 23 Nov 2023 12:55:49 +0000 Subject: [PATCH] Array API backend (#317) * Array api backend prep * Use math ceil rather than numpy as it's built in * Add central definition of the array API namespace used by the backend * Convert blockwise to use backend array API namespace * Use backend array API namespace in Cubed array API implementation * Convert to/from backend array * Add array-api-compat to dependencies * Ignore missing imports for mypy for array-api-compat * Fix dtype in test for mean --- cubed/array_api/array_object.py | 77 +++++++------ cubed/array_api/creation_functions.py | 18 +-- cubed/array_api/data_type_functions.py | 3 +- cubed/array_api/elementwise_functions.py | 119 ++++++++++---------- cubed/array_api/linear_algebra_functions.py | 9 +- cubed/array_api/manipulation_functions.py | 15 ++- cubed/array_api/searching_functions.py | 3 +- cubed/array_api/statistical_functions.py | 21 ++-- cubed/array_api/utility_functions.py | 7 +- cubed/backend_array_api.py | 23 ++++ cubed/core/ops.py | 21 +++- cubed/core/plan.py | 4 +- cubed/nan_functions.py | 16 ++- cubed/primitive/blockwise.py | 7 ++ cubed/random.py | 7 +- cubed/storage/virtual.py | 12 +- cubed/storage/zarr.py | 3 +- cubed/tests/primitive/test_blockwise.py | 7 +- cubed/tests/test_core.py | 5 +- cubed/tests/test_gufunc.py | 7 +- pyproject.toml | 1 + setup.cfg | 2 + 22 files changed, 232 insertions(+), 155 deletions(-) create mode 100644 cubed/backend_array_api.py diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index e4e81392..475acb2c 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -19,6 +19,7 @@ float64, ) from cubed.array_api.linear_algebra_functions import matmul +from cubed.backend_array_api import namespace as nxp from cubed.core.array import CoreArray from cubed.core.ops import elemwise from cubed.utils import memory_repr @@ -118,54 +119,54 @@ def T(self): def __neg__(self, /): if self.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __neg__") - return elemwise(np.negative, self, dtype=self.dtype) + return elemwise(nxp.negative, self, dtype=self.dtype) def __pos__(self, /): if self.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __pos__") - return elemwise(np.positive, self, dtype=self.dtype) + return elemwise(nxp.positive, self, dtype=self.dtype) def __add__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other - return elemwise(np.add, self, other, dtype=result_type(self, other)) + return elemwise(nxp.add, self, other, dtype=result_type(self, other)) def __sub__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other - return elemwise(np.subtract, self, other, dtype=result_type(self, other)) + return elemwise(nxp.subtract, self, other, dtype=result_type(self, other)) def __mul__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other - return elemwise(np.multiply, self, other, dtype=result_type(self, other)) + return elemwise(nxp.multiply, self, other, dtype=result_type(self, other)) def __truediv__(self, other, /): other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other - return elemwise(np.divide, self, other, dtype=result_type(self, other)) + return elemwise(nxp.divide, self, other, dtype=result_type(self, other)) def __floordiv__(self, other, /): other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other - return elemwise(np.floor_divide, self, other, dtype=result_type(self, other)) + return elemwise(nxp.floor_divide, self, other, dtype=result_type(self, other)) def __mod__(self, other, /): other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other - return elemwise(np.remainder, self, other, dtype=result_type(self, other)) + return elemwise(nxp.remainder, self, other, dtype=result_type(self, other)) def __pow__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other - return elemwise(np.power, self, other, dtype=result_type(self, other)) + return elemwise(nxp.pow, self, other, dtype=result_type(self, other)) # Array Operators @@ -180,37 +181,41 @@ def __matmul__(self, other, /): def __invert__(self, /): if self.dtype not in _integer_or_boolean_dtypes: raise TypeError("Only integer or boolean dtypes are allowed in __invert__") - return elemwise(np.invert, self, dtype=self.dtype) + return elemwise(nxp.bitwise_invert, self, dtype=self.dtype) def __and__(self, other, /): other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other - return elemwise(np.bitwise_and, self, other, dtype=result_type(self, other)) + return elemwise(nxp.bitwise_and, self, other, dtype=result_type(self, other)) def __or__(self, other, /): other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other - return elemwise(np.bitwise_or, self, other, dtype=result_type(self, other)) + return elemwise(nxp.bitwise_or, self, other, dtype=result_type(self, other)) def __xor__(self, other, /): other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other - return elemwise(np.bitwise_xor, self, other, dtype=result_type(self, other)) + return elemwise(nxp.bitwise_xor, self, other, dtype=result_type(self, other)) def __lshift__(self, other, /): other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other - return elemwise(np.left_shift, self, other, dtype=result_type(self, other)) + return elemwise( + nxp.bitwise_left_shift, self, other, dtype=result_type(self, other) + ) def __rshift__(self, other, /): other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other - return elemwise(np.right_shift, self, other, dtype=result_type(self, other)) + return elemwise( + nxp.bitwise_right_shift, self, other, dtype=result_type(self, other) + ) # Comparison Operators @@ -218,37 +223,37 @@ def __eq__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__eq__") if other is NotImplemented: return other - return elemwise(np.equal, self, other, dtype=np.bool_) + return elemwise(nxp.equal, self, other, dtype=np.bool_) def __ge__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__ge__") if other is NotImplemented: return other - return elemwise(np.greater_equal, self, other, dtype=np.bool_) + return elemwise(nxp.greater_equal, self, other, dtype=np.bool_) def __gt__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__gt__") if other is NotImplemented: return other - return elemwise(np.greater, self, other, dtype=np.bool_) + return elemwise(nxp.greater, self, other, dtype=np.bool_) def __le__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__le__") if other is NotImplemented: return other - return elemwise(np.less_equal, self, other, dtype=np.bool_) + return elemwise(nxp.less_equal, self, other, dtype=np.bool_) def __lt__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__lt__") if other is NotImplemented: return other - return elemwise(np.less, self, other, dtype=np.bool_) + return elemwise(nxp.less, self, other, dtype=np.bool_) def __ne__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other - return elemwise(np.not_equal, self, other, dtype=np.bool_) + return elemwise(nxp.not_equal, self, other, dtype=np.bool_) # Reflected Operators @@ -258,43 +263,43 @@ def __radd__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other - return elemwise(np.add, other, self, dtype=result_type(self, other)) + return elemwise(nxp.add, other, self, dtype=result_type(self, other)) def __rsub__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other - return elemwise(np.subtract, other, self, dtype=result_type(self, other)) + return elemwise(nxp.subtract, other, self, dtype=result_type(self, other)) def __rmul__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other - return elemwise(np.multiply, other, self, dtype=result_type(self, other)) + return elemwise(nxp.multiply, other, self, dtype=result_type(self, other)) def __rtruediv__(self, other, /): other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other - return elemwise(np.divide, other, self, dtype=result_type(self, other)) + return elemwise(nxp.divide, other, self, dtype=result_type(self, other)) def __rfloordiv__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__") if other is NotImplemented: return other - return elemwise(np.floor_divide, other, self, dtype=result_type(self, other)) + return elemwise(nxp.floor_divide, other, self, dtype=result_type(self, other)) def __rmod__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__rmod__") if other is NotImplemented: return other - return elemwise(np.remainder, other, self, dtype=result_type(self, other)) + return elemwise(nxp.remainder, other, self, dtype=result_type(self, other)) def __rpow__(self, other, /): other = self._check_allowed_dtypes(other, "numeric", "__rpow__") if other is NotImplemented: return other - return elemwise(np.power, other, self, dtype=result_type(self, other)) + return elemwise(nxp.pow, other, self, dtype=result_type(self, other)) # (Reflected) Array Operators @@ -310,31 +315,35 @@ def __rand__(self, other, /): other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other - return elemwise(np.bitwise_and, other, self, dtype=result_type(self, other)) + return elemwise(nxp.bitwise_and, other, self, dtype=result_type(self, other)) def __ror__(self, other, /): other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other - return elemwise(np.bitwise_or, other, self, dtype=result_type(self, other)) + return elemwise(nxp.bitwise_or, other, self, dtype=result_type(self, other)) def __rxor__(self, other, /): other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other - return elemwise(np.bitwise_xor, other, self, dtype=result_type(self, other)) + return elemwise(nxp.bitwise_xor, other, self, dtype=result_type(self, other)) def __rlshift__(self, other, /): other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other - return elemwise(np.left_shift, other, self, dtype=result_type(self, other)) + return elemwise( + nxp.bitwise_left_shift, other, self, dtype=result_type(self, other) + ) def __rrshift__(self, other, /): other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other - return elemwise(np.right_shift, other, self, dtype=result_type(self, other)) + return elemwise( + nxp.bitwise_right_shift, other, self, dtype=result_type(self, other) + ) # Methods @@ -347,7 +356,7 @@ def __abs__(self, /): dtype = float64 else: dtype = self.dtype - return elemwise(np.abs, self, dtype=dtype) + return elemwise(nxp.abs, self, dtype=dtype) def __array_namespace__(self, /, *, api_version=None): if api_version is not None and not api_version.startswith("2021."): diff --git a/cubed/array_api/creation_functions.py b/cubed/array_api/creation_functions.py index a2b758d9..d7bf512f 100644 --- a/cubed/array_api/creation_functions.py +++ b/cubed/array_api/creation_functions.py @@ -1,8 +1,10 @@ +import math from typing import TYPE_CHECKING, Iterable, List import numpy as np from zarr.util import normalize_shape +from cubed.backend_array_api import namespace as nxp from cubed.core import Plan, gensym, map_blocks from cubed.core.ops import map_direct from cubed.core.plan import new_temp_path @@ -20,9 +22,9 @@ def arange( ) -> "Array": if stop is None: start, stop = 0, start - num = int(max(np.ceil((stop - start) / step), 0)) + num = int(max(math.ceil((stop - start) / step), 0)) if dtype is None: - dtype = np.arange(start, stop, step * num if num else step).dtype + dtype = nxp.arange(start, stop, step * num if num else step).dtype chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype) chunksize = chunks[0][0] numblocks = len(chunks[0]) @@ -43,10 +45,10 @@ def arange( def _arange(a, size, start, stop, step): - i = a[0] + i = int(a[0]) blockstart = start + (i * size * step) blockstop = start + ((i + 1) * size * step) - return np.arange(blockstart, min(blockstop, stop), step) + return nxp.arange(blockstart, min(blockstop, stop), step) def asarray( @@ -64,7 +66,7 @@ def asarray( return asarray(a.data) elif not isinstance(getattr(a, "shape", None), Iterable): # ensure blocks are arrays - a = np.asarray(a, dtype=dtype) + a = nxp.asarray(a, dtype=dtype) if dtype is None: dtype = a.dtype @@ -133,9 +135,9 @@ def _eye(x, *arrays, k=None, chunksize=None, block_id=None): i, j = block_id bk = (j - i) * chunksize if bk - chunksize <= k <= bk + chunksize: - return np.eye(x.shape[0], x.shape[1], k=k - bk, dtype=x.dtype) + return nxp.eye(x.shape[0], x.shape[1], k=k - bk, dtype=x.dtype) else: - return np.zeros_like(x) + return nxp.zeros_like(x) def full( @@ -225,7 +227,7 @@ def _linspace(x, *arrays, size, start, step, endpoint, linspace_dtype, block_id= adjusted_bs = bs - 1 if endpoint else bs blockstart = start + (i * size * step) blockstop = blockstart + (adjusted_bs * step) - return np.linspace( + return nxp.linspace( blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype ) diff --git a/cubed/array_api/data_type_functions.py b/cubed/array_api/data_type_functions.py index 1a7dd3a8..79b9f9bd 100644 --- a/cubed/array_api/data_type_functions.py +++ b/cubed/array_api/data_type_functions.py @@ -3,6 +3,7 @@ import numpy as np from numpy.array_api._typing import Dtype +from cubed.backend_array_api import namespace as nxp from cubed.core import CoreArray, map_blocks from .dtypes import ( @@ -25,7 +26,7 @@ def astype(x, dtype, /, *, copy=True): def _astype(a, astype_dtype): - return a.astype(astype_dtype) + return nxp.astype(a, astype_dtype) def can_cast(from_, to, /): diff --git a/cubed/array_api/elementwise_functions.py b/cubed/array_api/elementwise_functions.py index b1d0c2b7..02af44b5 100644 --- a/cubed/array_api/elementwise_functions.py +++ b/cubed/array_api/elementwise_functions.py @@ -15,6 +15,7 @@ float32, float64, ) +from cubed.backend_array_api import namespace as nxp from cubed.core import elemwise @@ -27,55 +28,55 @@ def abs(x, /): dtype = float64 else: dtype = x.dtype - return elemwise(np.abs, x, dtype=dtype) + return elemwise(nxp.abs, x, dtype=dtype) def acos(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in acos") - return elemwise(np.arccos, x, dtype=x.dtype) + return elemwise(nxp.acos, x, dtype=x.dtype) def acosh(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in acosh") - return elemwise(np.arccosh, x, dtype=x.dtype) + return elemwise(nxp.acosh, x, dtype=x.dtype) def add(x1, x2, /): if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in add") - return elemwise(np.add, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.add, x1, x2, dtype=result_type(x1, x2)) def asin(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in asin") - return elemwise(np.arcsin, x, dtype=x.dtype) + return elemwise(nxp.asin, x, dtype=x.dtype) def asinh(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in asinh") - return elemwise(np.arcsinh, x, dtype=x.dtype) + return elemwise(nxp.asinh, x, dtype=x.dtype) def atan(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in atan") - return elemwise(np.arctan, x, dtype=x.dtype) + return elemwise(nxp.atan, x, dtype=x.dtype) def atan2(x1, x2, /): if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in atan2") - return elemwise(np.arctan2, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.atan2, x1, x2, dtype=result_type(x1, x2)) def atanh(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in atanh") - return elemwise(np.arctanh, x, dtype=x.dtype) + return elemwise(nxp.atanh, x, dtype=x.dtype) def bitwise_and(x1, x2, /): @@ -84,19 +85,19 @@ def bitwise_and(x1, x2, /): or x2.dtype not in _integer_or_boolean_dtypes ): raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and") - return elemwise(np.bitwise_and, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.bitwise_and, x1, x2, dtype=result_type(x1, x2)) def bitwise_invert(x, /): if x.dtype not in _integer_or_boolean_dtypes: raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert") - return elemwise(np.invert, x, dtype=x.dtype) + return elemwise(nxp.bitwise_invert, x, dtype=x.dtype) def bitwise_left_shift(x1, x2, /): if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") - return elemwise(np.left_shift, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.bitwise_left_shift, x1, x2, dtype=result_type(x1, x2)) def bitwise_or(x1, x2, /): @@ -105,13 +106,13 @@ def bitwise_or(x1, x2, /): or x2.dtype not in _integer_or_boolean_dtypes ): raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or") - return elemwise(np.bitwise_or, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.bitwise_or, x1, x2, dtype=result_type(x1, x2)) def bitwise_right_shift(x1, x2, /): if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") - return elemwise(np.right_shift, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.bitwise_right_shift, x1, x2, dtype=result_type(x1, x2)) def bitwise_xor(x1, x2, /): @@ -120,7 +121,7 @@ def bitwise_xor(x1, x2, /): or x2.dtype not in _integer_or_boolean_dtypes ): raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor") - return elemwise(np.bitwise_xor, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.bitwise_xor, x1, x2, dtype=result_type(x1, x2)) def ceil(x, /): @@ -129,47 +130,47 @@ def ceil(x, /): if x.dtype in _integer_dtypes: # Note: The return dtype of ceil is the same as the input return x - return elemwise(np.ceil, x, dtype=x.dtype) + return elemwise(nxp.ceil, x, dtype=x.dtype) def conj(x, /): if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in conj") - return elemwise(np.conj, x, dtype=x.dtype) + return elemwise(nxp.conj, x, dtype=x.dtype) def cos(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in cos") - return elemwise(np.cos, x, dtype=x.dtype) + return elemwise(nxp.cos, x, dtype=x.dtype) def cosh(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in cosh") - return elemwise(np.cosh, x, dtype=x.dtype) + return elemwise(nxp.cosh, x, dtype=x.dtype) def divide(x1, x2, /): if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in divide") - return elemwise(np.divide, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.divide, x1, x2, dtype=result_type(x1, x2)) def exp(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in exp") - return elemwise(np.exp, x, dtype=x.dtype) + return elemwise(nxp.exp, x, dtype=x.dtype) def expm1(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in expm1") - return elemwise(np.expm1, x, dtype=x.dtype) + return elemwise(nxp.expm1, x, dtype=x.dtype) def equal(x1, x2, /): - return elemwise(np.equal, x1, x2, dtype=np.bool_) + return elemwise(nxp.equal, x1, x2, dtype=np.bool_) def floor(x, /): @@ -178,21 +179,21 @@ def floor(x, /): if x.dtype in _integer_dtypes: # Note: The return dtype of floor is the same as the input return x - return elemwise(np.floor, x, dtype=x.dtype) + return elemwise(nxp.floor, x, dtype=x.dtype) def floor_divide(x1, x2, /): if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in floor_divide") - return elemwise(np.floor_divide, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.floor_divide, x1, x2, dtype=result_type(x1, x2)) def greater(x1, x2, /): - return elemwise(np.greater, x1, x2, dtype=np.bool_) + return elemwise(nxp.greater, x1, x2, dtype=np.bool_) def greater_equal(x1, x2, /): - return elemwise(np.greater_equal, x1, x2, dtype=np.bool_) + return elemwise(nxp.greater_equal, x1, x2, dtype=np.bool_) def imag(x, /): @@ -202,115 +203,115 @@ def imag(x, /): dtype = float64 else: raise TypeError("Only complex floating-point dtypes are allowed in imag") - return elemwise(np.imag, x, dtype=dtype) + return elemwise(nxp.imag, x, dtype=dtype) def isfinite(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isfinite") - return elemwise(np.isfinite, x, dtype=np.bool_) + return elemwise(nxp.isfinite, x, dtype=np.bool_) def isinf(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isinf") - return elemwise(np.isinf, x, dtype=np.bool_) + return elemwise(nxp.isinf, x, dtype=np.bool_) def isnan(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isnan") - return elemwise(np.isnan, x, dtype=np.bool_) + return elemwise(nxp.isnan, x, dtype=np.bool_) def less(x1, x2, /): - return elemwise(np.less, x1, x2, dtype=np.bool_) + return elemwise(nxp.less, x1, x2, dtype=np.bool_) def less_equal(x1, x2, /): - return elemwise(np.less_equal, x1, x2, dtype=np.bool_) + return elemwise(nxp.less_equal, x1, x2, dtype=np.bool_) def log(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log") - return elemwise(np.log, x, dtype=x.dtype) + return elemwise(nxp.log, x, dtype=x.dtype) def log1p(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log1p") - return elemwise(np.log1p, x, dtype=x.dtype) + return elemwise(nxp.log1p, x, dtype=x.dtype) def log2(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log2") - return elemwise(np.log2, x, dtype=x.dtype) + return elemwise(nxp.log2, x, dtype=x.dtype) def log10(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log10") - return elemwise(np.log10, x, dtype=x.dtype) + return elemwise(nxp.log10, x, dtype=x.dtype) def logaddexp(x1, x2, /): if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in logaddexp") - return elemwise(np.logaddexp, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.logaddexp, x1, x2, dtype=result_type(x1, x2)) def logical_and(x1, x2, /): if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_and") - return elemwise(np.logical_and, x1, x2, dtype=np.bool_) + return elemwise(nxp.logical_and, x1, x2, dtype=np.bool_) def logical_not(x, /): if x.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_not") - return elemwise(np.logical_not, x, dtype=np.bool_) + return elemwise(nxp.logical_not, x, dtype=np.bool_) def logical_or(x1, x2, /): if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_or") - return elemwise(np.logical_or, x1, x2, dtype=np.bool_) + return elemwise(nxp.logical_or, x1, x2, dtype=np.bool_) def logical_xor(x1, x2, /): if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_xor") - return elemwise(np.logical_xor, x1, x2, dtype=np.bool_) + return elemwise(nxp.logical_xor, x1, x2, dtype=np.bool_) def multiply(x1, x2, /): if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in multiply") - return elemwise(np.multiply, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.multiply, x1, x2, dtype=result_type(x1, x2)) def negative(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in negative") - return elemwise(np.negative, x, dtype=x.dtype) + return elemwise(nxp.negative, x, dtype=x.dtype) def not_equal(x1, x2, /): - return elemwise(np.not_equal, x1, x2, dtype=np.bool_) + return elemwise(nxp.not_equal, x1, x2, dtype=np.bool_) def positive(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in positive") - return elemwise(np.positive, x, dtype=x.dtype) + return elemwise(nxp.positive, x, dtype=x.dtype) def pow(x1, x2, /): if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in pow") - return elemwise(np.power, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.pow, x1, x2, dtype=result_type(x1, x2)) def real(x, /): @@ -320,67 +321,67 @@ def real(x, /): dtype = float64 else: raise TypeError("Only complex floating-point dtypes are allowed in real") - return elemwise(np.real, x, dtype=dtype) + return elemwise(nxp.real, x, dtype=dtype) def remainder(x1, x2, /): if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in remainder") - return elemwise(np.remainder, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.remainder, x1, x2, dtype=result_type(x1, x2)) def round(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in round") - return elemwise(np.round, x, dtype=x.dtype) + return elemwise(nxp.round, x, dtype=x.dtype) def sign(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sign") - return elemwise(np.sign, x, dtype=x.dtype) + return elemwise(nxp.sign, x, dtype=x.dtype) def sin(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sin") - return elemwise(np.sin, x, dtype=x.dtype) + return elemwise(nxp.sin, x, dtype=x.dtype) def sinh(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sinh") - return elemwise(np.sinh, x, dtype=x.dtype) + return elemwise(nxp.sinh, x, dtype=x.dtype) def sqrt(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sqrt") - return elemwise(np.sqrt, x, dtype=x.dtype) + return elemwise(nxp.sqrt, x, dtype=x.dtype) def square(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in square") - return elemwise(np.square, x, dtype=x.dtype) + return elemwise(nxp.square, x, dtype=x.dtype) def subtract(x1, x2, /): if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in subtract") - return elemwise(np.subtract, x1, x2, dtype=result_type(x1, x2)) + return elemwise(nxp.subtract, x1, x2, dtype=result_type(x1, x2)) def tan(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in tan") - return elemwise(np.tan, x, dtype=x.dtype) + return elemwise(nxp.tan, x, dtype=x.dtype) def tanh(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in tanh") - return elemwise(np.tanh, x, dtype=x.dtype) + return elemwise(nxp.tanh, x, dtype=x.dtype) def trunc(x, /): @@ -389,4 +390,4 @@ def trunc(x, /): if x.dtype in _integer_dtypes: # Note: The return dtype of trunc is the same as the input return x - return elemwise(np.trunc, x, dtype=x.dtype) + return elemwise(nxp.trunc, x, dtype=x.dtype) diff --git a/cubed/array_api/linear_algebra_functions.py b/cubed/array_api/linear_algebra_functions.py index b5cec4f0..217b4c35 100644 --- a/cubed/array_api/linear_algebra_functions.py +++ b/cubed/array_api/linear_algebra_functions.py @@ -6,6 +6,7 @@ from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import _numeric_dtypes from cubed.array_api.manipulation_functions import expand_dims +from cubed.backend_array_api import namespace as nxp from cubed.core import blockwise, reduction, squeeze @@ -59,7 +60,7 @@ def matmul(x1, x2, /): def _matmul(a, b): - chunk = np.matmul(a, b) + chunk = nxp.matmul(a, b) return chunk[..., np.newaxis, :] @@ -71,7 +72,7 @@ def _sum_wo_cat(a, axis=None, dtype=None): def _chunk_sum(a, axis=None, dtype=None, keepdims=None): - return np.sum(a, axis=axis, dtype=dtype, keepdims=True) + return nxp.sum(a, axis=axis, dtype=dtype, keepdims=True) def matrix_transpose(x, /): @@ -86,7 +87,7 @@ def matrix_transpose(x, /): def outer(x1, x2, /): - return blockwise(np.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype) + return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype) def tensordot(x1, x2, /, *, axes=2): @@ -137,7 +138,7 @@ def tensordot(x1, x2, /, *, axes=2): def _tensordot(a, b, axes): - x = np.tensordot(a, b, axes=axes) + x = nxp.tensordot(a, b, axes=axes) ind = [slice(None, None)] * x.ndim for a in sorted(axes[0]): ind.insert(a, None) diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 64662f87..21fe91ca 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -7,6 +7,8 @@ from toolz import reduce from cubed.array_api.creation_functions import empty +from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import numpy_array_to_backend_array from cubed.core import squeeze # noqa: F401 from cubed.core import blockwise, rechunk, unify_chunks from cubed.core.ops import elemwise, map_blocks, map_direct @@ -70,7 +72,7 @@ def broadcast_to(x, /, shape, *, chunks=None): def _broadcast_like(x, template): - return np.broadcast_to(x, template.shape) + return nxp.broadcast_to(x, template.shape) def concat(arrays, /, *, axis=0): @@ -127,7 +129,7 @@ def _read_concat_chunk(x, *arrays, axis=None, offsets=None, block_id=None): for ai, sl in _array_slices(offsets, start, stop): key = tuple(sl if i == axis else k for i, k in enumerate(key)) parts.append(arrays[ai].zarray[key]) - return np.concatenate(parts, axis=axis) + return nxp.concat(parts, axis=axis) def _array_slices(offsets, start, stop): @@ -151,7 +153,7 @@ def expand_dims(x, /, *, axis): chunks = tuple(1 if i in axis else next(chunks_it) for i in range(ndim_new)) return map_blocks( - np.expand_dims, x, dtype=x.dtype, chunks=chunks, new_axis=axis, axis=axis + nxp.expand_dims, x, dtype=x.dtype, chunks=chunks, new_axis=axis, axis=axis ) @@ -197,7 +199,7 @@ def permute_dims(x, /, axes): # extra memory copy due to Zarr enforcing C order on transposed array extra_projected_mem = x.chunkmem return blockwise( - np.transpose, + nxp.permute_dims, axes, x, tuple(range(x.ndim)), @@ -268,7 +270,8 @@ def _reshape_chunk(e, x, inchunks=None, outchunks=None, block_id=None): out_keys = list(product(*[range(len(c)) for c in outchunks])) idx = in_keys[out_keys.index(block_id)] out = x.zarray[get_item(x.chunks, idx)] - return out.reshape(e.shape) + out = numpy_array_to_backend_array(out) + return nxp.reshape(out, e.shape) def stack(arrays, /, *, axis=0): @@ -305,5 +308,5 @@ def _read_stack_chunk(x, *arrays, axis=None, block_id=None): array = arrays[block_id[axis]] idx = tuple(v for i, v in enumerate(block_id) if i != axis) out = array.zarray[get_item(array.chunks, idx)] - out = np.expand_dims(out, axis=axis) + out = nxp.expand_dims(out, axis=axis) return out diff --git a/cubed/array_api/searching_functions.py b/cubed/array_api/searching_functions.py index 5f13879b..dfe3b318 100644 --- a/cubed/array_api/searching_functions.py +++ b/cubed/array_api/searching_functions.py @@ -3,6 +3,7 @@ from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import _real_numeric_dtypes from cubed.array_api.manipulation_functions import reshape +from cubed.backend_array_api import namespace as nxp from cubed.core.ops import arg_reduction, elemwise @@ -28,4 +29,4 @@ def argmin(x, /, *, axis=None, keepdims=False): def where(condition, x1, x2, /): dtype = result_type(x1, x2) - return elemwise(np.where, condition, x1, x2, dtype=dtype) + return elemwise(nxp.where, condition, x1, x2, dtype=dtype) diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index b4ff7694..3d69b6e4 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -15,13 +15,14 @@ int64, uint64, ) +from cubed.backend_array_api import namespace as nxp from cubed.core import reduction def max(x, /, *, axis=None, keepdims=False): if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in max") - return reduction(x, np.max, axis=axis, dtype=x.dtype, keepdims=keepdims) + return reduction(x, nxp.max, axis=axis, dtype=x.dtype, keepdims=keepdims) def mean(x, /, *, axis=None, keepdims=False): @@ -52,19 +53,19 @@ def mean(x, /, *, axis=None, keepdims=False): def _mean_func(a, **kwargs): dtype = dict(kwargs.pop("dtype")) n = _numel(a, dtype=dtype["n"], **kwargs) - total = np.sum(a, dtype=dtype["total"], **kwargs) + total = nxp.sum(a, dtype=dtype["total"], **kwargs) return {"n": n, "total": total} def _mean_combine(a, **kwargs): dtype = dict(kwargs.pop("dtype")) - n = np.sum(a["n"], dtype=dtype["n"], **kwargs) - total = np.sum(a["total"], dtype=dtype["total"], **kwargs) + n = nxp.sum(a["n"], dtype=dtype["n"], **kwargs) + total = nxp.sum(a["total"], dtype=dtype["total"], **kwargs) return {"n": n, "total": total} def _mean_aggregate(a): - return np.divide(a["total"], a["n"]) + return nxp.divide(a["total"], a["n"]) # based on dask @@ -82,7 +83,7 @@ def _numel(x, **kwargs): if keepdims is False: return prod - return np.full(shape=(1,) * len(shape), fill_value=prod, dtype=dtype) + return nxp.full(shape=(1,) * len(shape), fill_value=prod, dtype=dtype) if not isinstance(axis, (tuple, list)): axis = [axis] @@ -95,13 +96,13 @@ def _numel(x, **kwargs): else: new_shape = tuple(shape[dim] for dim in range(len(shape)) if dim not in axis) - return np.broadcast_to(np.array(prod, dtype=dtype), new_shape) + return nxp.broadcast_to(nxp.asarray(prod, dtype=dtype), new_shape) def min(x, /, *, axis=None, keepdims=False): if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in min") - return reduction(x, np.min, axis=axis, dtype=x.dtype, keepdims=keepdims) + return reduction(x, nxp.min, axis=axis, dtype=x.dtype, keepdims=keepdims) def prod(x, /, *, axis=None, dtype=None, keepdims=False): @@ -121,7 +122,7 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False): extra_func_kwargs = dict(dtype=dtype) return reduction( x, - np.prod, + nxp.prod, axis=axis, dtype=dtype, keepdims=keepdims, @@ -146,7 +147,7 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False): extra_func_kwargs = dict(dtype=dtype) return reduction( x, - np.sum, + nxp.sum, axis=axis, dtype=dtype, keepdims=keepdims, diff --git a/cubed/array_api/utility_functions.py b/cubed/array_api/utility_functions.py index a2d8859d..10916e5c 100644 --- a/cubed/array_api/utility_functions.py +++ b/cubed/array_api/utility_functions.py @@ -1,16 +1,15 @@ -import numpy as np - from cubed.array_api.creation_functions import asarray +from cubed.backend_array_api import namespace as nxp from cubed.core import reduction def all(x, /, *, axis=None, keepdims=False): if x.size == 0: return asarray(True, dtype=x.dtype) - return reduction(x, np.all, axis=axis, dtype=bool, keepdims=keepdims) + return reduction(x, nxp.all, axis=axis, dtype=bool, keepdims=keepdims) def any(x, /, *, axis=None, keepdims=False): if x.size == 0: return asarray(False, dtype=x.dtype) - return reduction(x, np.any, axis=axis, dtype=bool, keepdims=keepdims) + return reduction(x, nxp.any, axis=axis, dtype=bool, keepdims=keepdims) diff --git a/cubed/backend_array_api.py b/cubed/backend_array_api.py new file mode 100644 index 00000000..58e4c1a4 --- /dev/null +++ b/cubed/backend_array_api.py @@ -0,0 +1,23 @@ +import numpy as np + +# The array implementation used for backend operations. +# This must be compatible with the Python Array API standard, although +# some extra functions are used too (nan functions, take_along_axis), +# which array_api_compat provides, but other Array API implementations +# may not. +import array_api_compat.numpy # noqa: F401 isort:skip + +namespace = array_api_compat.numpy + +# These functions to convert to/from backend arrays +# assume that no extra memory is allocated, by using the +# Python buffer protocol. +# See https://data-apis.org/array-api/latest/API_specification/generated/array_api.asarray.html + + +def backend_array_to_numpy_array(arr): + return np.asarray(arr) + + +def numpy_array_to_backend_array(arr, *, dtype=None): + return namespace.asarray(arr, dtype=dtype) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 1a4b1b78..e2160638 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -17,6 +17,8 @@ replace_ellipsis, ) +from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import numpy_array_to_backend_array from cubed.core.array import CoreArray, check_array_specs, compute, gensym from cubed.core.plan import Plan, new_temp_path from cubed.primitive.blockwise import blockwise as primitive_blockwise @@ -77,6 +79,7 @@ def _from_array(e, x, outchunks=None, asarray=None, block_id=None): out = x[get_item(outchunks, block_id)] if asarray: out = np.asarray(out) + out = numpy_array_to_backend_array(out) return out @@ -418,6 +421,7 @@ def _read_index_chunk(x, *arrays, target_chunks=None, selection=None, block_id=N array = arrays[0] idx = block_id out = array.zarray.oindex[_target_chunk_selection(target_chunks, idx, selection)] + out = numpy_array_to_backend_array(out) return out @@ -470,7 +474,8 @@ def offset_to_block_id(offset): def func_with_block_id(func): def wrap(*a, **kw): - block_id = offset_to_block_id(a[-1].item()) + offset = int(a[-1]) # convert from 0-d array + block_id = offset_to_block_id(offset) return func(*a[:-1], block_id=block_id, **kw) return wrap @@ -702,7 +707,9 @@ def merge_chunks(x, chunks): def _copy_chunk(e, x, target_chunks=None, block_id=None): - return x.zarray[get_item(target_chunks, block_id)] + out = x.zarray[get_item(target_chunks, block_id)] + out = numpy_array_to_backend_array(out) + return out def reduction( @@ -835,7 +842,8 @@ def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False): def _arg_map_func(a, axis, arg_func=None, size=None, block_id=None): i = arg_func(a, axis=axis, keepdims=True) - v = np.take_along_axis(a, i, axis=axis) + # note that the array API doesn't have take_along_axis, so this may fail + v = nxp.take_along_axis(a, i, axis=axis) # add block offset to i so it is absolute index within whole array offset = block_id[axis] * size return {"i": i + offset, "v": v} @@ -855,8 +863,9 @@ def _arg_combine(a, arg_func=None, **kwargs): # find indexes of values in v and apply to i and v vi = arg_func(v, axis=axis, **kwargs) - i_combined = np.take_along_axis(i, vi, axis=axis) - v_combined = np.take_along_axis(v, vi, axis=axis) + # note that the array API doesn't have take_along_axis, so this may fail + i_combined = nxp.take_along_axis(i, vi, axis=axis) + v_combined = nxp.take_along_axis(v, vi, axis=axis) return {"i": i_combined, "v": v_combined} @@ -877,7 +886,7 @@ def squeeze(x, /, axis): chunks = tuple(c for i, c in enumerate(x.chunks) if i not in axis) return map_blocks( - np.squeeze, x, dtype=x.dtype, chunks=chunks, drop_axis=axis, axis=axis + nxp.squeeze, x, dtype=x.dtype, chunks=chunks, drop_axis=axis, axis=axis ) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 63175fc7..af753501 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -5,6 +5,7 @@ import networkx as nx +from cubed.backend_array_api import backend_array_to_numpy_array from cubed.primitive.blockwise import can_fuse_pipelines, fuse from cubed.runtime.pipeline import visit_nodes from cubed.runtime.types import CubedPipeline @@ -363,7 +364,8 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem): projected_mem = ( max( [ - lza.initial_values.nbytes + # TODO: calculate nbytes from size and dtype itemsize + backend_array_to_numpy_array(lza.initial_values).nbytes if lza.initial_values is not None else lza.dtype.itemsize for lza in lazy_zarr_arrays diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index 836fb1fa..e3789481 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -11,6 +11,7 @@ int64, uint64, ) +from cubed.backend_array_api import namespace as nxp from cubed.core import reduction # TODO: refactor once nan functions are standardized: @@ -33,26 +34,29 @@ def nanmean(x, /, *, axis=None, keepdims=False): ) +# note that the array API doesn't have nansum or nanmean, so these functions may fail + + def _nanmean_func(a, **kwargs): n = _nannumel(a, **kwargs) - total = np.nansum(a, **kwargs) + total = nxp.nansum(a, **kwargs) return {"n": n, "total": total} def _nanmean_combine(a, **kwargs): - n = np.nansum(a["n"], **kwargs) - total = np.nansum(a["total"], **kwargs) + n = nxp.nansum(a["n"], **kwargs) + total = nxp.nansum(a["total"], **kwargs) return {"n": n, "total": total} def _nanmean_aggregate(a): with np.errstate(divide="ignore", invalid="ignore"): - return np.divide(a["total"], a["n"]) + return nxp.divide(a["total"], a["n"]) def _nannumel(x, **kwargs): """A reduction to count the number of elements, excluding nans""" - return np.sum(~(np.isnan(x)), **kwargs) + return nxp.sum(~(nxp.isnan(x)), **kwargs) def nansum(x, /, *, axis=None, dtype=None, keepdims=False): @@ -70,4 +74,4 @@ def nansum(x, /, *, axis=None, dtype=None, keepdims=False): dtype = complex128 else: dtype = x.dtype - return reduction(x, np.nansum, axis=axis, dtype=dtype, keepdims=keepdims) + return reduction(x, nxp.nansum, axis=axis, dtype=dtype, keepdims=keepdims) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index eba1b3a0..64d8ea3b 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -8,6 +8,10 @@ import zarr from toolz import map +from cubed.backend_array_api import ( + backend_array_to_numpy_array, + numpy_array_to_backend_array, +) from cubed.runtime.types import CubedPipeline from cubed.storage.zarr import T_ZarrArray, lazy_empty from cubed.types import T_Chunks, T_DType, T_Shape, T_Store @@ -66,13 +70,16 @@ def apply_blockwise(out_key: List[int], *, config: BlockwiseSpec) -> None: arr = config.reads_map[name].open() chunk_key = key_to_slices(chunk_ind, arr) arg = arr[chunk_key] + arg = numpy_array_to_backend_array(arg) args.append(arg) result = config.function(*args) if isinstance(result, dict): # structured array with named fields for k, v in result.items(): + v = backend_array_to_numpy_array(v) config.write.open().set_basic_selection(out_chunk_key, v, fields=k) else: + result = backend_array_to_numpy_array(result) config.write.open()[out_chunk_key] = result diff --git a/cubed/random.py b/cubed/random.py index 04148429..720f9090 100644 --- a/cubed/random.py +++ b/cubed/random.py @@ -1,10 +1,11 @@ import random as pyrandom import numpy as np -import numpy.array_api as nxp from numpy.random import Generator, Philox from zarr.util import normalize_shape +from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import numpy_array_to_backend_array from cubed.core.ops import map_direct from cubed.vendor.dask.array.core import normalize_chunks @@ -36,7 +37,9 @@ def random(size, *, chunks=None, spec=None): def _random(x, *arrays, numblocks=None, root_seed=None, block_id=None): stream_id = block_id_to_offset(block_id, numblocks) rg = Generator(Philox(key=root_seed + stream_id)) - return rg.random(x.shape) + out = rg.random(x.shape) + out = numpy_array_to_backend_array(out) + return out def block_id_to_offset(block_id, numblocks): diff --git a/cubed/storage/virtual.py b/cubed/storage/virtual.py index e1408bd6..2a48f6ec 100644 --- a/cubed/storage/virtual.py +++ b/cubed/storage/virtual.py @@ -5,6 +5,8 @@ import zarr from zarr.indexing import BasicIndexer, is_slice +from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import numpy_array_to_backend_array from cubed.types import T_DType, T_RegularChunks, T_Shape @@ -30,7 +32,7 @@ def __getitem__(self, key): if not isinstance(key, tuple): key = (key,) indexer = BasicIndexer(key, self.template) - return np.empty(indexer.shape, dtype=self.dtype) + return nxp.empty(indexer.shape, dtype=self.dtype) @property def oindex(self): @@ -65,7 +67,7 @@ def __getitem__(self, key): if not isinstance(key, tuple): key = (key,) indexer = BasicIndexer(key, self.template) - return np.full(indexer.shape, fill_value=self.fill_value, dtype=self.dtype) + return nxp.full(indexer.shape, fill_value=self.fill_value, dtype=self.dtype) @property def oindex(self): @@ -89,8 +91,10 @@ def __init__(self, shape: T_Shape): def __getitem__(self, key): if key == () and self.shape == (): - return np.array(0, dtype=self.dtype) - return np.ravel_multi_index(_key_to_index_tuple(key), self.shape) + return nxp.asarray(0, dtype=self.dtype) + return numpy_array_to_backend_array( + np.ravel_multi_index(_key_to_index_tuple(key), self.shape), dtype=self.dtype + ) def _key_to_index_tuple(selection): diff --git a/cubed/storage/zarr.py b/cubed/storage/zarr.py index cdbd09f7..60c33874 100644 --- a/cubed/storage/zarr.py +++ b/cubed/storage/zarr.py @@ -3,6 +3,7 @@ import zarr from numpy import ndarray +from cubed.backend_array_api import backend_array_to_numpy_array from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store @@ -60,7 +61,7 @@ def create(self, mode: str = "w-") -> zarr.Array: **self.kwargs, ) if self.initial_values is not None and self.initial_values.size > 0: - target[...] = self.initial_values + target[...] = backend_array_to_numpy_array(self.initial_values) return target def open(self) -> zarr.Array: diff --git a/cubed/tests/primitive/test_blockwise.py b/cubed/tests/primitive/test_blockwise.py index 5dfea04e..d1af9caf 100644 --- a/cubed/tests/primitive/test_blockwise.py +++ b/cubed/tests/primitive/test_blockwise.py @@ -3,6 +3,7 @@ import zarr from numpy.testing import assert_array_equal +from cubed.backend_array_api import namespace as nxp from cubed.primitive.blockwise import blockwise, make_blockwise_function from cubed.runtime.executors.python import PythonDagExecutor from cubed.tests.utils import create_zarr, execute_pipeline @@ -26,7 +27,7 @@ def test_blockwise(tmp_path, executor, reserved_mem): target_store = tmp_path / "target.zarr" pipeline = blockwise( - np.outer, + nxp.linalg.outer, "ij", source1, "i", @@ -74,7 +75,7 @@ def _permute_dims(x, /, axes, allowed_mem, reserved_mem, target_store): axes = tuple(range(x.ndim))[::-1] axes = tuple(d + x.ndim if d < 0 else d for d in axes) return blockwise( - np.transpose, + nxp.permute_dims, axes, x, tuple(range(x.ndim)), @@ -146,7 +147,7 @@ def test_blockwise_allowed_mem_exceeded(tmp_path, reserved_mem): match=r"Projected blockwise memory \(\d+\) exceeds allowed_mem \(100\), including reserved_mem \(\d+\)", ): blockwise( - np.outer, + nxp.linalg.outer, "ij", source1, "i", diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index d0e46ae4..85e3a3f5 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -9,6 +9,7 @@ import cubed import cubed.array_api as xp import cubed.random +from cubed.backend_array_api import namespace as nxp from cubed.core.ops import merge_chunks from cubed.tests.utils import ( ALL_EXECUTORS, @@ -148,14 +149,14 @@ def test_to_zarr(tmp_path, spec, executor): def test_map_blocks_with_kwargs(spec, executor): # based on dask test a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=5, spec=spec) - b = cubed.map_blocks(np.max, a, axis=0, keepdims=True, dtype=a.dtype, chunks=(1,)) + b = cubed.map_blocks(nxp.max, a, axis=0, keepdims=True, dtype=a.dtype, chunks=(1,)) assert_array_equal(b.compute(executor=executor), np.array([4, 9])) def test_map_blocks_with_block_id(spec, executor): # based on dask test def func(block, block_id=None, c=0): - return np.ones_like(block) * sum(block_id) + c + return nxp.ones_like(block) * int(sum(block_id)) + c a = xp.arange(10, dtype="int64", chunks=(2,)) b = cubed.map_blocks(func, a, dtype="int64") diff --git a/cubed/tests/test_gufunc.py b/cubed/tests/test_gufunc.py index 7944d609..a06d2e65 100644 --- a/cubed/tests/test_gufunc.py +++ b/cubed/tests/test_gufunc.py @@ -5,6 +5,7 @@ import cubed import cubed.array_api as xp from cubed import apply_gufunc +from cubed.backend_array_api import namespace as nxp @pytest.fixture() @@ -16,12 +17,12 @@ def spec(tmp_path): def test_apply_reduction(spec, vectorize): def stats(x): # note dtype matches output_dtypes in apply_gufunc below - return np.mean(x, axis=-1, dtype=np.float32) + return nxp.mean(x, axis=-1, dtype=np.float32) r = np.random.normal(size=(10, 20, 30)) a = cubed.from_array(r, chunks=(5, 5, 30), spec=spec) actual = apply_gufunc(stats, "(i)->()", a, output_dtypes="f", vectorize=vectorize) - expected = stats(r) + expected = np.mean(r, axis=-1, dtype=np.float32) assert actual.compute().shape == expected.shape assert_allclose(actual.compute(), expected) @@ -85,7 +86,7 @@ def outer_product(x, y): def test_gufunc_output_sizes(spec): def foo(x): - return np.broadcast_to(x[:, np.newaxis], (x.shape[0], 3)) + return nxp.broadcast_to(x[:, np.newaxis], (x.shape[0], 3)) a = cubed.from_array(np.array([1, 2, 3, 4, 5], dtype=int), spec=spec) x = apply_gufunc(foo, "()->(i_0)", a, output_dtypes=int, output_sizes={"i_0": 3}) diff --git a/pyproject.toml b/pyproject.toml index 0ac9abaa..ae72c66a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ "aiostream", + "array-api-compat", "fsspec", "mypy_extensions", # for rechunker "networkx < 2.8.3", diff --git a/setup.cfg b/setup.cfg index 6d0e2d68..ff572ada 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,8 @@ allow_redefinition = True ignore_missing_imports = True [mypy-aiostream.*] ignore_missing_imports = True +[mypy-array_api_compat.*] +ignore_missing_imports = True [mypy-coiled.*] ignore_missing_imports = True [mypy-dask.*]