diff --git a/cpp/include/cudf/aggregation/host_udf.hpp b/cpp/include/cudf/aggregation/host_udf.hpp index 128c02f2627..53a2171fa54 100644 --- a/cpp/include/cudf/aggregation/host_udf.hpp +++ b/cpp/include/cudf/aggregation/host_udf.hpp @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include #include @@ -156,28 +155,21 @@ struct host_udf_base { std::is_same_v)> data_attribute(std::unique_ptr value_) : value{std::move(value_)} { + CUDF_EXPECTS(std::get>(value) != nullptr, + "Invalid aggregation request."); if constexpr (std::is_same_v) { CUDF_EXPECTS( dynamic_cast(std::get>(value).get()) != nullptr, "Requesting results from other aggregations is only supported in groupby " "aggregations."); } - CUDF_EXPECTS(std::get>(value) != nullptr, - "Invalid aggregation request."); } /** * @brief Copy constructor. * @param other The other data attribute to copy from */ - data_attribute(data_attribute const& other) - : value{std::visit( - cudf::detail::visitor_overload{ - [](auto const& val) { return value_type{val}; }, - [](std::unique_ptr const& val) { return value_type{val->clone()}; }}, - other.value)} - { - } + data_attribute(data_attribute const& other); /** * @brief Hash functor for `data_attribute`. @@ -188,15 +180,7 @@ struct host_udf_base { * @param attr The data attribute to hash * @return The hash value of the data attribute */ - std::size_t operator()(data_attribute const& attr) const - { - auto const hash_value = - std::visit(cudf::detail::visitor_overload{ - [](auto const& val) { return std::hash{}(static_cast(val)); }, - [](std::unique_ptr const& val) { return val->do_hash(); }}, - attr.value); - return std::hash{}(attr.value.index()) ^ hash_value; - } + std::size_t operator()(data_attribute const& attr) const; }; // struct hash /** @@ -209,26 +193,7 @@ struct host_udf_base { * @param rhs The right-hand side data attribute * @return True if the two data attributes are equal */ - bool operator()(data_attribute const& lhs, data_attribute const& rhs) const - { - auto const& lhs_val = lhs.value; - auto const& rhs_val = rhs.value; - if (lhs_val.index() != rhs_val.index()) { return false; } - return std::visit(cudf::detail::visitor_overload{ - [](auto const& lhs_val, auto const& rhs_val) { - if constexpr (std::is_same_v) { - return lhs_val == rhs_val; - } else { - return false; - } - }, - [](std::unique_ptr const& lhs_val, - std::unique_ptr const& rhs_val) { - return lhs_val->is_equal(*rhs_val); - }}, - lhs_val, - rhs_val); - } + bool operator()(data_attribute const& lhs, data_attribute const& rhs) const; }; // struct equal_to }; // struct data_attribute diff --git a/cpp/src/groupby/sort/host_udf_aggregation.cpp b/cpp/src/groupby/sort/host_udf_aggregation.cpp index 7d0f37646b9..0da47e17f48 100644 --- a/cpp/src/groupby/sort/host_udf_aggregation.cpp +++ b/cpp/src/groupby/sort/host_udf_aggregation.cpp @@ -16,8 +16,51 @@ #include #include +#include namespace cudf { + +host_udf_base::data_attribute::data_attribute(data_attribute const& other) + : value{std::visit(cudf::detail::visitor_overload{[](auto const& val) { return value_type{val}; }, + [](std::unique_ptr const& val) { + return value_type{val->clone()}; + }}, + other.value)} +{ +} + +std::size_t host_udf_base::data_attribute::hash::operator()(data_attribute const& attr) const +{ + auto const hash_value = + std::visit(cudf::detail::visitor_overload{ + [](auto const& val) { return std::hash{}(static_cast(val)); }, + [](std::unique_ptr const& val) { return val->do_hash(); }}, + attr.value); + return std::hash{}(attr.value.index()) ^ hash_value; +} + +bool host_udf_base::data_attribute::equal_to::operator()(data_attribute const& lhs, + data_attribute const& rhs) const +{ + auto const& lhs_val = lhs.value; + auto const& rhs_val = rhs.value; + if (lhs_val.index() != rhs_val.index()) { return false; } + return std::visit( + cudf::detail::visitor_overload{ + [](auto const& lhs_val, auto const& rhs_val) { + if constexpr (std::is_same_v) { + return lhs_val == rhs_val; + } else { + return false; + } + }, + [](std::unique_ptr const& lhs_val, std::unique_ptr const& rhs_val) { + return lhs_val->is_equal(*rhs_val); + }}, + lhs_val, + rhs_val); +} + namespace detail { host_udf_aggregation::host_udf_aggregation(std::unique_ptr udf_ptr_)