diff --git a/tests/static_map/shared_memory_test.cu b/tests/static_map/shared_memory_test.cu index 70e2def8d..ca730e08e 100644 --- a/tests/static_map/shared_memory_test.cu +++ b/tests/static_map/shared_memory_test.cu @@ -31,29 +31,30 @@ #include -template +/* +template __global__ void shared_memory_test_kernel( - typename MapType::device_view const* const device_views, - typename MapType::device_view::key_type const* const insterted_keys, - typename MapType::device_view::mapped_type const* const inserted_values, - const size_t number_of_elements, + Ref* maps, + typename Ref::key_type const* const insterted_keys, + typename Ref::mapped_type const* const inserted_values, bool* const keys_exist, bool* const keys_and_values_correct) { // Each block processes one map const size_t map_id = blockIdx.x; - const size_t offset = map_id * number_of_elements; + const size_t offset = map_id * maps[map_id].capacity(); - __shared__ typename MapType::pair_atomic_type sm_buffer[CAPACITY]; + __shared__ typename Ref::window_type sm_buffer[NumWindows]; auto g = cuco::test::cg::this_thread_block(); - typename MapType::device_view sm_device_view = - MapType::device_view::make_copy(g, sm_buffer, device_views[map_id]); + auto insert_ref = + maps[map_id].make_copy(g, sm_buffer); + auto find_ref = insert_ref.with(cuco::experimental::op::find); - for (int i = g.thread_rank(); i < number_of_elements; i += g.size()) { - auto found_pair_it = sm_device_view.find(insterted_keys[offset + i]); + for (int i = g.thread_rank(); i < maps[map_id].capacity(); i += g.size()) { + auto found_pair_it = find_ref.find(insterted_keys[offset + i]); - if (found_pair_it != sm_device_view.end()) { + if (found_pair_it != find_ref.end()) { keys_exist[offset + i] = true; if (found_pair_it->first == insterted_keys[offset + i] and found_pair_it->second == inserted_values[offset + i]) { @@ -76,13 +77,13 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map", (int64_t, int32_t), (int64_t, int64_t)) { - using MapType = cuco::static_map; - using DeviceViewType = typename MapType::device_view; - constexpr std::size_t number_of_maps = 1000; constexpr std::size_t elements_in_map = 500; constexpr std::size_t map_capacity = 2 * elements_in_map; + using extent_type = cuco::experimental::extent; + using map_type = cuco::experimental::static_map; + // one array for all maps, first elements_in_map element belong to map 0, second to map 1 and so // on thrust::device_vector d_keys(number_of_maps * elements_in_map); @@ -93,34 +94,38 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map", // using std::unique_ptr because static_map does not have copy/move constructor/assignment // operator yet - std::vector> maps; + std::vector> maps; for (std::size_t map_id = 0; map_id < number_of_maps; ++map_id) { - maps.push_back(std::make_unique( - map_capacity, cuco::empty_key{-1}, cuco::empty_value{-1})); + maps.push_back(std::make_unique( + extent_type{}, cuco::empty_key{-1}, cuco::empty_value{-1})); } thrust::device_vector d_keys_exist(number_of_maps * elements_in_map); thrust::device_vector d_keys_and_values_correct(number_of_maps * elements_in_map); + using ref_type = decltype(maps.front()->ref(cuco::experimental::op::insert)); + SECTION("Keys are all found after insertion.") { + auto pairs_begin = thrust::make_zip_iterator(thrust::make_tuple(d_keys.begin(), d_values.begin())); - std::vector h_device_views; + std::vector h_refs; for (std::size_t map_id = 0; map_id < number_of_maps; ++map_id) { const std::size_t offset = map_id * elements_in_map; - MapType* map = maps[map_id].get(); + map_type* map = maps[map_id].get(); map->insert(pairs_begin + offset, pairs_begin + offset + elements_in_map); - h_device_views.push_back(map->get_device_view()); + h_refs.push_back(map->ref(cuco::experimental::op::insert)); } - thrust::device_vector d_device_views(h_device_views); + thrust::device_vector d_refs(h_refs); + + auto constexpr num_windows = h_refs[0].window_extent(); - shared_memory_test_kernel - <<>>(d_device_views.data().get(), + shared_memory_test_kernel(num_windows), ref_type> + <<>>(d_refs.data().get(), d_keys.data().get(), d_values.data().get(), - elements_in_map, d_keys_exist.data().get(), d_keys_and_values_correct.data().get()); @@ -137,17 +142,18 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map", SECTION("No key is found before insertion.") { - std::vector h_device_views; + std::vector h_refs; for (std::size_t map_id = 0; map_id < number_of_maps; ++map_id) { - h_device_views.push_back(maps[map_id].get()->get_device_view()); + h_refs.push_back(maps[map_id].get()->ref(cuco::experimental::op::insert)); } - thrust::device_vector d_device_views(h_device_views); + thrust::device_vector d_refs(h_refs); - shared_memory_test_kernel - <<>>(d_device_views.data().get(), + auto constexpr num_windows = h_refs[0].window_extent(); + + shared_memory_test_kernel(num_windows), ref_type> + <<>>(d_refs.data().get(), d_keys.data().get(), d_values.data().get(), - elements_in_map, d_keys_exist.data().get(), d_keys_and_values_correct.data().get()); @@ -155,6 +161,7 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map", } } +/* template __global__ void shared_memory_hash_table_kernel(bool* key_found) { @@ -188,3 +195,4 @@ TEMPLATE_TEST_CASE("Shared memory slots.", "", int32_t) REQUIRE(cuco::test::all_of(key_found.begin(), key_found.end(), thrust::identity{})); } +*/