Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix groupby argmin/max gather of sorted-order indices #17591

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions cpp/src/groupby/sort/group_argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,21 @@ std::unique_ptr<column> group_argmax(column_view const& values,
stream,
mr);

// The functor returns the index of maximum in the sorted values.
// We need the index of maximum in the original unsorted values.
// So use indices to gather the sort order used to sort `values`.
// Gather map cannot be null so we make a view with the mask removed.
// The values in data buffer of indices corresponding to null values was
// initialized to ARGMAX_SENTINEL. Using gather_if.
// This can't use gather because nulls in gathered column will not store ARGMAX_SENTINEL.
auto indices_view = indices->mutable_view();
thrust::gather_if(rmm::exec_policy(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
indices_view.begin<size_type>(), // stencil
key_sort_order.begin<size_type>(), // input
indices_view.begin<size_type>(), // result
[] __device__(auto i) { return (i != cudf::detail::ARGMAX_SENTINEL); });
return indices;
// The functor returns the indices of maximums in the sorted values.
// We need the indices of maximums from the original unsorted values
// so we use these indices to gather the sorted order values.
// We cannot use cudf::gather since indices should not have nulls.
auto indices_view = indices->view();
auto output = rmm::device_uvector<size_type>(indices_view.size(), stream, mr);
thrust::gather(rmm::exec_policy(stream),
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
key_sort_order.begin<size_type>(), // input
output.data() // result (most not overlap map)
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
);
auto null_count = indices_view.null_count();
auto null_mask = indices->release().null_mask.release();
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
return std::make_unique<column>(std::move(output), std::move(*null_mask), null_count);
}

} // namespace detail
Expand Down
35 changes: 17 additions & 18 deletions cpp/src/groupby/sort/group_argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

#include "groupby/sort/group_single_pass_reduction_util.cuh"

#include <cudf/detail/gather.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>

#include <thrust/gather.h>

Expand All @@ -40,24 +40,23 @@ std::unique_ptr<column> group_argmin(column_view const& values,
num_groups,
group_labels,
stream,
mr);
cudf::get_current_device_resource_ref());
davidwendt marked this conversation as resolved.
Show resolved Hide resolved

// The functor returns the index of minimum in the sorted values.
// We need the index of minimum in the original unsorted values.
// So use indices to gather the sort order used to sort `values`.
// The values in data buffer of indices corresponding to null values was
// initialized to ARGMIN_SENTINEL. Using gather_if.
// This can't use gather because nulls in gathered column will not store ARGMIN_SENTINEL.
auto indices_view = indices->mutable_view();
thrust::gather_if(rmm::exec_policy(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
indices_view.begin<size_type>(), // stencil
key_sort_order.begin<size_type>(), // input
indices_view.begin<size_type>(), // result
[] __device__(auto i) { return (i != cudf::detail::ARGMIN_SENTINEL); });

return indices;
// The functor returns the indices of minimums in the sorted values.
// We need the indices of minimums from the original unsorted values
// so we use these indices to gather the sorted order values.
// We cannot use cudf::gather since indices should not have nulls.
auto indices_view = indices->view();
auto output = rmm::device_uvector<size_type>(indices_view.size(), stream, mr);
thrust::gather(rmm::exec_policy(stream),
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
key_sort_order.begin<size_type>(), // input
output.data() // result (most not overlap map)
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
);
auto null_count = indices_view.null_count();
auto null_mask = indices->release().null_mask.release();
return std::make_unique<column>(std::move(output), std::move(*null_mask), null_count);
}

} // namespace detail
Expand Down
Loading