diff --git a/tests/static_map/shared_memory_test.cu b/tests/static_map/shared_memory_test.cu index 520f5179a..91a1b52a6 100644 --- a/tests/static_map/shared_memory_test.cu +++ b/tests/static_map/shared_memory_test.cu @@ -204,12 +204,14 @@ __global__ void shared_memory_hash_table_kernel(bool* key_found) auto find_ref = std::move(insert_ref).with(cuco::experimental::op::find); auto const retrieved_pair = find_ref.find(rank); + block.sync(); + if (retrieved_pair != find_ref.end() && retrieved_pair->second == rank) { key_found[index] = true; } } -TEMPLATE_TEST_CASE("Shared memory slots.", "", int32_t) +TEMPLATE_TEST_CASE("static map shared memory slots.", "", int32_t) { constexpr std::size_t N = 256; auto constexpr num_windows = cuco::experimental::make_window_extent( @@ -218,6 +220,7 @@ TEMPLATE_TEST_CASE("Shared memory slots.", "", int32_t) thrust::device_vector key_found(N, false); shared_memory_hash_table_kernel <<<8, 32>>>(key_found.data().get()); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); REQUIRE(cuco::test::all_of(key_found.begin(), key_found.end(), thrust::identity{})); }