Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid double MultiIndex factorization in groupby index result #17644

Open
wants to merge 4 commits into
base: branch-25.02
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import copy
import functools
import itertools
import textwrap
import types
Expand Down Expand Up @@ -30,7 +31,7 @@
from cudf.core._internals import aggregation, sorting
from cudf.core.abc import Serializable
from cudf.core.buffer import acquire_spill_lock
from cudf.core.column.column import ColumnBase, as_column
from cudf.core.column.column import ColumnBase, as_column, column_empty
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.copy_types import GatherMap
from cudf.core.dtypes import (
Expand Down Expand Up @@ -745,7 +746,7 @@ def _groupby(self) -> types.SimpleNamespace:
plc.Table(
[
col.to_pylibcudf(mode="read")
for col in self.grouping.keys._columns
for col in self.grouping._key_columns
]
),
plc.types.NullPolicy.EXCLUDE
Expand Down Expand Up @@ -1047,8 +1048,8 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
) and not _is_all_scan_aggregate(normalized_aggs):
# Even with `sort=False`, pandas guarantees that
# groupby preserves the order of rows within each group.
left_cols = list(self.grouping.keys.drop_duplicates()._columns)
right_cols = list(result_index._columns)
left_cols = self.grouping.keys.drop_duplicates()._columns
right_cols = result_index._columns
join_keys = [
_match_join_keys(lcol, rcol, "inner")
for lcol, rcol in zip(left_cols, right_cols)
Expand Down Expand Up @@ -2480,7 +2481,7 @@ def _cov_or_corr(self, func, method_name):

column_pair_groupby = cudf.DataFrame._from_data(
column_pair_structs
).groupby(by=self.grouping.keys)
).groupby(by=self.grouping)

try:
gb_cov_corr = column_pair_groupby.agg(func)
Expand Down Expand Up @@ -3504,7 +3505,9 @@ def _handle_by_or_level(self, by=None, level=None):
self._handle_level(level)
else:
by_list = by if isinstance(by, list) else [by]

if not len(self._obj) and not len(by_list):
# We pretend to groupby an empty column
by_list = [cudf.Index._from_column(column_empty(0))]
for by in by_list:
if callable(by):
self._handle_callable(by)
Expand All @@ -3526,16 +3529,12 @@ def _handle_by_or_level(self, by=None, level=None):
except (KeyError, TypeError):
self._handle_misc(by)

@property
@functools.cached_property
def keys(self):
"""Return grouping key columns as index"""
nkeys = len(self._key_columns)

if nkeys == 0:
return cudf.Index([], name=None)
elif nkeys > 1:
if len(self._key_columns) > 1:
return cudf.MultiIndex._from_data(
dict(zip(range(nkeys), self._key_columns))
dict(enumerate(self._key_columns))
)._set_names(self.names)
else:
return cudf.Index._from_column(
Expand Down
Loading