Skip to content

Commit

Permalink
[experimental][ROCM] Add shared memory support on ROCM RT and Target. (
Browse files Browse the repository at this point in the history
…iree-org#15097)

Currently ToM IREE ROCM compiler path is expecting shared memory.
However, since our runtime is not supporting it this can and will cause
correctness issues. We fix this issue by adding shared memory support on
runtime and target side.
  • Loading branch information
raikonenfnu authored Oct 4, 2023
1 parent b3e5a43 commit ad64ecc
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 6 deletions.
10 changes: 10 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class ROCMTargetBackend final : public TargetBackend {
exportOps[op.getSymName()] = op;
}
std::vector<std::array<int32_t, 3>> workgroupSizes;
SmallVector<uint32_t> workgroupLocalMemories;
for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
int32_t flatWgSize = 1;
auto *llvmFunc = llvmModule->getFunction(func.getName());
Expand All @@ -166,6 +167,11 @@ class ROCMTargetBackend final : public TargetBackend {
workgroupSize = {1, 1, 1};
}
workgroupSizes.push_back(workgroupSize);
uint32_t workgroupLocalMemory = 0;
if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) {
workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue();
}
workgroupLocalMemories.push_back(workgroupLocalMemory);
// For GPU kernels,
// 1. Insert AMDGPU_KERNEL calling convention.
// 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
Expand Down Expand Up @@ -230,10 +236,14 @@ class ROCMTargetBackend final : public TargetBackend {
builder, (*blockSizes)[0], (*blockSizes)[1], (*blockSizes)[2]);
++blockSizes;
}
auto workgroupLocalMemoriesRef =
builder.createInt32Vec(workgroupLocalMemories);
auto blockSizesRef = iree_hal_rocm_BlockSizeDef_vec_end(builder);

iree_hal_rocm_ExecutableDef_entry_points_add(builder, entryPointsRef);
iree_hal_rocm_ExecutableDef_block_sizes_add(builder, blockSizesRef);
iree_hal_rocm_ExecutableDef_shared_memory_sizes_add(
builder, workgroupLocalMemoriesRef);
iree_hal_rocm_ExecutableDef_hsaco_image_add(builder, hsacoRef);
iree_hal_rocm_ExecutableDef_end_as_root(builder);

