Skip to content

Commit

Permalink
Use singledispatch function
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Dec 18, 2024
1 parent 3649b10 commit 04c9896
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from cudf._lib.types import size_type_dtype
from cudf.api.extensions import no_default
from cudf.api.types import (
_is_categorical_dtype,
is_list_like,
is_numeric_dtype,
is_string_dtype,
Expand Down Expand Up @@ -126,6 +125,46 @@ def _(dtype: DecimalDtype):
return _DECIMAL_AGGS


@singledispatch
def _is_unsupported_agg_for_type(dtype, str_agg: str) -> bool:
return False


@_is_unsupported_agg_for_type.register
def _(dtype: np.dtype, str_agg: str) -> bool:
# string specifically
cumulative_agg = str_agg in {"cumsum", "cummin", "cummax"}
basic_agg = any(
a in str_agg
for a in (
"count",
"max",
"min",
"first",
"last",
"nunique",
"unique",
"nth",
)
)
return (
dtype.kind == "O"
and str_agg not in _STRING_AGGS
and (cumulative_agg or not (basic_agg or str_agg == "<class 'list'>"))
)


@_is_unsupported_agg_for_type.register
def _(dtype: CategoricalDtype, str_agg: str) -> bool:
cumulative_agg = str_agg in {"cumsum", "cummin", "cummax"}
not_basic_agg = not any(
a in str_agg for a in ("count", "max", "min", "unique")
)
return str_agg not in _CATEGORICAL_AGGS and (
cumulative_agg or not_basic_agg
)


def _is_all_scan_aggregate(all_aggs: list[list[str]]) -> bool:
"""
Returns true if all are scan aggregations.
Expand Down Expand Up @@ -760,49 +799,10 @@ def _aggregate(
col_aggregations = []
for agg in aggs:
str_agg = str(agg)
if (
is_string_dtype(col)
and agg not in _STRING_AGGS
and (
str_agg in {"cumsum", "cummin", "cummax"}
or not (
any(
a in str_agg
for a in {
"count",
"max",
"min",
"first",
"last",
"nunique",
"unique",
"nth",
}
)
or (agg is list)
)
)
):
raise TypeError(
f"function is not supported for this dtype: {agg}"
)
elif (
_is_categorical_dtype(col)
and agg not in _CATEGORICAL_AGGS
and (
str_agg in {"cumsum", "cummin", "cummax"}
or not (
any(
a in str_agg
for a in {"count", "max", "min", "unique"}
)
)
)
):
if _is_unsupported_agg_for_type(col.dtype, str_agg):
raise TypeError(
f"{col.dtype} type does not support {agg} operations"
)

agg_obj = aggregation.make_aggregation(agg)
if (
valid_aggregations == "ALL"
Expand Down

0 comments on commit 04c9896

Please sign in to comment.