Skip to content

Commit

Permalink
Array API backend (#317)
Browse files Browse the repository at this point in the history
* Array api backend prep

* Use math ceil rather than numpy as it's built in

* Add central definition of the array API namespace used by the backend

* Convert blockwise to use backend array API namespace

* Use backend array API namespace in Cubed array API implementation

* Convert to/from backend array

* Add array-api-compat to dependencies

* Ignore missing imports for mypy for array-api-compat

* Fix dtype in test for mean
  • Loading branch information
tomwhite authored Nov 23, 2023
1 parent d19caec commit e87671d
Show file tree
Hide file tree
Showing 22 changed files with 232 additions and 155 deletions.
77 changes: 43 additions & 34 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
float64,
)
from cubed.array_api.linear_algebra_functions import matmul
from cubed.backend_array_api import namespace as nxp
from cubed.core.array import CoreArray
from cubed.core.ops import elemwise
from cubed.utils import memory_repr
Expand Down Expand Up @@ -118,54 +119,54 @@ def T(self):
def __neg__(self, /):
if self.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in __neg__")
return elemwise(np.negative, self, dtype=self.dtype)
return elemwise(nxp.negative, self, dtype=self.dtype)

def __pos__(self, /):
if self.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in __pos__")
return elemwise(np.positive, self, dtype=self.dtype)
return elemwise(nxp.positive, self, dtype=self.dtype)

def __add__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__add__")
if other is NotImplemented:
return other
return elemwise(np.add, self, other, dtype=result_type(self, other))
return elemwise(nxp.add, self, other, dtype=result_type(self, other))

def __sub__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__sub__")
if other is NotImplemented:
return other
return elemwise(np.subtract, self, other, dtype=result_type(self, other))
return elemwise(nxp.subtract, self, other, dtype=result_type(self, other))

def __mul__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__mul__")
if other is NotImplemented:
return other
return elemwise(np.multiply, self, other, dtype=result_type(self, other))
return elemwise(nxp.multiply, self, other, dtype=result_type(self, other))

def __truediv__(self, other, /):
other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
if other is NotImplemented:
return other
return elemwise(np.divide, self, other, dtype=result_type(self, other))
return elemwise(nxp.divide, self, other, dtype=result_type(self, other))

def __floordiv__(self, other, /):
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
return elemwise(np.floor_divide, self, other, dtype=result_type(self, other))
return elemwise(nxp.floor_divide, self, other, dtype=result_type(self, other))

def __mod__(self, other, /):
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
return elemwise(np.remainder, self, other, dtype=result_type(self, other))
return elemwise(nxp.remainder, self, other, dtype=result_type(self, other))

def __pow__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
return elemwise(np.power, self, other, dtype=result_type(self, other))
return elemwise(nxp.pow, self, other, dtype=result_type(self, other))

# Array Operators

Expand All @@ -180,75 +181,79 @@ def __matmul__(self, other, /):
def __invert__(self, /):
if self.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in __invert__")
return elemwise(np.invert, self, dtype=self.dtype)
return elemwise(nxp.bitwise_invert, self, dtype=self.dtype)

def __and__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_and, self, other, dtype=result_type(self, other))
return elemwise(nxp.bitwise_and, self, other, dtype=result_type(self, other))

def __or__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_or, self, other, dtype=result_type(self, other))
return elemwise(nxp.bitwise_or, self, other, dtype=result_type(self, other))

def __xor__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_xor, self, other, dtype=result_type(self, other))
return elemwise(nxp.bitwise_xor, self, other, dtype=result_type(self, other))

def __lshift__(self, other, /):
other = self._check_allowed_dtypes(other, "integer", "__lshift__")
if other is NotImplemented:
return other
return elemwise(np.left_shift, self, other, dtype=result_type(self, other))
return elemwise(
nxp.bitwise_left_shift, self, other, dtype=result_type(self, other)
)

def __rshift__(self, other, /):
other = self._check_allowed_dtypes(other, "integer", "__rshift__")
if other is NotImplemented:
return other
return elemwise(np.right_shift, self, other, dtype=result_type(self, other))
return elemwise(
nxp.bitwise_right_shift, self, other, dtype=result_type(self, other)
)

# Comparison Operators

def __eq__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
return elemwise(np.equal, self, other, dtype=np.bool_)
return elemwise(nxp.equal, self, other, dtype=np.bool_)

def __ge__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__ge__")
if other is NotImplemented:
return other
return elemwise(np.greater_equal, self, other, dtype=np.bool_)
return elemwise(nxp.greater_equal, self, other, dtype=np.bool_)

def __gt__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__gt__")
if other is NotImplemented:
return other
return elemwise(np.greater, self, other, dtype=np.bool_)
return elemwise(nxp.greater, self, other, dtype=np.bool_)

def __le__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__le__")
if other is NotImplemented:
return other
return elemwise(np.less_equal, self, other, dtype=np.bool_)
return elemwise(nxp.less_equal, self, other, dtype=np.bool_)

def __lt__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__lt__")
if other is NotImplemented:
return other
return elemwise(np.less, self, other, dtype=np.bool_)
return elemwise(nxp.less, self, other, dtype=np.bool_)

def __ne__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
return elemwise(np.not_equal, self, other, dtype=np.bool_)
return elemwise(nxp.not_equal, self, other, dtype=np.bool_)

# Reflected Operators

Expand All @@ -258,43 +263,43 @@ def __radd__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__radd__")
if other is NotImplemented:
return other
return elemwise(np.add, other, self, dtype=result_type(self, other))
return elemwise(nxp.add, other, self, dtype=result_type(self, other))

