Skip to content

Commit

Permalink
Pass split_every through to TreeReduce (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Aug 8, 2023
1 parent 1a5571c commit 65163cd
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
10 changes: 9 additions & 1 deletion dask_expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ class SingleAggregation(ApplyConcatApply):
"chunk_kwargs",
"aggregate_kwargs",
"_slice",
"split_every",
]
_defaults = {
"observed": None,
"dropna": None,
"chunk_kwargs": None,
"aggregate_kwargs": None,
"_slice": None,
"split_every": 8,
}

groupby_chunk = None
Expand Down Expand Up @@ -433,7 +435,12 @@ def _numeric_only_kwargs(self, numeric_only):
return {"chunk_kwargs": kwargs, "aggregate_kwargs": kwargs}

def _single_agg(
self, expr_cls, split_out=1, chunk_kwargs=None, aggregate_kwargs=None
self,
expr_cls,
split_every=8,
split_out=1,
chunk_kwargs=None,
aggregate_kwargs=None,
):
if split_out > 1:
raise NotImplementedError("split_out>1 not yet supported")
Expand All @@ -446,6 +453,7 @@ def _single_agg(
chunk_kwargs=chunk_kwargs,
aggregate_kwargs=aggregate_kwargs,
_slice=self._slice,
split_every=split_every,
)
)

Expand Down
5 changes: 4 additions & 1 deletion dask_expr/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _layer(self):
# This is an abstract expression
raise NotImplementedError()

@property
@functools.cached_property
def _meta(self):
meta = meta_nonempty(self.frame._meta)
meta = self.chunk(meta, **self.chunk_kwargs)
Expand Down Expand Up @@ -98,6 +98,7 @@ def _lower(self):
aggregate,
combine_kwargs,
aggregate_kwargs,
split_every=getattr(self, "split_every", 0),
)


Expand Down Expand Up @@ -161,7 +162,9 @@ class TreeReduce(Expr):
"aggregate",
"combine_kwargs",
"aggregate_kwargs",
"split_every",
]
_defaults = {"split_every": 0}

def __dask_postcompute__(self):
return toolz.first, ()
Expand Down
14 changes: 14 additions & 0 deletions dask_expr/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dask.dataframe.utils import assert_eq

from dask_expr import from_pandas
from dask_expr._reductions import TreeReduce
from dask_expr.tests._util import _backend_library, xfail_gpu

# Set DataFrame backend for this module
Expand Down Expand Up @@ -122,3 +123,16 @@ def test_groupby_agg_column_projection(pdf, df):
assert list(agg.frame.columns) == ["x"]
expect = pdf.groupby("x").agg({"x": "count"})
assert_eq(agg, expect)


def test_groupby_split_every(pdf):
df = from_pandas(pdf, npartitions=16)
query = df.groupby("x").sum()
tree_reduce_node = list(query.optimize(fuse=False).find_operations(TreeReduce))
assert len(tree_reduce_node) == 1
assert tree_reduce_node[0].split_every == 8

query = df.groupby("x").aggregate({"y": "sum"})
tree_reduce_node = list(query.optimize(fuse=False).find_operations(TreeReduce))
assert len(tree_reduce_node) == 1
assert tree_reduce_node[0].split_every == 8

0 comments on commit 65163cd

Please sign in to comment.