Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HOST_UDF aggregation for reduction and groupby #17249

Draft
wants to merge 68 commits into
base: branch-25.02
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
bba150c
Implement host udf aggregation
ttnghia Nov 5, 2024
04e2bda
Add test
ttnghia Nov 5, 2024
5f7ab2b
Change example to compute aggregation on each group
ttnghia Nov 5, 2024
7c9316a
Merge branch 'branch-25.02' into host_udf
ttnghia Nov 19, 2024
57674e1
Add `host_udf_base` class
ttnghia Nov 19, 2024
47c7a7c
Rename variable
ttnghia Nov 19, 2024
5e6017a
Rewrite docs
ttnghia Nov 21, 2024
174678f
Implement `host_udf_aggregation`
ttnghia Nov 21, 2024
cee28f6
Change the `host_udf_base` interface
ttnghia Nov 21, 2024
0da4988
Remove `target_type_impl` for `HOST_UDF`
ttnghia Nov 21, 2024
c9c9ee6
Rewrite comments
ttnghia Nov 21, 2024
227016b
Construct empty output when the input is empty
ttnghia Nov 21, 2024
15732cf
Implement `HOST_UDF` for reduction
ttnghia Nov 21, 2024
ee28be8
Implement `HOST_UDF` for segmented reduction
ttnghia Nov 21, 2024
8333964
Merge branch 'branch-25.02' into host_udf
ttnghia Nov 21, 2024
754ee58
Implementing tests
ttnghia Nov 22, 2024
5a7ea45
Fix error
ttnghia Nov 22, 2024
e0999bb
Change `host_udf_base` interface
ttnghia Nov 22, 2024
52e0acd
Implement `test_udf_simple_type`
ttnghia Nov 22, 2024
a1b568b
Implement a simple test
ttnghia Nov 22, 2024
7ec2dd9
Fix compile issues
ttnghia Nov 22, 2024
237bb72
Fix test
ttnghia Nov 22, 2024
b5b8f5b
Remove `init` value from `get_empty_output`
ttnghia Nov 22, 2024
bfec6a2
Fix test
ttnghia Nov 22, 2024
3bc9ae3
Fix compile issues
ttnghia Nov 22, 2024
1fd4c8a
Merge branch 'branch-25.02' into host_udf
ttnghia Nov 22, 2024
26be262
Enable `segmented_reduce_aggregation`
ttnghia Nov 24, 2024
6b3e3f7
Implement test for `segmented_reduce`
ttnghia Nov 24, 2024
3d505da
Fix empty output
ttnghia Nov 24, 2024
9d1ac9a
Fix empty input handling
ttnghia Nov 24, 2024
3aefaf3
Fix comment
ttnghia Nov 24, 2024
9b61fe3
Rename tests
ttnghia Nov 24, 2024
b87e2a7
Fix groupby type
ttnghia Nov 24, 2024
697993b
Add test `GroupbySimpleInput`
ttnghia Nov 24, 2024
7a81754
Add the ability to call other aggregations
ttnghia Nov 25, 2024
1b8fb92
Add anonymous namespace
ttnghia Nov 25, 2024
91489c1
Refactor
ttnghia Nov 26, 2024
b597192
Revert cmake
ttnghia Nov 26, 2024
26c3ec4
Fix style
ttnghia Nov 26, 2024
9c168e5
Add docs
ttnghia Nov 26, 2024
b10e924
Fix docs
ttnghia Nov 26, 2024
9dd26d3
Still fix docs
ttnghia Nov 26, 2024
d289528
Implement Java & JNI for `HostUDFAggregation`
ttnghia Nov 26, 2024
a5133e6
Fix instantiating code
ttnghia Nov 26, 2024
0043472
Remove unused headers
ttnghia Nov 26, 2024
f190fd8
Fix style
ttnghia Nov 26, 2024
4d559cf
Add unit tests
ttnghia Nov 26, 2024
22df331
Implement random tests
ttnghia Nov 27, 2024
82379ca
Fix compile issue
ttnghia Nov 27, 2024
ecfb879
Rename test file
ttnghia Nov 27, 2024
91a4724
Merge branch 'branch-25.02' into host_udf
ttnghia Nov 27, 2024
ef4392e
Rewrite tests, adding more check
ttnghia Nov 27, 2024
93ac14c
Add more Java classes
ttnghia Nov 27, 2024
26866f0
Merge branch 'branch-25.02' into host_udf
ttnghia Nov 27, 2024
4716373
Rewrite `host_udf_base`
ttnghia Nov 28, 2024
baa7991
Rewrite tests
ttnghia Nov 28, 2024
6d013ff
Merge branch 'branch-25.02' into host_udf
ttnghia Nov 28, 2024
8405167
Rewrite switch statements
ttnghia Nov 28, 2024
bbdc699
Fix out of sync enums
ttnghia Dec 2, 2024
81ce190
Merge branch 'branch-25.02' into host_udf
ttnghia Dec 2, 2024
3bd496d
Merge branch 'branch-25.02' into host_udf
ttnghia Dec 5, 2024
3f4d450
Rewrite example
ttnghia Dec 5, 2024
069600b
Instantiate `HostUDFAggregation` from `HostUDFWrapper`
ttnghia Dec 18, 2024
a63000d
Fix Java
ttnghia Dec 18, 2024
05084a4
Apply new wrapper
ttnghia Dec 18, 2024
28af2de
Move `HostUDFWrapper`
ttnghia Dec 18, 2024
d75d3da
Fix compile error
ttnghia Dec 18, 2024
9e0f996
Merge branch 'branch-25.02' into host_udf
ttnghia Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 64 additions & 37 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

