diff --git a/.gitignore b/.gitignore index 6ccf378c2..17647a0f5 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ nosetests.xml coverage.xml *.cover .hypothesis/ +/tests/Testing/ ## Patching *.diff @@ -139,6 +140,14 @@ ENV/ # clang compile_commands.json +/.clangd/ # figures *.eps + +# Github +/.config/ +/.devcontainer.json + +# AWS cache +/.aws/ diff --git a/include/cuco/detail/common_kernels.cuh b/include/cuco/detail/common_kernels.cuh index 223f20609..cecd50735 100644 --- a/include/cuco/detail/common_kernels.cuh +++ b/include/cuco/detail/common_kernels.cuh @@ -149,6 +149,38 @@ __global__ void insert_if_n( } } +/** + * @brief Asynchronously erases keys in the range `[first, first + n)`. + * + * @tparam CGSize Number of threads in each CG + * @tparam BlockSize Number of threads in each block + * @tparam InputIt Device accessible input iterator whose `value_type` is + * convertible to the `value_type` of the data structure + * @tparam Ref Type of non-owning device ref allowing access to storage + * + * @param first Beginning of the sequence of input elements + * @param n Number of input elements + * @param ref Non-owning container device ref used to access the slot storage + */ +template +__global__ void erase(InputIt first, cuco::detail::index_type n, Ref ref) +{ + auto const loop_stride = cuco::detail::grid_stride() / CGSize; + auto idx = cuco::detail::global_thread_id() / CGSize; + + while (idx < n) { + typename std::iterator_traits::value_type const& erase_element{*(first + idx)}; + if constexpr (CGSize == 1) { + ref.erase(erase_element); + } else { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + ref.erase(tile, erase_element); + } + idx += loop_stride; + } +} + /** * @brief Indicates whether the keys in the range `[first, first + n)` are contained in the data * structure if `pred` of the corresponding stencil returns true. diff --git a/include/cuco/detail/equal_wrapper.cuh b/include/cuco/detail/equal_wrapper.cuh index 2768c37b5..e42caa401 100644 --- a/include/cuco/detail/equal_wrapper.cuh +++ b/include/cuco/detail/equal_wrapper.cuh @@ -26,7 +26,7 @@ namespace detail { /** * @brief Enum of equality comparison results. */ -enum class equal_result : int32_t { UNEQUAL = 0, EMPTY = 1, EQUAL = 2 }; +enum class equal_result : int32_t { UNEQUAL = 0, EMPTY = 1, EQUAL = 2, ERASED = 3 }; /** * @brief Key equality wrapper. @@ -39,17 +39,21 @@ enum class equal_result : int32_t { UNEQUAL = 0, EMPTY = 1, EQUAL = 2 }; template struct equal_wrapper { // TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper - T empty_sentinel_; ///< Sentinel value - Equal equal_; ///< Custom equality callable + T empty_sentinel_; ///< Empty sentinel value + T erased_sentinel_; ///< Erased sentinel value + Equal equal_; ///< Custom equality callable /** * @brief Equality wrapper ctor. * - * @param sentinel Sentinel value + * @param empty_sentinel Empty sentinel value + * @param erased_sentinel Erased sentinel value * @param equal Equality binary callable */ - __host__ __device__ constexpr equal_wrapper(T sentinel, Equal const& equal) noexcept - : empty_sentinel_{sentinel}, equal_{equal} + __host__ __device__ constexpr equal_wrapper(T empty_sentinel, + T erased_sentinel, + Equal const& equal) noexcept + : empty_sentinel_{empty_sentinel}, erased_sentinel_{erased_sentinel}, equal_{equal} { } diff --git a/include/cuco/detail/open_addressing/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh index 556f821d4..84865583d 100644 --- a/include/cuco/detail/open_addressing/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_impl.cuh @@ -130,6 +130,7 @@ class open_addressing_impl { cuda_stream_ref stream) noexcept : empty_key_sentinel_{empty_key_sentinel}, empty_slot_sentinel_{empty_slot_sentinel}, + erased_key_sentinel_{empty_key_sentinel}, predicate_{pred}, probing_scheme_{probing_scheme}, storage_{make_window_extent(capacity), alloc} @@ -190,6 +191,50 @@ class open_addressing_impl { this->clear_async(stream); } + /** + * @brief Constructs a statically-sized open addressing data structure with the specified initial + * capacity, sentinel values and CUDA stream. + * + * @note The actual capacity depends on the given `capacity`, the probing scheme, CG size, and the + * window size and it is computed via the `make_window_extent` factory. Insert operations will not + * automatically grow the container. Attempting to insert more unique keys than the capacity of + * the container results in undefined behavior. + * @note Any `*_sentinel`s are reserved and behavior is undefined when attempting to insert + * this sentinel value. + * @note If a non-default CUDA stream is provided, the caller is responsible for synchronizing the + * stream before the object is first used. + * + * @param capacity The requested lower-bound size + * @param empty_key_sentinel The reserved key value for empty slots + * @param empty_slot_sentinel The reserved slot value for empty slots + * @param erased_key_sentinel The reserved key value for erased slots + * @param pred Key equality binary predicate + * @param probing_scheme Probing scheme + * @param alloc Allocator used for allocating device storage + * @param stream CUDA stream used to initialize the data structure + */ + constexpr open_addressing_impl(Extent capacity, + Key empty_key_sentinel, + Value empty_slot_sentinel, + Key erased_key_sentinel, + KeyEqual const& pred, + ProbingScheme const& probing_scheme, + Allocator const& alloc, + cuda_stream_ref stream) + : empty_key_sentinel_{empty_key_sentinel}, + empty_slot_sentinel_{empty_slot_sentinel}, + erased_key_sentinel_{erased_key_sentinel}, + predicate_{pred}, + probing_scheme_{probing_scheme}, + storage_{make_window_extent(capacity), alloc} + { + CUCO_EXPECTS(empty_key_sentinel_ != erased_key_sentinel_, + "The empty key sentinel and erased key sentinel cannot be the same value.", + std::logic_error); + + this->clear_async(stream); + } + /** * @brief Erases all elements from the container. After this call, `size()` returns zero. * Invalidates any references, pointers, or iterators referring to contained elements. @@ -365,6 +410,46 @@ class open_addressing_impl { first, num_keys, stencil, pred, container_ref); } + /** + * @brief Asynchronously erases keys in the range `[first, last)`. + * + * @note For each key `k` in `[first, last)`, if contains(k) returns true, removes `k` and it's + * associated value from the container. Else, no effect. + * + * @note Side-effects: + * - `contains(k) == false` + * - `find(k) == end()` + * - `insert({k,v}) == true` + * - `size()` is reduced by the total number of erased keys + * + * @tparam InputIt Device accessible input iterator whose `value_type` is + * convertible to the container's `key_type` + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param container_ref Non-owning device container ref used to access the slot storage + * @param stream Stream used for executing the kernels + * + * @throw std::runtime_error if a unique erased key sentinel value was not + * provided at construction + */ + template + void erase_async(InputIt first, InputIt last, Ref container_ref, cuda_stream_ref stream = {}) + { + CUCO_EXPECTS(empty_key_sentinel_ != erased_key_sentinel_, + "The empty key sentinel and erased key sentinel cannot be the same value.", + std::logic_error); + + auto const num_keys = cuco::detail::distance(first, last); + if (num_keys == 0) { return; } + + auto const grid_size = cuco::detail::grid_size(num_keys, cg_size); + + detail::erase + <<>>( + first, num_keys, container_ref); + } + /** * @brief Asynchronously indicates whether the keys in the range `[first, last)` are contained in * the container. @@ -556,6 +641,16 @@ class open_addressing_impl { return empty_key_sentinel_; } + /** + * @brief Gets the sentinel value used to represent an erased key slot. + * + * @return The sentinel value used to represent an erased key slot + */ + [[nodiscard]] constexpr key_type erased_key_sentinel() const noexcept + { + return erased_key_sentinel_; + } + /** * @brief Gets the key comparator. * @@ -588,8 +683,10 @@ class open_addressing_impl { [[nodiscard]] constexpr storage_ref_type storage_ref() const noexcept { return storage_.ref(); } protected: + // TODO: cleanup by using equal wrapper as a data member key_type empty_key_sentinel_; ///< Key value that represents an empty slot value_type empty_slot_sentinel_; ///< Slot value that represents an empty slot + key_type erased_key_sentinel_; ///< Key value that represents an erased slot key_equal predicate_; ///< Key equality binary predicate probing_scheme_type probing_scheme_; ///< Probing scheme storage_type storage_; ///< Slot window storage diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index e610eeb0c..565af7aa4 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -130,7 +130,30 @@ class open_addressing_ref_impl { probing_scheme_type const& probing_scheme, storage_ref_type storage_ref) noexcept : empty_slot_sentinel_{empty_slot_sentinel}, - predicate_{this->extract_key(empty_slot_sentinel), predicate}, + predicate_{ + this->extract_key(empty_slot_sentinel), this->extract_key(empty_slot_sentinel), predicate}, + probing_scheme_{probing_scheme}, + storage_ref_{storage_ref} + { + } + + /** + * @brief Constructs open_addressing_ref_impl. + * + * @param empty_slot_sentinel Sentinel indicating an empty slot + * @param erased_key_sentinel Sentinel indicating an erased key + * @param predicate Key equality binary callable + * @param probing_scheme Probing scheme + * @param storage_ref Non-owning ref of slot storage + */ + __host__ __device__ explicit constexpr open_addressing_ref_impl( + value_type empty_slot_sentinel, + key_type erased_key_sentinel, + key_equal const& predicate, + probing_scheme_type const& probing_scheme, + storage_ref_type storage_ref) noexcept + : empty_slot_sentinel_{empty_slot_sentinel}, + predicate_{this->extract_key(empty_slot_sentinel), erased_key_sentinel, predicate}, probing_scheme_{probing_scheme}, storage_ref_{storage_ref} { @@ -157,6 +180,16 @@ class open_addressing_ref_impl { return this->extract_payload(this->empty_slot_sentinel()); } + /** + * @brief Gets the sentinel value used to represent an erased key slot. + * + * @return The sentinel value used to represent an erased key slot + */ + [[nodiscard]] __host__ __device__ constexpr key_type const& erased_key_sentinel() const noexcept + { + return this->predicate_.erased_sentinel_; + } + /** * @brief Gets the sentinel used to represent an empty slot. * @@ -250,9 +283,12 @@ class open_addressing_ref_impl { // If the key is already in the container, return false if (eq_res == detail::equal_result::EQUAL) { return false; } - if (eq_res == detail::equal_result::EMPTY) { + if (eq_res == detail::equal_result::EMPTY or + cuco::detail::bitwise_compare(this->extract_key(slot_content), + this->erased_key_sentinel())) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + slot_content, value)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; @@ -291,7 +327,14 @@ class open_addressing_ref_impl { return window_probing_results{detail::equal_result::EMPTY, i}; case detail::equal_result::EQUAL: return window_probing_results{detail::equal_result::EQUAL, i}; - default: continue; + default: { + if (cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), + this->erased_key_sentinel())) { + return window_probing_results{detail::equal_result::ERASED, i}; + } else { + continue; + } + } } } // returns dummy index `-1` for UNEQUAL @@ -301,13 +344,14 @@ class open_addressing_ref_impl { // If the key is already in the container, return false if (group.any(state == detail::equal_result::EQUAL)) { return false; } - auto const group_contains_empty = group.ballot(state == detail::equal_result::EMPTY); - - if (group_contains_empty) { - auto const src_lane = __ffs(group_contains_empty) - 1; + auto const group_contains_available = + group.ballot(state == detail::equal_result::EMPTY or state == detail::equal_result::ERASED); + if (group_contains_available) { + auto const src_lane = __ffs(group_contains_available) - 1; auto const status = (group.thread_rank() == src_lane) ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + window_slots[intra_window_index], value) : insert_result::CONTINUE; @@ -353,12 +397,14 @@ class open_addressing_ref_impl { // If the key is already in the container, return false if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; } - if (eq_res == detail::equal_result::EMPTY) { + if (eq_res == detail::equal_result::EMPTY or + cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), + this->erased_key_sentinel())) { switch ([&]() { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, value); + return packed_cas(window_ptr + i, window_slots[i], value); } else { - return cas_dependent_write(window_ptr + i, value); + return cas_dependent_write(window_ptr + i, window_slots[i], value); } }()) { case insert_result::SUCCESS: { @@ -407,7 +453,14 @@ class open_addressing_ref_impl { return window_probing_results{detail::equal_result::EMPTY, i}; case detail::equal_result::EQUAL: return window_probing_results{detail::equal_result::EQUAL, i}; - default: continue; + default: { + if (cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), + this->erased_key_sentinel())) { + return window_probing_results{detail::equal_result::ERASED, i}; + } else { + continue; + } + } } } // returns dummy index `-1` for UNEQUAL @@ -424,16 +477,17 @@ class open_addressing_ref_impl { return {iterator{reinterpret_cast(res)}, false}; } - auto const group_contains_empty = group.ballot(state == detail::equal_result::EMPTY); - if (group_contains_empty) { - auto const src_lane = __ffs(group_contains_empty) - 1; + auto const group_contains_available = + group.ballot(state == detail::equal_result::EMPTY or state == detail::equal_result::ERASED); + if (group_contains_available) { + auto const src_lane = __ffs(group_contains_available) - 1; auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); - auto const status = [&]() { + auto const status = [&, target_idx = intra_window_index]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, value); + return packed_cas(slot_ptr, window_slots[target_idx], value); } else { - return cas_dependent_write(slot_ptr, value); + return cas_dependent_write(slot_ptr, window_slots[target_idx], value); } }(); @@ -452,6 +506,103 @@ class open_addressing_ref_impl { } } + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * + * @param value The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(ProbeKey const& key) noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); + + while (true) { + auto const window_slots = storage_ref_[*probing_iter]; + + for (auto& slot_content : window_slots) { + auto const eq_res = this->predicate_(this->extract_key(slot_content), key); + + // Key doesn't exist, return false + if (eq_res == detail::equal_result::EMPTY) { return false; } + // Key exists, return true if successfully deleted + if (eq_res == detail::equal_result::EQUAL) { + auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); + switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + slot_content, + this->erased_slot_sentinel())) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: return false; + default: continue; + } + } + } + ++probing_iter; + } + } + + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * + * @param group The Cooperative Group used to perform group erase + * @param value The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key) noexcept + { + auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); + + while (true) { + auto const window_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_window_index] = [&]() { + for (auto i = 0; i < window_size; ++i) { + switch (this->predicate_(this->extract_key(window_slots[i]), key)) { + case detail::equal_result::EMPTY: + return window_probing_results{detail::equal_result::EMPTY, i}; + case detail::equal_result::EQUAL: + return window_probing_results{detail::equal_result::EQUAL, i}; + default: continue; + } + } + // returns dummy index `-1` for UNEQUAL + return window_probing_results{detail::equal_result::UNEQUAL, -1}; + }(); + + auto const group_contains_equal = group.ballot(state == detail::equal_result::EQUAL); + if (group_contains_equal) { + auto const src_lane = __ffs(group_contains_equal) - 1; + auto const status = + (group.thread_rank() == src_lane) + ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + window_slots[intra_window_index], + this->erased_slot_sentinel()) + : insert_result::CONTINUE; + + switch (group.shfl(status, src_lane)) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: return false; + default: continue; + } + } else if (group.any(state == detail::equal_result::EMPTY)) { + // Key doesn't exist, return false + return false; + } else { + ++probing_iter; + } + } + } + /** * @brief Indicates whether the probe key `key` was inserted into the container. * @@ -738,46 +889,45 @@ class open_addressing_ref_impl { return value.second; } + /** + * @brief Gets the sentinel used to represent an erased slot. + * + * @return The sentinel value used to represent an erased slot + */ + [[nodiscard]] __device__ constexpr value_type const erased_slot_sentinel() const noexcept + { + if constexpr (this->has_payload) { + return cuco::pair{this->erased_key_sentinel(), this->empty_slot_sentinel().second}; + } else { + return this->erased_key_sentinel(); + } + } + /** * @brief Inserts the specified element with one single CAS operation. * * @tparam Value Input type which is implicitly convertible to 'value_type' * - * @param slot Pointer to the slot in memory - * @param value Element to insert + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert * * @return Result of this operation, i.e., success/continue/duplicate */ template - [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, - Value const& value) noexcept + [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* address, + value_type const& expected, + Value const& desired) noexcept { - auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast(value)); - auto* old_ptr = reinterpret_cast(&old); - auto const inserted = [&]() { - if constexpr (this->has_payload) { - // If it's a map implementation, compare keys only - return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); - } else { - // If it's a set implementation, compare the whole slot content - return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); - } - }(); - if (inserted) { + auto old = compare_and_swap(address, expected, static_cast(desired)); + auto* old_ptr = reinterpret_cast(&old); + if (cuco::detail::bitwise_compare(this->extract_key(*old_ptr), this->extract_key(expected))) { return insert_result::SUCCESS; } else { - // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - auto const res = [&]() { - if constexpr (this->has_payload) { - // If it's a map implementation, compare keys only - return this->predicate_.equal_to(old_ptr->first, value.first); - } else { - // If it's a set implementation, compare the whole slot content - return this->predicate_.equal_to(*old_ptr, value); - } - }(); - return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE - : insert_result::CONTINUE; + return this->predicate_.equal_to(this->extract_key(*old_ptr), this->extract_key(desired)) == + detail::equal_result::EQUAL + ? insert_result::DUPLICATE + : insert_result::CONTINUE; } } @@ -786,23 +936,26 @@ class open_addressing_ref_impl { * * @tparam Value Input type which is implicitly convertible to 'value_type' * - * @param slot Pointer to the slot in memory - * @param value Element to insert + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert * * @return Result of this operation, i.e., success/continue/duplicate */ template - [[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* slot, - Value const& value) noexcept + [[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* address, + value_type const& expected, + Value const& desired) noexcept { using mapped_type = decltype(this->empty_slot_sentinel_.second); - auto const expected_key = this->empty_slot_sentinel_.first; - auto const expected_payload = this->empty_slot_sentinel_.second; + auto const expected_key = expected.first; + auto const expected_payload = expected.second; - auto old_key = compare_and_swap(&slot->first, expected_key, static_cast(value.first)); - auto old_payload = - compare_and_swap(&slot->second, expected_payload, static_cast(value.second)); + auto old_key = + compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_payload = compare_and_swap( + &address->second, expected_payload, static_cast(desired.second)); auto* old_key_ptr = reinterpret_cast(&old_key); auto* old_payload_ptr = reinterpret_cast(&old_payload); @@ -810,17 +963,17 @@ class open_addressing_ref_impl { // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - old_payload = - compare_and_swap(&slot->second, expected_payload, static_cast(value.second)); + old_payload = compare_and_swap( + &address->second, expected_payload, static_cast(desired.second)); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - atomic_store(&slot->second, expected_payload); + atomic_store(&address->second, expected_payload); } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -832,32 +985,34 @@ class open_addressing_ref_impl { * * @tparam Value Input type which is implicitly convertible to 'value_type' * - * @param slot Pointer to the slot in memory - * @param value Element to insert + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert * * @return Result of this operation, i.e., success/continue/duplicate */ template - [[nodiscard]] __device__ constexpr insert_result cas_dependent_write(value_type* slot, - Value const& value) noexcept + [[nodiscard]] __device__ constexpr insert_result cas_dependent_write( + value_type* address, value_type const& expected, Value const& desired) noexcept { using mapped_type = decltype(this->empty_slot_sentinel_.second); - auto const expected_key = this->empty_slot_sentinel_.first; + auto const expected_key = expected.first; - auto old_key = compare_and_swap(&slot->first, expected_key, static_cast(value.first)); + auto old_key = + compare_and_swap(&address->first, expected_key, static_cast(desired.first)); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&slot->second, static_cast(value.second)); + atomic_store(&address->second, static_cast(desired.second)); return insert_result::SUCCESS; } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -872,22 +1027,24 @@ class open_addressing_ref_impl { * * @tparam Value Input type which is implicitly convertible to 'value_type' * - * @param slot Pointer to the slot in memory - * @param value Element to insert + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert * * @return Result of this operation, i.e., success/continue/duplicate */ template - [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, - Value const& value) noexcept + [[nodiscard]] __device__ insert_result attempt_insert(value_type* address, + value_type const& expected, + Value const& desired) noexcept { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot, value); + return packed_cas(address, expected, desired); } else { #if (_CUDA_ARCH__ < 700) - return cas_dependent_write(slot, value); + return cas_dependent_write(address, expected, desired); #else - return back_to_back_cas(slot, value); + return back_to_back_cas(address, expected, desired); #endif } } diff --git a/include/cuco/detail/static_map/functors.cuh b/include/cuco/detail/static_map/functors.cuh index f508206f0..c877c9e07 100644 --- a/include/cuco/detail/static_map/functors.cuh +++ b/include/cuco/detail/static_map/functors.cuh @@ -63,14 +63,19 @@ struct get_slot { */ template struct slot_is_filled { - T empty_sentinel_; ///< The value of the empty key sentinel + T empty_sentinel_; ///< The value of the empty key sentinel + T erased_sentinel_; ///< Key value that represents an erased slot /** - * @brief Constructs `slot_is_filled` functor with the given empty sentinel. + * @brief Constructs `slot_is_filled` functor with the given sentinels. * - * @param s Sentinel indicating empty slot + * @param empty_sentinel Sentinel indicating empty slot + * @param erased_sentinel Sentinel indicating erased slot */ - explicit constexpr slot_is_filled(T const& s) noexcept : empty_sentinel_{s} {} + explicit constexpr slot_is_filled(T const& empty_sentinel, T const& erased_sentinel) noexcept + : empty_sentinel_{empty_sentinel}, erased_sentinel_{erased_sentinel} + { + } /** * @brief Indicates if the target slot `slot` is filled. @@ -84,7 +89,8 @@ struct slot_is_filled { template __device__ constexpr bool operator()(Slot const& slot) const noexcept { - return not cuco::detail::bitwise_compare(empty_sentinel_, thrust::get<0>(slot)); + return not(cuco::detail::bitwise_compare(empty_sentinel_, thrust::get<0>(slot)) or + cuco::detail::bitwise_compare(erased_sentinel_, thrust::get<0>(slot))); } /** @@ -96,7 +102,8 @@ struct slot_is_filled { */ __device__ constexpr bool operator()(cuco::pair const& slot) const noexcept { - return not cuco::detail::bitwise_compare(empty_sentinel_, slot.first); + return not(cuco::detail::bitwise_compare(empty_sentinel_, slot.first) or + cuco::detail::bitwise_compare(erased_sentinel_, slot.first)); } }; diff --git a/include/cuco/detail/static_map/static_map.inl b/include/cuco/detail/static_map/static_map.inl index 1cc932aeb..36eb74f5f 100644 --- a/include/cuco/detail/static_map/static_map.inl +++ b/include/cuco/detail/static_map/static_map.inl @@ -83,6 +83,35 @@ constexpr static_map +constexpr static_map:: + static_map(Extent capacity, + empty_key empty_key_sentinel, + empty_value empty_value_sentinel, + erased_key erased_key_sentinel, + KeyEqual const& pred, + ProbingScheme const& probing_scheme, + Allocator const& alloc, + cuda_stream_ref stream) + : impl_{std::make_unique(capacity, + empty_key_sentinel, + cuco::pair{empty_key_sentinel, empty_value_sentinel}, + erased_key_sentinel, + pred, + probing_scheme, + alloc, + stream)}, + empty_value_sentinel_{empty_value_sentinel} +{ +} + template +template +void static_map::erase( + InputIt first, InputIt last, cuda_stream_ref stream) +{ + erase_async(first, last, stream); + stream.synchronize(); +} + +template +template +void static_map::erase_async( + InputIt first, InputIt last, cuda_stream_ref stream) +{ + impl_->erase_async(first, last, ref(op::erase), stream); +} + template :: auto const begin = thrust::make_transform_iterator( thrust::counting_iterator{0}, static_map_ns::detail::get_slot(impl_->storage_ref())); - auto const is_filled = static_map_ns::detail::slot_is_filled(this->empty_key_sentinel()); + auto const is_filled = static_map_ns::detail::slot_is_filled(this->empty_key_sentinel(), + this->erased_key_sentinel()); auto zipped_out_begin = thrust::make_zip_iterator(thrust::make_tuple(keys_out, values_out)); auto const zipped_out_end = impl_->retrieve_all(begin, zipped_out_begin, is_filled, stream); auto const num = std::distance(zipped_out_begin, zipped_out_end); @@ -358,7 +419,8 @@ static_map:: static_map::size( cuda_stream_ref stream) const noexcept { - auto const is_filled = static_map_ns::detail::slot_is_filled(this->empty_key_sentinel()); + auto const is_filled = static_map_ns::detail::slot_is_filled(this->empty_key_sentinel(), + this->erased_key_sentinel()); return impl_->size(is_filled, stream); } @@ -408,6 +470,21 @@ constexpr static_mapempty_value_sentinel_; } +template +constexpr static_map::key_type +static_map:: + erased_key_sentinel() const noexcept +{ + return impl_->erased_key_sentinel(); +} + template {cuco::empty_key(this->empty_key_sentinel()), - cuco::empty_value(this->empty_value_sentinel()), - impl_->key_eq(), - impl_->probing_scheme(), - impl_->storage_ref()}; + return this->empty_key_sentinel() == this->erased_key_sentinel() + ? ref_type{cuco::empty_key(this->empty_key_sentinel()), + cuco::empty_value(this->empty_value_sentinel()), + impl_->key_eq(), + impl_->probing_scheme(), + impl_->storage_ref()} + : ref_type{cuco::empty_key(this->empty_key_sentinel()), + cuco::empty_value(this->empty_value_sentinel()), + cuco::erased_key(this->erased_key_sentinel()), + impl_->key_eq(), + impl_->probing_scheme(), + impl_->storage_ref()}; } } // namespace experimental } // namespace cuco diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 1f793138c..f27f21e76 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -49,6 +49,34 @@ __host__ __device__ constexpr static_map_ref< { } +template +__host__ __device__ constexpr static_map_ref< + Key, + T, + Scope, + KeyEqual, + ProbingScheme, + StorageRef, + Operators...>::static_map_ref(cuco::empty_key empty_key_sentinel, + cuco::empty_value empty_value_sentinel, + cuco::erased_key erased_key_sentinel, + KeyEqual const& predicate, + ProbingScheme const& probing_scheme, + StorageRef storage_ref) noexcept + : impl_{cuco::pair{empty_key_sentinel, empty_value_sentinel}, + erased_key_sentinel, + predicate, + probing_scheme, + storage_ref} +{ +} + template data() + intra_window_index, value)) { @@ -283,7 +312,14 @@ class operator_impl< return detail::window_probing_results{detail::equal_result::EMPTY, i}; case detail::equal_result::EQUAL: return detail::window_probing_results{detail::equal_result::EQUAL, i}; - default: continue; + default: { + if (cuco::detail::bitwise_compare(window_slots[i].first, + ref_.impl_.erased_key_sentinel())) { + return window_probing_results{detail::equal_result::ERASED, i}; + } else { + continue; + } + } } } // returns dummy index `-1` for UNEQUAL @@ -302,9 +338,10 @@ class operator_impl< return; } - auto const group_contains_empty = group.ballot(state == detail::equal_result::EMPTY); - if (group_contains_empty) { - auto const src_lane = __ffs(group_contains_empty) - 1; + auto const group_contains_available = + group.ballot(state == detail::equal_result::EMPTY or state == detail::equal_result::ERASED); + if (group_contains_available) { + auto const src_lane = __ffs(group_contains_available) - 1; auto const status = (group.thread_rank() == src_lane) ? attempt_insert_or_assign( @@ -447,6 +484,60 @@ class operator_impl< } }; +template +class operator_impl< + op::erase_tag, + static_map_ref> { + using base_type = static_map_ref; + using ref_type = static_map_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * + * @param key The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(ProbeKey const& key) noexcept + { + ref_type& ref_ = static_cast(*this); + return ref_.impl_.erase(key); + } + + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * + * @param group The Cooperative Group used to perform group insert + * @param key The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key) noexcept + { + auto& ref_ = static_cast(*this); + return ref_.impl_.erase(group, key); + } +}; + template struct slot_is_filled { - T empty_sentinel_; ///< The value of the empty key sentinel + T empty_sentinel_; ///< The value of the empty key sentinel + T erased_sentinel_; ///< Key value that represents an erased slot /** - * @brief Constructs `slot_is_filled` functor with the given empty sentinel. + * @brief Constructs `slot_is_filled` functor with the given sentinels. * - * @param s Sentinel indicating empty slot + * @param empty_sentinel Sentinel indicating empty slot + * @param erased_sentinel Sentinel indicating erased slot */ - explicit constexpr slot_is_filled(T const& s) noexcept : empty_sentinel_{s} {} + explicit constexpr slot_is_filled(T const& empty_sentinel, T const& erased_sentinel) noexcept + : empty_sentinel_{empty_sentinel}, erased_sentinel_{erased_sentinel} + { + } /** * @brief Indicates if the target slot `slot` is filled. @@ -49,7 +54,8 @@ struct slot_is_filled { */ __device__ constexpr bool operator()(T const& slot) const noexcept { - return not cuco::detail::bitwise_compare(empty_sentinel_, slot); + return not(cuco::detail::bitwise_compare(empty_sentinel_, slot) or + cuco::detail::bitwise_compare(erased_sentinel_, slot)); } }; diff --git a/include/cuco/detail/static_set/static_set.inl b/include/cuco/detail/static_set/static_set.inl index 888f0d67c..9fa87196a 100644 --- a/include/cuco/detail/static_set/static_set.inl +++ b/include/cuco/detail/static_set/static_set.inl @@ -72,6 +72,32 @@ constexpr static_set +constexpr static_set::static_set( + Extent capacity, + empty_key empty_key_sentinel, + erased_key erased_key_sentinel, + KeyEqual const& pred, + ProbingScheme const& probing_scheme, + Allocator const& alloc, + cuda_stream_ref stream) + : impl_{std::make_unique(capacity, + empty_key_sentinel, + empty_key_sentinel, + erased_key_sentinel, + pred, + probing_scheme, + alloc, + stream)} +{ +} + template impl_->insert_if_async(first, last, stencil, pred, ref(op::insert), stream); } +template +template +void static_set::erase( + InputIt first, InputIt last, cuda_stream_ref stream) +{ + erase_async(first, last, stream); + stream.synchronize(); +} + +template +template +void static_set::erase_async( + InputIt first, InputIt last, cuda_stream_ref stream) +{ + impl_->erase_async(first, last, ref(op::erase), stream); +} + template {0}, detail::get_slot(impl_->storage_ref())); - auto const is_filled = static_set_ns::detail::slot_is_filled(this->empty_key_sentinel()); + auto const is_filled = + static_set_ns::detail::slot_is_filled(this->empty_key_sentinel(), this->erased_key_sentinel()); return impl_->retrieve_all(begin, output_begin, is_filled, stream); } @@ -290,7 +346,8 @@ static_set::siz static_set::size( cuda_stream_ref stream) const noexcept { - auto const is_filled = static_set_ns::detail::slot_is_filled(this->empty_key_sentinel()); + auto const is_filled = + static_set_ns::detail::slot_is_filled(this->empty_key_sentinel(), this->erased_key_sentinel()); return impl_->size(is_filled, stream); } @@ -322,6 +379,20 @@ static_set::emp return impl_->empty_key_sentinel(); } +template +constexpr static_set::key_type +static_set::erased_key_sentinel() + const noexcept +{ + return impl_->erased_key_sentinel(); +} + template Operators...) const noexcept { static_assert(sizeof...(Operators), "No operators specified"); - return ref_type{cuco::empty_key(this->empty_key_sentinel()), - impl_->key_eq(), - impl_->probing_scheme(), - impl_->storage_ref()}; + return this->empty_key_sentinel() == this->erased_key_sentinel() + ? ref_type{cuco::empty_key(this->empty_key_sentinel()), + impl_->key_eq(), + impl_->probing_scheme(), + impl_->storage_ref()} + : ref_type{cuco::empty_key(this->empty_key_sentinel()), + cuco::erased_key(this->erased_key_sentinel()), + impl_->key_eq(), + impl_->probing_scheme(), + impl_->storage_ref()}; } } // namespace experimental } // namespace cuco diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 08fc4b6db..c21e0dddb 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -45,6 +45,27 @@ __host__ __device__ constexpr static_set_ref< { } +template +__host__ __device__ constexpr static_set_ref< + Key, + Scope, + KeyEqual, + ProbingScheme, + StorageRef, + Operators...>::static_set_ref(cuco::empty_key empty_key_sentinel, + cuco::erased_key erased_key_sentinel, + KeyEqual const& predicate, + ProbingScheme const& probing_scheme, + StorageRef storage_ref) noexcept + : impl_{empty_key_sentinel, erased_key_sentinel, predicate, probing_scheme, storage_ref} +{ +} + template +class operator_impl> { + using base_type = static_set_ref; + using ref_type = static_set_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * + * @param key The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(ProbeKey const& key) noexcept + { + ref_type& ref_ = static_cast(*this); + return ref_.impl_.erase(key); + } + + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * + * @param group The Cooperative Group used to perform group erase + * @param value The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key) noexcept + { + auto& ref_ = static_cast(*this); + return ref_.impl_.erase(group, key); + } +}; + template empty_key_sentinel, + empty_value empty_value_sentinel, + erased_key erased_key_sentinel, + KeyEqual const& pred = {}, + ProbingScheme const& probing_scheme = {}, + Allocator const& alloc = {}, + cuda_stream_ref stream = {}); + /** * @brief Erases all elements from the container. After this call, `size()` returns zero. * Invalidates any references, pointers, or iterators referring to contained elements. @@ -367,6 +399,58 @@ class static_map { template void insert_or_assign_async(InputIt first, InputIt last, cuda_stream_ref stream = {}) noexcept; + /** + * @brief Erases keys in the range `[first, last)`. + * + * @note For each key `k` in `[first, last)`, if contains(k) returns true, removes `k` and it's + * associated value from the map. Else, no effect. + * + * @note This function synchronizes `stream`. + * + * @note Side-effects: + * - `contains(k) == false` + * - `find(k) == end()` + * - `insert({k,v}) == true` + * - `size()` is reduced by the total number of erased keys + * + * @tparam InputIt Device accessible input iterator whose `value_type` is + * convertible to the map's `key_type` + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stream Stream used for executing the kernels + * + * @throw std::runtime_error if a unique erased key sentinel value was not + * provided at construction + */ + template + void erase(InputIt first, InputIt last, cuda_stream_ref stream = {}); + + /** + * @brief Asynchronously erases keys in the range `[first, last)`. + * + * @note For each key `k` in `[first, last)`, if contains(k) returns true, removes `k` and it's + * associated value from the map. Else, no effect. + * + * @note Side-effects: + * - `contains(k) == false` + * - `find(k) == end()` + * - `insert({k,v}) == true` + * - `size()` is reduced by the total number of erased keys + * + * @tparam InputIt Device accessible input iterator whose `value_type` is + * convertible to the map's `key_type` + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stream Stream used for executing the kernels + * + * @throw std::runtime_error if a unique erased key sentinel value was not + * provided at construction + */ + template + void erase_async(InputIt first, InputIt last, cuda_stream_ref stream = {}); + /** * @brief Indicates whether the keys in the range `[first, last)` are contained in the map. * @@ -567,6 +651,13 @@ class static_map { */ [[nodiscard]] constexpr mapped_type empty_value_sentinel() const noexcept; + /** + * @brief Gets the sentinel value used to represent an erased key slot. + * + * @return The sentinel value used to represent an erased key slot + */ + [[nodiscard]] constexpr key_type erased_key_sentinel() const noexcept; + /** * @brief Get device ref with operators. * diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index 93bb98aea..88e40f86c 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -103,12 +103,28 @@ class static_map_ref * @param probing_scheme Probing scheme * @param storage_ref Non-owning ref of slot storage */ - __host__ __device__ explicit constexpr static_map_ref( - cuco::empty_key empty_key_sentinel, - cuco::empty_value empty_value_sentinel, - key_equal const& predicate, - probing_scheme_type const& probing_scheme, - storage_ref_type storage_ref) noexcept; + __host__ __device__ explicit constexpr static_map_ref(cuco::empty_key empty_key_sentinel, + cuco::empty_value empty_value_sentinel, + KeyEqual const& predicate, + ProbingScheme const& probing_scheme, + StorageRef storage_ref) noexcept; + + /** + * @brief Constructs static_map_ref. + * + * @param empty_key_sentinel Sentinel indicating empty key + * @param empty_value_sentinel Sentinel indicating empty payload + * @param erased_key_sentinel Sentinel indicating erased key + * @param predicate Key equality binary callable + * @param probing_scheme Probing scheme + * @param storage_ref Non-owning ref of slot storage + */ + __host__ __device__ explicit constexpr static_map_ref(cuco::empty_key empty_key_sentinel, + cuco::empty_value empty_value_sentinel, + cuco::erased_key erased_key_sentinel, + KeyEqual const& predicate, + ProbingScheme const& probing_scheme, + StorageRef storage_ref) noexcept; /** * @brief Operator-agnostic move constructor. diff --git a/include/cuco/static_set.cuh b/include/cuco/static_set.cuh index 979bdfead..fdb65f5f8 100644 --- a/include/cuco/static_set.cuh +++ b/include/cuco/static_set.cuh @@ -193,6 +193,36 @@ class static_set { Allocator const& alloc = {}, cuda_stream_ref stream = {}); + /** + * @brief Constructs a statically-sized set with the specified initial capacity, sentinel values + * and CUDA stream. + * + * The actual set capacity depends on the given `capacity`, the probing scheme, CG size, and the + * window size and it is computed via the `make_window_extent` factory. Insert operations will not + * automatically grow the set. Attempting to insert more unique keys than the capacity of the map + * results in undefined behavior. + * + * @note Any `*_sentinel`s are reserved and behavior is undefined when attempting to insert + * this sentinel value. + * @note If a non-default CUDA stream is provided, the caller is responsible for synchronizing the + * stream before the object is first used. + * + * @param capacity The requested lower-bound set size + * @param empty_key_sentinel The reserved key value for empty slots + * @param erased_key_sentinel The reserved key to denote erased slots + * @param pred Key equality binary predicate + * @param probing_scheme Probing scheme + * @param alloc Allocator used for allocating device storage + * @param stream CUDA stream used to initialize the set + */ + constexpr static_set(Extent capacity, + empty_key empty_key_sentinel, + erased_key erased_key_sentinel, + KeyEqual const& pred = {}, + ProbingScheme const& probing_scheme = {}, + Allocator const& alloc = {}, + cuda_stream_ref stream = {}); + /** * @brief Erases all elements from the container. After this call, `size()` returns zero. * Invalidates any references, pointers, or iterators referring to contained elements. @@ -298,6 +328,58 @@ class static_set { Predicate pred, cuda_stream_ref stream = {}) noexcept; + /** + * @brief Erases keys in the range `[first, last)`. + * + * @note For each key `k` in `[first, last)`, if contains(k) returns true, removes `k` and it's + * associated value from the container. Else, no effect. + * + * @note This function synchronizes `stream`. + * + * @note Side-effects: + * - `contains(k) == false` + * - `find(k) == end()` + * - `insert({k,v}) == true` + * - `size()` is reduced by the total number of erased keys + * + * @tparam InputIt Device accessible input iterator whose `value_type` is + * convertible to the container's `key_type` + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stream Stream used for executing the kernels + * + * @throw std::runtime_error if a unique erased key sentinel value was not + * provided at construction + */ + template + void erase(InputIt first, InputIt last, cuda_stream_ref stream = {}); + + /** + * @brief Asynchronously erases keys in the range `[first, last)`. + * + * @note For each key `k` in `[first, last)`, if contains(k) returns true, removes `k` and it's + * associated value from the container. Else, no effect. + * + * @note Side-effects: + * - `contains(k) == false` + * - `find(k) == end()` + * - `insert({k,v}) == true` + * - `size()` is reduced by the total number of erased keys + * + * @tparam InputIt Device accessible input iterator whose `value_type` is + * convertible to the container's `key_type` + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stream Stream used for executing the kernels + * + * @throw std::runtime_error if a unique erased key sentinel value was not + * provided at construction + */ + template + void erase_async(InputIt first, InputIt last, cuda_stream_ref stream = {}); + /** * @brief Indicates whether the keys in the range `[first, last)` are contained in the set. * @@ -484,6 +566,13 @@ class static_set { */ [[nodiscard]] constexpr key_type empty_key_sentinel() const noexcept; + /** + * @brief Gets the sentinel value used to represent an erased key slot. + * + * @return The sentinel value used to represent an erased key slot + */ + [[nodiscard]] constexpr key_type erased_key_sentinel() const noexcept; + /** * @brief Get device ref with operators. * diff --git a/include/cuco/static_set_ref.cuh b/include/cuco/static_set_ref.cuh index 517f6b488..de9f54d62 100644 --- a/include/cuco/static_set_ref.cuh +++ b/include/cuco/static_set_ref.cuh @@ -91,11 +91,25 @@ class static_set_ref * @param probing_scheme Probing scheme * @param storage_ref Non-owning ref of slot storage */ - __host__ __device__ explicit constexpr static_set_ref( - cuco::empty_key empty_key_sentinel, - key_equal const& predicate, - probing_scheme_type const& probing_scheme, - storage_ref_type storage_ref) noexcept; + __host__ __device__ explicit constexpr static_set_ref(cuco::empty_key empty_key_sentinel, + KeyEqual const& predicate, + ProbingScheme const& probing_scheme, + StorageRef storage_ref) noexcept; + + /** + * @brief Constructs static_set_ref. + * + * @param empty_key_sentinel Sentinel indicating empty key + * @param erased_key_sentinel Sentinel indicating erased key + * @param predicate Key equality binary callable + * @param probing_scheme Probing scheme + * @param storage_ref Non-owning ref of slot storage + */ + __host__ __device__ explicit constexpr static_set_ref(cuco::empty_key empty_key_sentinel, + cuco::erased_key erased_key_sentinel, + KeyEqual const& predicate, + ProbingScheme const& probing_scheme, + StorageRef storage_ref) noexcept; /** * @brief Operator-agnostic move constructor. diff --git a/tests/static_map/erase_test.cu b/tests/static_map/erase_test.cu index 26cbd3fd3..5e410c5cc 100644 --- a/tests/static_map/erase_test.cu +++ b/tests/static_map/erase_test.cu @@ -27,16 +27,13 @@ #include -TEMPLATE_TEST_CASE_SIG("erase key", "", ((typename T), T), (int32_t), (int64_t)) -{ - using Key = T; - using Value = T; - - constexpr std::size_t num_keys = 1'000'000; - constexpr std::size_t capacity = 1'100'000; +using size_type = int32_t; - cuco::static_map map{ - capacity, cuco::empty_key{-1}, cuco::empty_value{-1}, cuco::erased_key{-2}}; +template +void test_erase(Map& map, size_type num_keys) +{ + using Key = typename Map::key_type; + using Value = typename Map::mapped_type; thrust::device_vector d_keys(num_keys); thrust::device_vector d_values(num_keys); @@ -52,11 +49,11 @@ TEMPLATE_TEST_CASE_SIG("erase key", "", ((typename T), T), (int32_t), (int64_t)) { map.insert(pairs_begin, pairs_begin + num_keys); - REQUIRE(map.get_size() == num_keys); + REQUIRE(map.size() == num_keys); map.erase(d_keys.begin(), d_keys.end()); - REQUIRE(map.get_size() == 0); + REQUIRE(map.size() == 0); map.contains(d_keys.begin(), d_keys.end(), d_keys_exist.begin()); @@ -64,7 +61,7 @@ TEMPLATE_TEST_CASE_SIG("erase key", "", ((typename T), T), (int32_t), (int64_t)) map.insert(pairs_begin, pairs_begin + num_keys); - REQUIRE(map.get_size() == num_keys); + REQUIRE(map.size() == num_keys); map.contains(d_keys.begin(), d_keys.end(), d_keys_exist.begin()); @@ -80,6 +77,53 @@ TEMPLATE_TEST_CASE_SIG("erase key", "", ((typename T), T), (int32_t), (int64_t)) d_keys_exist.begin() + num_keys / 2, d_keys_exist.end(), thrust::identity{})); map.erase(d_keys.begin() + num_keys / 2, d_keys.end()); - REQUIRE(map.get_size() == 0); + REQUIRE(map.size() == 0); } } + +TEMPLATE_TEST_CASE_SIG( + "static_map erase tests", + "", + ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), + Key, + Value, + Probe, + CGSize), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_type num_keys{400}; + + using probe = + std::conditional_t>, + cuco::experimental::double_hashing, + cuco::murmurhash3_32>>; + + auto map = cuco::experimental::static_map, + cuda::thread_scope_device, + thrust::equal_to, + probe, + cuco::cuda_allocator, + cuco::experimental::storage<2>>{ + num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}, cuco::erased_key{-2}}; + + test_erase(map, num_keys); +}