Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/dev' into build-script
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Nov 21, 2023
2 parents 143df1d + c7d52a2 commit cf1cc8b
Show file tree
Hide file tree
Showing 11 changed files with 345 additions and 253 deletions.
89 changes: 81 additions & 8 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ class open_addressing_ref_impl {
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
#if __CUDA_ARCH__ < 700
// Spinning to ensure that the write to the value part took place requires
// independent thread scheduling introduced with the Volta architecture.
static_assert(
cuco::detail::is_packable<value_type>(),
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const key = this->extract_key(value);
auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());
Expand All @@ -397,21 +404,36 @@ class open_addressing_ref_impl {
auto* window_ptr = (storage_ref_.data() + *probing_iter)->data();

// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; }
if (eq_res == detail::equal_result::EQUAL) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, false};
}
if (eq_res == detail::equal_result::EMPTY or
cuco::detail::bitwise_compare(this->extract_key(window_slots[i]),
this->erased_key_sentinel())) {
switch ([&]() {
auto const res = [&]() {
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(window_ptr + i, window_slots[i], value);
} else {
return cas_dependent_write(window_ptr + i, window_slots[i], value);
}
}()) {
}();
switch (res) {
case insert_result::SUCCESS: {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, true};
}
case insert_result::DUPLICATE: {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, false};
}
default: continue;
Expand Down Expand Up @@ -441,6 +463,14 @@ class open_addressing_ref_impl {
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
#if __CUDA_ARCH__ < 700
// Spinning to ensure that the write to the value part took place requires
// independent thread scheduling introduced with the Volta architecture.
static_assert(
cuco::detail::is_packable<value_type>(),
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const key = this->extract_key(value);
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

Expand Down Expand Up @@ -475,6 +505,13 @@ class open_addressing_ref_impl {
if (group_finds_equal) {
auto const src_lane = __ffs(group_finds_equal) - 1;
auto const res = group.shfl(reinterpret_cast<intptr_t>(slot_ptr), src_lane);
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
group.sync();
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}

Expand All @@ -494,9 +531,23 @@ class open_addressing_ref_impl {

switch (group.shfl(status, src_lane)) {
case insert_result::SUCCESS: {
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
group.sync();
return {iterator{reinterpret_cast<value_type*>(res)}, true};
}
case insert_result::DUPLICATE: {
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
group.sync();
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}
default: continue;
Expand Down Expand Up @@ -595,12 +646,12 @@ class open_addressing_ref_impl {
case insert_result::DUPLICATE: return false;
default: continue;
}
} else if (group.any(state == detail::equal_result::EMPTY)) {
// Key doesn't exist, return false
return false;
} else {
++probing_iter;
}

// Key doesn't exist, return false
if (group.any(state == detail::equal_result::EMPTY)) { return false; }

++probing_iter;
}
}

Expand Down Expand Up @@ -1010,6 +1061,7 @@ class open_addressing_ref_impl {
// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&address->second, static_cast<mapped_type>(thrust::get<1>(desired)));

return insert_result::SUCCESS;
}

Expand Down Expand Up @@ -1054,6 +1106,27 @@ class open_addressing_ref_impl {
}
}

/**
* @brief Waits until the slot payload has been updated
*
* @note The function will return once the slot payload is no longer equal to the sentinel value.
*
* @tparam T Map slot type
*
* @param slot The target slot to check payload with
* @param sentinel The slot sentinel value
*/
template <typename T>
__device__ void wait_for_payload(T& slot, T const& sentinel) const noexcept
{
auto ref = cuda::atomic_ref<T, Scope>{slot};
T current;
// TODO exponential backoff strategy
do {
current = ref.load(cuda::std::memory_order_relaxed);
} while (cuco::detail::bitwise_compare(current, sentinel));
}

// TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper
value_type empty_slot_sentinel_; ///< Sentinel value indicating an empty slot
detail::equal_wrapper<key_type, key_equal> predicate_; ///< Key equality binary callable
Expand Down
59 changes: 35 additions & 24 deletions tests/static_map/custom_type_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

#include <tuple>

// User-defined key type
Expand Down Expand Up @@ -123,17 +125,18 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(num),
insert_keys.begin(),
[] __device__(auto i) { return Key{i}; });
cuda::proclaim_return_type<Key>([] __device__(auto i) { return Key{i}; }));

thrust::transform(thrust::device,
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(num),
insert_values.begin(),
[] __device__(auto i) { return Value{i}; });
cuda::proclaim_return_type<Value>([] __device__(auto i) { return Value{i}; }));

auto insert_pairs =
thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto insert_pairs = thrust::make_transform_iterator(
thrust::make_counting_iterator<int>(0),
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); }));

SECTION("All inserted keys-value pairs should be correctly recovered during find")
{
Expand All @@ -151,9 +154,9 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
REQUIRE(cuco::test::equal(insert_values.begin(),
insert_values.end(),
found_values.begin(),
[] __device__(Value lhs, Value rhs) {
cuda::proclaim_return_type<bool>([] __device__(Value lhs, Value rhs) {
return std::tie(lhs.f, lhs.s) == std::tie(rhs.f, rhs.s);
}));
})));
}