#include <cudf/types.hpp>
#include <cudf/utilities/export.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <functional>
#include <memory>
Expand Down Expand Up @@ -84,43 +87,44 @@ class aggregation {
* @brief Possible aggregation operations
*/
enum Kind {
SUM, ///< sum reduction
PRODUCT, ///< product reduction
MIN, ///< min reduction
MAX, ///< max reduction
COUNT_VALID, ///< count number of valid elements
COUNT_ALL, ///< count number of elements
ANY, ///< any reduction
ALL, ///< all reduction
SUM_OF_SQUARES, ///< sum of squares reduction
MEAN, ///< arithmetic mean reduction
M2, ///< sum of squares of differences from the mean
VARIANCE, ///< variance
STD, ///< standard deviation
MEDIAN, ///< median reduction
QUANTILE, ///< compute specified quantile(s)
ARGMAX, ///< Index of max element
ARGMIN, ///< Index of min element
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation,
SUM, ///< sum reduction
PRODUCT, ///< product reduction
MIN, ///< min reduction
MAX, ///< max reduction
COUNT_VALID, ///< count number of valid elements
COUNT_ALL, ///< count number of elements
ANY, ///< any reduction
ALL, ///< all reduction
SUM_OF_SQUARES, ///< sum of squares reduction
MEAN, ///< arithmetic mean reduction
M2, ///< sum of squares of differences from the mean
VARIANCE, ///< variance
STD, ///< standard deviation
MEDIAN, ///< median reduction
QUANTILE, ///< compute specified quantile(s)
ARGMAX, ///< Index of max element
ARGMIN, ///< Index of min element
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation
HOST_UDF ///< host side UDF aggregation
};

aggregation() = delete;
Expand Down Expand Up @@ -770,5 +774,28 @@ std::unique_ptr<Base> make_tdigest_aggregation(int max_centroids = 1000);
template <typename Base>
std::unique_ptr<Base> make_merge_tdigest_aggregation(int max_centroids = 1000);

// We should pass as many parameters as possible to this function pointer,
// thus the UDF can have anything it needs to perform its operations.
// Currently (modify if needed):
// column_view const& input,
// cudf::device_span<size_type const> group_offsets,
// cudf::device_span<size_type const> group_labels,
// size_type num_groups,
// int max_centroids,
// rmm::cuda_stream_view stream,
// rmm::device_async_resource_ref mr
using host_udf_func_type = std::function<std::unique_ptr<column>(column_view const&,
device_span<size_type const>,
device_span<size_type const>,
size_type,
rmm::cuda_stream_view,
rmm::device_async_resource_ref)>;
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief make_host_udf_aggregation
* @return
*/
template <typename Base>
std::unique_ptr<Base> make_host_udf_aggregation(host_udf_func_type udf_func_);

/** @} */ // end of group
} // namespace CUDF_EXPORT cudf
35 changes: 35 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/span.hpp>
#include <cudf/utilities/traits.hpp>

#include <functional>
Expand Down Expand Up @@ -104,6 +105,8 @@ class simple_aggregations_collector { // Declares the interface for the simple
class tdigest_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, class merge_tdigest_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class host_udf_aggregation const& agg);
};

class aggregation_finalizer { // Declares the interface for the finalizer
Expand Down Expand Up @@ -144,6 +147,7 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class tdigest_aggregation const& agg);
virtual void visit(class merge_tdigest_aggregation const& agg);
virtual void visit(class ewma_aggregation const& agg);
virtual void visit(class host_udf_aggregation const& agg);
};

/**
Expand Down Expand Up @@ -1186,6 +1190,30 @@ class merge_tdigest_aggregation final : public groupby_aggregation, public reduc
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief
*/
class host_udf_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
host_udf_func_type host_udf_ptr;

explicit host_udf_aggregation(host_udf_func_type host_udf_ptr_)
: aggregation{HOST_UDF}, host_udf_ptr{std::move(host_udf_ptr_)}
{
}

[[nodiscard]] std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<host_udf_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Sentinel value used for `ARGMAX` aggregation.
*
Expand Down Expand Up @@ -1462,6 +1490,11 @@ struct target_type_impl<Source,
using type = struct_view;
};

template <typename SourceType>
struct target_type_impl<SourceType, aggregation::HOST_UDF> {
using type = struct_view;
};

