Skip to content

Commit

Permalink
Guard device lambda with proclaim_return_type
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Nov 20, 2023
1 parent 47a8fff commit 6fa4e56
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 74 deletions.
59 changes: 35 additions & 24 deletions tests/static_map/custom_type_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

#include <tuple>

// User-defined key type
Expand Down Expand Up @@ -123,17 +125,18 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(num),
insert_keys.begin(),
[] __device__(auto i) { return Key{i}; });
cuda::proclaim_return_type<Key>([] __device__(auto i) { return Key{i}; }));

thrust::transform(thrust::device,
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(num),
insert_values.begin(),
[] __device__(auto i) { return Value{i}; });
cuda::proclaim_return_type<Value>([] __device__(auto i) { return Value{i}; }));

auto insert_pairs =
thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto insert_pairs = thrust::make_transform_iterator(
thrust::make_counting_iterator<int>(0),
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); }));

SECTION("All inserted keys-value pairs should be correctly recovered during find")
{
Expand All @@ -151,9 +154,9 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
REQUIRE(cuco::test::equal(insert_values.begin(),
insert_values.end(),
found_values.begin(),
[] __device__(Value lhs, Value rhs) {
cuda::proclaim_return_type<bool>([] __device__(Value lhs, Value rhs) {
return std::tie(lhs.f, lhs.s) == std::tie(rhs.f, rhs.s);
}));
})));
}

SECTION("All inserted keys-value pairs should be contained")
Expand All @@ -175,7 +178,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
insert_pairs,
insert_pairs + num,
thrust::counting_iterator<int>(0),
[] __device__(auto const& key) { return (key % 2) == 0; },
cuda::proclaim_return_type<bool>([] __device__(auto const& key) { return (key % 2) == 0; }),
hash_custom_key{},
custom_key_equals{});

Expand All @@ -187,12 +190,13 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
hash_custom_key{},
custom_key_equals{});

REQUIRE(cuco::test::equal(contained.begin(),
contained.end(),
thrust::counting_iterator<int>(0),
[] __device__(auto const& idx_contained, auto const& idx) {
return ((idx % 2) == 0) == idx_contained;
}));
REQUIRE(cuco::test::equal(
contained.begin(),
contained.end(),
thrust::counting_iterator<int>(0),
cuda::proclaim_return_type<bool>([] __device__(auto const& idx_contained, auto const& idx) {
return ((idx % 2) == 0) == idx_contained;
})));
}

SECTION("Non-inserted keys-value pairs should not be contained")
Expand All @@ -212,19 +216,23 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
map.insert(insert_pairs, insert_pairs + num, hash_custom_key{}, custom_key_equals{});
auto view = map.get_device_view();
REQUIRE(cuco::test::all_of(
insert_pairs, insert_pairs + num, [view] __device__(cuco::pair<Key, Value> const& pair) {
insert_pairs,
insert_pairs + num,
cuda::proclaim_return_type<bool>([view] __device__(cuco::pair<Key, Value> const& pair) {
return view.contains(pair.first, hash_custom_key{}, custom_key_equals{});
}));
})));
}

SECTION("Inserting unique keys should return insert success.")
{
auto m_view = map.get_device_mutable_view();
REQUIRE(cuco::test::all_of(insert_pairs,
insert_pairs + num,
[m_view] __device__(cuco::pair<Key, Value> const& pair) mutable {
return m_view.insert(pair, hash_custom_key{}, custom_key_equals{});
}));
cuda::proclaim_return_type<bool>(
[m_view] __device__(cuco::pair<Key, Value> const& pair) mutable {
return m_view.insert(
pair, hash_custom_key{}, custom_key_equals{});
})));
}

SECTION("Cannot find any key in an empty hash map")
Expand All @@ -235,18 +243,21 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type",
REQUIRE(cuco::test::all_of(
insert_pairs,
insert_pairs + num,
[view] __device__(cuco::pair<Key, Value> const& pair) mutable {
return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end();
}));
cuda::proclaim_return_type<bool>(
[view] __device__(cuco::pair<Key, Value> const& pair) mutable {
return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end();
})));
}

