Skip to content

Commit

Permalink
Always reduce initial chunks in reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Sep 13, 2023
1 parent be97926 commit f3ac181
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,24 +729,20 @@ def reduction(
allowed_mem = x.spec.allowed_mem
max_mem = allowed_mem - x.spec.reserved_mem

# reduce initial chunks (if any axis chunksize is > 1)
if (
any(s > 1 for i, s in enumerate(result.chunksize) if i in axis)
or func != combine_func
):
args = (result, inds)
adjust_chunks = {
i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks)
}
result = blockwise(
func,
inds,
*args,
axis=axis,
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
)
# reduce initial chunks
args = (result, inds)
adjust_chunks = {
i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks)
}
result = blockwise(
func,
inds,
*args,
axis=axis,
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
)

# merge/reduce along axis in multiple rounds until there's a single block in each reduction axis
while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis):
Expand Down

0 comments on commit f3ac181

Please sign in to comment.