Skip to content

Commit

Permalink
Implement NumPy's __array_function__ protocol for array methods tha…
Browse files Browse the repository at this point in the history
…t are not in the Array API Standard (#468)

* nanmean, nansum, pad
  • Loading branch information
tomwhite committed Oct 1, 2024
1 parent 03f3e0e commit b4191e7
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 8 deletions.
24 changes: 24 additions & 0 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@
120 # cubed doesn't have a config module like dask does so hard-code this for now
)

_HANDLED_FUNCTIONS = {}


def implements(*numpy_functions):
"""Register an __array_function__ implementation for cubed.Array
Note that this is **only** used for functions that are not defined in the
Array API Standard.
"""

def decorator(cubed_func):
for numpy_function in numpy_functions:
_HANDLED_FUNCTIONS[numpy_function] = cubed_func

return cubed_func

return decorator


class Array(CoreArray):
"""Chunked array backed by Zarr storage that conforms to the Python Array API standard."""
Expand All @@ -44,6 +62,12 @@ def __array__(self, dtype=None) -> np.ndarray:
x = np.array(x)
return x

def __array_function__(self, func, types, args, kwargs):
# Only dispatch to functions that are not defined in the Array API Standard
if func in _HANDLED_FUNCTIONS:
return _HANDLED_FUNCTIONS[func](*args, **kwargs)
return NotImplemented

def __repr__(self):
return f"cubed.Array<{self.name}, shape={self.shape}, dtype={self.dtype}, chunks={self.chunks}>"

Expand Down
9 changes: 7 additions & 2 deletions cubed/nan_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

from cubed.array_api.array_object import implements
from cubed.array_api.dtypes import (
_numeric_dtypes,
_signed_integer_dtypes,
Expand All @@ -18,9 +19,12 @@
# https://github.com/data-apis/array-api/issues/621


def nanmean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
@implements(np.nanmean)
def nanmean(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
"""Compute the arithmetic mean along the specified axis, ignoring NaNs."""
dtype = x.dtype
dtype = dtype or x.dtype
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
return reduction(
x,
Expand Down Expand Up @@ -61,6 +65,7 @@ def _nannumel(x, **kwargs):
return nxp.sum(~(nxp.isnan(x)), **kwargs)


@implements(np.nansum)
def nansum(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
Expand Down
7 changes: 7 additions & 0 deletions cubed/pad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import numpy as np

from cubed.array_api.array_object import implements
from cubed.array_api.manipulation_functions import concat

# TODO: refactor once pad is standardized:
# https://github.com/data-apis/array-api/issues/187


@implements(np.pad)
def pad(x, pad_width, mode=None, chunks=None):
"""Pad an array."""
if len(pad_width) != x.ndim:
Expand Down
12 changes: 8 additions & 4 deletions cubed/tests/test_nan_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ def spec(tmp_path):
return cubed.Spec(tmp_path, allowed_mem=100000)


def test_nanmean(spec):
@pytest.mark.parametrize("namespace", [cubed, np])
def test_nanmean(spec, namespace):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec)
b = cubed.nanmean(a)
b = namespace.nanmean(a)
assert isinstance(b, cubed.Array)
assert_array_equal(
b.compute(), np.nanmean(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]]))
)
Expand All @@ -26,9 +28,11 @@ def test_nanmean_allnan(spec):
assert_array_equal(b.compute(), np.nanmean(np.array([np.nan])))


def test_nansum(spec):
@pytest.mark.parametrize("namespace", [cubed, np])
def test_nansum(spec, namespace):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec)
b = cubed.nansum(a)
b = namespace.nansum(a)
assert isinstance(b, cubed.Array)
assert_array_equal(
b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]]))
)
Expand Down
8 changes: 6 additions & 2 deletions cubed/tests/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ def spec(tmp_path):
return cubed.Spec(tmp_path, allowed_mem=100000)


def test_pad(spec):
@pytest.mark.parametrize("namespace", [cubed, np])
def test_pad(spec, namespace):
an = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = cubed.pad(a, ((1, 0), (0, 0)), mode="symmetric")
# check that we can dispatch via the numpy namespace (via __array_function__)
# since pad is not yet a part of the Array API Standard
b = namespace.pad(a, ((1, 0), (0, 0)), mode="symmetric")
assert isinstance(b, cubed.Array)
assert b.chunks == ((2, 2), (2, 1))

assert_array_equal(b.compute(), np.pad(an, ((1, 0), (0, 0)), mode="symmetric"))

0 comments on commit b4191e7

Please sign in to comment.