Skip to content

Commit

Permalink
Revert "Implement NumPy's __array_function__ protocol for array met…
Browse files Browse the repository at this point in the history
…hods tha…" (#469)

This reverts commit af1ab74.
  • Loading branch information
tomwhite authored May 22, 2024
1 parent af1ab74 commit 8143af7
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 52 deletions.
24 changes: 0 additions & 24 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,6 @@
120 # cubed doesn't have a config module like dask does so hard-code this for now
)

_HANDLED_FUNCTIONS = {}


def implements(*numpy_functions):
"""Register an __array_function__ implementation for cubed.Array
Note that this is **only** used for functions that are not defined in the
Array API Standard.
"""

def decorator(cubed_func):
for numpy_function in numpy_functions:
_HANDLED_FUNCTIONS[numpy_function] = cubed_func

return cubed_func

return decorator


class Array(CoreArray):
"""Chunked array backed by Zarr storage that conforms to the Python Array API standard."""
Expand All @@ -62,12 +44,6 @@ def __array__(self, dtype=None) -> np.ndarray:
x = np.array(x)
return x

def __array_function__(self, func, types, args, kwargs):
# Only dispatch to functions that are not defined in the Array API Standard
if func in _HANDLED_FUNCTIONS:
return _HANDLED_FUNCTIONS[func](*args, **kwargs)
return NotImplemented

def __repr__(self):
return f"cubed.Array<{self.name}, shape={self.shape}, dtype={self.dtype}, chunks={self.chunks}>"

Expand Down
9 changes: 2 additions & 7 deletions cubed/nan_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

from cubed.array_api.array_object import implements
from cubed.array_api.dtypes import (
_numeric_dtypes,
_signed_integer_dtypes,
Expand All @@ -19,12 +18,9 @@
# https://github.com/data-apis/array-api/issues/621


@implements(np.nanmean)
def nanmean(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
def nanmean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
"""Compute the arithmetic mean along the specified axis, ignoring NaNs."""
dtype = dtype or x.dtype
dtype = x.dtype
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
return reduction(
x,
Expand Down Expand Up @@ -65,7 +61,6 @@ def _nannumel(x, **kwargs):
return nxp.sum(~(nxp.isnan(x)), **kwargs)


@implements(np.nansum)
def nansum(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
Expand Down
7 changes: 0 additions & 7 deletions cubed/pad.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import numpy as np

from cubed.array_api.array_object import implements
from cubed.array_api.manipulation_functions import concat

# TODO: refactor once pad is standardized:
# https://github.com/data-apis/array-api/issues/187


@implements(np.pad)
def pad(x, pad_width, mode=None, chunks=None):
"""Pad an array."""
if len(pad_width) != x.ndim:
Expand Down
12 changes: 4 additions & 8 deletions cubed/tests/test_nan_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ def spec(tmp_path):
return cubed.Spec(tmp_path, allowed_mem=100000)


@pytest.mark.parametrize("namespace", [cubed, np])
def test_nanmean(spec, namespace):
def test_nanmean(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec)
b = namespace.nanmean(a)
assert isinstance(b, cubed.Array)
b = cubed.nanmean(a)
assert_array_equal(
b.compute(), np.nanmean(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]]))
)
Expand All @@ -28,11 +26,9 @@ def test_nanmean_allnan(spec):
assert_array_equal(b.compute(), np.nanmean(np.array([np.nan])))


@pytest.mark.parametrize("namespace", [cubed, np])
def test_nansum(spec, namespace):
def test_nansum(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec)
b = namespace.nansum(a)
assert isinstance(b, cubed.Array)
b = cubed.nansum(a)
assert_array_equal(
b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]]))
)
Expand Down
8 changes: 2 additions & 6 deletions cubed/tests/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,11 @@ def spec(tmp_path):
return cubed.Spec(tmp_path, allowed_mem=100000)


@pytest.mark.parametrize("namespace", [cubed, np])
def test_pad(spec, namespace):
def test_pad(spec):
an = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
# check that we can dispatch via the numpy namespace (via __array_function__)
# since pad is not yet a part of the Array API Standard
b = namespace.pad(a, ((1, 0), (0, 0)), mode="symmetric")
assert isinstance(b, cubed.Array)
b = cubed.pad(a, ((1, 0), (0, 0)), mode="symmetric")
assert b.chunks == ((2, 2), (2, 1))

assert_array_equal(b.compute(), np.pad(an, ((1, 0), (0, 0)), mode="symmetric"))

0 comments on commit 8143af7

Please sign in to comment.