Skip to content

Commit

Permalink
Enable heterogeneous insert for static_map
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Oct 10, 2023
1 parent c408fd4 commit 8eab8c1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 35 deletions.
11 changes: 6 additions & 5 deletions include/cuco/detail/static_map/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cub/block/block_reduce.cuh>

#include <cuda/atomic>
#include <iterator>

#include <cooperative_groups.h>

Expand All @@ -39,22 +40,22 @@ namespace detail {
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIterator Device accessible input iterator whose `value_type` is
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param first Beginning of the sequence of input elements
* @param n Number of input elements
* @param ref Non-owning container device ref used to access the slot storage
*/
template <int32_t CGSize, int32_t BlockSize, typename InputIterator, typename Ref>
__global__ void insert_or_assign(InputIterator first, cuco::detail::index_type n, Ref ref)
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename Ref>
__global__ void insert_or_assign(InputIt first, cuco::detail::index_type n, Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
typename Ref::value_type const insert_pair{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const& insert_pair = *(first + idx);
if constexpr (CGSize == 1) {
ref.insert_or_assign(insert_pair);
} else {
Expand Down Expand Up @@ -100,7 +101,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_

while (idx - thread_idx < n) { // the whole thread block falls into the same iteration
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
if constexpr (CGSize == 1) {
auto const found = ref.find(key);
/*
Expand Down
33 changes: 26 additions & 7 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,14 @@ class operator_impl<
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
*
* @return True if the given element is successfully inserted
*/
__device__ bool insert(value_type const& value) noexcept
template <typename Value>
__device__ bool insert(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(value);
Expand All @@ -164,12 +168,16 @@ class operator_impl<
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*
* @return True if the given element is successfully inserted
*/
template <typename Value>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
Value const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(group, value);
Expand Down Expand Up @@ -202,9 +210,12 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
*/
__device__ void insert_or_assign(value_type const& value) noexcept
template <typename Value>
__device__ void insert_or_assign(Value const& value) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

Expand Down Expand Up @@ -246,11 +257,14 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*/
template <typename Value>
__device__ void insert_or_assign(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);

Expand Down Expand Up @@ -313,13 +327,15 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*
* @return Returns `true` if the given `value` is inserted or `value` has a match in the map.
*/
__device__ constexpr bool attempt_insert_or_assign(value_type* slot,
value_type const& value) noexcept
template <typename Value>
__device__ constexpr bool attempt_insert_or_assign(value_type* slot, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto const expected_key = ref_.impl_.empty_slot_sentinel().first;
Expand Down Expand Up @@ -411,14 +427,17 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(group, value);
Expand Down
61 changes: 38 additions & 23 deletions tests/static_map/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct key_pair {
// Device equality operator is mandatory due to libcudacxx bug:
// https://github.com/NVIDIA/libcudacxx/issues/223
__device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; }

__device__ explicit operator T() const noexcept { return a; }
};

// probe key type
Expand All @@ -64,61 +66,74 @@ struct key_triplet {
// User-defined device hasher
struct custom_hasher {
template <typename CustomKey>
__device__ uint32_t operator()(CustomKey const& k)
__device__ uint32_t operator()(CustomKey const& k) const
{
return thrust::raw_reference_cast(k).a;
return k.a;
};
};

// User-defined device key equality
struct custom_key_equal {
template <typename LHS, typename RHS>
__device__ bool operator()(LHS const& lhs, RHS const& rhs)
template <typename SlotKey, typename InputKey>
__device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const
{
return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a;
return lhs == rhs.a;
}
};

TEMPLATE_TEST_CASE("Heterogeneous lookup",
"",
TEMPLATE_TEST_CASE_SIG("Heterogeneous lookup",
"",
((typename T, int CGSize), T, CGSize),
#if defined(CUCO_HAS_INDEPENDENT_THREADS) // Key type larger than 8B only supported for sm_70 and
// up
int64_t,
(int64_t, 1),
(int64_t, 2),
#endif
int32_t)

(int32_t, 1),
(int32_t, 2))
{
using Key = key_pair<TestType>;
using Value = TestType;
using ProbeKey = key_triplet<TestType>;
using Key = T;
using Value = T;
using InsertKey = key_pair<T>;
using ProbeKey = key_triplet<T>;
using probe_type = cuco::experimental::double_hashing<CGSize, custom_hasher, custom_hasher>;

auto const sentinel_key = Key{-1};
auto const sentinel_value = Value{-1};

constexpr std::size_t num = 100;
constexpr std::size_t capacity = num * 2;
cuco::static_map<Key, Value> map{
capacity, cuco::empty_key<Key>{sentinel_key}, cuco::empty_value<Value>{sentinel_value}};

auto insert_pairs =
thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto const probe = probe_type{custom_hasher{}, custom_hasher{}};
auto my_map = cuco::experimental::static_map<Key,
Value,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
custom_key_equal,
probe_type>{capacity,
cuco::empty_key<Key>{sentinel_key},
cuco::empty_value{sentinel_value},
custom_key_equal{},
probe};

auto insert_pairs = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<InsertKey, Value>(i, i); });
auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return ProbeKey(i); });

SECTION("All inserted keys-value pairs should be contained")
{
thrust::device_vector<bool> contained(num);
map.insert(insert_pairs, insert_pairs + num, custom_hasher{}, custom_key_equal{});
map.contains(
probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{});
my_map.insert(insert_pairs, insert_pairs + num);
my_map.contains(probe_keys, probe_keys + num, contained.begin());
REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{}));
}

SECTION("Non-inserted keys-value pairs should not be contained")
{
thrust::device_vector<bool> contained(num);
map.contains(
probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{});
my_map.contains(probe_keys, probe_keys + num, contained.begin());
REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), thrust::identity{}));
}
}

0 comments on commit 8eab8c1

Please sign in to comment.