Skip to content

Commit

Permalink
Fix groupby argmin/max gather of sorted-order indices (#17591)
Browse files Browse the repository at this point in the history
Fixes the gather logic in `groupby_argmin.cu` and `groupby_argmax.cu` that gathers the sorted-order indices from the results of the groupby reduction functions. The resulting indices must be remapped to the sorted-order indices before returning. The `gather` call has been fixed to use an output vector since the [gather documentation indicates the map and result iterators must not overlap](https://nvidia.github.io/cccl/thrust/api/function_group__gathering_1ga6fdb1fe3ff0d9ce01f41a72fa94c56df.html).
Also, the `gather_if` is not needed since the groupby reduction does not use the min/max sentinel values in its logic.

Closes #16542

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Bradley Dice (https://github.com/bdice)

URL: #17591
  • Loading branch information
davidwendt authored Dec 20, 2024
1 parent 27404bc commit 0f1bae8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 40 deletions.
10 changes: 2 additions & 8 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,7 @@ void aggregate_result_functor::operator()<aggregation::MIN>(aggregation const& a
operator()<aggregation::ARGMIN>(*argmin_agg);
column_view const argmin_result = cache.get_result(values, *argmin_agg);

// We make a view of ARGMIN result without a null mask and gather using
// this mask. The values in data buffer of ARGMIN result corresponding
// to null values was initialized to ARGMIN_SENTINEL which is an out of
// bounds index value and causes the gathered value to be null.
// Compute the ARGMIN result without the null mask in the gather map.
column_view const null_removed_map(
data_type(type_to_id<size_type>()),
argmin_result.size(),
Expand Down Expand Up @@ -251,10 +248,7 @@ void aggregate_result_functor::operator()<aggregation::MAX>(aggregation const& a
operator()<aggregation::ARGMAX>(*argmax_agg);
column_view const argmax_result = cache.get_result(values, *argmax_agg);

// We make a view of ARGMAX result without a null mask and gather using
// this mask. The values in data buffer of ARGMAX result corresponding
// to null values was initialized to ARGMAX_SENTINEL which is an out of
// bounds index value and causes the gathered value to be null.
// Compute the ARGMAX result without the null mask in the gather map.
column_view const null_removed_map(
data_type(type_to_id<size_type>()),
argmax_result.size(),
Expand Down
31 changes: 15 additions & 16 deletions cpp/src/groupby/sort/group_argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,21 @@ std::unique_ptr<column> group_argmax(column_view const& values,
stream,
mr);

// The functor returns the index of maximum in the sorted values.
// We need the index of maximum in the original unsorted values.
// So use indices to gather the sort order used to sort `values`.
// Gather map cannot be null so we make a view with the mask removed.
// The values in data buffer of indices corresponding to null values was
// initialized to ARGMAX_SENTINEL. Using gather_if.
// This can't use gather because nulls in gathered column will not store ARGMAX_SENTINEL.
auto indices_view = indices->mutable_view();
thrust::gather_if(rmm::exec_policy(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
indices_view.begin<size_type>(), // stencil
key_sort_order.begin<size_type>(), // input
indices_view.begin<size_type>(), // result
[] __device__(auto i) { return (i != cudf::detail::ARGMAX_SENTINEL); });
return indices;
// The functor returns the indices of maximums based on the sorted keys.
// We need the indices of maximums from the original unsorted keys
// so we use these indices and the key_sort_order to map to the correct indices.
// We do not use cudf::gather since we can move the null-mask separately.
auto indices_view = indices->view();
auto output = rmm::device_uvector<size_type>(indices_view.size(), stream, mr);
thrust::gather(rmm::exec_policy_nosync(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
key_sort_order.begin<size_type>(), // input
output.data() // result (must not overlap map)
);
auto null_count = indices_view.null_count();
auto null_mask = indices->release().null_mask.release();
return std::make_unique<column>(std::move(output), std::move(*null_mask), null_count);
}

} // namespace detail
Expand Down
32 changes: 16 additions & 16 deletions cpp/src/groupby/sort/group_argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>

#include <thrust/gather.h>

Expand All @@ -42,22 +43,21 @@ std::unique_ptr<column> group_argmin(column_view const& values,
stream,
mr);

// The functor returns the index of minimum in the sorted values.
// We need the index of minimum in the original unsorted values.
// So use indices to gather the sort order used to sort `values`.
// The values in data buffer of indices corresponding to null values was
// initialized to ARGMIN_SENTINEL. Using gather_if.
// This can't use gather because nulls in gathered column will not store ARGMIN_SENTINEL.
auto indices_view = indices->mutable_view();
thrust::gather_if(rmm::exec_policy(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
indices_view.begin<size_type>(), // stencil
key_sort_order.begin<size_type>(), // input
indices_view.begin<size_type>(), // result
[] __device__(auto i) { return (i != cudf::detail::ARGMIN_SENTINEL); });

return indices;
// The functor returns the indices of minimums based on the sorted keys.
// We need the indices of minimums from the original unsorted keys
// so we use these and the key_sort_order to map to the correct indices.
// We do not use cudf::gather since we can move the null-mask separately.
auto indices_view = indices->view();
auto output = rmm::device_uvector<size_type>(indices_view.size(), stream, mr);
thrust::gather(rmm::exec_policy_nosync(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
key_sort_order.begin<size_type>(), // input
output.data() // result (must not overlap map)
);
auto null_count = indices_view.null_count();
auto null_mask = indices->release().null_mask.release();
return std::make_unique<column>(std::move(output), std::move(*null_mask), null_count);
}

} // namespace detail
Expand Down

0 comments on commit 0f1bae8

Please sign in to comment.