From d295ecd132d6379beca8f7d6acc3ad3c471b8017 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 13 Nov 2023 15:46:07 -0800 Subject: [PATCH] Add CG insert_and_find tests --- .../open_addressing_ref_impl.cuh | 34 +++++++ tests/static_map/insert_and_find_test.cu | 96 +++++++++++++------ 2 files changed, 101 insertions(+), 29 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 5b44fa066..b7880ca8d 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -385,6 +385,12 @@ class open_addressing_ref_impl { __device__ thrust::pair insert_and_find(Value const& value) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); +#if __CUDA_ARCH__ < 700 + // Spinning to ensure that the write to the value part took place requires + // independent thread scheduling introduced with the Volta architecture. + static_assert(cuco::detail::is_packable(), + "insert_and_find is not supported for unpackable data on pre-Volta GPUs."); +#endif auto const key = this->extract_key(value); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); @@ -399,6 +405,7 @@ class open_addressing_ref_impl { // If the key is already in the container, return false if (eq_res == detail::equal_result::EQUAL) { if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second); } return {iterator{&window_ptr[i]}, false}; @@ -418,12 +425,14 @@ class open_addressing_ref_impl { switch (res) { case insert_result::SUCCESS: { if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second); } return {iterator{&window_ptr[i]}, true}; } case insert_result::DUPLICATE: { if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second); } return {iterator{&window_ptr[i]}, false}; @@ -455,6 +464,13 @@ class open_addressing_ref_impl { __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { +#if __CUDA_ARCH__ < 700 + // Spinning to ensure that the write to the value part took place requires + // independent thread scheduling introduced with the Volta architecture. + static_assert(cuco::detail::is_packable(), + "insert_and_find is not supported for unpackable data on pre-Volta GPUs."); +#endif + auto const key = this->extract_key(value); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); @@ -489,6 +505,12 @@ class open_addressing_ref_impl { if (group_finds_equal) { auto const src_lane = __ffs(group_finds_equal) - 1; auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); + if (group.thread_rank() == src_lane) { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second); + } + } return {iterator{reinterpret_cast(res)}, false}; } @@ -508,9 +530,21 @@ class open_addressing_ref_impl { switch (group.shfl(status, src_lane)) { case insert_result::SUCCESS: { + if (group.thread_rank() == src_lane) { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second); + } + } return {iterator{reinterpret_cast(res)}, true}; } case insert_result::DUPLICATE: { + if (group.thread_rank() == src_lane) { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second); + } + } return {iterator{reinterpret_cast(res)}, false}; } default: continue; diff --git a/tests/static_map/insert_and_find_test.cu b/tests/static_map/insert_and_find_test.cu index 8e6cddfdc..6d0e2567b 100644 --- a/tests/static_map/insert_and_find_test.cu +++ b/tests/static_map/insert_and_find_test.cu @@ -26,7 +26,7 @@ #include -static constexpr int Iters = 10; +static constexpr int Iters = 10'000; template __global__ void parallel_sum(Ref v) @@ -36,15 +36,27 @@ __global__ void parallel_sum(Ref v) if constexpr (cuco::detail::is_packable()) #endif { - auto [iter, inserted] = v.insert_and_find(cuco::pair{i, 1}); - // for debugging... - // if (iter->second < 0) { - // asm("trap;"); - // } - if (!inserted) { - auto ref = - cuda::atomic_ref{iter->second}; - ref.fetch_add(1); + auto constexpr cg_size = Ref::cg_size; + if constexpr (cg_size == 1) { + auto [iter, inserted] = v.insert_and_find(cuco::pair{i, 1}); + // for debugging... + // if (iter->second < 0) { + // asm("trap;"); + // } + if (!inserted) { + auto ref = + cuda::atomic_ref{iter->second}; + ref.fetch_add(1); + } + } else { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + auto [iter, inserted] = v.insert_and_find(tile, cuco::pair{i, 1}); + if (!inserted and tile.thread_rank() == 0) { + auto ref = + cuda::atomic_ref{iter->second}; + ref.fetch_add(1); + } } } #if __CUDA_ARCH__ < 700 @@ -55,36 +67,62 @@ __global__ void parallel_sum(Ref v) } } -TEMPLATE_TEST_CASE_SIG("Parallel insert-or-update", - "", - ((typename Key, typename Value), Key, Value), - (int32_t, int32_t), - (int32_t, int64_t), - (int64_t, int32_t), - (int64_t, int64_t)) +TEMPLATE_TEST_CASE_SIG( + "static_map insert_and_find tests", + "", + ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), + Key, + Value, + Probe, + CGSize), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2)) { - cuco::empty_key empty_key_sentinel{-1}; - cuco::empty_value empty_value_sentinel{-1}; - auto m = cuco::experimental::static_map{ - 10 * Iters, - empty_key_sentinel, - empty_value_sentinel, - thrust::equal_to{}, - cuco::experimental::linear_probing<1, cuco::murmurhash3_32>{}}; + using probe = + std::conditional_t>, + cuco::experimental::double_hashing, + cuco::murmurhash3_32>>; + + auto map = cuco::experimental::static_map, + cuda::thread_scope_device, + thrust::equal_to, + probe, + cuco::cuda_allocator, + cuco::experimental::storage<2>>{ + 10 * Iters, cuco::empty_key{-1}, cuco::empty_value{-1}}; static constexpr int Blocks = 1024; static constexpr int Threads = 128; parallel_sum<<>>( - m.ref(cuco::experimental::op::insert, cuco::experimental::op::insert_and_find)); + map.ref(cuco::experimental::op::insert, cuco::experimental::op::insert_and_find)); CUCO_CUDA_TRY(cudaDeviceSynchronize()); thrust::device_vector d_keys(Iters); thrust::device_vector d_values(Iters); thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); - m.find(d_keys.begin(), d_keys.end(), d_values.begin()); + map.find(d_keys.begin(), d_keys.end(), d_values.begin()); - REQUIRE(cuco::test::all_of( - d_values.begin(), d_values.end(), [] __device__(Value v) { return v == Blocks * Threads; })); + REQUIRE(cuco::test::all_of(d_values.begin(), d_values.end(), [] __device__(Value v) { + return v == (Blocks * Threads) / CGSize; + })); }