10
10
#include " command_buffer.hpp"
11
11
#include " helpers/kernel_helpers.hpp"
12
12
#include " logger/ur_logger.hpp"
13
+ #include " ur_api.h"
13
14
#include " ur_interface_loader.hpp"
14
15
#include " ur_level_zero.hpp"
15
16
@@ -170,6 +171,65 @@ ur_result_t getEventsFromSyncPoints(
170
171
return UR_RESULT_SUCCESS;
171
172
}
172
173
174
+ /* *
175
+ * If necessary, it creates a signal event and appends it to the previous
176
+ * command list (copy or compute), to indicate when it's finished executing.
177
+ * @param[in] CommandBuffer The CommandBuffer where the command is appended.
178
+ * @param[in] ZeCommandList the CommandList that's currently in use.
179
+ * @param[out] WaitEventList The list of event for the future command list to
180
+ * wait on before execution.
181
+ * @return UR_RESULT_SUCCESS or an error code on failure
182
+ */
183
+ ur_result_t createSyncPointBetweenCopyAndCompute (
184
+ ur_exp_command_buffer_handle_t CommandBuffer,
185
+ ze_command_list_handle_t ZeCommandList,
186
+ std::vector<ze_event_handle_t > &WaitEventList) {
187
+
188
+ if (!CommandBuffer->ZeCopyCommandList ) {
189
+ return UR_RESULT_SUCCESS;
190
+ }
191
+
192
+ bool IsCopy{ZeCommandList == CommandBuffer->ZeCopyCommandList };
193
+
194
+ // Skip synchronization for the first node in a graph or if the current
195
+ // command list matches the previous one.
196
+ if (!CommandBuffer->MWasPrevCopyCommandList .has_value ()) {
197
+ CommandBuffer->MWasPrevCopyCommandList = IsCopy;
198
+ return UR_RESULT_SUCCESS;
199
+ } else if (IsCopy == CommandBuffer->MWasPrevCopyCommandList ) {
200
+ return UR_RESULT_SUCCESS;
201
+ }
202
+
203
+ /*
204
+ * If the current CommandList differs from the previously used one, we must
205
+ * append a signal event to the previous CommandList to track when
206
+ * its execution is complete.
207
+ */
208
+ ur_event_handle_t SignalPrevCommandEvent = nullptr ;
209
+ UR_CALL (EventCreate (CommandBuffer->Context , nullptr /* Queue*/ ,
210
+ false /* IsMultiDevice*/ , false , &SignalPrevCommandEvent,
211
+ false /* CounterBasedEventEnabled*/ ,
212
+ !CommandBuffer->IsProfilingEnabled ,
213
+ false /* InterruptBasedEventEnabled*/ ));
214
+
215
+ // Determine which command list to signal.
216
+ auto CommandListToSignal = (!IsCopy && CommandBuffer->MWasPrevCopyCommandList )
217
+ ? CommandBuffer->ZeCopyCommandList
218
+ : CommandBuffer->ZeComputeCommandList ;
219
+ CommandBuffer->MWasPrevCopyCommandList = IsCopy;
220
+
221
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
222
+ (CommandListToSignal, SignalPrevCommandEvent->ZeEvent ));
223
+
224
+ // Add the event to the dependencies for future command list to wait on.
225
+ WaitEventList.push_back (SignalPrevCommandEvent->ZeEvent );
226
+
227
+ // Mark the event for future reset.
228
+ CommandBuffer->ZeEventsList .push_back (SignalPrevCommandEvent->ZeEvent );
229
+
230
+ return UR_RESULT_SUCCESS;
231
+ }
232
+
173
233
/* *
174
234
* If needed, creates a sync point for a given command and returns the L0
175
235
* events associated with the sync point.
@@ -190,7 +250,7 @@ ur_result_t getEventsFromSyncPoints(
190
250
*/
191
251
ur_result_t createSyncPointAndGetZeEvents (
192
252
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
193
- uint32_t NumSyncPointsInWaitList,
253
+ ze_command_list_handle_t ZeCommandList, uint32_t NumSyncPointsInWaitList,
194
254
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
195
255
bool HostVisible, ur_exp_command_buffer_sync_point_t *RetSyncPoint,
196
256
std::vector<ze_event_handle_t > &ZeEventList,
@@ -199,6 +259,11 @@ ur_result_t createSyncPointAndGetZeEvents(
199
259
ZeLaunchEvent = nullptr ;
200
260
201
261
if (CommandBuffer->IsInOrderCmdList ) {
262
+ UR_CALL (createSyncPointBetweenCopyAndCompute (CommandBuffer, ZeCommandList,
263
+ ZeEventList));
264
+ if (!ZeEventList.empty ()) {
265
+ NumSyncPointsInWaitList = ZeEventList.size ();
266
+ }
202
267
return UR_RESULT_SUCCESS;
203
268
}
204
269
@@ -225,24 +290,24 @@ ur_result_t createSyncPointAndGetZeEvents(
225
290
return UR_RESULT_SUCCESS;
226
291
}
227
292
228
- // Shared by all memory read/write/copy PI interfaces.
229
- // Helper function for common code when enqueuing memory operations to a command
230
- // buffer.
293
+ // Shared by all memory read/write/copy UR interfaces.
294
+ // Helper function for common code when enqueuing memory operations to a
295
+ // command buffer.
231
296
ur_result_t enqueueCommandBufferMemCopyHelper (
232
297
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
233
298
void *Dst, const void *Src, size_t Size , bool PreferCopyEngine,
234
299
uint32_t NumSyncPointsInWaitList,
235
300
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
236
301
ur_exp_command_buffer_sync_point_t *RetSyncPoint) {
237
302
303
+ ze_command_list_handle_t ZeCommandList =
304
+ CommandBuffer->chooseCommandList (PreferCopyEngine);
305
+
238
306
std::vector<ze_event_handle_t > ZeEventList;
239
307
ze_event_handle_t ZeLaunchEvent = nullptr ;
240
308
UR_CALL (createSyncPointAndGetZeEvents (
241
- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
242
- false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
243
-
244
- ze_command_list_handle_t ZeCommandList =
245
- CommandBuffer->chooseCommandList (PreferCopyEngine);
309
+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
310
+ SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
246
311
247
312
ZE2UR_CALL (zeCommandListAppendMemoryCopy,
248
313
(ZeCommandList, Dst, Src, Size , ZeLaunchEvent, ZeEventList.size (),
@@ -293,14 +358,14 @@ ur_result_t enqueueCommandBufferMemCopyRectHelper(
293
358
const ze_copy_region_t ZeDstRegion = {DstOriginX, DstOriginY, DstOriginZ,
294
359
Width, Height, Depth};
295
360
361
+ ze_command_list_handle_t ZeCommandList =
362
+ CommandBuffer->chooseCommandList (PreferCopyEngine);
363
+
296
364
std::vector<ze_event_handle_t > ZeEventList;
297
365
ze_event_handle_t ZeLaunchEvent = nullptr ;
298
366
UR_CALL (createSyncPointAndGetZeEvents (
299
- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
300
- false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
301
-
302
- ze_command_list_handle_t ZeCommandList =
303
- CommandBuffer->chooseCommandList (PreferCopyEngine);
367
+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
368
+ SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
304
369
305
370
ZE2UR_CALL (zeCommandListAppendMemoryCopyRegion,
306
371
(ZeCommandList, Dst, &ZeDstRegion, DstPitch, DstSlicePitch, Src,
@@ -321,19 +386,19 @@ ur_result_t enqueueCommandBufferFillHelper(
321
386
UR_ASSERT ((PatternSize > 0 ) && ((PatternSize & (PatternSize - 1 )) == 0 ),
322
387
UR_RESULT_ERROR_INVALID_VALUE);
323
388
324
- std::vector<ze_event_handle_t > ZeEventList;
325
- ze_event_handle_t ZeLaunchEvent = nullptr ;
326
- UR_CALL (createSyncPointAndGetZeEvents (
327
- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
328
- true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
329
-
330
389
bool PreferCopyEngine;
331
390
UR_CALL (
332
391
preferCopyEngineForFill (CommandBuffer, PatternSize, PreferCopyEngine));
333
392
334
393
ze_command_list_handle_t ZeCommandList =
335
394
CommandBuffer->chooseCommandList (PreferCopyEngine);
336
395
396
+ std::vector<ze_event_handle_t > ZeEventList;
397
+ ze_event_handle_t ZeLaunchEvent = nullptr ;
398
+ UR_CALL (createSyncPointAndGetZeEvents (
399
+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
400
+ SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
401
+
337
402
ZE2UR_CALL (zeCommandListAppendMemoryFill,
338
403
(ZeCommandList, Ptr , Pattern, PatternSize, Size , ZeLaunchEvent,
339
404
ZeEventList.size (), getPointerFromVector (ZeEventList)));
@@ -477,12 +542,12 @@ void ur_exp_command_buffer_handle_t_::registerSyncPoint(
477
542
478
543
ze_command_list_handle_t
479
544
ur_exp_command_buffer_handle_t_::chooseCommandList (bool PreferCopyEngine) {
480
- if (PreferCopyEngine && this -> useCopyEngine () && !this -> IsInOrderCmdList ) {
545
+ if (PreferCopyEngine && useCopyEngine () && !IsInOrderCmdList) {
481
546
// We indicate that ZeCopyCommandList contains commands to be submitted.
482
- this -> MCopyCommandListEmpty = false ;
483
- return this -> ZeCopyCommandList ;
547
+ MCopyCommandListEmpty = false ;
548
+ return ZeCopyCommandList;
484
549
}
485
- return this -> ZeComputeCommandList ;
550
+ return ZeComputeCommandList;
486
551
}
487
552
488
553
ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue (
@@ -646,7 +711,7 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
646
711
// the current implementation only uses the main copy engine and does not use
647
712
// the link engine even if available.
648
713
if (Device->hasMainCopyEngine ()) {
649
- UR_CALL (createMainCommandList (Context, Device, false , false , true ,
714
+ UR_CALL (createMainCommandList (Context, Device, IsInOrder , false , true ,
650
715
ZeCopyCommandList));
651
716
}
652
717
@@ -812,18 +877,25 @@ finalizeWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer) {
812
877
(CommandBuffer->ZeCommandListResetEvents ,
813
878
CommandBuffer->ExecutionFinishedEvent ->ZeEvent ));
814
879
815
- if (CommandBuffer->IsInOrderCmdList ) {
816
- ZE2UR_CALL (zeCommandListAppendSignalEvent,
817
- (CommandBuffer->ZeComputeCommandList ,
818
- CommandBuffer->ExecutionFinishedEvent ->ZeEvent ));
819
- } else {
820
- // Reset the L0 events we use for command-buffer sync-points to the
821
- // non-signaled state. This is required for multiple submissions.
880
+ // Reset the L0 events we use for command-buffer sync-points to the
881
+ // non-signaled state. This is required for multiple submissions.
882
+ auto resetEvents = [&CommandBuffer]() -> ur_result_t {
822
883
for (auto &Event : CommandBuffer->ZeEventsList ) {
823
884
ZE2UR_CALL (zeCommandListAppendEventReset,
824
885
(CommandBuffer->ZeCommandListResetEvents , Event));
825
886
}
887
+ return UR_RESULT_SUCCESS;
888
+ };
826
889
890
+ if (CommandBuffer->IsInOrderCmdList ) {
891
+ if (!CommandBuffer->MCopyCommandListEmpty ) {
892
+ resetEvents ();
893
+ }
894
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
895
+ (CommandBuffer->ZeComputeCommandList ,
896
+ CommandBuffer->ExecutionFinishedEvent ->ZeEvent ));
897
+ } else {
898
+ resetEvents ();
827
899
// Wait for all the user added commands to complete, and signal the
828
900
// command-buffer signal-event when they are done.
829
901
ZE2UR_CALL (zeCommandListAppendBarrier,
@@ -1073,7 +1145,8 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
1073
1145
std::vector<ze_event_handle_t > ZeEventList;
1074
1146
ze_event_handle_t ZeLaunchEvent = nullptr ;
1075
1147
UR_CALL (createSyncPointAndGetZeEvents (
1076
- UR_COMMAND_KERNEL_LAUNCH, CommandBuffer, NumSyncPointsInWaitList,
1148
+ UR_COMMAND_KERNEL_LAUNCH, CommandBuffer,
1149
+ CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1077
1150
SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1078
1151
1079
1152
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
@@ -1306,29 +1379,25 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
1306
1379
std::ignore = Command;
1307
1380
std::ignore = Flags;
1308
1381
1309
- if (CommandBuffer->IsInOrderCmdList ) {
1310
- // Add the prefetch command to the command-buffer.
1311
- // Note that L0 does not handle migration flags.
1312
- ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1313
- (CommandBuffer->ZeComputeCommandList , Mem, Size ));
1314
- } else {
1315
- std::vector<ze_event_handle_t > ZeEventList;
1316
- ze_event_handle_t ZeLaunchEvent = nullptr ;
1317
- UR_CALL (createSyncPointAndGetZeEvents (
1318
- UR_COMMAND_USM_PREFETCH, CommandBuffer, NumSyncPointsInWaitList,
1319
- SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1320
-
1321
- if (NumSyncPointsInWaitList) {
1322
- ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1323
- (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1324
- ZeEventList.data ()));
1325
- }
1382
+ std::vector<ze_event_handle_t > ZeEventList;
1383
+ ze_event_handle_t ZeLaunchEvent = nullptr ;
1384
+ UR_CALL (createSyncPointAndGetZeEvents (
1385
+ UR_COMMAND_USM_PREFETCH, CommandBuffer,
1386
+ CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1387
+ SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1388
+
1389
+ if (NumSyncPointsInWaitList) {
1390
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1391
+ (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1392
+ ZeEventList.data ()));
1393
+ }
1326
1394
1327
- // Add the prefetch command to the command-buffer.
1328
- // Note that L0 does not handle migration flags.
1329
- ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1330
- (CommandBuffer->ZeComputeCommandList , Mem, Size ));
1395
+ // Add the prefetch command to the command-buffer.
1396
+ // Note that L0 does not handle migration flags.
1397
+ ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1398
+ (CommandBuffer->ZeComputeCommandList , Mem, Size ));
1331
1399
1400
+ if (!CommandBuffer->IsInOrderCmdList ) {
1332
1401
// Level Zero does not have a completion "event" with the prefetch API,
1333
1402
// so manually add command to signal our event.
1334
1403
ZE2UR_CALL (zeCommandListAppendSignalEvent,
@@ -1376,27 +1445,24 @@ ur_result_t urCommandBufferAppendUSMAdviseExp(
1376
1445
1377
1446
ze_memory_advice_t ZeAdvice = static_cast <ze_memory_advice_t >(Value);
1378
1447
1379
- if (CommandBuffer->IsInOrderCmdList ) {
1380
- ZE2UR_CALL (zeCommandListAppendMemAdvise,
1381
- (CommandBuffer->ZeComputeCommandList ,
1382
- CommandBuffer->Device ->ZeDevice , Mem, Size , ZeAdvice));
1383
- } else {
1384
- std::vector<ze_event_handle_t > ZeEventList;
1385
- ze_event_handle_t ZeLaunchEvent = nullptr ;
1386
- UR_CALL (createSyncPointAndGetZeEvents (
1387
- UR_COMMAND_USM_ADVISE, CommandBuffer, NumSyncPointsInWaitList,
1388
- SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1389
-
1390
- if (NumSyncPointsInWaitList) {
1391
- ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1392
- (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1393
- ZeEventList.data ()));
1394
- }
1448
+ std::vector<ze_event_handle_t > ZeEventList;
1449
+ ze_event_handle_t ZeLaunchEvent = nullptr ;
1450
+ UR_CALL (createSyncPointAndGetZeEvents (
1451
+ UR_COMMAND_USM_ADVISE, CommandBuffer, CommandBuffer->ZeComputeCommandList ,
1452
+ NumSyncPointsInWaitList, SyncPointWaitList, true , RetSyncPoint,
1453
+ ZeEventList, ZeLaunchEvent));
1395
1454
1396
- ZE2UR_CALL (zeCommandListAppendMemAdvise,
1397
- (CommandBuffer->ZeComputeCommandList ,
1398
- CommandBuffer->Device ->ZeDevice , Mem, Size , ZeAdvice));
1455
+ if (NumSyncPointsInWaitList) {
1456
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1457
+ (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1458
+ ZeEventList.data ()));
1459
+ }
1460
+
1461
+ ZE2UR_CALL (zeCommandListAppendMemAdvise,
1462
+ (CommandBuffer->ZeComputeCommandList ,
1463
+ CommandBuffer->Device ->ZeDevice , Mem, Size , ZeAdvice));
1399
1464
1465
+ if (!CommandBuffer->IsInOrderCmdList ) {
1400
1466
// Level Zero does not have a completion "event" with the advise API,
1401
1467
// so manually add command to signal our event.
1402
1468
ZE2UR_CALL (zeCommandListAppendSignalEvent,
0 commit comments