Skip to content

Commit

Permalink
Avoid using anything from cudf::detail:: in the example
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Dec 19, 2024
1 parent 322fb25 commit 7e16440
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions cpp/tests/groupby/host_udf_example_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#include <cudf/aggregation/host_udf.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/copying.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/valid_if.cuh>
#include <cudf/groupby.hpp>
#include <cudf/reduction.hpp>
#include <cudf/utilities/type_dispatcher.hpp>
Expand Down Expand Up @@ -96,8 +96,6 @@ struct host_udf_groupby_example : cudf::host_udf_base {
// For simplicity, this example only accepts double input and always produces double output.
using InputType = double;
using OutputType = double;
using MaxType = cudf::detail::target_type_t<InputType, cudf::aggregation::Kind::MAX>;
using SumType = cudf::detail::target_type_t<InputType, cudf::aggregation::Kind::SUM>;

template <typename T, typename... Args, CUDF_ENABLE_IF(!std::is_same_v<InputType, T>)>
output_t operator()(Args...) const
Expand Down Expand Up @@ -131,36 +129,44 @@ struct host_udf_groupby_example : cudf::host_udf_base {
cudf::mask_state::UNALLOCATED,
stream,
mr);
rmm::device_uvector<bool> validity(num_groups, stream);

// Store row index if it is valid, otherwise store a negative value denoting a null row.
rmm::device_uvector<cudf::size_type> valid_idx(num_groups, stream);

thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(num_groups),
thrust::make_zip_iterator(output->mutable_view().begin<OutputType>(), validity.begin()),
thrust::make_zip_iterator(output->mutable_view().begin<OutputType>(), valid_idx.begin()),
transform_fn{*values_dv_ptr,
offsets,
group_indices,
group_max.begin<MaxType>(),
group_sum.begin<SumType>()});
auto [null_mask, null_count] =
cudf::detail::valid_if(validity.begin(), validity.end(), thrust::identity<>{}, stream, mr);
if (null_count > 0) { output->set_null_mask(std::move(null_mask), null_count); }
return output;
group_max.begin<InputType>(),
group_sum.begin<InputType>()});

auto const valid_idx_cv = cudf::column_view{
cudf::data_type{cudf::type_id::INT32}, num_groups, valid_idx.begin(), nullptr, 0};
return std::move(cudf::gather(cudf::table_view{{output->view()}},
valid_idx_cv,
cudf::out_of_bounds_policy::NULLIFY,
stream,
mr)
->release()
.front());
}

struct transform_fn {
cudf::column_device_view values;
cudf::device_span<cudf::size_type const> offsets;
cudf::device_span<cudf::size_type const> group_indices;
MaxType const* group_max;
SumType const* group_sum;
InputType const* group_max;
InputType const* group_sum;

thrust::tuple<OutputType, bool> __device__ operator()(cudf::size_type idx) const
thrust::tuple<OutputType, cudf::size_type> __device__ operator()(cudf::size_type idx) const
{
auto const start = offsets[idx];
auto const end = offsets[idx + 1];
if (start == end) { return {OutputType{0}, false}; }
if (start == end) { return {OutputType{0}, -1}; }

auto sum_sqr = OutputType{0};
bool has_valid{false};
Expand All @@ -171,10 +177,10 @@ struct host_udf_groupby_example : cudf::host_udf_base {
sum_sqr += val * val;
}

if (!has_valid) { return {OutputType{0}, false}; }
if (!has_valid) { return {OutputType{0}, -1}; }
return {static_cast<OutputType>(group_indices[start] + 1) * sum_sqr -
static_cast<OutputType>(group_max[idx]) * static_cast<OutputType>(group_sum[idx]),
true};
idx};
}
};
};
Expand Down

0 comments on commit 7e16440

Please sign in to comment.