Expand Down
1 change: 1 addition & 0 deletions experimental/rocm/context_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// Structure to wrap all objects constant within a context. This makes it
// simpler to pass it to the different objects and saves memory.
typedef struct iree_hal_rocm_context_wrapper_t {
hipDevice_t rocm_device;
hipCtx_t rocm_context;
iree_allocator_t host_allocator;
iree_hal_rocm_dynamic_symbols_t *syms;
Expand Down
12 changes: 6 additions & 6 deletions experimental/rocm/direct_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_fill_buffer(
ROCM_RETURN_IF_ERROR(
command_buffer->context->syms,
hipMemsetD8Async(dst, *(const uint8_t*)(pattern), num_elements, 0),
"hipMemsetD*Async");
"hipMemsetD8Async");
break;
}
default:
Expand Down Expand Up @@ -369,11 +369,11 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch(
// access proper stream from command buffer
ROCM_RETURN_IF_ERROR(
command_buffer->context->syms,
hipModuleLaunchKernel(kernel_params.function, workgroup_x, workgroup_y,
workgroup_z, kernel_params.block_size[0],
kernel_params.block_size[1],
kernel_params.block_size[2], 0, 0,
command_buffer->current_descriptor, NULL),
hipModuleLaunchKernel(
kernel_params.function, workgroup_x, workgroup_y, workgroup_z,
kernel_params.block_size[0], kernel_params.block_size[1],
kernel_params.block_size[2], kernel_params.shared_memory_size, 0,
command_buffer->current_descriptor, NULL),
"hipModuleLaunchKernel");

IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, 0);
Expand Down
2 changes: 2 additions & 0 deletions experimental/rocm/dynamic_symbol_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,5 @@ RC_PFN_DECL(hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
RC_PFN_DECL(hipEventQuery, hipEvent_t)
RC_PFN_DECL(hipEventRecord, hipEvent_t, hipStream_t)
RC_PFN_DECL(hipEventSynchronize, hipEvent_t)
RC_PFN_DECL(hipDeviceGetAttribute, int *, hipDeviceAttribute_t, int)
RC_PFN_DECL(hipFuncSetAttribute, const void *, hipFuncAttribute, int)
28 changes: 28 additions & 0 deletions experimental/rocm/native_executable.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ iree_status_t iree_hal_rocm_native_executable_create(
iree_hal_rocm_ExecutableDef_entry_points_get(executable_def);
iree_hal_rocm_BlockSizeDef_vec_t block_sizes_vec =
iree_hal_rocm_ExecutableDef_block_sizes_get(executable_def);
flatbuffers_uint32_vec_t shared_memory_sizes =
iree_hal_rocm_ExecutableDef_shared_memory_sizes_get(executable_def);
iree_host_size_t entry_count = flatbuffers_string_vec_len(entry_points_vec);

// Calculate the total number of characters across all entry point names. This
Expand Down Expand Up @@ -94,6 +96,17 @@ iree_status_t iree_hal_rocm_native_executable_create(
"hipModuleLoadDataEx");
}

// Query allowed max shared memory.
int32_t max_shared_mem = 0;
if (iree_status_is_ok(status)) {
status = ROCM_RESULT_TO_STATUS(
context->syms,
hipDeviceGetAttribute(&max_shared_mem,
hipDeviceAttributeMaxSharedMemoryPerBlock,
context->rocm_device),
"hipDeviceGetAttribute");
}

if (iree_status_is_ok(status)) {
executable->entry_count = entry_count;
for (iree_host_size_t i = 0; i < entry_count; i++) {
Expand All @@ -111,6 +124,20 @@ iree_status_t iree_hal_rocm_native_executable_create(
entry_name);
break;
}
if (shared_memory_sizes[i] > max_shared_mem) {
status =
iree_make_status(IREE_STATUS_INTERNAL,
"ROCM driver error: Requested shared memory "
"size of %d larger than allowed size of %d",
shared_memory_sizes[i], max_shared_mem);
} else if (shared_memory_sizes[i] != 0) {
status = ROCM_RESULT_TO_STATUS(
context->syms,
hipFuncSetAttribute(
function, HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_memory_sizes[i]),
"hipFuncSetAttribute");
}
// Package required parameters for kernel launches for each entry point.
iree_hal_rocm_kernel_params_t* params = &executable->entry_points[i];
params->layout = executable_params->pipeline_layouts[i];
Expand All @@ -119,6 +146,7 @@ iree_status_t iree_hal_rocm_native_executable_create(
params->block_size[0] = block_sizes_vec[i].x;
params->block_size[1] = block_sizes_vec[i].y;
params->block_size[2] = block_sizes_vec[i].z;
params->shared_memory_size = shared_memory_sizes[i];
// Stash the entry point name in the string table for use when tracing.
IREE_TRACE({
iree_host_size_t entry_name_length =
Expand Down
1 change: 1 addition & 0 deletions experimental/rocm/native_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ typedef struct iree_hal_rocm_kernel_params_t {
iree_hal_pipeline_layout_t* layout;
hipFunction_t function;
uint32_t block_size[3];
uint32_t shared_memory_size;
IREE_TRACE(iree_string_view_t function_name;)
IREE_TRACE(iree_string_view_t source_filename;)
IREE_TRACE(uint32_t source_line;)
Expand Down
1 change: 1 addition & 0 deletions experimental/rocm/rocm_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ static iree_status_t iree_hal_rocm_device_create_internal(
device->device = rocm_device;
device->stream = stream;
device->context_wrapper.rocm_context = context;
device->context_wrapper.rocm_device = rocm_device;
device->context_wrapper.host_allocator = host_allocator;
device->context_wrapper.syms = syms;
// Enable tracing for the (currently only) stream - no-op if disabled.
Expand Down
3 changes: 3 additions & 0 deletions runtime/src/iree/schemas/rocm_executable_def.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ table ExecutableDef {
//
block_sizes:[BlockSizeDef];

// Size of dynamic shared memory.
shared_memory_sizes:[uint32];

// HSACO string of the module.
hsaco_image:string;
}
Expand Down

0 comments on commit ad64ecc

Please sign in to comment.