Skip to content

Commit

Permalink
Add erase to new map and set (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel authored Oct 11, 2023
1 parent ed32bab commit 0f86edb
Show file tree
Hide file tree
Showing 16 changed files with 1,020 additions and 132 deletions.
32 changes: 32 additions & 0 deletions include/cuco/detail/common_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,38 @@ __global__ void insert_if_n(
}
}

/**
* @brief Asynchronously erases keys in the range `[first, first + n)`.
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param first Beginning of the sequence of input elements
* @param n Number of input elements
* @param ref Non-owning container device ref used to access the slot storage
*/
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename Ref>
__global__ void erase(InputIt first, cuco::detail::index_type n, Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
typename std::iterator_traits<InputIt>::value_type const& erase_element{*(first + idx)};
if constexpr (CGSize == 1) {
ref.erase(erase_element);
} else {
auto const tile =
cooperative_groups::tiled_partition<CGSize>(cooperative_groups::this_thread_block());
ref.erase(tile, erase_element);
}
idx += loop_stride;
}
}

/**
* @brief Indicates whether the keys in the range `[first, first + n)` are contained in the data
* structure if `pred` of the corresponding stencil returns true.
Expand Down
16 changes: 10 additions & 6 deletions include/cuco/detail/equal_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace detail {
/**
* @brief Enum of equality comparison results.
*/
enum class equal_result : int32_t { UNEQUAL = 0, EMPTY = 1, EQUAL = 2 };
enum class equal_result : int32_t { UNEQUAL = 0, EMPTY = 1, EQUAL = 2, ERASED = 3 };

/**
* @brief Key equality wrapper.
Expand All @@ -39,17 +39,21 @@ enum class equal_result : int32_t { UNEQUAL = 0, EMPTY = 1, EQUAL = 2 };
template <typename T, typename Equal>
struct equal_wrapper {
// TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper
T empty_sentinel_; ///< Sentinel value
Equal equal_; ///< Custom equality callable
T empty_sentinel_; ///< Empty sentinel value
T erased_sentinel_; ///< Erased sentinel value
Equal equal_; ///< Custom equality callable

/**
* @brief Equality wrapper ctor.
*
* @param sentinel Sentinel value
* @param empty_sentinel Empty sentinel value
* @param erased_sentinel Erased sentinel value
* @param equal Equality binary callable
*/
__host__ __device__ constexpr equal_wrapper(T sentinel, Equal const& equal) noexcept
: empty_sentinel_{sentinel}, equal_{equal}
__host__ __device__ constexpr equal_wrapper(T empty_sentinel,
T erased_sentinel,
Equal const& equal) noexcept
: empty_sentinel_{empty_sentinel}, erased_sentinel_{erased_sentinel}, equal_{equal}
{
}

Expand Down
97 changes: 97 additions & 0 deletions include/cuco/detail/open_addressing/open_addressing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class open_addressing_impl {
cuda_stream_ref stream) noexcept
: empty_key_sentinel_{empty_key_sentinel},
empty_slot_sentinel_{empty_slot_sentinel},
erased_key_sentinel_{empty_key_sentinel},
predicate_{pred},
probing_scheme_{probing_scheme},
storage_{make_window_extent<open_addressing_impl>(capacity), alloc}
Expand Down Expand Up @@ -190,6 +191,50 @@ class open_addressing_impl {
this->clear_async(stream);
}

