Skip to content

Commit

Permalink
patch1
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Oct 11, 2023
1 parent 31be3e8 commit 8a5cd29
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 32 deletions.
68 changes: 36 additions & 32 deletions include/merlin/core_kernels/contains.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,10 @@ __global__ void contains_kernel_pipeline(Bucket<K, V, S>* buckets,
/* Step3: check possible keys, and prefecth the value and score */
if (i > 0) {
key_idx_block -= 1;
int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block;
K target_key = sm_target_keys[key_idx_block];
int possible_num = sm_counts[key_idx_block];
sm_founds[key_idx_block] = 0;
// sm_founds[key_idx_block] = 0;
__pipeline_wait_prior(2);
bool found_flag = false;
if (rank < possible_num) {
Expand All @@ -191,27 +192,29 @@ __global__ void contains_kernel_pipeline(Bucket<K, V, S>* buckets,
}
}
int found_vote = g.ballot(found_flag);
if (found_vote) {
sm_founds[key_idx_block] = 1;
}
// if (found_vote) {
// sm_founds[key_idx_block] = 1;
// }
founds[key_idx_grid] = (found_vote > 0);
}

/* Step4: write `found` */
if (i > 1) {
key_idx_block -= 1;
int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block;
int found_flag = sm_founds[key_idx_block];
__pipeline_wait_prior(2);
founds[key_idx_grid] = (found_flag > 0);
}
// if (i > 1) {
// key_idx_block -= 1;
// int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block;
// int found_flag = sm_founds[key_idx_block];
// __pipeline_wait_prior(2);
// founds[key_idx_grid] = (found_flag > 0);
// }
} // End loop

/* Pipeline emptying: step3, i = loop_num */
{
int key_idx_block = groupID * GROUP_SIZE + (loop_num - 1);
int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block;
K target_key = sm_target_keys[key_idx_block];
int possible_num = sm_counts[key_idx_block];
sm_founds[key_idx_block] = 0;
// sm_founds[key_idx_block] = 0;
__pipeline_wait_prior(0);
bool found_flag = false;
if (rank < possible_num) {
Expand All @@ -222,28 +225,29 @@ __global__ void contains_kernel_pipeline(Bucket<K, V, S>* buckets,
}
}
int found_vote = g.ballot(found_flag);
if (found_vote) {
sm_founds[key_idx_block] = 1;
}
}
__pipeline_commit();

/* Pipeline emptying: step4, i = loop_num */
if (loop_num > 1) {
int key_idx_block = groupID * GROUP_SIZE + loop_num - 2;
int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block;
int found_flag = sm_founds[key_idx_block];
__pipeline_wait_prior(0);
founds[key_idx_grid] = (found_flag > 0);
// if (found_vote) {
// sm_founds[key_idx_block] = 1;
// }
founds[key_idx_grid] = (found_vote > 0);
}
// __pipeline_commit();

/* Pipeline emptying: step4, i = loop_num + 1 */
{
int key_idx_block = groupID * GROUP_SIZE + loop_num - 1;
int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block;
int found_flag = sm_founds[key_idx_block];
founds[key_idx_grid] = (found_flag > 0);
}
// /* Pipeline emptying: step4, i = loop_num */
// if (loop_num > 1) {
// int key_idx_block = groupID * GROUP_SIZE + loop_num - 2;
// int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block;
// int found_flag = sm_founds[key_idx_block];
// __pipeline_wait_prior(0);
// founds[key_idx_grid] = (found_flag > 0);
// }
//
// /* Pipeline emptying: step4, i = loop_num + 1 */
// {
// int key_idx_block = groupID * GROUP_SIZE + loop_num - 1;
// int key_idx_grid = blockIdx.x * blockDim.x + key_idx_block;
// int found_flag = sm_founds[key_idx_block];
// founds[key_idx_grid] = (found_flag > 0);
// }
} // End function

