Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HOST_UDF aggregation for groupby #17592

Merged
merged 24 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 251 additions & 3 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@

#pragma once

#include <cudf/detail/utilities/visitor_overload.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/export.hpp>
#include <cudf/utilities/span.hpp>
#include <cudf/utilities/traits.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <functional>
#include <memory>
#include <optional>
#include <unordered_map>
#include <unordered_set>
#include <variant>
#include <vector>

/**
Expand Down Expand Up @@ -110,8 +119,9 @@ class aggregation {
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
PTX, ///< PTX based UDF aggregation
CUDA, ///< CUDA based UDF aggregation
HOST_UDF, ///< host based UDF aggregation
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
Expand All @@ -120,7 +130,7 @@ class aggregation {
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation,
MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation
};

aggregation() = delete;
Expand Down Expand Up @@ -599,6 +609,244 @@ std::unique_ptr<Base> make_udf_aggregation(udf_type type,
std::string const& user_defined_aggregator,
data_type output_type);

/**
* @brief The interface for host-based UDF implementation.
*
* An implementation of host-based UDF needs to be derived from this base class, defining
* its own version of the required functions.
*/
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
struct host_udf_base {
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
host_udf_base() = default;
virtual ~host_udf_base() = default;

/**
* @brief Define the possible data needed for groupby aggregations.
*
* Note that only sort-based groupby aggregations are supported.
*/
enum class groupby_data_attribute : int32_t {
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
INPUT_VALUES, ///< The input values column.
GROUPED_VALUES, ///< The input values grouped according to the input `keys` for which the
///< values within each group maintain their original order.
SORTED_GROUPED_VALUES, ///< The input values grouped according to the input `keys` and
///< sorted within each group.
GROUP_OFFSETS, ///< The offsets separating groups.
GROUP_LABELS ///< Group labels (which is also the same as group indices).
};

/**
* @brief Describe possible data that may be needed in the derived class for its operations.
*
* Such data can be either intermediate data such as sorted values or group labels etc, or the
* results of other aggregations.
*
* Each derived host-based UDF class may need a different set of data. It is inefficient to
* evaluate and pass down all these possible data at once from libcudf. A solution for that is,
* the derived class can define a subset of data that it needs and libcudf will evaluate
* and pass down only data requested from that set.
*/
struct data_attribute {
/**
* @brief Hold all possible data types for the input of the aggregation in the derived class.
*/
using value_type = std::variant<groupby_data_attribute, std::unique_ptr<aggregation>>;
value_type value; ///< The actual data attribute, wrapped by this struct
///< as a wrapper is needed to define `hash` and `equal_to` functors.

data_attribute() = default; ///< Default constructor
data_attribute(data_attribute&&) = default; ///< Move constructor

/**
* @brief Construct a new data attribute from an aggregation attribute.
* @param value_ An aggregation attribute
*/
template <typename T, CUDF_ENABLE_IF(std::is_same_v<T, groupby_data_attribute>)>
data_attribute(T value_) : value{value_}
{
}

/**
* @brief Construct a new data attribute from another aggregation request.
* @param value_ An aggregation request
*/
template <typename T,
CUDF_ENABLE_IF(std::is_same_v<T, aggregation> ||
std::is_same_v<T, groupby_aggregation>)>
data_attribute(std::unique_ptr<T> value_) : value{std::move(value_)}
{
if constexpr (std::is_same_v<T, aggregation>) {
CUDF_EXPECTS(
dynamic_cast<groupby_aggregation*>(std::get<std::unique_ptr<T>>(value).get()) != nullptr,
"Requesting results from other aggregations is only supported in groupby "
"aggregations.");
}
CUDF_EXPECTS(std::get<std::unique_ptr<aggregation>>(value) != nullptr,
"Invalid aggregation request.");
}

/**
* @brief Copy constructor.
* @param other The other data attribute to copy from
*/
data_attribute(data_attribute const& other)
: value{std::visit(
cudf::detail::visitor_overload{
[](auto const& val) { return value_type{val}; },
[](std::unique_ptr<aggregation> const& val) { return value_type{val->clone()}; }},
other.value)}
{
}

/**
* @brief Hash functor for `data_attribute`.
*/
struct hash {
/**
* @brief Compute the hash value of a data attribute.
* @param attr The data attribute to hash
* @return The hash value of the data attribute
*/
std::size_t operator()(data_attribute const& attr) const
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
{
auto const& value = attr.value;
auto const hash_value =
std::visit(cudf::detail::visitor_overload{
[](auto const& val) { return std::hash<int>{}(static_cast<int>(val)); },
[](std::unique_ptr<aggregation> const& val) { return val->do_hash(); }},
value);
return std::hash<std::size_t>{}(value.index()) ^ hash_value;
}
}; // struct hash

/**
* @brief Equality comparison functor for `data_attribute`.
*/
struct equal_to {
/**
* @brief Check if two data attributes are equal.
* @param lhs The left-hand side data attribute
* @param rhs The right-hand side data attribute
* @return True if the two data attributes are equal
*/
bool operator()(data_attribute const& lhs, data_attribute const& rhs) const
{
auto const& lhs_val = lhs.value;
auto const& rhs_val = rhs.value;
if (lhs_val.index() != rhs_val.index()) { return false; }
return std::visit(cudf::detail::visitor_overload{
[](auto const& lhs_val, auto const& rhs_val) {
if constexpr (std::is_same_v<decltype(lhs_val), decltype(rhs_val)>) {
return lhs_val == rhs_val;
} else {
return false;
}
},
[](std::unique_ptr<aggregation> const& lhs_val,
std::unique_ptr<aggregation> const& rhs_val) {
return lhs_val->is_equal(*rhs_val);
}},
lhs_val,
rhs_val);
}
}; // struct equal_to
}; // struct data_attribute

/**
* @brief Set of attributes for the input data that is needed for computing the aggregation.
*/
using data_attributes_set_t =
std::unordered_set<data_attribute, data_attribute::hash, data_attribute::equal_to>;

/**
* @brief Return a set of attributes for the data that is needed for computing the aggregation.
*
* If this function is not overridden, an empty set is returned. That means all the data
* attributes (except results from other aggregations in groupby) will be needed.
*
* @return A set of `data_attribute`
*/
[[nodiscard]] virtual data_attributes_set_t get_required_data() const { return {}; }

/**
* @brief Hold all possible types of the data that is passed to the derived class for executing
* the aggregation.
*/
using input_data_t = std::variant<column_view, device_span<size_type const>>;

/**
* @brief Input to the aggregation, mapping from each data attribute to its actual data.
*/
using host_udf_input = std::
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
unordered_map<data_attribute, input_data_t, data_attribute::hash, data_attribute::equal_to>;

/**
* @brief Output type of the aggregation.
*
* Currently only a single type is supported as the output of the aggregation, but it will hold
* more type in the future when reduction is supported.
*/
using output_t = std::variant<std::unique_ptr<column>>;

/**
* @brief Get the output when the input values column is empty.
*
* This is called in libcudf when the input values column is empty. In such situations libcudf
* tries to generate the output directly without unnecessarily evaluating the intermediate data.
*
* @param output_dtype The expected output data type
* @param stream The CUDA stream to use for any kernel launches
* @param mr Device memory resource to use for any allocations
* @return The output result of the aggregation when input values is empty
*/
[[nodiscard]] virtual output_t get_empty_output(std::optional<data_type> output_dtype,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr) const = 0;

/**
* @brief Perform the main computation for the host-based UDF.
*
* @param input The input data needed for performing all computation
* @param stream The CUDA stream to use for any kernel launches
* @param mr Device memory resource to use for any allocations
* @return The output result of the aggregation
*/
[[nodiscard]] virtual output_t operator()(host_udf_input const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr) const = 0;

/**
* @brief Computes hash value of the derived class's instance.
* @return The hash value of the instance
*/
[[nodiscard]] virtual std::size_t do_hash() const = 0;
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Compares two instances of the derived class for equality.
* @param other The other derived class's instance to compare with
* @return True if the two instances are equal
*/
[[nodiscard]] virtual bool is_equal(host_udf_base const& other) const = 0;

/**
* @brief Clones the instance.
*
* A class derived from `host_udf_base` should not store too much data such that its instances
* remain lightweight for efficient cloning.
*
* @return A new instance cloned from this
*/
[[nodiscard]] virtual std::unique_ptr<host_udf_base> clone() const = 0;
};