/**
* @brief Constructs a statically-sized open addressing data structure with the specified initial
* capacity, sentinel values and CUDA stream.
*
* @note The actual capacity depends on the given `capacity`, the probing scheme, CG size, and the
* window size and it is computed via the `make_window_extent` factory. Insert operations will not
* automatically grow the container. Attempting to insert more unique keys than the capacity of
* the container results in undefined behavior.
* @note Any `*_sentinel`s are reserved and behavior is undefined when attempting to insert
* this sentinel value.
* @note If a non-default CUDA stream is provided, the caller is responsible for synchronizing the
* stream before the object is first used.
*
* @param capacity The requested lower-bound size
* @param empty_key_sentinel The reserved key value for empty slots
* @param empty_slot_sentinel The reserved slot value for empty slots
* @param erased_key_sentinel The reserved key value for erased slots
* @param pred Key equality binary predicate
* @param probing_scheme Probing scheme
* @param alloc Allocator used for allocating device storage
* @param stream CUDA stream used to initialize the data structure
*/
constexpr open_addressing_impl(Extent capacity,
Key empty_key_sentinel,
Value empty_slot_sentinel,
Key erased_key_sentinel,
KeyEqual const& pred,
ProbingScheme const& probing_scheme,
Allocator const& alloc,
cuda_stream_ref stream)
: empty_key_sentinel_{empty_key_sentinel},
empty_slot_sentinel_{empty_slot_sentinel},
erased_key_sentinel_{erased_key_sentinel},
predicate_{pred},
probing_scheme_{probing_scheme},
storage_{make_window_extent<open_addressing_impl>(capacity), alloc}
{
CUCO_EXPECTS(empty_key_sentinel_ != erased_key_sentinel_,
"The empty key sentinel and erased key sentinel cannot be the same value.",
std::logic_error);

this->clear_async(stream);
}

/**
* @brief Erases all elements from the container. After this call, `size()` returns zero.
* Invalidates any references, pointers, or iterators referring to contained elements.
Expand Down Expand Up @@ -365,6 +410,46 @@ class open_addressing_impl {
first, num_keys, stencil, pred, container_ref);
}

/**
* @brief Asynchronously erases keys in the range `[first, last)`.
*
* @note For each key `k` in `[first, last)`, if contains(k) returns true, removes `k` and it's
* associated value from the container. Else, no effect.
*
* @note Side-effects:
* - `contains(k) == false`
* - `find(k) == end()`
* - `insert({k,v}) == true`
* - `size()` is reduced by the total number of erased keys
*
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the container's `key_type`
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param container_ref Non-owning device container ref used to access the slot storage
* @param stream Stream used for executing the kernels
*
* @throw std::runtime_error if a unique erased key sentinel value was not
* provided at construction
*/
template <typename InputIt, typename Ref>
void erase_async(InputIt first, InputIt last, Ref container_ref, cuda_stream_ref stream = {})
{
CUCO_EXPECTS(empty_key_sentinel_ != erased_key_sentinel_,
"The empty key sentinel and erased key sentinel cannot be the same value.",
std::logic_error);

auto const num_keys = cuco::detail::distance(first, last);
if (num_keys == 0) { return; }

auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);

detail::erase<cg_size, cuco::detail::default_block_size()>
<<<grid_size, cuco::detail::default_block_size(), 0, stream>>>(
first, num_keys, container_ref);
}

/**
* @brief Asynchronously indicates whether the keys in the range `[first, last)` are contained in
* the container.
Expand Down Expand Up @@ -556,6 +641,16 @@ class open_addressing_impl {
return empty_key_sentinel_;
}

/**
* @brief Gets the sentinel value used to represent an erased key slot.
*
* @return The sentinel value used to represent an erased key slot
*/
[[nodiscard]] constexpr key_type erased_key_sentinel() const noexcept
{
return erased_key_sentinel_;
}

/**
* @brief Gets the key comparator.
*
Expand Down Expand Up @@ -588,8 +683,10 @@ class open_addressing_impl {
[[nodiscard]] constexpr storage_ref_type storage_ref() const noexcept { return storage_.ref(); }

protected:
// TODO: cleanup by using equal wrapper as a data member
key_type empty_key_sentinel_; ///< Key value that represents an empty slot
value_type empty_slot_sentinel_; ///< Slot value that represents an empty slot
key_type erased_key_sentinel_; ///< Key value that represents an erased slot
key_equal predicate_; ///< Key equality binary predicate
probing_scheme_type probing_scheme_; ///< Probing scheme
storage_type storage_; ///< Slot window storage
Expand Down
Loading

0 comments on commit 0f86edb

Please sign in to comment.