From 7e1644091322dc2a87c26e28e37fc4ccd9faa4fd Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 19 Dec 2024 14:32:33 -0800 Subject: [PATCH] Avoid using anything from `cudf::detail::` in the example Signed-off-by: Nghia Truong --- cpp/tests/groupby/host_udf_example_tests.cu | 40 ++++++++++++--------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/cpp/tests/groupby/host_udf_example_tests.cu b/cpp/tests/groupby/host_udf_example_tests.cu index 13b6c611ca4..0ec424d1cc1 100644 --- a/cpp/tests/groupby/host_udf_example_tests.cu +++ b/cpp/tests/groupby/host_udf_example_tests.cu @@ -20,8 +20,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -96,8 +96,6 @@ struct host_udf_groupby_example : cudf::host_udf_base { // For simplicity, this example only accepts double input and always produces double output. using InputType = double; using OutputType = double; - using MaxType = cudf::detail::target_type_t; - using SumType = cudf::detail::target_type_t; template )> output_t operator()(Args...) const @@ -131,36 +129,44 @@ struct host_udf_groupby_example : cudf::host_udf_base { cudf::mask_state::UNALLOCATED, stream, mr); - rmm::device_uvector validity(num_groups, stream); + + // Store row index if it is valid, otherwise store a negative value denoting a null row. + rmm::device_uvector valid_idx(num_groups, stream); thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_groups), - thrust::make_zip_iterator(output->mutable_view().begin(), validity.begin()), + thrust::make_zip_iterator(output->mutable_view().begin(), valid_idx.begin()), transform_fn{*values_dv_ptr, offsets, group_indices, - group_max.begin(), - group_sum.begin()}); - auto [null_mask, null_count] = - cudf::detail::valid_if(validity.begin(), validity.end(), thrust::identity<>{}, stream, mr); - if (null_count > 0) { output->set_null_mask(std::move(null_mask), null_count); } - return output; + group_max.begin(), + group_sum.begin()}); + + auto const valid_idx_cv = cudf::column_view{ + cudf::data_type{cudf::type_id::INT32}, num_groups, valid_idx.begin(), nullptr, 0}; + return std::move(cudf::gather(cudf::table_view{{output->view()}}, + valid_idx_cv, + cudf::out_of_bounds_policy::NULLIFY, + stream, + mr) + ->release() + .front()); } struct transform_fn { cudf::column_device_view values; cudf::device_span offsets; cudf::device_span group_indices; - MaxType const* group_max; - SumType const* group_sum; + InputType const* group_max; + InputType const* group_sum; - thrust::tuple __device__ operator()(cudf::size_type idx) const + thrust::tuple __device__ operator()(cudf::size_type idx) const { auto const start = offsets[idx]; auto const end = offsets[idx + 1]; - if (start == end) { return {OutputType{0}, false}; } + if (start == end) { return {OutputType{0}, -1}; } auto sum_sqr = OutputType{0}; bool has_valid{false}; @@ -171,10 +177,10 @@ struct host_udf_groupby_example : cudf::host_udf_base { sum_sqr += val * val; } - if (!has_valid) { return {OutputType{0}, false}; } + if (!has_valid) { return {OutputType{0}, -1}; } return {static_cast(group_indices[start] + 1) * sum_sqr - static_cast(group_max[idx]) * static_cast(group_sum[idx]), - true}; + idx}; } }; };