From 8143af70f172ed05ef546da3dfeae743c42eabc1 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 22 May 2024 11:34:41 +0100 Subject: [PATCH] =?UTF-8?q?Revert=20"Implement=20NumPy's=20`=5F=5Farray=5F?= =?UTF-8?q?function=5F=5F`=20protocol=20for=20array=20methods=20tha?= =?UTF-8?q?=E2=80=A6"=20(#469)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit af1ab743a76982297bca063aac93d8e5f4aef88a. --- cubed/array_api/array_object.py | 24 ------------------------ cubed/nan_functions.py | 9 ++------- cubed/pad.py | 7 ------- cubed/tests/test_nan_functions.py | 12 ++++-------- cubed/tests/test_pad.py | 8 ++------ 5 files changed, 8 insertions(+), 52 deletions(-) diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index f4ba89e6..3a2b23ff 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -29,24 +29,6 @@ 120 # cubed doesn't have a config module like dask does so hard-code this for now ) -_HANDLED_FUNCTIONS = {} - - -def implements(*numpy_functions): - """Register an __array_function__ implementation for cubed.Array - - Note that this is **only** used for functions that are not defined in the - Array API Standard. - """ - - def decorator(cubed_func): - for numpy_function in numpy_functions: - _HANDLED_FUNCTIONS[numpy_function] = cubed_func - - return cubed_func - - return decorator - class Array(CoreArray): """Chunked array backed by Zarr storage that conforms to the Python Array API standard.""" @@ -62,12 +44,6 @@ def __array__(self, dtype=None) -> np.ndarray: x = np.array(x) return x - def __array_function__(self, func, types, args, kwargs): - # Only dispatch to functions that are not defined in the Array API Standard - if func in _HANDLED_FUNCTIONS: - return _HANDLED_FUNCTIONS[func](*args, **kwargs) - return NotImplemented - def __repr__(self): return f"cubed.Array<{self.name}, shape={self.shape}, dtype={self.dtype}, chunks={self.chunks}>" diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index 2de161df..2acd308b 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -1,6 +1,5 @@ import numpy as np -from cubed.array_api.array_object import implements from cubed.array_api.dtypes import ( _numeric_dtypes, _signed_integer_dtypes, @@ -19,12 +18,9 @@ # https://github.com/data-apis/array-api/issues/621 -@implements(np.nanmean) -def nanmean( - x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None -): +def nanmean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" - dtype = dtype or x.dtype + dtype = x.dtype intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] return reduction( x, @@ -65,7 +61,6 @@ def _nannumel(x, **kwargs): return nxp.sum(~(nxp.isnan(x)), **kwargs) -@implements(np.nansum) def nansum( x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None ): diff --git a/cubed/pad.py b/cubed/pad.py index afb0c4d9..c292c65e 100644 --- a/cubed/pad.py +++ b/cubed/pad.py @@ -1,13 +1,6 @@ -import numpy as np - -from cubed.array_api.array_object import implements from cubed.array_api.manipulation_functions import concat -# TODO: refactor once pad is standardized: -# https://github.com/data-apis/array-api/issues/187 - -@implements(np.pad) def pad(x, pad_width, mode=None, chunks=None): """Pad an array.""" if len(pad_width) != x.ndim: diff --git a/cubed/tests/test_nan_functions.py b/cubed/tests/test_nan_functions.py index f67ce71d..53264e79 100644 --- a/cubed/tests/test_nan_functions.py +++ b/cubed/tests/test_nan_functions.py @@ -11,11 +11,9 @@ def spec(tmp_path): return cubed.Spec(tmp_path, allowed_mem=100000) -@pytest.mark.parametrize("namespace", [cubed, np]) -def test_nanmean(spec, namespace): +def test_nanmean(spec): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) - b = namespace.nanmean(a) - assert isinstance(b, cubed.Array) + b = cubed.nanmean(a) assert_array_equal( b.compute(), np.nanmean(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) ) @@ -28,11 +26,9 @@ def test_nanmean_allnan(spec): assert_array_equal(b.compute(), np.nanmean(np.array([np.nan]))) -@pytest.mark.parametrize("namespace", [cubed, np]) -def test_nansum(spec, namespace): +def test_nansum(spec): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) - b = namespace.nansum(a) - assert isinstance(b, cubed.Array) + b = cubed.nansum(a) assert_array_equal( b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) ) diff --git a/cubed/tests/test_pad.py b/cubed/tests/test_pad.py index 027348da..7ba985f4 100644 --- a/cubed/tests/test_pad.py +++ b/cubed/tests/test_pad.py @@ -11,15 +11,11 @@ def spec(tmp_path): return cubed.Spec(tmp_path, allowed_mem=100000) -@pytest.mark.parametrize("namespace", [cubed, np]) -def test_pad(spec, namespace): +def test_pad(spec): an = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) - # check that we can dispatch via the numpy namespace (via __array_function__) - # since pad is not yet a part of the Array API Standard - b = namespace.pad(a, ((1, 0), (0, 0)), mode="symmetric") - assert isinstance(b, cubed.Array) + b = cubed.pad(a, ((1, 0), (0, 0)), mode="symmetric") assert b.chunks == ((2, 2), (2, 1)) assert_array_equal(b.compute(), np.pad(an, ((1, 0), (0, 0)), mode="symmetric"))