diff --git a/cubed/core/groupby.py b/cubed/core/groupby.py index 1fbb72d1..84514fdb 100644 --- a/cubed/core/groupby.py +++ b/cubed/core/groupby.py @@ -71,7 +71,7 @@ def wrapper(a, by, **kwargs): # then reduce across blocks return reduction_new( out, - func=_identity_func, + func=None, combine_func=combine_func, aggregate_func=aggregate_func, axis=(dummy_axis, axis), # dummy and group axis @@ -82,8 +82,3 @@ def wrapper(a, by, **kwargs): combine_sizes={axis: num_groups}, # group axis doesn't have size 1 extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis), ) - - -def _identity_func(a, **kwargs): - # pass through - return a diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 48a65608..e6e2317f 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -1022,6 +1022,10 @@ def reduction_new( ) -> "Array": """Apply a function to reduce an array along one or more axes.""" if combine_func is None: + if func is None: + raise ValueError( + "At least one of `func` and `combine_func` must be specified in reduction" + ) combine_func = func if axis is None: axis = tuple(range(x.ndim)) @@ -1033,12 +1037,16 @@ def reduction_new( split_every = _normalize_split_every(split_every, axis) + if func is None: + initial_func = None + else: + initial_func = partial( + func, axis=axis, keepdims=True, **(extra_func_kwargs or {}) + ) result = partial_reduce( x, partial(combine_func, **(extra_func_kwargs or {})), - initial_func=partial( - func, axis=axis, keepdims=True, **(extra_func_kwargs or {}) - ), + initial_func=initial_func, split_every=split_every, dtype=intermediate_dtype, combine_sizes=combine_sizes,