Skip to content

Commit

Permalink
Make do_hash optional and rewrite expression
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Dec 18, 2024
1 parent 1a68728 commit e51fc98
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
17 changes: 11 additions & 6 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,12 @@ std::unique_ptr<Base> make_udf_aggregation(udf_type type,
* @brief The interface for host-based UDF implementation.
*
* An implementation of host-based UDF needs to be derived from this base class, defining
* its own version of the required functions. In particular, the derived class must define the
* following function: `get_empty_output`, `operator()`, `do_hash`, `is_equal` and `clone`.
* The function `get_required_data` can also be optionally overridden to facilitate selective
* access to the input data as well as intermediate data provided by libcudf.
* its own version of the required functions. In particular:
* - The derived class is required to implement `get_empty_output`, `operator()`, `is_equal`,
* and `clone` functions.
* - If necessary, the derived class can also override `do_hash` to compute hashing for its
* instance, and `get_required_data` to selectively access to the input data as well as
* intermediate data provided by libcudf.
*/
struct host_udf_base {
host_udf_base() = default;
Expand Down Expand Up @@ -821,10 +823,13 @@ struct host_udf_base {
rmm::device_async_resource_ref mr) const = 0;

/**
* @brief Computes hash value of the derived class's instance.
* @brief Computes hash value of the class's instance.
* @return The hash value of the instance
*/
[[nodiscard]] virtual std::size_t do_hash() const = 0;
[[nodiscard]] virtual std::size_t do_hash() const
{
return std::hash<int>{}(static_cast<int>(aggregation::Kind::HOST_UDF));
}

/**
* @brief Compares two instances of the derived class for equality.
Expand Down
21 changes: 11 additions & 10 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,16 +800,17 @@ void aggregate_result_functor::operator()<aggregation::HOST_UDF>(aggregation con
{
if (cache.has_result(values, agg)) { return; }

auto const& udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto data_attrs = udf_ptr->get_required_data();
if (data_attrs.empty()) { // empty means everything
data_attrs = {host_udf_base::groupby_data_attribute::INPUT_VALUES,
host_udf_base::groupby_data_attribute::GROUPED_VALUES,
host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES,
host_udf_base::groupby_data_attribute::NUM_GROUPS,
host_udf_base::groupby_data_attribute::GROUP_OFFSETS,
host_udf_base::groupby_data_attribute::GROUP_LABELS};
}
auto const& udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const data_attrs = [&]() -> host_udf_base::data_attribute_set_t {
if (auto tmp = udf_ptr->get_required_data(); !tmp.empty()) { return tmp; }
// Empty attribute set means everything.
return {host_udf_base::groupby_data_attribute::INPUT_VALUES,
host_udf_base::groupby_data_attribute::GROUPED_VALUES,
host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES,
host_udf_base::groupby_data_attribute::NUM_GROUPS,
host_udf_base::groupby_data_attribute::GROUP_OFFSETS,
host_udf_base::groupby_data_attribute::GROUP_LABELS};
}();

// Do not cache udf_input, as the actual input data may change from run to run.
host_udf_base::input_map_t udf_input;
Expand Down

0 comments on commit e51fc98

Please sign in to comment.