Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hip][cuda] Merged pending_queue_actions implementations. #18220

Merged
merged 9 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ iree_runtime_cc_library(
"nccl_channel.h",
"nop_executable_cache.c",
"nop_executable_cache.h",
"pending_queue_actions.c",
"pending_queue_actions.h",
"pipeline_layout.c",
"pipeline_layout.h",
"stream_command_buffer.c",
Expand Down Expand Up @@ -66,6 +64,7 @@ iree_runtime_cc_library(
"//runtime/src/iree/hal",
"//runtime/src/iree/hal/utils:collective_batch",
"//runtime/src/iree/hal/utils:deferred_command_buffer",
"//runtime/src/iree/hal/utils:deferred_work_queue",
"//runtime/src/iree/hal/utils:file_transfer",
"//runtime/src/iree/hal/utils:memory_file",
"//runtime/src/iree/hal/utils:resource_set",
Expand Down
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ iree_cc_library(
"nccl_channel.h"
"nop_executable_cache.c"
"nop_executable_cache.h"
"pending_queue_actions.c"
"pending_queue_actions.h"
"pipeline_layout.c"
"pipeline_layout.h"
"stream_command_buffer.c"
Expand All @@ -63,6 +61,7 @@ iree_cc_library(
iree::hal
iree::hal::utils::collective_batch
iree::hal::utils::deferred_command_buffer
iree::hal::utils::deferred_work_queue
iree::hal::utils::file_transfer
iree::hal::utils::memory_file
iree::hal::utils::resource_set
Expand Down
229 changes: 210 additions & 19 deletions runtime/src/iree/hal/drivers/cuda/cuda_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
#include "iree/hal/drivers/cuda/nccl_channel.h"
#include "iree/hal/drivers/cuda/nccl_dynamic_symbols.h"
#include "iree/hal/drivers/cuda/nop_executable_cache.h"
#include "iree/hal/drivers/cuda/pending_queue_actions.h"
#include "iree/hal/drivers/cuda/pipeline_layout.h"
#include "iree/hal/drivers/cuda/stream_command_buffer.h"
#include "iree/hal/drivers/cuda/timepoint_pool.h"
#include "iree/hal/drivers/cuda/tracing.h"
#include "iree/hal/utils/deferred_command_buffer.h"
#include "iree/hal/utils/deferred_work_queue.h"
#include "iree/hal/utils/file_transfer.h"
#include "iree/hal/utils/memory_file.h"

Expand Down Expand Up @@ -76,7 +76,7 @@ typedef struct iree_hal_cuda_device_t {
// are met. It buffers submissions and allocations internally before they
// are ready. This queue couples with HAL semaphores backed by iree_event_t
// and CUevent objects.
iree_hal_cuda_pending_queue_actions_t* pending_queue_actions;
iree_hal_deferred_work_queue_t* work_queue;

// Device memory pools and allocators.
bool supports_memory_pools;
Expand All @@ -88,6 +88,154 @@ typedef struct iree_hal_cuda_device_t {
} iree_hal_cuda_device_t;

static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable;
static const iree_hal_deferred_work_queue_device_interface_vtable_t
iree_hal_cuda_deferred_work_queue_device_interface_vtable;

// We put a CUEvent into a void*.
static_assert(sizeof(CUevent) <= sizeof(void*), "Unexpected event size");
typedef struct iree_hal_cuda_deferred_work_queue_device_interface_t {
iree_hal_deferred_work_queue_device_interface_t base;
iree_hal_device_t* device;
CUdevice cu_device;
CUcontext cu_context;
CUstream dispatch_cu_stream;
iree_allocator_t host_allocator;
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols;
} iree_hal_cuda_deferred_work_queue_device_interface_t;

void iree_hal_cuda_deferred_work_queue_symbol_table_destroy(
iree_hal_deferred_work_queue_device_interface_t* symbol_table) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
iree_allocator_free(table->host_allocator, table);
}

iree_status_t iree_hal_cuda_deferred_work_queue_symbol_table_bind_to_thread(
iree_hal_deferred_work_queue_device_interface_t* symbol_table) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
return IREE_CURESULT_TO_STATUS(table->cuda_symbols,
cuCtxSetCurrent(table->cu_context),
"cuCtxSetCurrent");
}

iree_status_t iree_hal_cuda_deferred_work_queue_symbol_table_wait_native_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
void* event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
CUevent cu_event = (CUevent)(event);
return IREE_CURESULT_TO_STATUS(
table->cuda_symbols,
cuStreamWaitEvent(table->dispatch_cu_stream, cu_event,
CU_EVENT_WAIT_DEFAULT),
"cuStreamWaitEvent");
}

iree_status_t
iree_hal_cuda_deferred_work_queue_symbol_table_create_native_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
void** out_event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
CUevent* out = (CUevent*)(out_event);
return IREE_CURESULT_TO_STATUS(table->cuda_symbols,
cuEventCreate(out, CU_EVENT_WAIT_DEFAULT),
"cuEventCreate");
}
iree_status_t
iree_hal_cuda_deferred_work_queue_symbol_table_record_native_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
void* event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
CUevent cu_event = (CUevent)(event);
return IREE_CURESULT_TO_STATUS(
table->cuda_symbols, cuEventRecord(cu_event, table->dispatch_cu_stream),
"cuEventCreate");
}

