diff --git a/vsort/vs_onnxruntime.cpp b/vsort/vs_onnxruntime.cpp index d3cd173..db044de 100644 --- a/vsort/vs_onnxruntime.cpp +++ b/vsort/vs_onnxruntime.cpp @@ -68,7 +68,10 @@ using namespace std::string_literals; static const VSPlugin * myself = nullptr; static const OrtApi * ortapi = nullptr; static std::atomic logger_id = 0; + +#if defined(ENABLE_CUDA) || defined(ENABLE_DML) static std::mutex capture_lock; +#endif // rename GridSample to com.microsoft::GridSample @@ -446,8 +449,11 @@ struct Resource { cudaStream_t stream; CUDA_Resource_t input; CUDA_Resource_t output; - bool require_replay; #endif // ENABLE_CUDA + +#if defined(ENABLE_CUDA) || defined(ENABLE_DML) + bool require_replay; +#endif }; struct vsOrtData { @@ -699,7 +705,7 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( } #endif // ENABLE_CUDA -#ifdef ENABLE_CUDA +#if defined(ENABLE_CUDA) || defined(ENABLE_DML) if (resource.require_replay) [[unlikely]] { resource.require_replay = false; @@ -714,7 +720,7 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( // onnxruntime replays the graph itself in CUDAExecutionProvider::OnRunEnd } else -#endif // ENABLE_CUDA +#endif // defined(ENABLE_CUDA) || defined(ENABLE_DML) if (d->backend == Backend::CPU || d->backend == Backend::CUDA) { checkError(ortapi->RunWithBinding(resource.session, run_options, resource.binding)); } else { @@ -1263,6 +1269,7 @@ static void VS_CC vsOrtCreate( const OrtDmlApi * ortdmlapi {}; checkError(ortapi->GetExecutionProviderApi("DML", ORT_API_VERSION, (const void **) &ortdmlapi)); checkError(ortdmlapi->SessionOptionsAppendExecutionProvider_DML(session_options, d->device_id)); + resource.require_replay = true; } #endif // ENABLE_DML