Skip to content

Commit

Permalink
Add a groupby_blockwise function for use in Flox
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Apr 24, 2024
1 parent 0f3021a commit 3bc15c3
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 3 deletions.
144 changes: 142 additions & 2 deletions cubed/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_blocks, reduction_new
from cubed.core.ops import map_blocks, map_direct, reduction_new
from cubed.utils import array_memory, get_item
from cubed.vendor.dask.array.core import normalize_chunks

if TYPE_CHECKING:
from cubed.array_api.array_object import Array
Expand All @@ -22,7 +24,7 @@ def groupby_reduction(
num_groups=None,
extra_func_kwargs=None,
) -> "Array":
"""A reduction that performs groupby aggregations.
"""A reduction operation that performs groupby aggregations.
Parameters
----------
Expand Down Expand Up @@ -116,3 +118,141 @@ def wrapper(a, by, **kwargs):
combine_sizes={axis: num_groups}, # group axis doesn't have size 1
extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis),
)


def groupby_blockwise(
x: "Array",
by,
func,
axis=None,
dtype=None,
num_groups=None,
extra_func_kwargs=None,
):
"""A blockwise operation that performs groupby aggregations.
Parameters
----------
x: Array
Array being grouped along one axis.
by: nxp.array
Array of non-negative integers to be used as labels with which to group
the values in ``x`` along the reduction axis. Must be a 1D array.
func: callable
Function to apply to each chunk of data. The output of the
function is a chunk with size corresponding to the number of groups in the
input chunk along the reduction axis.
axis: int or sequence of ints, optional
Axis to aggregate along. Only supports a single axis.
dtype: dtype
Data type of output.
num_groups: int
The number of groups in the grouping array ``by``.
extra_func_kwargs: dict, optional
Extra keyword arguments to pass to ``func``.
"""

if by.ndim != 1:
raise ValueError(f"Array `by` must be 1D, but has {by.ndim} dimensions.")

if isinstance(axis, tuple):
if len(axis) != 1:
raise ValueError(
f"Only a single axis is supported for groupby_reduction: {axis}"
)
axis = axis[0]

newchunks, groups_per_chunk = _get_chunks_for_groups(
x.numblocks[axis],
by,
num_groups=num_groups,
)

# calculate the chunking used to read the input array 'x'
read_chunks = tuple(newchunks if i == axis else c for i, c in enumerate(x.chunks))

# 'by' is not a cubed array, but we still read it in chunks
by_read_chunks = (newchunks,)

# find shape and chunks for the output
shape = tuple(num_groups if i == axis else s for i, s in enumerate(x.shape))
chunks = tuple(
groups_per_chunk if i == axis else c for i, c in enumerate(x.chunksize)
)
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although read_chunks will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem

# memory allocated for largest of (variable sized) read_chunks
read_chunksize = tuple(max(c) for c in read_chunks)
extra_projected_mem += array_memory(x.dtype, read_chunksize)

return map_direct(
_process_blockwise_chunk,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
axis=axis,
by=by,
blockwise_func=func,
read_chunks=read_chunks,
by_read_chunks=by_read_chunks,
target_chunks=target_chunks,
groups_per_chunk=groups_per_chunk,
extra_func_kwargs=extra_func_kwargs,
)


def _process_blockwise_chunk(
x,
*arrays,
axis=None,
by=None,
blockwise_func=None,
read_chunks=None,
by_read_chunks=None,
target_chunks=None,
groups_per_chunk=None,
block_id=None,
**kwargs,
):
array = arrays[0].zarray # underlying Zarr array (or virtual array)
idx = block_id
bi = idx[axis]

result = array[get_item(read_chunks, idx)]
by = by[get_item(by_read_chunks, (bi,))]

start_group = bi * groups_per_chunk

return blockwise_func(
result,
by,
axis=axis,
start_group=start_group,
num_groups=target_chunks[axis][bi],
**kwargs,
)


