|
2 | 2 |
|
3 | 3 | from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
|
4 | 4 | from cubed.backend_array_api import namespace as nxp
|
5 |
| -from cubed.core.ops import map_blocks, reduction_new |
| 5 | +from cubed.core.ops import map_blocks, map_direct, reduction_new |
| 6 | +from cubed.utils import array_memory, get_item |
| 7 | +from cubed.vendor.dask.array.core import normalize_chunks |
6 | 8 |
|
7 | 9 | if TYPE_CHECKING:
|
8 | 10 | from cubed.array_api.array_object import Array
|
@@ -116,3 +118,141 @@ def wrapper(a, by, **kwargs):
|
116 | 118 | combine_sizes={axis: num_groups}, # group axis doesn't have size 1
|
117 | 119 | extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis),
|
118 | 120 | )
|
| 121 | + |
| 122 | + |
| 123 | +def groupby_blockwise( |
| 124 | + x: "Array", |
| 125 | + by, |
| 126 | + func, |
| 127 | + axis=None, |
| 128 | + dtype=None, |
| 129 | + num_groups=None, |
| 130 | + extra_func_kwargs=None, |
| 131 | +): |
| 132 | + """A blockwise that performs groupby aggregations. |
| 133 | +
|
| 134 | + Parameters |
| 135 | + ---------- |
| 136 | + x: Array |
| 137 | + Array being grouped along one axis. |
| 138 | + by: nxp.array |
| 139 | + Array of non-negative integers to be used as labels with which to group |
| 140 | + the values in ``x`` along the reduction axis. Must be a 1D array. |
| 141 | + func: callable |
| 142 | + Function to apply to each chunk of data. The output of the |
| 143 | + function is a chunk with size corresponding to the number of groups in the |
| 144 | + input chunk along the reduction axis. |
| 145 | + axis: int or sequence of ints, optional |
| 146 | + Axis to aggregate along. Only supports a single axis. |
| 147 | + dtype: dtype |
| 148 | + Data type of output. |
| 149 | + num_groups: int |
| 150 | + The number of groups in the grouping array ``by``. |
| 151 | + extra_func_kwargs: dict, optional |
| 152 | + Extra keyword arguments to pass to ``func``. |
| 153 | + """ |
| 154 | + |
| 155 | + if by.ndim != 1: |
| 156 | + raise ValueError(f"Array `by` must be 1D, but has {by.ndim} dimensions.") |
| 157 | + |
| 158 | + if isinstance(axis, tuple): |
| 159 | + if len(axis) != 1: |
| 160 | + raise ValueError( |
| 161 | + f"Only a single axis is supported for groupby_reduction: {axis}" |
| 162 | + ) |
| 163 | + axis = axis[0] |
| 164 | + |
| 165 | + newchunks, groups_per_chunk = _get_chunks_for_groups( |
| 166 | + x.numblocks[axis], |
| 167 | + by, |
| 168 | + num_groups=num_groups, |
| 169 | + ) |
| 170 | + |
| 171 | + # calculate the chunking used to read the input array 'x' |
| 172 | + read_chunks = tuple(newchunks if i == axis else c for i, c in enumerate(x.chunks)) |
| 173 | + |
| 174 | + # 'by' is not a cubed array, but we still read it in chunks |
| 175 | + by_read_chunks = (newchunks,) |
| 176 | + |
| 177 | + # find shape and chunks for the output |
| 178 | + shape = tuple(num_groups if i == axis else s for i, s in enumerate(x.shape)) |
| 179 | + chunks = tuple( |
| 180 | + groups_per_chunk if i == axis else c for i, c in enumerate(x.chunksize) |
| 181 | + ) |
| 182 | + target_chunks = normalize_chunks(chunks, shape, dtype=dtype) |
| 183 | + |
| 184 | + # memory allocated by reading one chunk from input array |
| 185 | + # note that although read_chunks will overlap multiple input chunks, zarr will |
| 186 | + # read the chunks in series, reusing the buffer |
| 187 | + extra_projected_mem = x.chunkmem |
| 188 | + |
| 189 | + # memory allocated for largest of (variable sized) read_chunks |
| 190 | + read_chunksize = tuple(max(c) for c in read_chunks) |
| 191 | + extra_projected_mem += array_memory(x.dtype, read_chunksize) |
| 192 | + |
| 193 | + return map_direct( |
| 194 | + _process_blockwise_chunk, |
| 195 | + x, |
| 196 | + shape=shape, |
| 197 | + dtype=dtype, |
| 198 | + chunks=target_chunks, |
| 199 | + extra_projected_mem=extra_projected_mem, |
| 200 | + axis=axis, |
| 201 | + by=by, |
| 202 | + blockwise_func=func, |
| 203 | + read_chunks=read_chunks, |
| 204 | + by_read_chunks=by_read_chunks, |
| 205 | + target_chunks=target_chunks, |
| 206 | + groups_per_chunk=groups_per_chunk, |
| 207 | + extra_func_kwargs=extra_func_kwargs, |
| 208 | + ) |
| 209 | + |
| 210 | + |
| 211 | +def _process_blockwise_chunk( |
| 212 | + x, |
| 213 | + *arrays, |
| 214 | + axis=None, |
| 215 | + by=None, |
| 216 | + blockwise_func=None, |
| 217 | + read_chunks=None, |
| 218 | + by_read_chunks=None, |
| 219 | + target_chunks=None, |
| 220 | + groups_per_chunk=None, |
| 221 | + block_id=None, |
| 222 | + **kwargs, |
| 223 | +): |
| 224 | + array = arrays[0].zarray # underlying Zarr array (or virtual array) |
| 225 | + idx = block_id |
| 226 | + bi = idx[axis] |
| 227 | + |
| 228 | + result = array[get_item(read_chunks, idx)] |
| 229 | + by = by[get_item(by_read_chunks, (bi,))] |
| 230 | + |
| 231 | + start_group = bi * groups_per_chunk |
| 232 | + |
| 233 | + return blockwise_func( |
| 234 | + result, |
| 235 | + by, |
| 236 | + axis=axis, |
| 237 | + start_group=start_group, |
| 238 | + num_groups=target_chunks[axis][bi], |
| 239 | + **kwargs, |
| 240 | + ) |
| 241 | + |
| 242 | + |
| 243 | +def _get_chunks_for_groups(num_chunks, labels, num_groups): |
| 244 | + """Find new chunking so that there are an equal number of group labels per chunk.""" |
| 245 | + |
| 246 | + # find the start indexes of each group |
| 247 | + start_indexes = nxp.searchsorted(labels, nxp.arange(num_groups)) |
| 248 | + |
| 249 | + # find the number of groups per chunk |
| 250 | + groups_per_chunk = max(num_groups // num_chunks, 1) |
| 251 | + |
| 252 | + # each chunk has groups_per_chunk groups in it (except possibly last one) |
| 253 | + chunk_boundaries = start_indexes[::groups_per_chunk] |
| 254 | + |
| 255 | + # successive differences give the new chunk sizes (include end index for last chunk) |
| 256 | + newchunks = nxp.diff(chunk_boundaries, append=len(labels)) |
| 257 | + |
| 258 | + return tuple(newchunks), groups_per_chunk |
0 commit comments