def __rsub__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rsub__")
if other is NotImplemented:
return other
return elemwise(np.subtract, other, self, dtype=result_type(self, other))
return elemwise(nxp.subtract, other, self, dtype=result_type(self, other))

def __rmul__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rmul__")
if other is NotImplemented:
return other
return elemwise(np.multiply, other, self, dtype=result_type(self, other))
return elemwise(nxp.multiply, other, self, dtype=result_type(self, other))

def __rtruediv__(self, other, /):
other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__")
if other is NotImplemented:
return other
return elemwise(np.divide, other, self, dtype=result_type(self, other))
return elemwise(nxp.divide, other, self, dtype=result_type(self, other))

def __rfloordiv__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__")
if other is NotImplemented:
return other
return elemwise(np.floor_divide, other, self, dtype=result_type(self, other))
return elemwise(nxp.floor_divide, other, self, dtype=result_type(self, other))

def __rmod__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rmod__")
if other is NotImplemented:
return other
return elemwise(np.remainder, other, self, dtype=result_type(self, other))
return elemwise(nxp.remainder, other, self, dtype=result_type(self, other))

def __rpow__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
return elemwise(np.power, other, self, dtype=result_type(self, other))
return elemwise(nxp.pow, other, self, dtype=result_type(self, other))

# (Reflected) Array Operators

Expand All @@ -310,31 +315,35 @@ def __rand__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_and, other, self, dtype=result_type(self, other))
return elemwise(nxp.bitwise_and, other, self, dtype=result_type(self, other))

def __ror__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_or, other, self, dtype=result_type(self, other))
return elemwise(nxp.bitwise_or, other, self, dtype=result_type(self, other))

def __rxor__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_xor, other, self, dtype=result_type(self, other))
return elemwise(nxp.bitwise_xor, other, self, dtype=result_type(self, other))

def __rlshift__(self, other, /):
other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
if other is NotImplemented:
return other
return elemwise(np.left_shift, other, self, dtype=result_type(self, other))
return elemwise(
nxp.bitwise_left_shift, other, self, dtype=result_type(self, other)
)

def __rrshift__(self, other, /):
other = self._check_allowed_dtypes(other, "integer", "__rrshift__")
if other is NotImplemented:
return other
return elemwise(np.right_shift, other, self, dtype=result_type(self, other))
return elemwise(
nxp.bitwise_right_shift, other, self, dtype=result_type(self, other)
)

# Methods

Expand All @@ -347,7 +356,7 @@ def __abs__(self, /):
dtype = float64
else:
dtype = self.dtype
return elemwise(np.abs, self, dtype=dtype)
return elemwise(nxp.abs, self, dtype=dtype)

def __array_namespace__(self, /, *, api_version=None):
if api_version is not None and not api_version.startswith("2021."):
Expand Down
18 changes: 10 additions & 8 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from typing import TYPE_CHECKING, Iterable, List

import numpy as np
from zarr.util import normalize_shape

from cubed.backend_array_api import namespace as nxp
from cubed.core import Plan, gensym, map_blocks
from cubed.core.ops import map_direct
from cubed.core.plan import new_temp_path
Expand All @@ -20,9 +22,9 @@ def arange(
) -> "Array":
if stop is None:
start, stop = 0, start
num = int(max(np.ceil((stop - start) / step), 0))
num = int(max(math.ceil((stop - start) / step), 0))
if dtype is None:
dtype = np.arange(start, stop, step * num if num else step).dtype
dtype = nxp.arange(start, stop, step * num if num else step).dtype
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]
numblocks = len(chunks[0])
Expand All @@ -43,10 +45,10 @@ def arange(


def _arange(a, size, start, stop, step):
i = a[0]
i = int(a[0])
blockstart = start + (i * size * step)
blockstop = start + ((i + 1) * size * step)
return np.arange(blockstart, min(blockstop, stop), step)
return nxp.arange(blockstart, min(blockstop, stop), step)


def asarray(
Expand All @@ -64,7 +66,7 @@ def asarray(
return asarray(a.data)
elif not isinstance(getattr(a, "shape", None), Iterable):
# ensure blocks are arrays
a = np.asarray(a, dtype=dtype)
a = nxp.asarray(a, dtype=dtype)
if dtype is None:
dtype = a.dtype

Expand Down Expand Up @@ -133,9 +135,9 @@ def _eye(x, *arrays, k=None, chunksize=None, block_id=None):
i, j = block_id
bk = (j - i) * chunksize
if bk - chunksize <= k <= bk + chunksize:
return np.eye(x.shape[0], x.shape[1], k=k - bk, dtype=x.dtype)
return nxp.eye(x.shape[0], x.shape[1], k=k - bk, dtype=x.dtype)
else:
return np.zeros_like(x)
return nxp.zeros_like(x)


def full(
Expand Down Expand Up @@ -225,7 +227,7 @@ def _linspace(x, *arrays, size, start, step, endpoint, linspace_dtype, block_id=
adjusted_bs = bs - 1 if endpoint else bs
blockstart = start + (i * size * step)
blockstop = blockstart + (adjusted_bs * step)
return np.linspace(
return nxp.linspace(
blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype
)

Expand Down
3 changes: 2 additions & 1 deletion cubed/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from numpy.array_api._typing import Dtype

from cubed.backend_array_api import namespace as nxp
from cubed.core import CoreArray, map_blocks

from .dtypes import (
Expand All @@ -25,7 +26,7 @@ def astype(x, dtype, /, *, copy=True):


def _astype(a, astype_dtype):
return a.astype(astype_dtype)
return nxp.astype(a, astype_dtype)


def can_cast(from_, to, /):
Expand Down
Loading

0 comments on commit e87671d

Please sign in to comment.