Skip to content

Commit

Permalink
Merge branch 'dev' into add-count
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel authored Oct 30, 2024
2 parents 299f083 + 494f321 commit 1c6a7af
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 3 deletions.
31 changes: 31 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap.inl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,37 @@ void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator,
impl_->contains_if_async(first, last, stencil, pred, output_begin, ref(op::contains), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename OutputIt>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find(
InputIt first, InputIt last, OutputIt output_begin, cuda::stream_ref stream) const
{
this->find_async(first, last, output_begin, stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename OutputIt>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
find_async(InputIt first, InputIt last, OutputIt output_begin, cuda::stream_ref stream) const
{
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
64 changes: 64 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,70 @@ class operator_impl<
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<
op::find_tag,
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type =
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using iterator = typename base_type::iterator;
using const_iterator = typename base_type::const_iterator;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;

public:
/**
* @brief Finds an element in the map with key equivalent to the probe key.
*
* @note Returns a un-incrementable input iterator to the element whose key is equivalent to
* `key`. If no such element exists, returns `end()`.
*
* @tparam ProbeKey Probe key type
*
* @param key The key to search for
*
* @return An iterator to the position at which the equivalent key is stored
*/
template <typename ProbeKey>
[[nodiscard]] __device__ const_iterator find(ProbeKey const& key) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.find(key);
}

/**
* @brief Finds an element in the map with key equivalent to the probe key.
*
* @note Returns a un-incrementable input iterator to the element whose key is equivalent to
* `key`. If no such element exists, returns `end()`.
*
* @tparam ProbeKey Probe key type
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
*
* @return An iterator to the position at which the equivalent key is stored
*/
template <typename ProbeKey>
[[nodiscard]] __device__ const_iterator find(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.find(group, key);
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
Expand Down
46 changes: 46 additions & 0 deletions include/cuco/static_multimap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,52 @@ class static_multimap {
OutputIt output_begin,
cuda::stream_ref stream = {}) const noexcept;

/**
* @brief For all keys in the range `[first, last)`, finds a payload with its key equivalent to
* the query key.
*
* @note This function synchronizes the given stream. For asynchronous execution use `find_async`.
* @note If the key `*(first + i)` has a matched `element` in the map, copies the payload of
* `element` to `(output_begin + i)`. Else, copies the empty value sentinel.
* @note For a given key `*(first + i)`, if there are multiple matching elements in the multimap,
* it copies the payload of one match (unspecified which) to `(output_begin + i)`. If no match is
* found, it copies the empty value sentinel instead.
*
* @tparam InputIt Device accessible input iterator
* @tparam OutputIt Device accessible output iterator assignable from the map's `mapped_type`
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param output_begin Beginning of the sequence of payloads retrieved for each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt, typename OutputIt>
void find(InputIt first, InputIt last, OutputIt output_begin, cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, asynchronously finds a payload with its key
* equivalent to the query key.
*
* @note If the key `*(first + i)` has a matched `element` in the map, copies the payload of
* `element` to `(output_begin + i)`. Else, copies the empty value sentinel.
* @note For a given key `*(first + i)`, if there are multiple matching elements in the multimap,
* it copies the payload of one match (unspecified which) to `(output_begin + i)`. If no match is
* found, it copies the empty value sentinel instead.
*
* @tparam InputIt Device accessible input iterator
* @tparam OutputIt Device accessible output iterator assignable from the map's `mapped_type`
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param output_begin Beginning of the sequence of payloads retrieved for each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt, typename OutputIt>
void find_async(InputIt first,
InputIt last,
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief Counts the occurrences of keys in `[first, last)` contained in the multimap
*
Expand Down
12 changes: 9 additions & 3 deletions include/cuco/static_multiset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,15 @@ class static_multiset {
cuda::stream_ref stream = {}) const noexcept;

/**
* @brief For all keys in the range `[first, last)`, finds an element with key equivalent to the
* query key.
* @brief For all keys in the range `[first, last)`, finds an element with its key equivalent to
* the query key.
*
* @note This function synchronizes the given stream. For asynchronous execution use `find_async`.
* @note If the key `*(first + i)` has a matched `element` in the multiset, copies `element` to
* `(output_begin + i)`. Else, copies the empty key sentinel.
* @note For a given key `*(first + i)`, if there are multiple matching elements in the multiset,
* it copies the payload of one match (unspecified which) to `(output_begin + i)`. If no match is
* found, it copies the empty key sentinel instead.
*
* @tparam InputIt Device accessible input iterator
* @tparam OutputIt Device accessible output iterator assignable from the set's `key_type`
Expand All @@ -456,11 +459,14 @@ class static_multiset {
void find(InputIt first, InputIt last, OutputIt output_begin, cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, asynchronously finds an element with key
* @brief For all keys in the range `[first, last)`, asynchronously finds an element with its key
* equivalent to the query key.
*
* @note If the key `*(first + i)` has a matched `element` in the multiset, copies `element` to
* `(output_begin + i)`. Else, copies the empty key sentinel.
* @note For a given key `*(first + i)`, if there are multiple matching elements in the multiset,
* it copies the payload of one match (unspecified which) to `(output_begin + i)`. If no match is
* found, it copies the empty key sentinel instead.
*
* @tparam InputIt Device accessible input iterator
* @tparam OutputIt Device accessible output iterator assignable from the set's `key_type`
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ ConfigureTest(STATIC_MULTIMAP_TEST
static_multimap/count_test.cu
static_multimap/custom_pair_retrieve_test.cu
static_multimap/custom_type_test.cu
static_multimap/find_test.cu
static_multimap/heterogeneous_lookup_test.cu
static_multimap/insert_contains_test.cu
static_multimap/insert_if_test.cu
Expand Down
106 changes: 106 additions & 0 deletions tests/static_multimap/find_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <test_utils.hpp>

#include <cuco/static_multimap.cuh>

#include <cuda/functional>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/transform_iterator.h>

#include <catch2/catch_template_test_macros.hpp>

using size_type = int32_t;

template <typename Map>
void test_multimap_find(Map& map, size_type num_keys)
{
using Key = typename Map::key_type;
using Value = typename Map::mapped_type;

auto zip_equal = cuda::proclaim_return_type<bool>(
[] __device__(auto val) { return thrust::get<0>(val) == thrust::get<1>(val); });

auto const keys_begin = thrust::counting_iterator<Key>{0};

SECTION("Non-inserted keys have no matches")
{
thrust::device_vector<Value> found_vals(num_keys);

map.find(keys_begin, keys_begin + num_keys, found_vals.begin());
auto zip = thrust::make_zip_iterator(thrust::make_tuple(
found_vals.begin(), thrust::constant_iterator<Value>{map.empty_value_sentinel()}));

REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal));
}

auto const pairs_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<size_type>(0),
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>{i, i * 2}; }));

map.insert(pairs_begin, pairs_begin + num_keys);

SECTION("All inserted keys should be correctly recovered during find")
{
thrust::device_vector<Value> found_vals(num_keys);

map.find(keys_begin, keys_begin + num_keys, found_vals.begin());

auto const gold_vals_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<size_type>(0),
cuda::proclaim_return_type<Value>([] __device__(auto i) { return Value{i * 2}; }));
auto zip = thrust::make_zip_iterator(thrust::make_tuple(found_vals.begin(), gold_vals_begin));

REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal));
}
}

TEMPLATE_TEST_CASE_SIG(
"static_multimap find tests",
"",
((typename T, cuco::test::probe_sequence Probe, int CGSize), T, Probe, CGSize),
(int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, cuco::test::probe_sequence::linear_probing, 2))
{
constexpr size_type num_keys{1'000};

using probe = std::conditional_t<
Probe == cuco::test::probe_sequence::linear_probing,
cuco::linear_probing<CGSize, cuco::default_hash_function<T>>,
cuco::double_hashing<CGSize, cuco::default_hash_function<T>, cuco::default_hash_function<T>>>;

auto map = cuco::experimental::static_multimap<T,
T,
cuco::extent<size_type>,
cuda::thread_scope_device,
thrust::equal_to<T>,
probe,
cuco::cuda_allocator<cuda::std::byte>,
cuco::storage<2>>{
num_keys, cuco::empty_key<T>{-1}, cuco::empty_value<T>{-2}};

test_multimap_find(map, num_keys);
}

0 comments on commit 1c6a7af

Please sign in to comment.