Skip to content

Commit

Permalink
Expose erased key sentinel API + test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Dec 5, 2023
1 parent bd57dc9 commit 9cb6c7c
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 18 deletions.
16 changes: 15 additions & 1 deletion include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,20 @@ static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>
return impl_.empty_value_sentinel();
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
__host__ __device__ constexpr Key
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
erased_key_sentinel() const noexcept
{
return impl_.erased_key_sentinel();
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
Expand Down Expand Up @@ -216,7 +230,7 @@ static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>
cuco::empty_value<T>{this->empty_value_sentinel()},
cuco::erased_key<Key>{this->erased_key_sentinel()},
this->key_eq(),
this->probing_scheme(),
this->impl_.probing_scheme(),
storage_ref_type{this->window_extent(), memory_to_use}};
}

Expand Down
15 changes: 14 additions & 1 deletion include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,19 @@ static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::e
return impl_.empty_key_sentinel();
}

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
__host__ __device__ constexpr Key
static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::erased_key_sentinel()
const noexcept
{
return impl_.erased_key_sentinel();
}

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
Expand Down Expand Up @@ -178,7 +191,7 @@ static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::m
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::erased_key<Key>{this->erased_key_sentinel()},
this->key_eq(),
this->probing_scheme(),
this->impl_.probing_scheme(),
storage_ref_type{this->window_extent(), memory_to_use}};
}

Expand Down
7 changes: 7 additions & 0 deletions include/cuco/static_map_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ class static_map_ref
*/
[[nodiscard]] __host__ __device__ constexpr mapped_type empty_value_sentinel() const noexcept;

/**
* @brief Gets the sentinel value used to represent an erased key slot.
*
* @return The sentinel value used to represent an erased key slot
*/
[[nodiscard]] __host__ __device__ constexpr key_type erased_key_sentinel() const noexcept;

/**
* @brief Gets the key comparator.
*
Expand Down
7 changes: 7 additions & 0 deletions include/cuco/static_set_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class static_set_ref
*/
[[nodiscard]] __host__ __device__ constexpr key_type empty_key_sentinel() const noexcept;

/**
* @brief Gets the sentinel value used to represent an erased key slot.
*
* @return The sentinel value used to represent an erased key slot
*/
[[nodiscard]] __host__ __device__ constexpr key_type erased_key_sentinel() const noexcept;

/**
* @brief Gets the key comparator.
*
Expand Down
34 changes: 18 additions & 16 deletions tests/static_map/shared_memory_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,22 @@

#include <limits>

/*
template <std::size_t NumWindows, typename Ref>
__global__ void shared_memory_test_kernel(
Ref* maps,
typename Ref::key_type const* const insterted_keys,
typename Ref::mapped_type const* const inserted_values,
bool* const keys_exist,
bool* const keys_and_values_correct)
__global__ void shared_memory_test_kernel(Ref* maps,
typename Ref::key_type const* const insterted_keys,
typename Ref::mapped_type const* const inserted_values,
bool* const keys_exist,
bool* const keys_and_values_correct)
{
// Each block processes one map
const size_t map_id = blockIdx.x;
const size_t offset = map_id * maps[map_id].capacity();

__shared__ typename Ref::window_type sm_buffer[NumWindows];

auto g = cuco::test::cg::this_thread_block();
auto insert_ref =
maps[map_id].make_copy(g, sm_buffer);
auto find_ref = insert_ref.with(cuco::experimental::op::find);
auto g = cuco::test::cg::this_thread_block();
auto insert_ref = maps[map_id].make_copy(g, sm_buffer);
auto find_ref = std::move(insert_ref).with(cuco::experimental::op::find);

for (int i = g.thread_rank(); i < maps[map_id].capacity(); i += g.size()) {
auto found_pair_it = find_ref.find(insterted_keys[offset + i]);
Expand Down Expand Up @@ -82,7 +79,13 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map",
constexpr std::size_t map_capacity = 2 * elements_in_map;

using extent_type = cuco::experimental::extent<std::size_t, map_capacity>;
using map_type = cuco::experimental::static_map<Key, Value, extent_type>;
using map_type = cuco::experimental::static_map<
Key,
Value,
extent_type,
cuda::thread_scope_device,
thrust::equal_to<Key>,
cuco::experimental::linear_probing<1, cuco::default_hash_function<Key>>>;

// one array for all maps, first elements_in_map element belong to map 0, second to map 1 and so
// on
Expand All @@ -103,11 +106,10 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map",
thrust::device_vector<bool> d_keys_exist(number_of_maps * elements_in_map);
thrust::device_vector<bool> d_keys_and_values_correct(number_of_maps * elements_in_map);

using ref_type = decltype(maps.front()->ref(cuco::experimental::op::insert));
using ref_type = decltype(maps.front()->ref(cuco::experimental::op::insert));

SECTION("Keys are all found after insertion.")
{
auto pairs_begin =
thrust::make_zip_iterator(thrust::make_tuple(d_keys.begin(), d_values.begin()));
std::vector<ref_type> h_refs;
Expand All @@ -120,7 +122,7 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map",
}
thrust::device_vector<ref_type> d_refs(h_refs);

auto constexpr num_windows = h_refs[0].window_extent();
auto constexpr num_windows = cuco::experimental::make_window_extent<ref_type>(extent_type{});

shared_memory_test_kernel<static_cast<std::size_t>(num_windows), ref_type>
<<<number_of_maps, 64>>>(d_refs.data().get(),
Expand Down Expand Up @@ -148,7 +150,7 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map",
}
thrust::device_vector<ref_type> d_refs(h_refs);

auto constexpr num_windows = h_refs[0].window_extent();
auto constexpr num_windows = cuco::experimental::make_window_extent<ref_type>(extent_type{});

shared_memory_test_kernel<static_cast<std::size_t>(num_windows), ref_type>
<<<number_of_maps, 64>>>(d_refs.data().get(),
Expand Down

0 comments on commit 9cb6c7c

Please sign in to comment.