Skip to content

Commit

Permalink
Improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Nov 10, 2024
1 parent 8d5bf12 commit be3f83d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 21 deletions.
6 changes: 3 additions & 3 deletions include/cuco/detail/open_addressing/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,11 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void retrieve(InputProbeIt input_probe,

auto const block = cg::this_thread_block();
auto constexpr tiles_in_block = BlockSize / Ref::cg_size;
// make sure all but the last block are always occupied
auto const items_per_block = detail::int_div_ceil(n, tiles_in_block * gridDim.x) * tiles_in_block;
auto const items_per_block = tiles_in_block;

auto const block_begin_offset = block.group_index().x * items_per_block;
auto const block_end_offset = min(n, block_begin_offset + items_per_block);
auto const block_end_offset =
min(n, static_cast<cuco::detail::index_type>(block_begin_offset + items_per_block));

if (block_begin_offset < block_end_offset) {
if constexpr (IsOuter) {
Expand Down
31 changes: 13 additions & 18 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ class open_addressing_ref_impl {

auto constexpr num_flushing_tiles = BlockSize / flushing_tile_size;
auto constexpr max_matches_per_step = flushing_tile_size * bucket_size;
auto constexpr buffer_size = buffer_multiplier * max_matches_per_step;
auto constexpr buffer_size = buffer_multiplier * max_matches_per_step + flushing_tile_size;

auto const flushing_tile = cg::tiled_partition<flushing_tile_size>(block);
auto const probing_tile = cg::tiled_partition<probing_tile_size>(block);
Expand All @@ -1144,11 +1144,10 @@ class open_addressing_ref_impl {
auto idx = probing_tile.meta_group_rank();

// TODO align to 16B?
__shared__ probe_type probe_buffers[num_flushing_tiles][buffer_size];
__shared__ value_type match_buffers[num_flushing_tiles][buffer_size];
__shared__ cuco::pair<probe_type, value_type> buffers[num_flushing_tiles][buffer_size];
size_type num_matches = 0;

auto flush_buffers = [&](cg::coalesced_group const& tile) {
auto flush_buffers = [&](auto const& tile) {
auto const rank = tile.thread_rank();

#if defined(CUCO_HAS_CG_INVOKE_ONE)
Expand All @@ -1165,8 +1164,8 @@ class open_addressing_ref_impl {

// flush_buffers
for (size_type i = rank; i < num_matches; i += tile.size()) {
*(output_probe + offset + i) = probe_buffers[flushing_tile_id][i];
*(output_match + offset + i) = match_buffers[flushing_tile_id][i];
*(output_probe + offset + i) = buffers[flushing_tile_id][i].first;
*(output_match + offset + i) = buffers[flushing_tile_id][i].second;
}
};

Expand Down Expand Up @@ -1213,10 +1212,8 @@ class open_addressing_ref_impl {
auto const matching_tile = cg::binary_partition(active_flushing_tile, match_found);
// stage matches in shmem buffer
if (match_found) {
probe_buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] =
probe_key;
match_buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] =
bucket_slots[i];
buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] = {
probe_key, bucket_slots[i]};
}

// add number of new matches to the buffer counter
Expand Down Expand Up @@ -1246,9 +1243,9 @@ class open_addressing_ref_impl {
auto const sentinel_writers =
cg::binary_partition(active_flushing_tile, writes_sentinel);
if (writes_sentinel) {
auto const rank = sentinel_writers.thread_rank();
probe_buffers[flushing_tile_id][num_matches + rank] = probe_key;
match_buffers[flushing_tile_id][num_matches + rank] = this->empty_slot_sentinel();
auto const rank = sentinel_writers.thread_rank();
buffers[flushing_tile_id][num_matches + rank] = {probe_key,
this->empty_slot_sentinel()};
}
// add number of new matches to the buffer counter
num_matches += (writes_sentinel)
Expand All @@ -1271,16 +1268,14 @@ class open_addressing_ref_impl {
// onto the next probing bucket
++probing_iter;
}

// entire flusing_tile has finished; flush remaining elements
if (num_matches != 0 and active_flushing_tile.all((idx + stride) >= n)) {
flush_buffers(active_flushing_tile);
}
}

// onto the next key
idx += stride;
}

// entire flusing_tile has finished; flush remaining elements
if (num_matches > 0) { flush_buffers(flushing_tile); }
}

/**
Expand Down

0 comments on commit be3f83d

Please sign in to comment.