Skip to content

Commit

Permalink
Minor cleanups + fix an int overflow bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Aug 15, 2024
1 parent f7fa99f commit e0ed502
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions include/cuco/detail/open_addressing/open_addressing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -639,28 +639,29 @@ class open_addressing_impl {
template <typename OutputIt>
[[nodiscard]] OutputIt retrieve_all(OutputIt output_begin, cuda::stream_ref stream) const
{
std::size_t temp_storage_bytes = 0;
using temp_allocator_type =
typename std::allocator_traits<allocator_type>::template rebind_alloc<char>;
auto temp_allocator = temp_allocator_type{this->allocator()};
auto d_num_out = reinterpret_cast<size_type*>(
std::allocator_traits<temp_allocator_type>::allocate(temp_allocator, sizeof(size_type)));

cuco::detail::index_type constexpr stride = std::numeric_limits<int32_t>::max();

cuco::detail::index_type h_num_out{0};
auto temp_allocator = temp_allocator_type{this->allocator()};
auto d_num_out = reinterpret_cast<size_type*>(
std::allocator_traits<temp_allocator_type>::allocate(temp_allocator, sizeof(size_type)));

for (cuco::detail::index_type offset = 0;
offset < static_cast<cuco::detail::index_type>(this->capacity());
offset += stride) {
auto const num_items =
std::min(static_cast<cuco::detail::index_type>(this->capacity()) - offset, stride);
auto const begin = thrust::make_transform_iterator(
thrust::counting_iterator{static_cast<int32_t>(offset)},
thrust::counting_iterator{static_cast<size_type>(offset)},
open_addressing_ns::detail::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
this->empty_key_sentinel(), this->erased_key_sentinel()};

std::size_t temp_storage_bytes = 0;

CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr,
temp_storage_bytes,
begin,
Expand Down Expand Up @@ -692,6 +693,7 @@ class open_addressing_impl {

std::allocator_traits<temp_allocator_type>::deallocate(
temp_allocator, reinterpret_cast<char*>(d_num_out), sizeof(size_type));

return output_begin + h_num_out;
}

Expand Down

0 comments on commit e0ed502

Please sign in to comment.