Skip to content

Commit

Permalink
Enable the same thing for static_map
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Sep 29, 2023
1 parent 01ae730 commit 224881a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 43 deletions.
11 changes: 6 additions & 5 deletions include/cuco/detail/static_map/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cub/block/block_reduce.cuh>

#include <cuda/atomic>
#include <iterator>

#include <cooperative_groups.h>

Expand All @@ -39,22 +40,22 @@ namespace detail {
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIterator Device accessible input iterator whose `value_type` is
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param first Beginning of the sequence of input elements
* @param n Number of input elements
* @param ref Non-owning container device ref used to access the slot storage
*/
template <int32_t CGSize, int32_t BlockSize, typename InputIterator, typename Ref>
__global__ void insert_or_assign(InputIterator first, cuco::detail::index_type n, Ref ref)
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename Ref>
__global__ void insert_or_assign(InputIt first, cuco::detail::index_type n, Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
typename Ref::value_type const insert_pair{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const& insert_pair{*(first + idx)};
if constexpr (CGSize == 1) {
ref.insert_or_assign(insert_pair);
} else {
Expand Down Expand Up @@ -100,7 +101,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_

while (idx - thread_idx < n) { // the whole thread block falls into the same iteration
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
if constexpr (CGSize == 1) {
auto const found = ref.find(key);
/*
Expand Down
52 changes: 37 additions & 15 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,13 @@ class operator_impl<
/**
* @brief Inserts an element.
*
+ * @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
* @return True if the given element is successfully inserted
*/
__device__ bool insert(value_type const& value) noexcept
template <typename Value>
__device__ bool insert(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
Expand All @@ -256,12 +259,15 @@ class operator_impl<
/**
* @brief Inserts an element.
*
+ * @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
* @return True if the given element is successfully inserted
*/
template <typename Value>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
Value const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
Expand All @@ -282,7 +288,8 @@ class operator_impl<
using base_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using value_type = typename base_type::value_type;
using mapped_type = typename base_type::mapped_type;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;
Expand All @@ -295,14 +302,17 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
+ * @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
*/
__device__ void insert_or_assign(value_type const& value) noexcept
template <typename Value>
__device__ void insert_or_assign(Value const& value) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

ref_type& ref_ = static_cast<ref_type&>(*this);
auto const key = value.first;
auto const key = value.first; // TODO can we use auto here?
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.window_extent());
Expand All @@ -318,7 +328,7 @@ class operator_impl<
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
ref_.impl_.atomic_store(
&((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second,
value.second);
static_cast<mapped_type>(value.second));
return;
}
if (eq_res == detail::equal_result::EMPTY) {
Expand All @@ -339,15 +349,18 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
+ * @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*/
template <typename Value>
__device__ void insert_or_assign(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);

auto const key = value.first;
auto const key = value.first; // TODO
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(group, key, storage_ref.window_extent());
Expand Down Expand Up @@ -375,7 +388,7 @@ class operator_impl<
if (group.thread_rank() == src_lane) {
ref_.impl_.atomic_store(
&((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second,
value.second);
static_cast<mapped_type>(value.second));
}
group.sync();
return;
Expand Down Expand Up @@ -406,25 +419,28 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
+ * @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
*
* @return Returns `true` if the given `value` is inserted or `value` has a match in the map.
*/
__device__ constexpr bool attempt_insert_or_assign(value_type* slot,
value_type const& value) noexcept
template <typename Value>
__device__ constexpr bool attempt_insert_or_assign(value_type* slot, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto const expected_key = ref_.impl_.empty_slot_sentinel().first;

auto old_key = ref_.impl_.compare_and_swap(&slot->first, expected_key, value.first);
auto old_key =
ref_.impl_.compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));
auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success or key was already present in the map
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key) or
(ref_.predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL)) {
// Update payload
ref_.impl_.atomic_store(&slot->second, value.second);
ref_.impl_.atomic_store(&slot->second, static_cast<mapped_type>(value.second));
return true;
}
return false;
Expand Down Expand Up @@ -485,12 +501,15 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
+ * @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value) noexcept
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
Expand All @@ -504,14 +523,17 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
+ * @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
Expand Down
61 changes: 38 additions & 23 deletions tests/static_map/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct key_pair {
// Device equality operator is mandatory due to libcudacxx bug:
// https://github.com/NVIDIA/libcudacxx/issues/223
__device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; }

__device__ explicit operator T() const noexcept { return a; }
};

// probe key type
Expand All @@ -64,61 +66,74 @@ struct key_triplet {
// User-defined device hasher
struct custom_hasher {
template <typename CustomKey>
__device__ uint32_t operator()(CustomKey const& k)
__device__ uint32_t operator()(CustomKey const& k) const
{
return thrust::raw_reference_cast(k).a;
return k.a;
};
};

// User-defined device key equality
struct custom_key_equal {
template <typename LHS, typename RHS>
__device__ bool operator()(LHS const& lhs, RHS const& rhs)
template <typename SlotKey, typename InputKey>
__device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const
{
return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a;
return lhs == rhs.a;
}
};

TEMPLATE_TEST_CASE("Heterogeneous lookup",
"",
TEMPLATE_TEST_CASE_SIG("Heterogeneous lookup",
"",
((typename T, int CGSize), T, CGSize),
#if defined(CUCO_HAS_INDEPENDENT_THREADS) // Key type larger than 8B only supported for sm_70 and
// up
int64_t,
(int64_t, 1),
(int64_t, 2),
#endif
int32_t)
(int32_t, 1),
(int32_t, 2))
{
using Key = key_pair<TestType>;
using Value = TestType;
using ProbeKey = key_triplet<TestType>;
using Key = T;
using Value = T;
using InsertKey = key_pair<Key>;
using ProbeKey = key_triplet<Key>;
using probe_type = cuco::experimental::double_hashing<CGSize, custom_hasher, custom_hasher>;

auto const sentinel_key = Key{-1};
auto const sentinel_value = Value{-1};

constexpr std::size_t num = 100;
constexpr std::size_t capacity = num * 2;
cuco::static_map<Key, Value> map{
capacity, cuco::empty_key<Key>{sentinel_key}, cuco::empty_value<Value>{sentinel_value}};

auto insert_pairs =
thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto const probe = probe_type{custom_hasher{}, custom_hasher{}};

auto map = cuco::experimental::static_map<Key,
Value,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
custom_key_equal,
probe_type>{capacity,
cuco::empty_key<Key>{sentinel_key},
cuco::empty_value<Value>{sentinel_value},
custom_key_equal{},
probe};

auto insert_pairs = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<InsertKey, Value>(i, i); });
auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return ProbeKey(i); });

SECTION("All inserted keys-value pairs should be contained")
{
thrust::device_vector<bool> contained(num);
map.insert(insert_pairs, insert_pairs + num, custom_hasher{}, custom_key_equal{});
map.contains(
probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{});
map.insert(insert_pairs, insert_pairs + num);
map.contains(probe_keys, probe_keys + num, contained.begin());
REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{}));
}

SECTION("Non-inserted keys-value pairs should not be contained")
{
thrust::device_vector<bool> contained(num);
map.contains(
probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{});
map.contains(probe_keys, probe_keys + num, contained.begin());
REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), thrust::identity{}));
}
}

0 comments on commit 224881a

Please sign in to comment.