Skip to content

Commit

Permalink
[Fix] out-of-band issue and add test case for it
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jun 12, 2024
1 parent 1207b13 commit 97bea93
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 6 deletions.
13 changes: 7 additions & 6 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ __global__ void allocate_bucket_vectors(Bucket<K, V, S>* __restrict buckets,
template <class K, class V, class S>
__global__ void allocate_bucket_others(Bucket<K, V, S>* __restrict buckets,
size_t total_size_per_bucket,
size_t num_of_buckets_per_alloc,
size_t num_of_buckets,
const int start_index, uint8_t* address,
const uint32_t reserve_size,
const size_t bucket_max_size) {
for (size_t step = 0; step < num_of_buckets_per_alloc; step++) {
for (size_t step = 0; step < num_of_buckets; step++) {
size_t index = start_index + step;
buckets[index].digests_ = address;
buckets[index].keys_ =
Expand Down Expand Up @@ -238,12 +238,13 @@ void initialize_buckets(Table<K, V, S>** table, BaseAllocator* allocator,
*/
for (int i = start; i < end; i += (*table)->num_of_buckets_per_alloc) {
uint8_t* address = nullptr;
size_t num_of_buckets =
std::min(end - i, (*table)->num_of_buckets_per_alloc);
allocator->alloc(MemoryType::Device, (void**)&(address),
bucket_memory_size * (*table)->num_of_buckets_per_alloc);
bucket_memory_size * num_of_buckets);
allocate_bucket_others<K, V, S>
<<<1, 1>>>((*table)->buckets, bucket_memory_size,
(*table)->num_of_buckets_per_alloc, i, address, reserve_size,
bucket_max_size);
<<<1, 1>>>((*table)->buckets, bucket_memory_size, num_of_buckets, i,
address, reserve_size, bucket_max_size);
}
CUDA_CHECK(cudaDeviceSynchronize());

Expand Down
253 changes: 253 additions & 0 deletions tests/merlin_hashtable_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,255 @@ void test_basic(size_t max_hbm_for_vectors) {
CudaCheckError();
}

void test_basic_without_rehash(size_t max_hbm_for_vectors) {
constexpr uint64_t BUCKET_MAX_SIZE = 128;
constexpr uint64_t NUM_OF_BUCKETS_PER_ALLOC = 2048;
constexpr uint64_t INIT_CAPACITY =
64 * 1024 * 1024UL - (NUM_OF_BUCKETS_PER_ALLOC * BUCKET_MAX_SIZE) + 1;
constexpr uint64_t MAX_CAPACITY = INIT_CAPACITY;
constexpr uint64_t KEY_NUM = 1 * 1024 * 1024UL;
constexpr uint64_t TEST_TIMES = 1;

K* h_keys;
S* h_scores;
V* h_vectors;
bool* h_found;

TableOptions options;

options.init_capacity = INIT_CAPACITY;
options.max_capacity = MAX_CAPACITY;
options.dim = DIM;
options.max_bucket_size = BUCKET_MAX_SIZE;
options.max_hbm_for_vectors = nv::merlin::GB(max_hbm_for_vectors);
options.reserved_key_start_bit = 2;
options.num_of_buckets_per_alloc = NUM_OF_BUCKETS_PER_ALLOC;

using Table = nv::merlin::HashTable<K, V, S, EvictStrategy::kCustomized>;

CUDA_CHECK(cudaMallocHost(&h_keys, KEY_NUM * sizeof(K)));
CUDA_CHECK(cudaMallocHost(&h_scores, KEY_NUM * sizeof(S)));
CUDA_CHECK(cudaMallocHost(&h_vectors, KEY_NUM * sizeof(V) * options.dim));
CUDA_CHECK(cudaMallocHost(&h_found, KEY_NUM * sizeof(bool)));

CUDA_CHECK(cudaMemset(h_vectors, 0, KEY_NUM * sizeof(V) * options.dim));

test_util::create_random_keys<K, S, V, DIM>(h_keys, h_scores, h_vectors,
KEY_NUM);

K* d_keys;
S* d_scores = nullptr;
V* d_vectors;
V* d_new_vectors;
bool* d_found;
size_t dump_counter = 0;

CUDA_CHECK(cudaMalloc(&d_keys, KEY_NUM * sizeof(K)));
CUDA_CHECK(cudaMalloc(&d_scores, KEY_NUM * sizeof(S)));
CUDA_CHECK(cudaMalloc(&d_vectors, KEY_NUM * sizeof(V) * options.dim));
CUDA_CHECK(cudaMalloc(&d_new_vectors, KEY_NUM * sizeof(V) * options.dim));
CUDA_CHECK(cudaMalloc(&d_found, KEY_NUM * sizeof(bool)));

CUDA_CHECK(
cudaMemcpy(d_keys, h_keys, KEY_NUM * sizeof(K), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_scores, h_scores, KEY_NUM * sizeof(S),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_vectors, h_vectors, KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyHostToDevice));

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));

cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));

uint64_t total_size = 0;
for (int i = 0; i < TEST_TIMES; i++) {
std::unique_ptr<Table> table = std::make_unique<Table>();
table->init(options);

ASSERT_EQ(table->bucket_count(),
522241); // 1 + (INIT_CAPACITY / options.bucket_max_size)
total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ(total_size, 0);

table->insert_or_assign(KEY_NUM, d_keys, d_vectors, d_scores, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));

total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ(total_size, KEY_NUM);

CUDA_CHECK(cudaMemset(d_vectors, 0, KEY_NUM * sizeof(V) * options.dim));
table->find(KEY_NUM, d_keys, d_vectors, d_found, nullptr, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int found_num = 0;
CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(h_scores, d_scores, KEY_NUM * sizeof(S),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(h_vectors, d_vectors,
KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyDeviceToHost));

for (int i = 0; i < KEY_NUM; i++) {
if (h_found[i]) found_num++;
ASSERT_EQ(h_scores[i], h_keys[i]);
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors[i * options.dim + j],
static_cast<float>(h_keys[i] * 0.00001));
}
}

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool),
cudaMemcpyDeviceToHost));
for (int i = 0; i < KEY_NUM; i++) {
if (h_found[i]) contains_num++;
}
ASSERT_EQ(contains_num, found_num);

