Skip to content

Commit

Permalink
Add ref initialize function
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Dec 12, 2023
1 parent a95fc85 commit 80436a2
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 13 deletions.
23 changes: 23 additions & 0 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,29 @@ class open_addressing_ref_impl {
#endif
}

/**
* @brief Initializes the container storage using the threads in the group `tile`.
*
* @note This function synchronizes the group `tile`.
*
* @tparam CG The type of the cooperative thread group
*
* @param tile The cooperative thread group used to initialize the container
*/
template <typename CG>
__device__ constexpr void initialize(CG const& tile) noexcept
{
auto tid = tile.thread_rank();
while (tid < static_cast<size_type>(this->window_extent())) {
#pragma unroll
for (auto& slot : this->storage_ref_[tid]) {
slot = this->empty_slot_sentinel();
}
tid += tile.size();
}
tile.sync();
}

/**
* @brief Inserts an element.
*
Expand Down
19 changes: 17 additions & 2 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ template <typename Key,
template <typename CG>
__device__ constexpr auto
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::make_copy(
CG const& g, window_type* const memory_to_use) const noexcept
CG const& tile, window_type* const memory_to_use) const noexcept
{
this->impl_.make_copy(g, memory_to_use);
this->impl_.make_copy(tile, memory_to_use);
return static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>{
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
Expand All @@ -234,6 +234,21 @@ static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>
storage_ref_type{this->window_extent(), memory_to_use}};
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
template <typename CG>
__device__ constexpr void
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::initialize(
CG const& tile) noexcept
{
this->impl_.initialize(tile);
}

namespace detail {

template <typename Key,
Expand Down
18 changes: 16 additions & 2 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ template <typename Key,
template <typename CG>
__device__ constexpr auto
static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::make_copy(
CG const& g, window_type* const memory_to_use) const noexcept
CG const& tile, window_type* const memory_to_use) const noexcept
{
this->impl_.make_copy(g, memory_to_use);
this->impl_.make_copy(tile, memory_to_use);
return static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>{
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::erased_key<Key>{this->erased_key_sentinel()},
Expand All @@ -195,6 +195,20 @@ static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::m
storage_ref_type{this->window_extent(), memory_to_use}};
}

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
template <typename CG>
__device__ constexpr void
static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::initialize(
CG const& tile) noexcept
{
this->impl_.initialize(tile);
}

namespace detail {

template <typename Key,
Expand Down
16 changes: 14 additions & 2 deletions include/cuco/static_map_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,27 @@ class static_map_ref
*
* @tparam CG The type of the cooperative thread group
*
* @param g The ooperative thread group used to copy the data structure
* @param tile The ooperative thread group used to copy the data structure
* @param memory_to_use Array large enough to support `capacity` elements. Object does not take
* the ownership of the memory
*
* @return Copy of the current device ref
*/
template <typename CG>
[[nodiscard]] __device__ constexpr auto make_copy(
CG const& g, window_type* const memory_to_use) const noexcept;
CG const& tile, window_type* const memory_to_use) const noexcept;

/**
* @brief Initializes the map storage using the threads in the group `tile`.
*
* @note This function synchronizes the group `tile`.
*
* @tparam CG The type of the cooperative thread group
*
* @param tile The cooperative thread group used to initialize the map
*/
template <typename CG>
__device__ constexpr void initialize(CG const& tile) noexcept;

private:
impl_type impl_; ///< Static map ref implementation
Expand Down
16 changes: 14 additions & 2 deletions include/cuco/static_set_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,27 @@ class static_set_ref
*
* @tparam CG The type of the cooperative thread group
*
* @param g The ooperative thread group used to copy the data structure
* @param tile The ooperative thread group used to copy the data structure
* @param memory_to_use Array large enough to support `capacity` elements. Object does not take
* the ownership of the memory
*
* @return Copy of the current device ref
*/
template <typename CG>
[[nodiscard]] __device__ constexpr auto make_copy(
CG const& g, window_type* const memory_to_use) const noexcept;
CG const& tile, window_type* const memory_to_use) const noexcept;

/**
* @brief Initializes the set storage using the threads in the group `tile`.
*
* @note This function synchronizes the group `tile`.
*
* @tparam CG The type of the cooperative thread group
*
* @param tile The cooperative thread group used to initialize the set
*/
template <typename CG>
__device__ constexpr void initialize(CG const& tile) noexcept;

private:
impl_type impl_;
Expand Down
10 changes: 5 additions & 5 deletions tests/static_map/shared_memory_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ __global__ void shared_memory_hash_table_kernel(bool* key_found)
using extent_type = cuco::experimental::extent<std::size_t, NumWindows>;
using storage_ref_type = cuco::experimental::aow_storage_ref<slot_type, window_size, extent_type>;

// CTAD doesn't work for container ref types
auto raw_ref = cuco::experimental::static_map_ref<
Key,
Value,
Expand All @@ -190,12 +191,11 @@ __global__ void shared_memory_hash_table_kernel(bool* key_found)
{},
storage_ref_type{extent_type{}, map}};

namespace cg = cooperative_groups;
auto const block = cg::this_thread_block();
// map.initialize(block);
auto const block = cooperative_groups::this_thread_block();
raw_ref.initialize(block);

std::size_t index = threadIdx.x + blockIdx.x * blockDim.x;
auto const rank = block.thread_rank();
auto const index = threadIdx.x + blockIdx.x * blockDim.x;
auto const rank = block.thread_rank();

// insert {thread_rank, thread_rank} for each thread in thread-block
auto insert_ref = std::move(raw_ref).with(cuco::experimental::op::insert);
Expand Down

0 comments on commit 80436a2

Please sign in to comment.