-
-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement groupby multi level support #486
Changes from 3 commits
1a2efe9
ab9a663
b717928
23a2c1d
af4ee66
17d1308
f2ee711
30d458b
6047b82
dc5bfea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
import numpy as np | ||
from dask import is_dask_collection | ||
from dask.core import flatten | ||
from dask.dataframe.core import _concat, is_dataframe_like, is_series_like | ||
from dask.dataframe.dispatch import concat, make_meta, meta_nonempty | ||
from dask.dataframe.groupby import ( | ||
|
@@ -30,7 +31,7 @@ | |
) | ||
from dask.utils import M, is_index_like | ||
|
||
from dask_expr._collection import Index, Series, new_collection | ||
from dask_expr._collection import FrameBase, Index, Series, new_collection | ||
from dask_expr._expr import ( | ||
Assign, | ||
Blockwise, | ||
|
@@ -64,19 +65,57 @@ class GroupByChunk(Chunk): | |
_parameters = Chunk._parameters + ["by"] | ||
_defaults = Chunk._defaults | {"by": None} | ||
|
||
def dependencies(self): | ||
deps = [operand for operand in self.operands if isinstance(operand, Expr)] | ||
if isinstance(self.by, (list, tuple)): | ||
deps.extend(op for op in self.by if isinstance(op, Expr)) | ||
return deps | ||
|
||
@functools.cached_property | ||
def _args(self) -> list: | ||
return [self.frame, self.by] | ||
if not isinstance(self.by, (list, tuple)): | ||
return [self.frame, self.by] | ||
else: | ||
return [self.frame] + list(self.by) | ||
|
||
|
||
class GroupByApplyConcatApply(ApplyConcatApply): | ||
_chunk_cls = GroupByChunk | ||
|
||
@functools.cached_property | ||
def _by_meta(self): | ||
if isinstance(self.by, Expr): | ||
return [meta_nonempty(self.by._meta)] | ||
elif is_scalar(self.by): | ||
return [self.by] | ||
else: | ||
return [ | ||
meta_nonempty(x._meta) if isinstance(x, Expr) else x for x in self.by | ||
] | ||
|
||
@functools.cached_property | ||
def _by_columns(self): | ||
if isinstance(self.by, Expr): | ||
return [] | ||
else: | ||
return [x for x in self.by if not isinstance(x, Expr)] | ||
|
||
@property | ||
def split_by(self): | ||
if isinstance(self.by, Expr): | ||
return self.by.columns | ||
else: | ||
return list( | ||
flatten( | ||
[[x] if not isinstance(x, Expr) else x.columns for x in self.by], | ||
container=list, | ||
) | ||
) | ||
|
||
@functools.cached_property | ||
def _meta_chunk(self): | ||
meta = meta_nonempty(self.frame._meta) | ||
by = self.by if not isinstance(self.by, Expr) else meta_nonempty(self.by._meta) | ||
return self.chunk(meta, by, **self.chunk_kwargs) | ||
return self.chunk(meta, *self._by_meta, **self.chunk_kwargs) | ||
|
||
@property | ||
def _chunk_cls_args(self): | ||
|
@@ -156,16 +195,8 @@ class SingleAggregation(GroupByApplyConcatApply): | |
groupby_chunk = None | ||
groupby_aggregate = None | ||
|
||
@property | ||
def split_by(self): | ||
if isinstance(self.by, Expr): | ||
return self.by.columns | ||
return self.by | ||
|
||
@classmethod | ||
def chunk(cls, df, by=None, **kwargs): | ||
if hasattr(by, "dtype"): | ||
by = [by] | ||
def chunk(cls, df, *by, **kwargs): | ||
return _apply_chunk(df, *by, **kwargs) | ||
|
||
@classmethod | ||
|
@@ -199,8 +230,7 @@ def aggregate_kwargs(self) -> dict: | |
|
||
def _simplify_up(self, parent): | ||
if isinstance(parent, Projection): | ||
by_columns = self.by if not isinstance(self.by, Expr) else [] | ||
columns = sorted(set(parent.columns + by_columns)) | ||
columns = sorted(set(parent.columns + self._by_columns)) | ||
if columns == self.frame.columns: | ||
return | ||
columns = [col for col in self.frame.columns if col in columns] | ||
|
@@ -254,10 +284,6 @@ class GroupbyAggregation(GroupByApplyConcatApply): | |
"sort": None, | ||
} | ||
|
||
@property | ||
def split_by(self): | ||
return self.by | ||
|
||
@functools.cached_property | ||
def spec(self): | ||
# Converts the `arg` operand into specific | ||
|
@@ -343,10 +369,11 @@ def aggregate_kwargs(self) -> dict: | |
def _simplify_down(self): | ||
# Use agg-spec information to add column projection | ||
column_projection = None | ||
by_columns = self.by if not isinstance(self.by, Expr) else [] | ||
if isinstance(self.arg, dict): | ||
column_projection = ( | ||
set(by_columns).union(self.arg.keys()).intersection(self.frame.columns) | ||
set(self._by_columns) | ||
.union(self.arg.keys()) | ||
.intersection(self.frame.columns) | ||
) | ||
if column_projection and column_projection < set(self.frame.columns): | ||
return type(self)(self.frame[list(column_projection)], *self.operands[1:]) | ||
|
@@ -398,11 +425,28 @@ class GroupByReduction(Reduction): | |
def _chunk_cls_args(self): | ||
return [self.by] | ||
|
||
@functools.cached_property | ||
def _by_meta(self): | ||
if isinstance(self.by, Expr): | ||
return meta_nonempty(self.by._meta) | ||
elif is_scalar(self.by): | ||
return self.by | ||
else: | ||
return [ | ||
meta_nonempty(x._meta) if isinstance(x, Expr) else x for x in self.by | ||
] | ||
|
||
@functools.cached_property | ||
def _by_columns(self): | ||
if isinstance(self.by, Expr): | ||
return [] | ||
else: | ||
return [x for x in self.by if not isinstance(x, Expr)] | ||
|
||
@functools.cached_property | ||
def _meta_chunk(self): | ||
meta = meta_nonempty(self.frame._meta) | ||
by = self.by if not isinstance(self.by, Expr) else meta_nonempty(self.by._meta) | ||
return self.chunk(meta, by, **self.chunk_kwargs) | ||
return self.chunk(meta, *self._by_meta, **self.chunk_kwargs) | ||
|
||
|
||
def _var_combine(g, levels, sort=False, observed=False, dropna=True): | ||
|
@@ -428,12 +472,16 @@ class Var(GroupByReduction): | |
def split_by(self): | ||
if isinstance(self.by, Expr): | ||
return self.by.columns | ||
return self.by | ||
else: | ||
return list( | ||
flatten( | ||
[[x] if not isinstance(x, Expr) else x.columns for x in self.by], | ||
container=list, | ||
) | ||
) | ||
|
||
@staticmethod | ||
def chunk(frame, by, **kwargs): | ||
if hasattr(by, "dtype"): | ||
by = [by] | ||
def chunk(frame, *by, **kwargs): | ||
return _var_chunk(frame, *by, **kwargs) | ||
|
||
@functools.cached_property | ||
|
@@ -468,8 +516,7 @@ def _divisions(self): | |
|
||
def _simplify_up(self, parent): | ||
if isinstance(parent, Projection): | ||
by_columns = self.by if not isinstance(self.by, Expr) else [] | ||
columns = sorted(set(parent.columns + by_columns)) | ||
columns = sorted(set(parent.columns + self._by_columns)) | ||
if columns == self.frame.columns: | ||
return | ||
columns = [col for col in self.frame.columns if col in columns] | ||
|
@@ -902,6 +949,22 @@ def _extract_meta(x, nonempty=False): | |
### | ||
|
||
|
||
def _validate_by_expr(obj, by): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you think of a better name, and possibly add some docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if ( | ||
isinstance(by, Series) | ||
and by.name in obj.columns | ||
and by._name == obj[by.name]._name | ||
): | ||
return by.name | ||
elif isinstance(by, Index) and by._name == obj.index._name: | ||
return by.expr | ||
elif isinstance(by, Series): | ||
if not are_co_aligned(obj.expr, by.expr): | ||
raise ValueError("by must be in the DataFrames columns.") | ||
return by.expr | ||
return by | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment explaining what use cases are not collected by any of the above switches? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added, it can be a proper column name, e.g. |
||
|
||
|
||
class GroupBy: | ||
"""Collection container for groupby aggregations | ||
|
||
|
@@ -923,17 +986,10 @@ def __init__( | |
dropna=None, | ||
slice=None, | ||
): | ||
if ( | ||
isinstance(by, Series) | ||
and by.name in obj.columns | ||
and by._name == obj[by.name]._name | ||
): | ||
by = by.name | ||
elif isinstance(by, Index) and by._name == obj.index._name: | ||
pass | ||
elif isinstance(by, Series): | ||
if not are_co_aligned(obj.expr, by.expr): | ||
raise ValueError("by must be in the DataFrames columns.") | ||
if isinstance(by, (tuple, list)): | ||
by = [_validate_by_expr(obj, x) for x in by] | ||
else: | ||
by = _validate_by_expr(obj, by) | ||
|
||
by_ = by if isinstance(by, (tuple, list)) else [by] | ||
self._slice = slice | ||
|
@@ -957,11 +1013,7 @@ def __init__( | |
self.observed = observed | ||
self.dropna = dropna | ||
self.group_keys = group_keys | ||
|
||
if isinstance(by, Series): | ||
self.by = by.expr | ||
else: | ||
self.by = [by] if np.isscalar(by) else list(by) | ||
self.by = [by] if np.isscalar(by) or isinstance(by, Expr) else list(by) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everywhere else you added code branches that allow the by attribute to be either a scalar (Expr or otherwise) or a list. This however says that by is always a list? I much prefer the latter. Test coverage for all these new code branches is quite spotty - something that coercing everything into a list as soon as it's acquired from the user would prevent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very good point, that was an oversight on my part. It will always be a list after this pr is in |
||
|
||
def _numeric_only_kwargs(self, numeric_only): | ||
kwargs = {"numeric_only": numeric_only} | ||
|
@@ -1246,7 +1298,11 @@ def __init__( | |
): | ||
# Raise pandas errors if applicable | ||
if isinstance(obj, Series): | ||
if isinstance(by, Series): | ||
if ( | ||
isinstance(by, FrameBase) | ||
or isinstance(by, (list, tuple)) | ||
and any(isinstance(x, FrameBase) for x in by) | ||
): | ||
pass | ||
elif isinstance(by, list): | ||
if len(by) == 0: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels not great to have identical and/or very similar code (GroupByReduction/GroupByApplyConcatApply._by_columns are exactly the same, and GroupByReduction/GroupByApplyConcatApply._by_meta are nearly the same) between classes. Do you think the logic can be combined into
ApplyConcatApply
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, that's not groupby specific, so that's not the place for this logic.
The groupby structure needs a refactor anyway to make this more consistent but that's something for a follow up if the actual implementation is ironed out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rewrote parts of the implementation, still not happy with it, but less duplicated code now