From 8a4b7814bf6b628e7722b4242791cc49e43f9cc8 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 30 Oct 2023 13:30:13 +0800 Subject: [PATCH] [Feat] add the API `assign_values` - Add `check_evict_strategy` for `assign` --- CMakeLists.txt | 5 + include/merlin/core_kernels.cuh | 1 + include/merlin/core_kernels/update_values.cuh | 865 ++++++++++++++++++ include/merlin_hashtable.cuh | 162 +++- tests/assign_values_test.cc.cu | 747 +++++++++++++++ 5 files changed, 1772 insertions(+), 8 deletions(-) create mode 100644 include/merlin/core_kernels/update_values.cuh create mode 100644 tests/assign_values_test.cc.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f9fbbb7e..163262b82 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,3 +145,8 @@ add_executable(accum_or_assign_test tests/accum_or_assign_test.cc) target_compile_features(accum_or_assign_test PUBLIC cxx_std_14) set_target_properties(accum_or_assign_test PROPERTIES CUDA_ARCHITECTURES OFF) TARGET_LINK_LIBRARIES(accum_or_assign_test gtest_main) + +add_executable(assign_values_test tests/assign_values_test.cc.cu) +target_compile_features(assign_values_test PUBLIC cxx_std_14) +set_target_properties(assign_values_test PROPERTIES CUDA_ARCHITECTURES OFF) +TARGET_LINK_LIBRARIES(assign_values_test gtest_main) diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index 6b9df9475..6585efcf6 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -26,6 +26,7 @@ #include "core_kernels/lookup_ptr.cuh" #include "core_kernels/update.cuh" #include "core_kernels/update_score.cuh" +#include "core_kernels/update_values.cuh" #include "core_kernels/upsert.cuh" #include "core_kernels/upsert_and_evict.cuh" diff --git a/include/merlin/core_kernels/update_values.cuh b/include/merlin/core_kernels/update_values.cuh new file mode 100644 index 000000000..c2e3d17fb --- /dev/null +++ b/include/merlin/core_kernels/update_values.cuh @@ -0,0 +1,865 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http:///www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "kernel_utils.cuh" + +namespace nv { +namespace merlin { + +// Use 1 thread to deal with a KV-pair, including copying value. +template +__global__ void tlp_update_values_kernel_with_io( + Bucket* __restrict__ buckets, const uint64_t buckets_num, + uint32_t bucket_capacity, const uint32_t dim, const K* __restrict__ keys, + const VecV* __restrict__ values, uint64_t n) { + using BUCKET = Bucket; + using CopyValue = CopyValueMultipleGroup; + + uint32_t tx = threadIdx.x; + uint32_t kv_idx = blockIdx.x * blockDim.x + tx; + K key{static_cast(EMPTY_KEY)}; + OccupyResult occupy_result{OccupyResult::INITIAL}; + VecD_Comp target_digests{0}; + VecV* bucket_values_ptr{nullptr}; + K* bucket_keys_ptr{nullptr}; + uint32_t key_pos = {0}; + if (kv_idx < n) { + key = keys[kv_idx]; + + if (!IS_RESERVED_KEY(key)) { + const K hashed_key = Murmur3HashDevice(key); + target_digests = digests_from_hashed(hashed_key); + uint64_t global_idx = + static_cast(hashed_key % (buckets_num * bucket_capacity)); + key_pos = get_start_position(global_idx, bucket_capacity); + uint64_t bkt_idx = global_idx / bucket_capacity; + BUCKET* bucket = buckets + bkt_idx; + bucket_keys_ptr = reinterpret_cast(bucket->keys(0)); + bucket_values_ptr = reinterpret_cast(bucket->vectors); + } else { + return; + } + } else { + return; + } + + // Load `STRIDE` digests every time. + constexpr uint32_t STRIDE = sizeof(VecD_Load) / sizeof(D); + // One more loop to handle empty keys. + for (int offset = 0; offset < bucket_capacity + STRIDE; offset += STRIDE) { + if (occupy_result != OccupyResult::INITIAL) break; + + uint32_t pos_cur = align_to(key_pos); + pos_cur = (pos_cur + offset) & (bucket_capacity - 1); + + D* digests_ptr = BUCKET::digests(bucket_keys_ptr, bucket_capacity, pos_cur); + VecD_Load digests_vec = *(reinterpret_cast(digests_ptr)); + VecD_Comp digests_arr[4] = {digests_vec.x, digests_vec.y, digests_vec.z, + digests_vec.w}; + + for (int i = 0; i < 4; i++) { + VecD_Comp probe_digests = digests_arr[i]; + uint32_t possible_pos = 0; + bool result = false; + // Perform a vectorized comparison by byte, + // and if they are equal, set the corresponding byte in the result to + // 0xff. + int cmp_result = __vcmpeq4(probe_digests, target_digests); + cmp_result &= 0x01010101; + do { + if (cmp_result == 0) break; + // CUDA uses little endian, + // and the lowest byte in register stores in the lowest address. + uint32_t index = (__ffs(cmp_result) - 1) >> 3; + cmp_result &= (cmp_result - 1); + possible_pos = pos_cur + i * 4 + index; + auto current_key = BUCKET::keys(bucket_keys_ptr, possible_pos); + K expected_key = key; + // Modifications to the bucket will not before this instruction. + result = current_key->compare_exchange_strong( + expected_key, static_cast(LOCKED_KEY), + cuda::std::memory_order_acquire, cuda::std::memory_order_relaxed); + } while (!result); + if (result) { + occupy_result = OccupyResult::DUPLICATE; + key_pos = possible_pos; + VecV* bucket_value_ptr = bucket_values_ptr + key_pos * dim; + const VecV* param_value_ptr = values + kv_idx * dim; + CopyValue::ldg_stg(0, bucket_value_ptr, param_value_ptr, dim); + auto key_address = BUCKET::keys(bucket_keys_ptr, key_pos); + // memory_order_release: + // Modifications to the bucket will not after this instruction. + key_address->store(key, cuda::std::memory_order_release); + return; + } + VecD_Comp empty_digests_ = empty_digests(); + cmp_result = __vcmpeq4(probe_digests, empty_digests_); + cmp_result &= 0x01010101; + do { + if (cmp_result == 0) break; + uint32_t index = (__ffs(cmp_result) - 1) >> 3; + cmp_result &= (cmp_result - 1); + possible_pos = pos_cur + i * 4 + index; + if (offset == 0 && possible_pos < key_pos) continue; + auto current_key = BUCKET::keys(bucket_keys_ptr, possible_pos); + auto probe_key = current_key->load(cuda::std::memory_order_relaxed); + if (probe_key == static_cast(EMPTY_KEY)) { + return; + } + } while (true); + } + } +} +template +__global__ void pipeline_update_values_kernel_with_io( + Bucket* __restrict__ buckets, const uint64_t buckets_num, + const uint32_t dim, const K* __restrict__ keys, + const VecV* __restrict__ values, uint64_t n) { + constexpr uint32_t BUCKET_SIZE = 128; + constexpr int GROUP_NUM = BLOCK_SIZE / GROUP_SIZE; + constexpr uint32_t Comp_LEN = sizeof(VecD_Comp) / sizeof(D); + // Here, GROUP_SIZE * Load_LEN = BUCKET_SIZE. + using VecD_Load = byte8; + constexpr uint32_t Load_LEN = sizeof(VecD_Load) / sizeof(D); + constexpr int RESERVE = 8; + + using BUCKET = Bucket; + using CopyValue = CopyValueMultipleGroup; + + __shared__ VecD_Comp sm_target_digests[BLOCK_SIZE]; + __shared__ K sm_target_keys[BLOCK_SIZE]; + __shared__ K* sm_keys_ptr[BLOCK_SIZE]; + __shared__ VecV* sm_values_ptr[BLOCK_SIZE]; + // Reuse + int* sm_counts = reinterpret_cast(sm_target_digests); + int* sm_position = sm_counts; + // Double buffer + __shared__ D sm_digests[GROUP_NUM][2 * BUCKET_SIZE]; + __shared__ K sm_possible_keys[GROUP_NUM][2 * RESERVE]; + __shared__ int sm_possible_pos[GROUP_NUM][2 * RESERVE]; + __shared__ int sm_ranks[GROUP_NUM][2]; + // __shared__ VecV sm_values_buffer[GROUP_NUM][2 * dim]; + + extern __shared__ __align__(alignof(byte16)) byte sm_values_buffer[]; + + bool CAS_res[2]{false}; + + // Initialization + auto g = cg::tiled_partition(cg::this_thread_block()); + int groupID = threadIdx.x / GROUP_SIZE; + int rank = g.thread_rank(); + uint64_t key_idx_base = (blockIdx.x * blockDim.x) + groupID * GROUP_SIZE; + if (key_idx_base >= n) return; + int loop_num = + (n - key_idx_base) < GROUP_SIZE ? (n - key_idx_base) : GROUP_SIZE; + if (rank < loop_num) { + int idx_block = groupID * GROUP_SIZE + rank; + K key = keys[key_idx_base + rank]; + sm_target_keys[idx_block] = key; + const K hashed_key = Murmur3HashDevice(key); + sm_target_digests[idx_block] = digests_from_hashed(hashed_key); + uint64_t global_idx = hashed_key % (buckets_num * BUCKET_SIZE); + uint64_t bkt_idx = global_idx / BUCKET_SIZE; + Bucket* bucket = buckets + bkt_idx; + __pipeline_memcpy_async(sm_keys_ptr + idx_block, bucket->keys_addr(), + sizeof(K*)); + __pipeline_commit(); + __pipeline_memcpy_async(sm_values_ptr + idx_block, &(bucket->vectors), + sizeof(VecV*)); + } + __pipeline_wait_prior(0); + + // Pipeline loading + K* keys_ptr = sm_keys_ptr[groupID * GROUP_SIZE]; + D* digests_ptr = BUCKET::digests(keys_ptr, BUCKET_SIZE, rank * Load_LEN); + __pipeline_memcpy_async(sm_digests[groupID] + rank * Load_LEN, digests_ptr, + sizeof(VecD_Load)); + __pipeline_commit(); + // Padding, meet the param of the first `__pipeline_wait_prior` + // in the first loop. + __pipeline_commit(); + __pipeline_commit(); + + for (int i = 0; i < loop_num; i++) { + int key_idx_block = groupID * GROUP_SIZE + i; + + /* Step1: prefetch all digests in one bucket */ + if ((i + 1) < loop_num) { + K* keys_ptr = sm_keys_ptr[key_idx_block + 1]; + D* digests_ptr = BUCKET::digests(keys_ptr, BUCKET_SIZE, rank * Load_LEN); + __pipeline_memcpy_async( + sm_digests[groupID] + diff_buf(i) * BUCKET_SIZE + rank * Load_LEN, + digests_ptr, sizeof(VecD_Load)); + } + __pipeline_commit(); + + /* Step2: check digests and load possible keys */ + VecD_Comp target_digests = sm_target_digests[key_idx_block]; + sm_counts[key_idx_block] = 0; + __pipeline_wait_prior(3); + VecD_Comp probing_digests = *reinterpret_cast( + &sm_digests[groupID][same_buf(i) * BUCKET_SIZE + rank * Comp_LEN]); + uint32_t find_result_ = __vcmpeq4(probing_digests, target_digests); + uint32_t find_result = 0; + if ((find_result_ & 0x01) != 0) find_result |= 0x01; + if ((find_result_ & 0x0100) != 0) find_result |= 0x02; + if ((find_result_ & 0x010000) != 0) find_result |= 0x04; + if ((find_result_ & 0x01000000) != 0) find_result |= 0x08; + probing_digests = *reinterpret_cast( + &sm_digests[groupID][same_buf(i) * BUCKET_SIZE + + (GROUP_SIZE + rank) * Comp_LEN]); + find_result_ = __vcmpeq4(probing_digests, target_digests); + if ((find_result_ & 0x01) != 0) find_result |= 0x10; + if ((find_result_ & 0x0100) != 0) find_result |= 0x20; + if ((find_result_ & 0x010000) != 0) find_result |= 0x40; + if ((find_result_ & 0x01000000) != 0) find_result |= 0x80; + int find_number = __popc(find_result); + int group_base = 0; + if (find_number > 0) { + group_base = atomicAdd(sm_counts + key_idx_block, find_number); + } + bool gt_reserve = (group_base + find_number) > RESERVE; + int gt_vote = g.ballot(gt_reserve); + K* key_ptr = sm_keys_ptr[key_idx_block]; + if (gt_vote == 0) { + do { + int digest_idx = __ffs(find_result) - 1; + if (digest_idx >= 0) { + find_result &= (find_result - 1); + int key_pos = digest_idx < 4 + ? (rank * 4 + digest_idx) + : ((GROUP_SIZE + rank - 1) * 4 + digest_idx); + sm_possible_pos[groupID][same_buf(i) * RESERVE + group_base] = + key_pos; + __pipeline_memcpy_async( + sm_possible_keys[groupID] + same_buf(i) * RESERVE + group_base, + key_ptr + key_pos, sizeof(K)); + group_base += 1; + } else { + break; + } + } while (true); + } else { + K target_key = sm_target_keys[key_idx_block]; + sm_counts[key_idx_block] = 0; + int found_vote = 0; + bool found = false; + do { + int digest_idx = __ffs(find_result) - 1; + if (digest_idx >= 0) { + find_result &= (find_result - 1); + int key_pos = digest_idx < 4 + ? (rank * 4 + digest_idx) + : ((GROUP_SIZE + rank - 1) * 4 + digest_idx); + K possible_key = key_ptr[key_pos]; + if (possible_key == target_key) { + found = true; + sm_counts[key_idx_block] = 1; + sm_possible_pos[groupID][same_buf(i) * RESERVE] = key_pos; + sm_possible_keys[groupID][same_buf(i) * RESERVE] = possible_key; + } + } + found_vote = g.ballot(found); + if (found_vote) { + break; + } + found_vote = digest_idx >= 0; + } while (g.any(found_vote)); + } + __pipeline_commit(); + + /* Step3: check possible keys, and prefecth the value */ + if (i > 0) { + key_idx_block -= 1; + K target_key = sm_target_keys[key_idx_block]; + K* keys_ptr = sm_keys_ptr[key_idx_block]; + int possible_num = sm_counts[key_idx_block]; + sm_position[key_idx_block] = -1; + __pipeline_wait_prior(3); + int key_pos; + bool found_flag = false; + if (rank < possible_num) { + K possible_key = + sm_possible_keys[groupID][diff_buf(i) * RESERVE + rank]; + key_pos = sm_possible_pos[groupID][diff_buf(i) * RESERVE + rank]; + if (possible_key == target_key) { + found_flag = true; + auto key_ptr = BUCKET::keys(keys_ptr, key_pos); + sm_ranks[groupID][diff_buf(i)] = rank; + if (diff_buf(i) == 0) { + CAS_res[0] = key_ptr->compare_exchange_strong( + possible_key, static_cast(LOCKED_KEY), + cuda::std::memory_order_acquire, + cuda::std::memory_order_relaxed); + } else { + CAS_res[1] = key_ptr->compare_exchange_strong( + possible_key, static_cast(LOCKED_KEY), + cuda::std::memory_order_acquire, + cuda::std::memory_order_relaxed); + } + } + } + int found_vote = g.ballot(found_flag); + if (found_vote) { + int src_lane = __ffs(found_vote) - 1; + int target_pos = g.shfl(key_pos, src_lane); + sm_position[key_idx_block] = target_pos; + int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block; + const VecV* v_src = values + key_idx_grid * dim; + auto tmp = reinterpret_cast(sm_values_buffer); + VecV* v_dst = tmp + (groupID * 2 + diff_buf(i)) * dim; + CopyValue::ldg_sts(rank, v_dst, v_src, dim); + } + } + __pipeline_commit(); + + /* Step4: write back value */ + if (i > 1) { + key_idx_block -= 1; + VecV* value_ptr = sm_values_ptr[key_idx_block]; + int target_pos = sm_position[key_idx_block]; + K target_key = sm_target_keys[key_idx_block]; + K* keys_ptr = sm_keys_ptr[key_idx_block]; + int src_lane = sm_ranks[groupID][same_buf(i)]; + __pipeline_wait_prior(3); + int succ = 0; + if (rank == src_lane) { + bool CAS_res_cur = same_buf(i) == 0 ? CAS_res[0] : CAS_res[1]; + succ = CAS_res_cur ? 1 : 0; + } + succ = g.shfl(succ, src_lane); + if (target_pos >= 0 && succ == 1) { + auto tmp = reinterpret_cast(sm_values_buffer); + VecV* v_src = tmp + (groupID * 2 + same_buf(i)) * dim; + VecV* v_dst = value_ptr + target_pos * dim; + CopyValue::lds_stg(rank, v_dst, v_src, dim); + if (rank == 0) { + auto key_address = BUCKET::keys(keys_ptr, target_pos); + key_address->store(target_key, cuda::std::memory_order_release); + } + } + } + } // End loop + + /* Pipeline emptying: step3, i = loop_num */ + { + int key_idx_block = groupID * GROUP_SIZE + (loop_num - 1); + K target_key = sm_target_keys[key_idx_block]; + K* keys_ptr = sm_keys_ptr[key_idx_block]; + int possible_num = sm_counts[key_idx_block]; + sm_position[key_idx_block] = -1; + __pipeline_wait_prior(1); + int key_pos; + bool found_flag = false; + if (rank < possible_num) { + K possible_key = + sm_possible_keys[groupID][diff_buf(loop_num) * RESERVE + rank]; + key_pos = sm_possible_pos[groupID][diff_buf(loop_num) * RESERVE + rank]; + if (possible_key == target_key) { + found_flag = true; + auto key_ptr = BUCKET::keys(keys_ptr, key_pos); + sm_ranks[groupID][diff_buf(loop_num)] = rank; + if (diff_buf(loop_num) == 0) { + CAS_res[0] = key_ptr->compare_exchange_strong( + possible_key, static_cast(LOCKED_KEY), + cuda::std::memory_order_acquire, cuda::std::memory_order_relaxed); + } else { + CAS_res[1] = key_ptr->compare_exchange_strong( + possible_key, static_cast(LOCKED_KEY), + cuda::std::memory_order_acquire, cuda::std::memory_order_relaxed); + } + } + } + int found_vote = g.ballot(found_flag); + if (found_vote) { + int src_lane = __ffs(found_vote) - 1; + int target_pos = g.shfl(key_pos, src_lane); + sm_position[key_idx_block] = target_pos; + int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block; + const VecV* v_src = values + key_idx_grid * dim; + auto tmp = reinterpret_cast(sm_values_buffer); + VecV* v_dst = tmp + (groupID * 2 + diff_buf(loop_num)) * dim; + CopyValue::ldg_sts(rank, v_dst, v_src, dim); + } + } + __pipeline_commit(); + + /* Pipeline emptying: step4, i = loop_num */ + if (loop_num > 1) { + int key_idx_block = groupID * GROUP_SIZE + loop_num - 2; + VecV* value_ptr = sm_values_ptr[key_idx_block]; + int target_pos = sm_position[key_idx_block]; + K target_key = sm_target_keys[key_idx_block]; + K* keys_ptr = sm_keys_ptr[key_idx_block]; + int src_lane = sm_ranks[groupID][same_buf(loop_num)]; + __pipeline_wait_prior(1); + int succ = 0; + if (rank == src_lane) { + bool CAS_res_cur = same_buf(loop_num) == 0 ? CAS_res[0] : CAS_res[1]; + succ = CAS_res_cur ? 1 : 0; + } + succ = g.shfl(succ, src_lane); + if (target_pos >= 0 && succ == 1) { + auto tmp = reinterpret_cast(sm_values_buffer); + VecV* v_src = tmp + (groupID * 2 + same_buf(loop_num)) * dim; + VecV* v_dst = value_ptr + target_pos * dim; + CopyValue::lds_stg(rank, v_dst, v_src, dim); + + auto key_ptr = BUCKET::keys(keys_ptr, target_pos); + if (rank == 0) { + auto key_address = BUCKET::keys(keys_ptr, target_pos); + key_address->store(target_key, cuda::std::memory_order_release); + } + } + } + + /* Pipeline emptying: step4, i = loop_num + 1 */ + { + int key_idx_block = groupID * GROUP_SIZE + loop_num - 1; + VecV* value_ptr = sm_values_ptr[key_idx_block]; + int target_pos = sm_position[key_idx_block]; + K target_key = sm_target_keys[key_idx_block]; + K* keys_ptr = sm_keys_ptr[key_idx_block]; + int src_lane = sm_ranks[groupID][same_buf(loop_num + 1)]; + __pipeline_wait_prior(0); + int succ = 0; + if (rank == src_lane) { + bool CAS_res_cur = same_buf(loop_num + 1) == 0 ? CAS_res[0] : CAS_res[1]; + succ = CAS_res_cur ? 1 : 0; + } + succ = g.shfl(succ, src_lane); + if (target_pos >= 0 && succ == 1) { + auto tmp = reinterpret_cast(sm_values_buffer); + VecV* v_src = tmp + (groupID * 2 + same_buf(loop_num + 1)) * dim; + VecV* v_dst = value_ptr + target_pos * dim; + CopyValue::lds_stg(rank, v_dst, v_src, dim); + if (rank == 0) { + auto key_address = BUCKET::keys(keys_ptr, target_pos); + key_address->store(target_key, cuda::std::memory_order_release); + } + } + } +} // End function + +template +struct Params_UpdateValues { + Params_UpdateValues(float load_factor_, + Bucket* __restrict__ buckets_, + size_t buckets_num_, uint32_t bucket_capacity_, + uint32_t dim_, const K* __restrict__ keys_, + const V* __restrict__ values_, size_t n_) + : load_factor(load_factor_), + buckets(buckets_), + buckets_num(buckets_num_), + bucket_capacity(bucket_capacity_), + dim(dim_), + keys(keys_), + values(values_), + n(n_) {} + float load_factor; + Bucket* __restrict__ buckets; + size_t buckets_num; + uint32_t bucket_capacity; + uint32_t dim; + const K* __restrict__ keys; + const V* __restrict__ values; + uint64_t n; +}; + +template +struct Launch_TLP_UpdateValues { + using Params = Params_UpdateValues; + inline static void launch_kernel(Params& params, cudaStream_t& stream) { + constexpr int BLOCK_SIZE = 128; + params.dim = params.dim * sizeof(V) / sizeof(VecV); + tlp_update_values_kernel_with_io + <<<(params.n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>( + params.buckets, params.buckets_num, params.bucket_capacity, + params.dim, params.keys, + reinterpret_cast(params.values), params.n); + } +}; + +template +struct Launch_Pipeline_UpdateValues { + using Params = Params_UpdateValues; + inline static void launch_kernel(Params& params, cudaStream_t& stream) { + constexpr int BLOCK_SIZE = 128; + constexpr uint32_t GROUP_SIZE = 16; + constexpr uint32_t GROUP_NUM = BLOCK_SIZE / GROUP_SIZE; + + params.dim = params.dim * sizeof(V) / sizeof(VecV); + uint32_t shared_mem = GROUP_NUM * 2 * params.dim * sizeof(VecV); + shared_mem = + (shared_mem + sizeof(byte16) - 1) / sizeof(byte16) * sizeof(byte16); + pipeline_update_values_kernel_with_io + <<<(params.n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, shared_mem, + stream>>>(params.buckets, params.buckets_num, params.dim, + params.keys, reinterpret_cast(params.values), + params.n); + } +}; + +template +struct ValueConfig_UpdateValues; + +/// TODO: support more arch. +template <> +struct ValueConfig_UpdateValues { + // Value size greater than it will bring poor performance for TLP. + static constexpr uint32_t size_tlp = 8 * sizeof(byte4); + // Value size greater than it will reduce the occupancy for Pipeline. + // When the value is very high, the kernel will fail to launch. + static constexpr uint32_t size_pipeline = 128 * sizeof(byte4); +}; + +template <> +struct ValueConfig_UpdateValues { + // Value size greater than it will bring poor performance for TLP. + static constexpr uint32_t size_tlp = 8 * sizeof(byte4); + // Value size greater than it will reduce the occupancy for Pipeline. + // When the value is very high, the kernel will fail to launch. + static constexpr uint32_t size_pipeline = 64 * sizeof(byte4); +}; + +template +struct KernelSelector_UpdateValues { + using ValueConfig = ValueConfig_UpdateValues; + using Params = Params_UpdateValues; + + static bool callable(bool unique_key, uint32_t bucket_size, uint32_t dim) { + constexpr uint32_t MinBucketCap = sizeof(VecD_Load) / sizeof(D); + if (!unique_key || bucket_size < MinBucketCap) return false; + uint32_t value_size = dim * sizeof(V); + if (value_size <= ValueConfig::size_tlp) return true; + if (bucket_size == 128 && value_size <= ValueConfig::size_pipeline) { + return true; + } + return false; + } + + static void select_kernel(Params& params, cudaStream_t& stream) { + const uint32_t total_value_size = + static_cast(params.dim * sizeof(V)); + + auto launch_TLP = [&]() { + if (total_value_size % sizeof(byte16) == 0) { + using VecV = byte16; + Launch_TLP_UpdateValues::launch_kernel(params, stream); + } else if (total_value_size % sizeof(byte8) == 0) { + using VecV = byte8; + Launch_TLP_UpdateValues::launch_kernel(params, stream); + } else if (total_value_size % sizeof(byte4) == 0) { + using VecV = byte4; + Launch_TLP_UpdateValues::launch_kernel(params, stream); + } else if (total_value_size % sizeof(byte2) == 0) { + using VecV = byte2; + Launch_TLP_UpdateValues::launch_kernel(params, stream); + } else { + using VecV = byte; + Launch_TLP_UpdateValues::launch_kernel(params, stream); + } + }; + + auto launch_Pipeline = [&]() { + if (total_value_size % sizeof(byte16) == 0) { + using VecV = byte16; + Launch_Pipeline_UpdateValues::launch_kernel(params, + stream); + } else if (total_value_size % sizeof(byte8) == 0) { + using VecV = byte8; + Launch_Pipeline_UpdateValues::launch_kernel(params, + stream); + } else if (total_value_size % sizeof(byte4) == 0) { + using VecV = byte4; + Launch_Pipeline_UpdateValues::launch_kernel(params, + stream); + } else if (total_value_size % sizeof(byte2) == 0) { + using VecV = byte2; + Launch_Pipeline_UpdateValues::launch_kernel(params, + stream); + } else { + using VecV = byte; + Launch_Pipeline_UpdateValues::launch_kernel(params, + stream); + } + }; + // This part is according to the test on A100. + if (params.bucket_capacity != 128) { + launch_TLP(); + } else { + if (total_value_size <= ValueConfig::size_tlp) { + if (params.load_factor <= 0.60f) { + launch_TLP(); + } else { + launch_Pipeline(); + } + } else { + launch_Pipeline(); + } + } + } // End function +}; + +/* + * update with IO operation. This kernel is + * usually used for the pure HBM mode for better performance. + */ +template +__global__ void update_values_kernel_with_io( + const Table* __restrict table, Bucket* buckets, + const size_t bucket_max_size, const size_t buckets_num, const size_t dim, + const K* __restrict keys, const V* __restrict values, const size_t N) { + auto g = cg::tiled_partition(cg::this_thread_block()); + int* buckets_size = table->buckets_size; + + for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N; + t += blockDim.x * gridDim.x) { + int key_pos = -1; + size_t key_idx = t / TILE_SIZE; + + const K update_key = keys[key_idx]; + + if (IS_RESERVED_KEY(update_key)) continue; + + const V* update_value = values + key_idx * dim; + + size_t bkt_idx = 0; + size_t start_idx = 0; + int src_lane = -1; + + Bucket* bucket = get_key_position( + buckets, update_key, bkt_idx, start_idx, buckets_num, bucket_max_size); + + OccupyResult occupy_result{OccupyResult::INITIAL}; + const int bucket_size = buckets_size[bkt_idx]; + + if (bucket_size >= bucket_max_size) { + start_idx = (start_idx / TILE_SIZE) * TILE_SIZE; + } + occupy_result = find_and_lock_for_update( + g, bucket, update_key, start_idx, key_pos, src_lane, bucket_max_size); + + occupy_result = g.shfl(occupy_result, src_lane); + + if (occupy_result == OccupyResult::REFUSED) continue; + + if (occupy_result == OccupyResult::DUPLICATE) { + copy_vector(g, update_value, + bucket->vectors + key_pos * dim, dim); + } + + if (g.thread_rank() == src_lane) { + (bucket->keys(key_pos)) + ->store(update_key, cuda::std::memory_order_relaxed); + } + } +} + +template +struct SelectUpdateValuesKernelWithIO { + static void execute_kernel(const float& load_factor, const int& block_size, + const size_t bucket_max_size, + const size_t buckets_num, const size_t dim, + cudaStream_t& stream, const size_t& n, + const Table* __restrict table, + Bucket* buckets, const K* __restrict keys, + const V* __restrict values) { + if (load_factor <= 0.75) { + const unsigned int tile_size = 4; + const size_t N = n * tile_size; + const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size); + update_values_kernel_with_io + <<>>(table, buckets, + bucket_max_size, buckets_num, + dim, keys, values, N); + } else { + const unsigned int tile_size = 32; + const size_t N = n * tile_size; + const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size); + update_values_kernel_with_io + <<>>(table, buckets, + bucket_max_size, buckets_num, + dim, keys, values, N); + } + return; + } +}; + +// Use 1 thread to deal with a KV-pair, including copying value. +template +__global__ void tlp_update_values_kernel_hybrid( + Bucket* __restrict__ buckets, const uint64_t buckets_num, + uint32_t bucket_capacity, const uint32_t dim, const K* __restrict__ keys, + V** __restrict__ values, K** __restrict__ key_ptrs, + int* __restrict src_offset, uint64_t n) { + using BUCKET = Bucket; + + uint32_t tx = threadIdx.x; + uint32_t kv_idx = blockIdx.x * blockDim.x + tx; + K key{static_cast(EMPTY_KEY)}; + OccupyResult occupy_result{OccupyResult::INITIAL}; + VecD_Comp target_digests{0}; + V* bucket_values_ptr{nullptr}; + K* bucket_keys_ptr{nullptr}; + uint32_t key_pos = {0}; + if (kv_idx < n) { + key = keys[kv_idx]; + if (src_offset) src_offset[kv_idx] = kv_idx; + if (!IS_RESERVED_KEY(key)) { + const K hashed_key = Murmur3HashDevice(key); + target_digests = digests_from_hashed(hashed_key); + uint64_t global_idx = + static_cast(hashed_key % (buckets_num * bucket_capacity)); + key_pos = get_start_position(global_idx, bucket_capacity); + uint64_t bkt_idx = global_idx / bucket_capacity; + BUCKET* bucket = buckets + bkt_idx; + bucket_keys_ptr = reinterpret_cast(bucket->keys(0)); + bucket_values_ptr = bucket->vectors; + } else { + key_ptrs[kv_idx] = nullptr; + return; + } + } else { + return; + } + + // Load `STRIDE` digests every time. + constexpr uint32_t STRIDE = sizeof(VecD_Load) / sizeof(D); + // One more loop to handle empty keys. + for (int offset = 0; offset < bucket_capacity + STRIDE; offset += STRIDE) { + if (occupy_result != OccupyResult::INITIAL) break; + + uint32_t pos_cur = align_to(key_pos); + pos_cur = (pos_cur + offset) & (bucket_capacity - 1); + + D* digests_ptr = BUCKET::digests(bucket_keys_ptr, bucket_capacity, pos_cur); + VecD_Load digests_vec = *(reinterpret_cast(digests_ptr)); + VecD_Comp digests_arr[4] = {digests_vec.x, digests_vec.y, digests_vec.z, + digests_vec.w}; + + for (int i = 0; i < 4; i++) { + VecD_Comp probe_digests = digests_arr[i]; + uint32_t possible_pos = 0; + bool result = false; + // Perform a vectorized comparison by byte, + // and if they are equal, set the corresponding byte in the result to + // 0xff. + int cmp_result = __vcmpeq4(probe_digests, target_digests); + cmp_result &= 0x01010101; + do { + if (cmp_result == 0) break; + // CUDA uses little endian, + // and the lowest byte in register stores in the lowest address. + uint32_t index = (__ffs(cmp_result) - 1) >> 3; + cmp_result &= (cmp_result - 1); + possible_pos = pos_cur + i * 4 + index; + auto current_key = BUCKET::keys(bucket_keys_ptr, possible_pos); + K expected_key = key; + // Modifications to the bucket will not before this instruction. + result = current_key->compare_exchange_strong( + expected_key, static_cast(LOCKED_KEY), + cuda::std::memory_order_acquire, cuda::std::memory_order_relaxed); + } while (!result); + if (result) { + key_pos = possible_pos; + V* bucket_value_ptr = bucket_values_ptr + key_pos * dim; + values[kv_idx] = bucket_value_ptr; + key_ptrs[kv_idx] = bucket_keys_ptr + key_pos; + return; + } + VecD_Comp empty_digests_ = empty_digests(); + cmp_result = __vcmpeq4(probe_digests, empty_digests_); + cmp_result &= 0x01010101; + do { + if (cmp_result == 0) break; + uint32_t index = (__ffs(cmp_result) - 1) >> 3; + cmp_result &= (cmp_result - 1); + possible_pos = pos_cur + i * 4 + index; + if (offset == 0 && possible_pos < key_pos) continue; + auto current_key = BUCKET::keys(bucket_keys_ptr, possible_pos); + auto probe_key = current_key->load(cuda::std::memory_order_relaxed); + if (probe_key == static_cast(EMPTY_KEY)) { + return; + } + } while (true); + } + } +} + +template +__global__ void update_values_kernel(const Table* __restrict table, + Bucket* buckets, + const size_t bucket_max_size, + const size_t buckets_num, const size_t dim, + const K* __restrict keys, + V** __restrict vectors, + int* __restrict src_offset, size_t N) { + auto g = cg::tiled_partition(cg::this_thread_block()); + int* buckets_size = table->buckets_size; + + for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N; + t += blockDim.x * gridDim.x) { + int key_pos = -1; + size_t key_idx = t / TILE_SIZE; + + const K update_key = keys[key_idx]; + + if (IS_RESERVED_KEY(update_key)) continue; + + size_t bkt_idx = 0; + size_t start_idx = 0; + int src_lane = -1; + + Bucket* bucket = get_key_position( + buckets, update_key, bkt_idx, start_idx, buckets_num, bucket_max_size); + + OccupyResult occupy_result{OccupyResult::INITIAL}; + const int bucket_size = buckets_size[bkt_idx]; + *(src_offset + key_idx) = key_idx; + + if (bucket_size >= bucket_max_size) { + start_idx = (start_idx / TILE_SIZE) * TILE_SIZE; + } + occupy_result = find_and_lock_for_update( + g, bucket, update_key, start_idx, key_pos, src_lane, bucket_max_size); + + occupy_result = g.shfl(occupy_result, src_lane); + + if (occupy_result == OccupyResult::REFUSED) continue; + + if (g.thread_rank() == src_lane) { + if (occupy_result == OccupyResult::DUPLICATE) { + *(vectors + key_idx) = (bucket->vectors + key_pos * dim); + } else { + *(vectors + key_idx) = nullptr; + } + } + + if (g.thread_rank() == src_lane) { + (bucket->keys(key_pos)) + ->store(update_key, cuda::std::memory_order_relaxed); + } + } +} + +} // namespace merlin +} // namespace nv \ No newline at end of file diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 4f3c5def7..17c9e126e 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -1084,6 +1084,8 @@ class HashTable { return; } + check_evict_strategy(scores); + update_shared_lock lock(mutex_, stream); if (is_fast_mode()) { @@ -1201,14 +1203,12 @@ class HashTable { } /** - * @brief Assign new key-value-score tuples into the hash table. + * @brief Assign new scores for keys. * If the key doesn't exist, the operation on the key will be ignored. * - * @param n Number of key-value-score tuples to insert or assign. + * @param n Number of key-score pairs to assign. * @param keys The keys to insert on GPU-accessible memory with shape * (n). - * @param values The values to insert on GPU-accessible memory with - * shape (n, DIM). * @parblock * The scores should be a `uint64_t` value. You can specify a value that * such as the timestamp of the key insertion, number of the key @@ -1222,14 +1222,16 @@ class HashTable { * * @param unique_key If all keys in the same batch are unique. */ - void assign(const size_type n, - const key_type* keys, // (n) - const score_type* scores = nullptr, // (n) - cudaStream_t stream = 0, bool unique_key = true) { + void assign_scores(const size_type n, + const key_type* keys, // (n) + const score_type* scores = nullptr, // (n) + cudaStream_t stream = 0, bool unique_key = true) { if (n == 0) { return; } + check_evict_strategy(scores); + { update_shared_lock lock(mutex_, stream); static thread_local int step_counter = 0; @@ -1260,6 +1262,150 @@ class HashTable { CudaCheckError(); } + /** + * @brief Alias of `assign_scores`. + */ + void assign(const size_type n, + const key_type* keys, // (n) + const score_type* scores = nullptr, // (n) + cudaStream_t stream = 0, bool unique_key = true) { + assign_scores(n, keys, scores, stream, unique_key); + } + + /** + * @brief Assign new values for each keys . + * If the key doesn't exist, the operation on the key will be ignored. + * + * @param n Number of key-value pairs to assign. + * @param keys The keys need to be operated, which must be on GPU-accessible + * memory with shape (n). + * @param values The values need to be updated, which must be on + * GPU-accessible memory with shape (n, DIM). + * + * @param stream The CUDA stream that is used to execute the operation. + * + * @param unique_key If all keys in the same batch are unique. + */ + void assign_values(const size_type n, + const key_type* keys, // (n) + const value_type* values, // (n, DIM) + cudaStream_t stream = 0, bool unique_key = true) { + if (n == 0) { + return; + } + + update_shared_lock lock(mutex_, stream); + + if (is_fast_mode()) { + static thread_local int step_counter = 0; + static thread_local float load_factor = 0.0; + + if (((step_counter++) % kernel_select_interval_) == 0) { + load_factor = fast_load_factor(0, stream, false); + } + using Selector = KernelSelector_UpdateValues; + if (Selector::callable(unique_key, + static_cast(options_.max_bucket_size), + static_cast(options_.dim))) { + typename Selector::Params kernelParams( + load_factor, table_->buckets, table_->buckets_num, + static_cast(options_.max_bucket_size), + static_cast(options_.dim), keys, values, n); + Selector::select_kernel(kernelParams, stream); + } else { + using Selector = + SelectUpdateValuesKernelWithIO; + Selector::execute_kernel(load_factor, options_.block_size, + options_.max_bucket_size, table_->buckets_num, + options_.dim, stream, n, d_table_, + table_->buckets, keys, values); + } + } else { + const size_type dev_ws_size{ + n * (sizeof(value_type*) + sizeof(key_type) + sizeof(int))}; + auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)}; + auto d_dst{dev_ws.get(0)}; + auto keys_ptr{reinterpret_cast(d_dst + n)}; + auto d_src_offset{reinterpret_cast(keys_ptr + n)}; + + CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream)); + + constexpr uint32_t MinBucketCapacityFilter = + sizeof(VecD_Load) / sizeof(D); + + bool filter_condition = + options_.max_bucket_size >= MinBucketCapacityFilter && + !options_.io_by_cpu && unique_key; + + if (filter_condition) { + constexpr uint32_t BLOCK_SIZE = 128U; + + tlp_update_values_kernel_hybrid + <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>( + table_->buckets, table_->buckets_num, options_.max_bucket_size, + options_.dim, keys, d_dst, keys_ptr, d_src_offset, n); + + } else { + const size_t block_size = options_.block_size; + const size_t N = n * TILE_SIZE; + const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size); + + update_values_kernel + <<>>( + d_table_, table_->buckets, options_.max_bucket_size, + table_->buckets_num, options_.dim, keys, d_dst, d_src_offset, + N); + } + + { + thrust::device_ptr d_dst_ptr( + reinterpret_cast(d_dst)); + thrust::device_ptr d_src_offset_ptr(d_src_offset); + + thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), d_dst_ptr, + d_dst_ptr + n, d_src_offset_ptr, + thrust::less()); + } + + if (filter_condition) { + const size_t block_size = options_.io_block_size; + const size_t N = n * dim(); + const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size); + + write_kernel_unlock_key + <<>>(values, d_dst, d_src_offset, + dim(), keys, keys_ptr, N); + + } else if (options_.io_by_cpu) { + const size_type host_ws_size{dev_ws_size + + n * sizeof(value_type) * dim()}; + auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)}; + auto h_dst{host_ws.get(0)}; + auto h_src_offset{reinterpret_cast(h_dst + n)}; + auto h_values{reinterpret_cast(h_src_offset + n)}; + + CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_size, + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_values, values, host_ws_size - dev_ws_size, + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + write_by_cpu(h_dst, h_values, h_src_offset, dim(), n); + } else { + const size_t block_size = options_.io_block_size; + const size_t N = n * dim(); + const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size); + + write_kernel + <<>>(values, d_dst, d_src_offset, + dim(), N); + } + } + + CudaCheckError(); + } + /** * @brief Searches the hash table for the specified keys. * diff --git a/tests/assign_values_test.cc.cu b/tests/assign_values_test.cc.cu new file mode 100644 index 000000000..72216c9b2 --- /dev/null +++ b/tests/assign_values_test.cc.cu @@ -0,0 +1,747 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * test API: assign_values + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "merlin_hashtable.cuh" +#include "test_util.cuh" + +constexpr size_t DIM = 16; +using K = uint64_t; +using V = float; +using S = uint64_t; +using EvictStrategy = nv::merlin::EvictStrategy; +using TableOptions = nv::merlin::HashTableOptions; + +void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) { + constexpr uint64_t BUCKET_NUM = 8UL; + constexpr uint64_t BUCKET_MAX_SIZE = 128UL; + constexpr uint64_t INIT_CAPACITY = BUCKET_NUM * BUCKET_MAX_SIZE; // 1024UL; + constexpr uint64_t MAX_CAPACITY = INIT_CAPACITY; + constexpr uint64_t BASE_KEY_NUM = BUCKET_MAX_SIZE; + constexpr uint64_t TEST_KEY_NUM = 4; + constexpr uint64_t TEMP_KEY_NUM = std::max(BASE_KEY_NUM, TEST_KEY_NUM); + constexpr uint64_t TEST_TIMES = 128; + + TableOptions options; + + options.init_capacity = INIT_CAPACITY; + options.max_capacity = MAX_CAPACITY; + options.dim = DIM; + options.max_hbm_for_vectors = nv::merlin::GB(max_hbm_for_vectors); + using Table = nv::merlin::HashTable; + + std::array h_keys_base; + std::array h_scores_base; + std::array h_vectors_base; + + std::array h_keys_test; + std::array h_scores_test; + std::array h_vectors_test; + + std::array h_keys_temp; + std::array h_scores_temp; + std::array h_vectors_temp; + + K* d_keys_temp; + S* d_scores_temp = nullptr; + V* d_vectors_temp; + + CUDA_CHECK(cudaMalloc(&d_keys_temp, TEMP_KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMalloc(&d_scores_temp, TEMP_KEY_NUM * sizeof(S))); + CUDA_CHECK( + cudaMalloc(&d_vectors_temp, TEMP_KEY_NUM * sizeof(V) * options.dim)); + + test_util::create_keys_in_one_buckets( + h_keys_base.data(), h_scores_base.data(), h_vectors_base.data(), + BASE_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0, 0x3FFFFFFFFFFFFFFF); + + test_util::create_keys_in_one_buckets( + h_keys_test.data(), h_scores_test.data(), h_vectors_test.data(), + TEST_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0x3FFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFD); + + h_keys_test[2] = h_keys_base[72]; + h_keys_test[3] = h_keys_base[73]; + + for (int i = 0; i < options.dim; i++) { + h_vectors_test[2 * options.dim + i] = + static_cast(h_keys_base[72] * 0.00002); + h_vectors_test[3 * options.dim + i] = + static_cast(h_keys_base[73] * 0.00002); + } + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + size_t total_size = 0; + size_t dump_counter = 0; + for (int i = 0; i < TEST_TIMES; i++) { + std::unique_ptr table = std::make_unique
(); + table->init(options); + + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, 0); + + { + CUDA_CHECK(cudaMemcpy(d_keys_temp, h_keys_base.data(), + BASE_KEY_NUM * sizeof(K), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_scores_temp, h_scores_base.data(), + BASE_KEY_NUM * sizeof(S), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_base.data(), + BASE_KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyHostToDevice)); + S start_ts = test_util::host_nano(stream); + table->find_or_insert(BASE_KEY_NUM, d_keys_temp, d_vectors_temp, nullptr, + stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + S end_ts = test_util::host_nano(stream); + + size_t total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, BUCKET_MAX_SIZE); + + dump_counter = table->export_batch(table->capacity(), 0, d_keys_temp, + d_vectors_temp, d_scores_temp, stream); + ASSERT_EQ(dump_counter, BUCKET_MAX_SIZE); + + CUDA_CHECK(cudaMemcpy(h_keys_temp.data(), d_keys_temp, + BASE_KEY_NUM * sizeof(K), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(h_scores_temp.data(), d_scores_temp, + BASE_KEY_NUM * sizeof(S), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(h_vectors_temp.data(), d_vectors_temp, + BASE_KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyDefault)); + + std::array h_scores_temp_sorted(h_scores_temp); + std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end()); + + ASSERT_GE(h_scores_temp_sorted[0], start_ts); + ASSERT_LE(h_scores_temp_sorted[TEST_KEY_NUM - 1], end_ts); + for (int i = 0; i < dump_counter; i++) { + for (int j = 0; j < options.dim; j++) { + ASSERT_EQ(h_vectors_temp[i * options.dim + j], + static_cast(h_keys_temp[i] * 0.00001)); + } + } + } + + { + CUDA_CHECK(cudaMemcpy(d_keys_temp, h_keys_test.data(), + TEST_KEY_NUM * sizeof(K), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_scores_temp, h_scores_test.data(), + TEST_KEY_NUM * sizeof(S), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_test.data(), + TEST_KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyHostToDevice)); + + S start_ts = test_util::host_nano(stream); + table->assign_values(TEST_KEY_NUM, d_keys_temp, d_vectors_temp, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + size_t total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, BUCKET_MAX_SIZE); + + dump_counter = table->export_batch(table->capacity(), 0, d_keys_temp, + d_vectors_temp, d_scores_temp, stream); + ASSERT_EQ(dump_counter, BUCKET_MAX_SIZE); + + CUDA_CHECK(cudaMemcpy(h_keys_temp.data(), d_keys_temp, + TEMP_KEY_NUM * sizeof(K), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(h_scores_temp.data(), d_scores_temp, + TEMP_KEY_NUM * sizeof(S), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(h_vectors_temp.data(), d_vectors_temp, + TEMP_KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyDefault)); + + for (int i = 0; i < TEMP_KEY_NUM; i++) { + V expected_v = (h_keys_temp[i] == h_keys_test[2] || + h_keys_temp[i] == h_keys_test[3]) + ? static_cast(h_keys_temp[i] * 0.00002) + : static_cast(h_keys_temp[i] * 0.00001); + for (int j = 0; j < options.dim; j++) { + ASSERT_EQ(h_vectors_temp[i * options.dim + j], expected_v); + } + ASSERT_LE(h_scores_temp[i], start_ts); + } + } + } + CUDA_CHECK(cudaStreamDestroy(stream)); + + CUDA_CHECK(cudaFree(d_keys_temp)); + CUDA_CHECK(cudaFree(d_scores_temp)); + CUDA_CHECK(cudaFree(d_vectors_temp)); + + CUDA_CHECK(cudaDeviceSynchronize()); + + CudaCheckError(); +} + +void test_evict_strategy_epochlfu_basic(size_t max_hbm_for_vectors) { + constexpr uint64_t BUCKET_NUM = 8UL; + constexpr uint64_t BUCKET_MAX_SIZE = 128UL; + constexpr uint64_t INIT_CAPACITY = BUCKET_NUM * BUCKET_MAX_SIZE; // 1024UL; + constexpr uint64_t MAX_CAPACITY = INIT_CAPACITY; + constexpr uint64_t BASE_KEY_NUM = BUCKET_MAX_SIZE; + constexpr uint64_t TEST_KEY_NUM = 4; + constexpr uint64_t TEMP_KEY_NUM = std::max(BASE_KEY_NUM, TEST_KEY_NUM); + constexpr uint64_t TEST_TIMES = 128; + + TableOptions options; + + options.init_capacity = INIT_CAPACITY; + options.max_capacity = MAX_CAPACITY; + options.dim = DIM; + options.max_hbm_for_vectors = nv::merlin::GB(max_hbm_for_vectors); + using Table = nv::merlin::HashTable; + + std::array h_keys_base; + std::array h_scores_base; + std::array h_vectors_base; + + std::array h_keys_test; + std::array h_scores_test; + std::array h_vectors_test; + + std::array h_keys_temp; + std::array h_scores_temp; + std::array h_vectors_temp; + + K* d_keys_temp; + S* d_scores_temp = nullptr; + V* d_vectors_temp; + + int freq_range = 1000; + + CUDA_CHECK(cudaMalloc(&d_keys_temp, TEMP_KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMalloc(&d_scores_temp, TEMP_KEY_NUM * sizeof(S))); + CUDA_CHECK( + cudaMalloc(&d_vectors_temp, TEMP_KEY_NUM * sizeof(V) * options.dim)); + + test_util::create_keys_in_one_buckets_lfu( + h_keys_base.data(), h_scores_base.data(), h_vectors_base.data(), + BASE_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0, 0x3FFFFFFFFFFFFFFF, + freq_range); + + test_util::create_keys_in_one_buckets_lfu( + h_keys_test.data(), h_scores_test.data(), h_vectors_test.data(), + TEST_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0x3FFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFD, freq_range); + + // Simulate overflow of low 32bits. + h_scores_base[71] = static_cast(std::numeric_limits::max() - + static_cast(1)); + + h_keys_test[1] = h_keys_base[71]; + h_keys_test[2] = h_keys_base[72]; + h_keys_test[3] = h_keys_base[73]; + + h_scores_test[1] = h_scores_base[71]; + h_scores_test[2] = h_keys_base[72] % freq_range; + h_scores_test[3] = h_keys_base[73] % freq_range; + + for (int i = 0; i < options.dim; i++) { + h_vectors_test[1 * options.dim + i] = + static_cast(h_keys_base[71] * 0.00002); + h_vectors_test[2 * options.dim + i] = + static_cast(h_keys_base[72] * 0.00002); + h_vectors_test[3 * options.dim + i] = + static_cast(h_keys_base[73] * 0.00002); + } + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + size_t total_size = 0; + size_t dump_counter = 0; + S global_epoch = 1; + for (int i = 0; i < TEST_TIMES; i++) { + std::unique_ptr
table = std::make_unique
(); + table->init(options); + + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, 0); + + { + CUDA_CHECK(cudaMemcpy(d_keys_temp, h_keys_base.data(), + BASE_KEY_NUM * sizeof(K), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_scores_temp, h_scores_base.data(), + BASE_KEY_NUM * sizeof(S), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_base.data(), + BASE_KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyHostToDevice)); + EvictStrategy::set_global_epoch(global_epoch); + table->find_or_insert(BASE_KEY_NUM, d_keys_temp, d_vectors_temp, + d_scores_temp, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + size_t total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, BUCKET_MAX_SIZE); + + dump_counter = table->export_batch(table->capacity(), 0, d_keys_temp, + d_vectors_temp, d_scores_temp, stream); + ASSERT_EQ(dump_counter, BUCKET_MAX_SIZE); + + CUDA_CHECK(cudaMemcpy(h_keys_temp.data(), d_keys_temp, + BASE_KEY_NUM * sizeof(K), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(h_scores_temp.data(), d_scores_temp, + BASE_KEY_NUM * sizeof(S), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(h_vectors_temp.data(), d_vectors_temp, + BASE_KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyDefault)); + + for (int i = 0; i < dump_counter; i++) { + if (h_keys_temp[i] == h_keys_base[71]) { + S expected_score = test_util::make_expected_score_for_epochlfu( + global_epoch, h_scores_base[71]); + ASSERT_EQ(h_scores_temp[i], expected_score); + } else { + S expected_score = test_util::make_expected_score_for_epochlfu( + global_epoch, (h_keys_temp[i] % freq_range)); + ASSERT_EQ(h_scores_temp[i], expected_score); + } + for (int j = 0; j < options.dim; j++) { + ASSERT_EQ(h_vectors_temp[i * options.dim + j], + static_cast(h_keys_temp[i] * 0.00001)); + } + } + } + + { + CUDA_CHECK(cudaMemcpy(d_keys_temp, h_keys_test.data(), + TEST_KEY_NUM * sizeof(K), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_scores_temp, h_scores_test.data(), + TEST_KEY_NUM * sizeof(S), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_test.data(), + TEST_KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyHostToDevice)); + table->assign_values(TEST_KEY_NUM, d_keys_temp, d_vectors_temp, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + size_t total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, BUCKET_MAX_SIZE); + + dump_counter = table->export_batch(table->capacity(), 0, d_keys_temp, + d_vectors_temp, d_scores_temp, stream); + ASSERT_EQ(dump_counter, BUCKET_MAX_SIZE); + + CUDA_CHECK(cudaMemcpy(h_keys_temp.data(), d_keys_temp, + TEMP_KEY_NUM * sizeof(K), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(h_scores_temp.data(), d_scores_temp, + TEMP_KEY_NUM * sizeof(S), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(h_vectors_temp.data(), d_vectors_temp, + TEMP_KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyDefault)); + + ASSERT_TRUE(h_keys_temp.end() != std::find(h_keys_temp.begin(), + h_keys_temp.end(), + h_keys_base[71])); + + for (int i = 0; i < dump_counter; i++) { + if (h_keys_temp[i] == h_keys_base[71]) { + S expected_score = test_util::make_expected_score_for_epochlfu( + global_epoch, h_scores_base[71]); + ASSERT_EQ(h_scores_temp[i], expected_score); + } else { + S expected_score = test_util::make_expected_score_for_epochlfu( + global_epoch, (h_keys_temp[i] % freq_range)); + ASSERT_EQ(h_scores_temp[i], expected_score); + } + for (int j = 0; j < options.dim; j++) { + V expected_v = (h_keys_temp[i] == h_keys_test[1] || + h_keys_temp[i] == h_keys_test[2] || + h_keys_temp[i] == h_keys_test[3]) + ? static_cast(h_keys_temp[i] * 0.00002) + : static_cast(h_keys_temp[i] * 0.00001); + ASSERT_EQ(h_vectors_temp[i * options.dim + j], expected_v); + } + } + } + } + CUDA_CHECK(cudaStreamDestroy(stream)); + + CUDA_CHECK(cudaFree(d_keys_temp)); + CUDA_CHECK(cudaFree(d_scores_temp)); + CUDA_CHECK(cudaFree(d_vectors_temp)); + + CUDA_CHECK(cudaDeviceSynchronize()); + + CudaCheckError(); +} + +template +void CheckAssignOnEpochLfu(Table* table, + test_util::KVMSBuffer* data_buffer, + test_util::KVMSBuffer* evict_buffer, + test_util::KVMSBuffer* pre_data_buffer, + size_t len, cudaStream_t stream, TableOptions& opt, + unsigned int global_epoch) { + std::map> values_map_before_insert; + std::map> values_map_after_insert; + + std::unordered_map scores_map_before_insert; + std::map scores_map_after_insert; + + std::map scores_map_current_batch; + std::map scores_map_current_evict; + + K* keys = data_buffer->keys_ptr(); + V* values = data_buffer->values_ptr(); + S* scores = data_buffer->scores_ptr(); + + K* evicted_keys = evict_buffer->keys_ptr(); + V* evicted_values = evict_buffer->values_ptr(); + S* evicted_scores = evict_buffer->scores_ptr(); + + for (size_t i = 0; i < len; i++) { + scores_map_current_batch[data_buffer->keys_ptr(false)[i]] = + data_buffer->scores_ptr(false)[i]; + } + + K* h_tmp_keys = nullptr; + V* h_tmp_values = nullptr; + S* h_tmp_scores = nullptr; + bool* h_tmp_founds = nullptr; + + K* d_tmp_keys = nullptr; + V* d_tmp_values = nullptr; + S* d_tmp_scores = nullptr; + bool* d_tmp_founds = nullptr; + + size_t table_size_before = table->size(stream); + size_t cap = table_size_before + len; + + CUDA_CHECK(cudaMallocAsync(&d_tmp_keys, cap * sizeof(K), stream)); + CUDA_CHECK(cudaMemsetAsync(d_tmp_keys, 0, cap * sizeof(K), stream)); + CUDA_CHECK(cudaMallocAsync(&d_tmp_values, cap * dim * sizeof(V), stream)); + CUDA_CHECK(cudaMemsetAsync(d_tmp_values, 0, cap * dim * sizeof(V), stream)); + CUDA_CHECK(cudaMallocAsync(&d_tmp_scores, cap * sizeof(S), stream)); + CUDA_CHECK(cudaMemsetAsync(d_tmp_scores, 0, cap * sizeof(S), stream)); + CUDA_CHECK(cudaMallocAsync(&d_tmp_founds, cap * sizeof(bool), stream)); + CUDA_CHECK(cudaMemsetAsync(d_tmp_founds, 0, cap * sizeof(bool), stream)); + h_tmp_keys = (K*)malloc(cap * sizeof(K)); + h_tmp_values = (V*)malloc(cap * dim * sizeof(V)); + h_tmp_scores = (S*)malloc(cap * sizeof(S)); + h_tmp_founds = (bool*)malloc(cap * sizeof(bool)); + + size_t table_size_verify0 = table->export_batch( + table->capacity(), 0, d_tmp_keys, d_tmp_values, d_tmp_scores, stream); + ASSERT_EQ(table_size_before, table_size_verify0); + + CUDA_CHECK(cudaMemcpyAsync(h_tmp_keys, d_tmp_keys, + table_size_before * sizeof(K), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_values, d_tmp_values, + table_size_before * dim * sizeof(V), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_scores, d_tmp_scores, + table_size_before * sizeof(S), + cudaMemcpyDeviceToHost, stream)); + + CUDA_CHECK(cudaMemcpyAsync(h_tmp_keys + table_size_before, keys, + len * sizeof(K), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_values + table_size_before * dim, values, + len * dim * sizeof(V), cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_scores + table_size_before, scores, + len * sizeof(S), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (size_t i = 0; i < cap; i++) { + test_util::ValueArray* vec = + reinterpret_cast*>(h_tmp_values + + i * dim); + values_map_before_insert[h_tmp_keys[i]] = *vec; + } + + for (size_t i = 0; i < table_size_before; i++) { + scores_map_before_insert[h_tmp_keys[i]] = h_tmp_scores[i]; + } + + table->assign_values(len, keys, values, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + { + size_t table_size_verify1 = table->export_batch( + table->capacity(), 0, d_tmp_keys, d_tmp_values, d_tmp_scores, stream); + + CUDA_CHECK(cudaMemcpyAsync(h_tmp_keys, d_tmp_keys, + table_size_before * sizeof(K), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_values, d_tmp_values, + table_size_before * dim * sizeof(V), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_scores, d_tmp_scores, + table_size_before * sizeof(S), + cudaMemcpyDeviceToHost, stream)); + + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(table_size_verify1, table_size_before); + + size_t score_error_cnt = 0; + + for (int64_t i = table_size_before - 1; i >= 0; i--) { + test_util::ValueArray* vec = + reinterpret_cast*>(h_tmp_values + + i * dim); + values_map_after_insert[h_tmp_keys[i]] = *vec; + scores_map_after_insert[h_tmp_keys[i]] = h_tmp_scores[i]; + } + + for (auto it : scores_map_current_batch) { + const K key = it.first; + const K score = it.second; + S current_score = scores_map_after_insert[key]; + S score_before_insert = 0; + if (scores_map_before_insert.find(key) != + scores_map_before_insert.end()) { + score_before_insert = scores_map_before_insert[key]; + bool valid = ((current_score >> 32) < global_epoch) && + ((current_score & 0xFFFFFFFF) == + (0xFFFFFFFF & score_before_insert)); + + if (!valid) { + score_error_cnt++; + } + } + } + std::cout << "Check assign behavior got " + << ", score_error_cnt: " << score_error_cnt + << ", while len: " << len << std::endl; + ASSERT_EQ(score_error_cnt, 0); + } + + for (int64_t i = 0; i < table_size_before; i++) { + values_map_before_insert[h_tmp_keys[i]] = + values_map_after_insert[h_tmp_keys[i]]; + scores_map_before_insert[h_tmp_keys[i]] = + scores_map_after_insert[h_tmp_keys[i]]; + } + values_map_after_insert.clear(); + scores_map_after_insert.clear(); + + EvictStrategy::set_global_epoch(global_epoch); + auto start = std::chrono::steady_clock::now(); + size_t filtered_len = table->insert_and_evict( + len, keys, values, + (Table::evict_strategy == EvictStrategy::kLru || + Table::evict_strategy == EvictStrategy::kEpochLru) + ? nullptr + : scores, + evicted_keys, evicted_values, evicted_scores, stream); + evict_buffer->SyncData(false, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + auto end = std::chrono::steady_clock::now(); + auto diff = std::chrono::duration_cast(end - start); + + for (size_t i = 0; i < filtered_len; i++) { + scores_map_current_evict[evict_buffer->keys_ptr(false)[i]] = + evict_buffer->scores_ptr(false)[i]; + } + + float dur = diff.count(); + + size_t table_size_after = table->size(stream); + size_t table_size_verify1 = table->export_batch( + table->capacity(), 0, d_tmp_keys, d_tmp_values, d_tmp_scores, stream); + + ASSERT_EQ(table_size_verify1, table_size_after); + + size_t new_cap = table_size_after + filtered_len; + CUDA_CHECK(cudaMemcpyAsync(h_tmp_keys, d_tmp_keys, + table_size_after * sizeof(K), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_values, d_tmp_values, + table_size_after * dim * sizeof(V), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_scores, d_tmp_scores, + table_size_after * sizeof(S), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_keys + table_size_after, evicted_keys, + filtered_len * sizeof(K), cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_values + table_size_after * dim, + evicted_values, filtered_len * dim * sizeof(V), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_tmp_scores + table_size_after, evicted_scores, + filtered_len * sizeof(S), cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + size_t key_miss_cnt = 0; + size_t value_diff_cnt = 0; + size_t score_error_cnt1 = 0; + size_t score_error_cnt2 = 0; + + for (int64_t i = new_cap - 1; i >= 0; i--) { + test_util::ValueArray* vec = + reinterpret_cast*>(h_tmp_values + + i * dim); + values_map_after_insert[h_tmp_keys[i]] = *vec; + scores_map_after_insert[h_tmp_keys[i]] = h_tmp_scores[i]; + if (i >= (new_cap - filtered_len)) { + bool valid = ((h_tmp_scores[i] >> 32) < (global_epoch - 2)); + if (!valid) { + score_error_cnt1++; + } + } + } + + for (auto it : scores_map_current_batch) { + const K key = it.first; + const K score = it.second; + S current_score = scores_map_after_insert[key]; + S score_before_insert = 0; + if (values_map_after_insert.find(key) != values_map_after_insert.end() && + scores_map_current_evict.find(key) == scores_map_current_evict.end()) { + score_before_insert = scores_map_before_insert[key]; + } + bool valid = ((current_score >> 32) == global_epoch) && + ((current_score & 0xFFFFFFFF) == + ((0xFFFFFFFF & score_before_insert) + (0xFFFFFFFF & score))); + + if (!valid) { + score_error_cnt2++; + } + } + + for (auto& it : values_map_before_insert) { + if (values_map_after_insert.find(it.first) == + values_map_after_insert.end()) { + ++key_miss_cnt; + continue; + } + test_util::ValueArray& vec0 = it.second; + test_util::ValueArray& vec1 = values_map_after_insert.at(it.first); + for (size_t j = 0; j < dim; j++) { + if (vec0[j] != vec1[j]) { + ++value_diff_cnt; + break; + } + } + } + + std::cout << "Check insert_and_evict behavior got " + << "key_miss_cnt: " << key_miss_cnt + << ", value_diff_cnt: " << value_diff_cnt + << ", score_error_cnt1: " << score_error_cnt1 + << ", score_error_cnt2: " << score_error_cnt2 + << ", while table_size_before: " << table_size_before + << ", while table_size_after: " << table_size_after + << ", while len: " << len << std::endl; + + ASSERT_EQ(key_miss_cnt, 0); + ASSERT_EQ(value_diff_cnt, 0); + ASSERT_EQ(score_error_cnt1, 0); + ASSERT_EQ(score_error_cnt2, 0); + + CUDA_CHECK(cudaFreeAsync(d_tmp_keys, stream)); + CUDA_CHECK(cudaFreeAsync(d_tmp_values, stream)); + CUDA_CHECK(cudaFreeAsync(d_tmp_scores, stream)); + CUDA_CHECK(cudaFreeAsync(d_tmp_founds, stream)); + free(h_tmp_keys); + free(h_tmp_values); + free(h_tmp_scores); + free(h_tmp_founds); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +void test_assign_advanced_on_epochlfu(size_t max_hbm_for_vectors) { + const size_t U = 1024 * 1024; + const size_t B = 100000; + constexpr size_t dim = 16; + + TableOptions opt; + + opt.max_capacity = U; + opt.init_capacity = U; + opt.max_hbm_for_vectors = U * dim * sizeof(V); + opt.max_bucket_size = 128; + opt.max_hbm_for_vectors = nv::merlin::GB(max_hbm_for_vectors); + using Table = nv::merlin::HashTable; + opt.dim = dim; + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + std::unique_ptr
table = std::make_unique
(); + table->init(opt); + + test_util::KVMSBuffer evict_buffer; + evict_buffer.Reserve(B, dim, stream); + evict_buffer.ToZeros(stream); + + test_util::KVMSBuffer data_buffer; + test_util::KVMSBuffer pre_data_buffer; + data_buffer.Reserve(B, dim, stream); + pre_data_buffer.Reserve(B, dim, stream); + + size_t offset = 0; + int freq_range = 100; + float repeat_rate = 0.9; + for (unsigned int global_epoch = 1; global_epoch <= 20; global_epoch++) { + repeat_rate = global_epoch <= 1 ? 0.0 : 0.1; + if (global_epoch <= 1) { + test_util::create_random_keys_advanced( + dim, data_buffer.keys_ptr(false), data_buffer.scores_ptr(false), + data_buffer.values_ptr(false), (int)B, B * 32, freq_range); + } else { + test_util::create_random_keys_advanced( + dim, data_buffer.keys_ptr(false), pre_data_buffer.keys_ptr(false), + data_buffer.scores_ptr(false), data_buffer.values_ptr(false), (int)B, + B * 32, freq_range, repeat_rate); + } + data_buffer.SyncData(true, stream); + if (global_epoch <= 1) { + pre_data_buffer.CopyFrom(data_buffer, stream); + } + + CheckAssignOnEpochLfu(table.get(), &data_buffer, + &evict_buffer, &pre_data_buffer, + B, stream, opt, global_epoch); + + pre_data_buffer.CopyFrom(data_buffer, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + offset += B; + } +} + +TEST(AssignValuesTest, test_evict_strategy_lru_basic) { + test_evict_strategy_lru_basic(16); + test_evict_strategy_lru_basic(0); +} +TEST(AssignValuesTest, test_evict_strategy_epochlfu_basic) { + test_evict_strategy_epochlfu_basic(16); + test_evict_strategy_epochlfu_basic(0); +} +TEST(AssignValuesTest, test_assign_advanced_on_epochlfu) { + test_assign_advanced_on_epochlfu(16); +} \ No newline at end of file