-
Notifications
You must be signed in to change notification settings - Fork 448
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
922cad6
commit 051b0b3
Showing
3 changed files
with
59 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,12 @@ | |
// granted to it by virtue of its status as an Intergovernmental Organization | ||
// or submit itself to any jurisdiction. | ||
|
||
/// \file ort_interface.h | ||
/// \file OrtInterface.h | ||
/// \author Christian Sonnabend <[email protected]> | ||
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU | ||
|
||
#ifndef O2_ML_ONNX_INTERFACE_H | ||
#define O2_ML_ONNX_INTERFACE_H | ||
#ifndef O2_ML_ORTINTERFACE_H | ||
#define O2_ML_ORTINTERFACE_H | ||
|
||
// C++ and system includes | ||
#include <vector> | ||
|
@@ -89,4 +89,4 @@ class OrtModel | |
|
||
} // namespace o2 | ||
|
||
#endif // O2_ML_ORT_INTERFACE_H | ||
#endif // O2_ML_ORTINTERFACE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,11 +9,11 @@ | |
// granted to it by virtue of its status as an Intergovernmental Organization | ||
// or submit itself to any jurisdiction. | ||
|
||
/// \file ort_interface.cxx | ||
/// \file OrtInterface.cxx | ||
/// \author Christian Sonnabend <[email protected]> | ||
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU | ||
|
||
#include "ML/ort_interface.h" | ||
#include "ML/OrtInterface.h" | ||
#include "ML/3rdparty/GPUORTFloat16.h" | ||
|
||
// ONNX includes | ||
|
@@ -50,29 +50,35 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap) | |
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0); | ||
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0); | ||
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0); | ||
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0); | ||
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2); | ||
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0); | ||
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0); | ||
|
||
std::string dev_mem_str = "Hip"; | ||
#ifdef ORT_ROCM_BUILD | ||
#if defined(ORT_ROCM_BUILD) | ||
#if ORT_ROCM_BUILD == 1 | ||
if (device == "ROCM") { | ||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId)); | ||
LOG(info) << "(ORT) ROCM execution provider set"; | ||
} | ||
#endif | ||
#ifdef ORT_MIGRAPHX_BUILD | ||
#endif | ||
#if defined(ORT_MIGRAPHX_BUILD) | ||
#if ORT_MIGRAPHX_BUILD == 1 | ||
if (device == "MIGRAPHX") { | ||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId)); | ||
LOG(info) << "(ORT) MIGraphX execution provider set"; | ||
} | ||
#endif | ||
#ifdef ORT_CUDA_BUILD | ||
#endif | ||
#if defined(ORT_CUDA_BUILD) | ||
#if ORT_CUDA_BUILD == 1 | ||
if (device == "CUDA") { | ||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId)); | ||
LOG(info) << "(ORT) CUDA execution provider set"; | ||
dev_mem_str = "Cuda"; | ||
} | ||
#endif | ||
#endif | ||
|
||
if (allocateDeviceMemory) { | ||
|
@@ -106,7 +112,27 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap) | |
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations)); | ||
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel)); | ||
|
||
pImplOrt->env = std::make_shared<Ort::Env>(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str())); | ||
pImplOrt->env = std::make_shared<Ort::Env>( | ||
OrtLoggingLevel(loggingLevel), | ||
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()), | ||
// Integrate ORT logging into Fairlogger | ||
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) { | ||
if (severity == ORT_LOGGING_LEVEL_VERBOSE) { | ||
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; | ||
} else if (severity == ORT_LOGGING_LEVEL_INFO) { | ||
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; | ||
} else if (severity == ORT_LOGGING_LEVEL_WARNING) { | ||
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; | ||
} else if (severity == ORT_LOGGING_LEVEL_ERROR) { | ||
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; | ||
} else if (severity == ORT_LOGGING_LEVEL_FATAL) { | ||
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; | ||
} else { | ||
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message; | ||
} | ||
}, | ||
(void*)3); | ||
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events | ||
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions); | ||
|
||
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) { | ||
|
@@ -130,16 +156,14 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap) | |
[&](const std::string& str) { return str.c_str(); }); | ||
|
||
// Print names | ||
if (loggingLevel > 1) { | ||
LOG(info) << "Input Nodes:"; | ||
for (size_t i = 0; i < mInputNames.size(); i++) { | ||
LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]); | ||
} | ||
LOG(info) << "\tInput Nodes:"; | ||
for (size_t i = 0; i < mInputNames.size(); i++) { | ||
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]); | ||
} | ||
|
||
LOG(info) << "Output Nodes:"; | ||
for (size_t i = 0; i < mOutputNames.size(); i++) { | ||
LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]); | ||
} | ||
LOG(info) << "\tOutput Nodes:"; | ||
for (size_t i = 0; i < mOutputNames.size(); i++) { | ||
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]); | ||
} | ||
} | ||
|
||
|