From 97bea9363190100580a7b7d1e369b58e8cf18776 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 12 Jun 2024 14:19:24 -0700 Subject: [PATCH] [Fix] out-of-band issue and add test case for it --- include/merlin/core_kernels.cuh | 13 +- tests/merlin_hashtable_test.cc.cu | 253 ++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 6 deletions(-) diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index b8ffd8bd..48e67037 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -87,11 +87,11 @@ __global__ void allocate_bucket_vectors(Bucket* __restrict buckets, template __global__ void allocate_bucket_others(Bucket* __restrict buckets, size_t total_size_per_bucket, - size_t num_of_buckets_per_alloc, + size_t num_of_buckets, const int start_index, uint8_t* address, const uint32_t reserve_size, const size_t bucket_max_size) { - for (size_t step = 0; step < num_of_buckets_per_alloc; step++) { + for (size_t step = 0; step < num_of_buckets; step++) { size_t index = start_index + step; buckets[index].digests_ = address; buckets[index].keys_ = @@ -238,12 +238,13 @@ void initialize_buckets(Table** table, BaseAllocator* allocator, */ for (int i = start; i < end; i += (*table)->num_of_buckets_per_alloc) { uint8_t* address = nullptr; + size_t num_of_buckets = + std::min(end - i, (*table)->num_of_buckets_per_alloc); allocator->alloc(MemoryType::Device, (void**)&(address), - bucket_memory_size * (*table)->num_of_buckets_per_alloc); + bucket_memory_size * num_of_buckets); allocate_bucket_others - <<<1, 1>>>((*table)->buckets, bucket_memory_size, - (*table)->num_of_buckets_per_alloc, i, address, reserve_size, - bucket_max_size); + <<<1, 1>>>((*table)->buckets, bucket_memory_size, num_of_buckets, i, + address, reserve_size, bucket_max_size); } CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/tests/merlin_hashtable_test.cc.cu b/tests/merlin_hashtable_test.cc.cu index 8614bdb8..0ca93e28 100644 --- a/tests/merlin_hashtable_test.cc.cu +++ b/tests/merlin_hashtable_test.cc.cu @@ -373,6 +373,255 @@ void test_basic(size_t max_hbm_for_vectors) { CudaCheckError(); } +void test_basic_without_rehash(size_t max_hbm_for_vectors) { + constexpr uint64_t BUCKET_MAX_SIZE = 128; + constexpr uint64_t NUM_OF_BUCKETS_PER_ALLOC = 2048; + constexpr uint64_t INIT_CAPACITY = + 64 * 1024 * 1024UL - (NUM_OF_BUCKETS_PER_ALLOC * BUCKET_MAX_SIZE) + 1; + constexpr uint64_t MAX_CAPACITY = INIT_CAPACITY; + constexpr uint64_t KEY_NUM = 1 * 1024 * 1024UL; + constexpr uint64_t TEST_TIMES = 1; + + K* h_keys; + S* h_scores; + V* h_vectors; + bool* h_found; + + TableOptions options; + + options.init_capacity = INIT_CAPACITY; + options.max_capacity = MAX_CAPACITY; + options.dim = DIM; + options.max_bucket_size = BUCKET_MAX_SIZE; + options.max_hbm_for_vectors = nv::merlin::GB(max_hbm_for_vectors); + options.reserved_key_start_bit = 2; + options.num_of_buckets_per_alloc = NUM_OF_BUCKETS_PER_ALLOC; + + using Table = nv::merlin::HashTable; + + CUDA_CHECK(cudaMallocHost(&h_keys, KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMallocHost(&h_scores, KEY_NUM * sizeof(S))); + CUDA_CHECK(cudaMallocHost(&h_vectors, KEY_NUM * sizeof(V) * options.dim)); + CUDA_CHECK(cudaMallocHost(&h_found, KEY_NUM * sizeof(bool))); + + CUDA_CHECK(cudaMemset(h_vectors, 0, KEY_NUM * sizeof(V) * options.dim)); + + test_util::create_random_keys(h_keys, h_scores, h_vectors, + KEY_NUM); + + K* d_keys; + S* d_scores = nullptr; + V* d_vectors; + V* d_new_vectors; + bool* d_found; + size_t dump_counter = 0; + + CUDA_CHECK(cudaMalloc(&d_keys, KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMalloc(&d_scores, KEY_NUM * sizeof(S))); + CUDA_CHECK(cudaMalloc(&d_vectors, KEY_NUM * sizeof(V) * options.dim)); + CUDA_CHECK(cudaMalloc(&d_new_vectors, KEY_NUM * sizeof(V) * options.dim)); + CUDA_CHECK(cudaMalloc(&d_found, KEY_NUM * sizeof(bool))); + + CUDA_CHECK( + cudaMemcpy(d_keys, h_keys, KEY_NUM * sizeof(K), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_scores, h_scores, KEY_NUM * sizeof(S), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_vectors, h_vectors, KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyHostToDevice)); + + CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool))); + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + uint64_t total_size = 0; + for (int i = 0; i < TEST_TIMES; i++) { + std::unique_ptr table = std::make_unique
(); + table->init(options); + + ASSERT_EQ(table->bucket_count(), + 522241); // 1 + (INIT_CAPACITY / options.bucket_max_size) + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, 0); + + table->insert_or_assign(KEY_NUM, d_keys, d_vectors, d_scores, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, KEY_NUM); + + CUDA_CHECK(cudaMemset(d_vectors, 0, KEY_NUM * sizeof(V) * options.dim)); + table->find(KEY_NUM, d_keys, d_vectors, d_found, nullptr, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + int found_num = 0; + CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_scores, d_scores, KEY_NUM * sizeof(S), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(h_vectors, d_vectors, + KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyDeviceToHost)); + + for (int i = 0; i < KEY_NUM; i++) { + if (h_found[i]) found_num++; + ASSERT_EQ(h_scores[i], h_keys[i]); + for (int j = 0; j < options.dim; j++) { + ASSERT_EQ(h_vectors[i * options.dim + j], + static_cast(h_keys[i] * 0.00001)); + } + } + + CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool))); + table->contains(KEY_NUM, d_keys, d_found, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + int contains_num = 0; + CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool), + cudaMemcpyDeviceToHost)); + for (int i = 0; i < KEY_NUM; i++) { + if (h_found[i]) contains_num++; + } + ASSERT_EQ(contains_num, found_num); + + CUDA_CHECK(cudaMemset(d_new_vectors, 2, KEY_NUM * sizeof(V) * options.dim)); + table->insert_or_assign(KEY_NUM, d_keys, + reinterpret_cast(d_new_vectors), d_scores, + stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, KEY_NUM); + + CUDA_CHECK(cudaMemset(d_new_vectors, 0, KEY_NUM * sizeof(V) * options.dim)); + table->find(KEY_NUM, d_keys, reinterpret_cast(d_new_vectors), + d_found, nullptr, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_vectors, d_new_vectors, + KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyDeviceToHost)); + found_num = 0; + uint32_t i_value = 0x2020202; + for (int i = 0; i < KEY_NUM; i++) { + if (h_found[i]) found_num++; + for (int j = 0; j < options.dim; j++) { + ASSERT_EQ(h_vectors[i * options.dim + j], + *(reinterpret_cast(&i_value))); + } + } + ASSERT_EQ(found_num, KEY_NUM); + + CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool))); + table->contains(KEY_NUM, d_keys, d_found, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + contains_num = 0; + CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool), + cudaMemcpyDeviceToHost)); + for (int i = 0; i < KEY_NUM; i++) { + if (h_found[i]) contains_num++; + } + ASSERT_EQ(contains_num, found_num); + + table->accum_or_assign(KEY_NUM, d_keys, d_vectors, d_found, d_scores, + stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, KEY_NUM); + + table->erase(KEY_NUM >> 1, d_keys, stream); + size_t total_size_after_erase = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size_after_erase, total_size >> 1); + + table->clear(stream); + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, 0); + + table->insert_or_assign(KEY_NUM, d_keys, d_vectors, d_scores, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + CUDA_CHECK(cudaMemset(d_scores, 0, KEY_NUM * sizeof(S))); + CUDA_CHECK(cudaMemset(d_vectors, 0, KEY_NUM * sizeof(V) * options.dim)); + + table->find(KEY_NUM, d_keys, d_vectors, d_found, d_scores, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + found_num = 0; + CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_scores, d_scores, KEY_NUM * sizeof(S), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_vectors, d_vectors, + KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyDeviceToHost)); + + for (int i = 0; i < KEY_NUM; i++) { + if (h_found[i]) found_num++; + ASSERT_EQ(h_scores[i], h_keys[i]); + for (int j = 0; j < options.dim; j++) { + ASSERT_EQ(h_vectors[i * options.dim + j], + static_cast(h_keys[i] * 0.00001)); + } + } + ASSERT_EQ(found_num, KEY_NUM); + + CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool))); + table->contains(KEY_NUM, d_keys, d_found, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + contains_num = 0; + CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool), + cudaMemcpyDeviceToHost)); + for (int i = 0; i < KEY_NUM; i++) { + if (h_found[i]) contains_num++; + } + ASSERT_EQ(contains_num, found_num); + + CUDA_CHECK(cudaMemset(d_keys, 0, KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMemset(d_scores, 0, KEY_NUM * sizeof(S))); + CUDA_CHECK(cudaMemset(d_vectors, 0, KEY_NUM * sizeof(V) * options.dim)); + dump_counter = table->export_batch(table->capacity(), 0, d_keys, d_vectors, + d_scores, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + ASSERT_EQ(dump_counter, KEY_NUM); + CUDA_CHECK(cudaMemcpy(h_keys, d_keys, KEY_NUM * sizeof(K), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_vectors, d_vectors, + KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_scores, d_scores, KEY_NUM * sizeof(S), + cudaMemcpyDeviceToHost)); + for (int i = 0; i < KEY_NUM; i++) { + ASSERT_EQ(h_scores[i], h_keys[i]); + for (int j = 0; j < options.dim; j++) { + ASSERT_EQ(h_vectors[i * options.dim + j], + static_cast(h_keys[i] * 0.00001)); + } + } + } + CUDA_CHECK(cudaStreamDestroy(stream)); + + CUDA_CHECK(cudaFreeHost(h_keys)); + CUDA_CHECK(cudaFreeHost(h_scores)); + CUDA_CHECK(cudaFreeHost(h_vectors)); + CUDA_CHECK(cudaFreeHost(h_found)); + + CUDA_CHECK(cudaFree(d_keys)); + CUDA_CHECK(cudaFree(d_scores)); + CUDA_CHECK(cudaFree(d_vectors)); + CUDA_CHECK(cudaFree(d_new_vectors)); + CUDA_CHECK(cudaFree(d_found)); + CUDA_CHECK(cudaDeviceSynchronize()); + + CudaCheckError(); +} + template void test_find_using_pipeline(int dim, bool load_scores) { using TableOptions = nv::merlin::HashTableOptions; @@ -3484,6 +3733,10 @@ TEST(MerlinHashTableTest, test_basic) { test_basic(16); test_basic(0); } +TEST(MerlinHashTableTest, test_basic_without_rehash) { + test_basic_without_rehash(16); + test_basic_without_rehash(0); +} TEST(MerlinHashTableTest, test_bucket_size) { test_bucket_size(); } TEST(MerlinHashTableTest, test_find_using_pipeline) { test_find_using_pipeline(224, true);