diff --git a/review/pr-177/README.html b/review/pr-177/README.html
index 23b869bf4..67584f26b 100644
--- a/review/pr-177/README.html
+++ b/review/pr-177/README.html
@@ -435,31 +435,31 @@
On pure HBM mode:0.50 |
-1.153 |
-2.580 |
-1.706 |
-1.798 |
-4.111 |
-1.801 |
-1.083 |
+1.093 |
+2.470 |
+1.478 |
+1.770 |
+3.726 |
+1.447 |
+1.075 |
0.75 |
-1.010 |
-2.561 |
-0.683 |
-1.863 |
-1.939 |
-1.305 |
-0.994 |
+1.045 |
+2.452 |
+1.335 |
+1.807 |
+3.374 |
+1.309 |
+1.013 |
1.00 |
-0.366 |
-2.580 |
-0.371 |
-1.894 |
-0.932 |
-0.393 |
-0.517 |
+0.655 |
+2.481 |
+0.612 |
+1.815 |
+1.865 |
+0.619 |
+0.511 |
@@ -473,19 +473,19 @@ On pure HBM mode:0.50 |
-2.180 |
-18.766 |
-3.046 |
+2.087 |
+12.258 |
+3.121 |
0.75 |
-2.179 |
-18.873 |
-3.006 |
+2.045 |
+12.447 |
+3.094 |
1.00 |
-2.081 |
-2.813 |
-2.988 |
+1.950 |
+2.657 |
+3.096 |
@@ -506,31 +506,31 @@ On pure HBM mode:0.50 |
-1.068 |
-2.325 |
-1.277 |
-1.759 |
-4.113 |
-1.807 |
-0.932 |
+0.961 |
+2.272 |
+1.278 |
+1.706 |
+3.718 |
+1.435 |
+0.931 |
0.75 |
-0.862 |
-2.309 |
-0.651 |
-1.754 |
-1.934 |
-1.304 |
-0.874 |
+0.930 |
+2.238 |
+1.177 |
+1.693 |
+3.369 |
+1.316 |
+0.866 |
1.00 |
-0.360 |
-2.349 |
-0.346 |
-1.825 |
-0.926 |
-0.373 |
-0.470 |
+0.646 |
+2.321 |
+0.572 |
+1.783 |
+1.873 |
+0.618 |
+0.469 |
@@ -544,19 +544,19 @@ On pure HBM mode:0.50 |
-0.701 |
-15.185 |
-3.013 |
+0.692 |
+10.784 |
+3.100 |
0.75 |
-0.578 |
-13.840 |
-2.981 |
+0.569 |
+10.240 |
+3.075 |
1.00 |
-0.567 |
-0.777 |
-2.972 |
+0.551 |
+0.765 |
+3.096 |
@@ -577,31 +577,31 @@ On pure HBM mode:0.50 |
-0.862 |
-2.041 |
-0.930 |
-1.543 |
-4.399 |
-1.854 |
-0.807 |
+0.834 |
+1.982 |
+1.113 |
+1.499 |
+3.950 |
+1.502 |
+0.805 |
0.75 |
-0.662 |
-2.022 |
-0.579 |
-1.541 |
-1.983 |
-1.299 |
-0.771 |
+0.801 |
+1.951 |
+1.033 |
+1.493 |
+3.545 |
+1.359 |
+0.773 |
1.00 |
-0.334 |
-2.052 |
-0.335 |
-1.590 |
-0.939 |
-0.393 |
-0.473 |
+0.621 |
+2.021 |
+0.608 |
+1.541 |
+1.965 |
+0.613 |
+0.481 |
@@ -615,19 +615,19 @@ On pure HBM mode:0.50 |
-0.321 |
-10.791 |
-3.124 |
+0.316 |
+8.199 |
+3.239 |
0.75 |
-0.298 |
-10.864 |
-3.101 |
+0.296 |
+8.549 |
+3.198 |
1.00 |
-0.293 |
-0.393 |
-3.096 |
+0.288 |
+0.395 |
+3.225 |
@@ -651,27 +651,27 @@ On HBM+HMEM hybrid mode:0.50 |
0.083 |
-0.123 |
-0.115 |
+0.124 |
+0.109 |
0.131 |
-4.015 |
-1.804 |
+3.705 |
+1.435 |
0.75 |
-0.082 |
-0.123 |
+0.083 |
+0.122 |
0.111 |
0.129 |
-1.911 |
-1.136 |
+3.221 |
+1.274 |
1.00 |
-0.069 |
-0.110 |
-0.085 |
-0.106 |
-0.929 |
-0.393 |
+0.073 |
+0.123 |
+0.095 |
+0.126 |
+1.854 |
+0.617 |
@@ -685,19 +685,19 @@ On HBM+HMEM hybrid mode:0.50 |
-0.322 |
-10.720 |
-3.050 |
+0.318 |
+8.086 |
+3.122 |
0.75 |
-0.299 |
-10.587 |
-3.020 |
+0.294 |
+5.549 |
+3.111 |
1.00 |
-0.293 |
-0.394 |
-2.994 |
+0.287 |
+0.393 |
+3.075 |
@@ -718,27 +718,27 @@ On HBM+HMEM hybrid mode:0.50 |
0.049 |
-0.073 |
-0.048 |
-0.070 |
-3.530 |
-1.730 |
+0.069 |
+0.049 |
+0.069 |
+3.484 |
+1.370 |
0.75 |
0.049 |
-0.072 |
-0.048 |
0.069 |
-1.849 |
-1.261 |
+0.049 |
+0.069 |
+3.116 |
+1.242 |
1.00 |
-0.044 |
-0.068 |
-0.044 |
-0.062 |
-0.911 |
-0.393 |
+0.047 |
+0.072 |
+0.047 |
+0.070 |
+1.771 |
+0.607 |
@@ -752,19 +752,19 @@ On HBM+HMEM hybrid mode:0.50 |
-0.321 |
-11.148 |
-2.908 |
+0.316 |
+8.181 |
+3.073 |
0.75 |
-0.299 |
-11.269 |
-2.898 |
+0.293 |
+8.950 |
+3.052 |
1.00 |
-0.293 |
-0.396 |
-2.782 |
+0.292 |
+0.394 |
+3.026 |
diff --git a/review/pr-177/_sources/README.md.txt b/review/pr-177/_sources/README.md.txt
index 76a647607..798819e34 100644
--- a/review/pr-177/_sources/README.md.txt
+++ b/review/pr-177/_sources/README.md.txt
@@ -224,43 +224,43 @@ For Benchmark:
| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|-----------------:|
-| 0.50 | 1.153 | 2.580 | 1.706 | 1.798 | 4.111 | 1.801 | 1.083 |
-| 0.75 | 1.010 | 2.561 | 0.683 | 1.863 | 1.939 | 1.305 | 0.994 |
-| 1.00 | 0.366 | 2.580 | 0.371 | 1.894 | 0.932 | 0.393 | 0.517 |
+| 0.50 | 1.093 | 2.470 | 1.478 | 1.770 | 3.726 | 1.447 | 1.075 |
+| 0.75 | 1.045 | 2.452 | 1.335 | 1.807 | 3.374 | 1.309 | 1.013 |
+| 1.00 | 0.655 | 2.481 | 0.612 | 1.815 | 1.865 | 0.619 | 0.511 |
-| λ | export_batch | export_batch_if | contains |
-|-----:|-------------:|----------------:|----------:|
-| 0.50 | 2.180 | 18.766 | 3.046 |
-| 0.75 | 2.179 | 18.873 | 3.006 |
-| 1.00 | 2.081 | 2.813 | 2.988 |
+| λ | export_batch | export_batch_if | contains |
+|-----:|-------------:|----------------:|---------:|
+| 0.50 | 2.087 | 12.258 | 3.121 |
+| 0.75 | 2.045 | 12.447 | 3.094 |
+| 1.00 | 1.950 | 2.657 | 3.096 |
* dim = 32, capacity = 128 Million-KV, HBM = 16 GB, HMEM = 0 GB
| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|-----------------:|
-| 0.50 | 1.068 | 2.325 | 1.277 | 1.759 | 4.113 | 1.807 | 0.932 |
-| 0.75 | 0.862 | 2.309 | 0.651 | 1.754 | 1.934 | 1.304 | 0.874 |
-| 1.00 | 0.360 | 2.349 | 0.346 | 1.825 | 0.926 | 0.373 | 0.470 |
+| 0.50 | 0.961 | 2.272 | 1.278 | 1.706 | 3.718 | 1.435 | 0.931 |
+| 0.75 | 0.930 | 2.238 | 1.177 | 1.693 | 3.369 | 1.316 | 0.866 |
+| 1.00 | 0.646 | 2.321 | 0.572 | 1.783 | 1.873 | 0.618 | 0.469 |
-| λ | export_batch | export_batch_if | contains |
-|-----:|-------------:|----------------:|----------:|
-| 0.50 | 0.701 | 15.185 | 3.013 |
-| 0.75 | 0.578 | 13.840 | 2.981 |
-| 1.00 | 0.567 | 0.777 | 2.972 |
+| λ | export_batch | export_batch_if | contains |
+|-----:|-------------:|----------------:|---------:|
+| 0.50 | 0.692 | 10.784 | 3.100 |
+| 0.75 | 0.569 | 10.240 | 3.075 |
+| 1.00 | 0.551 | 0.765 | 3.096 |
* dim = 64, capacity = 64 Million-KV, HBM = 16 GB, HMEM = 0 GB
| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|-----------------:|
-| 0.50 | 0.862 | 2.041 | 0.930 | 1.543 | 4.399 | 1.854 | 0.807 |
-| 0.75 | 0.662 | 2.022 | 0.579 | 1.541 | 1.983 | 1.299 | 0.771 |
-| 1.00 | 0.334 | 2.052 | 0.335 | 1.590 | 0.939 | 0.393 | 0.473 |
+| 0.50 | 0.834 | 1.982 | 1.113 | 1.499 | 3.950 | 1.502 | 0.805 |
+| 0.75 | 0.801 | 1.951 | 1.033 | 1.493 | 3.545 | 1.359 | 0.773 |
+| 1.00 | 0.621 | 2.021 | 0.608 | 1.541 | 1.965 | 0.613 | 0.481 |
-| λ | export_batch | export_batch_if | contains |
-|-----:|-------------:|----------------:|----------:|
-| 0.50 | 0.321 | 10.791 | 3.124 |
-| 0.75 | 0.298 | 10.864 | 3.101 |
-| 1.00 | 0.293 | 0.393 | 3.096 |
+| λ | export_batch | export_batch_if | contains |
+|-----:|-------------:|----------------:|---------:|
+| 0.50 | 0.316 | 8.199 | 3.239 |
+| 0.75 | 0.296 | 8.549 | 3.198 |
+| 1.00 | 0.288 | 0.395 | 3.225 |
### On HBM+HMEM hybrid mode:
@@ -268,29 +268,29 @@ For Benchmark:
| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|
-| 0.50 | 0.083 | 0.123 | 0.115 | 0.131 | 4.015 | 1.804 |
-| 0.75 | 0.082 | 0.123 | 0.111 | 0.129 | 1.911 | 1.136 |
-| 1.00 | 0.069 | 0.110 | 0.085 | 0.106 | 0.929 | 0.393 |
+| 0.50 | 0.083 | 0.124 | 0.109 | 0.131 | 3.705 | 1.435 |
+| 0.75 | 0.083 | 0.122 | 0.111 | 0.129 | 3.221 | 1.274 |
+| 1.00 | 0.073 | 0.123 | 0.095 | 0.126 | 1.854 | 0.617 |
-| λ | export_batch | export_batch_if | contains |
-|-----:|-------------:|----------------:|----------:|
-| 0.50 | 0.322 | 10.720 | 3.050 |
-| 0.75 | 0.299 | 10.587 | 3.020 |
-| 1.00 | 0.293 | 0.394 | 2.994 |
+| λ | export_batch | export_batch_if | contains |
+|-----:|-------------:|----------------:|---------:|
+| 0.50 | 0.318 | 8.086 | 3.122 |
+| 0.75 | 0.294 | 5.549 | 3.111 |
+| 1.00 | 0.287 | 0.393 | 3.075 |
* dim = 64, capacity = 512 Million-KV, HBM = 32 GB, HMEM = 96 GB
| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|
-| 0.50 | 0.049 | 0.073 | 0.048 | 0.070 | 3.530 | 1.730 |
-| 0.75 | 0.049 | 0.072 | 0.048 | 0.069 | 1.849 | 1.261 |
-| 1.00 | 0.044 | 0.068 | 0.044 | 0.062 | 0.911 | 0.393 |
-
-| λ | export_batch | export_batch_if | contains |
-|-----:|-------------:|----------------:|----------:|
-| 0.50 | 0.321 | 11.148 | 2.908 |
-| 0.75 | 0.299 | 11.269 | 2.898 |
-| 1.00 | 0.293 | 0.396 | 2.782 |
+| 0.50 | 0.049 | 0.069 | 0.049 | 0.069 | 3.484 | 1.370 |
+| 0.75 | 0.049 | 0.069 | 0.049 | 0.069 | 3.116 | 1.242 |
+| 1.00 | 0.047 | 0.072 | 0.047 | 0.070 | 1.771 | 0.607 |
+
+| λ | export_batch | export_batch_if | contains |
+|-----:|-------------:|----------------:|---------:|
+| 0.50 | 0.316 | 8.181 | 3.073 |
+| 0.75 | 0.293 | 8.950 | 3.052 |
+| 1.00 | 0.292 | 0.394 | 3.026 |
### Support and Feedback:
diff --git a/review/pr-177/_sources/api/program_listing_file_merlin_hashtable.cuh.rst.txt b/review/pr-177/_sources/api/program_listing_file_merlin_hashtable.cuh.rst.txt
index ec63cb12e..44e7a9eeb 100644
--- a/review/pr-177/_sources/api/program_listing_file_merlin_hashtable.cuh.rst.txt
+++ b/review/pr-177/_sources/api/program_listing_file_merlin_hashtable.cuh.rst.txt
@@ -208,7 +208,7 @@ Program Listing for File merlin_hashtable.cuh
const key_type* keys, // (n)
const value_type* values, // (n, DIM)
const score_type* scores = nullptr, // (n)
- cudaStream_t stream = 0,
+ cudaStream_t stream = 0, bool unique_key = true,
bool ignore_evict_strategy = false) {
if (n == 0) {
return;
@@ -226,8 +226,6 @@ Program Listing for File merlin_hashtable.cuh
insert_unique_lock lock(mutex_, stream);
if (is_fast_mode()) {
- using Selector = SelectUpsertKernelWithIO;
static thread_local int step_counter = 0;
static thread_local float load_factor = 0.0;
@@ -235,20 +233,55 @@ Program Listing for File merlin_hashtable.cuh
load_factor = fast_load_factor(0, stream, false);
}
- Selector::execute_kernel(
- load_factor, options_.block_size, options_.max_bucket_size,
- table_->buckets_num, options_.dim, stream, n, d_table_,
- table_->buckets, keys, reinterpret_cast(values),
- scores, EvictStrategyParam.global_epoch);
+ using Selector = KernelSelector_Upsert;
+ if (Selector::callable(unique_key,
+ static_cast(options_.max_bucket_size),
+ static_cast(options_.dim))) {
+ typename Selector::Params kernelParams(
+ load_factor, table_->buckets, table_->buckets_size,
+ table_->buckets_num,
+ static_cast(options_.max_bucket_size),
+ static_cast(options_.dim), keys, values, scores, n,
+ EvictStrategyParam.global_epoch);
+ Selector::select_kernel(kernelParams, stream);
+ } else {
+ using Selector = SelectUpsertKernelWithIO;
+ Selector::execute_kernel(
+ load_factor, options_.block_size, options_.max_bucket_size,
+ table_->buckets_num, options_.dim, stream, n, d_table_,
+ table_->buckets, keys, reinterpret_cast(values),
+ scores, EvictStrategyParam.global_epoch);
+ }
} else {
- const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int))};
+ const size_type dev_ws_size{
+ n * (sizeof(value_type*) + sizeof(int) + sizeof(key_type*))};
auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
auto d_dst{dev_ws.get(0)};
- auto d_src_offset{reinterpret_cast(d_dst + n)};
+ auto keys_ptr{reinterpret_cast(d_dst + n)};
+ auto d_src_offset{reinterpret_cast(keys_ptr + n)};
CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));
- {
+ constexpr uint32_t MinBucketCapacityFilter =
+ sizeof(VecD_Load) / sizeof(D);
+
+ bool filter_condition =
+ unique_key && options_.max_bucket_size >= MinBucketCapacityFilter &&
+ !options_.io_by_cpu;
+
+ if (filter_condition) {
+ constexpr uint32_t BLOCK_SIZE = 128;
+
+ upsert_kernel_lock_key_hybrid
+ <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
+ table_->buckets, table_->buckets_size, table_->buckets_num,
+ options_.max_bucket_size, options_.dim, keys, d_dst, scores,
+ keys_ptr, d_src_offset, n, EvictStrategyParam.global_epoch);
+
+ } else {
const size_t block_size = options_.block_size;
const size_t N = n * TILE_SIZE;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
@@ -269,7 +302,16 @@ Program Listing for File merlin_hashtable.cuh
d_src_offset_ptr, thrust::less());
}
- if (options_.io_by_cpu) {
+ if (filter_condition) {
+ const size_t block_size = options_.io_block_size;
+ const size_t N = n * dim();
+ const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
+
+ write_kernel_unlock_key
+ <<>>(values, d_dst, d_src_offset,
+ dim(), keys, keys_ptr, N);
+
+ } else if (options_.io_by_cpu) {
const size_type host_ws_size{dev_ws_size +
n * sizeof(value_type) * dim()};
auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
@@ -502,7 +544,7 @@ Program Listing for File merlin_hashtable.cuh
void find_or_insert(const size_type n, const key_type* keys, // (n)
value_type* values, // (n * DIM)
score_type* scores = nullptr, // (n)
- cudaStream_t stream = 0,
+ cudaStream_t stream = 0, bool unique_key = true,
bool ignore_evict_strategy = false) {
if (n == 0) {
return;
@@ -520,31 +562,66 @@ Program Listing for File merlin_hashtable.cuh
insert_unique_lock lock(mutex_, stream);
if (is_fast_mode()) {
- using Selector =
- SelectFindOrInsertKernelWithIO;
static thread_local int step_counter = 0;
static thread_local float load_factor = 0.0;
if (((step_counter++) % kernel_select_interval_) == 0) {
load_factor = fast_load_factor(0, stream, false);
}
- Selector::execute_kernel(load_factor, options_.block_size,
- options_.max_bucket_size, table_->buckets_num,
- options_.dim, stream, n, d_table_,
- table_->buckets, keys, values, scores,
- EvictStrategyParam.global_epoch);
+
+ using Selector =
+ KernelSelector_FindOrInsert;
+ if (Selector::callable(unique_key,
+ static_cast(options_.max_bucket_size),
+ static_cast(options_.dim))) {
+ typename Selector::Params kernelParams(
+ load_factor, table_->buckets, table_->buckets_size,
+ table_->buckets_num,
+ static_cast(options_.max_bucket_size),
+ static_cast(options_.dim), keys, values, scores, n,
+ EvictStrategyParam.global_epoch);
+ Selector::select_kernel(kernelParams, stream);
+ } else {
+ using Selector =
+ SelectFindOrInsertKernelWithIO;
+ Selector::execute_kernel(load_factor, options_.block_size,
+ options_.max_bucket_size, table_->buckets_num,
+ options_.dim, stream, n, d_table_,
+ table_->buckets, keys, values, scores,
+ EvictStrategyParam.global_epoch);
+ }
} else {
- const size_type dev_ws_size{
- n * (sizeof(value_type*) + sizeof(int) + sizeof(bool))};
+ const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int) +
+ sizeof(bool) + sizeof(key_type*))};
auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
auto d_table_value_addrs{dev_ws.get(0)};
- auto param_key_index{reinterpret_cast(d_table_value_addrs + n)};
+ auto keys_ptr{reinterpret_cast(d_table_value_addrs + n)};
+ auto param_key_index{reinterpret_cast(keys_ptr + n)};
auto founds{reinterpret_cast(param_key_index + n)};
CUDA_CHECK(cudaMemsetAsync(d_table_value_addrs, 0, dev_ws_size, stream));
- {
+ constexpr uint32_t MinBucketCapacityFilter =
+ sizeof(VecD_Load) / sizeof(D);
+
+ bool filter_condition =
+ unique_key && options_.max_bucket_size >= MinBucketCapacityFilter &&
+ !options_.io_by_cpu;
+
+ if (filter_condition) {
+ constexpr uint32_t BLOCK_SIZE = 128;
+
+ find_or_insert_kernel_lock_key_hybrid
+ <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
+ table_->buckets, table_->buckets_size, table_->buckets_num,
+ options_.max_bucket_size, options_.dim, keys,
+ d_table_value_addrs, scores, keys_ptr, param_key_index, founds,
+ n, EvictStrategyParam.global_epoch);
+
+ } else {
const size_t block_size = options_.block_size;
const size_t N = n * TILE_SIZE;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
@@ -567,7 +644,17 @@ Program Listing for File merlin_hashtable.cuh
thrust::less());
}
- if (options_.io_by_cpu) {
+ if (filter_condition) {
+ const size_t block_size = options_.io_block_size;
+ const size_t N = n * dim();
+ const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
+
+ read_or_write_kernel_unlock_key
+ <<>>(d_table_value_addrs, values,
+ founds, param_key_index,
+ keys_ptr, keys, dim(), N);
+
+ } else if (options_.io_by_cpu) {
const size_type host_ws_size{
dev_ws_size + n * (sizeof(bool) + sizeof(value_type) * dim())};
auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
@@ -609,7 +696,7 @@ Program Listing for File merlin_hashtable.cuh
value_type** values, // (n)
bool* founds, // (n)
score_type* scores = nullptr, // (n)
- cudaStream_t stream = 0,
+ cudaStream_t stream = 0, bool unique_key = true,
bool ignore_evict_strategy = false) {
if (n == 0) {
return;
@@ -626,18 +713,41 @@ Program Listing for File merlin_hashtable.cuh
insert_unique_lock lock(mutex_, stream);
- using Selector = SelectFindOrInsertPtrKernel;
- static thread_local int step_counter = 0;
- static thread_local float load_factor = 0.0;
+ constexpr uint32_t MinBucketCapacityFilter = sizeof(VecD_Load) / sizeof(D);
- if (((step_counter++) % kernel_select_interval_) == 0) {
- load_factor = fast_load_factor(0, stream, false);
+ if (unique_key && options_.max_bucket_size >= MinBucketCapacityFilter) {
+ constexpr uint32_t BLOCK_SIZE = 128U;
+
+ const size_type dev_ws_size{n * sizeof(key_type**)};
+ auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
+ auto keys_ptr{dev_ws.get(0)};
+ CUDA_CHECK(cudaMemsetAsync(keys_ptr, 0, dev_ws_size, stream));
+
+ find_or_insert_ptr_kernel_lock_key
+ <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
+ table_->buckets, table_->buckets_size, table_->buckets_num,
+ options_.max_bucket_size, options_.dim, keys, values, scores,
+ keys_ptr, n, founds, EvictStrategyParam.global_epoch);
+
+ find_or_insert_ptr_kernel_unlock_key
+ <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
+ keys, keys_ptr, n);
+ } else {
+ using Selector = SelectFindOrInsertPtrKernel;
+ static thread_local int step_counter = 0;
+ static thread_local float load_factor = 0.0;
+
+ if (((step_counter++) % kernel_select_interval_) == 0) {
+ load_factor = fast_load_factor(0, stream, false);
+ }
+ Selector::execute_kernel(load_factor, options_.block_size,
+ options_.max_bucket_size, table_->buckets_num,
+ options_.dim, stream, n, d_table_,
+ table_->buckets, keys, values, scores, founds,
+ EvictStrategyParam.global_epoch);
}
- Selector::execute_kernel(
- load_factor, options_.block_size, options_.max_bucket_size,
- table_->buckets_num, options_.dim, stream, n, d_table_, table_->buckets,
- keys, values, scores, founds, EvictStrategyParam.global_epoch);
CudaCheckError();
}
@@ -680,14 +790,33 @@ Program Listing for File merlin_hashtable.cuh
EvictStrategyParam.global_epoch);
}
} else {
- const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int))};
+ const size_type dev_ws_size{
+ n * (sizeof(value_type*) + sizeof(key_type) + sizeof(int))};
auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
auto d_dst{dev_ws.get(0)};
- auto d_src_offset{reinterpret_cast(d_dst + n)};
+ auto keys_ptr{reinterpret_cast(d_dst + n)};
+ auto d_src_offset{reinterpret_cast(keys_ptr + n)};
CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));
- {
+ constexpr uint32_t MinBucketCapacityFilter =
+ sizeof(VecD_Load) / sizeof(D);
+
+ bool filter_condition =
+ options_.max_bucket_size >= MinBucketCapacityFilter &&
+ !options_.io_by_cpu && unique_key;
+
+ if (filter_condition) {
+ constexpr uint32_t BLOCK_SIZE = 128U;
+
+ tlp_update_kernel_hybrid
+ <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
+ table_->buckets, table_->buckets_num, options_.max_bucket_size,
+ options_.dim, keys, d_dst, scores, keys_ptr, d_src_offset,
+ EvictStrategyParam.global_epoch, n);
+
+ } else {
const size_t block_size = options_.block_size;
const size_t N = n * TILE_SIZE;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
@@ -708,7 +837,16 @@ Program Listing for File merlin_hashtable.cuh
d_src_offset_ptr, thrust::less());
}
- if (options_.io_by_cpu) {
+ if (filter_condition) {
+ const size_t block_size = options_.io_block_size;
+ const size_t N = n * dim();
+ const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
+
+ write_kernel_unlock_key
+ <<>>(values, d_dst, d_src_offset,
+ dim(), keys, keys_ptr, N);
+
+ } else if (options_.io_by_cpu) {
const size_type host_ws_size{dev_ws_size +
n * sizeof(value_type) * dim()};
auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
@@ -822,7 +960,20 @@ Program Listing for File merlin_hashtable.cuh
CUDA_CHECK(cudaMemsetAsync(src, 0, dev_ws_size, stream));
- {
+ constexpr uint32_t MinBucketCapacityFilter =
+ sizeof(VecD_Load) / sizeof(D);
+
+ bool filter_condition =
+ options_.max_bucket_size >= MinBucketCapacityFilter;
+
+ if (filter_condition) {
+ constexpr uint32_t BLOCK_SIZE = 128U;
+
+ tlp_lookup_kernel_hybrid
+ <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
+ table_->buckets, table_->buckets_num, options_.max_bucket_size,
+ options_.dim, keys, src, scores, dst_offset, founds, n);
+ } else {
const size_t block_size = options_.block_size;
const size_t N = n * TILE_SIZE;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
@@ -859,7 +1010,7 @@ Program Listing for File merlin_hashtable.cuh
value_type** values, // (n)
bool* founds, // (n)
score_type* scores = nullptr, // (n)
- cudaStream_t stream = 0) const {
+ cudaStream_t stream = 0, bool unique_key = true) const {
if (n == 0) {
return;
}
@@ -868,17 +1019,27 @@ Program Listing for File merlin_hashtable.cuh
read_shared_lock lock(mutex_, stream);
- using Selector = SelectLookupPtrKernel;
- static thread_local int step_counter = 0;
- static thread_local float load_factor = 0.0;
+ constexpr uint32_t MinBucketCapacityFilter = sizeof(VecD_Load) / sizeof(D);
+ if (unique_key && options_.max_bucket_size >= MinBucketCapacityFilter) {
+ constexpr uint32_t BLOCK_SIZE = 128U;
+ tlp_lookup_ptr_kernel_with_filter
+ <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
+ table_->buckets, table_->buckets_num, options_.max_bucket_size,
+ options_.dim, keys, values, scores, founds, n);
+ } else {
+ using Selector = SelectLookupPtrKernel;
+ static thread_local int step_counter = 0;
+ static thread_local float load_factor = 0.0;
- if (((step_counter++) % kernel_select_interval_) == 0) {
- load_factor = fast_load_factor(0, stream, false);
+ if (((step_counter++) % kernel_select_interval_) == 0) {
+ load_factor = fast_load_factor(0, stream, false);
+ }
+
+ Selector::execute_kernel(load_factor, options_.block_size,
+ options_.max_bucket_size, table_->buckets_num,
+ options_.dim, stream, n, d_table_,
+ table_->buckets, keys, values, scores, founds);
}
- Selector::execute_kernel(load_factor, options_.block_size,
- options_.max_bucket_size, table_->buckets_num,
- options_.dim, stream, n, d_table_, table_->buckets,
- keys, values, scores, founds);
CudaCheckError();
}
@@ -1273,7 +1434,7 @@ Program Listing for File merlin_hashtable.cuh
}
EvictStrategy::set_global_epoch(static_cast(IGNORED_GLOBAL_EPOCH));
- insert_or_assign(count, d_keys, d_values, d_scores, stream, true);
+ insert_or_assign(count, d_keys, d_values, d_scores, stream, true, true);
total_count += count;
// Read next batch.
diff --git a/review/pr-177/api/classnv_1_1merlin_1_1HashTable.html b/review/pr-177/api/classnv_1_1merlin_1_1HashTable.html
index aa0bcfb22..938093233 100644
--- a/review/pr-177/api/classnv_1_1merlin_1_1HashTable.html
+++ b/review/pr-177/api/classnv_1_1merlin_1_1HashTable.html
@@ -175,8 +175,8 @@ Class Documentation