SECTION("const view")
{
auto const view = map.get_device_view();
REQUIRE(cuco::test::all_of(
insert_pairs, insert_pairs + num, [view] __device__(cuco::pair<Key, Value> const& pair) {
insert_pairs,
insert_pairs + num,
cuda::proclaim_return_type<bool>([view] __device__(cuco::pair<Key, Value> const& pair) {
return view.find(pair.first, hash_custom_key{}, custom_key_equals{}) == view.end();
}));
})));
}
}
}
5 changes: 4 additions & 1 deletion tests/static_map/duplicate_keys_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

using size_type = std::size_t;

TEMPLATE_TEST_CASE_SIG(
Expand Down Expand Up @@ -83,7 +85,8 @@ TEMPLATE_TEST_CASE_SIG(

auto pairs_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i / 2, i / 2); });
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i / 2, i / 2); }));

thrust::device_vector<Value> d_results(num_keys);
thrust::device_vector<bool> d_contained(num_keys);
Expand Down
7 changes: 5 additions & 2 deletions tests/static_map/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

#include <tuple>

// insert key type
Expand Down Expand Up @@ -115,8 +117,9 @@ TEMPLATE_TEST_CASE_SIG("Heterogeneous lookup",
auto insert_pairs = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<InsertKey, Value>(i, i); });
auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return ProbeKey(i); });
auto probe_keys = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
cuda::proclaim_return_type<ProbeKey>([] __device__(auto i) { return ProbeKey{i}; }));

SECTION("All inserted keys-value pairs should be contained")
{
Expand Down
9 changes: 6 additions & 3 deletions tests/static_map/insert_and_find_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

static constexpr int Iters = 10'000;

template <typename Ref>
Expand Down Expand Up @@ -129,7 +131,8 @@ TEMPLATE_TEST_CASE_SIG(
thrust::sequence(thrust::device, d_keys.begin(), d_keys.end());
map.find(d_keys.begin(), d_keys.end(), d_values.begin());

REQUIRE(cuco::test::all_of(d_values.begin(), d_values.end(), [] __device__(Value v) {
return v == (Blocks * Threads) / CGSize;
}));
REQUIRE(cuco::test::all_of(
d_values.begin(), d_values.end(), cuda::proclaim_return_type<bool>([] __device__(Value v) {
return v == (Blocks * Threads) / CGSize;
})));
}
15 changes: 10 additions & 5 deletions tests/static_map/insert_or_assign_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

using size_type = std::size_t;

template <typename Map>
Expand All @@ -36,9 +38,11 @@ __inline__ void test_insert_or_assign(Map& map, size_type num_keys)
using Value = typename Map::mapped_type;

// Insert pairs
auto pairs_begin =
thrust::make_transform_iterator(thrust::counting_iterator<size_type>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto pairs_begin = thrust::make_transform_iterator(
thrust::counting_iterator<size_type>(0),
cuda::proclaim_return_type<cuco::pair<Key, Value>>([] __device__(auto i) {
return cuco::pair<Key, Value>{i, i};
}));

auto const initial_size = map.insert(pairs_begin, pairs_begin + num_keys);
REQUIRE(initial_size == num_keys); // all keys should be inserted
Expand All @@ -58,8 +62,9 @@ __inline__ void test_insert_or_assign(Map& map, size_type num_keys)
thrust::device_vector<Key> d_values(num_keys);
map.retrieve_all(d_keys.begin(), d_values.begin());

auto gold_values_begin = thrust::make_transform_iterator(thrust::counting_iterator<size_type>(0),
[] __device__(auto i) { return i * 2; });
auto gold_values_begin = thrust::make_transform_iterator(
thrust::counting_iterator<size_type>(0),
cuda::proclaim_return_type<size_type>([] __device__(auto i) { return i * 2; }));

thrust::sort(thrust::device, d_values.begin(), d_values.end());
REQUIRE(cuco::test::equal(
Expand Down
37 changes: 21 additions & 16 deletions tests/static_map/key_sentinel_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

#define SIZE 10
__device__ int A[SIZE];

Expand Down Expand Up @@ -55,32 +57,35 @@ TEMPLATE_TEST_CASE_SIG(
}
CUCO_CUDA_TRY(cudaMemcpyToSymbol(A, h_A, SIZE * sizeof(int)));

auto pairs_begin =
thrust::make_transform_iterator(thrust::make_counting_iterator<T>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto pairs_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<T>(0),
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); }));

SECTION(
"Tests of non-CG insert: The custom `key_equal` can never be used to compare against sentinel")
{
REQUIRE(cuco::test::all_of(pairs_begin,
pairs_begin + num_keys,
[insert_ref] __device__(cuco::pair<Key, Value> const& pair) mutable {
return insert_ref.insert(pair);
}));
REQUIRE(
cuco::test::all_of(pairs_begin,
pairs_begin + num_keys,
cuda::proclaim_return_type<bool>(
[insert_ref] __device__(cuco::pair<Key, Value> const& pair) mutable {
return insert_ref.insert(pair);
})));
}

SECTION(
"Tests of CG insert: The custom `key_equal` can never be used to compare against sentinel")
{
map.insert(pairs_begin, pairs_begin + num_keys);
// All keys inserted via custom `key_equal` should be found
REQUIRE(cuco::test::all_of(pairs_begin,
pairs_begin + num_keys,
[find_ref] __device__(cuco::pair<Key, Value> const& pair) {
auto const found = find_ref.find(pair.first);
return (found != find_ref.end()) and
(found->first == pair.first and
found->second == pair.second);
}));
REQUIRE(cuco::test::all_of(
pairs_begin,
pairs_begin + num_keys,
cuda::proclaim_return_type<bool>([find_ref] __device__(cuco::pair<Key, Value> const& pair) {
auto const found = find_ref.find(pair.first);
return (found != find_ref.end()) and
(found->first == pair.first and found->second == pair.second);
})));
}
}
10 changes: 7 additions & 3 deletions tests/static_map/shared_memory_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

