Skip to content

Commit

Permalink
[fix] Fix init_bucket failure caused by max_hbm being smaller than sl…
Browse files Browse the repository at this point in the history
…ice size
  • Loading branch information
LinGeLin authored and rhdong committed Jun 13, 2024
1 parent d831ff5 commit c895d54
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ void initialize_buckets(Table<K, V, S>** table, BaseAllocator* allocator,
((*table)->num_of_memory_slices + num_of_memory_slices) * sizeof(V*),
allocator);

bool mixed_hbm = false;
for (size_t i = (*table)->num_of_memory_slices;
i < (*table)->num_of_memory_slices + num_of_memory_slices; i++) {
if (i == (*table)->num_of_memory_slices + num_of_memory_slices - 1) {
Expand All @@ -189,6 +190,9 @@ void initialize_buckets(Table<K, V, S>** table, BaseAllocator* allocator,
(*table)->bucket_max_size * sizeof(V) *
(*table)->dim;
if ((*table)->remaining_hbm_for_vectors >= slice_real_size) {
if (!(*table)->is_pure_hbm) {
mixed_hbm = true;
}
allocator->alloc(MemoryType::Device, (void**)&((*table)->slices[i]),
slice_real_size);
(*table)->remaining_hbm_for_vectors -= slice_real_size;
Expand All @@ -198,7 +202,7 @@ void initialize_buckets(Table<K, V, S>** table, BaseAllocator* allocator,
slice_real_size, cudaHostAllocMapped);
}
for (int j = 0; j < num_of_buckets_in_one_slice; j++) {
if ((*table)->is_pure_hbm) {
if ((*table)->is_pure_hbm || mixed_hbm) {
size_t index = start + num_of_allocated_buckets + j;
V* address =
(*table)->slices[i] + j * (*table)->bucket_max_size * (*table)->dim;
Expand Down

0 comments on commit c895d54

Please sign in to comment.