diff --git a/cubed/__init__.py b/cubed/__init__.py index 63580391..82fceac3 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -24,6 +24,7 @@ ) from .core.gufunc import apply_gufunc from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr +from .nan_functions import nanmean, nansum __all__ = [ "__version__", @@ -38,6 +39,8 @@ "map_blocks", "measure_reserved_mem", "measure_reserved_memory", + "nanmean", + "nansum", "store", "to_zarr", "visualize", diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 8a81fee7..97bc2643 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -729,24 +729,20 @@ def reduction( allowed_mem = x.spec.allowed_mem max_mem = allowed_mem - x.spec.reserved_mem - # reduce initial chunks (if any axis chunksize is > 1) - if ( - any(s > 1 for i, s in enumerate(result.chunksize) if i in axis) - or func != combine_func - ): - args = (result, inds) - adjust_chunks = { - i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks) - } - result = blockwise( - func, - inds, - *args, - axis=axis, - keepdims=True, - dtype=intermediate_dtype, - adjust_chunks=adjust_chunks, - ) + # reduce initial chunks + args = (result, inds) + adjust_chunks = { + i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks) + } + result = blockwise( + func, + inds, + *args, + axis=axis, + keepdims=True, + dtype=intermediate_dtype, + adjust_chunks=adjust_chunks, + ) # merge/reduce along axis in multiple rounds until there's a single block in each reduction axis while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis): diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py new file mode 100644 index 00000000..836fb1fa --- /dev/null +++ b/cubed/nan_functions.py @@ -0,0 +1,73 @@ +import numpy as np + +from cubed.array_api.dtypes import ( + _numeric_dtypes, + _signed_integer_dtypes, + _unsigned_integer_dtypes, + complex64, + complex128, + float32, + float64, + int64, + uint64, +) +from cubed.core import reduction + +# TODO: refactor once nan functions are standardized: +# https://github.com/data-apis/array-api/issues/621 + + +def nanmean(x, /, *, axis=None, keepdims=False): + """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" + dtype = x.dtype + intermediate_dtype = [("n", np.int64), ("total", np.float64)] + return reduction( + x, + _nanmean_func, + combine_func=_nanmean_combine, + aggegrate_func=_nanmean_aggregate, + axis=axis, + intermediate_dtype=intermediate_dtype, + dtype=dtype, + keepdims=keepdims, + ) + + +def _nanmean_func(a, **kwargs): + n = _nannumel(a, **kwargs) + total = np.nansum(a, **kwargs) + return {"n": n, "total": total} + + +def _nanmean_combine(a, **kwargs): + n = np.nansum(a["n"], **kwargs) + total = np.nansum(a["total"], **kwargs) + return {"n": n, "total": total} + + +def _nanmean_aggregate(a): + with np.errstate(divide="ignore", invalid="ignore"): + return np.divide(a["total"], a["n"]) + + +def _nannumel(x, **kwargs): + """A reduction to count the number of elements, excluding nans""" + return np.sum(~(np.isnan(x)), **kwargs) + + +def nansum(x, /, *, axis=None, dtype=None, keepdims=False): + """Return the sum of array elements over a given axis treating NaNs as zero.""" + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in nansum") + if dtype is None: + if x.dtype in _signed_integer_dtypes: + dtype = int64 + elif x.dtype in _unsigned_integer_dtypes: + dtype = uint64 + elif x.dtype == float32: + dtype = float64 + elif x.dtype == complex64: + dtype = complex128 + else: + dtype = x.dtype + return reduction(x, np.nansum, axis=axis, dtype=dtype, keepdims=keepdims) diff --git a/cubed/tests/test_nan_functions.py b/cubed/tests/test_nan_functions.py new file mode 100644 index 00000000..53264e79 --- /dev/null +++ b/cubed/tests/test_nan_functions.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import cubed +import cubed.array_api as xp + + +@pytest.fixture() +def spec(tmp_path): + return cubed.Spec(tmp_path, allowed_mem=100000) + + +def test_nanmean(spec): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) + b = cubed.nanmean(a) + assert_array_equal( + b.compute(), np.nanmean(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) + ) + + +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_nanmean_allnan(spec): + a = xp.asarray([xp.nan], spec=spec) + b = cubed.nanmean(a) + assert_array_equal(b.compute(), np.nanmean(np.array([np.nan]))) + + +def test_nansum(spec): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) + b = cubed.nansum(a) + assert_array_equal( + b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) + ) + + +def test_nansum_allnan(spec): + a = xp.asarray([xp.nan], spec=spec) + b = cubed.nansum(a) + assert_array_equal(b.compute(), np.nansum(np.array([np.nan]))) diff --git a/docs/api.rst b/docs/api.rst index be4ad533..87672809 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -46,6 +46,17 @@ Chunk-specific functions apply_gufunc map_blocks +Non-standardised functions +========================== + +.. currentmodule:: cubed +.. autosummary:: + :nosignatures: + :toctree: generated/ + + nanmean + nansum + Random number generation ========================