CUDA_CHECK(cudaMemset(d_new_vectors, 2, KEY_NUM * sizeof(V) * options.dim));
table->insert_or_assign(KEY_NUM, d_keys,
reinterpret_cast<float*>(d_new_vectors), d_scores,
stream);
CUDA_CHECK(cudaStreamSynchronize(stream));

total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ(total_size, KEY_NUM);

CUDA_CHECK(cudaMemset(d_new_vectors, 0, KEY_NUM * sizeof(V) * options.dim));
table->find(KEY_NUM, d_keys, reinterpret_cast<float*>(d_new_vectors),
d_found, nullptr, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(h_vectors, d_new_vectors,
KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyDeviceToHost));
found_num = 0;
uint32_t i_value = 0x2020202;
for (int i = 0; i < KEY_NUM; i++) {
if (h_found[i]) found_num++;
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors[i * options.dim + j],
*(reinterpret_cast<float*>(&i_value)));
}
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
contains_num = 0;
CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool),
cudaMemcpyDeviceToHost));
for (int i = 0; i < KEY_NUM; i++) {
if (h_found[i]) contains_num++;
}
ASSERT_EQ(contains_num, found_num);

table->accum_or_assign(KEY_NUM, d_keys, d_vectors, d_found, d_scores,
stream);
CUDA_CHECK(cudaStreamSynchronize(stream));

total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ(total_size, KEY_NUM);

table->erase(KEY_NUM >> 1, d_keys, stream);
size_t total_size_after_erase = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ(total_size_after_erase, total_size >> 1);

table->clear(stream);
total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ(total_size, 0);

table->insert_or_assign(KEY_NUM, d_keys, d_vectors, d_scores, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));

CUDA_CHECK(cudaMemset(d_scores, 0, KEY_NUM * sizeof(S)));
CUDA_CHECK(cudaMemset(d_vectors, 0, KEY_NUM * sizeof(V) * options.dim));

table->find(KEY_NUM, d_keys, d_vectors, d_found, d_scores, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));

found_num = 0;
CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(h_scores, d_scores, KEY_NUM * sizeof(S),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(h_vectors, d_vectors,
KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyDeviceToHost));

for (int i = 0; i < KEY_NUM; i++) {
if (h_found[i]) found_num++;
ASSERT_EQ(h_scores[i], h_keys[i]);
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors[i * options.dim + j],
static_cast<float>(h_keys[i] * 0.00001));
}
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
contains_num = 0;
CUDA_CHECK(cudaMemcpy(h_found, d_found, KEY_NUM * sizeof(bool),
cudaMemcpyDeviceToHost));
for (int i = 0; i < KEY_NUM; i++) {
if (h_found[i]) contains_num++;
}
ASSERT_EQ(contains_num, found_num);

CUDA_CHECK(cudaMemset(d_keys, 0, KEY_NUM * sizeof(K)));
CUDA_CHECK(cudaMemset(d_scores, 0, KEY_NUM * sizeof(S)));
CUDA_CHECK(cudaMemset(d_vectors, 0, KEY_NUM * sizeof(V) * options.dim));
dump_counter = table->export_batch(table->capacity(), 0, d_keys, d_vectors,
d_scores, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));

ASSERT_EQ(dump_counter, KEY_NUM);
CUDA_CHECK(cudaMemcpy(h_keys, d_keys, KEY_NUM * sizeof(K),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(h_vectors, d_vectors,
KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(h_scores, d_scores, KEY_NUM * sizeof(S),
cudaMemcpyDeviceToHost));
for (int i = 0; i < KEY_NUM; i++) {
ASSERT_EQ(h_scores[i], h_keys[i]);
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors[i * options.dim + j],
static_cast<float>(h_keys[i] * 0.00001));
}
}
}
CUDA_CHECK(cudaStreamDestroy(stream));

CUDA_CHECK(cudaFreeHost(h_keys));
CUDA_CHECK(cudaFreeHost(h_scores));
CUDA_CHECK(cudaFreeHost(h_vectors));
CUDA_CHECK(cudaFreeHost(h_found));

CUDA_CHECK(cudaFree(d_keys));
CUDA_CHECK(cudaFree(d_scores));
CUDA_CHECK(cudaFree(d_vectors));
CUDA_CHECK(cudaFree(d_new_vectors));
CUDA_CHECK(cudaFree(d_found));
CUDA_CHECK(cudaDeviceSynchronize());

CudaCheckError();
}

template <typename V>
void test_find_using_pipeline(int dim, bool load_scores) {
using TableOptions = nv::merlin::HashTableOptions;
Expand Down Expand Up @@ -3484,6 +3733,10 @@ TEST(MerlinHashTableTest, test_basic) {
test_basic(16);
test_basic(0);
}
TEST(MerlinHashTableTest, test_basic_without_rehash) {
test_basic_without_rehash(16);
test_basic_without_rehash(0);
}
TEST(MerlinHashTableTest, test_bucket_size) { test_bucket_size(); }
TEST(MerlinHashTableTest, test_find_using_pipeline) {
test_find_using_pipeline<int32_t>(224, true);
Expand Down

0 comments on commit 97bea93

Please sign in to comment.