From 996654185c9704d18c606b27f9380cf5547badf4 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Fri, 17 Nov 2023 16:02:31 -0800 Subject: [PATCH 1/3] Fix bugs in `insert_and_find` (#389) This PR fixes bugs in the new `insert_and_find` implementation where the function should not return before the payload is updated. It also migrates the related tests to test the new map. --------- Co-authored-by: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../open_addressing_ref_impl.cuh | 79 ++++++++++++- tests/static_map/insert_and_find_test.cu | 104 ++++++++++++++---- 2 files changed, 156 insertions(+), 27 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 683bf94b1..26e5a055a 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -385,6 +385,13 @@ class open_addressing_ref_impl { __device__ thrust::pair insert_and_find(Value const& value) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); +#if __CUDA_ARCH__ < 700 + // Spinning to ensure that the write to the value part took place requires + // independent thread scheduling introduced with the Volta architecture. + static_assert( + cuco::detail::is_packable(), + "insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs."); +#endif auto const key = this->extract_key(value); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); @@ -397,21 +404,36 @@ class open_addressing_ref_impl { auto* window_ptr = (storage_ref_.data() + *probing_iter)->data(); // If the key is already in the container, return false - if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; } + if (eq_res == detail::equal_result::EQUAL) { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second); + } + return {iterator{&window_ptr[i]}, false}; + } if (eq_res == detail::equal_result::EMPTY or cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), this->erased_key_sentinel())) { - switch ([&]() { + auto const res = [&]() { if constexpr (sizeof(value_type) <= 8) { return packed_cas(window_ptr + i, window_slots[i], value); } else { return cas_dependent_write(window_ptr + i, window_slots[i], value); } - }()) { + }(); + switch (res) { case insert_result::SUCCESS: { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second); + } return {iterator{&window_ptr[i]}, true}; } case insert_result::DUPLICATE: { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second); + } return {iterator{&window_ptr[i]}, false}; } default: continue; @@ -441,6 +463,14 @@ class open_addressing_ref_impl { __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { +#if __CUDA_ARCH__ < 700 + // Spinning to ensure that the write to the value part took place requires + // independent thread scheduling introduced with the Volta architecture. + static_assert( + cuco::detail::is_packable(), + "insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs."); +#endif + auto const key = this->extract_key(value); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); @@ -475,6 +505,13 @@ class open_addressing_ref_impl { if (group_finds_equal) { auto const src_lane = __ffs(group_finds_equal) - 1; auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); + if (group.thread_rank() == src_lane) { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second); + } + } + group.sync(); return {iterator{reinterpret_cast(res)}, false}; } @@ -494,9 +531,23 @@ class open_addressing_ref_impl { switch (group.shfl(status, src_lane)) { case insert_result::SUCCESS: { + if (group.thread_rank() == src_lane) { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second); + } + } + group.sync(); return {iterator{reinterpret_cast(res)}, true}; } case insert_result::DUPLICATE: { + if (group.thread_rank() == src_lane) { + if constexpr (has_payload) { + // wait to ensure that the write to the value part also took place + this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second); + } + } + group.sync(); return {iterator{reinterpret_cast(res)}, false}; } default: continue; @@ -1010,6 +1061,7 @@ class open_addressing_ref_impl { // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { atomic_store(&address->second, static_cast(thrust::get<1>(desired))); + return insert_result::SUCCESS; } @@ -1054,6 +1106,27 @@ class open_addressing_ref_impl { } } + /** + * @brief Waits until the slot payload has been updated + * + * @note The function will return once the slot payload is no longer equal to the sentinel value. + * + * @tparam T Map slot type + * + * @param slot The target slot to check payload with + * @param sentinel The slot sentinel value + */ + template + __device__ void wait_for_payload(T& slot, T const& sentinel) const noexcept + { + auto ref = cuda::atomic_ref{slot}; + T current; + // TODO exponential backoff strategy + do { + current = ref.load(cuda::std::memory_order_relaxed); + } while (cuco::detail::bitwise_compare(current, sentinel)); + } + // TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper value_type empty_slot_sentinel_; ///< Sentinel value indicating an empty slot detail::equal_wrapper predicate_; ///< Key equality binary callable diff --git a/tests/static_map/insert_and_find_test.cu b/tests/static_map/insert_and_find_test.cu index 5784f786f..3afc27b9a 100644 --- a/tests/static_map/insert_and_find_test.cu +++ b/tests/static_map/insert_and_find_test.cu @@ -28,52 +28,108 @@ static constexpr int Iters = 10'000; -template -__global__ void parallel_sum(View v) +template +__global__ void parallel_sum(Ref v) { for (int i = 0; i < Iters; i++) { #if __CUDA_ARCH__ < 700 - if constexpr (cuco::detail::is_packable()) + if constexpr (cuco::detail::is_packable()) #endif { - auto [iter, inserted] = v.insert_and_find(thrust::make_pair(i, 1)); - // for debugging... - // if (iter->second < 0) { - // asm("trap;"); - // } - if (!inserted) { iter->second += 1; } + auto constexpr cg_size = Ref::cg_size; + if constexpr (cg_size == 1) { + auto [iter, inserted] = v.insert_and_find(cuco::pair{i, 1}); + // for debugging... + // if (iter->second < 0) { + // asm("trap;"); + // } + if (!inserted) { + auto ref = + cuda::atomic_ref{iter->second}; + ref.fetch_add(1); + } + } else { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + auto [iter, inserted] = v.insert_and_find(tile, cuco::pair{i, 1}); + if (!inserted and tile.thread_rank() == 0) { + auto ref = + cuda::atomic_ref{iter->second}; + ref.fetch_add(1); + } + } } #if __CUDA_ARCH__ < 700 else { - v.insert(thrust::make_pair(i, gridDim.x * blockDim.x)); + auto constexpr cg_size = Ref::cg_size; + if constexpr (cg_size == 1) { + v.insert(cuco::pair{i, gridDim.x * blockDim.x}); + } else { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + v.insert(tile, cuco::pair{i, gridDim.x * blockDim.x / cg_size}); + } } #endif } } -TEMPLATE_TEST_CASE_SIG("Parallel insert-or-update", - "", - ((typename Key, typename Value), Key, Value), - (int32_t, int32_t), - (int32_t, int64_t), - (int64_t, int32_t), - (int64_t, int64_t)) +TEMPLATE_TEST_CASE_SIG( + "static_map insert_and_find tests", + "", + ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), + Key, + Value, + Probe, + CGSize), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2)) { - cuco::empty_key empty_key_sentinel{-1}; - cuco::empty_value empty_value_sentinel{-1}; - cuco::static_map m(10 * Iters, empty_key_sentinel, empty_value_sentinel); + using probe = + std::conditional_t>, + cuco::experimental::double_hashing, + cuco::murmurhash3_32>>; + + auto map = cuco::experimental::static_map, + cuda::thread_scope_device, + thrust::equal_to, + probe, + cuco::cuda_allocator, + cuco::experimental::storage<2>>{ + 10 * Iters, cuco::empty_key{-1}, cuco::empty_value{-1}}; static constexpr int Blocks = 1024; static constexpr int Threads = 128; - parallel_sum<<>>(m.get_device_mutable_view()); + + parallel_sum<<>>( + map.ref(cuco::experimental::op::insert, cuco::experimental::op::insert_and_find)); CUCO_CUDA_TRY(cudaDeviceSynchronize()); thrust::device_vector d_keys(Iters); thrust::device_vector d_values(Iters); thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); - m.find(d_keys.begin(), d_keys.end(), d_values.begin()); + map.find(d_keys.begin(), d_keys.end(), d_values.begin()); - REQUIRE(cuco::test::all_of( - d_values.begin(), d_values.end(), [] __device__(Value v) { return v == Blocks * Threads; })); + REQUIRE(cuco::test::all_of(d_values.begin(), d_values.end(), [] __device__(Value v) { + return v == (Blocks * Threads) / CGSize; + })); } From c5f94e5d4c11402e442b782d8d3d23c50b90443d Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Sat, 18 Nov 2023 20:48:14 -0800 Subject: [PATCH 2/3] Fix potential logic issues in CG erase (#396) --- .../open_addressing/open_addressing_ref_impl.cuh | 10 +++++----- tests/static_map/erase_test.cu | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 26e5a055a..86d072a8a 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -646,12 +646,12 @@ class open_addressing_ref_impl { case insert_result::DUPLICATE: return false; default: continue; } - } else if (group.any(state == detail::equal_result::EMPTY)) { - // Key doesn't exist, return false - return false; - } else { - ++probing_iter; } + + // Key doesn't exist, return false + if (group.any(state == detail::equal_result::EMPTY)) { return false; } + + ++probing_iter; } } diff --git a/tests/static_map/erase_test.cu b/tests/static_map/erase_test.cu index 5e410c5cc..aab0df6d8 100644 --- a/tests/static_map/erase_test.cu +++ b/tests/static_map/erase_test.cu @@ -106,7 +106,7 @@ TEMPLATE_TEST_CASE_SIG( (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2)) { - constexpr size_type num_keys{400}; + constexpr size_type num_keys{1'000'000}; using probe = std::conditional_t Date: Mon, 20 Nov 2023 15:43:26 -0800 Subject: [PATCH 3/3] Migrate tests to excercise the new map (#393) This PR migrates static map tests to test the new static map: - Large type tests are not migrated since the new map doesn't support keys larger than 8 bytes - Shared memory tests require related functions to be added into ref code thus would be in a separate PR --- tests/static_map/custom_type_test.cu | 59 ++++--- tests/static_map/duplicate_keys_test.cu | 59 +++++-- tests/static_map/heterogeneous_lookup_test.cu | 7 +- tests/static_map/insert_and_find_test.cu | 9 +- tests/static_map/insert_or_assign_test.cu | 15 +- tests/static_map/key_sentinel_test.cu | 53 ++++--- tests/static_map/shared_memory_test.cu | 15 +- tests/static_map/stream_test.cu | 46 +++--- tests/static_map/unique_sequence_test.cu | 148 +++--------------- 9 files changed, 187 insertions(+), 224 deletions(-) diff --git a/tests/static_map/custom_type_test.cu b/tests/static_map/custom_type_test.cu index e23216ca3..536c83194 100644 --- a/tests/static_map/custom_type_test.cu +++ b/tests/static_map/custom_type_test.cu @@ -27,6 +27,8 @@ #include +#include + #include // User-defined key type @@ -123,17 +125,18 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", thrust::counting_iterator(0), thrust::counting_iterator(num), insert_keys.begin(), - [] __device__(auto i) { return Key{i}; }); + cuda::proclaim_return_type([] __device__(auto i) { return Key{i}; })); thrust::transform(thrust::device, thrust::counting_iterator(0), thrust::counting_iterator(num), insert_values.begin(), - [] __device__(auto i) { return Value{i}; }); + cuda::proclaim_return_type([] __device__(auto i) { return Value{i}; })); - auto insert_pairs = - thrust::make_transform_iterator(thrust::make_counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i, i); }); + auto insert_pairs = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cuda::proclaim_return_type>( + [] __device__(auto i) { return cuco::pair(i, i); })); SECTION("All inserted keys-value pairs should be correctly recovered during find") { @@ -151,9 +154,9 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", REQUIRE(cuco::test::equal(insert_values.begin(), insert_values.end(), found_values.begin(), - [] __device__(Value lhs, Value rhs) { + cuda::proclaim_return_type([] __device__(Value lhs, Value rhs) { return std::tie(lhs.f, lhs.s) == std::tie(rhs.f, rhs.s); - })); + }))); } SECTION("All inserted keys-value pairs should be contained") @@ -175,7 +178,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", insert_pairs, insert_pairs + num, thrust::counting_iterator(0), - [] __device__(auto const& key) { return (key % 2) == 0; }, + cuda::proclaim_return_type([] __device__(auto const& key) { return (key % 2) == 0; }), hash_custom_key{}, custom_key_equals{}); @@ -187,12 +190,13 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", hash_custom_key{}, custom_key_equals{}); - REQUIRE(cuco::test::equal(contained.begin(), - contained.end(), - thrust::counting_iterator(0), - [] __device__(auto const& idx_contained, auto const& idx) { - return ((idx % 2) == 0) == idx_contained; - })); + REQUIRE(cuco::test::equal( + contained.begin(), + contained.end(), + thrust::counting_iterator(0), + cuda::proclaim_return_type([] __device__(auto const& idx_contained, auto const& idx) { + return ((idx % 2) == 0) == idx_contained; + }))); } SECTION("Non-inserted keys-value pairs should not be contained") @@ -212,9 +216,11 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", map.insert(insert_pairs, insert_pairs + num, hash_custom_key{}, custom_key_equals{}); auto view = map.get_device_view(); REQUIRE(cuco::test::all_of( - insert_pairs, insert_pairs + num, [view] __device__(cuco::pair const& pair) { + insert_pairs, + insert_pairs + num, + cuda::proclaim_return_type([view] __device__(cuco::pair const& pair) { return view.contains(pair.first, hash_custom_key{}, custom_key_equals{}); - })); + }))); } SECTION("Inserting unique keys should return insert success.") @@ -222,9 +228,11 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", auto m_view = map.get_device_mutable_view(); REQUIRE(cuco::test::all_of(insert_pairs, insert_pairs + num, - [m_view] __device__(cuco::pair const& pair) mutable { - return m_view.insert(pair, hash_custom_key{}, custom_key_equals{}); - })); + cuda::proclaim_return_type( + [m_view] __device__(cuco::pair const& pair) mutable { + return m_view.insert( + pair, hash_custom_key{}, custom_key_equals{}); + }))); } SECTION("Cannot find any key in an empty hash map") @@ -235,18 +243,21 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", REQUIRE(cuco::test::all_of( insert_pairs, insert_pairs + num, - [view] __device__(cuco::pair const& pair) mutable { - return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end(); - })); + cuda::proclaim_return_type( + [view] __device__(cuco::pair const& pair) mutable { + return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end(); + }))); } SECTION("const view") { auto const view = map.get_device_view(); REQUIRE(cuco::test::all_of( - insert_pairs, insert_pairs + num, [view] __device__(cuco::pair const& pair) { + insert_pairs, + insert_pairs + num, + cuda::proclaim_return_type([view] __device__(cuco::pair const& pair) { return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end(); - })); + }))); } } } diff --git a/tests/static_map/duplicate_keys_test.cu b/tests/static_map/duplicate_keys_test.cu index 5620fa4e9..e17ec3af8 100644 --- a/tests/static_map/duplicate_keys_test.cu +++ b/tests/static_map/duplicate_keys_test.cu @@ -29,16 +29,52 @@ #include -TEMPLATE_TEST_CASE_SIG("Duplicate keys", - "", - ((typename Key, typename Value), Key, Value), - (int32_t, int32_t), - (int32_t, int64_t), - (int64_t, int32_t), - (int64_t, int64_t)) +#include + +using size_type = std::size_t; + +TEMPLATE_TEST_CASE_SIG( + "static_map duplicate keys", + "", + ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), + Key, + Value, + Probe, + CGSize), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2)) { - constexpr std::size_t num_keys{500'000}; - cuco::static_map map{ + constexpr size_type num_keys{500'000}; + + using probe = + std::conditional_t>, + cuco::experimental::double_hashing, + cuco::murmurhash3_32>>; + + auto map = cuco::experimental::static_map, + cuda::thread_scope_device, + thrust::equal_to, + probe, + cuco::cuda_allocator, + cuco::experimental::storage<2>>{ num_keys * 2, cuco::empty_key{-1}, cuco::empty_value{-1}}; thrust::device_vector d_keys(num_keys); @@ -49,7 +85,8 @@ TEMPLATE_TEST_CASE_SIG("Duplicate keys", auto pairs_begin = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i / 2, i / 2); }); + cuda::proclaim_return_type>( + [] __device__(auto i) { return cuco::pair(i / 2, i / 2); })); thrust::device_vector d_results(num_keys); thrust::device_vector d_contained(num_keys); @@ -68,7 +105,7 @@ TEMPLATE_TEST_CASE_SIG("Duplicate keys", map.insert(pairs_begin, pairs_begin + num_keys); - auto const num_entries = map.get_size(); + auto const num_entries = map.size(); REQUIRE(num_entries == gold); auto [key_out_end, value_out_end] = diff --git a/tests/static_map/heterogeneous_lookup_test.cu b/tests/static_map/heterogeneous_lookup_test.cu index ed1ace9bd..f386d96a5 100644 --- a/tests/static_map/heterogeneous_lookup_test.cu +++ b/tests/static_map/heterogeneous_lookup_test.cu @@ -27,6 +27,8 @@ #include +#include + #include // insert key type @@ -115,8 +117,9 @@ TEMPLATE_TEST_CASE_SIG("Heterogeneous lookup", auto insert_pairs = thrust::make_transform_iterator( thrust::counting_iterator(0), [] __device__(auto i) { return cuco::pair(i, i); }); - auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), - [] __device__(auto i) { return ProbeKey(i); }); + auto probe_keys = thrust::make_transform_iterator( + thrust::counting_iterator(0), + cuda::proclaim_return_type([] __device__(auto i) { return ProbeKey{i}; })); SECTION("All inserted keys-value pairs should be contained") { diff --git a/tests/static_map/insert_and_find_test.cu b/tests/static_map/insert_and_find_test.cu index 3afc27b9a..9941e46a6 100644 --- a/tests/static_map/insert_and_find_test.cu +++ b/tests/static_map/insert_and_find_test.cu @@ -26,6 +26,8 @@ #include +#include + static constexpr int Iters = 10'000; template @@ -129,7 +131,8 @@ TEMPLATE_TEST_CASE_SIG( thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); map.find(d_keys.begin(), d_keys.end(), d_values.begin()); - REQUIRE(cuco::test::all_of(d_values.begin(), d_values.end(), [] __device__(Value v) { - return v == (Blocks * Threads) / CGSize; - })); + REQUIRE(cuco::test::all_of( + d_values.begin(), d_values.end(), cuda::proclaim_return_type([] __device__(Value v) { + return v == (Blocks * Threads) / CGSize; + }))); } diff --git a/tests/static_map/insert_or_assign_test.cu b/tests/static_map/insert_or_assign_test.cu index 90c6553ce..4bca776f7 100644 --- a/tests/static_map/insert_or_assign_test.cu +++ b/tests/static_map/insert_or_assign_test.cu @@ -27,6 +27,8 @@ #include +#include + using size_type = std::size_t; template @@ -36,9 +38,11 @@ __inline__ void test_insert_or_assign(Map& map, size_type num_keys) using Value = typename Map::mapped_type; // Insert pairs - auto pairs_begin = - thrust::make_transform_iterator(thrust::counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i, i); }); + auto pairs_begin = thrust::make_transform_iterator( + thrust::counting_iterator(0), + cuda::proclaim_return_type>([] __device__(auto i) { + return cuco::pair{i, i}; + })); auto const initial_size = map.insert(pairs_begin, pairs_begin + num_keys); REQUIRE(initial_size == num_keys); // all keys should be inserted @@ -58,8 +62,9 @@ __inline__ void test_insert_or_assign(Map& map, size_type num_keys) thrust::device_vector d_values(num_keys); map.retrieve_all(d_keys.begin(), d_values.begin()); - auto gold_values_begin = thrust::make_transform_iterator(thrust::counting_iterator(0), - [] __device__(auto i) { return i * 2; }); + auto gold_values_begin = thrust::make_transform_iterator( + thrust::counting_iterator(0), + cuda::proclaim_return_type([] __device__(auto i) { return i * 2; })); thrust::sort(thrust::device, d_values.begin(), d_values.end()); REQUIRE(cuco::test::equal( diff --git a/tests/static_map/key_sentinel_test.cu b/tests/static_map/key_sentinel_test.cu index 74a1badd1..dceaf6ec4 100644 --- a/tests/static_map/key_sentinel_test.cu +++ b/tests/static_map/key_sentinel_test.cu @@ -24,12 +24,14 @@ #include +#include + #define SIZE 10 __device__ int A[SIZE]; template struct custom_equals { - __device__ bool operator()(T lhs, T rhs) { return A[lhs] == A[rhs]; } + __device__ bool operator()(T lhs, T rhs) const { return A[lhs] == A[rhs]; } }; TEMPLATE_TEST_CASE_SIG( @@ -39,11 +41,15 @@ TEMPLATE_TEST_CASE_SIG( using Value = T; constexpr std::size_t num_keys{SIZE}; - cuco::static_map map{ - SIZE * 2, cuco::empty_key{-1}, cuco::empty_value{-1}}; + auto map = cuco::experimental::static_map{ + SIZE * 2, + cuco::empty_key{-1}, + cuco::empty_value{-1}, + custom_equals{}, + cuco::experimental::linear_probing<1, cuco::default_hash_function>{}}; - auto m_view = map.get_device_mutable_view(); - auto view = map.get_device_view(); + auto insert_ref = map.ref(cuco::experimental::op::insert); + auto find_ref = map.ref(cuco::experimental::op::find); int h_A[SIZE]; for (int i = 0; i < SIZE; i++) { @@ -51,34 +57,35 @@ TEMPLATE_TEST_CASE_SIG( } CUCO_CUDA_TRY(cudaMemcpyToSymbol(A, h_A, SIZE * sizeof(int))); - auto pairs_begin = - thrust::make_transform_iterator(thrust::make_counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i, i); }); + auto pairs_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cuda::proclaim_return_type>( + [] __device__(auto i) { return cuco::pair(i, i); })); SECTION( "Tests of non-CG insert: The custom `key_equal` can never be used to compare against sentinel") { - REQUIRE(cuco::test::all_of(pairs_begin, - pairs_begin + num_keys, - [m_view] __device__(cuco::pair const& pair) mutable { - return m_view.insert( - pair, cuco::default_hash_function{}, custom_equals{}); - })); + REQUIRE( + cuco::test::all_of(pairs_begin, + pairs_begin + num_keys, + cuda::proclaim_return_type( + [insert_ref] __device__(cuco::pair const& pair) mutable { + return insert_ref.insert(pair); + }))); } SECTION( "Tests of CG insert: The custom `key_equal` can never be used to compare against sentinel") { - map.insert(pairs_begin, - pairs_begin + num_keys, - cuco::default_hash_function{}, - custom_equals{}); + map.insert(pairs_begin, pairs_begin + num_keys); // All keys inserted via custom `key_equal` should be found REQUIRE(cuco::test::all_of( - pairs_begin, pairs_begin + num_keys, [view] __device__(cuco::pair const& pair) { - auto const found = view.find(pair.first); - return (found != view.end()) and - (found->first.load() == pair.first and found->second.load() == pair.second); - })); + pairs_begin, + pairs_begin + num_keys, + cuda::proclaim_return_type([find_ref] __device__(cuco::pair const& pair) { + auto const found = find_ref.find(pair.first); + return (found != find_ref.end()) and + (found->first == pair.first and found->second == pair.second); + }))); } } diff --git a/tests/static_map/shared_memory_test.cu b/tests/static_map/shared_memory_test.cu index 444f1c7e7..70e2def8d 100644 --- a/tests/static_map/shared_memory_test.cu +++ b/tests/static_map/shared_memory_test.cu @@ -27,6 +27,8 @@ #include +#include + #include template @@ -74,9 +76,8 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map", (int64_t, int32_t), (int64_t, int64_t)) { - using MapType = cuco::static_map; - using DeviceViewType = typename MapType::device_view; - using DeviceViewIteratorType = typename DeviceViewType::iterator; + using MapType = cuco::static_map; + using DeviceViewType = typename MapType::device_view; constexpr std::size_t number_of_maps = 1000; constexpr std::size_t elements_in_map = 500; @@ -127,9 +128,11 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map", auto zip = thrust::make_zip_iterator( thrust::make_tuple(d_keys_exist.begin(), d_keys_and_values_correct.begin())); - REQUIRE(cuco::test::all_of(zip, zip + d_keys_exist.size(), [] __device__(auto const& z) { - return thrust::get<0>(z) and thrust::get<1>(z); - })); + REQUIRE(cuco::test::all_of(zip, + zip + d_keys_exist.size(), + cuda::proclaim_return_type([] __device__(auto const& z) { + return thrust::get<0>(z) and thrust::get<1>(z); + }))); } SECTION("No key is found before insertion.") diff --git a/tests/static_map/stream_test.cu b/tests/static_map/stream_test.cu index 6121cbd62..fe1b2ac65 100644 --- a/tests/static_map/stream_test.cu +++ b/tests/static_map/stream_test.cu @@ -29,7 +29,9 @@ #include -TEMPLATE_TEST_CASE_SIG("Unique sequence of keys on given stream", +#include + +TEMPLATE_TEST_CASE_SIG("static_map: unique sequence of keys on given stream", "", ((typename Key, typename Value), Key, Value), (int32_t, int32_t), @@ -41,11 +43,14 @@ TEMPLATE_TEST_CASE_SIG("Unique sequence of keys on given stream", CUCO_CUDA_TRY(cudaStreamCreate(&stream)); constexpr std::size_t num_keys{500'000}; - cuco::static_map map{1'000'000, - cuco::empty_key{-1}, - cuco::empty_value{-1}, - cuco::cuda_allocator{}, - stream}; + auto map = cuco::experimental::static_map{ + num_keys * 2, + cuco::empty_key{-1}, + cuco::empty_value{-1}, + thrust::equal_to{}, + cuco::experimental::linear_probing<1, cuco::default_hash_function>{}, + cuco::cuda_allocator{}, + stream}; thrust::device_vector d_keys(num_keys); thrust::device_vector d_values(num_keys); @@ -53,35 +58,34 @@ TEMPLATE_TEST_CASE_SIG("Unique sequence of keys on given stream", thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); thrust::sequence(thrust::device, d_values.begin(), d_values.end()); - auto pairs_begin = - thrust::make_transform_iterator(thrust::make_counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i, i); }); - - auto hash_fn = cuco::default_hash_function{}; - auto equal_fn = thrust::equal_to{}; + auto pairs_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cuda::proclaim_return_type>( + [] __device__(auto i) { return cuco::pair(i, i); })); // bulk function test cases SECTION("All inserted keys-value pairs should be correctly recovered during find") { thrust::device_vector d_results(num_keys); - map.insert(pairs_begin, pairs_begin + num_keys, hash_fn, equal_fn, stream); - map.find(d_keys.begin(), d_keys.end(), d_results.begin(), hash_fn, equal_fn, stream); + map.insert(pairs_begin, pairs_begin + num_keys, stream); + map.find(d_keys.begin(), d_keys.end(), d_results.begin(), stream); auto zip = thrust::make_zip_iterator(thrust::make_tuple(d_results.begin(), d_values.begin())); - REQUIRE(cuco::test::all_of( - zip, - zip + num_keys, - [] __device__(auto const& p) { return thrust::get<0>(p) == thrust::get<1>(p); }, - stream)); + REQUIRE(cuco::test::all_of(zip, + zip + num_keys, + cuda::proclaim_return_type([] __device__(auto const& p) { + return thrust::get<0>(p) == thrust::get<1>(p); + }), + stream)); } SECTION("All inserted keys-value pairs should be contained") { thrust::device_vector d_contained(num_keys); - map.insert(pairs_begin, pairs_begin + num_keys, hash_fn, equal_fn, stream); - map.contains(d_keys.begin(), d_keys.end(), d_contained.begin(), hash_fn, equal_fn, stream); + map.insert(pairs_begin, pairs_begin + num_keys, stream); + map.contains(d_keys.begin(), d_keys.end(), d_contained.begin(), stream); REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{}, stream)); } diff --git a/tests/static_map/unique_sequence_test.cu b/tests/static_map/unique_sequence_test.cu index 6a0165cc2..69fa69fb0 100644 --- a/tests/static_map/unique_sequence_test.cu +++ b/tests/static_map/unique_sequence_test.cu @@ -31,122 +31,7 @@ #include -TEMPLATE_TEST_CASE_SIG("Unique sequence of keys", - "", - ((typename Key, typename Value), Key, Value), - (int32_t, int32_t), - (int32_t, int64_t), - (int64_t, int32_t), - (int64_t, int64_t)) -{ - constexpr std::size_t num_keys{500'000}; - cuco::static_map map{ - 1'000'000, cuco::empty_key{-1}, cuco::empty_value{-1}}; - - auto m_view = map.get_device_mutable_view(); - auto view = map.get_device_view(); - - thrust::device_vector d_keys(num_keys); - thrust::device_vector d_values(num_keys); - - thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); - thrust::sequence(thrust::device, d_values.begin(), d_values.end()); - - auto pairs_begin = - thrust::make_transform_iterator(thrust::make_counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i, i); }); - - thrust::device_vector d_results(num_keys); - thrust::device_vector d_contained(num_keys); - - // bulk function test cases - SECTION("All inserted keys-value pairs should be correctly recovered during find") - { - map.insert(pairs_begin, pairs_begin + num_keys); - map.find(d_keys.begin(), d_keys.end(), d_results.begin()); - auto zip = thrust::make_zip_iterator(thrust::make_tuple(d_results.begin(), d_values.begin())); - - REQUIRE(cuco::test::all_of(zip, zip + num_keys, [] __device__(auto const& p) { - return thrust::get<0>(p) == thrust::get<1>(p); - })); - } - - SECTION("All inserted keys-value pairs should be contained") - { - map.insert(pairs_begin, pairs_begin + num_keys); - map.contains(d_keys.begin(), d_keys.end(), d_contained.begin()); - - REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{})); - } - - SECTION("Non-inserted keys-value pairs should not be contained") - { - map.contains(d_keys.begin(), d_keys.end(), d_contained.begin()); - - REQUIRE(cuco::test::none_of(d_contained.begin(), d_contained.end(), thrust::identity{})); - } - - SECTION("Inserting unique keys should return insert success.") - { - REQUIRE(cuco::test::all_of(pairs_begin, - pairs_begin + num_keys, - [m_view] __device__(cuco::pair const& pair) mutable { - return m_view.insert(pair); - })); - } - - SECTION("Cannot find any key in an empty hash map with non-const view") - { - SECTION("non-const view") - { - REQUIRE(cuco::test::all_of(pairs_begin, - pairs_begin + num_keys, - [view] __device__(cuco::pair const& pair) mutable { - return view.find(pair.first) == view.end(); - })); - } - SECTION("const view") - { - REQUIRE(cuco::test::all_of( - pairs_begin, pairs_begin + num_keys, [view] __device__(cuco::pair const& pair) { - return view.find(pair.first) == view.end(); - })); - } - } - - SECTION("Keys are all found after inserting many keys.") - { - // Bulk insert keys - thrust::for_each( - thrust::device, - pairs_begin, - pairs_begin + num_keys, - [m_view] __device__(cuco::pair const& pair) mutable { m_view.insert(pair); }); - - SECTION("non-const view") - { - // All keys should be found - REQUIRE(cuco::test::all_of(pairs_begin, - pairs_begin + num_keys, - [view] __device__(cuco::pair const& pair) mutable { - auto const found = view.find(pair.first); - return (found != view.end()) and - (found->first.load() == pair.first and - found->second.load() == pair.second); - })); - } - SECTION("const view") - { - // All keys should be found - REQUIRE(cuco::test::all_of( - pairs_begin, pairs_begin + num_keys, [view] __device__(cuco::pair const& pair) { - auto const found = view.find(pair.first); - return (found != view.end()) and - (found->first.load() == pair.first and found->second.load() == pair.second); - })); - } - } -} +#include using size_type = int32_t; @@ -160,14 +45,18 @@ __inline__ void test_unique_sequence(Map& map, size_type num_keys) thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); - auto keys_begin = d_keys.begin(); - auto pairs_begin = - thrust::make_transform_iterator(thrust::make_counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i, i); }); + auto keys_begin = d_keys.begin(); + auto pairs_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cuda::proclaim_return_type>([] __device__(auto i) { + return cuco::pair{i, i}; + })); thrust::device_vector d_contained(num_keys); - auto zip_equal = [] __device__(auto const& p) { return thrust::get<0>(p) == thrust::get<1>(p); }; - auto is_even = [] __device__(auto const& i) { return i % 2 == 0; }; + auto zip_equal = cuda::proclaim_return_type( + [] __device__(auto const& p) { return thrust::get<0>(p) == thrust::get<1>(p); }); + auto is_even = + cuda::proclaim_return_type([] __device__(auto const& i) { return i % 2 == 0; }); SECTION("Non-inserted keys should not be contained.") { @@ -196,12 +85,13 @@ __inline__ void test_unique_sequence(Map& map, size_type num_keys) REQUIRE(map.size() == num_keys / 2); map.contains(keys_begin, keys_begin + num_keys, d_contained.begin()); - REQUIRE(cuco::test::equal(d_contained.begin(), - d_contained.end(), - thrust::counting_iterator(0), - [] __device__(auto const& idx_contained, auto const& idx) { - return ((idx % 2) == 0) == idx_contained; - })); + REQUIRE(cuco::test::equal( + d_contained.begin(), + d_contained.end(), + thrust::counting_iterator(0), + cuda::proclaim_return_type([] __device__(auto const& idx_contained, auto const& idx) { + return ((idx % 2) == 0) == idx_contained; + }))); } map.insert(pairs_begin, pairs_begin + num_keys); @@ -253,7 +143,7 @@ __inline__ void test_unique_sequence(Map& map, size_type num_keys) } TEMPLATE_TEST_CASE_SIG( - "Unique sequence", + "static_map: unique sequence", "", ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), Key,