Skip to content

Commit

Permalink
Use new reduction implementation by default and add use_new_impl an…
Browse files Browse the repository at this point in the history
…d `split_every` to reduction functions in array API
  • Loading branch information
tomwhite committed Mar 8, 2024
1 parent 672aca2 commit 9dd82ad
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 29 deletions.
36 changes: 28 additions & 8 deletions cubed/array_api/linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cubed.core import blockwise, reduction, squeeze


def matmul(x1, x2, /):
def matmul(x1, x2, /, use_new_impl=True, split_every=None):
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in matmul")

Expand Down Expand Up @@ -47,7 +47,9 @@ def matmul(x1, x2, /):
dtype=dtype,
)

out = _sum_wo_cat(out, axis=-2, dtype=dtype)
out = _sum_wo_cat(
out, axis=-2, dtype=dtype, use_new_impl=use_new_impl, split_every=split_every
)

if x1_is_1d:
out = squeeze(out, -2)
Expand All @@ -62,13 +64,19 @@ def _matmul(a, b):
return chunk[..., nxp.newaxis, :]


def _sum_wo_cat(a, axis=None, dtype=None):
def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None):
if a.shape[axis] == 1:
return squeeze(a, axis)

extra_func_kwargs = dict(dtype=dtype)
return reduction(
a, _chunk_sum, axis=axis, dtype=dtype, extra_func_kwargs=extra_func_kwargs
a,
_chunk_sum,
axis=axis,
dtype=dtype,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


Expand All @@ -91,7 +99,7 @@ def outer(x1, x2, /):
return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype)


def tensordot(x1, x2, /, *, axes=2):
def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
from cubed.array_api.statistical_functions import sum

if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
Expand Down Expand Up @@ -135,7 +143,13 @@ def tensordot(x1, x2, /, *, axes=2):
adjust_chunks=adjust_chunks,
axes=(x1_axes, x2_axes),
)
return sum(out, axis=x1_axes, dtype=dtype)
return sum(
out,
axis=x1_axes,
dtype=dtype,
use_new_impl=use_new_impl,
split_every=split_every,
)


def _tensordot(a, b, axes):
Expand All @@ -147,7 +161,13 @@ def _tensordot(a, b, axes):
return x


def vecdot(x1, x2, /, *, axis=-1):
def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in vecdot")
return tensordot(x1, x2, axes=((axis,), (axis,)))
return tensordot(
x1,
x2,
axes=((axis,), (axis,)),
use_new_impl=use_new_impl,
split_every=split_every,
)
22 changes: 18 additions & 4 deletions cubed/array_api/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,38 @@
from cubed.core.ops import arg_reduction, elemwise