/**
* @brief Factory to create a HOST_UDF aggregation
*
* @param host_udf An instance of a class derived from `host_udf_base` to perform aggregation
* @return A HOST_UDF aggregation object
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_host_udf_aggregation(std::unique_ptr<host_udf_base> host_udf);

/**
* @brief Factory to create a MERGE_LISTS aggregation.
*
Expand Down
53 changes: 53 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/span.hpp>
#include <cudf/utilities/traits.hpp>

#include <functional>
Expand Down Expand Up @@ -88,6 +89,8 @@ class simple_aggregations_collector { // Declares the interface for the simple
class lead_lag_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class udf_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class host_udf_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class merge_lists_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
Expand Down Expand Up @@ -135,6 +138,7 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class collect_set_aggregation const& agg);
virtual void visit(class lead_lag_aggregation const& agg);
virtual void visit(class udf_aggregation const& agg);
virtual void visit(class host_udf_aggregation const& agg);
virtual void visit(class merge_lists_aggregation const& agg);
virtual void visit(class merge_sets_aggregation const& agg);
virtual void visit(class merge_m2_aggregation const& agg);
Expand Down Expand Up @@ -960,6 +964,47 @@ class udf_aggregation final : public rolling_aggregation {
}
};

/**
* @brief Derived class for specifying a custom aggregation specified in host-based UDF.
*/
class host_udf_aggregation final : public groupby_aggregation {
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
public:
std::unique_ptr<host_udf_base> udf_ptr;

host_udf_aggregation() = delete;
host_udf_aggregation(host_udf_aggregation const&) = delete;

explicit host_udf_aggregation(std::unique_ptr<host_udf_base> udf_ptr_)
: aggregation{HOST_UDF}, udf_ptr{std::move(udf_ptr_)}
{
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid host-based UDF instance.");
}

[[nodiscard]] bool is_equal(aggregation const& _other) const override
{
if (!this->aggregation::is_equal(_other)) { return false; }
auto const& other = dynamic_cast<host_udf_aggregation const&>(_other);
return udf_ptr->is_equal(*other.udf_ptr);
}

[[nodiscard]] size_t do_hash() const override
{
return this->aggregation::do_hash() ^ udf_ptr->do_hash();
}

[[nodiscard]] std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<host_udf_aggregation>(udf_ptr->clone());
}

std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying MERGE_LISTS aggregation
*/
Expand Down Expand Up @@ -1462,6 +1507,12 @@ struct target_type_impl<Source,
using type = struct_view;
};

template <typename SourceType>
struct target_type_impl<SourceType, aggregation::HOST_UDF> {
// Just a placeholder. The actual return type is unknown.
using type = struct_view;
};

/**
* @brief Helper alias to get the accumulator type for performing aggregation
* `k` on elements of type `Source`
Expand Down Expand Up @@ -1579,6 +1630,8 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind
return f.template operator()<aggregation::MERGE_TDIGEST>(std::forward<Ts>(args)...);
case aggregation::EWMA:
return f.template operator()<aggregation::EWMA>(std::forward<Ts>(args)...);
case aggregation::HOST_UDF:
return f.template operator()<aggregation::HOST_UDF>(std::forward<Ts>(args)...);
default: {
#ifndef __CUDA_ARCH__
CUDF_FAIL("Unsupported aggregation.");
Expand Down
Loading
Loading