Skip to content

Commit

Permalink
vsort/vs_onnxruntime.cpp: replay the first dml execution
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Nov 30, 2024
1 parent a2b1a88 commit ac25053
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions vsort/vs_onnxruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ using namespace std::string_literals;
static const VSPlugin * myself = nullptr;
static const OrtApi * ortapi = nullptr;
static std::atomic<int64_t> logger_id = 0;

#if defined(ENABLE_CUDA) || defined(ENABLE_DML)
static std::mutex capture_lock;
#endif


// rename GridSample to com.microsoft::GridSample
Expand Down Expand Up @@ -446,8 +449,11 @@ struct Resource {
cudaStream_t stream;
CUDA_Resource_t input;
CUDA_Resource_t output;
bool require_replay;
#endif // ENABLE_CUDA

#if defined(ENABLE_CUDA) || defined(ENABLE_DML)
bool require_replay;
#endif
};

struct vsOrtData {
Expand Down Expand Up @@ -699,7 +705,7 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
}
#endif // ENABLE_CUDA

#ifdef ENABLE_CUDA
#if defined(ENABLE_CUDA) || defined(ENABLE_DML)
if (resource.require_replay) [[unlikely]] {
resource.require_replay = false;

Expand All @@ -710,11 +716,26 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
// note that this applies only to stream capture from the ort library
// this fails when another plugin also uses global-mode stream capture
std::lock_guard _ { capture_lock };
checkError(ortapi->RunWithBinding(resource.session, run_options, resource.binding));
if (d->backend == Backend::CUDA) {
checkError(ortapi->RunWithBinding(resource.session, run_options, resource.binding));
} else if (d->backend == Backend::DML) {
for (int i = 0; i < 2; i++) {
checkError(ortapi->Run(
resource.session,
run_options,
&resource.input_name,
&resource.input_tensor,
1,
&resource.output_name,
1,
&resource.output_tensor
));
}
}

// onnxruntime replays the graph itself in CUDAExecutionProvider::OnRunEnd
} else
#endif // ENABLE_CUDA
#endif // defined(ENABLE_CUDA) || defined(ENABLE_DML)
if (d->backend == Backend::CPU || d->backend == Backend::CUDA) {
checkError(ortapi->RunWithBinding(resource.session, run_options, resource.binding));
} else {
Expand Down Expand Up @@ -1263,6 +1284,7 @@ static void VS_CC vsOrtCreate(
const OrtDmlApi * ortdmlapi {};
checkError(ortapi->GetExecutionProviderApi("DML", ORT_API_VERSION, (const void **) &ortdmlapi));
checkError(ortdmlapi->SessionOptionsAppendExecutionProvider_DML(session_options, d->device_id));
resource.require_replay = true;
}
#endif // ENABLE_DML

Expand Down

0 comments on commit ac25053

Please sign in to comment.