Skip to content

Commit 3bc15c3

Browse files
committed
Add a groupby_blockwise function for use in Flox
1 parent 0f3021a commit 3bc15c3

File tree

2 files changed

+234
-3
lines changed

2 files changed

+234
-3
lines changed

cubed/core/groupby.py

+142-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
44
from cubed.backend_array_api import namespace as nxp
5-
from cubed.core.ops import map_blocks, reduction_new
5+
from cubed.core.ops import map_blocks, map_direct, reduction_new
6+
from cubed.utils import array_memory, get_item
7+
from cubed.vendor.dask.array.core import normalize_chunks
68

79
if TYPE_CHECKING:
810
from cubed.array_api.array_object import Array
@@ -22,7 +24,7 @@ def groupby_reduction(
2224
num_groups=None,
2325
extra_func_kwargs=None,
2426
) -> "Array":
25-
"""A reduction that performs groupby aggregations.
27+
"""A reduction operation that performs groupby aggregations.
2628
2729
Parameters
2830
----------
@@ -116,3 +118,141 @@ def wrapper(a, by, **kwargs):
116118
combine_sizes={axis: num_groups}, # group axis doesn't have size 1
117119
extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis),
118120
)
121+
122+
123+
def groupby_blockwise(
124+
x: "Array",
125+
by,
126+
func,
127+
axis=None,
128+
dtype=None,
129+
num_groups=None,
130+
extra_func_kwargs=None,
131+
):
132+
"""A blockwise operation that performs groupby aggregations.
133+
134+
Parameters
135+
----------
136+
x: Array
137+
Array being grouped along one axis.
138+
by: nxp.array
139+
Array of non-negative integers to be used as labels with which to group
140+
the values in ``x`` along the reduction axis. Must be a 1D array.
141+
func: callable
142+
Function to apply to each chunk of data. The output of the
143+
function is a chunk with size corresponding to the number of groups in the
144+
input chunk along the reduction axis.
145+
axis: int or sequence of ints, optional
146+
Axis to aggregate along. Only supports a single axis.
147+
dtype: dtype
148+
Data type of output.
149+
num_groups: int
150+
The number of groups in the grouping array ``by``.
151+
extra_func_kwargs: dict, optional
152+
Extra keyword arguments to pass to ``func``.
153+
"""
154+
155+
if by.ndim != 1:
156+
raise ValueError(f"Array `by` must be 1D, but has {by.ndim} dimensions.")
157+
158+
if isinstance(axis, tuple):
159+
if len(axis) != 1:
160+
raise ValueError(
161+
f"Only a single axis is supported for groupby_reduction: {axis}"
162+
)
163+
axis = axis[0]
164+
165+
newchunks, groups_per_chunk = _get_chunks_for_groups(
166+
x.numblocks[axis],
167+
by,
168+
num_groups=num_groups,
169+
)
170+
171+
# calculate the chunking used to read the input array 'x'
172+
read_chunks = tuple(newchunks if i == axis else c for i, c in enumerate(x.chunks))
173+
174+
# 'by' is not a cubed array, but we still read it in chunks
175+
by_read_chunks = (newchunks,)
176+
177+
# find shape and chunks for the output
178+
shape = tuple(num_groups if i == axis else s for i, s in enumerate(x.shape))
179+
chunks = tuple(
180+
groups_per_chunk if i == axis else c for i, c in enumerate(x.chunksize)
181+
)
182+
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)
183+
184+
# memory allocated by reading one chunk from input array
185+
# note that although read_chunks will overlap multiple input chunks, zarr will
186+
# read the chunks in series, reusing the buffer
187+
extra_projected_mem = x.chunkmem
188+
189+
# memory allocated for largest of (variable sized) read_chunks
190+
read_chunksize = tuple(max(c) for c in read_chunks)
191+
extra_projected_mem += array_memory(x.dtype, read_chunksize)
192+
193+
return map_direct(
194+
_process_blockwise_chunk,
195+
x,
196+
shape=shape,
197+
dtype=dtype,
198+
chunks=target_chunks,
199+
extra_projected_mem=extra_projected_mem,
200+
axis=axis,
201+
by=by,
202+
blockwise_func=func,
203+
read_chunks=read_chunks,
204+
by_read_chunks=by_read_chunks,
205+
target_chunks=target_chunks,
206+
groups_per_chunk=groups_per_chunk,
207+
extra_func_kwargs=extra_func_kwargs,
208+
)
209+
210+
211+
def _process_blockwise_chunk(
212+
x,
213+
*arrays,
214+
axis=None,
215+
by=None,
216+
blockwise_func=None,
217+
read_chunks=None,
218+
by_read_chunks=None,
219+
target_chunks=None,
220+
groups_per_chunk=None,
221+
block_id=None,
222+
**kwargs,
223+
):
224+
array = arrays[0].zarray # underlying Zarr array (or virtual array)
225+
idx = block_id
226+
bi = idx[axis]
227+
228+
result = array[get_item(read_chunks, idx)]
229+
by = by[get_item(by_read_chunks, (bi,))]
230+
231+
start_group = bi * groups_per_chunk
232+
233+
return blockwise_func(
234+
result,
235+
by,
236+
axis=axis,
237+
start_group=start_group,
238+
num_groups=target_chunks[axis][bi],
239+
**kwargs,
240+
)
241+
242+
243+
def _get_chunks_for_groups(num_chunks, labels, num_groups):
244+
"""Find new chunking so that there are an equal number of group labels per chunk."""
245+
246+
# find the start indexes of each group
247+
start_indexes = nxp.searchsorted(labels, nxp.arange(num_groups))
248+
249+
# find the number of groups per chunk
250+
groups_per_chunk = max(num_groups // num_chunks, 1)
251+
252+
# each chunk has groups_per_chunk groups in it (except possibly last one)
253+
chunk_boundaries = start_indexes[::groups_per_chunk]
254+
255+
# successive differences give the new chunk sizes (include end index for last chunk)
256+
newchunks = nxp.diff(chunk_boundaries, append=len(labels))
257+
258+
return tuple(newchunks), groups_per_chunk

cubed/tests/test_groupby.py

+92-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import numpy as np
22
import numpy_groupies as npg
3+
import pytest
34
from numpy.testing import assert_array_equal
45

56
import cubed.array_api as xp
67
from cubed.backend_array_api import namespace as nxp
7-
from cubed.core.groupby import groupby_reduction
8+
from cubed.core.groupby import (
9+
_get_chunks_for_groups,
10+
groupby_blockwise,
11+
groupby_reduction,
12+
)
813

914

1015
def test_groupby_reduction_axis0():
@@ -59,3 +64,89 @@ def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims):
5964

