From 3ecf05d66e2e04435fde3c8200e5208ce2707eb7 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sun, 1 Sep 2024 23:05:45 +0900 Subject: [PATCH] change: liberate VOICEVOX CORE (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * change: liberate VOICEVOX CORE * `session.use_vv_bin` * VOICEVOX/voicevox_core#722 用エラー --- src/environment.rs | 4 ++-- src/error.rs | 4 +++- src/lib.rs | 25 +++++++++++++++++++++++-- src/session/builder.rs | 9 +++++++++ 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/environment.rs b/src/environment.rs index eb56862a..83c2e8b1 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -293,8 +293,8 @@ extern_system_fn! { ); match severity { - // TODO: https://github.com/VOICEVOX/voicevox_project/issues/24 をやる際に、libonnxruntime側で`WARNING`未満のログを遮断する - ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE | ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => {} + ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::DEBUG, "{message}"), + ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::INFO, "{message}"), ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::WARN, "{message}"), ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR | ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => { tracing::event!(parent: &span, Level::ERROR, "{message}"); diff --git a/src/error.rs b/src/error.rs index 2e84580a..2017c57e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -263,7 +263,9 @@ pub enum Error { #[error("Could't get device ID from memory info: {0}")] GetDeviceId(ErrorInternal), #[error("Training API is not enabled in this build of ONNX Runtime.")] - TrainingNotEnabled + TrainingNotEnabled, + #[error("This ONNX Runtime does not support \"vv-bin\" format (note: load/link `voicevox_onnxruntime` instead of `onnxruntime`)")] + VvBinNotSupported } impl Error { diff --git a/src/lib.rs b/src/lib.rs index 8e83d6d7..f5e6ed6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -185,6 +185,7 @@ thread_local! { #[cfg_attr(docsrs, doc(cfg(feature = "__init-for-voicevox")))] #[derive(Debug)] pub struct EnvHandle { + is_voicevox_onnxruntime: bool, _env: std::sync::Arc, api: AssertSendSync>, #[cfg(feature = "load-dynamic")] @@ -265,7 +266,14 @@ pub fn try_init_from(filename: &std::ffi::OsStr, tp_options: Option) -> anyho let _env = create_env(api.0, tp_options)?; - Ok(EnvHandle { _env, api }) + let is_voicevox_onnxruntime = is_voicevox_onnxruntime(api.0); + + Ok(EnvHandle { is_voicevox_onnxruntime, _env, api }) }) } @@ -317,6 +327,17 @@ fn create_env(api: NonNull, tp_options: Option) -> bool { + unsafe { + let build_info = api.as_ref().GetBuildInfoString.expect("`GetBuildInfoString` must be present")(); + CStr::from_ptr(build_info) + .to_str() + .expect("should be UTF-8") + .starts_with("VOICEVOX ORT Build Info: ") + } +} + pub(crate) static G_ORT_API: OnceLock> = OnceLock::new(); /// Returns a pointer to the global [`ort_sys::OrtApi`] object. diff --git a/src/session/builder.rs b/src/session/builder.rs index 458c6ade..f2b2576a 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -445,6 +445,15 @@ impl SessionBuilder { }; Ok(session) } + + #[cfg(feature = "__init-for-voicevox")] + pub fn commit_from_vv_bin(self, bin: &[u8]) -> Result { + if !crate::EnvHandle::get().expect("should be present").is_voicevox_onnxruntime { + return Err(Error::VvBinNotSupported); + } + ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr.as_ptr(), c"session.use_vv_bin".as_ptr(), c"1".as_ptr())]; + self.commit_from_memory(bin) + } } /// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially