diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9cbacee8e8d..8c6cd922747 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -446,7 +446,6 @@ add_library( src/groupby/sort/group_quantiles.cu src/groupby/sort/group_std.cu src/groupby/sort/group_sum.cu - src/groupby/sort/scan.cpp src/groupby/sort/group_count_scan.cu src/groupby/sort/group_max_scan.cu src/groupby/sort/group_min_scan.cu @@ -454,6 +453,8 @@ add_library( src/groupby/sort/group_rank_scan.cu src/groupby/sort/group_replace_nulls.cu src/groupby/sort/group_sum_scan.cu + src/groupby/sort/host_udf_aggregation.cpp + src/groupby/sort/scan.cpp src/groupby/sort/sort_helper.cu src/hash/md5_hash.cu src/hash/murmurhash3_x86_32.cu diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index f5f514d26d9..a1b7db5e08a 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -110,8 +110,9 @@ class aggregation { COLLECT_SET, ///< collect values into a list without duplicate entries LEAD, ///< window function, accesses row at specified offset following current row LAG, ///< window function, accesses row at specified offset preceding current row - PTX, ///< PTX UDF based reduction - CUDA, ///< CUDA UDF based reduction + PTX, ///< PTX based UDF aggregation + CUDA, ///< CUDA based UDF aggregation + HOST_UDF, ///< host based UDF aggregation MERGE_LISTS, ///< merge multiple lists values into one list MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries MERGE_M2, ///< merge partial values of M2 aggregation, @@ -120,7 +121,7 @@ class aggregation { TDIGEST, ///< create a tdigest from a set of input values MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together HISTOGRAM, ///< compute frequency of each element - MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation, + MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation }; aggregation() = delete; @@ -599,6 +600,18 @@ std::unique_ptr make_udf_aggregation(udf_type type, std::string const& user_defined_aggregator, data_type output_type); +// Forward declaration of `host_udf_base` for the factory function of `HOST_UDF` aggregation. +struct host_udf_base; + +/** + * @brief Factory to create a HOST_UDF aggregation. + * + * @param host_udf An instance of a class derived from `host_udf_base` to perform aggregation + * @return A HOST_UDF aggregation object + */ +template +std::unique_ptr make_host_udf_aggregation(std::unique_ptr host_udf); + /** * @brief Factory to create a MERGE_LISTS aggregation. * diff --git a/cpp/include/cudf/aggregation/host_udf.hpp b/cpp/include/cudf/aggregation/host_udf.hpp new file mode 100644 index 00000000000..bbce76dc5f3 --- /dev/null +++ b/cpp/include/cudf/aggregation/host_udf.hpp @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +/** + * @file host_udf.hpp + * @brief Declare the base class for host-side user-defined function (`HOST_UDF`) and example of + * subclass implementation. + */ + +namespace CUDF_EXPORT cudf { +/** + * @addtogroup aggregation_factories + * @{ + */ + +/** + * @brief The interface for host-based UDF implementation. + * + * An implementation of host-based UDF needs to be derived from this base class, defining + * its own version of the required functions. In particular: + * - The derived class is required to implement `get_empty_output`, `operator()`, `is_equal`, + * and `clone` functions. + * - If necessary, the derived class can also override `do_hash` to compute hashing for its + * instance, and `get_required_data` to selectively access to the input data as well as + * intermediate data provided by libcudf. + * + * Example of such implementation: + * @code{.cpp} + * struct my_udf_aggregation : cudf::host_udf_base { + * my_udf_aggregation() = default; + * + * // This UDF aggregation needs `GROUPED_VALUES` and `GROUP_OFFSETS`, + * // and the result from groupby `MAX` aggregation. + * [[nodiscard]] data_attribute_set_t get_required_data() const override + * { + * return {groupby_data_attribute::GROUPED_VALUES, + * groupby_data_attribute::GROUP_OFFSETS, + * cudf::make_max_aggregation()}; + * } + * + * [[nodiscard]] output_t get_empty_output( + * [[maybe_unused]] std::optional output_dtype, + * [[maybe_unused]] rmm::cuda_stream_view stream, + * [[maybe_unused]] rmm::device_async_resource_ref mr) const override + * { + * // This UDF aggregation always returns a column of type INT32. + * return cudf::make_empty_column(cudf::data_type{cudf::type_id::INT32}); + * } + * + * [[nodiscard]] output_t operator()(input_map_t const& input, + * rmm::cuda_stream_view stream, + * rmm::device_async_resource_ref mr) const override + * { + * // Perform UDF computation using the input data and return the result. + * } + * + * [[nodiscard]] bool is_equal(host_udf_base const& other) const override + * { + * // Check if the other object is also instance of this class. + * return dynamic_cast(&other) != nullptr; + * } + * + * [[nodiscard]] std::unique_ptr clone() const override + * { + * return std::make_unique(); + * } + * }; + * @endcode + */ +struct host_udf_base { + host_udf_base() = default; + virtual ~host_udf_base() = default; + + /** + * @brief Define the possible data needed for groupby aggregations. + * + * Note that only sort-based groupby aggregations are supported. + */ + enum class groupby_data_attribute : int32_t { + INPUT_VALUES, ///< The input values column. + GROUPED_VALUES, ///< The input values grouped according to the input `keys` for which the + ///< values within each group maintain their original order. + SORTED_GROUPED_VALUES, ///< The input values grouped according to the input `keys` and + ///< sorted within each group. + NUM_GROUPS, ///< The number of groups (i.e., number of distinct keys). + GROUP_OFFSETS, ///< The offsets separating groups. + GROUP_LABELS ///< Group labels (which is also the same as group indices). + }; + + /** + * @brief Describe possible data that may be needed in the derived class for its operations. + * + * Such data can be either intermediate data such as sorted values or group labels etc, or the + * results of other aggregations. + * + * Each derived host-based UDF class may need a different set of data. It is inefficient to + * evaluate and pass down all these possible data at once from libcudf. A solution for that is, + * the derived class can define a subset of data that it needs and libcudf will evaluate + * and pass down only data requested from that set. + */ + struct data_attribute { + /** + * @brief Hold all possible data types for the input of the aggregation in the derived class. + */ + using value_type = std::variant>; + value_type value; ///< The actual data attribute, wrapped by this struct + ///< as a wrapper is needed to define `hash` and `equal_to` functors. + + data_attribute() = default; ///< Default constructor + data_attribute(data_attribute&&) = default; ///< Move constructor + + /** + * @brief Construct a new data attribute from an aggregation attribute. + * @param value_ An aggregation attribute + */ + template )> + data_attribute(T value_) : value{value_} + { + } + + /** + * @brief Construct a new data attribute from another aggregation request. + * @param value_ An aggregation request + */ + template || + std::is_same_v)> + data_attribute(std::unique_ptr value_) : value{std::move(value_)} + { + CUDF_EXPECTS(std::get>(value) != nullptr, + "Invalid aggregation request."); + if constexpr (std::is_same_v) { + CUDF_EXPECTS( + dynamic_cast(std::get>(value).get()) != nullptr, + "Requesting results from other aggregations is only supported in groupby " + "aggregations."); + } + } + + /** + * @brief Copy constructor. + * @param other The other data attribute to copy from + */ + data_attribute(data_attribute const& other); + + /** + * @brief Hash functor for `data_attribute`. + */ + struct hash { + /** + * @brief Compute the hash value of a data attribute. + * @param attr The data attribute to hash + * @return The hash value of the data attribute + */ + std::size_t operator()(data_attribute const& attr) const; + }; // struct hash + + /** + * @brief Equality comparison functor for `data_attribute`. + */ + struct equal_to { + /** + * @brief Check if two data attributes are equal. + * @param lhs The left-hand side data attribute + * @param rhs The right-hand side data attribute + * @return True if the two data attributes are equal + */ + bool operator()(data_attribute const& lhs, data_attribute const& rhs) const; + }; // struct equal_to + }; // struct data_attribute + + /** + * @brief Set of attributes for the input data that is needed for computing the aggregation. + */ + using data_attribute_set_t = + std::unordered_set; + + /** + * @brief Return a set of attributes for the data that is needed for computing the aggregation. + * + * The derived class should return the attributes corresponding to only the data that it needs to + * avoid unnecessary computation performed in libcudf. If this function is not overridden, an + * empty set is returned. That means all the data attributes (except results from other + * aggregations in groupby) will be needed. + * + * @return A set of `data_attribute` + */ + [[nodiscard]] virtual data_attribute_set_t get_required_data() const { return {}; } + + /** + * @brief Hold all possible types of the data that is passed to the derived class for executing + * the aggregation. + */ + using input_data_t = std::variant>; + + /** + * @brief Input to the aggregation, mapping from each data attribute to its actual data. + */ + using input_map_t = std:: + unordered_map; + + /** + * @brief Output type of the aggregation. + * + * Currently only a single type is supported as the output of the aggregation, but it will hold + * more type in the future when reduction is supported. + */ + using output_t = std::variant>; + + /** + * @brief Get the output when the input values column is empty. + * + * This is called in libcudf when the input values column is empty. In such situations libcudf + * tries to generate the output directly without unnecessarily evaluating the intermediate data. + * + * @param output_dtype The expected output data type + * @param stream The CUDA stream to use for any kernel launches + * @param mr Device memory resource to use for any allocations + * @return The output result of the aggregation when input values is empty + */ + [[nodiscard]] virtual output_t get_empty_output(std::optional output_dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const = 0; + + /** + * @brief Perform the main computation for the host-based UDF. + * + * @param input The input data needed for performing all computation + * @param stream The CUDA stream to use for any kernel launches + * @param mr Device memory resource to use for any allocations + * @return The output result of the aggregation + */ + [[nodiscard]] virtual output_t operator()(input_map_t const& input, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const = 0; + + /** + * @brief Computes hash value of the class's instance. + * @return The hash value of the instance + */ + [[nodiscard]] virtual std::size_t do_hash() const + { + return std::hash{}(static_cast(aggregation::Kind::HOST_UDF)); + } + + /** + * @brief Compares two instances of the derived class for equality. + * @param other The other derived class's instance to compare with + * @return True if the two instances are equal + */ + [[nodiscard]] virtual bool is_equal(host_udf_base const& other) const = 0; + + /** + * @brief Clones the instance. + * + * A class derived from `host_udf_base` should not store too much data such that its instances + * remain lightweight for efficient cloning. + * + * @return A new instance cloned from this + */ + [[nodiscard]] virtual std::unique_ptr clone() const = 0; +}; + +/** @} */ // end of group +} // namespace CUDF_EXPORT cudf diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 6661a461b8b..d873e93bd20 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -88,6 +89,8 @@ class simple_aggregations_collector { // Declares the interface for the simple class lead_lag_aggregation const& agg); virtual std::vector> visit(data_type col_type, class udf_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class host_udf_aggregation const& agg); virtual std::vector> visit(data_type col_type, class merge_lists_aggregation const& agg); virtual std::vector> visit(data_type col_type, @@ -135,6 +138,7 @@ class aggregation_finalizer { // Declares the interface for the finalizer virtual void visit(class collect_set_aggregation const& agg); virtual void visit(class lead_lag_aggregation const& agg); virtual void visit(class udf_aggregation const& agg); + virtual void visit(class host_udf_aggregation const& agg); virtual void visit(class merge_lists_aggregation const& agg); virtual void visit(class merge_sets_aggregation const& agg); virtual void visit(class merge_m2_aggregation const& agg); @@ -960,6 +964,35 @@ class udf_aggregation final : public rolling_aggregation { } }; +/** + * @brief Derived class for specifying host-based UDF aggregation. + */ +class host_udf_aggregation final : public groupby_aggregation { + public: + std::unique_ptr udf_ptr; + + host_udf_aggregation() = delete; + host_udf_aggregation(host_udf_aggregation const&) = delete; + + // Need to define the constructor and destructor in a separate source file where we have the + // complete declaration of `host_udf_base`. + explicit host_udf_aggregation(std::unique_ptr udf_ptr_); + ~host_udf_aggregation() override; + + [[nodiscard]] bool is_equal(aggregation const& _other) const override; + + [[nodiscard]] size_t do_hash() const override; + + [[nodiscard]] std::unique_ptr clone() const override; + + std::vector> get_simple_aggregations( + data_type col_type, simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; + /** * @brief Derived aggregation class for specifying MERGE_LISTS aggregation */ @@ -1462,6 +1495,12 @@ struct target_type_impl +struct target_type_impl { + // Just a placeholder. The actual return type is unknown. + using type = struct_view; +}; + /** * @brief Helper alias to get the accumulator type for performing aggregation * `k` on elements of type `Source` @@ -1579,6 +1618,8 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind return f.template operator()(std::forward(args)...); case aggregation::EWMA: return f.template operator()(std::forward(args)...); + case aggregation::HOST_UDF: + return f.template operator()(std::forward(args)...); default: { #ifndef __CUDA_ARCH__ CUDF_FAIL("Unsupported aggregation."); diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index a60a7f63882..0d4400b891b 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -237,6 +237,12 @@ std::vector> simple_aggregations_collector::visit( return visit(col_type, static_cast(agg)); } +std::vector> simple_aggregations_collector::visit( + data_type col_type, host_udf_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + // aggregation_finalizer ---------------------------------------- void aggregation_finalizer::visit(aggregation const& agg) {} @@ -410,6 +416,11 @@ void aggregation_finalizer::visit(merge_tdigest_aggregation const& agg) visit(static_cast(agg)); } +void aggregation_finalizer::visit(host_udf_aggregation const& agg) +{ + visit(static_cast(agg)); +} + } // namespace detail std::vector> aggregation::get_simple_aggregations( diff --git a/cpp/src/groupby/groupby.cu b/cpp/src/groupby/groupby.cu index c42038026e5..4c90cd0eef5 100644 --- a/cpp/src/groupby/groupby.cu +++ b/cpp/src/groupby/groupby.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -32,7 +33,6 @@ #include #include #include -#include #include #include #include @@ -99,6 +99,8 @@ namespace { struct empty_column_constructor { column_view values; aggregation const& agg; + rmm::cuda_stream_view stream; + rmm::device_async_resource_ref mr; template std::unique_ptr operator()() const @@ -108,7 +110,7 @@ struct empty_column_constructor { if constexpr (k == aggregation::Kind::COLLECT_LIST || k == aggregation::Kind::COLLECT_SET) { return make_lists_column( - 0, make_empty_column(type_to_id()), empty_like(values), 0, {}); + 0, make_empty_column(type_to_id()), empty_like(values), 0, {}, stream, mr); } if constexpr (k == aggregation::Kind::HISTOGRAM) { @@ -116,7 +118,9 @@ struct empty_column_constructor { make_empty_column(type_to_id()), cudf::reduction::detail::make_empty_histogram_like(values), 0, - {}); + {}, + stream, + mr); } if constexpr (k == aggregation::Kind::MERGE_HISTOGRAM) { return empty_like(values); } @@ -140,31 +144,41 @@ struct empty_column_constructor { return empty_like(values); } + if constexpr (k == aggregation::Kind::HOST_UDF) { + auto const& udf_ptr = dynamic_cast(agg).udf_ptr; + return std::get>(udf_ptr->get_empty_output(std::nullopt, stream, mr)); + } + return make_empty_column(target_type(values.type(), k)); } }; /// Make an empty table with appropriate types for requested aggs template -auto empty_results(host_span requests) +auto empty_results(host_span requests, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { std::vector empty_results; - std::transform( - requests.begin(), requests.end(), std::back_inserter(empty_results), [](auto const& request) { - std::vector> results; - - std::transform( - request.aggregations.begin(), - request.aggregations.end(), - std::back_inserter(results), - [&request](auto const& agg) { - return cudf::detail::dispatch_type_and_aggregation( - request.values.type(), agg->kind, empty_column_constructor{request.values, *agg}); - }); - - return aggregation_result{std::move(results)}; - }); + std::transform(requests.begin(), + requests.end(), + std::back_inserter(empty_results), + [stream, mr](auto const& request) { + std::vector> results; + + std::transform(request.aggregations.begin(), + request.aggregations.end(), + std::back_inserter(results), + [&request, stream, mr](auto const& agg) { + return cudf::detail::dispatch_type_and_aggregation( + request.values.type(), + agg->kind, + empty_column_constructor{request.values, *agg, stream, mr}); + }); + + return aggregation_result{std::move(results)}; + }); return empty_results; } @@ -206,7 +220,7 @@ std::pair, std::vector> groupby::aggr verify_valid_requests(requests); - if (_keys.num_rows() == 0) { return {empty_like(_keys), empty_results(requests)}; } + if (_keys.num_rows() == 0) { return {empty_like(_keys), empty_results(requests, stream, mr)}; } return dispatch_aggregation(requests, stream, mr); } @@ -226,7 +240,9 @@ std::pair, std::vector> groupby::scan verify_valid_requests(requests); - if (_keys.num_rows() == 0) { return std::pair(empty_like(_keys), empty_results(requests)); } + if (_keys.num_rows() == 0) { + return std::pair(empty_like(_keys), empty_results(requests, stream, mr)); + } return sort_scan(requests, stream, mr); } diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 7a8a1883ed4..e9f885a5917 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -19,6 +19,7 @@ #include "groupby/sort/group_reductions.hpp" #include +#include #include #include #include @@ -795,6 +796,65 @@ void aggregate_result_functor::operator()(aggregatio mr)); } +template <> +void aggregate_result_functor::operator()(aggregation const& agg) +{ + if (cache.has_result(values, agg)) { return; } + + auto const& udf_ptr = dynamic_cast(agg).udf_ptr; + auto const data_attrs = [&]() -> host_udf_base::data_attribute_set_t { + if (auto tmp = udf_ptr->get_required_data(); !tmp.empty()) { return tmp; } + // Empty attribute set means everything. + return {host_udf_base::groupby_data_attribute::INPUT_VALUES, + host_udf_base::groupby_data_attribute::GROUPED_VALUES, + host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES, + host_udf_base::groupby_data_attribute::NUM_GROUPS, + host_udf_base::groupby_data_attribute::GROUP_OFFSETS, + host_udf_base::groupby_data_attribute::GROUP_LABELS}; + }(); + + // Do not cache udf_input, as the actual input data may change from run to run. + host_udf_base::input_map_t udf_input; + for (auto const& attr : data_attrs) { + CUDF_EXPECTS(std::holds_alternative(attr.value) || + std::holds_alternative>(attr.value), + "Invalid input data attribute for HOST_UDF groupby aggregation."); + if (std::holds_alternative(attr.value)) { + switch (std::get(attr.value)) { + case host_udf_base::groupby_data_attribute::INPUT_VALUES: + udf_input.emplace(attr, values); + break; + case host_udf_base::groupby_data_attribute::GROUPED_VALUES: + udf_input.emplace(attr, get_grouped_values()); + break; + case host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES: + udf_input.emplace(attr, get_sorted_values()); + break; + case host_udf_base::groupby_data_attribute::NUM_GROUPS: + udf_input.emplace(attr, helper.num_groups(stream)); + break; + case host_udf_base::groupby_data_attribute::GROUP_OFFSETS: + udf_input.emplace(attr, helper.group_offsets(stream)); + break; + case host_udf_base::groupby_data_attribute::GROUP_LABELS: + udf_input.emplace(attr, helper.group_labels(stream)); + break; + default: CUDF_UNREACHABLE("Invalid input data attribute for HOST_UDF groupby aggregation."); + } + } else { // data is result from another aggregation + auto other_agg = std::get>(attr.value)->clone(); + cudf::detail::aggregation_dispatcher(other_agg->kind, *this, *other_agg); + auto result = cache.get_result(values, *other_agg); + udf_input.emplace(std::move(other_agg), std::move(result)); + } + } + + auto output = (*udf_ptr)(udf_input, stream, mr); + CUDF_EXPECTS(std::holds_alternative>(output), + "Invalid output type from HOST_UDF groupby aggregation."); + cache.add_result(values, agg, std::get>(std::move(output))); +} + } // namespace detail // Sort-based groupby diff --git a/cpp/src/groupby/sort/host_udf_aggregation.cpp b/cpp/src/groupby/sort/host_udf_aggregation.cpp new file mode 100644 index 00000000000..0da47e17f48 --- /dev/null +++ b/cpp/src/groupby/sort/host_udf_aggregation.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace cudf { + +host_udf_base::data_attribute::data_attribute(data_attribute const& other) + : value{std::visit(cudf::detail::visitor_overload{[](auto const& val) { return value_type{val}; }, + [](std::unique_ptr const& val) { + return value_type{val->clone()}; + }}, + other.value)} +{ +} + +std::size_t host_udf_base::data_attribute::hash::operator()(data_attribute const& attr) const +{ + auto const hash_value = + std::visit(cudf::detail::visitor_overload{ + [](auto const& val) { return std::hash{}(static_cast(val)); }, + [](std::unique_ptr const& val) { return val->do_hash(); }}, + attr.value); + return std::hash{}(attr.value.index()) ^ hash_value; +} + +bool host_udf_base::data_attribute::equal_to::operator()(data_attribute const& lhs, + data_attribute const& rhs) const +{ + auto const& lhs_val = lhs.value; + auto const& rhs_val = rhs.value; + if (lhs_val.index() != rhs_val.index()) { return false; } + return std::visit( + cudf::detail::visitor_overload{ + [](auto const& lhs_val, auto const& rhs_val) { + if constexpr (std::is_same_v) { + return lhs_val == rhs_val; + } else { + return false; + } + }, + [](std::unique_ptr const& lhs_val, std::unique_ptr const& rhs_val) { + return lhs_val->is_equal(*rhs_val); + }}, + lhs_val, + rhs_val); +} + +namespace detail { + +host_udf_aggregation::host_udf_aggregation(std::unique_ptr udf_ptr_) + : aggregation{HOST_UDF}, udf_ptr{std::move(udf_ptr_)} +{ + CUDF_EXPECTS(udf_ptr != nullptr, "Invalid host_udf_base instance."); +} + +host_udf_aggregation::~host_udf_aggregation() = default; + +bool host_udf_aggregation::is_equal(aggregation const& _other) const +{ + if (!this->aggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); + return udf_ptr->is_equal(*other.udf_ptr); +} + +size_t host_udf_aggregation::do_hash() const +{ + return this->aggregation::do_hash() ^ udf_ptr->do_hash(); +} + +std::unique_ptr host_udf_aggregation::clone() const +{ + return std::make_unique(udf_ptr->clone()); +} + +} // namespace detail + +template +std::unique_ptr make_host_udf_aggregation(std::unique_ptr udf_ptr_) +{ + return std::make_unique(std::move(udf_ptr_)); +} +template CUDF_EXPORT std::unique_ptr make_host_udf_aggregation( + std::unique_ptr); +template CUDF_EXPORT std::unique_ptr + make_host_udf_aggregation(std::unique_ptr); + +} // namespace cudf diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index adf512811cc..e5c29314203 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -132,6 +132,8 @@ ConfigureTest( groupby/groupby_test_util.cpp groupby/groups_tests.cpp groupby/histogram_tests.cpp + groupby/host_udf_example_tests.cu + groupby/host_udf_tests.cpp groupby/keys_tests.cpp groupby/lists_tests.cpp groupby/m2_tests.cpp diff --git a/cpp/tests/groupby/host_udf_example_tests.cu b/cpp/tests/groupby/host_udf_example_tests.cu new file mode 100644 index 00000000000..a454bd692fc --- /dev/null +++ b/cpp/tests/groupby/host_udf_example_tests.cu @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace { +/** + * @brief A host-based UDF implementation for groupby. + * + * For each group of values, the aggregation computes + * `(group_idx + 1) * group_sum_of_squares - group_max * group_sum`. + */ +struct host_udf_groupby_example : cudf::host_udf_base { + host_udf_groupby_example() = default; + + [[nodiscard]] data_attribute_set_t get_required_data() const override + { + // We need grouped values, group offsets, group labels, and also results from groups' + // MAX and SUM aggregations. + return {groupby_data_attribute::GROUPED_VALUES, + groupby_data_attribute::GROUP_OFFSETS, + groupby_data_attribute::GROUP_LABELS, + cudf::make_max_aggregation(), + cudf::make_sum_aggregation()}; + } + + [[nodiscard]] output_t get_empty_output( + [[maybe_unused]] std::optional output_dtype, + [[maybe_unused]] rmm::cuda_stream_view stream, + [[maybe_unused]] rmm::device_async_resource_ref mr) const override + { + return cudf::make_empty_column( + cudf::data_type{cudf::type_to_id()}); + } + + [[nodiscard]] output_t operator()(input_map_t const& input, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override + { + auto const& values = + std::get(input.at(groupby_data_attribute::GROUPED_VALUES)); + return cudf::type_dispatcher(values.type(), groupby_fn{this}, input, stream, mr); + } + + [[nodiscard]] std::size_t do_hash() const override + { + // Just return the same hash for all instances of this class. + return std::size_t{12345}; + } + + [[nodiscard]] bool is_equal(host_udf_base const& other) const override + { + // Just check if the other object is also instance of this class. + return dynamic_cast(&other) != nullptr; + } + + [[nodiscard]] std::unique_ptr clone() const override + { + return std::make_unique(); + } + + struct groupby_fn { + // Store pointer to the parent class so we can call its functions. + host_udf_groupby_example const* parent; + + // For simplicity, this example only accepts double input and always produces double output. + using InputType = double; + using OutputType = double; + + template )> + output_t operator()(Args...) const + { + CUDF_FAIL("Unsupported input type."); + } + + template )> + output_t operator()(input_map_t const& input, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const + { + auto const& values = + std::get(input.at(groupby_data_attribute::GROUPED_VALUES)); + if (values.size() == 0) { return parent->get_empty_output(std::nullopt, stream, mr); } + + auto const offsets = std::get>( + input.at(groupby_data_attribute::GROUP_OFFSETS)); + CUDF_EXPECTS(offsets.size() > 0, "Invalid offsets."); + auto const num_groups = static_cast(offsets.size()) - 1; + auto const group_indices = std::get>( + input.at(groupby_data_attribute::GROUP_LABELS)); + auto const group_max = std::get( + input.at(cudf::make_max_aggregation())); + auto const group_sum = std::get( + input.at(cudf::make_sum_aggregation())); + + auto const values_dv_ptr = cudf::column_device_view::create(values, stream); + auto const output = cudf::make_numeric_column(cudf::data_type{cudf::type_to_id()}, + num_groups, + cudf::mask_state::UNALLOCATED, + stream, + mr); + + // Store row index if it is valid, otherwise store a negative value denoting a null row. + rmm::device_uvector valid_idx(num_groups, stream); + + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_groups), + thrust::make_zip_iterator(output->mutable_view().begin(), valid_idx.begin()), + transform_fn{*values_dv_ptr, + offsets, + group_indices, + group_max.begin(), + group_sum.begin()}); + + auto const valid_idx_cv = cudf::column_view{ + cudf::data_type{cudf::type_id::INT32}, num_groups, valid_idx.begin(), nullptr, 0}; + return std::move(cudf::gather(cudf::table_view{{output->view()}}, + valid_idx_cv, + cudf::out_of_bounds_policy::NULLIFY, + stream, + mr) + ->release() + .front()); + } + + struct transform_fn { + cudf::column_device_view values; + cudf::device_span offsets; + cudf::device_span group_indices; + InputType const* group_max; + InputType const* group_sum; + + thrust::tuple __device__ operator()(cudf::size_type idx) const + { + auto const start = offsets[idx]; + auto const end = offsets[idx + 1]; + + auto constexpr invalid_idx = cuda::std::numeric_limits::lowest(); + if (start == end) { return {OutputType{0}, invalid_idx}; } + + auto sum_sqr = OutputType{0}; + bool has_valid{false}; + for (auto i = start; i < end; ++i) { + if (values.is_null(i)) { continue; } + has_valid = true; + auto const val = static_cast(values.element(i)); + sum_sqr += val * val; + } + + if (!has_valid) { return {OutputType{0}, invalid_idx}; } + return {static_cast(group_indices[start] + 1) * sum_sqr - + static_cast(group_max[idx]) * static_cast(group_sum[idx]), + idx}; + } + }; + }; +}; + +} // namespace + +using doubles_col = cudf::test::fixed_width_column_wrapper; +using int32s_col = cudf::test::fixed_width_column_wrapper; + +struct HostUDFGroupbyExampleTest : cudf::test::BaseFixture {}; + +TEST_F(HostUDFGroupbyExampleTest, SimpleInput) +{ + double constexpr null = 0.0; + auto const keys = int32s_col{0, 1, 2, 0, 1, 2, 0, 1, 2, 0}; + auto const vals = doubles_col{{0.0, null, 2.0, 3.0, null, 5.0, null, null, 8.0, 9.0}, + {true, false, true, true, false, true, false, false, true, true}}; + auto agg = cudf::make_host_udf_aggregation( + std::make_unique()); + + std::vector requests; + requests.emplace_back(); + requests[0].values = vals; + requests[0].aggregations.push_back(std::move(agg)); + cudf::groupby::groupby gb_obj( + cudf::table_view({keys}), cudf::null_policy::INCLUDE, cudf::sorted::NO, {}, {}); + + auto const grp_result = gb_obj.aggregate(requests, cudf::test::get_default_stream()); + auto const& result = grp_result.second[0].results[0]; + + // Output type of groupby is double. + // Values grouped by keys: [ {0, 3, null, 9}, {null, null, null}, {2, 5, 8} ] + // Group sum_sqr: [ 90, null, 93 ] + // Group max: [ 9, null, 8 ] + // Group sum: [ 12, null, 15 ] + // Output: [ 1 * 90 - 9 * 12, null, 3 * 93 - 8 * 15 ] + auto const expected = doubles_col{{-18.0, null, 159.0}, {true, false, true}}; + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); +} + +TEST_F(HostUDFGroupbyExampleTest, EmptyInput) +{ + auto const keys = int32s_col{}; + auto const vals = doubles_col{}; + auto agg = cudf::make_host_udf_aggregation( + std::make_unique()); + + std::vector requests; + requests.emplace_back(); + requests[0].values = vals; + requests[0].aggregations.push_back(std::move(agg)); + cudf::groupby::groupby gb_obj( + cudf::table_view({keys}), cudf::null_policy::INCLUDE, cudf::sorted::NO, {}, {}); + + auto const grp_result = gb_obj.aggregate(requests, cudf::test::get_default_stream()); + auto const& result = grp_result.second[0].results[0]; + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals, *result); +} diff --git a/cpp/tests/groupby/host_udf_tests.cpp b/cpp/tests/groupby/host_udf_tests.cpp new file mode 100644 index 00000000000..1a0f68c0c6c --- /dev/null +++ b/cpp/tests/groupby/host_udf_tests.cpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace { +/** + * @brief A host-based UDF implementation used for unit tests. + */ +struct host_udf_test_base : cudf::host_udf_base { + int test_location_line; // the location where testing is called + bool* test_run; // to check if the test is accidentally skipped + data_attribute_set_t input_attrs; + + host_udf_test_base(int test_location_line_, bool* test_run_, data_attribute_set_t input_attrs_) + : test_location_line{test_location_line_}, + test_run{test_run_}, + input_attrs(std::move(input_attrs_)) + { + } + + [[nodiscard]] data_attribute_set_t get_required_data() const override { return input_attrs; } + + // This is the main testing function, which checks for the correctness of input data. + // The rests are just to satisfy the interface. + [[nodiscard]] output_t operator()(input_map_t const& input, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override + { + SCOPED_TRACE("Test instance created at line: " + std::to_string(test_location_line)); + + test_data_attributes(input, stream, mr); + + *test_run = true; // test is run successfully + return get_empty_output(std::nullopt, stream, mr); + } + + [[nodiscard]] output_t get_empty_output( + [[maybe_unused]] std::optional output_dtype, + [[maybe_unused]] rmm::cuda_stream_view stream, + [[maybe_unused]] rmm::device_async_resource_ref mr) const override + { + // Unused function - dummy output. + return cudf::make_empty_column(cudf::data_type{cudf::type_id::INT32}); + } + + [[nodiscard]] std::size_t do_hash() const override { return 0; } + [[nodiscard]] bool is_equal(host_udf_base const& other) const override { return true; } + + // The main test function, which must be implemented for each kind of aggregations + // (groupby/reduction/segmented_reduction). + virtual void test_data_attributes(input_map_t const& input, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const = 0; +}; + +/** + * @brief A host-based UDF implementation used for unit tests for groupby aggregation. + */ +struct host_udf_groupby_test : host_udf_test_base { + host_udf_groupby_test(int test_location_line_, + bool* test_run_, + data_attribute_set_t input_attrs_ = {}) + : host_udf_test_base(test_location_line_, test_run_, std::move(input_attrs_)) + { + } + + [[nodiscard]] std::unique_ptr clone() const override + { + return std::make_unique(test_location_line, test_run, input_attrs); + } + + void test_data_attributes(input_map_t const& input, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override + { + data_attribute_set_t check_attrs = input_attrs; + if (check_attrs.empty()) { + check_attrs = data_attribute_set_t{groupby_data_attribute::INPUT_VALUES, + groupby_data_attribute::GROUPED_VALUES, + groupby_data_attribute::SORTED_GROUPED_VALUES, + groupby_data_attribute::NUM_GROUPS, + groupby_data_attribute::GROUP_OFFSETS, + groupby_data_attribute::GROUP_LABELS}; + } + EXPECT_EQ(input.size(), check_attrs.size()); + for (auto const& attr : check_attrs) { + EXPECT_TRUE(input.count(attr) > 0); + EXPECT_TRUE(std::holds_alternative(attr.value) || + std::holds_alternative>(attr.value)); + if (std::holds_alternative(attr.value)) { + switch (std::get(attr.value)) { + case groupby_data_attribute::INPUT_VALUES: + EXPECT_TRUE(std::holds_alternative(input.at(attr))); + break; + case groupby_data_attribute::GROUPED_VALUES: + EXPECT_TRUE(std::holds_alternative(input.at(attr))); + break; + case groupby_data_attribute::SORTED_GROUPED_VALUES: + EXPECT_TRUE(std::holds_alternative(input.at(attr))); + break; + case groupby_data_attribute::NUM_GROUPS: + EXPECT_TRUE(std::holds_alternative(input.at(attr))); + break; + case groupby_data_attribute::GROUP_OFFSETS: + EXPECT_TRUE( + std::holds_alternative>(input.at(attr))); + break; + case groupby_data_attribute::GROUP_LABELS: + EXPECT_TRUE( + std::holds_alternative>(input.at(attr))); + break; + default:; + } + } else { // std::holds_alternative>(attr.value) + EXPECT_TRUE(std::holds_alternative(input.at(attr))); + } + } + } +}; + +/** + * @brief Get a random subset of input data attributes. + */ +cudf::host_udf_base::data_attribute_set_t get_subset( + cudf::host_udf_base::data_attribute_set_t const& attrs) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution size_distr(1, attrs.size() - 1); + auto const subset_size = size_distr(gen); + auto const elements = + std::vector(attrs.begin(), attrs.end()); + std::uniform_int_distribution idx_distr(0, attrs.size() - 1); + cudf::host_udf_base::data_attribute_set_t output; + while (output.size() < subset_size) { + output.insert(elements[idx_distr(gen)]); + } + return output; +} + +/** + * @brief Generate a random aggregation object from {min, max, sum, product}. + */ +std::unique_ptr get_random_agg() +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution distr(1, 4); + switch (distr(gen)) { + case 1: return cudf::make_min_aggregation(); + case 2: return cudf::make_max_aggregation(); + case 3: return cudf::make_sum_aggregation(); + case 4: return cudf::make_product_aggregation(); + default: CUDF_UNREACHABLE("This should not be reached."); + } + return nullptr; +} + +} // namespace + +using int32s_col = cudf::test::fixed_width_column_wrapper; + +// Number of randomly testing on the input data attributes. +// For each test, a subset of data attributes will be randomly generated from all the possible input +// data attributes. The input data corresponding to that subset passed from libcudf will be tested +// for correctness. +constexpr int NUM_RANDOM_TESTS = 20; + +struct HostUDFTest : cudf::test::BaseFixture {}; + +TEST_F(HostUDFTest, GroupbyAllInput) +{ + bool test_run = false; + auto const keys = int32s_col{0, 1, 2}; + auto const vals = int32s_col{0, 1, 2}; + auto agg = cudf::make_host_udf_aggregation( + std::make_unique(__LINE__, &test_run)); + + std::vector requests; + requests.emplace_back(); + requests[0].values = vals; + requests[0].aggregations.push_back(std::move(agg)); + cudf::groupby::groupby gb_obj( + cudf::table_view({keys}), cudf::null_policy::INCLUDE, cudf::sorted::NO, {}, {}); + [[maybe_unused]] auto const grp_result = + gb_obj.aggregate(requests, cudf::test::get_default_stream()); + EXPECT_TRUE(test_run); +} + +TEST_F(HostUDFTest, GroupbySomeInput) +{ + auto const keys = int32s_col{0, 1, 2}; + auto const vals = int32s_col{0, 1, 2}; + auto const all_attrs = cudf::host_udf_base::data_attribute_set_t{ + cudf::host_udf_base::groupby_data_attribute::INPUT_VALUES, + cudf::host_udf_base::groupby_data_attribute::GROUPED_VALUES, + cudf::host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES, + cudf::host_udf_base::groupby_data_attribute::NUM_GROUPS, + cudf::host_udf_base::groupby_data_attribute::GROUP_OFFSETS, + cudf::host_udf_base::groupby_data_attribute::GROUP_LABELS}; + for (int i = 0; i < NUM_RANDOM_TESTS; ++i) { + bool test_run = false; + auto input_attrs = get_subset(all_attrs); + input_attrs.insert(get_random_agg()); + auto agg = cudf::make_host_udf_aggregation( + std::make_unique(__LINE__, &test_run, std::move(input_attrs))); + + std::vector requests; + requests.emplace_back(); + requests[0].values = vals; + requests[0].aggregations.push_back(std::move(agg)); + cudf::groupby::groupby gb_obj( + cudf::table_view({keys}), cudf::null_policy::INCLUDE, cudf::sorted::NO, {}, {}); + [[maybe_unused]] auto const grp_result = + gb_obj.aggregate(requests, cudf::test::get_default_stream()); + EXPECT_TRUE(test_run); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 379750bb0b7..2276b223740 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -62,15 +62,16 @@ enum Kind { LAG(23), PTX(24), CUDA(25), - M2(26), - MERGE_M2(27), - RANK(28), - DENSE_RANK(29), - PERCENT_RANK(30), - TDIGEST(31), // This can take a delta argument for accuracy level - MERGE_TDIGEST(32), // This can take a delta argument for accuracy level - HISTOGRAM(33), - MERGE_HISTOGRAM(34); + HOST_UDF(26), + M2(27), + MERGE_M2(28), + RANK(29), + DENSE_RANK(30), + PERCENT_RANK(31), + TDIGEST(32), // This can take a delta argument for accuracy level + MERGE_TDIGEST(33), // This can take a delta argument for accuracy level + HISTOGRAM(34), + MERGE_HISTOGRAM(35); final int nativeId; @@ -385,6 +386,35 @@ public boolean equals(Object other) { } } + static final class HostUDFAggregation extends Aggregation { + private final HostUDFWrapper wrapper; + + private HostUDFAggregation(HostUDFWrapper wrapper) { + super(Kind.HOST_UDF); + this.wrapper = wrapper; + } + + @Override + long createNativeInstance() { + return Aggregation.createHostUDFAgg(wrapper.udfNativeHandle); + } + + @Override + public int hashCode() { + return 31 * kind.hashCode() + wrapper.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other instanceof HostUDFAggregation) { + return wrapper.equals(((HostUDFAggregation) other).wrapper); + } + return false; + } + } + protected final Kind kind; protected Aggregation(Kind kind) { @@ -837,6 +867,15 @@ static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nan return new MergeSetsAggregation(nullEquality, nanEquality); } + /** + * Host UDF aggregation, to execute a host-side user-defined function (UDF). + * @param wrapper The wrapper for the native host UDF instance. + * @return A new HostUDFAggregation instance + */ + static HostUDFAggregation hostUDF(HostUDFWrapper wrapper) { + return new HostUDFAggregation(wrapper); + } + static final class LeadAggregation extends LeadLagAggregation { private LeadAggregation(int offset, ColumnVector defaultOutput) { super(Kind.LEAD, offset, defaultOutput); @@ -990,4 +1029,9 @@ static MergeHistogramAggregation mergeHistogram() { * Create a TDigest aggregation. */ private static native long createTDigestAgg(int kind, int delta); + + /** + * Create a HOST_UDF aggregation. + */ + private static native long createHostUDFAgg(long udfNativeHandle); } diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java index 0fae33927b6..27966ddfdd4 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -277,6 +277,15 @@ public static GroupByAggregation mergeSets() { return new GroupByAggregation(Aggregation.mergeSets()); } + /** + * Execute an aggregation using a host-side user-defined function (UDF). + * @param wrapper The wrapper for the native host UDF instance. + * @return A new GroupByAggregation instance + */ + public static GroupByAggregation hostUDF(HostUDFWrapper wrapper) { + return new GroupByAggregation(Aggregation.hostUDF(wrapper)); + } + /** * Merge the partial sets produced by multiple CollectSetAggregations. * diff --git a/java/src/main/java/ai/rapids/cudf/HostUDFWrapper.java b/java/src/main/java/ai/rapids/cudf/HostUDFWrapper.java new file mode 100644 index 00000000000..0b6ecf2e140 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/HostUDFWrapper.java @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf; + +/** + * A wrapper around native host UDF aggregations. + *

+ * This class is used to store the native handle of a host UDF aggregation and is used as + * a proxy object to compute hash code and compare two host UDF aggregations for equality. + *

+ * A new host UDF aggregation implementation must extend this class and override the + * {@code hashCode} and {@code equals} methods for such purposes. + */ +public abstract class HostUDFWrapper { + public final long udfNativeHandle; + + public HostUDFWrapper(long udfNativeHandle) { + this.udfNativeHandle = udfNativeHandle; + } +} diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index c40f1c55500..dd41c677761 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -17,6 +17,7 @@ #include "cudf_jni_apis.hpp" #include +#include extern "C" { @@ -80,25 +81,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv* // case 23: LAG // case 24: PTX // case 25: CUDA - case 26: // M2 + // case 26: HOST_UDF + case 27: // M2 return cudf::make_m2_aggregation(); - case 27: // MERGE_M2 + case 28: // MERGE_M2 return cudf::make_merge_m2_aggregation(); - case 28: // RANK + case 29: // RANK return cudf::make_rank_aggregation( cudf::rank_method::MIN, {}, cudf::null_policy::INCLUDE); - case 29: // DENSE_RANK + case 30: // DENSE_RANK return cudf::make_rank_aggregation( cudf::rank_method::DENSE, {}, cudf::null_policy::INCLUDE); - case 30: // ANSI SQL PERCENT_RANK + case 31: // ANSI SQL PERCENT_RANK return cudf::make_rank_aggregation(cudf::rank_method::MIN, {}, cudf::null_policy::INCLUDE, {}, cudf::rank_percentage::ONE_NORMALIZED); - case 33: // HISTOGRAM + // case 32: TDIGEST + // case 33: MERGE_TDIGEST + case 34: // HISTOGRAM return cudf::make_histogram_aggregation(); - case 34: // MERGE_HISTOGRAM + case 35: // MERGE_HISTOGRAM return cudf::make_merge_histogram_aggregation(); default: throw std::logic_error("Unsupported No Parameter Aggregation Operation"); @@ -160,10 +164,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createTDigestAgg(JNIEnv* std::unique_ptr ret; // These numbers come from Aggregation.java and must stay in sync switch (kind) { - case 31: // TDIGEST + case 32: // TDIGEST ret = cudf::make_tdigest_aggregation(delta); break; - case 32: // MERGE_TDIGEST + case 33: // MERGE_TDIGEST ret = cudf::make_merge_tdigest_aggregation(delta); break; default: throw std::logic_error("Unsupported TDigest Aggregation Operation"); @@ -296,4 +300,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMergeSetsAgg(JNIEn CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createHostUDFAgg(JNIEnv* env, + jclass class_object, + jlong udf_native_handle) +{ + JNI_NULL_CHECK(env, udf_native_handle, "udf_native_handle is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const udf_ptr = reinterpret_cast(udf_native_handle); + auto output = cudf::make_host_udf_aggregation(udf_ptr->clone()); + return reinterpret_cast(output.release()); + } + CATCH_STD(env, 0); +} + } // extern "C"