diff --git a/include/cuco/detail/open_addressing/kernels.cuh b/include/cuco/detail/open_addressing/kernels.cuh index b0457c071..1abe320b7 100644 --- a/include/cuco/detail/open_addressing/kernels.cuh +++ b/include/cuco/detail/open_addressing/kernels.cuh @@ -399,77 +399,6 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first, } } -/** - * @brief Retrieves the equivalent container elements of all keys in the range `[input_probe, - * input_probe + n)`. - * - * If key `k = *(input_probe + i)` has one or more matches in the container, copies `k` to - * `output_probe` and associated slot contents to `output_match`, respectively. The output order is - * unspecified. - * - * @tparam IsOuter Flag indicating whether it's an outer count or not - * @tparam block_size The size of the thread block - * @tparam InputProbeIt Device accessible input iterator - * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is - * convertible to the `InputProbeIt`'s `value_type` - * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is - * convertible to the container's `value_type` - * @tparam AtomicCounter Integral atomic type that follows the same semantics as - * `cuda::(std::)atomic(_ref)` - * @tparam Ref Type of non-owning device ref allowing access to storage - * - * @param input_probe Beginning of the sequence of input keys - * @param n Number of the keys to query - * @param output_probe Beginning of the sequence of keys corresponding to matching elements in - * `output_match` - * @param output_match Beginning of the sequence of matching elements - * @param atomic_counter Pointer to an atomic object of integral type that is used to count the - * number of output elements - * @param ref Non-owning container device ref used to access the slot storage - */ -template -CUCO_KERNEL __launch_bounds__(BlockSize) void retrieve(InputProbeIt input_probe, - cuco::detail::index_type n, - OutputProbeIt output_probe, - OutputMatchIt output_match, - AtomicCounter* atomic_counter, - Ref ref) -{ - namespace cg = cooperative_groups; - - auto const block = cg::this_thread_block(); - auto constexpr tiles_in_block = BlockSize / Ref::cg_size; - // make sure all but the last block are always occupied - auto const items_per_block = detail::int_div_ceil(n, tiles_in_block * gridDim.x) * tiles_in_block; - - auto const block_begin_offset = block.group_index().x * items_per_block; - auto const block_end_offset = min(n, block_begin_offset + items_per_block); - - if (block_begin_offset < block_end_offset) { - if constexpr (IsOuter) { - ref.retrieve_outer(block, - input_probe + block_begin_offset, - input_probe + block_end_offset, - output_probe, - output_match, - *atomic_counter); - } else { - ref.retrieve(block, - input_probe + block_begin_offset, - input_probe + block_end_offset, - output_probe, - output_match, - *atomic_counter); - } - } -} - /** * @brief Inserts all elements in the range `[first, last)`. * @@ -626,6 +555,75 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void count(InputIt first, if (threadIdx.x == 0) { count->fetch_add(block_count, cuda::std::memory_order_relaxed); } } +/** + * @brief Retrieves the equivalent container elements of all keys in the range `[input_probe, + * input_probe + n)`. + * + * If key `k = *(input_probe + i)` has one or more matches in the container, copies `k` to + * `output_probe` and associated slot contents to `output_match`, respectively. The output order is + * unspecified. + * + * @tparam IsOuter Flag indicating whether it's an outer count or not + * @tparam block_size The size of the thread block + * @tparam InputProbeIt Device accessible input iterator + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * @tparam Ref Type of non-owning device ref allowing access to storage + * + * @param input_probe Beginning of the sequence of input keys + * @param n Number of the keys to query + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Pointer to an atomic object of integral type that is used to count the + * number of output elements + * @param ref Non-owning container device ref used to access the slot storage + */ +template +CUCO_KERNEL __launch_bounds__(BlockSize) void retrieve(InputProbeIt input_probe, + cuco::detail::index_type n, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter* atomic_counter, + Ref ref) +{ + auto constexpr num_tiles = BlockSize / Ref::cg_size; + // make sure all but the last block are always occupied + auto const items_per_block = n / BlockSize; + + auto const block = cooperative_groups::this_thread_block(); + auto const block_begin_offset = block.group_index().x * items_per_block; + auto const block_end_offset = min(n, block_begin_offset + items_per_block); + + if (block_begin_offset < block_end_offset) { + if constexpr (IsOuter) { + ref.retrieve_outer(block, + input_probe + block_begin_offset, + input_probe + block_end_offset, + output_probe, + output_match, + *atomic_counter); + } else { + ref.retrieve(block, + input_probe + block_begin_offset, + input_probe + block_end_offset, + output_probe, + output_match, + *atomic_counter); + } + } +} + /** * @brief Calculates the number of filled slots for the given bucket storage. * diff --git a/include/cuco/detail/open_addressing/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh index a6fd9b3c1..448d2a63b 100644 --- a/include/cuco/detail/open_addressing/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_impl.cuh @@ -1131,8 +1131,7 @@ class open_addressing_impl { auto counter = counter_type{this->allocator()}; counter.reset(stream.get()); - int32_t constexpr block_size = cuco::detail::default_block_size(); - + auto constexpr block_size = cuco::detail::default_block_size(); auto constexpr grid_stride = 1; auto const grid_size = cuco::detail::grid_size(n, cg_size, grid_stride, block_size);