Skip to content

Commit

Permalink
Make func optional in reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Apr 12, 2024
1 parent 0574e2e commit 0041e34
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
7 changes: 1 addition & 6 deletions cubed/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def wrapper(a, by, **kwargs):
# then reduce across blocks
return reduction_new(
out,
func=_identity_func,
func=None,
combine_func=combine_func,
aggregate_func=aggregate_func,
axis=(dummy_axis, axis), # dummy and group axis
Expand All @@ -82,8 +82,3 @@ def wrapper(a, by, **kwargs):
combine_sizes={axis: num_groups}, # group axis doesn't have size 1
extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis),
)


def _identity_func(a, **kwargs):
# pass through
return a
14 changes: 11 additions & 3 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,10 @@ def reduction_new(
) -> "Array":
"""Apply a function to reduce an array along one or more axes."""
if combine_func is None:
if func is None:
raise ValueError(
"At least one of `func` and `combine_func` must be specified in reduction"
)
combine_func = func
if axis is None:
axis = tuple(range(x.ndim))
Expand All @@ -1033,12 +1037,16 @@ def reduction_new(

split_every = _normalize_split_every(split_every, axis)

if func is None:
initial_func = None
else:
initial_func = partial(
func, axis=axis, keepdims=True, **(extra_func_kwargs or {})
)
result = partial_reduce(
x,
partial(combine_func, **(extra_func_kwargs or {})),
initial_func=partial(
func, axis=axis, keepdims=True, **(extra_func_kwargs or {})
),
initial_func=initial_func,
split_every=split_every,
dtype=intermediate_dtype,
combine_sizes=combine_sizes,
Expand Down

0 comments on commit 0041e34

Please sign in to comment.