Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Dec 12, 2023
1 parent 40d39b1 commit e71699f
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions tests/static_map/shared_memory_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map",

auto constexpr num_windows = cuco::experimental::make_window_extent<ref_type>(extent_type{});

shared_memory_test_kernel<static_cast<std::size_t>(num_windows), ref_type>
shared_memory_test_kernel<num_windows.value(), ref_type>
<<<number_of_maps, 64>>>(d_refs.data().get(),
d_keys.data().get(),
d_values.data().get(),
Expand Down Expand Up @@ -154,7 +154,7 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map",

auto constexpr num_windows = cuco::experimental::make_window_extent<ref_type>(extent_type{});

shared_memory_test_kernel<static_cast<std::size_t>(num_windows), ref_type>
shared_memory_test_kernel<num_windows.value(), ref_type>
<<<number_of_maps, 64>>>(d_refs.data().get(),
d_keys.data().get(),
d_values.data().get(),
Expand All @@ -166,14 +166,33 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map",
}
}

/*
template <typename K, typename V, std::size_t N>
auto constexpr cg_size = 1;
auto constexpr window_size = 1;

template <typename Key, typename Value, std::size_t NumWindows>
__global__ void shared_memory_hash_table_kernel(bool* key_found)
{
using slot_type = cuco::pair<Key, Value>;
__shared__ cuco::experimental::window<slot_type, window_size> windows[NumWindows];

using extent_type = cuco::experimental::extent<std::size_t, NumWindows>;
using storage_ref_type = cuco::experimental::aow_storage_ref<slot_type, window_size, extent_type>;

auto raw_ref = cuco::experimental::static_map_ref<
Key,
Value,
cuda::thread_scope_device,
thrust::equal_to<Key>,
cuco::experimental::linear_probing<cg_size, cuco::default_hash_function<Key>>,
storage_ref_type>{cuco::empty_key<Key>{-1},
cuco::empty_value<Value>{-1},
{},
{},
storage_ref_type{extent_type{}, windows}};
/*
namespace cg = cooperative_groups;
using map_type = typename cuco::static_map<K, V, cuda::thread_scope_block>::device_mutable_view;
using find_map_type = typename cuco::static_map<K, V, cuda::thread_scope_block>::device_view;
__shared__ typename map_type::slot_type slots[N];
auto map = map_type::make_from_uninitialized_slots(
cg::this_thread_block(), &slots[0], N, cuco::empty_key<K>{-1}, cuco::empty_value<V>{-1});
Expand All @@ -190,14 +209,18 @@ __global__ void shared_memory_hash_table_kernel(bool* key_found)
if (retrieved_pair != find_map.end() && retrieved_pair->second == rank) {
key_found[index] = true;
}
*/
}

TEMPLATE_TEST_CASE("Shared memory slots.", "", int32_t)
{
constexpr std::size_t N = 256;
constexpr std::size_t N = 256;
auto constexpr num_windows = cuco::experimental::make_window_extent<cg_size, window_size>(
cuco::experimental::extent<std::size_t, N>{});

thrust::device_vector<bool> key_found(N, false);
shared_memory_hash_table_kernel<TestType, TestType, N><<<8, 32>>>(key_found.data().get());
shared_memory_hash_table_kernel<TestType, TestType, num_windows.value()>
<<<8, 32>>>(key_found.data().get());

REQUIRE(cuco::test::all_of(key_found.begin(), key_found.end(), thrust::identity<bool>{}));
}
*/

0 comments on commit e71699f

Please sign in to comment.