Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Aggregation serialization in pylibcudf #17469

Open
wants to merge 5 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/pylibcudf/pylibcudf/aggregation.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ from pylibcudf.libcudf.aggregation cimport (
reduce_aggregation,
rolling_aggregation,
scan_aggregation,
std_var_aggregation,
quantile_aggregation,
nunique_aggregation,
nth_element_aggregation,
ewma_aggregation,
rank_aggregation,
collect_list_aggregation,
collect_set_aggregation,
udf_aggregation,
correlation_aggregation,
covariance_aggregation,
)
from pylibcudf.libcudf.types cimport (
interpolation,
Expand All @@ -32,6 +43,17 @@ ctypedef groupby_scan_aggregation * gbsa_ptr
ctypedef reduce_aggregation * ra_ptr
ctypedef scan_aggregation * sa_ptr
ctypedef rolling_aggregation * roa_ptr
ctypedef std_var_aggregation * std_var_ptr
ctypedef quantile_aggregation * quantile_ptr
ctypedef nunique_aggregation * nunique_ptr
ctypedef nth_element_aggregation * nth_element_ptr
ctypedef ewma_aggregation * ewma_ptr
ctypedef rank_aggregation * rank_ptr
ctypedef collect_list_aggregation * collect_list_ptr
ctypedef collect_set_aggregation * collect_set_ptr
ctypedef udf_aggregation * udf_ptr
ctypedef correlation_aggregation * correlation_ptr
ctypedef covariance_aggregation * covariance_ptr


cdef class Aggregation:
Expand Down
102 changes: 102 additions & 0 deletions python/pylibcudf/pylibcudf/aggregation.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,108 @@ cdef class Aggregation:
def __hash__(self):
return dereference(self.c_obj).do_hash()

def __reduce__(self):
cdef std_var_aggregation *std_var_cast
cdef quantile_aggregation *quantile_cast
cdef nunique_aggregation *nunique_cast
cdef nth_element_aggregation *nth_element_cast
cdef ewma_aggregation *ewma_cast
cdef rank_aggregation *rank_cast
cdef collect_list_aggregation *collect_list_cast
cdef collect_set_aggregation *collect_set_cast
cdef udf_aggregation *udf_cast
cdef correlation_aggregation *correlation_cast
cdef covariance_aggregation *covariance_cast

if self.kind() is Kind.SUM:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace this cascade with a singledispatch helper function? It's a lot of clauses to fail through in the worst case.

return (sum, ())
elif self.kind() is Kind.PRODUCT:
return (product, ())
elif self.kind() is Kind.MIN:
return (min, ())
elif self.kind() is Kind.MAX:
return (max, ())
elif self.kind() is Kind.COUNT_ALL:
return (count, (null_policy.INCLUDE,))
elif self.kind() is Kind.COUNT_VALID:
return (count, (null_policy.EXCLUDE,))
elif self.kind() is Kind.ANY:
return (any, ())
elif self.kind() is Kind.ALL:
return (all, ())
elif self.kind() is Kind.SUM_OF_SQUARES:
return (sum_of_squares, ())
elif self.kind() is Kind.MEAN:
return (mean, ())
elif self.kind() is Kind.VARIANCE:
std_var_cast = dynamic_cast[std_var_ptr](self.c_obj.get())
return (variance, (std_var_cast._ddof,))
elif self.kind() is Kind.STD:
std_var_cast = dynamic_cast[std_var_ptr](self.c_obj.get())
return (std, (std_var_cast._ddof,))
elif self.kind() is Kind.MEDIAN:
return (median, ())
elif self.kind() is Kind.QUANTILE:
quantile_cast = dynamic_cast[quantile_ptr](self.c_obj.get())
return (quantile, (quantile_cast._quantiles, quantile_cast._interpolation))
elif self.kind() is Kind.ARGMAX:
return (argmax, ())
elif self.kind() is Kind.ARGMIN:
return (argmin, ())
elif self.kind() is Kind.NUNIQUE:
nunique_cast = dynamic_cast[nunique_ptr](self.c_obj.get())
return (nunique, (nunique_cast._null_handling,))
elif self.kind() is Kind.NTH_ELEMENT:
nth_element_cast = dynamic_cast[nth_element_ptr](self.c_obj.get())
return (nth_element, (nth_element_cast._n, nth_element_cast._null_handling))
elif self.kind() is Kind.EWMA:
ewma_cast = dynamic_cast[ewma_ptr](self.c_obj.get())
return (ewma, (ewma_cast.center_of_mass, ewma_cast.history))
elif self.kind() is Kind.RANK:
rank_cast = dynamic_cast[rank_ptr](self.c_obj.get())
return (
rank, (
rank_cast._method,
rank_cast._column_order,
rank_cast._null_handling,
rank_cast._null_precedence,
rank_cast._percentage
)
)
elif self.kind() is Kind.COLLECT_LIST:
collect_list_cast = dynamic_cast[collect_list_ptr](self.c_obj.get())
return (collect_list, (collect_list_cast._null_handling, ))
elif self.kind() is Kind.COLLECT_SET:
collect_set_cast = dynamic_cast[collect_set_ptr](self.c_obj.get())
return (
collect_set, (
collect_set_cast._null_handling,
collect_set_cast._nulls_equal,
collect_set_cast._nans_equal
)
)
elif self.kind() in (Kind.CUDA, Kind.PTX):
udf_cast = dynamic_cast[udf_ptr](self.c_obj.get())
return (
udf, (
udf_cast._source.decode("utf-8"),
DataType.from_libcudf(udf_cast._output_type)
)
)
elif self.kind() is Kind.CORRELATION:
correlation_cast = dynamic_cast[correlation_ptr](self.c_obj.get())
return (
correlation, (
correlation_cast._type,
correlation_cast._min_periods
)
)
elif self.kind() is Kind.COVARIANCE:
covariance_cast = dynamic_cast[covariance_ptr](self.c_obj.get())
return (covariance, (covariance_cast._min_periods, covariance_cast._ddof))
else:
raise ValueError("Unsupported kind")

# TODO: Ideally we would include the return type here, but we need to do so
# in a way that Sphinx understands (currently have issues due to
# https://github.com/cython/cython/issues/5609).
Expand Down
67 changes: 67 additions & 0 deletions python/pylibcudf/pylibcudf/libcudf/aggregation.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:
CUDA
CORRELATION
COVARIANCE
EWMA

cdef cppclass aggregation:
Kind kind
Expand All @@ -70,6 +71,9 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:
cdef cppclass scan_aggregation(aggregation):
pass

cdef cppclass segmented_reduce_aggregation(aggregation):
pass

cpdef enum class udf_type(bool):
CUDA
PTX
Expand Down Expand Up @@ -170,3 +174,66 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:
null_policy null_handling,
null_order null_precedence,
rank_percentage percentage) except +libcudf_exception_handler