def _get_chunks_for_groups(num_chunks, labels, num_groups):
"""Find new chunking so that there are an equal number of group labels per chunk."""

# find the start indexes of each group
start_indexes = nxp.searchsorted(labels, nxp.arange(num_groups))

# find the number of groups per chunk
groups_per_chunk = max(num_groups // num_chunks, 1)

# each chunk has groups_per_chunk groups in it (except possibly last one)
chunk_boundaries = start_indexes[::groups_per_chunk]

# successive differences give the new chunk sizes (include end index for last chunk)
newchunks = nxp.diff(chunk_boundaries, append=len(labels))

return tuple(newchunks), groups_per_chunk
93 changes: 92 additions & 1 deletion cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import numpy as np
import numpy_groupies as npg
import pytest
from numpy.testing import assert_array_equal

import cubed.array_api as xp
from cubed.backend_array_api import namespace as nxp
from cubed.core.groupby import groupby_reduction
from cubed.core.groupby import (
_get_chunks_for_groups,
groupby_blockwise,
groupby_reduction,
)


def test_groupby_reduction_axis0():
Expand Down Expand Up @@ -59,3 +64,89 @@ def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims):

def _mean_groupby_aggregate(a):
return nxp.divide(a["total"], a["n"])


@pytest.mark.parametrize(
"num_chunks, expected_newchunks, expected_groups_per_chunk",
[
[10, (3, 2, 2, 0, 3), 1],
[5, (3, 2, 2, 0, 3), 1],
[4, (3, 2, 2, 0, 3), 1],
[3, (3, 2, 2, 0, 3), 1],
[2, (5, 2, 3), 2],
[2, (5, 2, 3), 2],
[2, (5, 2, 3), 2],
[2, (5, 2, 3), 2],
[2, (5, 2, 3), 2],
[1, (10), 5],
],
)
def test_get_chunks_for_groups(
num_chunks, expected_newchunks, expected_groups_per_chunk
):
# group 3 has no data
labels = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
newchunks, groups_per_chunk = _get_chunks_for_groups(
num_chunks, labels, num_groups=5
)
assert_array_equal(newchunks, expected_newchunks)
assert groups_per_chunk == expected_groups_per_chunk


def test_groupby_blockwise_axis0():
a = xp.ones((10, 3), dtype=nxp.int32, chunks=(6, 2))
b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
extra_func_kwargs = dict(dtype=nxp.int32)
c = groupby_blockwise(
a,
b,
func=_sum_reduction_func,
axis=0,
dtype=nxp.int64,
num_groups=6,
extra_func_kwargs=extra_func_kwargs,
)
assert_array_equal(
c.compute(),
nxp.asarray(
[
[3, 3, 3],
[2, 2, 2],
[2, 2, 2],
[0, 0, 0], # group 3 has no data
[3, 3, 3],
[0, 0, 0], # final group since we specified num_groups=6
]
),
)


def test_groupby_blockwise_axis1():
a = xp.ones((3, 10), dtype=nxp.int32, chunks=(6, 2))
b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
extra_func_kwargs = dict(dtype=nxp.int32)
c = groupby_blockwise(
a,
b,
func=_sum_reduction_func,
axis=1,
dtype=nxp.int64,
num_groups=6,
extra_func_kwargs=extra_func_kwargs,
)
assert_array_equal(
c.compute(),
nxp.asarray(
[
[3, 2, 2, 0, 3, 0],
[3, 2, 2, 0, 3, 0],
[3, 2, 2, 0, 3, 0],
]
),
)


def _sum_reduction_func(arr, by, axis, start_group, num_groups, dtype):
# change 'by' so it starts from 0 for each chunk
by = by - start_group
return npg.aggregate(by, arr, func="sum", dtype=dtype, axis=axis, size=num_groups)

0 comments on commit 3bc15c3

Please sign in to comment.