From be97926446869f0f6fe2a36ad5355a72f0ff95e7 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 27 Mar 2023 17:07:38 +0100 Subject: [PATCH] Add `nansum` and `nanmean` functions to `cubed.array` (not `cubed.array_api` since they are not yet standard) --- cubed/array/__init__.py | 7 +++ cubed/array/nan_functions.py | 73 +++++++++++++++++++++++++ cubed/tests/test_array_nan_functions.py | 40 ++++++++++++++ docs/api.rst | 11 ++++ 4 files changed, 131 insertions(+) create mode 100644 cubed/array/__init__.py create mode 100644 cubed/array/nan_functions.py create mode 100644 cubed/tests/test_array_nan_functions.py diff --git a/cubed/array/__init__.py b/cubed/array/__init__.py new file mode 100644 index 000000000..ca88ca0d1 --- /dev/null +++ b/cubed/array/__init__.py @@ -0,0 +1,7 @@ +from ..array_api import * # noqa: F401, F403 + +__all__ = [] + +from .nan_functions import nanmean, nansum + +__all__ += ["nanmean", "nansum"] diff --git a/cubed/array/nan_functions.py b/cubed/array/nan_functions.py new file mode 100644 index 000000000..836fb1faf --- /dev/null +++ b/cubed/array/nan_functions.py @@ -0,0 +1,73 @@ +import numpy as np + +from cubed.array_api.dtypes import ( + _numeric_dtypes, + _signed_integer_dtypes, + _unsigned_integer_dtypes, + complex64, + complex128, + float32, + float64, + int64, + uint64, +) +from cubed.core import reduction + +# TODO: refactor once nan functions are standardized: +# https://github.com/data-apis/array-api/issues/621 + + +def nanmean(x, /, *, axis=None, keepdims=False): + """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" + dtype = x.dtype + intermediate_dtype = [("n", np.int64), ("total", np.float64)] + return reduction( + x, + _nanmean_func, + combine_func=_nanmean_combine, + aggegrate_func=_nanmean_aggregate, + axis=axis, + intermediate_dtype=intermediate_dtype, + dtype=dtype, + keepdims=keepdims, + ) + + +def _nanmean_func(a, **kwargs): + n = _nannumel(a, **kwargs) + total = np.nansum(a, **kwargs) + return {"n": n, "total": total} + + +def _nanmean_combine(a, **kwargs): + n = np.nansum(a["n"], **kwargs) + total = np.nansum(a["total"], **kwargs) + return {"n": n, "total": total} + + +def _nanmean_aggregate(a): + with np.errstate(divide="ignore", invalid="ignore"): + return np.divide(a["total"], a["n"]) + + +def _nannumel(x, **kwargs): + """A reduction to count the number of elements, excluding nans""" + return np.sum(~(np.isnan(x)), **kwargs) + + +def nansum(x, /, *, axis=None, dtype=None, keepdims=False): + """Return the sum of array elements over a given axis treating NaNs as zero.""" + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in nansum") + if dtype is None: + if x.dtype in _signed_integer_dtypes: + dtype = int64 + elif x.dtype in _unsigned_integer_dtypes: + dtype = uint64 + elif x.dtype == float32: + dtype = float64 + elif x.dtype == complex64: + dtype = complex128 + else: + dtype = x.dtype + return reduction(x, np.nansum, axis=axis, dtype=dtype, keepdims=keepdims) diff --git a/cubed/tests/test_array_nan_functions.py b/cubed/tests/test_array_nan_functions.py new file mode 100644 index 000000000..7e1ded617 --- /dev/null +++ b/cubed/tests/test_array_nan_functions.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import cubed +import cubed.array as xp + + +@pytest.fixture() +def spec(tmp_path): + return cubed.Spec(tmp_path, allowed_mem=100000) + + +def test_nanmean(spec): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) + b = xp.nanmean(a) + assert_array_equal( + b.compute(), np.nanmean(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) + ) + + +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_nanmean_allnan(spec): + a = xp.asarray([xp.nan], spec=spec) + b = xp.nanmean(a) + assert_array_equal(b.compute(), np.nanmean(np.array([np.nan]))) + + +def test_nansum(spec): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) + b = xp.nansum(a) + assert_array_equal( + b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) + ) + + +def test_nansum_allnan(spec): + a = xp.asarray([xp.nan], spec=spec) + b = xp.nansum(a) + assert_array_equal(b.compute(), np.nansum(np.array([np.nan]))) diff --git a/docs/api.rst b/docs/api.rst index be4ad5333..9df12b1b3 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -46,6 +46,17 @@ Chunk-specific functions apply_gufunc map_blocks +Non-standardised functions +========================== + +.. currentmodule:: cubed.array +.. autosummary:: + :nosignatures: + :toctree: generated/ + + nanmean + nansum + Random number generation ========================