Skip to content

Commit

Permalink
Fix groupby tests running on tensorstore (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Nov 6, 2024
1 parent 8f4d2f7 commit 10001f8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tensorstore-tests.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: TensorStore tests

on:
pull_request:
schedule:
# Every weekday at 03:58 UTC, see https://crontab.guru/
- cron: "58 3 * * 1-5"
Expand Down
5 changes: 2 additions & 3 deletions cubed/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def groupby_blockwise(
axis=None,
dtype=None,
num_groups=None,
extra_func_kwargs=None,
**kwargs,
):
"""A blockwise operation that performs groupby aggregations.
Expand All @@ -148,8 +148,6 @@ def groupby_blockwise(
Data type of output.
num_groups: int
The number of groups in the grouping array ``by``.
extra_func_kwargs: dict, optional
Extra keyword arguments to pass to ``func``.
"""

if by.ndim != 1:
Expand Down Expand Up @@ -203,6 +201,7 @@ def selection_function(out_key):
by_read_chunks=by_read_chunks,
target_chunks=target_chunks,
groups_per_chunk=groups_per_chunk,
**kwargs,
)


Expand Down
8 changes: 6 additions & 2 deletions cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_groupby_blockwise_axis0():
axis=0,
dtype=nxp.int64,
num_groups=6,
groupby_dtype=nxp.int32,
)
assert_array_equal(
c.compute(),
Expand All @@ -129,6 +130,7 @@ def test_groupby_blockwise_axis1():
axis=1,
dtype=nxp.int64,
num_groups=6,
groupby_dtype=nxp.int32,
)
assert_array_equal(
c.compute(),
Expand All @@ -142,7 +144,9 @@ def test_groupby_blockwise_axis1():
)


def _sum_reduction_func(arr, by, axis, start_group, num_groups):
def _sum_reduction_func(arr, by, axis, start_group, num_groups, groupby_dtype):
# change 'by' so it starts from 0 for each chunk
by = by - start_group
return npg.aggregate(by, arr, func="sum", axis=axis, size=num_groups)
return npg.aggregate(
by, arr, func="sum", dtype=groupby_dtype, axis=axis, size=num_groups
)

0 comments on commit 10001f8

Please sign in to comment.