From 69129e5f00634758794c56fd8848e45152e29ab5 Mon Sep 17 00:00:00 2001 From: CodeSlinger Date: Tue, 30 Jan 2024 09:03:20 -0600 Subject: [PATCH] Fix clock overflow issue reported by Alex Wells. Remove unnecessary condition checks and unused variables (#56) * Added "Profile PyTorch" section in README.md * Fix clock overflow issue reported by Alex Wells. Remove unnecessary condition checks and unused variables * Remove unused argument 'result' in OnEnter*() calls * Remove dead code --- .../unitrace/scripts/gen_tracing_callbacks.py | 4 +- tools/unitrace/src/levelzero/ze_collector.h | 322 +++++++++--------- 2 files changed, 164 insertions(+), 162 deletions(-) diff --git a/tools/unitrace/scripts/gen_tracing_callbacks.py b/tools/unitrace/scripts/gen_tracing_callbacks.py index f64509b..f15c4e4 100644 --- a/tools/unitrace/scripts/gen_tracing_callbacks.py +++ b/tools/unitrace/scripts/gen_tracing_callbacks.py @@ -360,7 +360,7 @@ def gen_enter_callback(f, func, command_list_func_list, command_queue_func_list, if (cb != ""): f.write(" if (collector->options_.kernel_tracing) { \n") if (func in synchronize_func_list): - f.write(" " + cb + "(params, result, global_user_data, instance_user_data, &kids); \n") + f.write(" " + cb + "(params, global_user_data, instance_user_data, &kids); \n") f.write(" if (kids.size() != 0) {\n") f.write(" ze_instance_data.kid = kids[0];\n") # pass kid to the exit callback f.write(" }\n") @@ -368,7 +368,7 @@ def gen_enter_callback(f, func, command_list_func_list, command_queue_func_list, f.write(" ze_instance_data.kid = (uint64_t)(-1);\n") f.write(" }\n") else: - f.write(" " + cb + "(params, result, global_user_data, instance_user_data); \n") + f.write(" " + cb + "(params, global_user_data, instance_user_data); \n") f.write(" }\n") f.write("\n") f.write(" PTI_ASSERT(collector->correlator_ != nullptr);\n") diff --git a/tools/unitrace/src/levelzero/ze_collector.h b/tools/unitrace/src/levelzero/ze_collector.h index ee08782..de76962 100644 --- a/tools/unitrace/src/levelzero/ze_collector.h +++ b/tools/unitrace/src/levelzero/ze_collector.h @@ -180,8 +180,9 @@ struct ZeMetricQueryPools { }; struct ZeInstanceData { - uint64_t start_time_host; - uint64_t end_time_host; + uint64_t start_time_host; // in ns + uint64_t timestamp_host; // in ns + uint64_t timestamp_device; // in ticks uint64_t kid; // passing kid from enter callback to exit callback }; @@ -680,7 +681,7 @@ struct ZeCommand { uint64_t metric_timer_mask_; uint64_t append_time_ = 0; uint64_t submit_time_ = 0; //in ns - uint64_t submit_time_device_ = 0; //in ns + uint64_t submit_time_device_ = 0; //in ticks ze_command_list_handle_t command_list_ = nullptr; ze_command_queue_handle_t queue_ = nullptr; ze_fence_handle_t fence_; @@ -703,8 +704,10 @@ std::set *global_device_submissions_ = nullptr; struct ZeDeviceSubmissions { std::list commands_submitted_; + std::list commands_staged_; std::list commands_free_pool_; std::list metric_queries_submitted_; + std::list metric_queries_staged_; std::list metric_queries_free_pool_; std::map device_time_stats_; std::map host_time_stats_; @@ -754,6 +757,10 @@ struct ZeDeviceSubmissions { } } + inline void StageKernelCommand(ZeCommand *command) { + commands_staged_.push_back(command); + } + inline ZeCommand *GetKernelCommand(void) { ZeCommand *command; @@ -778,6 +785,52 @@ struct ZeDeviceSubmissions { } } + inline void StageCommandMetricQuery(ZeCommandMetricQuery *query) { + metric_queries_staged_.push_back(query); + } + + inline void SubmitStagedKernelCommandAndMetricQueries(ZeEventCache& event_cache, std::vector *kids) { + auto cit = commands_staged_.begin(); + auto mit = metric_queries_staged_.begin(); + for (; cit != commands_staged_.end(); cit++, mit++) { + ZeCommand *cmd = *cit; + ZeCommandMetricQuery *cmd_query = *mit; + + // back fill kernel instance id and reset event + cmd->instance_id_ = UniKernelInstanceId::GetKernelInstanceId(); + event_cache.ResetEvent(cmd->event_); + + if (kids) { + kids->push_back(cmd->instance_id_); + } + SubmitKernelCommand(cmd); + + if (cmd_query != nullptr) { + cmd_query->instance_id_ = cmd->instance_id_; + event_cache.ResetEvent(cmd_query->metric_query_event_); + SubmitCommandMetricQuery(cmd_query); + } + } + commands_staged_.clear(); + metric_queries_staged_.clear(); + } + + inline void RevertStagedKernelCommandAndMetricQueries(void) { + auto cit = commands_staged_.begin(); + auto mit = metric_queries_staged_.begin(); + for (; cit != commands_staged_.end(); cit++, mit++) { + ZeCommand *cmd = *cit; + ZeCommandMetricQuery *cmd_query = *mit; + + commands_free_pool_.push_back(cmd); + if (cmd_query != nullptr) { + metric_queries_free_pool_.push_back(cmd_query); + } + } + commands_staged_.clear(); + metric_queries_staged_.clear(); + } + inline ZeCommandMetricQuery *GetCommandMetricQuery(void) { ZeCommandMetricQuery *query; @@ -2021,7 +2074,7 @@ class ZeCollector { if (start <= end) { duration = (end - start) * static_cast(NSEC_IN_SEC) / freq; } else { // Timer Overflow - duration = ((mask + 1ull) + end - start) * static_cast(NSEC_IN_SEC) / freq; + duration = (mask - start + 1 + end) * static_cast(NSEC_IN_SEC) / freq; } return duration; } @@ -2029,32 +2082,23 @@ class ZeCollector { inline void GetHostTime(const ZeCommand *command, const ze_kernel_timestamp_result_t& ts, uint64_t& start, uint64_t& end) { uint64_t device_freq = command->device_timer_frequency_; uint64_t device_mask = command->device_timer_mask_; - uint64_t tspan = (device_mask + 1ull) * NSEC_IN_SEC / device_freq; // time span of universe uint64_t device_start = ts.global.kernelStart & device_mask; uint64_t device_end = ts.global.kernelEnd & device_mask; - uint64_t start_ns = (device_start * NSEC_IN_SEC / device_freq); - uint64_t device_submit_time = command->submit_time_device_; + uint64_t device_submit_time = (command->submit_time_device_ & device_mask); - int64_t time_shift = start_ns - device_submit_time; + uint64_t time_shift; - if (start_ns > device_submit_time) { - uint64_t diff = device_submit_time + tspan - start_ns; - if (diff < time_shift) { - // overflow - time_shift = -diff; - } + if (device_start > device_submit_time) { + time_shift = (device_start - device_submit_time) * NSEC_IN_SEC / device_freq; } else { - uint64_t diff = start_ns + tspan - device_submit_time; - if (diff < (-time_shift)) { - // overflow - time_shift = diff; - } + // overflow + time_shift = (device_mask - device_submit_time + 1 + device_start) * NSEC_IN_SEC / device_freq; } - int64_t duration = ComputeDuration(device_start, device_end, device_freq, device_mask); + uint64_t duration = ComputeDuration(device_start, device_end, device_freq, device_mask); start = command->submit_time_ + time_shift; end = start + duration; @@ -2311,14 +2355,8 @@ class ZeCollector { command_lists_mutex_.unlock(); } - void ExecuteCommandLists( - ze_command_list_handle_t *cmdlists, uint32_t count, - ze_command_queue_handle_t queue, ze_fence_handle_t fence, const uint64_t ts, - std::vector *kids) { - - if (local_device_submissions_.IsFinalized()) { - return; - } + void PrepareToExecuteCommandLists( + ze_command_list_handle_t *cmdlists, uint32_t count, ze_command_queue_handle_t queue, ze_fence_handle_t fence) { command_queues_mutex_.lock_shared(); auto qit = command_queues_.find(queue); @@ -2327,9 +2365,9 @@ class ZeCollector { for (uint32_t i = 0; i < count; i++) { ze_command_list_handle_t cmdlist = cmdlists[i]; - + auto it = command_lists_.find(cmdlist); - + if (it == command_lists_.end()) { std::cerr << "[ERROR] Command list (" << cmdlist << ") is not found to execute." << std::endl; continue; @@ -2339,7 +2377,7 @@ class ZeCollector { for (auto command : it->second->commands_) { ZeCommand *cmd = nullptr; ZeCommandMetricQuery *cmd_query = nullptr; - + cmd = local_device_submissions_.GetKernelCommand(); if (command->command_metric_query_ != nullptr) { @@ -2348,35 +2386,28 @@ class ZeCollector { *cmd = *command; uint64_t host_timestamp; - uint64_t ticks; uint64_t device_timestamp; ze_result_t status; - status = zeDeviceGetGlobalTimestamps(cmd->device_, &host_timestamp, &ticks); + status = zeDeviceGetGlobalTimestamps(cmd->device_, &host_timestamp, &device_timestamp); PTI_ASSERT(status == ZE_RESULT_SUCCESS); - device_timestamp = ticks & cmd->device_timer_mask_; - device_timestamp = device_timestamp * NSEC_IN_SEC / cmd->device_timer_frequency_; - cmd->engine_ordinal_ = qit->second.engine_ordinal_; cmd->engine_index_ = qit->second.engine_index_; - cmd->instance_id_ = UniKernelInstanceId::GetKernelInstanceId(); cmd->submit_time_ = host_timestamp; //in ns - cmd->submit_time_device_ = device_timestamp;//in ns + cmd->submit_time_device_ = device_timestamp; //in ticks cmd->tid_ = utils::GetTid();; cmd->fence_ = fence; - event_cache_.ResetEvent(cmd->event_); - local_device_submissions_.SubmitKernelCommand(cmd); - - if (kids) { - kids->push_back(cmd->instance_id_); - } + // Exit callback will reset cmd->event_ and backfill cmd->instance_id_ + local_device_submissions_.StageKernelCommand(cmd); if (cmd_query) { *cmd_query = *(command->command_metric_query_); - cmd_query->instance_id_ = cmd->instance_id_; - event_cache_.ResetEvent(cmd_query->metric_query_event_); - local_device_submissions_.SubmitCommandMetricQuery(cmd_query); + // Exit callback will reset cmd_query->metric_query_event_ and backfill cmd_query->instance_id_ + local_device_submissions_.StageCommandMetricQuery(cmd_query); + } + else { + local_device_submissions_.StageCommandMetricQuery(nullptr); } } } @@ -2419,7 +2450,7 @@ class ZeCollector { private: // Callbacks - static void OnEnterEventPoolCreate(ze_event_pool_create_params_t *params, ze_result_t result, void *global_data, void **instance_data) { + static void OnEnterEventPoolCreate(ze_event_pool_create_params_t *params, void *global_data, void **instance_data) { const ze_event_pool_desc_t* desc = *(params->pdesc); if (desc == nullptr) { return; @@ -2454,7 +2485,7 @@ class ZeCollector { static void OnEnterEventDestroy( ze_event_destroy_params_t *params, - ze_result_t result, void *global_data, void **instance_data, std::vector *kids) { + void *global_data, void **instance_data, std::vector *kids) { if (*(params->phEvent) != nullptr) { ZeCollector* collector = reinterpret_cast(global_data); @@ -2465,7 +2496,7 @@ class ZeCollector { } static void OnEnterEventHostReset( - ze_event_host_reset_params_t *params, ze_result_t result, + ze_event_host_reset_params_t *params, void *global_data, void **instance_data, std::vector *kids) { if (*(params->phEvent) != nullptr) { ZeCollector* collector = reinterpret_cast(global_data); @@ -2584,6 +2615,7 @@ class ZeCollector { PTI_ASSERT(signal_event != nullptr); } + ze_result_t status; zet_metric_query_handle_t query = nullptr; if (collector->options_.metric_query && iskernel) { devices_mutex_.lock_shared(); @@ -2595,10 +2627,19 @@ class ZeCollector { devices_mutex_.unlock_shared(); - ze_result_t status = zetCommandListAppendMetricQueryBegin(command_list, query); + status = zetCommandListAppendMetricQueryBegin(command_list, query); PTI_ASSERT(status == ZE_RESULT_SUCCESS); } + uint64_t host_timestamp; + uint64_t device_timestamp; // in ticks + + status = zeDeviceGetGlobalTimestamps(device, &host_timestamp, &device_timestamp); + PTI_ASSERT(status == ZE_RESULT_SUCCESS); + + ze_instance_data.timestamp_host = host_timestamp; + ze_instance_data.timestamp_device = device_timestamp; + return query; } @@ -2677,23 +2718,13 @@ class ZeCollector { desc_query->device_ = it->second->device_; } + uint64_t host_timestamp = ze_instance_data.timestamp_host; if (it->second->immediate_) { - uint64_t host_timestamp; - uint64_t ticks; - uint64_t device_timestamp; - ze_result_t status; - - status = zeDeviceGetGlobalTimestamps(desc->device_, &host_timestamp, &ticks); - PTI_ASSERT(status == ZE_RESULT_SUCCESS); - - device_timestamp = ticks & desc->device_timer_mask_; - device_timestamp = device_timestamp * NSEC_IN_SEC / desc->device_timer_frequency_; - desc->immediate_ = true; desc->instance_id_ = UniKernelInstanceId::GetKernelInstanceId(); desc->append_time_ = host_timestamp; desc->submit_time_ = host_timestamp; - desc->submit_time_device_ = device_timestamp; + desc->submit_time_device_ = ze_instance_data.timestamp_device; // append time and submit time are the same desc->command_metric_query_ = nullptr; // don't care metric query in case of immediate command list local_device_submissions_.SubmitKernelCommand(desc); kids->push_back(desc->instance_id_); @@ -2705,8 +2736,6 @@ class ZeCollector { } } else { - uint64_t host_timestamp = ze_instance_data.start_time_host; - desc->append_time_ = host_timestamp; desc->immediate_ = false; desc->command_metric_query_ = desc_query; // need metric query upon submission @@ -2799,23 +2828,13 @@ class ZeCollector { desc_query->device_ = it->second->device_; } + uint64_t host_timestamp = ze_instance_data.timestamp_host; if (it->second->immediate_) { - uint64_t host_timestamp; - uint64_t ticks; - uint64_t device_timestamp; - ze_result_t status; - - status = zeDeviceGetGlobalTimestamps(desc->device_, &host_timestamp, &ticks); - PTI_ASSERT(status == ZE_RESULT_SUCCESS); - - device_timestamp = ticks & desc->device_timer_mask_; - device_timestamp = device_timestamp * NSEC_IN_SEC / desc->device_timer_frequency_; - desc->immediate_ = true; desc->instance_id_ = UniKernelInstanceId::GetKernelInstanceId(); desc->append_time_ = host_timestamp; desc->submit_time_ = host_timestamp; - desc->submit_time_device_ = device_timestamp; + desc->submit_time_device_ = ze_instance_data.timestamp_device; // append time and submit time are the same desc->command_metric_query_ = nullptr; // do not care metric query in case of immediate command list local_device_submissions_.SubmitKernelCommand(desc); @@ -2829,8 +2848,6 @@ class ZeCollector { } } else { - uint64_t host_timestamp = ze_instance_data.start_time_host; - desc->append_time_ = host_timestamp; desc->immediate_ = false; desc->command_metric_query_ = desc_query; @@ -2918,23 +2935,14 @@ class ZeCollector { desc_query->metric_query_ = query; desc_query->metric_query_event_ = metric_query_event; } - if (it->second->immediate_) { - uint64_t host_timestamp; - uint64_t ticks; - uint64_t device_timestamp; - ze_result_t status; - - status = zeDeviceGetGlobalTimestamps(desc->device_, &host_timestamp, &ticks); - PTI_ASSERT(status == ZE_RESULT_SUCCESS); - - device_timestamp = ticks & desc->device_timer_mask_; - device_timestamp = device_timestamp * NSEC_IN_SEC / desc->device_timer_frequency_; + uint64_t host_timestamp = ze_instance_data.timestamp_host; + if (it->second->immediate_) { desc->immediate_ = true; desc->instance_id_ = UniKernelInstanceId::GetKernelInstanceId(); desc->append_time_ = host_timestamp; desc->submit_time_ = host_timestamp; - desc->submit_time_device_ = device_timestamp; + desc->submit_time_device_ = ze_instance_data.timestamp_device; // append time and submit time are the same desc->command_metric_query_ = nullptr; local_device_submissions_.SubmitKernelCommand(desc); @@ -2947,8 +2955,6 @@ class ZeCollector { } } else { - uint64_t host_timestamp = ze_instance_data.start_time_host; - desc->append_time_ = host_timestamp; desc->immediate_ = false; desc->command_metric_query_ = desc_query; @@ -3040,23 +3046,14 @@ class ZeCollector { desc_query->metric_query_ = query; desc_query->metric_query_event_ = metric_query_event; } - if (it->second->immediate_) { - uint64_t host_timestamp; - uint64_t ticks; - uint64_t device_timestamp; - ze_result_t status; - - status = zeDeviceGetGlobalTimestamps(desc->device_, &host_timestamp, &ticks); - PTI_ASSERT(status == ZE_RESULT_SUCCESS); - - device_timestamp = ticks & desc->device_timer_mask_; - device_timestamp = device_timestamp * NSEC_IN_SEC / desc->device_timer_frequency_; + uint64_t host_timestamp = ze_instance_data.timestamp_host; + if (it->second->immediate_) { desc->immediate_ = true; desc->instance_id_ = UniKernelInstanceId::GetKernelInstanceId(); desc->append_time_ = host_timestamp; desc->submit_time_ = host_timestamp; - desc->submit_time_device_ = device_timestamp; + desc->submit_time_device_ = ze_instance_data.timestamp_device; // append time and submit time are the same desc->command_metric_query_ = nullptr; local_device_submissions_.SubmitKernelCommand(desc); @@ -3069,8 +3066,6 @@ class ZeCollector { } } else { - uint64_t host_timestamp = ze_instance_data.start_time_host; - desc->append_time_ = host_timestamp; desc->immediate_ = false; desc->command_metric_query_ = desc_query; @@ -3147,23 +3142,14 @@ class ZeCollector { ze_result_t status = zetCommandListAppendMetricQueryEnd(command_list, query, metric_query_event, 0, nullptr); PTI_ASSERT(status == ZE_RESULT_SUCCESS); } - if (it->second->immediate_) { - uint64_t host_timestamp; - uint64_t ticks; - uint64_t device_timestamp; - ze_result_t status; - - status = zeDeviceGetGlobalTimestamps(desc->device_, &host_timestamp, &ticks); - PTI_ASSERT(status == ZE_RESULT_SUCCESS); - - device_timestamp = ticks & desc->device_timer_mask_; - device_timestamp = device_timestamp * NSEC_IN_SEC / desc->device_timer_frequency_; + uint64_t host_timestamp = ze_instance_data.timestamp_host; + if (it->second->immediate_) { desc->immediate_ = true; desc->instance_id_ = UniKernelInstanceId::GetKernelInstanceId(); desc->append_time_ = host_timestamp; desc->submit_time_ = host_timestamp; - desc->submit_time_device_ = device_timestamp; + desc->submit_time_device_ = ze_instance_data.timestamp_device; // append time and submit time are the same desc->command_metric_query_ = nullptr; local_device_submissions_.SubmitKernelCommand(desc); @@ -3175,8 +3161,6 @@ class ZeCollector { } } else { - uint64_t host_timestamp = ze_instance_data.start_time_host; - desc->append_time_ = host_timestamp; desc->immediate_ = false; desc->command_metric_query_ = desc_query; @@ -3196,8 +3180,8 @@ class ZeCollector { static void OnEnterCommandListAppendLaunchKernel( ze_command_list_append_launch_kernel_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), true); *instance_data = reinterpret_cast(query); @@ -3231,8 +3215,8 @@ class ZeCollector { static void OnEnterCommandListAppendLaunchCooperativeKernel( ze_command_list_append_launch_cooperative_kernel_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), true); *instance_data = reinterpret_cast(query); @@ -3265,8 +3249,8 @@ class ZeCollector { static void OnEnterCommandListAppendLaunchKernelIndirect( ze_command_list_append_launch_kernel_indirect_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), true); *instance_data = reinterpret_cast(query); @@ -3387,8 +3371,8 @@ class ZeCollector { static void OnEnterCommandListAppendMemoryCopy( ze_command_list_append_memory_copy_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3416,8 +3400,8 @@ class ZeCollector { static void OnEnterCommandListAppendMemoryFill( ze_command_list_append_memory_fill_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3445,8 +3429,8 @@ class ZeCollector { static void OnEnterCommandListAppendBarrier( ze_command_list_append_barrier_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3473,8 +3457,8 @@ class ZeCollector { static void OnEnterCommandListAppendMemoryRangesBarrier( ze_command_list_append_memory_ranges_barrier_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3501,8 +3485,8 @@ class ZeCollector { static void OnEnterCommandListAppendMemoryCopyRegion( ze_command_list_append_memory_copy_region_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3540,8 +3524,8 @@ class ZeCollector { static void OnEnterCommandListAppendMemoryCopyFromContext( ze_command_list_append_memory_copy_from_context_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3571,8 +3555,8 @@ class ZeCollector { static void OnEnterCommandListAppendImageCopy( ze_command_list_append_image_copy_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3600,8 +3584,8 @@ class ZeCollector { static void OnEnterCommandListAppendImageCopyRegion( ze_command_list_append_image_copy_region_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3629,8 +3613,8 @@ class ZeCollector { static void OnEnterCommandListAppendImageCopyToMemory( ze_command_list_append_image_copy_to_memory_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3658,8 +3642,8 @@ class ZeCollector { static void OnEnterCommandListAppendImageCopyFromMemory( ze_command_list_append_image_copy_from_memory_params_t* params, - ze_result_t result, void* global_data, void** instance_data) { - if ((result == ZE_RESULT_SUCCESS) && (UniController::IsCollectionEnabled())) { + void* global_data, void** instance_data) { + if (UniController::IsCollectionEnabled()) { ZeCollector* collector = reinterpret_cast(global_data); zet_metric_query_handle_t query = PrepareToAppendKernelCommand(collector, *(params->phSignalEvent), *(params->phCommandList), false); *instance_data = reinterpret_cast(query); @@ -3752,6 +3736,32 @@ class ZeCollector { } } + static void OnEnterCommandQueueExecuteCommandLists( + ze_command_queue_execute_command_lists_params_t* params, + void* global_data, void** instance_data) { + + ZeCollector* collector = reinterpret_cast(global_data); + + if (UniController::IsCollectionEnabled()) { + uint32_t count = *params->pnumCommandLists; + if (count == 0) { + return; + } + + ze_command_list_handle_t* cmdlists = *params->pphCommandLists; + if (cmdlists == nullptr) { + return; + } + + if (local_device_submissions_.IsFinalized()) { + return; + } + + ze_command_queue_handle_t queue = *(params->phCommandQueue); + collector->PrepareToExecuteCommandLists(cmdlists, count, queue, *(params->phFence)); + } + } + static void OnExitCommandQueueExecuteCommandLists( ze_command_queue_execute_command_lists_params_t* params, ze_result_t result, void* global_data, void** instance_data, std::vector *kids) { @@ -3760,20 +3770,12 @@ class ZeCollector { ZeCollector* collector = reinterpret_cast(global_data); if (UniController::IsCollectionEnabled()) { - uint32_t count = *params->pnumCommandLists; - if (count == 0) { - return; - } - - ze_command_list_handle_t* cmdlists = *params->pphCommandLists; - if (cmdlists == nullptr) { - return; - } - - uint64_t ts = ze_instance_data.start_time_host; - collector->ExecuteCommandLists(cmdlists, count, *(params->phCommandQueue), *(params->phFence), ts, kids); + local_device_submissions_.SubmitStagedKernelCommandAndMetricQueries(collector->event_cache_, kids); } } + else { + local_device_submissions_.RevertStagedKernelCommandAndMetricQueries(); + } } static void OnExitCommandQueueSynchronize( @@ -3839,7 +3841,7 @@ class ZeCollector { } } - static void OnEnterModuleDestroy(ze_module_destroy_params_t* params, ze_result_t result, void* global_data, void** instance_user_data) { + static void OnEnterModuleDestroy(ze_module_destroy_params_t* params, void* global_data, void** instance_user_data) { ZeCollector* collector = reinterpret_cast(global_data); ze_module_handle_t mod = *(params->phModule); modules_on_devices_mutex_.lock();