From f3ac181ffa3891b29bc0a167b665207807b4732d Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 12 Sep 2023 16:18:00 +0100 Subject: [PATCH] Always reduce initial chunks in `reduction` --- cubed/core/ops.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 8a81fee70..97bc26431 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -729,24 +729,20 @@ def reduction( allowed_mem = x.spec.allowed_mem max_mem = allowed_mem - x.spec.reserved_mem - # reduce initial chunks (if any axis chunksize is > 1) - if ( - any(s > 1 for i, s in enumerate(result.chunksize) if i in axis) - or func != combine_func - ): - args = (result, inds) - adjust_chunks = { - i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks) - } - result = blockwise( - func, - inds, - *args, - axis=axis, - keepdims=True, - dtype=intermediate_dtype, - adjust_chunks=adjust_chunks, - ) + # reduce initial chunks + args = (result, inds) + adjust_chunks = { + i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks) + } + result = blockwise( + func, + inds, + *args, + axis=axis, + keepdims=True, + dtype=intermediate_dtype, + adjust_chunks=adjust_chunks, + ) # merge/reduce along axis in multiple rounds until there's a single block in each reduction axis while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis):