From 0cd4da08be0289b20306ec44a68044668730c0a9 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Fri, 15 Sep 2023 10:37:40 -0700 Subject: [PATCH 1/5] Clean up ref implementations with `has_payload` flag (#368) #356 introduces the `HasPayload` template boolean to distinguish code paths between map and set implementations thus the key input for base ref insert functions becomes redundant. This PR cleans up the base ref implementations by removing the key input and fixes a logical issue in #356: set doesn't have payload while map has. --- .../cuco/detail/open_addressing_ref_impl.cuh | 55 +++++++++++++------ .../cuco/detail/static_map/static_map_ref.inl | 16 +++--- .../cuco/detail/static_set/static_set_ref.inl | 16 +++--- 3 files changed, 54 insertions(+), 33 deletions(-) diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index 46ef2bfd7..213d35af1 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -159,18 +159,23 @@ class open_addressing_ref_impl { * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * - * @param key Key of the element to insert * @param value The element to insert * @param predicate Predicate used to compare slot content against `key` * * @return True if the given element is successfully inserted */ template - __device__ bool insert(key_type const& key, - value_type const& value, - Predicate const& predicate) noexcept + __device__ bool insert(value_type const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto const key = [&]() { + if constexpr (HasPayload) { + return value.first; + } else { + return value; + } + }(); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { @@ -202,7 +207,6 @@ class open_addressing_ref_impl { * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert - * @param key Key of the element to insert * @param value The element to insert * @param predicate Predicate used to compare slot content against `key` * @@ -210,10 +214,16 @@ class open_addressing_ref_impl { */ template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - key_type const& key, value_type const& value, Predicate const& predicate) noexcept { + auto const key = [&]() { + if constexpr (HasPayload) { + return value.first; + } else { + return value; + } + }(); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); while (true) { @@ -269,7 +279,6 @@ class open_addressing_ref_impl { * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * - * @param key Key of the element to insert * @param value The element to insert * @param predicate Predicate used to compare slot content against `key` * @@ -277,11 +286,18 @@ class open_addressing_ref_impl { * insertion is successful or not. */ template - __device__ thrust::pair insert_and_find(key_type const& key, - value_type const& value, + __device__ thrust::pair insert_and_find(value_type const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto const key = [&]() { + if constexpr (HasPayload) { + return value.first; + } else { + return value; + } + }(); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { @@ -326,7 +342,6 @@ class open_addressing_ref_impl { * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert_and_find - * @param key Key of the element to insert * @param value The element to insert * @param predicate Predicate used to compare slot content against `key` * @@ -336,10 +351,16 @@ class open_addressing_ref_impl { template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, - key_type const& key, value_type const& value, Predicate const& predicate) noexcept { + auto const key = [&]() { + if constexpr (HasPayload) { + return value.first; + } else { + return value; + } + }(); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); while (true) { @@ -710,11 +731,11 @@ class open_addressing_ref_impl { auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { if constexpr (HasPayload) { - // If it's a set implementation, compare the whole slot content - return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); - } else { // If it's a map implementation, compare keys only return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); + } else { + // If it's a set implementation, compare the whole slot content + return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); } }(); if (inserted) { @@ -723,11 +744,11 @@ class open_addressing_ref_impl { // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare auto const res = [&]() { if constexpr (HasPayload) { - // If it's a set implementation, compare the whole slot content - return predicate.equal_to(*old_ptr, value); - } else { // If it's a map implementation, compare keys only return predicate.equal_to(old_ptr->first, value.first); + } else { + // If it's a set implementation, compare the whole slot content + return predicate.equal_to(*old_ptr, value); } }(); return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 13fc2ce47..250c84feb 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -210,8 +210,8 @@ class operator_impl< __device__ bool insert(value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert(value.first, value, ref_.predicate_); + auto constexpr has_payload = true; + return ref_.impl_.insert(value, ref_.predicate_); } /** @@ -225,8 +225,8 @@ class operator_impl< value_type const& value) noexcept { auto& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert(group, value.first, value, ref_.predicate_); + auto constexpr has_payload = true; + return ref_.impl_.insert(group, value, ref_.predicate_); } }; @@ -454,8 +454,8 @@ class operator_impl< __device__ thrust::pair insert_and_find(value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_); + auto constexpr has_payload = true; + return ref_.impl_.insert_and_find(value, ref_.predicate_); } /** @@ -475,8 +475,8 @@ class operator_impl< cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_); + auto constexpr has_payload = true; + return ref_.impl_.insert_and_find(group, value, ref_.predicate_); } }; diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 3131f3764..2bb7f0c6f 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -101,8 +101,8 @@ class operator_impl(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert(value, value, ref_.predicate_); + auto constexpr has_payload = false; + return ref_.impl_.insert(value, ref_.predicate_); } /** @@ -117,8 +117,8 @@ class operator_impl(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert(group, value, value, ref_.predicate_); + auto constexpr has_payload = false; + return ref_.impl_.insert(group, value, ref_.predicate_); } }; @@ -182,8 +182,8 @@ class operator_impl insert_and_find(value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert_and_find(value, value, ref_.predicate_); + auto constexpr has_payload = false; + return ref_.impl_.insert_and_find(value, ref_.predicate_); } /** @@ -203,8 +203,8 @@ class operator_impl const& group, value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_); + auto constexpr has_payload = false; + return ref_.impl_.insert_and_find(group, value, ref_.predicate_); } }; From 359f5ae67e93b69a8df35ebd1d12f746aac8916e Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 26 Sep 2023 13:13:44 -0700 Subject: [PATCH 2/5] Add device subsets example (#346) Depends on #349 This PR adds an example demonstrating how to create multiple subsets with one single storage. It includes necessary changes and cleanups that will unblock orc/parquet dictionary encoding (https://github.com/rapidsai/cudf/issues/12261) to use the new map/set data structures. --------- Co-authored-by: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> --- examples/CMakeLists.txt | 1 + examples/static_set/device_ref_example.cu | 16 +- examples/static_set/device_subsets_example.cu | 183 ++++++++++++++++++ include/cuco/aow_storage.cuh | 23 ++- include/cuco/detail/extent/extent.inl | 35 ++-- include/cuco/detail/open_addressing_impl.cuh | 8 +- .../cuco/detail/open_addressing_ref_impl.cuh | 10 +- .../cuco/detail/static_map/static_map_ref.inl | 39 ++++ .../cuco/detail/static_set/static_set_ref.inl | 34 ++++ include/cuco/detail/storage/aow_storage.inl | 8 + include/cuco/detail/storage/storage.cuh | 1 + include/cuco/extent.cuh | 16 +- include/cuco/static_map_ref.cuh | 42 ++++ include/cuco/static_set_ref.cuh | 41 ++++ include/cuco/storage.cuh | 1 + 15 files changed, 403 insertions(+), 55 deletions(-) create mode 100644 examples/static_set/device_subsets_example.cu diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d78627eee..91e1417aa 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -35,6 +35,7 @@ endfunction(ConfigureExample) ConfigureExample(STATIC_SET_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_set/host_bulk_example.cu") ConfigureExample(STATIC_SET_DEVICE_REF_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_set/device_ref_example.cu") +ConfigureExample(STATIC_SET_DEVICE_SUBSETS_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_set/device_subsets_example.cu") ConfigureExample(STATIC_MAP_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/host_bulk_example.cu") ConfigureExample(STATIC_MAP_DEVICE_SIDE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/device_view_example.cu") ConfigureExample(STATIC_MAP_CUSTOM_TYPE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/custom_type_example.cu") diff --git a/examples/static_set/device_ref_example.cu b/examples/static_set/device_ref_example.cu index 136292f6b..52e41cf45 100644 --- a/examples/static_set/device_ref_example.cu +++ b/examples/static_set/device_ref_example.cu @@ -26,6 +26,14 @@ #include #include +/** + * @file device_reference_example.cu + * @brief Demonstrates usage of the static_set device-side APIs. + * + * static_set provides a non-owning reference which can be used to interact with + * the container from within device code. + */ + // insert a set of keys into a hash set using one cooperative group for each task template __global__ void custom_cooperative_insert(SetRef set, InputIterator keys, std::size_t n) @@ -60,14 +68,6 @@ __global__ void custom_contains(SetRef set, InputIterator keys, std::size_t n, O } } -/** - * @file device_reference_example.cu - * @brief Demonstrates usage of the static_set device-side APIs. - * - * static_set provides a non-owning reference which can be used to interact with - * the container from within device code. - * - */ int main(void) { using Key = int; diff --git a/examples/static_set/device_subsets_example.cu b/examples/static_set/device_subsets_example.cu new file mode 100644 index 000000000..827342f95 --- /dev/null +++ b/examples/static_set/device_subsets_example.cu @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +/** + * @file device_subsets_example.cu + * @brief Demonstrates how to use one bulk set storage to create multiple subsets and perform + * individual operations via device-side ref APIs. + * + * To optimize memory usage, especially when dealing with expensive data allocation and multiple + * hashsets, a practical solution involves employing a single bulk storage for generating subsets. + * This eliminates the need for separate memory allocation and deallocation for each container. This + * can be achieved by using the lightweight non-owning ref type. + * + * @note This example is for demonstration purposes only. It is not intended to show the most + * performant way to do the example algorithm. + */ + +auto constexpr cg_size = 8; ///< A CUDA Cooperative Group of 8 threads to handle each subset +auto constexpr window_size = 1; ///< Number of concurrent slots handled by each thread +auto constexpr N = 10; ///< Number of elements to insert and query + +using key_type = int; ///< Key type +using probing_scheme_type = cuco::experimental::linear_probing< + cg_size, + cuco::default_hash_function>; ///< Type controls CG granularity and probing scheme + ///< (linear probing v.s. double hashing) +/// Type of bulk allocation storage +using storage_type = cuco::experimental::aow_storage; +/// Lightweight non-owning storage ref type +using storage_ref_type = typename storage_type::ref_type; +using ref_type = cuco::experimental::static_set_ref, + probing_scheme_type, + storage_ref_type>; ///< Set ref type + +/// Sample data to insert and query +__device__ constexpr std::array data = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19}; +/// Empty slots are represented by reserved "sentinel" values. These values should be selected such +/// that they never occur in your input data. +key_type constexpr empty_key_sentinel = -1; + +/** + * @brief Inserts sample data into subsets by using cooperative group + * + * Each Cooperative Group creates its own subset and inserts `N` sample data. + * + * @param set_refs Pointer to the array of subset objects + */ +__global__ void insert(ref_type* set_refs) +{ + namespace cg = cooperative_groups; + + auto const tile = cg::tiled_partition(cg::this_thread_block()); + // Get subset (or CG) index + auto const idx = (blockDim.x * blockIdx.x + threadIdx.x) / cg_size; + + auto raw_set_ref = *(set_refs + idx); + auto insert_set_ref = std::move(raw_set_ref).with(cuco::experimental::insert); + + // Insert `N` elemtns into the set with CG insert + for (int i = 0; i < N; i++) { + insert_set_ref.insert(tile, data[i]); + } +} + +/** + * @brief All inserted data can be found + * + * Each Cooperative Group reconstructs its own subset ref based on the storage parameters and + * verifies all inserted data can be found. + * + * @param set_refs Pointer to the array of subset objects + */ +__global__ void find(ref_type* set_refs) +{ + namespace cg = cooperative_groups; + + auto const tile = cg::tiled_partition(cg::this_thread_block()); + auto const idx = (blockDim.x * blockIdx.x + threadIdx.x) / cg_size; + + auto raw_set_ref = *(set_refs + idx); + auto find_set_ref = std::move(raw_set_ref).with(cuco::experimental::find); + + // Result denoting if any of the inserted data is not found + __shared__ int result; + if (threadIdx.x == 0) { result = 0; } + __syncthreads(); + + for (int i = 0; i < N; i++) { + // Query the set with inserted data + auto const found = find_set_ref.find(tile, data[i]); + // Record if the inserted data has been found + atomicOr(&result, *found != data[i]); + } + __syncthreads(); + + if (threadIdx.x == 0) { + // If the result is still 0, all inserted data are found. + if (result == 0) { printf("Success! Found all inserted elements.\n"); } + } +} + +int main() +{ + // Number of subsets to be created + auto constexpr num = 16; + // Each subset may have a different requested size + auto constexpr subset_sizes = + std::array{20, 20, 20, 20, 30, 30, 30, 30, 40, 40, 40, 40, 50, 50, 50, 50}; + + auto valid_sizes = std::vector(); + valid_sizes.reserve(num); + + for (size_t i = 0; i < num; ++i) { + valid_sizes.emplace_back( + static_cast(cuco::experimental::make_window_extent(subset_sizes[i]))); + } + + std::vector offsets(num + 1, 0); + + // prefix sum to compute offsets and total number of windows + std::size_t current_sum = 0; + for (std::size_t i = 0; i < valid_sizes.size(); ++i) { + current_sum += valid_sizes[i]; + offsets[i + 1] = current_sum; + } + + // total number of windows is located at the back of the offsets array + auto const total_num_windows = offsets.back(); + + // Create a single bulk storage used by all subsets + auto set_storage = storage_type{total_num_windows}; + // Initializes the storage with the given sentinel + set_storage.initialize(empty_key_sentinel); + + std::vector set_refs; + + // create subsets + for (std::size_t i = 0; i < num; ++i) { + storage_ref_type storage_ref{valid_sizes[i], set_storage.data() + offsets[i]}; + set_refs.emplace_back( + ref_type{cuco::empty_key{empty_key_sentinel}, {}, {}, storage_ref}); + } + + thrust::device_vector d_set_refs(set_refs); + + // Insert sample data + insert<<<1, 128>>>(d_set_refs.data().get()); + // Find all inserted data + find<<<1, 128>>>(d_set_refs.data().get()); + + return 0; +} diff --git a/include/cuco/aow_storage.cuh b/include/cuco/aow_storage.cuh index fdd970cf4..479246fac 100644 --- a/include/cuco/aow_storage.cuh +++ b/include/cuco/aow_storage.cuh @@ -16,10 +16,10 @@ #pragma once -#include - #include +#include #include +#include #include @@ -47,7 +47,10 @@ class aow_storage_ref; * @tparam Extent Type of extent denoting number of windows * @tparam Allocator Type of allocator used for device storage (de)allocation */ -template +template , + typename Allocator = cuco::cuda_allocator>> class aow_storage : public detail::aow_storage_base { public: using base_type = detail::aow_storage_base; ///< AoW base class type @@ -78,7 +81,7 @@ class aow_storage : public detail::aow_storage_base { * @param size Number of windows to (de)allocate * @param allocator Allocator used for (de)allocating device storage */ - explicit constexpr aow_storage(Extent size, Allocator const& allocator) noexcept; + explicit constexpr aow_storage(Extent size, Allocator const& allocator = {}) noexcept; aow_storage(aow_storage&&) = default; ///< Move constructor /** @@ -119,7 +122,15 @@ class aow_storage : public detail::aow_storage_base { * @param key Key to which all keys in `slots` are initialized * @param stream Stream used for executing the kernel */ - void initialize(value_type key, cuda_stream_ref stream) noexcept; + void initialize(value_type key, cuda_stream_ref stream = {}) noexcept; + + /** + * @brief Asynchronously initializes each slot in the AoW storage to contain `key`. + * + * @param key Key to which all keys in `slots` are initialized + * @param stream Stream used for executing the kernel + */ + void initialize_async(value_type key, cuda_stream_ref stream = {}) noexcept; private: allocator_type allocator_; ///< Allocator used to (de)allocate windows @@ -134,7 +145,7 @@ class aow_storage : public detail::aow_storage_base { * @tparam WindowSize Number of slots in each window * @tparam Extent Type of extent denoting storage capacity */ -template +template > class aow_storage_ref : public detail::aow_storage_base { public: using base_type = detail::aow_storage_base; ///< AoW base class type diff --git a/include/cuco/detail/extent/extent.inl b/include/cuco/detail/extent/extent.inl index 911bda9b1..a7cd83dcd 100644 --- a/include/cuco/detail/extent/extent.inl +++ b/include/cuco/detail/extent/extent.inl @@ -27,13 +27,10 @@ namespace cuco { namespace experimental { -template +template struct window_extent { using value_type = SizeType; ///< Extent value type - static auto constexpr cg_size = CGSize; - static auto constexpr window_size = WindowSize; - __host__ __device__ constexpr value_type value() const noexcept { return N; } __host__ __device__ explicit constexpr operator value_type() const noexcept { return value(); } @@ -45,15 +42,11 @@ struct window_extent { friend auto constexpr make_window_extent(extent ext); }; -template -struct window_extent - : cuco::utility::fast_int { +template +struct window_extent : cuco::utility::fast_int { using value_type = typename cuco::utility::fast_int::fast_int::value_type; ///< Extent value type - static auto constexpr cg_size = CGSize; - static auto constexpr window_size = WindowSize; - private: using cuco::utility::fast_int::fast_int; @@ -67,10 +60,10 @@ template return make_window_extent(ext); } -template -[[nodiscard]] std::size_t constexpr make_window_extent(std::size_t size) +template +[[nodiscard]] auto constexpr make_window_extent(SizeType size) { - return make_window_extent(size); + return make_window_extent(extent{size}); } template @@ -86,15 +79,13 @@ template if (size > max_value) { CUCO_FAIL("Invalid input extent"); } if constexpr (N == dynamic_extent) { - return window_extent{static_cast( + return window_extent{static_cast( *cuco::detail::lower_bound( cuco::detail::primes.begin(), cuco::detail::primes.end(), static_cast(size)) * CGSize)}; } if constexpr (N != dynamic_extent) { - return window_extent( *cuco::detail::lower_bound(cuco::detail::primes.begin(), cuco::detail::primes.end(), @@ -103,10 +94,10 @@ template } } -template -[[nodiscard]] std::size_t constexpr make_window_extent(std::size_t size) +template +[[nodiscard]] auto constexpr make_window_extent(SizeType size) { - return static_cast(make_window_extent(extent{size})); + return make_window_extent(extent{size}); } namespace detail { @@ -115,8 +106,8 @@ template struct is_window_extent : std::false_type { }; -template -struct is_window_extent> : std::true_type { +template +struct is_window_extent> : std::true_type { }; template diff --git a/include/cuco/detail/open_addressing_impl.cuh b/include/cuco/detail/open_addressing_impl.cuh index ef4821b40..2bc3a7225 100644 --- a/include/cuco/detail/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing_impl.cuh @@ -141,11 +141,7 @@ class open_addressing_impl { * * @param stream CUDA stream this operation is executed in */ - void clear(cuda_stream_ref stream) noexcept - { - this->clear_async(stream); - stream.synchronize(); - } + void clear(cuda_stream_ref stream) noexcept { storage_.initialize(empty_slot_sentinel_, stream); } /** * @brief Asynchronously erases all elements from the container. After this call, `size()` returns @@ -155,7 +151,7 @@ class open_addressing_impl { */ void clear_async(cuda_stream_ref stream) noexcept { - storage_.initialize(empty_slot_sentinel_, stream); + storage_.initialize_async(empty_slot_sentinel_, stream); } /** diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index 213d35af1..cce691c21 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -87,12 +88,9 @@ class open_addressing_ref_impl { ProbingScheme>, "ProbingScheme must inherit from cuco::detail::probing_scheme_base"); - static_assert(is_window_extent_v, - "Extent is not a valid cuco::window_extent"); - static_assert(ProbingScheme::cg_size == StorageRef::extent_type::cg_size, - "Extent has incompatible CG size"); - static_assert(StorageRef::window_size == StorageRef::extent_type::window_size, - "Extent has incompatible window size"); + // TODO: how to re-enable this check? + // static_assert(is_window_extent_v, + // "Extent is not a valid cuco::window_extent"); public: using key_type = Key; ///< Key type diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 250c84feb..28b3ffaf2 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -50,6 +50,30 @@ __host__ __device__ constexpr static_map_ref< { } +template +template +__host__ __device__ constexpr static_map_ref:: + static_map_ref( + static_map_ref&& + other) noexcept + : impl_{std::move(other.impl_)}, + predicate_{std::move(other.predicate_)}, + empty_value_sentinel_{std::move(other.empty_value_sentinel_)} +{ +} + template return empty_value_sentinel_; } +template +template +auto static_map_ref::with( + NewOperators...) && noexcept +{ + return static_map_ref( + std::move(*this)); +} + template +template +__host__ __device__ constexpr static_set_ref:: + static_set_ref( + static_set_ref&& + other) noexcept + : impl_{std::move(other.impl_)}, predicate_{std::move(other.predicate_)} +{ +} + template ::e return predicate_.empty_sentinel_; } +template +template +auto static_set_ref::with( + NewOperators...) && noexcept +{ + return static_set_ref( + std::move(*this)); +} + namespace detail { template ::ref() const noexcept template void aow_storage::initialize(value_type key, cuda_stream_ref stream) noexcept +{ + this->initialize_async(key, stream); + stream.synchronize(); +} + +template +void aow_storage::initialize_async( + value_type key, cuda_stream_ref stream) noexcept { auto constexpr cg_size = 1; auto constexpr stride = 4; diff --git a/include/cuco/detail/storage/storage.cuh b/include/cuco/detail/storage/storage.cuh index b9a00baa2..4dda179c9 100644 --- a/include/cuco/detail/storage/storage.cuh +++ b/include/cuco/detail/storage/storage.cuh @@ -45,6 +45,7 @@ class storage : StorageImpl::template impl { using impl_type::capacity; using impl_type::data; using impl_type::initialize; + using impl_type::initialize_async; using impl_type::num_windows; using impl_type::ref; diff --git a/include/cuco/extent.cuh b/include/cuco/extent.cuh index e45068d9e..50e7ae4aa 100644 --- a/include/cuco/extent.cuh +++ b/include/cuco/extent.cuh @@ -83,7 +83,7 @@ struct extent { * @tparam N Extent * */ -template +template struct window_extent; /** @@ -118,15 +118,16 @@ template * the capacity ctor argument for the container. * * @tparam Container Container type to compute the extent for + * @tparam SizeType Size type * * @param size The input size * * @throw If the input size is invalid * - * @return Resulting valid extent as `std::size_t` + * @return Resulting valid extent */ -template -[[nodiscard]] std::size_t constexpr make_window_extent(std::size_t size); +template +[[nodiscard]] auto constexpr make_window_extent(SizeType size); /** * @brief Computes valid window extent based on given parameters. @@ -162,15 +163,16 @@ template * * @tparam CGSize Number of elements handled per CG * @tparam WindowSize Number of elements handled per Window + * @tparam SizeType Size type * * @param size The input size * * @throw If the input size is invalid * - * @return Resulting valid extent as `std::size_t` + * @return Resulting valid extent */ -template -[[nodiscard]] std::size_t constexpr make_window_extent(std::size_t size); +template +[[nodiscard]] auto constexpr make_window_extent(SizeType size); } // namespace experimental } // namespace cuco diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index 2460f1f10..c41ed88f3 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -17,8 +17,11 @@ #pragma once #include +#include #include +#include #include +#include #include @@ -106,6 +109,18 @@ class static_map_ref probing_scheme_type const& probing_scheme, storage_ref_type storage_ref) noexcept; + /** + * @brief Operator-agnostic move constructor. + * + * @tparam OtherOperators Operator set of the `other` object + * + * @param other Object to construct `*this` from + */ + template + __host__ __device__ explicit constexpr static_map_ref( + static_map_ref&& + other) noexcept; + /** * @brief Gets the maximum number of elements the container can hold. * @@ -127,6 +142,23 @@ class static_map_ref */ [[nodiscard]] __host__ __device__ constexpr mapped_type empty_value_sentinel() const noexcept; + /** + * @brief Creates a reference with new operators from the current object. + * + * Note that this function uses move semantics and thus invalidates the current object. + * + * @warning Using two or more reference objects to the same container but with + * a different operator set at the same time results in undefined behavior. + * + * @tparam NewOperators List of `cuco::op::*_tag` types + * + * @param ops List of operators, e.g., `cuco::insert` + * + * @return `*this` with `NewOperators...` + */ + template + [[nodiscard]] __host__ __device__ auto with(NewOperators... ops) && noexcept; + private: struct predicate_wrapper; @@ -137,6 +169,16 @@ class static_map_ref // Mixins need to be friends with this class in order to access private members template friend class detail::operator_impl; + + // Refs with other operator sets need to be friends too + template + friend class static_map_ref; }; } // namespace experimental diff --git a/include/cuco/static_set_ref.cuh b/include/cuco/static_set_ref.cuh index cf9c00ee0..b2c8158e7 100644 --- a/include/cuco/static_set_ref.cuh +++ b/include/cuco/static_set_ref.cuh @@ -18,8 +18,11 @@ #include #include +#include #include +#include #include +#include #include @@ -94,6 +97,18 @@ class static_set_ref probing_scheme_type const& probing_scheme, storage_ref_type storage_ref) noexcept; + /** + * @brief Operator-agnostic move constructor. + * + * @tparam OtherOperators Operator set of the `other` object + * + * @param other Object to construct `*this` from + */ + template + __host__ __device__ explicit constexpr static_set_ref( + static_set_ref&& + other) noexcept; + /** * @brief Gets the maximum number of elements the container can hold. * @@ -108,6 +123,23 @@ class static_set_ref */ [[nodiscard]] __host__ __device__ constexpr key_type empty_key_sentinel() const noexcept; + /** + * @brief Creates a reference with new operators from the current object. + * + * Note that this function uses move semantics and thus invalidates the current object. + * + * @warning Using two or more reference objects to the same container but with + * a different operator set at the same time results in undefined behavior. + * + * @tparam NewOperators List of `cuco::op::*_tag` types + * + * @param ops List of operators, e.g., `cuco::insert` + * + * @return `*this` with `NewOperators...` + */ + template + [[nodiscard]] __host__ __device__ auto with(NewOperators... ops) && noexcept; + private: impl_type impl_; detail::equal_wrapper predicate_; ///< Key equality binary callable @@ -115,6 +147,15 @@ class static_set_ref // Mixins need to be friends with this class in order to access private members template friend class detail::operator_impl; + + // Refs with other operator sets need to be friends too + template + friend class static_set_ref; }; } // namespace experimental diff --git a/include/cuco/storage.cuh b/include/cuco/storage.cuh index e34e59c96..e2e0c6f46 100644 --- a/include/cuco/storage.cuh +++ b/include/cuco/storage.cuh @@ -20,6 +20,7 @@ namespace cuco { namespace experimental { + /** * @brief Public storage class. * From ee9c48abcdc7188df4833f9b391f6e84d798000d Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Fri, 29 Sep 2023 12:33:09 -0700 Subject: [PATCH 3/5] Add constructor overloads taking load factor as input (#369) This PR adds constructor overloads that take a size and load factor for the new map and set. --- include/cuco/detail/open_addressing_impl.cuh | 59 ++++++- include/cuco/detail/static_map/static_map.inl | 29 ++++ include/cuco/detail/static_set/static_set.inl | 26 +++ include/cuco/static_map.cuh | 46 ++++- include/cuco/static_set.cuh | 44 ++++- tests/CMakeLists.txt | 1 + tests/static_map/capacity_test.cu | 162 ++++++++++++++++++ tests/static_set/capacity_test.cu | 30 ++++ 8 files changed, 387 insertions(+), 10 deletions(-) create mode 100644 tests/static_map/capacity_test.cu diff --git a/include/cuco/detail/open_addressing_impl.cuh b/include/cuco/detail/open_addressing_impl.cuh index 2bc3a7225..556f821d4 100644 --- a/include/cuco/detail/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing_impl.cuh @@ -34,6 +34,8 @@ #include +#include + namespace cuco { namespace experimental { namespace detail { @@ -120,8 +122,8 @@ class open_addressing_impl { * @param stream CUDA stream used to initialize the data structure */ constexpr open_addressing_impl(Extent capacity, - key_type empty_key_sentinel, - value_type empty_slot_sentinel, + Key empty_key_sentinel, + Value empty_slot_sentinel, KeyEqual const& pred, ProbingScheme const& probing_scheme, Allocator const& alloc, @@ -135,6 +137,59 @@ class open_addressing_impl { this->clear_async(stream); } + /** + * @brief Constructs a statically-sized open addressing data structure with the number of elements + * to insert `n`, the desired load factor, etc. + * + * @note This constructor helps users create a data structure based on the number of elements to + * insert and the desired load factor without manually computing the desired capacity. The actual + * capacity will be a size no smaller than `ceil(n / desired_load_factor)`. It's determined by + * multiple factors including the given `n`, the desired load factor, the probing scheme, the CG + * size, and the window size and is computed via the `make_window_extent` factory. + * @note Insert operations will not automatically grow the container. + * @note Attempting to insert more unique keys than the capacity of the container results in + * undefined behavior. + * @note Any `*_sentinel`s are reserved and behavior is undefined when attempting to insert + * this sentinel value. + * @note This constructor doesn't synchronize the given stream. + * @note This overload will convert compile-time extents to runtime constants which might lead to + * performance regressions. + * + * @throw If the desired occupancy is no bigger than zero + * @throw If the desired occupancy is no smaller than one + * + * @param n The number of elements to insert + * @param desired_load_factor The desired load factor of the container, e.g., 0.5 implies a 50% + * load factor + * @param empty_key_sentinel The reserved key value for empty slots + * @param empty_slot_sentinel The reserved slot value for empty slots + * @param pred Key equality binary predicate + * @param probing_scheme Probing scheme + * @param alloc Allocator used for allocating device storage + * @param stream CUDA stream used to initialize the data structure + */ + constexpr open_addressing_impl(Extent n, + double desired_load_factor, + Key empty_key_sentinel, + Value empty_slot_sentinel, + KeyEqual const& pred, + ProbingScheme const& probing_scheme, + Allocator const& alloc, + cuda_stream_ref stream) + : empty_key_sentinel_{empty_key_sentinel}, + empty_slot_sentinel_{empty_slot_sentinel}, + predicate_{pred}, + probing_scheme_{probing_scheme}, + storage_{make_window_extent( + static_cast(std::ceil(static_cast(n) / desired_load_factor))), + alloc} + { + CUCO_EXPECTS(desired_load_factor > 0., "Desired occupancy must be larger than zero"); + CUCO_EXPECTS(desired_load_factor < 1., "Desired occupancy must be smaller than one"); + + this->clear_async(stream); + } + /** * @brief Erases all elements from the container. After this call, `size()` returns zero. * Invalidates any references, pointers, or iterators referring to contained elements. diff --git a/include/cuco/detail/static_map/static_map.inl b/include/cuco/detail/static_map/static_map.inl index d7274245e..1cc932aeb 100644 --- a/include/cuco/detail/static_map/static_map.inl +++ b/include/cuco/detail/static_map/static_map.inl @@ -54,6 +54,35 @@ constexpr static_map +constexpr static_map:: + static_map(Extent n, + double desired_load_factor, + empty_key empty_key_sentinel, + empty_value empty_value_sentinel, + KeyEqual const& pred, + ProbingScheme const& probing_scheme, + Allocator const& alloc, + cuda_stream_ref stream) + : impl_{std::make_unique(n, + desired_load_factor, + empty_key_sentinel, + cuco::pair{empty_key_sentinel, empty_value_sentinel}, + pred, + probing_scheme, + alloc, + stream)}, + empty_value_sentinel_{empty_value_sentinel} +{ +} + template +constexpr static_set::static_set( + Extent n, + double desired_load_factor, + empty_key empty_key_sentinel, + KeyEqual const& pred, + ProbingScheme const& probing_scheme, + Allocator const& alloc, + cuda_stream_ref stream) + : impl_{std::make_unique(n, + desired_load_factor, + empty_key_sentinel, + empty_key_sentinel, + pred, + probing_scheme, + alloc, + stream)} +{ +} + template , @@ -156,7 +155,7 @@ class static_map { /** * @brief Constructs a statically-sized map with the specified initial capacity, sentinel values - * and CUDA stream. + * and CUDA stream * * The actual map capacity depends on the given `capacity`, the probing scheme, CG size, and the * window size and it is computed via the `make_window_extent` factory. Insert operations will not @@ -165,8 +164,7 @@ class static_map { * * @note Any `*_sentinel`s are reserved and behavior is undefined when attempting to insert * this sentinel value. - * @note If a non-default CUDA stream is provided, the caller is responsible for synchronizing the - * stream before the object is first used. + * @note This constructor doesn't synchronize the given stream. * * @param capacity The requested lower-bound map size * @param empty_key_sentinel The reserved key value for empty slots @@ -184,6 +182,46 @@ class static_map { Allocator const& alloc = {}, cuda_stream_ref stream = {}); + /** + * @brief Constructs a statically-sized map with the number of elements to insert `n`, the desired + * load factor, etc + * + * @note This constructor helps users create a map based on the number of elements to insert and + * the desired load factor without manually computing the desired capacity. The actual map + * capacity will be a size no smaller than `ceil(n / desired_load_factor)`. It's determined by + * multiple factors including the given `n`, the desired load factor, the probing scheme, the CG + * size, and the window size and is computed via the `make_window_extent` factory. + * @note Insert operations will not automatically grow the container. + * @note Attempting to insert more unique keys than the capacity of the container results in + * undefined behavior. + * @note Any `*_sentinel`s are reserved and behavior is undefined when attempting to insert + * this sentinel value. + * @note This constructor doesn't synchronize the given stream. + * @note This overload will convert compile-time extents to runtime constants which might lead to + * performance regressions. + * + * @throw If the desired occupancy is no bigger than zero + * @throw If the desired occupancy is no smaller than one + * + * @param n The number of elements to insert + * @param desired_load_factor The desired load factor of the container, e.g., 0.5 implies a 50% + * load factor + * @param empty_key_sentinel The reserved key value for empty slots + * @param empty_value_sentinel The reserved mapped value for empty slots + * @param pred Key equality binary predicate + * @param probing_scheme Probing scheme + * @param alloc Allocator used for allocating device storage + * @param stream CUDA stream used to initialize the map + */ + constexpr static_map(Extent n, + double desired_load_factor, + empty_key empty_key_sentinel, + empty_value empty_value_sentinel, + KeyEqual const& pred = {}, + ProbingScheme const& probing_scheme = {}, + Allocator const& alloc = {}, + cuda_stream_ref stream = {}); + /** * @brief Erases all elements from the container. After this call, `size()` returns zero. * Invalidates any references, pointers, or iterators referring to contained elements. diff --git a/include/cuco/static_set.cuh b/include/cuco/static_set.cuh index 613a99bd4..6d48d5dc8 100644 --- a/include/cuco/static_set.cuh +++ b/include/cuco/static_set.cuh @@ -79,7 +79,6 @@ namespace experimental { * @tparam Allocator Type of allocator used for device storage * @tparam Storage Slot window storage type */ - template , cuda::thread_scope Scope = cuda::thread_scope_device, @@ -131,7 +130,7 @@ class static_set { /** * @brief Constructs a statically-sized set with the specified initial capacity, sentinel values - * and CUDA stream. + * and CUDA stream * * The actual set capacity depends on the given `capacity`, the probing scheme, CG size, and the * window size and it is computed via the `make_window_extent` factory. Insert operations will not @@ -140,8 +139,7 @@ class static_set { * * @note Any `*_sentinel`s are reserved and behavior is undefined when attempting to insert * this sentinel value. - * @note If a non-default CUDA stream is provided, the caller is responsible for synchronizing the - * stream before the object is first used. + * @note This constructor doesn't synchronize the given stream. * * @param capacity The requested lower-bound set size * @param empty_key_sentinel The reserved key value for empty slots @@ -157,6 +155,44 @@ class static_set { Allocator const& alloc = {}, cuda_stream_ref stream = {}); + /** + * @brief Constructs a statically-sized map with the number of elements to insert `n`, the desired + * load factor, etc + * + * @note This constructor helps users create a set based on the number of elements to insert and + * the desired load factor without manually computing the desired capacity. The actual set + * capacity will be a size no smaller than `ceil(n / desired_load_factor)`. It's determined by + * multiple factors including the given `n`, the desired load factor, the probing scheme, the CG + * size, and the window size and is computed via the `make_window_extent` factory. + * @note Insert operations will not automatically grow the container. + * @note Attempting to insert more unique keys than the capacity of the container results in + * undefined behavior. + * @note Any `*_sentinel`s are reserved and behavior is undefined when attempting to insert + * this sentinel value. + * @note This constructor doesn't synchronize the given stream. + * @note This overload will convert compile-time extents to runtime constants which might lead to + * performance regressions. + * + * @throw If the desired occupancy is no bigger than zero + * @throw If the desired occupancy is no smaller than one + * + * @param n The number of elements to insert + * @param desired_load_factor The desired load factor of the container, e.g., 0.5 implies a 50% + * load factor + * @param empty_key_sentinel The reserved key value for empty slots + * @param pred Key equality binary predicate + * @param probing_scheme Probing scheme + * @param alloc Allocator used for allocating device storage + * @param stream CUDA stream used to initialize the set + */ + constexpr static_set(Extent n, + double desired_load_factor, + empty_key empty_key_sentinel, + KeyEqual const& pred = {}, + ProbingScheme const& probing_scheme = {}, + Allocator const& alloc = {}, + cuda_stream_ref stream = {}); + /** * @brief Erases all elements from the container. After this call, `size()` returns zero. * Invalidates any references, pointers, or iterators referring to contained elements. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3deeeddf1..775b5b82f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -69,6 +69,7 @@ ConfigureTest(STATIC_SET_TEST ################################################################################################### # - static_map tests ------------------------------------------------------------------------------ ConfigureTest(STATIC_MAP_TEST + static_map/capacity_test.cu static_map/custom_type_test.cu static_map/duplicate_keys_test.cu static_map/erase_test.cu diff --git a/tests/static_map/capacity_test.cu b/tests/static_map/capacity_test.cu new file mode 100644 index 000000000..13774fe8a --- /dev/null +++ b/tests/static_map/capacity_test.cu @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +TEST_CASE("Static map capacity", "") +{ + using Key = int32_t; + using T = int32_t; + using ProbeT = cuco::experimental::double_hashing<1, cuco::default_hash_function>; + using Equal = thrust::equal_to; + using AllocatorT = cuco::cuda_allocator; + using StorageT = cuco::experimental::storage<2>; + + SECTION("zero capacity is allowed.") + { + auto constexpr gold_capacity = 4; + + using extent_type = cuco::experimental::extent; + cuco::experimental::static_map + map{extent_type{}, cuco::empty_key{-1}, cuco::empty_value{-1}}; + auto const capacity = map.capacity(); + REQUIRE(capacity == gold_capacity); + + auto ref = map.ref(cuco::experimental::insert); + auto const ref_capacity = ref.capacity(); + REQUIRE(ref_capacity == gold_capacity); + } + + SECTION("negative capacity (ikr -_-||) is also allowed.") + { + auto constexpr gold_capacity = 4; + + using extent_type = cuco::experimental::extent; + cuco::experimental::static_map + map{extent_type{-10}, cuco::empty_key{-1}, cuco::empty_value{-1}}; + auto const capacity = map.capacity(); + REQUIRE(capacity == gold_capacity); + + auto ref = map.ref(cuco::experimental::insert); + auto const ref_capacity = ref.capacity(); + REQUIRE(ref_capacity == gold_capacity); + } + + constexpr std::size_t num_keys{400}; + + SECTION("Dynamic extent is evaluated at run time.") + { + auto constexpr gold_capacity = 422; // 211 x 2 + + using extent_type = cuco::experimental::extent; + cuco::experimental::static_map + map{num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; + auto const capacity = map.capacity(); + REQUIRE(capacity == gold_capacity); + + auto ref = map.ref(cuco::experimental::insert); + auto const ref_capacity = ref.capacity(); + REQUIRE(ref_capacity == gold_capacity); + } + + SECTION("map can be constructed from plain integer.") + { + auto constexpr gold_capacity = 422; // 211 x 2 + + cuco::experimental::static_map + map{num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; + auto const capacity = map.capacity(); + REQUIRE(capacity == gold_capacity); + + auto ref = map.ref(cuco::experimental::insert); + auto const ref_capacity = ref.capacity(); + REQUIRE(ref_capacity == gold_capacity); + } + + SECTION("map can be constructed from plain integer and load factor.") + { + auto constexpr gold_capacity = 502; // 251 x 2 + + cuco::experimental::static_map + map{num_keys, 0.8, cuco::empty_key{-1}, cuco::empty_value{-1}}; + auto const capacity = map.capacity(); + REQUIRE(capacity == gold_capacity); + + auto ref = map.ref(cuco::experimental::insert); + auto const ref_capacity = ref.capacity(); + REQUIRE(ref_capacity == gold_capacity); + } + + SECTION("Dynamic extent is evaluated at run time.") + { + auto constexpr gold_capacity = 412; // 103 x 2 x 2 + + using probe = cuco::experimental::linear_probing<2, cuco::default_hash_function>; + auto map = cuco::experimental::static_map, + cuda::thread_scope_device, + Equal, + probe, + AllocatorT, + StorageT>{ + num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; + + auto const capacity = map.capacity(); + REQUIRE(capacity == gold_capacity); + + auto ref = map.ref(cuco::experimental::insert); + auto const ref_capacity = ref.capacity(); + REQUIRE(ref_capacity == gold_capacity); + } +} diff --git a/tests/static_set/capacity_test.cu b/tests/static_set/capacity_test.cu index 4c66a7ccc..f042cdb73 100644 --- a/tests/static_set/capacity_test.cu +++ b/tests/static_set/capacity_test.cu @@ -76,6 +76,36 @@ TEST_CASE("Static set capacity", "") REQUIRE(ref_capacity == gold_capacity); } + SECTION("Set can be constructed from plain integer.") + { + auto constexpr gold_capacity = 422; // 211 x 2 + + cuco::experimental:: + static_set + set{num_keys, cuco::empty_key{-1}}; + auto const capacity = set.capacity(); + REQUIRE(capacity == gold_capacity); + + auto ref = set.ref(cuco::experimental::insert); + auto const ref_capacity = ref.capacity(); + REQUIRE(ref_capacity == gold_capacity); + } + + SECTION("Set can be constructed from plain integer and load factor.") + { + auto constexpr gold_capacity = 502; // 251 x 2 + + cuco::experimental:: + static_set + set{num_keys, 0.8, cuco::empty_key{-1}}; + auto const capacity = set.capacity(); + REQUIRE(capacity == gold_capacity); + + auto ref = set.ref(cuco::experimental::insert); + auto const ref_capacity = ref.capacity(); + REQUIRE(ref_capacity == gold_capacity); + } + SECTION("Dynamic extent is evaluated at run time.") { auto constexpr gold_capacity = 412; // 103 x 2 x 2 From fd23a3dcb2c2e53c228e629014b1c625e9e6515b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= <2955913+sleeepyjack@users.noreply.github.com> Date: Mon, 2 Oct 2023 18:05:58 +0200 Subject: [PATCH 4/5] Enable heterogeneous insert for static_set (#375) --- include/cuco/detail/common_kernels.cuh | 22 +++---- include/cuco/detail/equal_wrapper.cuh | 14 +++-- .../cuco/detail/open_addressing_ref_impl.cuh | 60 +++++++++++-------- include/cuco/detail/static_set/kernels.cuh | 4 +- .../cuco/detail/static_set/static_set_ref.inl | 20 +++++-- tests/static_set/heterogeneous_lookup_test.cu | 21 ++++--- 6 files changed, 87 insertions(+), 54 deletions(-) diff --git a/include/cuco/detail/common_kernels.cuh b/include/cuco/detail/common_kernels.cuh index 759041bad..223f20609 100644 --- a/include/cuco/detail/common_kernels.cuh +++ b/include/cuco/detail/common_kernels.cuh @@ -23,6 +23,8 @@ #include +#include + namespace cuco { namespace experimental { namespace detail { @@ -37,7 +39,7 @@ namespace detail { * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam StencilIt Device accessible random access iterator whose value_type is * convertible to Predicate's argument type @@ -55,12 +57,12 @@ namespace detail { */ template -__global__ void insert_if_n(InputIterator first, +__global__ void insert_if_n(InputIt first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, @@ -76,7 +78,7 @@ __global__ void insert_if_n(InputIterator first, while (idx < n) { if (pred(*(stencil + idx))) { - typename Ref::value_type const insert_element{*(first + idx)}; + typename std::iterator_traits::value_type const& insert_element{*(first + idx)}; if constexpr (CGSize == 1) { if (ref.insert(insert_element)) { thread_num_successes++; }; } else { @@ -106,7 +108,7 @@ __global__ void insert_if_n(InputIterator first, * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam StencilIt Device accessible random access iterator whose value_type is * convertible to Predicate's argument type @@ -122,19 +124,19 @@ __global__ void insert_if_n(InputIterator first, */ template __global__ void insert_if_n( - InputIterator first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref) + InputIt first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref) { auto const loop_stride = cuco::detail::grid_stride() / CGSize; auto idx = cuco::detail::global_thread_id() / CGSize; while (idx < n) { if (pred(*(stencil + idx))) { - typename Ref::value_type const insert_element{*(first + idx)}; + typename std::iterator_traits::value_type const& insert_element{*(first + idx)}; if constexpr (CGSize == 1) { ref.insert(insert_element); } else { @@ -198,7 +200,7 @@ __global__ void contains_if_n(InputIt first, while (idx - thread_idx < n) { // the whole thread block falls into the same iteration if constexpr (CGSize == 1) { if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const& key = *(first + idx); /* * The ld.relaxed.gpu instruction causes L1 to flush more frequently, causing increased * sector stores from L2 to global memory. By writing results to shared memory and then @@ -212,7 +214,7 @@ __global__ void contains_if_n(InputIt first, } else { auto const tile = cg::tiled_partition(cg::this_thread_block()); if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const& key = *(first + idx); auto const found = pred(*(stencil + idx)) ? ref.contains(tile, key) : false; if (tile.thread_rank() == 0) { *(output_begin + idx) = found; } } diff --git a/include/cuco/detail/equal_wrapper.cuh b/include/cuco/detail/equal_wrapper.cuh index d2ded4a33..0c05a4a9c 100644 --- a/include/cuco/detail/equal_wrapper.cuh +++ b/include/cuco/detail/equal_wrapper.cuh @@ -55,15 +55,16 @@ struct equal_wrapper { /** * @brief Equality check with the given equality callable. * - * @tparam U Right-hand side Element type + * @tparam LHS Left-hand side Element type + * @tparam RHS Right-hand side Element type * * @param lhs Left-hand side element to check equality * @param rhs Right-hand side element to check equality * * @return `EQUAL` if `lhs` and `rhs` are equivalent. `UNEQUAL` otherwise. */ - template - __device__ constexpr equal_result equal_to(T const& lhs, U const& rhs) const noexcept + template + __device__ constexpr equal_result equal_to(LHS const& lhs, RHS const& rhs) const noexcept { return equal_(lhs, rhs) ? equal_result::EQUAL : equal_result::UNEQUAL; } @@ -75,15 +76,16 @@ struct equal_wrapper { * first then perform a equality check with the given `equal_` callable, i.e., `equal_(lhs, rhs)`. * @note Container (like set or map) keys MUST be always on the left-hand side. * - * @tparam U Right-hand side Element type + * @tparam LHS Left-hand side Element type + * @tparam RHS Right-hand side Element type * * @param lhs Left-hand side element to check equality * @param rhs Right-hand side element to check equality * * @return Three way equality comparison result */ - template - __device__ constexpr equal_result operator()(T const& lhs, U const& rhs) const noexcept + template + __device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept { return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY : this->equal_to(lhs, rhs); diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index cce691c21..3967cffa3 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -155,6 +155,7 @@ class open_addressing_ref_impl { * @brief Inserts an element. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param value The element to insert @@ -162,8 +163,8 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template - __device__ bool insert(value_type const& value, Predicate const& predicate) noexcept + template + __device__ bool insert(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); @@ -202,6 +203,7 @@ class open_addressing_ref_impl { * @brief Inserts an element. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert @@ -210,9 +212,9 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { @@ -275,6 +277,7 @@ class open_addressing_ref_impl { * not. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param value The element to insert @@ -283,8 +286,8 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template - __device__ thrust::pair insert_and_find(value_type const& value, + template + __device__ thrust::pair insert_and_find(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); @@ -337,6 +340,7 @@ class open_addressing_ref_impl { * not. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert_and_find @@ -346,10 +350,10 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { @@ -712,6 +716,7 @@ class open_addressing_ref_impl { * @brief Inserts the specified element with one single CAS operation. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -720,12 +725,12 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { - auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); + auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast(value)); auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { if constexpr (HasPayload) { @@ -757,6 +762,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with two back-to-back CAS operations. * + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -765,17 +771,18 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result back_to_back_cas( - value_type* slot, value_type const& value, Predicate const& predicate) noexcept + value_type* slot, Value const& value, Predicate const& predicate) noexcept { + using mapped_type = decltype(this->empty_slot_sentinel_.second); + auto const expected_key = this->empty_slot_sentinel_.first; auto const expected_payload = this->empty_slot_sentinel_.second; - auto old_key = compare_and_swap(&slot->first, expected_key, value.first); - auto old_payload = compare_and_swap(&slot->second, expected_payload, value.second); - - using mapped_type = decltype(expected_payload); + auto old_key = compare_and_swap(&slot->first, expected_key, static_cast(value.first)); + auto old_payload = + compare_and_swap(&slot->second, expected_payload, static_cast(value.second)); auto* old_key_ptr = reinterpret_cast(&old_key); auto* old_payload_ptr = reinterpret_cast(&old_payload); @@ -783,7 +790,8 @@ class open_addressing_ref_impl { // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - old_payload = compare_and_swap(&slot->second, expected_payload, value.second); + old_payload = + compare_and_swap(&slot->second, expected_payload, static_cast(value.second)); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { @@ -802,6 +810,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with CAS-dependent write operations. * + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -810,19 +819,21 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result cas_dependent_write( - value_type* slot, value_type const& value, Predicate const& predicate) noexcept + value_type* slot, Value const& value, Predicate const& predicate) noexcept { + using mapped_type = decltype(this->empty_slot_sentinel_.second); + auto const expected_key = this->empty_slot_sentinel_.first; - auto old_key = compare_and_swap(&slot->first, expected_key, value.first); + auto old_key = compare_and_swap(&slot->first, expected_key, static_cast(value.first)); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&slot->second, value.second); + atomic_store(&slot->second, static_cast(value.second)); return insert_result::SUCCESS; } @@ -842,6 +853,7 @@ class open_addressing_ref_impl { * type and presence of other operator mixins. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -850,9 +862,9 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { if constexpr (sizeof(value_type) <= 8) { diff --git a/include/cuco/detail/static_set/kernels.cuh b/include/cuco/detail/static_set/kernels.cuh index 72744f2b4..15d725f68 100644 --- a/include/cuco/detail/static_set/kernels.cuh +++ b/include/cuco/detail/static_set/kernels.cuh @@ -24,6 +24,8 @@ #include +#include + namespace cuco { namespace experimental { namespace static_set_ns { @@ -62,7 +64,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_ while (idx - thread_idx < n) { // the whole thread block falls into the same iteration if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const& key = *(first + idx); if constexpr (CGSize == 1) { auto const found = ref.find(key); /* diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 4c3853971..3dbda9bbf 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -128,11 +128,14 @@ class operator_impl + __device__ bool insert(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -142,13 +145,16 @@ class operator_impl __device__ bool insert(cooperative_groups::thread_block_tile const& group, - value_type const& value) noexcept + Value const& value) noexcept { auto& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -208,12 +214,15 @@ class operator_impl insert_and_find(value_type const& value) noexcept + template + __device__ thrust::pair insert_and_find(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -227,14 +236,17 @@ class operator_impl __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept + cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; diff --git a/tests/static_set/heterogeneous_lookup_test.cu b/tests/static_set/heterogeneous_lookup_test.cu index cbc0efac3..ddc799ed3 100644 --- a/tests/static_set/heterogeneous_lookup_test.cu +++ b/tests/static_set/heterogeneous_lookup_test.cu @@ -41,6 +41,8 @@ struct key_pair { // Device equality operator is mandatory due to libcudacxx bug: // https://github.com/NVIDIA/libcudacxx/issues/223 __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; } + + __device__ explicit operator T() const noexcept { return a; } }; // probe key type @@ -66,23 +68,24 @@ struct custom_hasher { template __device__ uint32_t operator()(CustomKey const& k) const { - return thrust::raw_reference_cast(k).a; + return k.a; }; }; // User-defined device key equality struct custom_key_equal { - template - __device__ bool operator()(LHS const& lhs, RHS const& rhs) const + template + __device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const { - return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a; + return lhs == rhs.a; } }; TEMPLATE_TEST_CASE_SIG( "Heterogeneous lookup", "", ((typename T, int CGSize), T, CGSize), (int32_t, 1), (int32_t, 2)) { - using Key = key_pair; + using Key = T; + using InsertKey = key_pair; using ProbeKey = key_triplet; using probe_type = cuco::experimental::double_hashing; @@ -98,15 +101,15 @@ TEMPLATE_TEST_CASE_SIG( probe_type>{ capacity, cuco::empty_key{sentinel_key}, custom_key_equal{}, probe}; - auto insert_pairs = thrust::make_transform_iterator(thrust::counting_iterator(0), - [] __device__(auto i) { return Key{i}; }); - auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), + auto insert_keys = thrust::make_transform_iterator( + thrust::counting_iterator(0), [] __device__(auto i) { return InsertKey(i); }); + auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), [] __device__(auto i) { return ProbeKey(i); }); SECTION("All inserted keys should be contained") { thrust::device_vector contained(num); - my_set.insert(insert_pairs, insert_pairs + num); + my_set.insert(insert_keys, insert_keys + num); my_set.contains(probe_keys, probe_keys + num, contained.begin()); REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); } From b4657fda872b7032762cdb67c44506d6e2cf796d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= <2955913+sleeepyjack@users.noreply.github.com> Date: Wed, 4 Oct 2023 17:18:57 +0200 Subject: [PATCH 5/5] Remove HasPayload tparam from OA impl classes (#377) --- .../open_addressing_impl.cuh | 0 .../open_addressing_ref_impl.cuh | 48 +++++++++---------- .../cuco/detail/static_map/static_map_ref.inl | 20 ++++---- .../cuco/detail/static_set/static_set_ref.inl | 20 ++++---- include/cuco/static_map.cuh | 2 +- include/cuco/static_map_ref.cuh | 2 +- include/cuco/static_set.cuh | 2 +- include/cuco/static_set_ref.cuh | 2 +- 8 files changed, 42 insertions(+), 54 deletions(-) rename include/cuco/detail/{ => open_addressing}/open_addressing_impl.cuh (100%) rename include/cuco/detail/{ => open_addressing}/open_addressing_ref_impl.cuh (95%) diff --git a/include/cuco/detail/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh similarity index 100% rename from include/cuco/detail/open_addressing_impl.cuh rename to include/cuco/detail/open_addressing/open_addressing_impl.cuh diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh similarity index 95% rename from include/cuco/detail/open_addressing_ref_impl.cuh rename to include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 3967cffa3..2432a81b0 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -106,6 +106,9 @@ class open_addressing_ref_impl { static constexpr auto cg_size = probing_scheme_type::cg_size; ///< Cooperative group size static constexpr auto window_size = storage_ref_type::window_size; ///< Number of elements handled per window + static constexpr auto has_payload = + not std::is_same_v; ///< Determines if the container is a key/value or + ///< key-only store /** * @brief Constructs open_addressing_ref_impl. @@ -154,7 +157,6 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -163,13 +165,13 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); auto const key = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { return value.first; } else { return value; @@ -187,7 +189,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EQUAL) { return false; } if (eq_res == detail::equal_result::EMPTY) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); - switch (attempt_insert( + switch (attempt_insert( (storage_ref_.data() + *probing_iter)->data() + intra_window_index, value, predicate)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; @@ -202,7 +204,6 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -212,13 +213,13 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { return value.first; } else { return value; @@ -252,10 +253,9 @@ class open_addressing_ref_impl { auto const src_lane = __ffs(group_contains_empty) - 1; auto const status = (group.thread_rank() == src_lane) - ? attempt_insert( - (storage_ref_.data() + *probing_iter)->data() + intra_window_index, - value, - predicate) + ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + value, + predicate) : insert_result::CONTINUE; switch (group.shfl(status, src_lane)) { @@ -276,7 +276,6 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -286,14 +285,14 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); auto const key = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { return value.first; } else { return value; @@ -313,7 +312,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EMPTY) { switch ([&]() { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, value, predicate); + return packed_cas(window_ptr + i, value, predicate); } else { return cas_dependent_write(window_ptr + i, value, predicate); } @@ -339,7 +338,6 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -350,14 +348,14 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { return value.first; } else { return value; @@ -399,7 +397,7 @@ class open_addressing_ref_impl { auto const status = [&]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, value, predicate); + return packed_cas(slot_ptr, value, predicate); } else { return cas_dependent_write(slot_ptr, value, predicate); } @@ -715,7 +713,6 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with one single CAS operation. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -725,7 +722,7 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, Value const& value, Predicate const& predicate) noexcept @@ -733,7 +730,7 @@ class open_addressing_ref_impl { auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast(value)); auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { // If it's a map implementation, compare keys only return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); } else { @@ -746,7 +743,7 @@ class open_addressing_ref_impl { } else { // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare auto const res = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { // If it's a map implementation, compare keys only return predicate.equal_to(old_ptr->first, value.first); } else { @@ -852,7 +849,6 @@ class open_addressing_ref_impl { * @note Dispatches the correct implementation depending on the container * type and presence of other operator mixins. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -862,13 +858,13 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, Value const& value, Predicate const& predicate) noexcept { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot, value, predicate); + return packed_cas(slot, value, predicate); } else { #if (_CUDA_ARCH__ < 700) return cas_dependent_write(slot, value, predicate); diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 28b3ffaf2..e85b77509 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -248,9 +248,8 @@ class operator_impl< */ __device__ bool insert(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert(value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert(value, ref_.predicate_); } /** @@ -263,9 +262,8 @@ class operator_impl< __device__ bool insert(cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - auto& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert(group, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + return ref_.impl_.insert(group, value, ref_.predicate_); } }; @@ -492,9 +490,8 @@ class operator_impl< */ __device__ thrust::pair insert_and_find(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert_and_find(value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert_and_find(value, ref_.predicate_); } /** @@ -513,9 +510,8 @@ class operator_impl< __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert_and_find(group, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert_and_find(group, value, ref_.predicate_); } }; diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 3dbda9bbf..3b754d972 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -137,9 +137,8 @@ class operator_impl __device__ bool insert(Value const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert(value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert(value, ref_.predicate_); } /** @@ -156,9 +155,8 @@ class operator_impl const& group, Value const& value) noexcept { - auto& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert(group, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + return ref_.impl_.insert(group, value, ref_.predicate_); } }; @@ -224,9 +222,8 @@ class operator_impl __device__ thrust::pair insert_and_find(Value const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert_and_find(value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert_and_find(value, ref_.predicate_); } /** @@ -248,9 +245,8 @@ class operator_impl insert_and_find( cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert_and_find(group, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert_and_find(group, value, ref_.predicate_); } }; diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index 4db0d43e7..34fcfc805 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include #include diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index c41ed88f3..f65b4566b 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include #include diff --git a/include/cuco/static_set.cuh b/include/cuco/static_set.cuh index 6d48d5dc8..979bdfead 100644 --- a/include/cuco/static_set.cuh +++ b/include/cuco/static_set.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/include/cuco/static_set_ref.cuh b/include/cuco/static_set_ref.cuh index b2c8158e7..af34b134e 100644 --- a/include/cuco/static_set_ref.cuh +++ b/include/cuco/static_set_ref.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include #include