def argmax(x, /, *, axis=None, keepdims=False):
def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmax")
if axis is None:
x = reshape(x, (-1,))
axis = 0
keepdims = False
return arg_reduction(x, nxp.argmax, axis=axis, keepdims=keepdims)
return arg_reduction(
x,
nxp.argmax,
axis=axis,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


def argmin(x, /, *, axis=None, keepdims=False):
def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmin")
if axis is None:
x = reshape(x, (-1,))
axis = 0
keepdims = False
return arg_reduction(x, nxp.argmin, axis=axis, keepdims=keepdims)
return arg_reduction(
x,
nxp.argmin,
axis=axis,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


def where(condition, x1, x2, /):
Expand Down
38 changes: 31 additions & 7 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@
from cubed.core import reduction


def max(x, /, *, axis=None, keepdims=False):
def max(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in max")
return reduction(x, nxp.max, axis=axis, dtype=x.dtype, keepdims=keepdims)
return reduction(
x,
nxp.max,
axis=axis,
dtype=x.dtype,
use_new_impl=use_new_impl,
split_every=split_every,
keepdims=keepdims,
)


def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False, split_every=None):
def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses NumPy and Zarr's structured arrays to store a
Expand Down Expand Up @@ -99,13 +107,23 @@ def _numel(x, **kwargs):
return nxp.broadcast_to(nxp.asarray(prod, dtype=dtype), new_shape)


def min(x, /, *, axis=None, keepdims=False):
def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in min")
return reduction(x, nxp.min, axis=axis, dtype=x.dtype, keepdims=keepdims)
return reduction(
x,
nxp.min,
axis=axis,
dtype=x.dtype,
use_new_impl=use_new_impl,
split_every=split_every,
keepdims=keepdims,
)


def prod(x, /, *, axis=None, dtype=None, keepdims=False):
def prod(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in prod")
if dtype is None:
Expand All @@ -126,11 +144,15 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False):
axis=axis,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def sum(x, /, *, axis=None, dtype=None, keepdims=False):
def sum(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in sum")
if dtype is None:
Expand All @@ -151,5 +173,7 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False):
axis=axis,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
24 changes: 20 additions & 4 deletions cubed/array_api/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,29 @@
from cubed.core import reduction


def all(x, /, *, axis=None, keepdims=False):
def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
if x.size == 0:
return asarray(True, dtype=x.dtype)
return reduction(x, nxp.all, axis=axis, dtype=bool, keepdims=keepdims)
return reduction(
x,
nxp.all,
axis=axis,
dtype=bool,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


def any(x, /, *, axis=None, keepdims=False):
def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
if x.size == 0:
return asarray(False, dtype=x.dtype)
return reduction(x, nxp.any, axis=axis, dtype=bool, keepdims=keepdims)
return reduction(
x,
nxp.any,
axis=axis,
dtype=bool,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)
8 changes: 6 additions & 2 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def reduction(
intermediate_dtype=None,
dtype=None,
keepdims=False,
use_new_impl=False,
use_new_impl=True,
split_every=None,
extra_func_kwargs=None,
) -> "Array":
Expand Down Expand Up @@ -1174,7 +1174,9 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None):
return result


def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False):
def arg_reduction(
x, /, arg_func, axis=None, *, keepdims=False, use_new_impl=True, split_every=None
):
"""A reduction that returns the array indexes, not the values."""
dtype = nxp.int64 # index data type
intermediate_dtype = [("i", dtype), ("v", x.dtype)]
Expand All @@ -1200,6 +1202,8 @@ def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False):
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


Expand Down
18 changes: 15 additions & 3 deletions cubed/nan_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# https://github.com/data-apis/array-api/issues/621


def nanmean(x, /, *, axis=None, keepdims=False):
def nanmean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
"""Compute the arithmetic mean along the specified axis, ignoring NaNs."""
dtype = x.dtype
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
Expand All @@ -31,6 +31,8 @@ def nanmean(x, /, *, axis=None, keepdims=False):
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


Expand Down Expand Up @@ -59,7 +61,9 @@ def _nannumel(x, **kwargs):
return nxp.sum(~(nxp.isnan(x)), **kwargs)


def nansum(x, /, *, axis=None, dtype=None, keepdims=False):
def nansum(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
"""Return the sum of array elements over a given axis treating NaNs as zero."""
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in nansum")
Expand All @@ -74,4 +78,12 @@ def nansum(x, /, *, axis=None, dtype=None, keepdims=False):
dtype = complex128
else:
dtype = x.dtype
return reduction(x, nxp.nansum, axis=axis, dtype=dtype, keepdims=keepdims)
return reduction(
x,
nxp.nansum,
axis=axis,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)
3 changes: 2 additions & 1 deletion cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def test_reduction_not_enough_memory(tmp_path):
spec = cubed.Spec(tmp_path, allowed_mem=50)
a = xp.ones((100, 10), dtype=np.uint8, chunks=(1, 10), spec=spec)
with pytest.raises(ValueError, match=r"Not enough memory for reduction"):
xp.sum(a, axis=0, dtype=np.uint8)
# only a problem with the old implementation, so set use_new_impl=False
xp.sum(a, axis=0, dtype=np.uint8, use_new_impl=False)


def test_partial_reduce(spec):
Expand Down

0 comments on commit 9dd82ad

Please sign in to comment.