Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement groupby multi level support #486

Merged
merged 10 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 102 additions & 46 deletions dask_expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from dask import is_dask_collection
from dask.core import flatten
from dask.dataframe.core import _concat, is_dataframe_like, is_series_like
from dask.dataframe.dispatch import concat, make_meta, meta_nonempty
from dask.dataframe.groupby import (
Expand Down Expand Up @@ -30,7 +31,7 @@
)
from dask.utils import M, is_index_like

from dask_expr._collection import Index, Series, new_collection
from dask_expr._collection import FrameBase, Index, Series, new_collection
from dask_expr._expr import (
Assign,
Blockwise,
Expand Down Expand Up @@ -64,19 +65,57 @@ class GroupByChunk(Chunk):
_parameters = Chunk._parameters + ["by"]
_defaults = Chunk._defaults | {"by": None}

def dependencies(self):
deps = [operand for operand in self.operands if isinstance(operand, Expr)]
if isinstance(self.by, (list, tuple)):
deps.extend(op for op in self.by if isinstance(op, Expr))
return deps

@functools.cached_property
def _args(self) -> list:
return [self.frame, self.by]
if not isinstance(self.by, (list, tuple)):
return [self.frame, self.by]
else:
return [self.frame] + list(self.by)


class GroupByApplyConcatApply(ApplyConcatApply):
_chunk_cls = GroupByChunk

@functools.cached_property
def _by_meta(self):
if isinstance(self.by, Expr):
return [meta_nonempty(self.by._meta)]
elif is_scalar(self.by):
return [self.by]
else:
return [
meta_nonempty(x._meta) if isinstance(x, Expr) else x for x in self.by
]

@functools.cached_property
def _by_columns(self):
if isinstance(self.by, Expr):
return []
else:
return [x for x in self.by if not isinstance(x, Expr)]

@property
def split_by(self):
if isinstance(self.by, Expr):
return self.by.columns
else:
return list(
flatten(
[[x] if not isinstance(x, Expr) else x.columns for x in self.by],
container=list,
)
)

@functools.cached_property
def _meta_chunk(self):
meta = meta_nonempty(self.frame._meta)
by = self.by if not isinstance(self.by, Expr) else meta_nonempty(self.by._meta)
return self.chunk(meta, by, **self.chunk_kwargs)
return self.chunk(meta, *self._by_meta, **self.chunk_kwargs)

@property
def _chunk_cls_args(self):
Expand Down Expand Up @@ -156,16 +195,8 @@ class SingleAggregation(GroupByApplyConcatApply):
groupby_chunk = None
groupby_aggregate = None

@property
def split_by(self):
if isinstance(self.by, Expr):
return self.by.columns
return self.by

@classmethod
def chunk(cls, df, by=None, **kwargs):
if hasattr(by, "dtype"):
by = [by]
def chunk(cls, df, *by, **kwargs):
return _apply_chunk(df, *by, **kwargs)

@classmethod
Expand Down Expand Up @@ -199,8 +230,7 @@ def aggregate_kwargs(self) -> dict:

def _simplify_up(self, parent):
if isinstance(parent, Projection):
by_columns = self.by if not isinstance(self.by, Expr) else []
columns = sorted(set(parent.columns + by_columns))
columns = sorted(set(parent.columns + self._by_columns))
if columns == self.frame.columns:
return
columns = [col for col in self.frame.columns if col in columns]
Expand Down Expand Up @@ -254,10 +284,6 @@ class GroupbyAggregation(GroupByApplyConcatApply):
"sort": None,
}

@property
def split_by(self):
return self.by

@functools.cached_property
def spec(self):
# Converts the `arg` operand into specific
Expand Down Expand Up @@ -343,10 +369,11 @@ def aggregate_kwargs(self) -> dict:
def _simplify_down(self):
# Use agg-spec information to add column projection
column_projection = None
by_columns = self.by if not isinstance(self.by, Expr) else []
if isinstance(self.arg, dict):
column_projection = (
set(by_columns).union(self.arg.keys()).intersection(self.frame.columns)
set(self._by_columns)
.union(self.arg.keys())
.intersection(self.frame.columns)
)
if column_projection and column_projection < set(self.frame.columns):
return type(self)(self.frame[list(column_projection)], *self.operands[1:])
Expand Down Expand Up @@ -398,11 +425,28 @@ class GroupByReduction(Reduction):
def _chunk_cls_args(self):
return [self.by]

@functools.cached_property
def _by_meta(self):
if isinstance(self.by, Expr):
return meta_nonempty(self.by._meta)
elif is_scalar(self.by):
return self.by
else:
return [
meta_nonempty(x._meta) if isinstance(x, Expr) else x for x in self.by
]

@functools.cached_property
def _by_columns(self):
if isinstance(self.by, Expr):
return []
else:
return [x for x in self.by if not isinstance(x, Expr)]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels not great to have identical and/or very similar code (GroupByReduction/GroupByApplyConcatApply._by_columns are exactly the same, and GroupByReduction/GroupByApplyConcatApply._by_meta are nearly the same) between classes. Do you think the logic can be combined into ApplyConcatApply?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's not groupby specific, so that's not the place for this logic.

The groupby structure needs a refactor anyway to make this more consistent but that's something for a follow up if the actual implementation is ironed out

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote parts of the implementation, still not happy with it, but less duplicated code now

@functools.cached_property
def _meta_chunk(self):
meta = meta_nonempty(self.frame._meta)
by = self.by if not isinstance(self.by, Expr) else meta_nonempty(self.by._meta)
return self.chunk(meta, by, **self.chunk_kwargs)
return self.chunk(meta, *self._by_meta, **self.chunk_kwargs)


