diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index 4214c2b92775e..b28acef5471bc 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h @@ -25,6 +25,7 @@ RC_PFN_DECL(hipInit, unsigned int) RC_PFN_DECL(hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **) +RC_PFN_DECL(hipMemAdvise, const void *, size_t, int, int) RC_PFN_DECL(hipMemset, void *, int, size_t) RC_PFN_DECL(hipMemsetAsync, void *, int, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD32Async, void *, int, size_t, hipStream_t) diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c index 84dfb32f81c87..dbd0ea4b99366 100644 --- a/experimental/rocm/rocm_allocator.c +++ b/experimental/rocm/rocm_allocator.c @@ -285,6 +285,16 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer( status = ROCM_RESULT_TO_STATUS( allocator->context->syms, hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal)); + if (iree_status_is_ok(status)) { + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemAdvise(device_ptr, allocation_size, + hipMemAdviseSetPreferredLocation, allocator->device)); + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemAdvise(device_ptr, allocation_size, + hipMemAdviseSetCoarseGrain, allocator->device)); + } if (iree_status_is_ok(status) && allocator->supports_concurrent_managed_access) { // Prefetch the buffer on the GPU device.