Skip to content

Commit

Permalink
Extract more code to source file
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Dec 19, 2024
1 parent 4b854b0 commit 757d3eb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 40 deletions.
45 changes: 5 additions & 40 deletions cpp/include/cudf/aggregation/host_udf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#pragma once

#include <cudf/aggregation.hpp>
#include <cudf/detail/utilities/visitor_overload.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/export.hpp>
#include <cudf/utilities/span.hpp>
Expand Down Expand Up @@ -156,28 +155,21 @@ struct host_udf_base {
std::is_same_v<T, groupby_aggregation>)>
data_attribute(std::unique_ptr<T> value_) : value{std::move(value_)}
{
CUDF_EXPECTS(std::get<std::unique_ptr<aggregation>>(value) != nullptr,
"Invalid aggregation request.");
if constexpr (std::is_same_v<T, aggregation>) {
CUDF_EXPECTS(
dynamic_cast<groupby_aggregation*>(std::get<std::unique_ptr<T>>(value).get()) != nullptr,
"Requesting results from other aggregations is only supported in groupby "
"aggregations.");
}
CUDF_EXPECTS(std::get<std::unique_ptr<aggregation>>(value) != nullptr,
"Invalid aggregation request.");
}

/**
* @brief Copy constructor.
* @param other The other data attribute to copy from
*/
data_attribute(data_attribute const& other)
: value{std::visit(
cudf::detail::visitor_overload{
[](auto const& val) { return value_type{val}; },
[](std::unique_ptr<aggregation> const& val) { return value_type{val->clone()}; }},
other.value)}
{
}
data_attribute(data_attribute const& other);

/**
* @brief Hash functor for `data_attribute`.
Expand All @@ -188,15 +180,7 @@ struct host_udf_base {
* @param attr The data attribute to hash
* @return The hash value of the data attribute
*/
std::size_t operator()(data_attribute const& attr) const
{
auto const hash_value =
std::visit(cudf::detail::visitor_overload{
[](auto const& val) { return std::hash<int>{}(static_cast<int>(val)); },
[](std::unique_ptr<aggregation> const& val) { return val->do_hash(); }},
attr.value);
return std::hash<std::size_t>{}(attr.value.index()) ^ hash_value;
}
std::size_t operator()(data_attribute const& attr) const;
}; // struct hash

/**
Expand All @@ -209,26 +193,7 @@ struct host_udf_base {
* @param rhs The right-hand side data attribute
* @return True if the two data attributes are equal
*/
bool operator()(data_attribute const& lhs, data_attribute const& rhs) const
{
auto const& lhs_val = lhs.value;
auto const& rhs_val = rhs.value;
if (lhs_val.index() != rhs_val.index()) { return false; }
return std::visit(cudf::detail::visitor_overload{
[](auto const& lhs_val, auto const& rhs_val) {
if constexpr (std::is_same_v<decltype(lhs_val), decltype(rhs_val)>) {
return lhs_val == rhs_val;
} else {
return false;
}
},
[](std::unique_ptr<aggregation> const& lhs_val,
std::unique_ptr<aggregation> const& rhs_val) {
return lhs_val->is_equal(*rhs_val);
}},
lhs_val,
rhs_val);
}
bool operator()(data_attribute const& lhs, data_attribute const& rhs) const;
}; // struct equal_to
}; // struct data_attribute

Expand Down
43 changes: 43 additions & 0 deletions cpp/src/groupby/sort/host_udf_aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,51 @@

#include <cudf/aggregation/host_udf.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/utilities/visitor_overload.hpp>

namespace cudf {

host_udf_base::data_attribute::data_attribute(data_attribute const& other)
: value{std::visit(cudf::detail::visitor_overload{[](auto const& val) { return value_type{val}; },
[](std::unique_ptr<aggregation> const& val) {
return value_type{val->clone()};
}},
other.value)}
{
}

std::size_t host_udf_base::data_attribute::hash::operator()(data_attribute const& attr) const
{
auto const hash_value =
std::visit(cudf::detail::visitor_overload{
[](auto const& val) { return std::hash<int>{}(static_cast<int>(val)); },
[](std::unique_ptr<aggregation> const& val) { return val->do_hash(); }},
attr.value);
return std::hash<std::size_t>{}(attr.value.index()) ^ hash_value;
}

bool host_udf_base::data_attribute::equal_to::operator()(data_attribute const& lhs,
data_attribute const& rhs) const
{
auto const& lhs_val = lhs.value;
auto const& rhs_val = rhs.value;
if (lhs_val.index() != rhs_val.index()) { return false; }
return std::visit(
cudf::detail::visitor_overload{
[](auto const& lhs_val, auto const& rhs_val) {
if constexpr (std::is_same_v<decltype(lhs_val), decltype(rhs_val)>) {
return lhs_val == rhs_val;
} else {
return false;
}
},
[](std::unique_ptr<aggregation> const& lhs_val, std::unique_ptr<aggregation> const& rhs_val) {
return lhs_val->is_equal(*rhs_val);
}},
lhs_val,
rhs_val);
}

namespace detail {

host_udf_aggregation::host_udf_aggregation(std::unique_ptr<host_udf_base> udf_ptr_)
Expand Down

0 comments on commit 757d3eb

Please sign in to comment.