6065
def _mean_groupby_aggregate(a):
6166
return nxp.divide(a["total"], a["n"])
67+
68+
69+
@pytest.mark.parametrize(
70+
"num_chunks, expected_newchunks, expected_groups_per_chunk",
71+
[
72+
[10, (3, 2, 2, 0, 3), 1],
73+
[5, (3, 2, 2, 0, 3), 1],
74+
[4, (3, 2, 2, 0, 3), 1],
75+
[3, (3, 2, 2, 0, 3), 1],
76+
[2, (5, 2, 3), 2],
77+
[2, (5, 2, 3), 2],
78+
[2, (5, 2, 3), 2],
79+
[2, (5, 2, 3), 2],
80+
[2, (5, 2, 3), 2],
81+
[1, (10), 5],
82+
],
83+
)
84+
def test_get_chunks_for_groups(
85+
num_chunks, expected_newchunks, expected_groups_per_chunk
86+
):
87+
# group 3 has no data
88+
labels = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
89+
newchunks, groups_per_chunk = _get_chunks_for_groups(
90+
num_chunks, labels, num_groups=5
91+
)
92+
assert_array_equal(newchunks, expected_newchunks)
93+
assert groups_per_chunk == expected_groups_per_chunk
94+
95+
96+
def test_groupby_blockwise_axis0():
97+
a = xp.ones((10, 3), dtype=nxp.int32, chunks=(6, 2))
98+
b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
99+
extra_func_kwargs = dict(dtype=nxp.int32)
100+
c = groupby_blockwise(
101+
a,
102+
b,
103+
func=_sum_reduction_func,
104+
axis=0,
105+
dtype=nxp.int64,
106+
num_groups=6,
107+
extra_func_kwargs=extra_func_kwargs,
108+
)
109+
assert_array_equal(
110+
c.compute(),
111+
nxp.asarray(
112+
[
113+
[3, 3, 3],
114+
[2, 2, 2],
115+
[2, 2, 2],
116+
[0, 0, 0], # group 3 has no data
117+
[3, 3, 3],
118+
[0, 0, 0], # final group since we specified num_groups=6
119+
]
120+
),
121+
)
122+
123+
124+
def test_groupby_blockwise_axis1():
125+
a = xp.ones((3, 10), dtype=nxp.int32, chunks=(6, 2))
126+
b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
127+
extra_func_kwargs = dict(dtype=nxp.int32)
128+
c = groupby_blockwise(
129+
a,
130+
b,
131+
func=_sum_reduction_func,
132+
axis=1,
133+
dtype=nxp.int64,
134+
num_groups=6,
135+
extra_func_kwargs=extra_func_kwargs,
136+
)
137+
assert_array_equal(
138+
c.compute(),
139+
nxp.asarray(
140+
[
141+
[3, 2, 2, 0, 3, 0],
142+
[3, 2, 2, 0, 3, 0],
143+
[3, 2, 2, 0, 3, 0],
144+
]
145+
),
146+
)
147+
148+
149+
def _sum_reduction_func(arr, by, axis, start_group, num_groups, dtype):
150+
# change 'by' so it starts from 0 for each chunk
151+
by = by - start_group
152+
return npg.aggregate(by, arr, func="sum", dtype=dtype, axis=axis, size=num_groups)

0 commit comments

Comments
 (0)