diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 7a8a1883ed4..f667a3192cf 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -208,10 +208,7 @@ void aggregate_result_functor::operator()(aggregation const& a operator()(*argmin_agg); column_view const argmin_result = cache.get_result(values, *argmin_agg); - // We make a view of ARGMIN result without a null mask and gather using - // this mask. The values in data buffer of ARGMIN result corresponding - // to null values was initialized to ARGMIN_SENTINEL which is an out of - // bounds index value and causes the gathered value to be null. + // Compute the ARGMIN result without the null mask in the gather map. column_view const null_removed_map( data_type(type_to_id()), argmin_result.size(), @@ -250,10 +247,7 @@ void aggregate_result_functor::operator()(aggregation const& a operator()(*argmax_agg); column_view const argmax_result = cache.get_result(values, *argmax_agg); - // We make a view of ARGMAX result without a null mask and gather using - // this mask. The values in data buffer of ARGMAX result corresponding - // to null values was initialized to ARGMAX_SENTINEL which is an out of - // bounds index value and causes the gathered value to be null. + // Compute the ARGMAX result without the null mask in the gather map. column_view const null_removed_map( data_type(type_to_id()), argmax_result.size(), diff --git a/cpp/src/groupby/sort/group_argmax.cu b/cpp/src/groupby/sort/group_argmax.cu index 7dce341130e..329c7c4eb32 100644 --- a/cpp/src/groupby/sort/group_argmax.cu +++ b/cpp/src/groupby/sort/group_argmax.cu @@ -42,22 +42,21 @@ std::unique_ptr group_argmax(column_view const& values, stream, mr); - // The functor returns the index of maximum in the sorted values. - // We need the index of maximum in the original unsorted values. - // So use indices to gather the sort order used to sort `values`. - // Gather map cannot be null so we make a view with the mask removed. - // The values in data buffer of indices corresponding to null values was - // initialized to ARGMAX_SENTINEL. Using gather_if. - // This can't use gather because nulls in gathered column will not store ARGMAX_SENTINEL. - auto indices_view = indices->mutable_view(); - thrust::gather_if(rmm::exec_policy(stream), - indices_view.begin(), // map first - indices_view.end(), // map last - indices_view.begin(), // stencil - key_sort_order.begin(), // input - indices_view.begin(), // result - [] __device__(auto i) { return (i != cudf::detail::ARGMAX_SENTINEL); }); - return indices; + // The functor returns the indices of maximums based on the sorted keys. + // We need the indices of maximums from the original unsorted keys + // so we use these indices and the key_sort_order to map to the correct indices. + // We do not use cudf::gather since we can move the null-mask separately. + auto indices_view = indices->view(); + auto output = rmm::device_uvector(indices_view.size(), stream, mr); + thrust::gather(rmm::exec_policy_nosync(stream), + indices_view.begin(), // map first + indices_view.end(), // map last + key_sort_order.begin(), // input + output.data() // result (must not overlap map) + ); + auto null_count = indices_view.null_count(); + auto null_mask = indices->release().null_mask.release(); + return std::make_unique(std::move(output), std::move(*null_mask), null_count); } } // namespace detail diff --git a/cpp/src/groupby/sort/group_argmin.cu b/cpp/src/groupby/sort/group_argmin.cu index c4bed330b9f..dbfc375fc20 100644 --- a/cpp/src/groupby/sort/group_argmin.cu +++ b/cpp/src/groupby/sort/group_argmin.cu @@ -21,6 +21,7 @@ #include #include +#include #include @@ -42,22 +43,21 @@ std::unique_ptr group_argmin(column_view const& values, stream, mr); - // The functor returns the index of minimum in the sorted values. - // We need the index of minimum in the original unsorted values. - // So use indices to gather the sort order used to sort `values`. - // The values in data buffer of indices corresponding to null values was - // initialized to ARGMIN_SENTINEL. Using gather_if. - // This can't use gather because nulls in gathered column will not store ARGMIN_SENTINEL. - auto indices_view = indices->mutable_view(); - thrust::gather_if(rmm::exec_policy(stream), - indices_view.begin(), // map first - indices_view.end(), // map last - indices_view.begin(), // stencil - key_sort_order.begin(), // input - indices_view.begin(), // result - [] __device__(auto i) { return (i != cudf::detail::ARGMIN_SENTINEL); }); - - return indices; + // The functor returns the indices of minimums based on the sorted keys. + // We need the indices of minimums from the original unsorted keys + // so we use these and the key_sort_order to map to the correct indices. + // We do not use cudf::gather since we can move the null-mask separately. + auto indices_view = indices->view(); + auto output = rmm::device_uvector(indices_view.size(), stream, mr); + thrust::gather(rmm::exec_policy_nosync(stream), + indices_view.begin(), // map first + indices_view.end(), // map last + key_sort_order.begin(), // input + output.data() // result (must not overlap map) + ); + auto null_count = indices_view.null_count(); + auto null_mask = indices->release().null_mask.release(); + return std::make_unique(std::move(output), std::move(*null_mask), null_count); } } // namespace detail