iree_status_t
iree_hal_cuda_deferred_work_queue_symbol_table_synchronize_native_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
void* event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
CUevent cu_event = (CUevent)(event);
return IREE_CURESULT_TO_STATUS(table->cuda_symbols,
cuEventSynchronize(cu_event));
}
iree_status_t
iree_hal_cuda_deferred_work_queue_symbol_table_destroy_native_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
void* event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
CUevent cu_event = (CUevent)(event);
return IREE_CURESULT_TO_STATUS(table->cuda_symbols, cuEventDestroy(cu_event));
}

iree_status_t
iree_hal_cuda_deferred_work_queue_symbol_table_semaphore_acquire_timepoint_device_signal_native_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
struct iree_hal_semaphore_t* semaphore, uint64_t value, void** out_event) {
CUevent* out = (CUevent*)(out_event);
return iree_hal_cuda_event_semaphore_acquire_timepoint_device_signal(
semaphore, value, out);
}

bool iree_hal_cuda_deferred_work_queue_symbol_table_acquire_host_wait_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
struct iree_hal_semaphore_t* semaphore, uint64_t value, void** out_event) {
return iree_hal_cuda_semaphore_acquire_event_host_wait(
semaphore, value, (iree_hal_cuda_event_t**)out_event);
}

void iree_hal_cuda_deferred_work_queue_symbol_table_release_wait_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
void* wait_event) {
iree_hal_cuda_event_release(wait_event);
}

void* iree_hal_cuda_deferred_work_queue_symbol_table_native_event_from_wait_event(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
void* event) {
iree_hal_cuda_event_t* wait_event = (iree_hal_cuda_event_t*)event;
return iree_hal_cuda_event_handle(wait_event);
}

iree_status_t
iree_hal_cuda_deferred_work_queue_symbol_table_create_command_buffer_for_deferred(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t categories,
iree_hal_command_buffer_t** out) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
return iree_hal_cuda_device_create_stream_command_buffer(table->device, mode,
categories, 0, out);
}

iree_status_t
iree_hal_cuda_deferred_work_queue_symbol_table_submit_command_buffer(
iree_hal_deferred_work_queue_device_interface_t* symbol_table,
iree_hal_command_buffer_t* command_buffer) {
iree_hal_cuda_deferred_work_queue_device_interface_t* table =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(symbol_table);
iree_status_t status = iree_ok_status();
if (iree_hal_cuda_stream_command_buffer_isa(command_buffer)) {
// Stream command buffer so nothing to do but notify it was submitted.
iree_hal_cuda_stream_notify_submitted_commands(command_buffer);
} else {
CUgraphExec exec =
iree_hal_cuda_graph_command_buffer_handle(command_buffer);
status = IREE_CURESULT_TO_STATUS(
table->cuda_symbols, cuGraphLaunch(exec, table->dispatch_cu_stream));
if (IREE_LIKELY(iree_status_is_ok(status))) {
iree_hal_cuda_graph_tracing_notify_submitted_commands(command_buffer);
}
}
return status;
}

