Skip to content

Commit

Permalink
onnx: fix provider detection
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Nov 17, 2024
1 parent 2ec6760 commit ea4922d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
4 changes: 2 additions & 2 deletions plugins/onnx/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/onnx/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.1.116"
"version": "0.1.117"
}
9 changes: 8 additions & 1 deletion plugins/onnx/src/ort/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,14 @@ def executor_initializer():
thread_name = threading.current_thread().name
interpreter = compiled_models.pop()
self.compiled_models[thread_name] = interpreter
self.provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in interpreter.get_providers() else "CPUExecutionProvider"
# remove CPUExecutionProider from providers
providers = interpreter.get_providers()
if not len(providers):
providers = ["CPUExecutionProvider"]
if "CPUExecutionProvider" in providers:
providers.remove("CPUExecutionProvider")
# join the remaining providers string
self.provider = ", ".join(providers)
print('Runtime initialized on thread {}'.format(thread_name))

self.executor = concurrent.futures.ThreadPoolExecutor(
Expand Down

0 comments on commit ea4922d

Please sign in to comment.