diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index be3cc410174..612a0c571b5 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -2,6 +2,7 @@ from __future__ import annotations import copy +import functools import itertools import textwrap import types @@ -30,7 +31,7 @@ from cudf.core._internals import aggregation, sorting from cudf.core.abc import Serializable from cudf.core.buffer import acquire_spill_lock -from cudf.core.column.column import ColumnBase, as_column +from cudf.core.column.column import ColumnBase, as_column, column_empty from cudf.core.column_accessor import ColumnAccessor from cudf.core.copy_types import GatherMap from cudf.core.dtypes import ( @@ -745,7 +746,7 @@ def _groupby(self) -> types.SimpleNamespace: plc.Table( [ col.to_pylibcudf(mode="read") - for col in self.grouping.keys._columns + for col in self.grouping._key_columns ] ), plc.types.NullPolicy.EXCLUDE @@ -1047,8 +1048,8 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): ) and not _is_all_scan_aggregate(normalized_aggs): # Even with `sort=False`, pandas guarantees that # groupby preserves the order of rows within each group. - left_cols = list(self.grouping.keys.drop_duplicates()._columns) - right_cols = list(result_index._columns) + left_cols = self.grouping.keys.drop_duplicates()._columns + right_cols = result_index._columns join_keys = [ _match_join_keys(lcol, rcol, "inner") for lcol, rcol in zip(left_cols, right_cols) @@ -2480,7 +2481,7 @@ def _cov_or_corr(self, func, method_name): column_pair_groupby = cudf.DataFrame._from_data( column_pair_structs - ).groupby(by=self.grouping.keys) + ).groupby(by=self.grouping) try: gb_cov_corr = column_pair_groupby.agg(func) @@ -3504,7 +3505,9 @@ def _handle_by_or_level(self, by=None, level=None): self._handle_level(level) else: by_list = by if isinstance(by, list) else [by] - + if not len(self._obj) and not len(by_list): + # We pretend to groupby an empty column + by_list = [cudf.Index._from_column(column_empty(0))] for by in by_list: if callable(by): self._handle_callable(by) @@ -3526,16 +3529,12 @@ def _handle_by_or_level(self, by=None, level=None): except (KeyError, TypeError): self._handle_misc(by) - @property + @functools.cached_property def keys(self): """Return grouping key columns as index""" - nkeys = len(self._key_columns) - - if nkeys == 0: - return cudf.Index([], name=None) - elif nkeys > 1: + if len(self._key_columns) > 1: return cudf.MultiIndex._from_data( - dict(zip(range(nkeys), self._key_columns)) + dict(enumerate(self._key_columns)) )._set_names(self.names) else: return cudf.Index._from_column(