Skip to content

Commit

Permalink
[opt] Allocate hbm uniformly for buckets to avoid fragmentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
LinGeLin authored and rhdong committed Dec 7, 2023
1 parent 8a4b781 commit 770be38
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
26 changes: 10 additions & 16 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,14 @@ void initialize_buckets(Table<K, V, S>** table, BaseAllocator* allocator,
uint32_t reserve_size =
bucket_max_size < CACHE_LINE_SIZE ? CACHE_LINE_SIZE : bucket_max_size;
bucket_memory_size += reserve_size * sizeof(uint8_t);
uint8_t* address = nullptr;
allocator->alloc(MemoryType::Device, (void**)&(address),
bucket_memory_size * (end - start));
(*table)->buckets_address.push_back(address);
for (int i = start; i < end; i++) {
uint8_t* address = nullptr;
allocator->alloc(MemoryType::Device, (void**)&(address),
bucket_memory_size);
allocate_bucket_others<K, V, S><<<1, 1>>>((*table)->buckets, i, address,
reserve_size, bucket_max_size);
allocate_bucket_others<K, V, S><<<1, 1>>>(
(*table)->buckets, i, address + (bucket_memory_size * (i - start)),
reserve_size, bucket_max_size);
}
CUDA_CHECK(cudaDeviceSynchronize());

Expand Down Expand Up @@ -365,17 +367,9 @@ void double_capacity(Table<K, V, S>** table, BaseAllocator* allocator) {
/* free all of the resource of a Table. */
template <class K, class V, class S>
void destroy_table(Table<K, V, S>** table, BaseAllocator* allocator) {
uint8_t** d_address = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_address, sizeof(uint8_t*)));
for (int i = 0; i < (*table)->buckets_num; i++) {
uint8_t* h_address;
get_bucket_others_address<K, V, S>
<<<1, 1>>>((*table)->buckets, i, d_address);
CUDA_CHECK(cudaMemcpy(&h_address, d_address, sizeof(uint8_t*),
cudaMemcpyDeviceToHost));
allocator->free(MemoryType::Device, h_address);
}
CUDA_CHECK(cudaFree(d_address));
for (auto addr : (*table)->buckets_address) {
allocator->free(MemoryType::Device, addr);
}

for (int i = 0; i < (*table)->num_of_memory_slices; i++) {
if (is_on_device((*table)->slices[i])) {
Expand Down
2 changes: 2 additions & 0 deletions include/merlin/types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <stddef.h>
#include <cstdint>
#include <cuda/std/semaphore>
#include <vector>

namespace nv {
namespace merlin {
Expand Down Expand Up @@ -161,6 +162,7 @@ struct Table {
int slots_number = 0; // unused
int device_id = 0; // Device id
int tile_size;
std::vector<uint8_t*> buckets_address;
};

template <class K, class S>
Expand Down

0 comments on commit 770be38

Please sign in to comment.