def _var_combine(g, levels, sort=False, observed=False, dropna=True):
Expand All @@ -428,12 +472,16 @@ class Var(GroupByReduction):
def split_by(self):
if isinstance(self.by, Expr):
return self.by.columns
return self.by
else:
return list(
flatten(
[[x] if not isinstance(x, Expr) else x.columns for x in self.by],
container=list,
)
)

@staticmethod
def chunk(frame, by, **kwargs):
if hasattr(by, "dtype"):
by = [by]
def chunk(frame, *by, **kwargs):
return _var_chunk(frame, *by, **kwargs)

@functools.cached_property
Expand Down Expand Up @@ -468,8 +516,7 @@ def _divisions(self):

def _simplify_up(self, parent):
if isinstance(parent, Projection):
by_columns = self.by if not isinstance(self.by, Expr) else []
columns = sorted(set(parent.columns + by_columns))
columns = sorted(set(parent.columns + self._by_columns))
if columns == self.frame.columns:
return
columns = [col for col in self.frame.columns if col in columns]
Expand Down Expand Up @@ -902,6 +949,22 @@ def _extract_meta(x, nonempty=False):
###


def _validate_by_expr(obj, by):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you think of a better name, and possibly add some docstring?
Elsewhere in dask, "validate" functions typically contain a bunch of asserts and return None.
This function seems to extract sometimes a column name, sometimes an Expr, sometimes something else.
Maybe "clean_by_expr" or "preprocess_by_expr"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_clean_by_expr seems fine

if (
isinstance(by, Series)
and by.name in obj.columns
and by._name == obj[by.name]._name
):
return by.name
elif isinstance(by, Index) and by._name == obj.index._name:
return by.expr
elif isinstance(by, Series):
if not are_co_aligned(obj.expr, by.expr):
raise ValueError("by must be in the DataFrames columns.")
return by.expr
return by
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining what use cases are not collected by any of the above switches?
I gather from elsewhere that Expr is a possible use case; are there others?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added,

it can be a proper column name, e.g. by="a"



class GroupBy:
"""Collection container for groupby aggregations

Expand All @@ -923,17 +986,10 @@ def __init__(
dropna=None,
slice=None,
):
if (
isinstance(by, Series)
and by.name in obj.columns
and by._name == obj[by.name]._name
):
by = by.name
elif isinstance(by, Index) and by._name == obj.index._name:
pass
elif isinstance(by, Series):
if not are_co_aligned(obj.expr, by.expr):
raise ValueError("by must be in the DataFrames columns.")
if isinstance(by, (tuple, list)):
by = [_validate_by_expr(obj, x) for x in by]
else:
by = _validate_by_expr(obj, by)

by_ = by if isinstance(by, (tuple, list)) else [by]
self._slice = slice
Expand All @@ -957,11 +1013,7 @@ def __init__(
self.observed = observed
self.dropna = dropna
self.group_keys = group_keys

if isinstance(by, Series):
self.by = by.expr
else:
self.by = [by] if np.isscalar(by) else list(by)
self.by = [by] if np.isscalar(by) or isinstance(by, Expr) else list(by)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everywhere else you added code branches that allow the by attribute to be either a scalar (Expr or otherwise) or a list. This however says that by is always a list?

I much prefer the latter. Test coverage for all these new code branches is quite spotty - something that coercing everything into a list as soon as it's acquired from the user would prevent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good point, that was an oversight on my part. It will always be a list after this pr is in


def _numeric_only_kwargs(self, numeric_only):
kwargs = {"numeric_only": numeric_only}
Expand Down Expand Up @@ -1246,7 +1298,11 @@ def __init__(
):
# Raise pandas errors if applicable
if isinstance(obj, Series):
if isinstance(by, Series):
if (
isinstance(by, FrameBase)
or isinstance(by, (list, tuple))
and any(isinstance(x, FrameBase) for x in by)
):
pass
elif isinstance(by, list):
if len(by) == 0:
Expand Down
21 changes: 21 additions & 0 deletions dask_expr/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ def test_groupby_numeric(pdf, df, api, numeric_only):
expect = getattr(pdf.groupby("x"), api)(numeric_only=numeric_only)["y"]
assert_eq(agg, expect)

g = df.groupby([df.x])
agg = getattr(g, api)(numeric_only=numeric_only)["y"]

expect = getattr(pdf.groupby([pdf.x]), api)(numeric_only=numeric_only)["y"]
assert_eq(agg, expect)

g = df.groupby([df.x, df.z])
agg = getattr(g, api)()
expect = getattr(pdf.groupby([pdf.x, pdf.z]), api)(numeric_only=numeric_only)
assert_eq(agg, expect)

g = df.groupby([df.x, "z"])
agg = getattr(g, api)()
expect = getattr(pdf.groupby([pdf.x, "z"]), api)(numeric_only=numeric_only)
assert_eq(agg, expect)

pdf = pdf.set_index("x")
df = from_pandas(pdf, npartitions=10, sort=False)
g = df.groupby("x")
Expand Down Expand Up @@ -311,6 +327,11 @@ def test_groupby_single_agg_split_out(pdf, df, api, sort, split_out):
expect = getattr(pdf.y.groupby(pdf.x, sort=sort), api)()
assert_eq(agg, expect, sort_results=not sort)

g = df.y.groupby([df.x, df.z], sort=sort)
agg = getattr(g, api)(split_out=split_out)
expect = getattr(pdf.y.groupby([pdf.x, pdf.z], sort=sort), api)()
assert_eq(agg, expect, sort_results=not sort)


@pytest.mark.parametrize(
"spec",
Expand Down