template <typename K, typename V, typename S>
Expand Down
6 changes: 6 additions & 0 deletions tests/insert_and_evict_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ void CheckInsertAndEvict(Table* table, K* keys, V* values, S* scores,
std::cout << "filtered_len:" << filtered_len
<< ", miss counter:" << len - found_counter << std::endl;

CUDA_CHECK(cudaMemset(d_tmp_founds, 0, len * sizeof(bool)));
table->contains(len, keys, d_tmp_founds, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_counter = 0;
Expand Down Expand Up @@ -475,6 +476,7 @@ void CheckInsertAndEvictOnLfu(Table* table,
if (h_tmp_founds[i]) found_counter++;
}

CUDA_CHECK(cudaMemset(d_tmp_founds, 0, len * sizeof(bool)));
table->contains(len, keys, d_tmp_founds, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_counter = 0;
Expand Down Expand Up @@ -749,6 +751,7 @@ void CheckInsertAndEvictOnEpochLru(Table* table,
<< ", miss counter:" << len - found_counter << std::endl;
ASSERT_EQ(len, found_counter);

CUDA_CHECK(cudaMemset(d_tmp_founds, 0, len * sizeof(bool)));
table->contains(len, keys, d_tmp_founds, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_counter = 0;
Expand Down Expand Up @@ -1029,6 +1032,7 @@ void CheckInsertAndEvictOnEpochLfu(
<< ", pre_data miss counter:" << len - found_counter << std::endl;
ASSERT_EQ(len, found_counter);

CUDA_CHECK(cudaMemset(d_tmp_founds, 0, len * sizeof(bool)));
table->contains(len, pre_data_buffer->keys_ptr(), d_tmp_founds, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_counter = 0;
Expand Down Expand Up @@ -1059,6 +1063,7 @@ void CheckInsertAndEvictOnEpochLfu(
<< ", miss counter:" << len - found_counter << std::endl;
ASSERT_EQ(len, found_counter);

CUDA_CHECK(cudaMemset(d_tmp_founds, 0, len * sizeof(bool)));
table->contains(len, keys, d_tmp_founds, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_counter = 0;
Expand Down Expand Up @@ -1606,6 +1611,7 @@ void BatchCheckFind(Table* table, K* keys, V* values, S* scores, size_t len,
}
ASSERT_EQ(value_diff_cnt, 0);

CUDA_CHECK(cudaMemset(d_tmp_founds, 0, cap * sizeof(bool)));
table->contains(cap, keys, d_tmp_founds, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down
13 changes: 13 additions & 0 deletions tests/merlin_hashtable_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ void test_basic(size_t max_hbm_for_vectors) {
}
}

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -261,6 +262,7 @@ void test_basic(size_t max_hbm_for_vectors) {
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
contains_num = 0;
Expand Down Expand Up @@ -317,6 +319,7 @@ void test_basic(size_t max_hbm_for_vectors) {
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
contains_num = 0;
Expand Down Expand Up @@ -466,6 +469,7 @@ void test_find_using_pipeline(int dim, bool load_scores) {
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -698,6 +702,7 @@ void test_erase_if_pred(size_t max_hbm_for_vectors) {
}
ASSERT_EQ(found_num, (BUCKET_MAX_SIZE - erase_num));

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -839,6 +844,7 @@ void test_rehash(size_t max_hbm_for_vectors) {
}
ASSERT_EQ(found_num, BUCKET_MAX_SIZE);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -983,6 +989,7 @@ void test_rehash_on_big_batch(size_t max_hbm_for_vectors) {
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -1171,6 +1178,7 @@ void test_dynamic_rehash_on_multi_threads(size_t max_hbm_for_vectors) {
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -1315,6 +1323,7 @@ void test_export_batch_if(size_t max_hbm_for_vectors) {
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -1492,6 +1501,7 @@ void test_basic_for_cpu_io() {
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -2908,6 +2918,7 @@ void test_insert_or_assign_multi_threads(size_t max_hbm_for_vectors,
}
}

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -3040,6 +3051,7 @@ void test_insert_or_assign_multi_threads(size_t max_hbm_for_vectors,
}
}

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down Expand Up @@ -3375,6 +3387,7 @@ void test_bucket_size(bool load_scores = true) {
}
ASSERT_EQ(found_num, KEY_NUM);

CUDA_CHECK(cudaMemset(d_found, 0, KEY_NUM * sizeof(bool)));
table->contains(KEY_NUM, d_keys, d_found, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
int contains_num = 0;
Expand Down

0 comments on commit 8a5cd29

Please sign in to comment.