static iree_hal_cuda_device_t* iree_hal_cuda_device_cast(
iree_hal_device_t* base_value) {
Expand Down Expand Up @@ -152,9 +300,27 @@ static iree_status_t iree_hal_cuda_device_create_internal(
device->dispatch_cu_stream = dispatch_stream;
device->host_allocator = host_allocator;

iree_status_t status = iree_hal_cuda_pending_queue_actions_create(
cuda_symbols, cu_device, context, &device->block_pool, host_allocator,
&device->pending_queue_actions);
iree_hal_cuda_deferred_work_queue_device_interface_t* symbol_table;
iree_status_t status = iree_allocator_malloc(
host_allocator,
sizeof(iree_hal_cuda_deferred_work_queue_device_interface_t),
(void**)&symbol_table);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
iree_hal_device_release((iree_hal_device_t*)device);
return status;
}
symbol_table->base._vtable =
&iree_hal_cuda_deferred_work_queue_device_interface_vtable;
symbol_table->cu_context = context;
symbol_table->cuda_symbols = cuda_symbols;
symbol_table->cu_device = cu_device;
symbol_table->device = (iree_hal_device_t*)device;
symbol_table->dispatch_cu_stream = dispatch_stream;
symbol_table->host_allocator = host_allocator;

status = iree_hal_deferred_work_queue_create(
(iree_hal_deferred_work_queue_device_interface_t*)symbol_table,
&device->block_pool, host_allocator, &device->work_queue);

// Enable tracing for the (currently only) stream - no-op if disabled.
if (iree_status_is_ok(status) && device->params.stream_tracing) {
Expand Down Expand Up @@ -297,8 +463,7 @@ static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
IREE_TRACE_ZONE_BEGIN(z0);

// Destroy the pending workload queue.
iree_hal_cuda_pending_queue_actions_destroy(
(iree_hal_resource_t*)device->pending_queue_actions);
iree_hal_deferred_work_queue_destroy(device->work_queue);

// There should be no more buffers live that use the allocator.
iree_hal_allocator_release(device->device_allocator);
Expand Down Expand Up @@ -620,7 +785,7 @@ static iree_status_t iree_hal_cuda_device_create_semaphore(
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
return iree_hal_cuda_event_semaphore_create(
initial_value, device->cuda_symbols, device->timepoint_pool,
device->pending_queue_actions, device->host_allocator, out_semaphore);
device->work_queue, device->host_allocator, out_semaphore);
}

static iree_hal_semaphore_compatibility_t
Expand Down Expand Up @@ -765,15 +930,13 @@ static iree_status_t iree_hal_cuda_device_queue_execute(
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);

iree_status_t status = iree_hal_cuda_pending_queue_actions_enqueue_execution(
base_device, device->dispatch_cu_stream, device->pending_queue_actions,
iree_hal_cuda_device_collect_tracing_context, device->tracing_context,
wait_semaphore_list, signal_semaphore_list, command_buffer_count,
command_buffers, binding_tables);
iree_status_t status = iree_hal_deferred_work_queue_enque(
device->work_queue, iree_hal_cuda_device_collect_tracing_context,
device->tracing_context, wait_semaphore_list, signal_semaphore_list,
command_buffer_count, command_buffers, binding_tables);
if (iree_status_is_ok(status)) {
// Try to advance the pending workload queue.
status = iree_hal_cuda_pending_queue_actions_issue(
device->pending_queue_actions);
// Try to advance the deferred work queue.
status = iree_hal_deferred_work_queue_issue(device->work_queue);
}

IREE_TRACE_ZONE_END(z0);
Expand All @@ -784,9 +947,8 @@ static iree_status_t iree_hal_cuda_device_queue_flush(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
// Try to advance the pending workload queue.
iree_status_t status =
iree_hal_cuda_pending_queue_actions_issue(device->pending_queue_actions);
// Try to advance the deferred work queue.
iree_status_t status = iree_hal_deferred_work_queue_issue(device->work_queue);
IREE_TRACE_ZONE_END(z0);
return status;
}
Expand Down Expand Up @@ -850,3 +1012,32 @@ static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = {
.profiling_flush = iree_hal_cuda_device_profiling_flush,
.profiling_end = iree_hal_cuda_device_profiling_end,
};

static const iree_hal_deferred_work_queue_device_interface_vtable_t
iree_hal_cuda_deferred_work_queue_device_interface_vtable = {
.destroy = iree_hal_cuda_deferred_work_queue_symbol_table_destroy,
.bind_to_thread =
iree_hal_cuda_deferred_work_queue_symbol_table_bind_to_thread,
.wait_native_event =
iree_hal_cuda_deferred_work_queue_symbol_table_wait_native_event,
.create_native_event =
iree_hal_cuda_deferred_work_queue_symbol_table_create_native_event,
.record_native_event =
iree_hal_cuda_deferred_work_queue_symbol_table_record_native_event,
.synchronize_native_event =
iree_hal_cuda_deferred_work_queue_symbol_table_synchronize_native_event,
.destroy_native_event =
iree_hal_cuda_deferred_work_queue_symbol_table_destroy_native_event,
.semaphore_acquire_timepoint_device_signal_native_event =
iree_hal_cuda_deferred_work_queue_symbol_table_semaphore_acquire_timepoint_device_signal_native_event,
.acquire_host_wait_event =
iree_hal_cuda_deferred_work_queue_symbol_table_acquire_host_wait_event,
.release_wait_event =
iree_hal_cuda_deferred_work_queue_symbol_table_release_wait_event,
.native_event_from_wait_event =
iree_hal_cuda_deferred_work_queue_symbol_table_native_event_from_wait_event,
.create_command_buffer_for_deferred =
iree_hal_cuda_deferred_work_queue_symbol_table_create_command_buffer_for_deferred,
.submit_command_buffer =
iree_hal_cuda_deferred_work_queue_symbol_table_submit_command_buffer,
};
22 changes: 11 additions & 11 deletions runtime/src/iree/hal/drivers/cuda/event_semaphore.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/hal/drivers/cuda/cuda_dynamic_symbols.h"
#include "iree/hal/drivers/cuda/cuda_status_util.h"
#include "iree/hal/drivers/cuda/timepoint_pool.h"
#include "iree/hal/utils/deferred_work_queue.h"
#include "iree/hal/utils/semaphore_base.h"

typedef struct iree_hal_cuda_semaphore_t {
Expand All @@ -28,7 +29,7 @@ typedef struct iree_hal_cuda_semaphore_t {

// The list of pending queue actions that this semaphore need to advance on
// new signaled values.
iree_hal_cuda_pending_queue_actions_t* pending_queue_actions;
iree_hal_deferred_work_queue_t* work_queue;

// Guards value and status. We expect low contention on semaphores and since
// iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler
Expand Down Expand Up @@ -57,11 +58,11 @@ static iree_hal_cuda_semaphore_t* iree_hal_cuda_semaphore_cast(
iree_status_t iree_hal_cuda_event_semaphore_create(
uint64_t initial_value, const iree_hal_cuda_dynamic_symbols_t* symbols,
iree_hal_cuda_timepoint_pool_t* timepoint_pool,
iree_hal_cuda_pending_queue_actions_t* pending_queue_actions,
iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) {
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
iree_hal_semaphore_t** out_semaphore) {
IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(timepoint_pool);
IREE_ASSERT_ARGUMENT(pending_queue_actions);
IREE_ASSERT_ARGUMENT(work_queue);
IREE_ASSERT_ARGUMENT(out_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);

Expand All @@ -75,7 +76,7 @@ iree_status_t iree_hal_cuda_event_semaphore_create(
semaphore->host_allocator = host_allocator;
semaphore->symbols = symbols;
semaphore->timepoint_pool = timepoint_pool;
semaphore->pending_queue_actions = pending_queue_actions;
semaphore->work_queue = work_queue;
iree_slim_mutex_initialize(&semaphore->mutex);
semaphore->current_value = initial_value;
semaphore->failure_status = iree_ok_status();
Expand Down Expand Up @@ -149,10 +150,10 @@ static iree_status_t iree_hal_cuda_semaphore_signal(
// Notify timepoints - note that this must happen outside the lock.
iree_hal_semaphore_notify(&semaphore->base, new_value, IREE_STATUS_OK);

// Advance the pending queue actions if possible. This also must happen
// Advance the deferred work queue if possible. This also must happen
// outside the lock to avoid nesting.
iree_status_t status = iree_hal_cuda_pending_queue_actions_issue(
semaphore->pending_queue_actions);
iree_status_t status =
iree_hal_deferred_work_queue_issue(semaphore->work_queue);

IREE_TRACE_ZONE_END(z0);
return status;
Expand Down Expand Up @@ -188,10 +189,9 @@ static void iree_hal_cuda_semaphore_fail(iree_hal_semaphore_t* base_semaphore,
iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE,
status_code);

// Advance the pending queue actions if possible. This also must happen
// Advance the deferred work queue if possible. This also must happen
// outside the lock to avoid nesting.
status = iree_hal_cuda_pending_queue_actions_issue(
semaphore->pending_queue_actions);
status = iree_hal_deferred_work_queue_issue(semaphore->work_queue);
iree_status_ignore(status);

IREE_TRACE_ZONE_END(z0);
Expand Down
Loading
Loading