From 5ff2b5948addd5045dfa100ee0c17077993764b7 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 6 Nov 2023 03:28:44 +0900 Subject: [PATCH] =?UTF-8?q?ONNX=20Runtime=E3=81=A8=E3=83=A2=E3=83=87?= =?UTF-8?q?=E3=83=AB=E3=81=AE=E3=82=B7=E3=82=B0=E3=83=8D=E3=83=81=E3=83=A3?= =?UTF-8?q?=E3=82=92=E9=9A=94=E9=9B=A2=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 1 + Cargo.toml | 1 + crates/voicevox_core/Cargo.toml | 1 + crates/voicevox_core/src/infer.rs | 92 +++++++ crates/voicevox_core/src/infer/runtimes.rs | 3 + .../src/infer/runtimes/onnxruntime.rs | 136 +++++++++ crates/voicevox_core/src/infer/signatures.rs | 87 ++++++ crates/voicevox_core/src/inference_core.rs | 83 +++--- crates/voicevox_core/src/lib.rs | 1 + crates/voicevox_core/src/status.rs | 259 +++--------------- crates/voicevox_core_c_api/Cargo.toml | 2 +- 11 files changed, 397 insertions(+), 269 deletions(-) create mode 100644 crates/voicevox_core/src/infer.rs create mode 100644 crates/voicevox_core/src/infer/runtimes.rs create mode 100644 crates/voicevox_core/src/infer/runtimes/onnxruntime.rs create mode 100644 crates/voicevox_core/src/infer/signatures.rs diff --git a/Cargo.lock b/Cargo.lock index 50868f63b..08e25f93e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4280,6 +4280,7 @@ dependencies = [ "indexmap 2.0.0", "itertools", "nanoid", + "ndarray", "once_cell", "onnxruntime", "open_jtalk", diff --git a/Cargo.toml b/Cargo.toml index b6237098a..bb98404f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ easy-ext = "1.0.1" fs-err = { version = "2.9.0", features = ["tokio"] } futures = "0.3.26" itertools = "0.10.5" +ndarray = "0.15.6" once_cell = "1.18.0" regex = "1.10.0" rstest = "0.15.0" diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 3a23b794a..bee2f822c 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -22,6 +22,7 @@ futures.workspace = true indexmap = { version = "2.0.0", features = ["serde"] } itertools.workspace = true nanoid = "0.4.0" +ndarray.workspace = true once_cell.workspace = true regex.workspace = true serde.workspace = true diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs new file mode 100644 index 000000000..efc66b1e8 --- /dev/null +++ b/crates/voicevox_core/src/infer.rs @@ -0,0 +1,92 @@ +pub(crate) mod runtimes; +pub(crate) mod signatures; + +use std::{fmt::Debug, marker::PhantomData, sync::Arc}; + +use derive_new::new; +use ndarray::{Array, Dimension, LinalgScalar}; +use thiserror::Error; + +pub(crate) trait InferenceRuntime: Copy { + type Session: Session; + type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>; +} + +pub(crate) trait Session: Sized + 'static { + fn new( + model: impl FnOnce() -> std::result::Result, DecryptModelError>, + options: SessionOptions, + ) -> anyhow::Result; +} + +pub(crate) trait RunBuilder<'a>: + From<&'a mut ::Session> +{ + type Runtime: InferenceRuntime; + fn input(&mut self, tensor: Array) -> &mut Self; +} + +pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputScalar {} + +impl InputScalar for i64 {} +impl InputScalar for f32 {} + +pub(crate) trait Signature: Sized + Send + Sync + 'static { + type SessionSet; + type Output; + fn get_session( + session_set: &Self::SessionSet, + ) -> &Arc>>; + fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>); +} + +pub(crate) trait Output: Sized + Send { + fn run(ctx: R::RunBuilder<'_>) -> anyhow::Result; +} + +pub(crate) struct TypedSession { + inner: R::Session, + marker: PhantomData, +} + +impl TypedSession { + pub(crate) fn new( + model: impl FnOnce() -> std::result::Result, DecryptModelError>, + options: SessionOptions, + ) -> anyhow::Result { + let inner = R::Session::new(model, options)?; + Ok(Self { + inner, + marker: PhantomData, + }) + } + + pub(crate) fn run(&mut self, sig: S) -> anyhow::Result + where + S::Output: Output, + { + let mut ctx = R::RunBuilder::from(&mut self.inner); + sig.input(&mut ctx); + S::Output::run(ctx) + } +} + +#[derive(new, Clone, Copy)] +pub(crate) struct SessionOptions { + pub(crate) cpu_num_threads: u16, + pub(crate) use_gpu: bool, +} + +#[derive(Error, Debug)] +#[error("不正なモデルファイルです")] +pub(crate) struct DecryptModelError; + +mod sealed { + pub(crate) trait OnnxruntimeInputScalar: + onnxruntime::TypeToTensorElementDataType + { + } + + impl OnnxruntimeInputScalar for i64 {} + impl OnnxruntimeInputScalar for f32 {} +} diff --git a/crates/voicevox_core/src/infer/runtimes.rs b/crates/voicevox_core/src/infer/runtimes.rs new file mode 100644 index 000000000..7934027b6 --- /dev/null +++ b/crates/voicevox_core/src/infer/runtimes.rs @@ -0,0 +1,3 @@ +mod onnxruntime; + +pub(crate) use self::onnxruntime::Onnxruntime; diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs new file mode 100644 index 000000000..a26abbb74 --- /dev/null +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -0,0 +1,136 @@ +use ndarray::{Array, Dimension}; +use once_cell::sync::Lazy; +use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel}; + +use crate::infer::{ + DecryptModelError, InferenceRuntime, InputScalar, Output, RunBuilder, Session, SessionOptions, +}; + +pub(crate) use self::assert_send::AssertSend; + +#[derive(Clone, Copy)] +pub(crate) enum Onnxruntime {} + +impl InferenceRuntime for Onnxruntime { + type Session = AssertSend>; + type RunBuilder<'a> = OnnxruntimeInferenceBuilder<'a>; +} + +impl Session for AssertSend> { + fn new( + model: impl FnOnce() -> std::result::Result, DecryptModelError>, + options: SessionOptions, + ) -> anyhow::Result { + let mut builder = ENVIRONMENT + .new_session_builder()? + .with_optimization_level(GraphOptimizationLevel::Basic)? + .with_intra_op_num_threads(options.cpu_num_threads.into())? + .with_inter_op_num_threads(options.cpu_num_threads.into())?; + + if options.use_gpu { + #[cfg(feature = "directml")] + { + use onnxruntime::ExecutionMode; + + builder = builder + .with_disable_mem_pattern()? + .with_execution_mode(ExecutionMode::ORT_SEQUENTIAL)? + .with_append_execution_provider_directml(0)?; + } + + #[cfg(not(feature = "directml"))] + { + builder = builder.with_append_execution_provider_cuda(Default::default())?; + } + } + + let model = model()?; + let this = builder.with_model_from_memory(model)?.into(); + return Ok(this); + + static ENVIRONMENT: Lazy = Lazy::new(|| { + Environment::builder() + .with_name(env!("CARGO_PKG_NAME")) + .with_log_level(LOGGING_LEVEL) + .build() + .unwrap() + }); + + const LOGGING_LEVEL: LoggingLevel = if cfg!(debug_assertions) { + LoggingLevel::Verbose + } else { + LoggingLevel::Warning + }; + } +} + +pub(crate) struct OnnxruntimeInferenceBuilder<'sess> { + sess: &'sess mut AssertSend>, + inputs: Vec>, +} + +impl<'sess> From<&'sess mut AssertSend>> + for OnnxruntimeInferenceBuilder<'sess> +{ + fn from(sess: &'sess mut AssertSend>) -> Self { + Self { + sess, + inputs: vec![], + } + } +} + +impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> { + type Runtime = Onnxruntime; + + fn input(&mut self, tensor: Array) -> &mut Self { + self.inputs + .push(Box::new(onnxruntime::session::NdArray::new(tensor))); + self + } +} + +impl Output for (Vec,) { + fn run( + OnnxruntimeInferenceBuilder { sess, mut inputs }: OnnxruntimeInferenceBuilder<'_>, + ) -> anyhow::Result { + let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; + + // FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く + Ok((outputs[0].as_slice().unwrap().to_owned(),)) + } +} + +// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 +// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 +mod assert_send { + use std::ops::{Deref, DerefMut}; + + pub(crate) struct AssertSend(T); + + impl From> + for AssertSend> + { + fn from(session: onnxruntime::session::Session<'static>) -> Self { + Self(session) + } + } + + impl Deref for AssertSend { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for AssertSend { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + // SAFETY: `Session` is probably "send"able. + #[allow(unsafe_code)] + unsafe impl Send for AssertSend {} +} diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs new file mode 100644 index 000000000..764d70b8d --- /dev/null +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use ndarray::{Array0, Array1, Array2}; + +use crate::infer::{InferenceRuntime, RunBuilder, Signature, TypedSession}; + +pub(crate) struct SessionSet { + pub(crate) predict_duration: Arc>>, + pub(crate) predict_intonation: Arc>>, + pub(crate) decode: Arc>>, +} + +pub(crate) struct PredictDuration { + pub(crate) phoneme: Array1, + pub(crate) speaker_id: Array1, +} + +impl Signature for PredictDuration { + type SessionSet = SessionSet; + type Output = (Vec,); + + fn get_session( + session_set: &Self::SessionSet, + ) -> &Arc>> { + &session_set.predict_duration + } + + fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { + ctx.input(self.phoneme).input(self.speaker_id); + } +} + +pub(crate) struct PredictIntonation { + pub(crate) length: Array0, + pub(crate) vowel_phoneme: Array1, + pub(crate) consonant_phoneme: Array1, + pub(crate) start_accent: Array1, + pub(crate) end_accent: Array1, + pub(crate) start_accent_phrase: Array1, + pub(crate) end_accent_phrase: Array1, + pub(crate) speaker_id: Array1, +} + +impl Signature for PredictIntonation { + type SessionSet = SessionSet; + type Output = (Vec,); + + fn get_session( + session_set: &Self::SessionSet, + ) -> &Arc>> { + &session_set.predict_intonation + } + + fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { + ctx.input(self.length) + .input(self.vowel_phoneme) + .input(self.consonant_phoneme) + .input(self.start_accent) + .input(self.end_accent) + .input(self.start_accent_phrase) + .input(self.end_accent_phrase) + .input(self.speaker_id); + } +} + +pub(crate) struct Decode { + pub(crate) f0: Array2, + pub(crate) phoneme: Array2, + pub(crate) speaker_id: Array1, +} + +impl Signature for Decode { + type SessionSet = SessionSet; + type Output = (Vec,); + + fn get_session( + session_set: &Self::SessionSet, + ) -> &Arc>> { + &session_set.decode + } + + fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { + ctx.input(self.f0) + .input(self.phoneme) + .input(self.speaker_id); + } +} diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 4b0d08be2..d1c7831c5 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,6 +1,6 @@ use self::status::*; use super::*; -use onnxruntime::{ndarray, session::NdArray}; +use crate::infer::signatures::{Decode, PredictDuration, PredictIntonation}; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; @@ -60,12 +60,15 @@ impl InferenceCore { let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let phoneme_vector_array = NdArray::new(ndarray::arr1(phoneme_vector)); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); - - let mut output = self + let (mut output,) = self .status - .predict_duration_session_run(&model_id, phoneme_vector_array, speaker_id_array) + .run_session( + &model_id, + PredictDuration { + phoneme: ndarray::arr1(phoneme_vector), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, + ) .await?; for output_item in output.iter_mut() { @@ -95,29 +98,24 @@ impl InferenceCore { let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let length_array = NdArray::new(ndarray::arr0(length as i64)); - let vowel_phoneme_vector_array = NdArray::new(ndarray::arr1(vowel_phoneme_vector)); - let consonant_phoneme_vector_array = NdArray::new(ndarray::arr1(consonant_phoneme_vector)); - let start_accent_vector_array = NdArray::new(ndarray::arr1(start_accent_vector)); - let end_accent_vector_array = NdArray::new(ndarray::arr1(end_accent_vector)); - let start_accent_phrase_vector_array = - NdArray::new(ndarray::arr1(start_accent_phrase_vector)); - let end_accent_phrase_vector_array = NdArray::new(ndarray::arr1(end_accent_phrase_vector)); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); - - self.status - .predict_intonation_session_run( + let (output,) = self + .status + .run_session( &model_id, - length_array, - vowel_phoneme_vector_array, - consonant_phoneme_vector_array, - start_accent_vector_array, - end_accent_vector_array, - start_accent_phrase_vector_array, - end_accent_phrase_vector_array, - speaker_id_array, + PredictIntonation { + length: ndarray::arr0(length as i64), + vowel_phoneme: ndarray::arr1(vowel_phoneme_vector), + consonant_phoneme: ndarray::arr1(consonant_phoneme_vector), + start_accent: ndarray::arr1(start_accent_vector), + end_accent: ndarray::arr1(end_accent_vector), + start_accent_phrase: ndarray::arr1(start_accent_phrase_vector), + end_accent_phrase: ndarray::arr1(end_accent_phrase_vector), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, ) - .await + .await?; + + Ok(output) } pub async fn decode( @@ -150,22 +148,23 @@ impl InferenceCore { padding_size, ); - let f0_array = NdArray::new( - ndarray::arr1(&f0_with_padding) - .into_shape([length_with_padding, 1]) - .unwrap(), - ); - let phoneme_array = NdArray::new( - ndarray::arr1(&phoneme_with_padding) - .into_shape([length_with_padding, phoneme_size]) - .unwrap(), - ); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); + let (output,) = self + .status + .run_session( + &model_id, + Decode { + f0: ndarray::arr1(&f0_with_padding) + .into_shape([length_with_padding, 1]) + .unwrap(), + phoneme: ndarray::arr1(&phoneme_with_padding) + .into_shape([length_with_padding, phoneme_size]) + .unwrap(), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, + ) + .await?; - self.status - .decode_session_run(&model_id, f0_array, phoneme_array, speaker_id_array) - .await - .map(|output| Self::trim_padding_from_output(output, padding_size)) + Ok(Self::trim_padding_from_output(output, padding_size)) } fn make_f0_with_padding( diff --git a/crates/voicevox_core/src/lib.rs b/crates/voicevox_core/src/lib.rs index 798515fb9..407f0b8f4 100644 --- a/crates/voicevox_core/src/lib.rs +++ b/crates/voicevox_core/src/lib.rs @@ -6,6 +6,7 @@ mod devices; /// cbindgen:ignore mod engine; mod error; +mod infer; mod inference_core; mod macros; mod manifest; diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 64a402683..46e462d1a 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,23 +1,16 @@ use super::*; -use itertools::iproduct; -use once_cell::sync::Lazy; -use onnxruntime::{ - environment::Environment, - ndarray::{Ix0, Ix1, Ix2}, - session::{NdArray, Session}, - GraphOptimizationLevel, LoggingLevel, +use crate::infer::{ + runtimes::Onnxruntime, + signatures::{Decode, PredictDuration, PredictIntonation, SessionSet}, + DecryptModelError, Output, SessionOptions, Signature, TypedSession, }; +use derive_more::Index; +use itertools::iproduct; +use std::path::Path; use std::sync::Arc; -use std::{env, path::Path}; -use tracing::error; mod model_file; -cfg_if! { - if #[cfg(not(feature="directml"))]{ - use onnxruntime::CudaProviderOptions; - } -} use std::collections::BTreeMap; pub struct Status { @@ -26,31 +19,6 @@ pub struct Status { heavy_session_options: SessionOptions, // 重いモデルはこちらを使う } -#[derive(new, Getters)] -struct SessionOptions { - cpu_num_threads: u16, - use_gpu: bool, -} - -#[derive(thiserror::Error, Debug)] -#[error("不正なモデルファイルです")] -struct DecryptModelError; - -static ENVIRONMENT: Lazy = Lazy::new(|| { - cfg_if! { - if #[cfg(debug_assertions)]{ - const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Verbose; - } else{ - const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Warning; - } - } - Environment::builder() - .with_name(env!("CARGO_PKG_NAME")) - .with_log_level(LOGGING_LEVEL) - .build() - .unwrap() -}); - impl Status { pub fn new(use_gpu: bool, cpu_num_threads: u16) -> Self { Self { @@ -116,13 +84,13 @@ impl Status { self.loaded_models.lock().unwrap().contains_style(style_id) } - fn new_session( + fn new_session( &self, model: &[u8], session_options: &SessionOptions, path: impl AsRef, - ) -> LoadModelResult> { - self.new_session_from_bytes(|| model_file::decrypt(model), session_options) + ) -> LoadModelResult> { + TypedSession::::new(|| model_file::decrypt(model), *session_options) .map_err(|source| LoadModelError { path: path.as_ref().to_owned(), context: LoadModelErrorKind::InvalidModelData, @@ -130,36 +98,6 @@ impl Status { }) } - fn new_session_from_bytes( - &self, - model_bytes: impl FnOnce() -> std::result::Result, DecryptModelError>, - session_options: &SessionOptions, - ) -> anyhow::Result> { - let session_builder = ENVIRONMENT - .new_session_builder()? - .with_optimization_level(GraphOptimizationLevel::Basic)? - .with_intra_op_num_threads(*session_options.cpu_num_threads() as i32)? - .with_inter_op_num_threads(*session_options.cpu_num_threads() as i32)?; - - let session_builder = if *session_options.use_gpu() { - cfg_if! { - if #[cfg(feature = "directml")]{ - session_builder - .with_disable_mem_pattern()? - .with_execution_mode(onnxruntime::ExecutionMode::ORT_SEQUENTIAL)? - .with_append_execution_provider_directml(0)? - } else { - let options = CudaProviderOptions::default(); - session_builder.with_append_execution_provider_cuda(options)? - } - } - } else { - session_builder - }; - - Ok(session_builder.with_model_from_memory(model_bytes()?)?) - } - pub fn validate_speaker_id(&self, style_id: StyleId) -> bool { self.is_loaded_model_by_style_id(style_id) } @@ -167,102 +105,25 @@ impl Status { /// # Panics /// /// `self`が`model_id`を含んでいないとき、パニックする。 - pub async fn predict_duration_session_run( - &self, - model_id: &VoiceModelId, - mut phoneme_vector_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let predict_duration = self.loaded_models.lock().unwrap().get( - model_id, - |SessionSet { - predict_duration, .. - }| predict_duration, - ); - - tokio::task::spawn_blocking(move || { - let mut predict_duration = predict_duration.lock().unwrap(); - - let output_tensors = predict_duration - .run(vec![&mut phoneme_vector_array, &mut speaker_id_array]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - }) - .await - .unwrap() - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - #[allow(clippy::too_many_arguments)] - pub async fn predict_intonation_session_run( + pub(crate) async fn run_session( &self, model_id: &VoiceModelId, - mut length_array: NdArray, - mut vowel_phoneme_vector_array: NdArray, - mut consonant_phoneme_vector_array: NdArray, - mut start_accent_vector_array: NdArray, - mut end_accent_vector_array: NdArray, - mut start_accent_phrase_vector_array: NdArray, - mut end_accent_phrase_vector_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let predict_intonation = self.loaded_models.lock().unwrap().get( - model_id, - |SessionSet { - predict_intonation, .. - }| predict_intonation, - ); + input: S, + ) -> Result + where + S: Signature, + for<'a> &'a S::SessionSet: From<&'a SessionSet>, + S::Output: Output, + { + let sess = S::get_session::( + (&self.loaded_models.lock().unwrap()[model_id].session_set).into(), + ) + .clone(); tokio::task::spawn_blocking(move || { - let mut predict_intonation = predict_intonation.lock().unwrap(); - - let output_tensors = predict_intonation - .run(vec![ - &mut length_array, - &mut vowel_phoneme_vector_array, - &mut consonant_phoneme_vector_array, - &mut start_accent_vector_array, - &mut end_accent_vector_array, - &mut start_accent_phrase_vector_array, - &mut end_accent_phrase_vector_array, - &mut speaker_id_array, - ]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - }) - .await - .unwrap() - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - pub async fn decode_session_run( - &self, - model_id: &VoiceModelId, - mut f0_array: NdArray, - mut phoneme_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let decode = self - .loaded_models - .lock() - .unwrap() - .get(model_id, |SessionSet { decode, .. }| decode); - - tokio::task::spawn_blocking(move || { - let mut decode = decode.lock().unwrap(); - - let output_tensors = decode - .run(vec![ - &mut f0_array, - &mut phoneme_array, - &mut speaker_id_array, - ]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) + let mut sess = sess.lock().unwrap(); + sess.run(input) + .map_err(|e| ErrorRepr::InferenceFailed(e).into()) }) .await .unwrap() @@ -272,13 +133,13 @@ impl Status { /// 読み込んだモデルの`Session`とそのメタ情報を保有し、追加/削除/取得の操作を提供する。 /// /// この構造体のメソッドは、すべて一瞬で完了すべきである。 -#[derive(Default)] +#[derive(Default, Index)] struct LoadedModels(BTreeMap); struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: SessionSet, + session_set: SessionSet, } impl LoadedModels { @@ -314,17 +175,6 @@ impl LoadedModels { Ok((model_id.clone(), model_inner_id)) } - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - fn get( - &self, - model_id: &VoiceModelId, - which: fn(&SessionSet) -> &Arc>>>, - ) -> Arc>>> { - which(&self.0[model_id].session_set).clone() - } - fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool { self.0.contains_key(model_id) } @@ -366,9 +216,9 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - predict_duration: Session<'static>, - predict_intonation: Session<'static>, - decode: Session<'static>, + predict_duration: TypedSession, + predict_intonation: TypedSession, + decode: TypedSession, ) -> Result<()> { self.ensure_acceptable(model)?; @@ -378,9 +228,9 @@ impl LoadedModels { model_inner_ids: model.model_inner_ids(), metas: model.metas().clone(), session_set: SessionSet { - predict_duration: Arc::new(std::sync::Mutex::new(predict_duration.into())), - predict_intonation: Arc::new(std::sync::Mutex::new(predict_intonation.into())), - decode: Arc::new(std::sync::Mutex::new(decode.into())), + predict_duration: Arc::new(std::sync::Mutex::new(predict_duration)), + predict_intonation: Arc::new(std::sync::Mutex::new(predict_intonation)), + decode: Arc::new(std::sync::Mutex::new(decode)), }, }, ); @@ -406,49 +256,6 @@ impl LoadedModels { } } -struct SessionSet { - predict_duration: Arc>>>, - predict_intonation: Arc>>>, - decode: Arc>>>, -} - -// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 -// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 - -use self::assert_send::AssertSend; - -mod assert_send { - use std::ops::{Deref, DerefMut}; - - use onnxruntime::session::Session; - - pub(super) struct AssertSend(T); - - impl From> for AssertSend> { - fn from(session: Session<'static>) -> Self { - Self(session) - } - } - - impl Deref for AssertSend { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } - } - - impl DerefMut for AssertSend { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } - } - - // SAFETY: `Session` is probably "send"able. - #[allow(unsafe_code)] - unsafe impl Send for AssertSend {} -} - #[cfg(test)] mod tests { diff --git a/crates/voicevox_core_c_api/Cargo.toml b/crates/voicevox_core_c_api/Cargo.toml index f187f0001..fad0e1b7b 100644 --- a/crates/voicevox_core_c_api/Cargo.toml +++ b/crates/voicevox_core_c_api/Cargo.toml @@ -52,7 +52,7 @@ easy-ext.workspace = true inventory = "0.3.4" libloading = "0.7.3" libtest-mimic = "0.6.0" -ndarray = "0.15.6" +ndarray.workspace = true ndarray-stats = "0.5.1" regex.workspace = true serde.workspace = true