Skip to content

Commit

Permalink
Improve OA retrieve implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Nov 18, 2024
1 parent 5db1066 commit edd129e
Showing 1 changed file with 97 additions and 82 deletions.
179 changes: 97 additions & 82 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include <cuda/atomic>
#include <cuda/std/type_traits>
#include <thrust/distance.h>
#include <thrust/execution_policy.h>
#include <thrust/logical.h>
#include <thrust/reduce.h>
#include <thrust/tuple.h>
#if defined(CUCO_HAS_CUDA_BARRIER)
#include <cuda/barrier>
Expand Down Expand Up @@ -1016,7 +1019,7 @@ class open_addressing_ref_impl {
InputProbeIt input_probe_end,
OutputProbeIt output_probe,
OutputMatchIt output_match,
AtomicCounter& atomic_counter) const
AtomicCounter* atomic_counter) const
{
auto constexpr is_outer = false;
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include
Expand Down Expand Up @@ -1065,7 +1068,7 @@ class open_addressing_ref_impl {
InputProbeIt input_probe_end,
OutputProbeIt output_probe,
OutputMatchIt output_match,
AtomicCounter& atomic_counter) const
AtomicCounter* atomic_counter) const
{
auto constexpr is_outer = true;
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include
Expand Down Expand Up @@ -1116,7 +1119,7 @@ class open_addressing_ref_impl {
cuco::detail::index_type n,
OutputProbeIt output_probe,
OutputMatchIt output_match,
AtomicCounter& atomic_counter) const
AtomicCounter* atomic_counter) const
{
namespace cg = cooperative_groups;

Expand All @@ -1143,26 +1146,24 @@ class open_addressing_ref_impl {
auto const stride = probing_tile.meta_group_size();
auto idx = probing_tile.meta_group_rank();

// TODO align to 16B?
__shared__ cuco::pair<probe_type, value_type> buffers[num_flushing_tiles][buffer_size];
size_type num_matches = 0;
__shared__ int32_t counters[num_flushing_tiles];

if (flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; }
flushing_tile.sync();

auto flush_buffers = [&](auto const& tile) {
size_type offset = 0;
/*
if (tile.thread_rank() == 0) {
offset = atomic_counter.fetch_add(num_matches, cuda::std::memory_order_relaxed);
}
*/
auto const count = counters[flushing_tile_id];
auto const rank = tile.thread_rank();
if (rank == 0) { offset = atomic_counter->fetch_add(count, cuda::memory_order_relaxed); }
offset = tile.shfl(offset, 0);

/*
// flush_buffers
for (size_type i = rank; i < num_matches; i += tile.size()) {
for (auto i = rank; i < count; i += tile.size()) {
*(output_probe + offset + i) = buffers[flushing_tile_id][i].first;
*(output_match + offset + i) = buffers[flushing_tile_id][i].second;
}
*/
};

while (flushing_tile.any(idx < n)) {
Expand All @@ -1176,102 +1177,116 @@ class open_addressing_ref_impl {
auto const& probe_key = *(input_probe + idx);
auto probing_iter =
this->probing_scheme_(probing_tile, probe_key, this->storage_ref_.bucket_extent());
bool running = true;
bool match_found = false;
[[maybe_unused]] bool found_any_match = false; // only needed if `IsOuter == true`

while (true) {
// TODO atomic_ref::load if insert operator is present
auto const bucket_slots = this->storage_ref_[*probing_iter];

for (int32_t i = 0; i < bucket_size; ++i) {
if (running) {
// inspect slot content
switch (this->predicate_.operator()<is_insert::NO>(
probe_key, this->extract_key(bucket_slots[i]))) {
case detail::equal_result::EMPTY: {
running = false;
break;
}
case detail::equal_result::EQUAL: {
if constexpr (!AllowsDuplicates) { running = false; }
match_found = true;
break;
}
default: {
break;

bool running = true;
[[maybe_unused]] bool found_match = false;

bool equals[buffer_size];
uint32_t exists[buffer_size];

while (active_flushing_tile.any(running)) {
if (running) {
// TODO atomic_ref::load if insert operator is present
auto const bucket_slots = this->storage_ref_[*probing_iter];

#pragma unroll buffer_size
for (int32_t i = 0; i < bucket_size; ++i) {
equals[i] = false;
if (running) {
// inspect slot content
switch (this->predicate_.operator()<is_insert::NO>(
probe_key, this->extract_key(bucket_slots[i]))) {
case detail::equal_result::EMPTY: {
running = false;
break;
}
case detail::equal_result::EQUAL: {
if constexpr (!AllowsDuplicates) { running = false; }
equals[i] = true;
break;
}
default: {
break;
}
}
}
}

if (active_flushing_tile.any(match_found)) {
auto const matching_tile = cg::binary_partition(active_flushing_tile, match_found);
// stage matches in shmem buffer
if (match_found) {
buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] = {
probe_key, bucket_slots[i]};
probing_tile.sync();
running = probing_tile.all(running);
#pragma unroll buffer_size
for (int32_t i = 0; i < bucket_size; ++i) {
exists[i] = probing_tile.ballot(equals[i]);
}

if (thrust::any_of(thrust::seq, exists, exists + bucket_size, thrust::identity{})) {
if constexpr (IsOuter) { found_match = true; }

int32_t num_matches[bucket_size];

for (int32_t i = 0; i < bucket_size; ++i) {
num_matches[i] = __popc(exists[i]);
}

// add number of new matches to the buffer counter
num_matches += (match_found) ? matching_tile.size()
: active_flushing_tile.size() - matching_tile.size();
}
auto const total_matches =
thrust::reduce(thrust::seq, num_matches, num_matches + bucket_size);

if constexpr (IsOuter) {
if (not found_any_match /*yet*/ and probing_tile.any(match_found) /*now*/) {
found_any_match = true;
int32_t output_idx;
if (probing_tile.thread_rank() == 0) {
auto ref =
cuda::atomic_ref<int32_t, cuda::thread_scope_block>{counters[flushing_tile_id]};
output_idx = ref.fetch_add(total_matches, cuda::memory_order_relaxed);
}
output_idx = probing_tile.shfl(output_idx, 0);

int32_t matche_offset = 0;
#pragma unroll buffer_size
for (int32_t i = 0; i < bucket_size; ++i) {
if (equals[i]) {
auto const lane_offset =
detail::count_least_significant_bits(exists[i], probing_tile.thread_rank());
buffers[flushing_tile_id][output_idx + matche_offset + lane_offset] = {
probe_key, bucket_slots[i]};
}
matche_offset += num_matches[i];
}
}

// reset flag for next iteration
match_found = false;
}
running = probing_tile.all(running);

// check if all probing tiles have finished their work
bool const finished = !active_flushing_tile.any(running);

if constexpr (IsOuter) {
if (finished) {
bool const writes_sentinel =
((probing_tile.thread_rank() == 0) and not found_any_match);

auto const sentinel_writers =
cg::binary_partition(active_flushing_tile, writes_sentinel);
if (writes_sentinel) {
auto const rank = sentinel_writers.thread_rank();
buffers[flushing_tile_id][num_matches + rank] = {probe_key,
this->empty_slot_sentinel()};
if constexpr (IsOuter) {
if (!running) {
if (!found_match and probing_tile.thread_rank() == 0) {
auto ref =
cuda::atomic_ref<int32_t, cuda::thread_scope_block>{counters[flushing_tile_id]};
auto const output_idx = ref.fetch_add(1, cuda::memory_order_relaxed);
buffers[flushing_tile_id][output_idx] = {probe_key, this->empty_slot_sentinel()};
}
}
// add number of new matches to the buffer counter
num_matches += (writes_sentinel)
? sentinel_writers.size()
: active_flushing_tile.size() - sentinel_writers.size();
}
}
} // if running

active_flushing_tile.sync();
// if the buffer has not enough empty slots for the next iteration
if (num_matches > (buffer_size - max_matches_per_step)) {
if (counters[flushing_tile_id] > (buffer_size - max_matches_per_step)) {
flush_buffers(active_flushing_tile);
active_flushing_tile.sync();

// reset buffer counter
num_matches = 0;
if (active_flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; }
active_flushing_tile.sync();
}

// the entire flushing tile has finished its work
if (finished) { break; }

// onto the next probing bucket
++probing_iter;
}
}
} // while running
} // if active_flag

// onto the next key
idx += stride;
}

flushing_tile.sync();
// entire flusing_tile has finished; flush remaining elements
if (num_matches > 0) { flush_buffers(flushing_tile); }
if (counters[flushing_tile_id] > 0) { flush_buffers(flushing_tile); }
}

/**
Expand Down

0 comments on commit edd129e

Please sign in to comment.