diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index 9d0bdbcff..7048a5426 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -284,6 +284,37 @@ void static_multimapcontains_if_async(first, last, stencil, pred, output_begin, ref(op::contains), stream); } +template +template +void static_multimap::find( + InputIt first, InputIt last, OutputIt output_begin, cuda::stream_ref stream) const +{ + this->find_async(first, last, output_begin, stream); + stream.wait(); +} + +template +template +void static_multimap:: + find_async(InputIt first, InputIt last, OutputIt output_begin, cuda::stream_ref stream) const +{ + impl_->find_async(first, last, output_begin, ref(op::find), stream); +} + template +class operator_impl< + op::find_tag, + static_multimap_ref> { + using base_type = static_multimap_ref; + using ref_type = + static_multimap_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + using iterator = typename base_type::iterator; + using const_iterator = typename base_type::const_iterator; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + /** + * @brief Finds an element in the map with key equivalent to the probe key. + * + * @note Returns a un-incrementable input iterator to the element whose key is equivalent to + * `key`. If no such element exists, returns `end()`. + * + * @tparam ProbeKey Probe key type + * + * @param key The key to search for + * + * @return An iterator to the position at which the equivalent key is stored + */ + template + [[nodiscard]] __device__ const_iterator find(ProbeKey const& key) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + return ref_.impl_.find(key); + } + + /** + * @brief Finds an element in the map with key equivalent to the probe key. + * + * @note Returns a un-incrementable input iterator to the element whose key is equivalent to + * `key`. If no such element exists, returns `end()`. + * + * @tparam ProbeKey Probe key type + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * + * @return An iterator to the position at which the equivalent key is stored + */ + template + [[nodiscard]] __device__ const_iterator find( + cooperative_groups::thread_block_tile const& group, ProbeKey const& key) const noexcept + { + auto const& ref_ = static_cast(*this); + return ref_.impl_.find(group, key); + } +}; + template + void find(InputIt first, InputIt last, OutputIt output_begin, cuda::stream_ref stream = {}) const; + + /** + * @brief For all keys in the range `[first, last)`, asynchronously finds a payload with its key + * equivalent to the query key. + * + * @note If the key `*(first + i)` has a matched `element` in the map, copies the payload of + * `element` to `(output_begin + i)`. Else, copies the empty value sentinel. + * @note For a given key `*(first + i)`, if there are multiple matching elements in the multimap, + * it copies the payload of one match (unspecified which) to `(output_begin + i)`. If no match is + * found, it copies the empty value sentinel instead. + * + * @tparam InputIt Device accessible input iterator + * @tparam OutputIt Device accessible output iterator assignable from the map's `mapped_type` + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param output_begin Beginning of the sequence of payloads retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_async(InputIt first, + InputIt last, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + /** * @brief Counts the occurrences of keys in `[first, last)` contained in the multimap * diff --git a/include/cuco/static_multiset.cuh b/include/cuco/static_multiset.cuh index c4b1dff76..943465c51 100644 --- a/include/cuco/static_multiset.cuh +++ b/include/cuco/static_multiset.cuh @@ -437,12 +437,15 @@ class static_multiset { cuda::stream_ref stream = {}) const noexcept; /** - * @brief For all keys in the range `[first, last)`, finds an element with key equivalent to the - * query key. + * @brief For all keys in the range `[first, last)`, finds an element with its key equivalent to + * the query key. * * @note This function synchronizes the given stream. For asynchronous execution use `find_async`. * @note If the key `*(first + i)` has a matched `element` in the multiset, copies `element` to * `(output_begin + i)`. Else, copies the empty key sentinel. + * @note For a given key `*(first + i)`, if there are multiple matching elements in the multiset, + * it copies the payload of one match (unspecified which) to `(output_begin + i)`. If no match is + * found, it copies the empty key sentinel instead. * * @tparam InputIt Device accessible input iterator * @tparam OutputIt Device accessible output iterator assignable from the set's `key_type` @@ -456,11 +459,14 @@ class static_multiset { void find(InputIt first, InputIt last, OutputIt output_begin, cuda::stream_ref stream = {}) const; /** - * @brief For all keys in the range `[first, last)`, asynchronously finds an element with key + * @brief For all keys in the range `[first, last)`, asynchronously finds an element with its key * equivalent to the query key. * * @note If the key `*(first + i)` has a matched `element` in the multiset, copies `element` to * `(output_begin + i)`. Else, copies the empty key sentinel. + * @note For a given key `*(first + i)`, if there are multiple matching elements in the multiset, + * it copies the payload of one match (unspecified which) to `(output_begin + i)`. If no match is + * found, it copies the empty key sentinel instead. * * @tparam InputIt Device accessible input iterator * @tparam OutputIt Device accessible output iterator assignable from the set's `key_type` diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1b28746ea..bc7cc697f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -114,6 +114,7 @@ ConfigureTest(STATIC_MULTIMAP_TEST static_multimap/count_test.cu static_multimap/custom_pair_retrieve_test.cu static_multimap/custom_type_test.cu + static_multimap/find_test.cu static_multimap/heterogeneous_lookup_test.cu static_multimap/insert_contains_test.cu static_multimap/insert_if_test.cu diff --git a/tests/static_multimap/find_test.cu b/tests/static_multimap/find_test.cu new file mode 100644 index 000000000..51456b088 --- /dev/null +++ b/tests/static_multimap/find_test.cu @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +using size_type = int32_t; + +template +void test_multimap_find(Map& map, size_type num_keys) +{ + using Key = typename Map::key_type; + using Value = typename Map::mapped_type; + + auto zip_equal = cuda::proclaim_return_type( + [] __device__(auto val) { return thrust::get<0>(val) == thrust::get<1>(val); }); + + auto const keys_begin = thrust::counting_iterator{0}; + + SECTION("Non-inserted keys have no matches") + { + thrust::device_vector found_vals(num_keys); + + map.find(keys_begin, keys_begin + num_keys, found_vals.begin()); + auto zip = thrust::make_zip_iterator(thrust::make_tuple( + found_vals.begin(), thrust::constant_iterator{map.empty_value_sentinel()})); + + REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); + } + + auto const pairs_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cuda::proclaim_return_type>( + [] __device__(auto i) { return cuco::pair{i, i * 2}; })); + + map.insert(pairs_begin, pairs_begin + num_keys); + + SECTION("All inserted keys should be correctly recovered during find") + { + thrust::device_vector found_vals(num_keys); + + map.find(keys_begin, keys_begin + num_keys, found_vals.begin()); + + auto const gold_vals_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cuda::proclaim_return_type([] __device__(auto i) { return Value{i * 2}; })); + auto zip = thrust::make_zip_iterator(thrust::make_tuple(found_vals.begin(), gold_vals_begin)); + + REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); + } +} + +TEMPLATE_TEST_CASE_SIG( + "static_multimap find tests", + "", + ((typename T, cuco::test::probe_sequence Probe, int CGSize), T, Probe, CGSize), + (int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_type num_keys{1'000}; + + using probe = std::conditional_t< + Probe == cuco::test::probe_sequence::linear_probing, + cuco::linear_probing>, + cuco::double_hashing, cuco::default_hash_function>>; + + auto map = cuco::experimental::static_multimap, + cuda::thread_scope_device, + thrust::equal_to, + probe, + cuco::cuda_allocator, + cuco::storage<2>>{ + num_keys, cuco::empty_key{-1}, cuco::empty_value{-2}}; + + test_multimap_find(map, num_keys); +}