Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a groupby_blockwise function for use in Flox #448

Merged
merged 1 commit into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 142 additions & 2 deletions cubed/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_blocks, reduction_new
from cubed.core.ops import map_blocks, map_direct, reduction_new
from cubed.utils import array_memory, get_item
from cubed.vendor.dask.array.core import normalize_chunks

if TYPE_CHECKING:
from cubed.array_api.array_object import Array
Expand All @@ -22,7 +24,7 @@ def groupby_reduction(
num_groups=None,
extra_func_kwargs=None,
) -> "Array":
"""A reduction that performs groupby aggregations.
"""A reduction operation that performs groupby aggregations.

Parameters
----------
Expand Down Expand Up @@ -116,3 +118,141 @@ def wrapper(a, by, **kwargs):
combine_sizes={axis: num_groups}, # group axis doesn't have size 1
extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis),
)


def groupby_blockwise(
x: "Array",
by,
func,
axis=None,
dtype=None,
num_groups=None,
extra_func_kwargs=None,
):
"""A blockwise operation that performs groupby aggregations.

Parameters
----------
x: Array
Array being grouped along one axis.
by: nxp.array
Array of non-negative integers to be used as labels with which to group
the values in ``x`` along the reduction axis. Must be a 1D array.
func: callable
Function to apply to each chunk of data. The output of the
function is a chunk with size corresponding to the number of groups in the
input chunk along the reduction axis.
axis: int or sequence of ints, optional
Axis to aggregate along. Only supports a single axis.
dtype: dtype
Data type of output.
num_groups: int
The number of groups in the grouping array ``by``.
extra_func_kwargs: dict, optional
Extra keyword arguments to pass to ``func``.
"""

if by.ndim != 1:
raise ValueError(f"Array `by` must be 1D, but has {by.ndim} dimensions.")

if isinstance(axis, tuple):
if len(axis) != 1:
raise ValueError(
f"Only a single axis is supported for groupby_reduction: {axis}"
)
axis = axis[0]

newchunks, groups_per_chunk = _get_chunks_for_groups(
x.numblocks[axis],
by,
num_groups=num_groups,
)

# calculate the chunking used to read the input array 'x'
read_chunks = tuple(newchunks if i == axis else c for i, c in enumerate(x.chunks))

# 'by' is not a cubed array, but we still read it in chunks
by_read_chunks = (newchunks,)

# find shape and chunks for the output
shape = tuple(num_groups if i == axis else s for i, s in enumerate(x.shape))
chunks = tuple(
groups_per_chunk if i == axis else c for i, c in enumerate(x.chunksize)
)
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although read_chunks will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem

# memory allocated for largest of (variable sized) read_chunks
read_chunksize = tuple(max(c) for c in read_chunks)
extra_projected_mem += array_memory(x.dtype, read_chunksize)

return map_direct(
_process_blockwise_chunk,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
axis=axis,
by=by,
blockwise_func=func,
read_chunks=read_chunks,
by_read_chunks=by_read_chunks,
target_chunks=target_chunks,
groups_per_chunk=groups_per_chunk,
extra_func_kwargs=extra_func_kwargs,
)


def _process_blockwise_chunk(
x,
*arrays,
axis=None,
by=None,
blockwise_func=None,
read_chunks=None,
by_read_chunks=None,
target_chunks=None,
groups_per_chunk=None,
block_id=None,
**kwargs,
):
array = arrays[0].zarray # underlying Zarr array (or virtual array)
idx = block_id
bi = idx[axis]

result = array[get_item(read_chunks, idx)]
by = by[get_item(by_read_chunks, (bi,))]

start_group = bi * groups_per_chunk

return blockwise_func(
result,
by,
axis=axis,
start_group=start_group,
num_groups=target_chunks[axis][bi],
**kwargs,
)


def _get_chunks_for_groups(num_chunks, labels, num_groups):
"""Find new chunking so that there are an equal number of group labels per chunk."""

# find the start indexes of each group
start_indexes = nxp.searchsorted(labels, nxp.arange(num_groups))

# find the number of groups per chunk
groups_per_chunk = max(num_groups // num_chunks, 1)

# each chunk has groups_per_chunk groups in it (except possibly last one)
chunk_boundaries = start_indexes[::groups_per_chunk]

# successive differences give the new chunk sizes (include end index for last chunk)
newchunks = nxp.diff(chunk_boundaries, append=len(labels))

return tuple(newchunks), groups_per_chunk
93 changes: 92 additions & 1 deletion cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import numpy as np
import numpy_groupies as npg
import pytest
from numpy.testing import assert_array_equal

import cubed.array_api as xp
from cubed.backend_array_api import namespace as nxp
from cubed.core.groupby import groupby_reduction
from cubed.core.groupby import (
_get_chunks_for_groups,
groupby_blockwise,
groupby_reduction,
)


def test_groupby_reduction_axis0():
Expand Down Expand Up @@ -59,3 +64,89 @@ def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims):

def _mean_groupby_aggregate(a):
return nxp.divide(a["total"], a["n"])


@pytest.mark.parametrize(
"num_chunks, expected_newchunks, expected_groups_per_chunk",
[
[10, (3, 2, 2, 0, 3), 1],
[5, (3, 2, 2, 0, 3), 1],
[4, (3, 2, 2, 0, 3), 1],
[3, (3, 2, 2, 0, 3), 1],
[2, (5, 2, 3), 2],
[2, (5, 2, 3), 2],
[2, (5, 2, 3), 2],
[2, (5, 2, 3), 2],
[2, (5, 2, 3), 2],
[1, (10), 5],
],
)
def test_get_chunks_for_groups(
num_chunks, expected_newchunks, expected_groups_per_chunk
):
# group 3 has no data
labels = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
newchunks, groups_per_chunk = _get_chunks_for_groups(
num_chunks, labels, num_groups=5
)
assert_array_equal(newchunks, expected_newchunks)
assert groups_per_chunk == expected_groups_per_chunk


def test_groupby_blockwise_axis0():
a = xp.ones((10, 3), dtype=nxp.int32, chunks=(6, 2))
b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
extra_func_kwargs = dict(dtype=nxp.int32)
c = groupby_blockwise(
a,
b,
func=_sum_reduction_func,
axis=0,
dtype=nxp.int64,
num_groups=6,
extra_func_kwargs=extra_func_kwargs,
)
assert_array_equal(
c.compute(),
nxp.asarray(
[
[3, 3, 3],
[2, 2, 2],
[2, 2, 2],
[0, 0, 0], # group 3 has no data
[3, 3, 3],
[0, 0, 0], # final group since we specified num_groups=6
]
),
)


def test_groupby_blockwise_axis1():
a = xp.ones((3, 10), dtype=nxp.int32, chunks=(6, 2))
b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
extra_func_kwargs = dict(dtype=nxp.int32)
c = groupby_blockwise(
a,
b,
func=_sum_reduction_func,
axis=1,
dtype=nxp.int64,
num_groups=6,
extra_func_kwargs=extra_func_kwargs,
)
assert_array_equal(
c.compute(),
nxp.asarray(
[
[3, 2, 2, 0, 3, 0],
[3, 2, 2, 0, 3, 0],
[3, 2, 2, 0, 3, 0],
]
),
)


def _sum_reduction_func(arr, by, axis, start_group, num_groups, dtype):
# change 'by' so it starts from 0 for each chunk
by = by - start_group
return npg.aggregate(by, arr, func="sum", dtype=dtype, axis=axis, size=num_groups)
Loading