From c0e59677cc4df5b88fd55519dd04effd8f056c98 Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 20 Jul 2023 19:18:36 +0800 Subject: [PATCH] [Fix] rehash dead loop on specific MAX_CAPACITY. --- include/merlin_hashtable.cuh | 1 + tests/merlin_hashtable_test.cc.cu | 74 +++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 97c80342a..b026f8be5 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -1421,6 +1421,7 @@ class HashTable { */ void reserve(const size_type new_capacity, cudaStream_t stream = 0) { if (reach_max_capacity_ || new_capacity > options_.max_capacity) { + reach_max_capacity_ = (capacity() * 2 > options_.max_capacity); return; } diff --git a/tests/merlin_hashtable_test.cc.cu b/tests/merlin_hashtable_test.cc.cu index 37b71a3ae..94b289a2d 100644 --- a/tests/merlin_hashtable_test.cc.cu +++ b/tests/merlin_hashtable_test.cc.cu @@ -866,6 +866,76 @@ void test_rehash_on_big_batch(size_t max_hbm_for_vectors) { CudaCheckError(); } +void test_rehash_on_big_batch_specific(size_t max_hbm_for_vectors) { + constexpr uint64_t INIT_CAPACITY = 50000; + constexpr uint64_t MAX_CAPACITY = 100000; + constexpr uint64_t EXPECTED_MAX_CAPACITY = 65536; + constexpr uint64_t KEY_NUM = 50000; + K* h_keys; + S* h_scores; + V* h_vectors; + + TableOptions options; + + options.init_capacity = INIT_CAPACITY; + options.max_capacity = MAX_CAPACITY; + options.dim = DIM; + options.max_bucket_size = 128; + options.max_load_factor = 0.6; + options.max_hbm_for_vectors = nv::merlin::GB(max_hbm_for_vectors); + options.evict_strategy = nv::merlin::EvictStrategy::kLru; + + CUDA_CHECK(cudaMallocHost(&h_keys, KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMallocHost(&h_scores, KEY_NUM * sizeof(S))); + CUDA_CHECK(cudaMallocHost(&h_vectors, KEY_NUM * sizeof(V) * options.dim)); + + K* d_keys; + S* d_scores = nullptr; + V* d_vectors; + + CUDA_CHECK(cudaMalloc(&d_keys, KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMalloc(&d_scores, KEY_NUM * sizeof(S))); + CUDA_CHECK(cudaMalloc(&d_vectors, KEY_NUM * sizeof(V) * options.dim)); + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + uint64_t total_size = 0; + std::unique_ptr table = std::make_unique
(); + table->init(options); + + test_util::create_random_keys(h_keys, h_scores, h_vectors, + KEY_NUM); + + CUDA_CHECK( + cudaMemcpy(d_keys, h_keys, KEY_NUM * sizeof(K), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_scores, h_scores, KEY_NUM * sizeof(S), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_vectors, h_vectors, KEY_NUM * sizeof(V) * options.dim, + cudaMemcpyHostToDevice)); + + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(total_size, 0); + + table->insert_or_assign(KEY_NUM, d_keys, d_vectors, nullptr, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_EQ(table->capacity(), EXPECTED_MAX_CAPACITY); + + CUDA_CHECK(cudaStreamDestroy(stream)); + + CUDA_CHECK(cudaFreeHost(h_keys)); + CUDA_CHECK(cudaFreeHost(h_scores)); + CUDA_CHECK(cudaFreeHost(h_vectors)); + + CUDA_CHECK(cudaFree(d_keys)); + CUDA_CHECK(cudaFree(d_scores)); + CUDA_CHECK(cudaFree(d_vectors)); + CUDA_CHECK(cudaDeviceSynchronize()); + + CudaCheckError(); +} + void test_dynamic_rehash_on_multi_threads(size_t max_hbm_for_vectors) { constexpr uint64_t BUCKET_MAX_SIZE = 128ul; constexpr uint64_t INIT_CAPACITY = 4 * 1024 - BUCKET_MAX_SIZE - 1; @@ -2592,6 +2662,10 @@ TEST(MerlinHashTableTest, test_rehash) { test_rehash(16); test_rehash(0); } +TEST(MerlinHashTableTest, test_rehash_on_big_batch_specific) { + test_rehash_on_big_batch_specific(16); + test_rehash_on_big_batch_specific(0); +} TEST(MerlinHashTableTest, test_rehash_on_big_batch) { test_rehash_on_big_batch(16); test_rehash_on_big_batch(0);