Skip to content

Commit

Permalink
Migrate device ref example
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Dec 18, 2023
1 parent 8c59522 commit 4b39b77
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
2 changes: 1 addition & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ConfigureExample(STATIC_SET_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/stati
ConfigureExample(STATIC_SET_DEVICE_REF_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_set/device_ref_example.cu")
ConfigureExample(STATIC_SET_DEVICE_SUBSETS_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_set/device_subsets_example.cu")
ConfigureExample(STATIC_MAP_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/host_bulk_example.cu")
ConfigureExample(STATIC_MAP_DEVICE_SIDE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/device_view_example.cu")
ConfigureExample(STATIC_MAP_DEVICE_SIDE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/device_ref_example.cu")
ConfigureExample(STATIC_MAP_CUSTOM_TYPE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/custom_type_example.cu")
ConfigureExample(STATIC_MAP_COUNT_BY_KEY_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/count_by_key_example.cu")
ConfigureExample(STATIC_MULTIMAP_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_multimap/host_bulk_example.cu")
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,12 +29,11 @@
#include <limits>

/**
* @file device_view_example.cu
* @file device_ref_example.cu
* @brief Demonstrates usage of the device side APIs for individual operations like insert/find.
*
* Individual operations like a single insert or find can be performed in device code via the
* static_map "device_view" types. Note that concurrent insert and find are not supported, and
* therefore there are separate view types for insert and find to help prevent undefined behavior.
* "static_map_ref" types.
*
* @note This example is for demonstration purposes only. It is not intended to show the most
* performant way to do the example algorithm.
Expand All @@ -44,12 +43,12 @@
/**
* @brief Inserts keys that pass the specified predicated into the map.
*
* @tparam Map Type of the map returned from static_map::get_device_mutable_view
* @tparam Map Type of the map device reference
* @tparam KeyIter Input iterator whose value_type convertible to Map::key_type
* @tparam ValueIter Input iterator whose value_type is convertible to Map::mapped_type
* @tparam Predicate Unary predicate
*
* @param[in] map_view View of the map into which inserts will be performed
* @param[in] map_ref Reference of the map into which inserts will be performed
* @param[in] key_begin The beginning of the range of keys to insert
* @param[in] value_begin The beginning of the range of values associated with each key to insert
* @param[in] num_keys The total number of keys and values
Expand All @@ -58,7 +57,7 @@
* @param[out] num_inserted The total number of keys successfully inserted
*/
template <typename Map, typename KeyIter, typename ValueIter, typename Predicate>
__global__ void filtered_insert(Map map_view,
__global__ void filtered_insert(Map map_ref,
KeyIter key_begin,
ValueIter value_begin,
std::size_t num_keys,
Expand All @@ -71,9 +70,9 @@ __global__ void filtered_insert(Map map_view,
while (tid < num_keys) {
// Only insert keys that pass the predicate
if (pred(key_begin[tid])) {
// device_mutable_view::insert returns `true` if it is the first time the given key was
// Map::insert returns `true` if it is the first time the given key was
// inserted and `false` if the key already existed
if (map_view.insert({key_begin[tid], value_begin[tid]})) {
if (map_ref.insert(cuco::pair{key_begin[tid], value_begin[tid]})) {
++counter; // Count number of successfully inserted keys
}
}
Expand All @@ -87,25 +86,26 @@ __global__ void filtered_insert(Map map_view,
/**
* @brief For keys that have a match in the map, increments their corresponding value by one.
*
* @tparam Map Type of the map returned from static_map::get_device_view
* @tparam Map Type of the map device reference
* @tparam KeyIter Input iterator whose value_type convertible to Map::key_type
*
* @param map_view View of the map into which queries will be performed
* @param map_ref Reference of the map into which queries will be performed
* @param key_begin The beginning of the range of keys to query
* @param num_keys The total number of keys
*/
template <typename Map, typename KeyIter>
__global__ void increment_values(Map map_view, KeyIter key_begin, std::size_t num_keys)
__global__ void increment_values(Map map_ref, KeyIter key_begin, std::size_t num_keys)
{
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
while (tid < num_keys) {
// If the key exists in the map, find returns an iterator to the specified key. Otherwise it
// returns map.end()
auto found = map_view.find(key_begin[tid]);
if (found != map_view.end()) {
auto found = map_ref.find(key_begin[tid]);
if (found != map_ref.end()) {
// If the key exists, atomically increment the associated value
// The value type of the iterator is pair<cuda::atomic<Key>, cuda::atomic<Value>>
found->second.fetch_add(1, cuda::memory_order_relaxed);
auto ref =
cuda::atomic_ref<typename Map::mapped_type, cuda::thread_scope_device>{found->second};
ref.fetch_add(1, cuda::memory_order_relaxed);
}
tid += gridDim.x * blockDim.x;
}
Expand Down Expand Up @@ -135,11 +135,16 @@ int main(void)
std::size_t const capacity = std::ceil(num_keys / load_factor);

// Constructs a map with "capacity" slots using -1 and -1 as the empty key/value sentinels.
cuco::static_map<Key, Value> map{
capacity, cuco::empty_key{empty_key_sentinel}, cuco::empty_value{empty_value_sentinel}};
auto map = cuco::experimental::static_map{
capacity,
cuco::empty_key{empty_key_sentinel},
cuco::empty_value{empty_value_sentinel},
thrust::equal_to<Key>{},
cuco::experimental::linear_probing<1, cuco::default_hash_function<Key>>{}};

// Get a non-owning, mutable view of the map that allows inserts to pass by value into the kernel
auto device_insert_view = map.get_device_mutable_view();
// Get a non-owning, mutable reference of the map that allows inserts to pass by value into the
// kernel
auto insert_ref = map.ref(cuco::experimental::op::insert);

// Predicate will only insert even keys
auto is_even = [] __device__(auto key) { return (key % 2) == 0; };
Expand All @@ -149,7 +154,7 @@ int main(void)

auto constexpr block_size = 256;
auto const grid_size = (num_keys + block_size - 1) / block_size;
filtered_insert<<<grid_size, block_size>>>(device_insert_view,
filtered_insert<<<grid_size, block_size>>>(insert_ref,
insert_keys.begin(),
insert_values.begin(),
num_keys,
Expand All @@ -158,10 +163,11 @@ int main(void)

std::cout << "Number of keys inserted: " << num_inserted[0] << std::endl;

// Get a non-owning view of the map that allows find operations to pass by value into the kernel
auto device_find_view = map.get_device_view();
// Get a non-owning reference of the map that allows find operations to pass by value into the
// kernel
auto find_ref = map.ref(cuco::experimental::op::find);

increment_values<<<grid_size, block_size>>>(device_find_view, insert_keys.begin(), num_keys);
increment_values<<<grid_size, block_size>>>(find_ref, insert_keys.begin(), num_keys);

// Retrieve contents of all the non-empty slots in the map
thrust::device_vector<Key> contained_keys(num_inserted[0]);
Expand Down

0 comments on commit 4b39b77

Please sign in to comment.