2
2
3
3
from cubed .array_api .manipulation_functions import broadcast_to , expand_dims
4
4
from cubed .backend_array_api import namespace as nxp
5
- from cubed .core .ops import map_blocks , reduction_new
5
+ from cubed .core .ops import map_blocks , map_direct , reduction_new
6
+ from cubed .utils import array_memory , get_item
7
+ from cubed .vendor .dask .array .core import normalize_chunks
6
8
7
9
if TYPE_CHECKING :
8
10
from cubed .array_api .array_object import Array
@@ -22,7 +24,7 @@ def groupby_reduction(
22
24
num_groups = None ,
23
25
extra_func_kwargs = None ,
24
26
) -> "Array" :
25
- """A reduction that performs groupby aggregations.
27
+ """A reduction operation that performs groupby aggregations.
26
28
27
29
Parameters
28
30
----------
@@ -116,3 +118,141 @@ def wrapper(a, by, **kwargs):
116
118
combine_sizes = {axis : num_groups }, # group axis doesn't have size 1
117
119
extra_func_kwargs = dict (dtype = intermediate_dtype , dummy_axis = dummy_axis ),
118
120
)
121
+
122
+
123
+ def groupby_blockwise (
124
+ x : "Array" ,
125
+ by ,
126
+ func ,
127
+ axis = None ,
128
+ dtype = None ,
129
+ num_groups = None ,
130
+ extra_func_kwargs = None ,
131
+ ):
132
+ """A blockwise operation that performs groupby aggregations.
133
+
134
+ Parameters
135
+ ----------
136
+ x: Array
137
+ Array being grouped along one axis.
138
+ by: nxp.array
139
+ Array of non-negative integers to be used as labels with which to group
140
+ the values in ``x`` along the reduction axis. Must be a 1D array.
141
+ func: callable
142
+ Function to apply to each chunk of data. The output of the
143
+ function is a chunk with size corresponding to the number of groups in the
144
+ input chunk along the reduction axis.
145
+ axis: int or sequence of ints, optional
146
+ Axis to aggregate along. Only supports a single axis.
147
+ dtype: dtype
148
+ Data type of output.
149
+ num_groups: int
150
+ The number of groups in the grouping array ``by``.
151
+ extra_func_kwargs: dict, optional
152
+ Extra keyword arguments to pass to ``func``.
153
+ """
154
+
155
+ if by .ndim != 1 :
156
+ raise ValueError (f"Array `by` must be 1D, but has { by .ndim } dimensions." )
157
+
158
+ if isinstance (axis , tuple ):
159
+ if len (axis ) != 1 :
160
+ raise ValueError (
161
+ f"Only a single axis is supported for groupby_reduction: { axis } "
162
+ )
163
+ axis = axis [0 ]
164
+
165
+ newchunks , groups_per_chunk = _get_chunks_for_groups (
166
+ x .numblocks [axis ],
167
+ by ,
168
+ num_groups = num_groups ,
169
+ )
170
+
171
+ # calculate the chunking used to read the input array 'x'
172
+ read_chunks = tuple (newchunks if i == axis else c for i , c in enumerate (x .chunks ))
173
+
174
+ # 'by' is not a cubed array, but we still read it in chunks
175
+ by_read_chunks = (newchunks ,)
176
+
177
+ # find shape and chunks for the output
178
+ shape = tuple (num_groups if i == axis else s for i , s in enumerate (x .shape ))
179
+ chunks = tuple (
180
+ groups_per_chunk if i == axis else c for i , c in enumerate (x .chunksize )
181
+ )
182
+ target_chunks = normalize_chunks (chunks , shape , dtype = dtype )
183
+
184
+ # memory allocated by reading one chunk from input array
185
+ # note that although read_chunks will overlap multiple input chunks, zarr will
186
+ # read the chunks in series, reusing the buffer
187
+ extra_projected_mem = x .chunkmem
188
+
189
+ # memory allocated for largest of (variable sized) read_chunks
190
+ read_chunksize = tuple (max (c ) for c in read_chunks )
191
+ extra_projected_mem += array_memory (x .dtype , read_chunksize )
192
+
193
+ return map_direct (
194
+ _process_blockwise_chunk ,
195
+ x ,
196
+ shape = shape ,
197
+ dtype = dtype ,
198
+ chunks = target_chunks ,
199
+ extra_projected_mem = extra_projected_mem ,
200
+ axis = axis ,
201
+ by = by ,
202
+ blockwise_func = func ,
203
+ read_chunks = read_chunks ,
204
+ by_read_chunks = by_read_chunks ,
205
+ target_chunks = target_chunks ,
206
+ groups_per_chunk = groups_per_chunk ,
207
+ extra_func_kwargs = extra_func_kwargs ,
208
+ )
209
+
210
+
211
+ def _process_blockwise_chunk (
212
+ x ,
213
+ * arrays ,
214
+ axis = None ,
215
+ by = None ,
216
+ blockwise_func = None ,
217
+ read_chunks = None ,
218
+ by_read_chunks = None ,
219
+ target_chunks = None ,
220
+ groups_per_chunk = None ,
221
+ block_id = None ,
222
+ ** kwargs ,
223
+ ):
224
+ array = arrays [0 ].zarray # underlying Zarr array (or virtual array)
225
+ idx = block_id
226
+ bi = idx [axis ]
227
+
228
+ result = array [get_item (read_chunks , idx )]
229
+ by = by [get_item (by_read_chunks , (bi ,))]
230
+
231
+ start_group = bi * groups_per_chunk
232
+
233
+ return blockwise_func (
234
+ result ,
235
+ by ,
236
+ axis = axis ,
237
+ start_group = start_group ,
238
+ num_groups = target_chunks [axis ][bi ],
239
+ ** kwargs ,
240
+ )
241
+
242
+
243
+ def _get_chunks_for_groups (num_chunks , labels , num_groups ):
244
+ """Find new chunking so that there are an equal number of group labels per chunk."""
245
+
246
+ # find the start indexes of each group
247
+ start_indexes = nxp .searchsorted (labels , nxp .arange (num_groups ))
248
+
249
+ # find the number of groups per chunk
250
+ groups_per_chunk = max (num_groups // num_chunks , 1 )
251
+
252
+ # each chunk has groups_per_chunk groups in it (except possibly last one)
253
+ chunk_boundaries = start_indexes [::groups_per_chunk ]
254
+
255
+ # successive differences give the new chunk sizes (include end index for last chunk)
256
+ newchunks = nxp .diff (chunk_boundaries , append = len (labels ))
257
+
258
+ return tuple (newchunks ), groups_per_chunk
0 commit comments