diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index a202dd424..43f0c7475 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -26,6 +26,9 @@ #include #include #include +#include +#include +#include #include #if defined(CUCO_HAS_CUDA_BARRIER) #include @@ -1016,7 +1019,7 @@ class open_addressing_ref_impl { InputProbeIt input_probe_end, OutputProbeIt output_probe, OutputMatchIt output_match, - AtomicCounter& atomic_counter) const + AtomicCounter* atomic_counter) const { auto constexpr is_outer = false; auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include @@ -1065,7 +1068,7 @@ class open_addressing_ref_impl { InputProbeIt input_probe_end, OutputProbeIt output_probe, OutputMatchIt output_match, - AtomicCounter& atomic_counter) const + AtomicCounter* atomic_counter) const { auto constexpr is_outer = true; auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include @@ -1116,7 +1119,7 @@ class open_addressing_ref_impl { cuco::detail::index_type n, OutputProbeIt output_probe, OutputMatchIt output_match, - AtomicCounter& atomic_counter) const + AtomicCounter* atomic_counter) const { namespace cg = cooperative_groups; @@ -1143,26 +1146,24 @@ class open_addressing_ref_impl { auto const stride = probing_tile.meta_group_size(); auto idx = probing_tile.meta_group_rank(); - // TODO align to 16B? __shared__ cuco::pair buffers[num_flushing_tiles][buffer_size]; - size_type num_matches = 0; + __shared__ int32_t counters[num_flushing_tiles]; + + if (flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; } + flushing_tile.sync(); auto flush_buffers = [&](auto const& tile) { size_type offset = 0; - /* - if (tile.thread_rank() == 0) { - offset = atomic_counter.fetch_add(num_matches, cuda::std::memory_order_relaxed); - } - */ + auto const count = counters[flushing_tile_id]; + auto const rank = tile.thread_rank(); + if (rank == 0) { offset = atomic_counter->fetch_add(count, cuda::memory_order_relaxed); } offset = tile.shfl(offset, 0); - /* // flush_buffers - for (size_type i = rank; i < num_matches; i += tile.size()) { + for (auto i = rank; i < count; i += tile.size()) { *(output_probe + offset + i) = buffers[flushing_tile_id][i].first; *(output_match + offset + i) = buffers[flushing_tile_id][i].second; } - */ }; while (flushing_tile.any(idx < n)) { @@ -1176,102 +1177,116 @@ class open_addressing_ref_impl { auto const& probe_key = *(input_probe + idx); auto probing_iter = this->probing_scheme_(probing_tile, probe_key, this->storage_ref_.bucket_extent()); - bool running = true; - bool match_found = false; - [[maybe_unused]] bool found_any_match = false; // only needed if `IsOuter == true` - - while (true) { - // TODO atomic_ref::load if insert operator is present - auto const bucket_slots = this->storage_ref_[*probing_iter]; - - for (int32_t i = 0; i < bucket_size; ++i) { - if (running) { - // inspect slot content - switch (this->predicate_.operator()( - probe_key, this->extract_key(bucket_slots[i]))) { - case detail::equal_result::EMPTY: { - running = false; - break; - } - case detail::equal_result::EQUAL: { - if constexpr (!AllowsDuplicates) { running = false; } - match_found = true; - break; - } - default: { - break; + + bool running = true; + [[maybe_unused]] bool found_match = false; + + bool equals[buffer_size]; + uint32_t exists[buffer_size]; + + while (active_flushing_tile.any(running)) { + if (running) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + +#pragma unroll buffer_size + for (int32_t i = 0; i < bucket_size; ++i) { + equals[i] = false; + if (running) { + // inspect slot content + switch (this->predicate_.operator()( + probe_key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::EMPTY: { + running = false; + break; + } + case detail::equal_result::EQUAL: { + if constexpr (!AllowsDuplicates) { running = false; } + equals[i] = true; + break; + } + default: { + break; + } } } } - if (active_flushing_tile.any(match_found)) { - auto const matching_tile = cg::binary_partition(active_flushing_tile, match_found); - // stage matches in shmem buffer - if (match_found) { - buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] = { - probe_key, bucket_slots[i]}; + probing_tile.sync(); + running = probing_tile.all(running); +#pragma unroll buffer_size + for (int32_t i = 0; i < bucket_size; ++i) { + exists[i] = probing_tile.ballot(equals[i]); + } + + if (thrust::any_of(thrust::seq, exists, exists + bucket_size, thrust::identity{})) { + if constexpr (IsOuter) { found_match = true; } + + int32_t num_matches[bucket_size]; + + for (int32_t i = 0; i < bucket_size; ++i) { + num_matches[i] = __popc(exists[i]); } - // add number of new matches to the buffer counter - num_matches += (match_found) ? matching_tile.size() - : active_flushing_tile.size() - matching_tile.size(); - } + auto const total_matches = + thrust::reduce(thrust::seq, num_matches, num_matches + bucket_size); - if constexpr (IsOuter) { - if (not found_any_match /*yet*/ and probing_tile.any(match_found) /*now*/) { - found_any_match = true; + int32_t output_idx; + if (probing_tile.thread_rank() == 0) { + auto ref = + cuda::atomic_ref{counters[flushing_tile_id]}; + output_idx = ref.fetch_add(total_matches, cuda::memory_order_relaxed); + } + output_idx = probing_tile.shfl(output_idx, 0); + + int32_t matche_offset = 0; +#pragma unroll buffer_size + for (int32_t i = 0; i < bucket_size; ++i) { + if (equals[i]) { + auto const lane_offset = + detail::count_least_significant_bits(exists[i], probing_tile.thread_rank()); + buffers[flushing_tile_id][output_idx + matche_offset + lane_offset] = { + probe_key, bucket_slots[i]}; + } + matche_offset += num_matches[i]; } } - // reset flag for next iteration - match_found = false; - } - running = probing_tile.all(running); - - // check if all probing tiles have finished their work - bool const finished = !active_flushing_tile.any(running); - - if constexpr (IsOuter) { - if (finished) { - bool const writes_sentinel = - ((probing_tile.thread_rank() == 0) and not found_any_match); - - auto const sentinel_writers = - cg::binary_partition(active_flushing_tile, writes_sentinel); - if (writes_sentinel) { - auto const rank = sentinel_writers.thread_rank(); - buffers[flushing_tile_id][num_matches + rank] = {probe_key, - this->empty_slot_sentinel()}; + if constexpr (IsOuter) { + if (!running) { + if (!found_match and probing_tile.thread_rank() == 0) { + auto ref = + cuda::atomic_ref{counters[flushing_tile_id]}; + auto const output_idx = ref.fetch_add(1, cuda::memory_order_relaxed); + buffers[flushing_tile_id][output_idx] = {probe_key, this->empty_slot_sentinel()}; + } } - // add number of new matches to the buffer counter - num_matches += (writes_sentinel) - ? sentinel_writers.size() - : active_flushing_tile.size() - sentinel_writers.size(); } - } + } // if running + active_flushing_tile.sync(); // if the buffer has not enough empty slots for the next iteration - if (num_matches > (buffer_size - max_matches_per_step)) { + if (counters[flushing_tile_id] > (buffer_size - max_matches_per_step)) { flush_buffers(active_flushing_tile); + active_flushing_tile.sync(); // reset buffer counter - num_matches = 0; + if (active_flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; } + active_flushing_tile.sync(); } - // the entire flushing tile has finished its work - if (finished) { break; } - // onto the next probing bucket ++probing_iter; - } - } + } // while running + } // if active_flag // onto the next key idx += stride; } + flushing_tile.sync(); // entire flusing_tile has finished; flush remaining elements - if (num_matches > 0) { flush_buffers(flushing_tile); } + if (counters[flushing_tile_id] > 0) { flush_buffers(flushing_tile); } } /**