/**
* @brief Helper alias to get the accumulator type for performing aggregation
* `k` on elements of type `Source`
Expand Down Expand Up @@ -1579,6 +1612,8 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind
return f.template operator()<aggregation::MERGE_TDIGEST>(std::forward<Ts>(args)...);
case aggregation::EWMA:
return f.template operator()<aggregation::EWMA>(std::forward<Ts>(args)...);
case aggregation::HOST_UDF:
return f.template operator()<aggregation::HOST_UDF>(std::forward<Ts>(args)...);
default: {
#ifndef __CUDA_ARCH__
CUDF_FAIL("Unsupported aggregation.");
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, host_udf_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

// aggregation_finalizer ----------------------------------------

void aggregation_finalizer::visit(aggregation const& agg) {}
Expand Down Expand Up @@ -410,6 +416,11 @@ void aggregation_finalizer::visit(merge_tdigest_aggregation const& agg)
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(host_udf_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

} // namespace detail

std::vector<std::unique_ptr<aggregation>> aggregation::get_simple_aggregations(
Expand Down Expand Up @@ -917,6 +928,18 @@ make_merge_tdigest_aggregation<groupby_aggregation>(int max_centroids);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_merge_tdigest_aggregation<reduce_aggregation>(int max_centroids);

template <typename Base>
std::unique_ptr<Base> make_host_udf_aggregation(host_udf_func_type udf_func_)
{
return std::make_unique<detail::host_udf_aggregation>(udf_func_);
}
template CUDF_EXPORT std::unique_ptr<aggregation> make_host_udf_aggregation<aggregation>(
host_udf_func_type);
template CUDF_EXPORT std::unique_ptr<groupby_aggregation>
make_host_udf_aggregation<groupby_aggregation>(host_udf_func_type);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_host_udf_aggregation<reduce_aggregation>(host_udf_func_type);

namespace detail {
namespace {
struct target_type_functor {
Expand Down
17 changes: 17 additions & 0 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,23 @@ void aggregate_result_functor::operator()<aggregation::MERGE_TDIGEST>(aggregatio
mr));
}

template <>
void aggregate_result_functor::operator()<aggregation::HOST_UDF>(aggregation const& agg)
{
// TODO: Add a name string to the aggregation so that we can look up different host UDFs.
if (cache.has_result(values, agg)) { return; }
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
auto const udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).host_udf_ptr;
CUDF_EXPECTS(udf_ptr != nullptr, "errrrrrrrrr");
cache.add_result(values,
agg,
udf_ptr(get_grouped_values(),
helper.group_offsets(stream),
helper.group_labels(stream),
helper.num_groups(stream),
stream,
mr));
}

} // namespace detail

// Sort-based groupby
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ struct reduce_dispatch_functor {
auto td_agg = static_cast<cudf::detail::merge_tdigest_aggregation const&>(agg);
return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr);
}
case aggregation::HOST_UDF: {
CUDF_FAIL("Host UDF aggregation is not implemented in `reduction`");
}
default: CUDF_FAIL("Unsupported reduction operator");
}
}
Expand Down
39 changes: 1 addition & 38 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,44 +122,7 @@ ConfigureTest(TIMESTAMPS_TEST wrappers/timestamps_test.cu)
# * groupby tests ---------------------------------------------------------------------------------
ConfigureTest(
GROUPBY_TEST
groupby/argmin_tests.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't intend for this change to be checked in, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this was temporarily modified for local tests, and should be reverted now in the new PR.

groupby/argmax_tests.cpp
groupby/collect_list_tests.cpp
groupby/collect_set_tests.cpp
groupby/correlation_tests.cpp
groupby/count_scan_tests.cpp
groupby/count_tests.cpp
groupby/covariance_tests.cpp
groupby/groupby_test_util.cpp
groupby/groups_tests.cpp
groupby/histogram_tests.cpp
groupby/keys_tests.cpp
groupby/lists_tests.cpp
groupby/m2_tests.cpp
groupby/min_tests.cpp
groupby/max_scan_tests.cpp
groupby/max_tests.cpp
groupby/mean_tests.cpp
groupby/median_tests.cpp
groupby/merge_m2_tests.cpp
groupby/merge_lists_tests.cpp
groupby/merge_sets_tests.cpp
groupby/min_scan_tests.cpp
groupby/nth_element_tests.cpp
groupby/nunique_tests.cpp
groupby/product_scan_tests.cpp
groupby/product_tests.cpp
groupby/quantile_tests.cpp
groupby/rank_scan_tests.cpp
groupby/replace_nulls_tests.cpp
groupby/shift_tests.cpp
groupby/std_tests.cpp
groupby/structs_tests.cpp
groupby/sum_of_squares_tests.cpp
groupby/sum_scan_tests.cpp
groupby/sum_tests.cpp
groupby/tdigest_tests.cu
groupby/var_tests.cpp
groupby/host_udf_tests.cu
GPUS 1
PERCENT 100
)
Expand Down
Loading
Loading