cdef extern from "cudf/detail/aggregation/aggregation.hpp" \
namespace "cudf::detail" nogil:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm we don't expose pretty much any cudf detail APIs to pylibcudf, and I don't want to start here. Can we open an issue about these? If these are attributes that are absolutely necessary to reconstruct the serialized types, then we should discuss exposing them publicly in libcudf.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment Vyas, I've opened #17630. Please let me know if there's anything else I can do.


cdef cppclass std_var_aggregation(
rolling_aggregation,
groupby_aggregation,
reduce_aggregation,
segmented_reduce_aggregation
):
size_type _ddof

cdef cppclass quantile_aggregation(groupby_aggregation, reduce_aggregation):
vector[double] _quantiles
interpolation _interpolation

cdef cppclass nunique_aggregation(
groupby_aggregation, reduce_aggregation, segmented_reduce_aggregation
):
null_policy _null_handling

cdef cppclass nth_element_aggregation(
groupby_aggregation, reduce_aggregation, rolling_aggregation
):
size_type _n
null_policy _null_handling

cdef cppclass ewma_aggregation(scan_aggregation):
double center_of_mass
ewm_history history

cdef cppclass rank_aggregation(
rolling_aggregation, groupby_scan_aggregation, reduce_aggregation
):
rank_method _method
order _column_order
null_policy _null_handling
null_order _null_precedence
rank_percentage _percentage

cdef cppclass collect_list_aggregation(
rolling_aggregation, groupby_aggregation, reduce_aggregation
):
null_policy _null_handling

cdef cppclass collect_set_aggregation(
rolling_aggregation, groupby_aggregation, reduce_aggregation
):
null_policy _null_handling
null_equality _nulls_equal
nan_equality _nans_equal

cdef cppclass udf_aggregation(rolling_aggregation):
string _source
data_type _output_type

cdef cppclass correlation_aggregation(groupby_aggregation):
correlation_type _type
size_type _min_periods

cdef cppclass covariance_aggregation(groupby_aggregation):
size_type _min_periods
size_type _ddof
Loading
Loading