SECTION("All inserted keys-value pairs should be contained")
Expand All @@ -175,7 +178,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
insert_pairs,
insert_pairs + num,
thrust::counting_iterator<int>(0),
[] __device__(auto const& key) { return (key % 2) == 0; },
cuda::proclaim_return_type<bool>([] __device__(auto const& key) { return (key % 2) == 0; }),
hash_custom_key{},
custom_key_equals{});

Expand All @@ -187,12 +190,13 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
hash_custom_key{},
custom_key_equals{});

REQUIRE(cuco::test::equal(contained.begin(),
contained.end(),
thrust::counting_iterator<int>(0),
[] __device__(auto const& idx_contained, auto const& idx) {
return ((idx % 2) == 0) == idx_contained;
}));
REQUIRE(cuco::test::equal(
contained.begin(),
contained.end(),
thrust::counting_iterator<int>(0),
cuda::proclaim_return_type<bool>([] __device__(auto const& idx_contained, auto const& idx) {
return ((idx % 2) == 0) == idx_contained;
})));
}

SECTION("Non-inserted keys-value pairs should not be contained")
Expand All @@ -212,19 +216,23 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
map.insert(insert_pairs, insert_pairs + num, hash_custom_key{}, custom_key_equals{});
auto view = map.get_device_view();
REQUIRE(cuco::test::all_of(
insert_pairs, insert_pairs + num, [view] __device__(cuco::pair<Key, Value> const& pair) {
insert_pairs,
insert_pairs + num,
cuda::proclaim_return_type<bool>([view] __device__(cuco::pair<Key, Value> const& pair) {
return view.contains(pair.first, hash_custom_key{}, custom_key_equals{});
}));
})));
}

SECTION("Inserting unique keys should return insert success.")
{
auto m_view = map.get_device_mutable_view();
REQUIRE(cuco::test::all_of(insert_pairs,
insert_pairs + num,
[m_view] __device__(cuco::pair<Key, Value> const& pair) mutable {
return m_view.insert(pair, hash_custom_key{}, custom_key_equals{});
}));
cuda::proclaim_return_type<bool>(
[m_view] __device__(cuco::pair<Key, Value> const& pair) mutable {
return m_view.insert(
pair, hash_custom_key{}, custom_key_equals{});
})));
}

SECTION("Cannot find any key in an empty hash map")
Expand All @@ -235,18 +243,21 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
REQUIRE(cuco::test::all_of(
insert_pairs,
insert_pairs + num,
[view] __device__(cuco::pair<Key, Value> const& pair) mutable {
return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end();
}));
cuda::proclaim_return_type<bool>(
[view] __device__(cuco::pair<Key, Value> const& pair) mutable {
return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end();
})));
}

SECTION("const view")
{
auto const view = map.get_device_view();
REQUIRE(cuco::test::all_of(
insert_pairs, insert_pairs + num, [view] __device__(cuco::pair<Key, Value> const& pair) {
insert_pairs,
insert_pairs + num,
cuda::proclaim_return_type<bool>([view] __device__(cuco::pair<Key, Value> const& pair) {
return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end();
}));
})));
}
}
}
59 changes: 48 additions & 11 deletions tests/static_map/duplicate_keys_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,52 @@

#include <catch2/catch_template_test_macros.hpp>

TEMPLATE_TEST_CASE_SIG("Duplicate keys",
"",
((typename Key, typename Value), Key, Value),
(int32_t, int32_t),
(int32_t, int64_t),
(int64_t, int32_t),
(int64_t, int64_t))
#include <cuda/functional>

using size_type = std::size_t;

TEMPLATE_TEST_CASE_SIG(
"static_map duplicate keys",
"",
((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize),
Key,
Value,
Probe,
CGSize),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2))
{
constexpr std::size_t num_keys{500'000};
cuco::static_map<Key, Value> map{
constexpr size_type num_keys{500'000};

using probe =
std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
cuco::experimental::linear_probing<CGSize, cuco::murmurhash3_32<Key>>,
cuco::experimental::double_hashing<CGSize,
cuco::murmurhash3_32<Key>,
cuco::murmurhash3_32<Key>>>;

auto map = cuco::experimental::static_map<Key,
Value,
cuco::experimental::extent<size_type>,
cuda::thread_scope_device,
thrust::equal_to<Key>,
probe,
cuco::cuda_allocator<std::byte>,
cuco::experimental::storage<2>>{
num_keys * 2, cuco::empty_key<Key>{-1}, cuco::empty_value<Value>{-1}};

thrust::device_vector<Key> d_keys(num_keys);
Expand All @@ -49,7 +85,8 @@ TEMPLATE_TEST_CASE_SIG("Duplicate keys",

auto pairs_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i / 2, i / 2); });
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i / 2, i / 2); }));

thrust::device_vector<Value> d_results(num_keys);
thrust::device_vector<bool> d_contained(num_keys);
Expand All @@ -68,7 +105,7 @@ TEMPLATE_TEST_CASE_SIG("Duplicate keys",

map.insert(pairs_begin, pairs_begin + num_keys);

auto const num_entries = map.get_size();
auto const num_entries = map.size();
REQUIRE(num_entries == gold);

auto [key_out_end, value_out_end] =
Expand Down
2 changes: 1 addition & 1 deletion tests/static_map/erase_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ TEMPLATE_TEST_CASE_SIG(
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2))
{
constexpr size_type num_keys{400};
constexpr size_type num_keys{1'000'000};

using probe =
std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
Expand Down
Loading

0 comments on commit cf1cc8b

Please sign in to comment.