Skip to content

Commit

Permalink
Add nansum and nanmean functions to cubed.array (not `cubed.arr…
Browse files Browse the repository at this point in the history
…ay_api` since they are not yet standard)
  • Loading branch information
tomwhite committed Sep 13, 2023
1 parent a45a8d5 commit be97926
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 0 deletions.
7 changes: 7 additions & 0 deletions cubed/array/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ..array_api import * # noqa: F401, F403

__all__ = []

from .nan_functions import nanmean, nansum

__all__ += ["nanmean", "nansum"]
73 changes: 73 additions & 0 deletions cubed/array/nan_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy as np

from cubed.array_api.dtypes import (
_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
complex64,
complex128,
float32,
float64,
int64,
uint64,
)
from cubed.core import reduction

# TODO: refactor once nan functions are standardized:
# https://github.com/data-apis/array-api/issues/621


def nanmean(x, /, *, axis=None, keepdims=False):
"""Compute the arithmetic mean along the specified axis, ignoring NaNs."""
dtype = x.dtype
intermediate_dtype = [("n", np.int64), ("total", np.float64)]
return reduction(
x,
_nanmean_func,
combine_func=_nanmean_combine,
aggegrate_func=_nanmean_aggregate,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
)


def _nanmean_func(a, **kwargs):
n = _nannumel(a, **kwargs)
total = np.nansum(a, **kwargs)
return {"n": n, "total": total}


def _nanmean_combine(a, **kwargs):
n = np.nansum(a["n"], **kwargs)
total = np.nansum(a["total"], **kwargs)
return {"n": n, "total": total}


def _nanmean_aggregate(a):
with np.errstate(divide="ignore", invalid="ignore"):
return np.divide(a["total"], a["n"])


def _nannumel(x, **kwargs):
"""A reduction to count the number of elements, excluding nans"""
return np.sum(~(np.isnan(x)), **kwargs)


def nansum(x, /, *, axis=None, dtype=None, keepdims=False):
"""Return the sum of array elements over a given axis treating NaNs as zero."""
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in nansum")
if dtype is None:
if x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = float64
elif x.dtype == complex64:
dtype = complex128
else:
dtype = x.dtype
return reduction(x, np.nansum, axis=axis, dtype=dtype, keepdims=keepdims)
40 changes: 40 additions & 0 deletions cubed/tests/test_array_nan_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np
import pytest
from numpy.testing import assert_array_equal

import cubed
import cubed.array as xp


@pytest.fixture()
def spec(tmp_path):
return cubed.Spec(tmp_path, allowed_mem=100000)


def test_nanmean(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec)
b = xp.nanmean(a)
assert_array_equal(
b.compute(), np.nanmean(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]]))
)


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_nanmean_allnan(spec):
a = xp.asarray([xp.nan], spec=spec)
b = xp.nanmean(a)
assert_array_equal(b.compute(), np.nanmean(np.array([np.nan])))


def test_nansum(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec)
b = xp.nansum(a)
assert_array_equal(
b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]]))
)


def test_nansum_allnan(spec):
a = xp.asarray([xp.nan], spec=spec)
b = xp.nansum(a)
assert_array_equal(b.compute(), np.nansum(np.array([np.nan])))
11 changes: 11 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ Chunk-specific functions
apply_gufunc
map_blocks

Non-standardised functions
==========================

.. currentmodule:: cubed.array
.. autosummary::
:nosignatures:
:toctree: generated/

nanmean
nansum

Random number generation
========================

Expand Down

0 comments on commit be97926

Please sign in to comment.