diff --git a/include/cuco/detail/open_addressing/kernels.cuh b/include/cuco/detail/open_addressing/kernels.cuh index b0457c071..9c1f4079f 100644 --- a/include/cuco/detail/open_addressing/kernels.cuh +++ b/include/cuco/detail/open_addressing/kernels.cuh @@ -445,11 +445,11 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void retrieve(InputProbeIt input_probe, auto const block = cg::this_thread_block(); auto constexpr tiles_in_block = BlockSize / Ref::cg_size; - // make sure all but the last block are always occupied - auto const items_per_block = detail::int_div_ceil(n, tiles_in_block * gridDim.x) * tiles_in_block; + auto const items_per_block = tiles_in_block; auto const block_begin_offset = block.group_index().x * items_per_block; - auto const block_end_offset = min(n, block_begin_offset + items_per_block); + auto const block_end_offset = + min(n, static_cast(block_begin_offset + items_per_block)); if (block_begin_offset < block_end_offset) { if constexpr (IsOuter) { diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index f1c0194ca..03810531b 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -1134,7 +1134,7 @@ class open_addressing_ref_impl { auto constexpr num_flushing_tiles = BlockSize / flushing_tile_size; auto constexpr max_matches_per_step = flushing_tile_size * bucket_size; - auto constexpr buffer_size = buffer_multiplier * max_matches_per_step; + auto constexpr buffer_size = buffer_multiplier * max_matches_per_step + flushing_tile_size; auto const flushing_tile = cg::tiled_partition(block); auto const probing_tile = cg::tiled_partition(block); @@ -1144,11 +1144,10 @@ class open_addressing_ref_impl { auto idx = probing_tile.meta_group_rank(); // TODO align to 16B? - __shared__ probe_type probe_buffers[num_flushing_tiles][buffer_size]; - __shared__ value_type match_buffers[num_flushing_tiles][buffer_size]; + __shared__ cuco::pair buffers[num_flushing_tiles][buffer_size]; size_type num_matches = 0; - auto flush_buffers = [&](cg::coalesced_group const& tile) { + auto flush_buffers = [&](auto const& tile) { auto const rank = tile.thread_rank(); #if defined(CUCO_HAS_CG_INVOKE_ONE) @@ -1165,8 +1164,8 @@ class open_addressing_ref_impl { // flush_buffers for (size_type i = rank; i < num_matches; i += tile.size()) { - *(output_probe + offset + i) = probe_buffers[flushing_tile_id][i]; - *(output_match + offset + i) = match_buffers[flushing_tile_id][i]; + *(output_probe + offset + i) = buffers[flushing_tile_id][i].first; + *(output_match + offset + i) = buffers[flushing_tile_id][i].second; } }; @@ -1213,10 +1212,8 @@ class open_addressing_ref_impl { auto const matching_tile = cg::binary_partition(active_flushing_tile, match_found); // stage matches in shmem buffer if (match_found) { - probe_buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] = - probe_key; - match_buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] = - bucket_slots[i]; + buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] = { + probe_key, bucket_slots[i]}; } // add number of new matches to the buffer counter @@ -1246,9 +1243,9 @@ class open_addressing_ref_impl { auto const sentinel_writers = cg::binary_partition(active_flushing_tile, writes_sentinel); if (writes_sentinel) { - auto const rank = sentinel_writers.thread_rank(); - probe_buffers[flushing_tile_id][num_matches + rank] = probe_key; - match_buffers[flushing_tile_id][num_matches + rank] = this->empty_slot_sentinel(); + auto const rank = sentinel_writers.thread_rank(); + buffers[flushing_tile_id][num_matches + rank] = {probe_key, + this->empty_slot_sentinel()}; } // add number of new matches to the buffer counter num_matches += (writes_sentinel) @@ -1271,16 +1268,14 @@ class open_addressing_ref_impl { // onto the next probing bucket ++probing_iter; } - - // entire flusing_tile has finished; flush remaining elements - if (num_matches != 0 and active_flushing_tile.all((idx + stride) >= n)) { - flush_buffers(active_flushing_tile); - } } // onto the next key idx += stride; } + + // entire flusing_tile has finished; flush remaining elements + if (num_matches > 0) { flush_buffers(flushing_tile); } } /**