#include <limits>

template <typename MapType, int CAPACITY>
Expand Down Expand Up @@ -126,9 +128,11 @@ TEMPLATE_TEST_CASE_SIG("Shared memory static map",
auto zip = thrust::make_zip_iterator(
thrust::make_tuple(d_keys_exist.begin(), d_keys_and_values_correct.begin()));

REQUIRE(cuco::test::all_of(zip, zip + d_keys_exist.size(), [] __device__(auto const& z) {
return thrust::get<0>(z) and thrust::get<1>(z);
}));
REQUIRE(cuco::test::all_of(zip,
zip + d_keys_exist.size(),
cuda::proclaim_return_type<bool>([] __device__(auto const& z) {
return thrust::get<0>(z) and thrust::get<1>(z);
})));
}

SECTION("No key is found before insertion.")
Expand Down
20 changes: 12 additions & 8 deletions tests/static_map/stream_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

TEMPLATE_TEST_CASE_SIG("static_map: unique sequence of keys on given stream",
"",
((typename Key, typename Value), Key, Value),
Expand Down Expand Up @@ -56,9 +58,10 @@ TEMPLATE_TEST_CASE_SIG("static_map: unique sequence of keys on given stream",
thrust::sequence(thrust::device, d_keys.begin(), d_keys.end());
thrust::sequence(thrust::device, d_values.begin(), d_values.end());

auto pairs_begin =
thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto pairs_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<int>(0),
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); }));

// bulk function test cases
SECTION("All inserted keys-value pairs should be correctly recovered during find")
Expand All @@ -69,11 +72,12 @@ TEMPLATE_TEST_CASE_SIG("static_map: unique sequence of keys on given stream",
map.find(d_keys.begin(), d_keys.end(), d_results.begin(), stream);
auto zip = thrust::make_zip_iterator(thrust::make_tuple(d_results.begin(), d_values.begin()));

REQUIRE(cuco::test::all_of(
zip,
zip + num_keys,
[] __device__(auto const& p) { return thrust::get<0>(p) == thrust::get<1>(p); },
stream));
REQUIRE(cuco::test::all_of(zip,
zip + num_keys,
cuda::proclaim_return_type<bool>([] __device__(auto const& p) {
return thrust::get<0>(p) == thrust::get<1>(p);
}),
stream));
}

SECTION("All inserted keys-value pairs should be contained")
Expand Down
Loading

0 comments on commit 6fa4e56

Please sign in to comment.