diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index d8a85002..e250d234 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -174,6 +174,7 @@ void initialize_buckets(Table** table, BaseAllocator* allocator, ((*table)->num_of_memory_slices + num_of_memory_slices) * sizeof(V*), allocator); + bool mixed_hbm = false; for (size_t i = (*table)->num_of_memory_slices; i < (*table)->num_of_memory_slices + num_of_memory_slices; i++) { if (i == (*table)->num_of_memory_slices + num_of_memory_slices - 1) { @@ -183,6 +184,9 @@ void initialize_buckets(Table** table, BaseAllocator* allocator, (*table)->bucket_max_size * sizeof(V) * (*table)->dim; if ((*table)->remaining_hbm_for_vectors >= slice_real_size) { + if (!(*table)->is_pure_hbm) { + mixed_hbm = true; + } allocator->alloc(MemoryType::Device, (void**)&((*table)->slices[i]), slice_real_size); (*table)->remaining_hbm_for_vectors -= slice_real_size; @@ -192,7 +196,7 @@ void initialize_buckets(Table** table, BaseAllocator* allocator, slice_real_size, cudaHostAllocMapped); } for (int j = 0; j < num_of_buckets_in_one_slice; j++) { - if ((*table)->is_pure_hbm) { + if ((*table)->is_pure_hbm || mixed_hbm) { size_t index = start + num_of_allocated_buckets + j; V* address = (*table)->slices[i] + j * (*table)->bucket_max_size * (*table)->dim;