From 81b5804037ea21fe0b3fe0ae934fe1d2a2548c74 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 13:34:40 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=A9=E3=83=B3=E3=82=BF=E3=82=A4=E3=83=A0?= =?UTF-8?q?=E3=81=AF=E4=BB=BB=E6=84=8F=E6=AC=A1=E5=85=83=E4=BB=BB=E6=84=8F?= =?UTF-8?q?=E5=80=8B=E6=95=B0=E3=81=AE=E5=85=A5=E5=87=BA=E5=8A=9B=E3=81=8C?= =?UTF-8?q?=E3=81=A7=E3=81=8D=E3=82=8B=E3=81=A8=E4=BB=AE=E5=AE=9A=E3=81=99?= =?UTF-8?q?=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/engine/synthesis_engine.rs | 12 +- crates/voicevox_core/src/infer.rs | 114 +++++++++------ .../src/infer/runtimes/onnxruntime.rs | 71 +++++---- crates/voicevox_core/src/infer/signatures.rs | 138 +++++++++++++----- crates/voicevox_core/src/inference_core.rs | 29 ++-- crates/voicevox_core/src/status.rs | 5 +- 6 files changed, 226 insertions(+), 143 deletions(-) diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index b171f978d..c70742f16 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -5,10 +5,7 @@ use std::sync::Arc; use super::full_context_label::Utterance; use super::open_jtalk::OpenJtalk; use super::*; -use crate::infer::{ - signatures::{Decode, PredictDuration, PredictIntonation}, - InferenceRuntime, SupportsInferenceSignature, -}; +use crate::infer::InferenceRuntime; use crate::numerics::F32Ext as _; use crate::InferenceCore; @@ -26,12 +23,7 @@ pub(crate) struct SynthesisEngine { open_jtalk: Arc, } -impl< - R: SupportsInferenceSignature - + SupportsInferenceSignature - + SupportsInferenceSignature, - > SynthesisEngine -{ +impl SynthesisEngine { pub fn inference_core(&self) -> &InferenceCore { &self.inference_core } diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 9df7365f5..0eaea5b63 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -7,21 +7,28 @@ use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; use derive_new::new; use easy_ext::ext; use enum_map::{Enum, EnumMap}; +use ndarray::{Array, ArrayD, Dimension, ShapeError}; use thiserror::Error; use crate::{ErrorRepr, SupportedDevices}; pub(crate) trait InferenceRuntime: 'static { - type Session: InferenceSession; + type Session: Sized + Send + 'static; type RunContext<'a>: RunContext<'a, Runtime = Self>; + fn supported_devices() -> crate::Result; -} -pub(crate) trait InferenceSession: Sized + Send + 'static { - fn new( + fn new_session( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: InferenceSessionOptions, - ) -> anyhow::Result; + ) -> anyhow::Result; + + fn push_input( + input: Array, + ctx: &mut Self::RunContext<'_>, + ); + + fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } pub(crate) trait RunContext<'a>: @@ -32,54 +39,56 @@ pub(crate) trait RunContext<'a>: #[ext(RunContextExt)] impl<'a, T: RunContext<'a>> T { - fn with_input(mut self, tensor: I) -> Self - where - T::Runtime: SupportsInferenceInputTensor, - { + fn with_input(mut self, tensor: Array) -> Self { T::Runtime::push_input(tensor, &mut self); self } } -pub(crate) trait SupportsInferenceSignature: - SupportsInferenceInputSignature + SupportsInferenceOutput -{ +pub(crate) trait InferenceGroup { + type Kind: Copy + Enum; } -impl< - R: SupportsInferenceInputSignature + SupportsInferenceOutput, - S: InferenceSignature, - > SupportsInferenceSignature for R -{ +pub(crate) trait InferenceSignature: Sized + Send + 'static { + type Group: InferenceGroup; + type Input: InferenceInputSignature; + type Output: TryFrom, Error = anyhow::Error> + Send; + const INFERENCE: ::Kind; } -pub(crate) trait SupportsInferenceInputTensor: InferenceRuntime { - fn push_input(input: I, ctx: &mut Self::RunContext<'_>); +pub(crate) trait InferenceInputSignature: Send + 'static { + type Signature: InferenceSignature; + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>; } -pub(crate) trait SupportsInferenceInputSignature: - InferenceRuntime -{ - fn make_run_context(sess: &mut Self::Session, input: I) -> Self::RunContext<'_>; -} +pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {} + +impl InputScalar for i64 {} +impl InputScalar for f32 {} -pub(crate) trait SupportsInferenceOutput: InferenceRuntime { - fn run(ctx: Self::RunContext<'_>) -> anyhow::Result; +pub(crate) trait OutputScalar: Sized { + fn extract_dyn_dim(tensor: AnyTensor) -> std::result::Result, ExtractError>; } -pub(crate) trait InferenceGroup { - type Kind: Copy + Enum; +impl OutputScalar for f32 { + fn extract_dyn_dim(tensor: AnyTensor) -> std::result::Result, ExtractError> { + match tensor { + AnyTensor::Float32(tensor) => Ok(tensor), + } + } } -pub(crate) trait InferenceSignature: Sized + Send + 'static { - type Group: InferenceGroup; - type Input: InferenceInputSignature; - type Output: Send; - const INFERENCE: ::Kind; +pub(crate) enum AnyTensor { + Float32(ArrayD), } -pub(crate) trait InferenceInputSignature: Send + 'static { - type Signature: InferenceSignature; +impl TryFrom for Array { + type Error = ExtractError; + + fn try_from(tensor: AnyTensor) -> Result { + let this = A::extract_dyn_dim(tensor)?.into_dimensionality()?; + Ok(this) + } } pub(crate) struct InferenceSessionSet( @@ -94,7 +103,7 @@ impl InferenceSessionSet { let mut sessions = model_bytes .iter() .map(|(k, m)| { - let sess = R::Session::new(|| model_file::decrypt(m), options(k))?; + let sess = R::new_session(|| model_file::decrypt(m), options(k))?; Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) }) .collect::>>()?; @@ -123,19 +132,16 @@ pub(crate) struct InferenceSessionCell { marker: PhantomData, } -impl< - R: SupportsInferenceInputSignature - + SupportsInferenceOutput<::Output>, - I: InferenceInputSignature, - > InferenceSessionCell -{ +impl InferenceSessionCell { pub(crate) fn run( self, input: I, ) -> crate::Result<::Output> { let inner = &mut self.inner.lock().unwrap(); - let ctx = R::make_run_context(inner, input); - R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into()) + let ctx = input.make_run_context::(inner); + R::run(ctx) + .and_then(TryInto::try_into) + .map_err(|e| ErrorRepr::InferenceFailed(e).into()) } } @@ -145,6 +151,26 @@ pub(crate) struct InferenceSessionOptions { pub(crate) use_gpu: bool, } +#[derive(Error, Debug)] +pub(crate) enum ExtractError { + #[error(transparent)] + Shape(#[from] ShapeError), +} + #[derive(Error, Debug)] #[error("不正なモデルファイルです")] pub(crate) struct DecryptModelError; + +mod sealed { + pub(crate) trait InputScalar: OnnxruntimeInputScalar {} + + impl InputScalar for i64 {} + impl InputScalar for f32 {} + + pub(crate) trait OnnxruntimeInputScalar: + onnxruntime::TypeToTensorElementDataType + { + } + + impl OnnxruntimeInputScalar for T {} +} diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 8848c24e3..26bc93655 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use ndarray::{Array, Dimension}; use once_cell::sync::Lazy; use onnxruntime::{ - environment::Environment, GraphOptimizationLevel, LoggingLevel, TypeToTensorElementDataType, + environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType, }; use self::assert_send::AssertSend; @@ -11,8 +11,8 @@ use crate::{ devices::SupportedDevices, error::ErrorRepr, infer::{ - DecryptModelError, InferenceRuntime, InferenceSession, InferenceSessionOptions, RunContext, - SupportsInferenceInputTensor, SupportsInferenceOutput, + AnyTensor, DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, + RunContext, }, }; @@ -44,13 +44,11 @@ impl InferenceRuntime for Onnxruntime { dml: dml_support, }) } -} -impl InferenceSession for AssertSend> { - fn new( + fn new_session( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: InferenceSessionOptions, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut builder = ENVIRONMENT .new_session_builder()? .with_optimization_level(GraphOptimizationLevel::Basic)? @@ -75,8 +73,8 @@ impl InferenceSession for AssertSend> { } let model = model()?; - let this = builder.with_model_from_memory(model)?.into(); - return Ok(this); + let sess = builder.with_model_from_memory(model)?.into(); + return Ok(sess); static ENVIRONMENT: Lazy = Lazy::new(|| { Environment::builder() @@ -92,6 +90,39 @@ impl InferenceSession for AssertSend> { LoggingLevel::Warning }; } + + fn push_input( + input: Array, + ctx: &mut Self::RunContext<'_>, + ) { + ctx.inputs + .push(Box::new(onnxruntime::session::NdArray::new(input))); + } + + fn run( + OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, + ) -> anyhow::Result> { + // FIXME: 現状では`f32`のみ対応。実行時にsessionからdatatypeが取れるので、別の型の対応も + // おそらく可能ではあるが、それが必要になるよりもortクレートへの引越しが先になると思われる + // のでこのままにする。 + + if !sess + .outputs + .iter() + .all(|info| matches!(info.output_type, TensorElementDataType::Float)) + { + unimplemented!( + "currently only `ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT` is supported for output", + ); + } + + let outputs = sess.run::(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; + + Ok(outputs + .iter() + .map(|o| AnyTensor::Float32((*o).clone().into_owned())) + .collect()) + } } pub(crate) struct OnnxruntimeRunContext<'sess> { @@ -114,28 +145,6 @@ impl<'sess> RunContext<'sess> for OnnxruntimeRunContext<'sess> { type Runtime = Onnxruntime; } -impl - SupportsInferenceInputTensor> for Onnxruntime -{ - fn push_input(input: Array, ctx: &mut Self::RunContext<'_>) { - ctx.inputs - .push(Box::new(onnxruntime::session::NdArray::new(input))); - } -} - -impl SupportsInferenceOutput<(Vec,)> - for Onnxruntime -{ - fn run( - OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, - ) -> anyhow::Result<(Vec,)> { - let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; - - // FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く - Ok((outputs[0].as_slice().unwrap().to_owned(),)) - } -} - // FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 // https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 mod assert_send { diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index d04e7b3ad..ac46c6444 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,9 +1,10 @@ +use anyhow::ensure; use enum_map::Enum; use ndarray::{Array0, Array1, Array2}; -use crate::infer::{ - InferenceGroup, InferenceInputSignature, InferenceSignature, RunContextExt as _, - SupportsInferenceInputSignature, SupportsInferenceInputTensor, +use super::{ + AnyTensor, InferenceGroup, InferenceInputSignature, InferenceRuntime, InferenceSignature, + RunContextExt as _, }; pub(crate) enum InferenceGroupImpl {} @@ -24,7 +25,7 @@ pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { type Group = InferenceGroupImpl; type Input = PredictDurationInput; - type Output = (Vec,); + type Output = PredictDurationOutput; const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictDuration; } @@ -35,18 +36,36 @@ pub(crate) struct PredictDurationInput { impl InferenceInputSignature for PredictDurationInput { type Signature = PredictDuration; + + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { + R::RunContext::from(sess) + .with_input(self.phoneme) + .with_input(self.speaker_id) + } +} + +pub(crate) struct PredictDurationOutput { + pub(crate) phoneme_length: Array1, } -impl>> - SupportsInferenceInputSignature for R -{ - fn make_run_context( - sess: &mut Self::Session, - input: PredictDurationInput, - ) -> Self::RunContext<'_> { - Self::RunContext::from(sess) - .with_input(input.phoneme) - .with_input(input.speaker_id) +impl TryFrom> for PredictDurationOutput { + type Error = anyhow::Error; + + fn try_from(tensors: Vec) -> Result { + ensure!( + tensors.len() == 1, + "expected 1 tensor(s), got {}", + tensors.len(), + ); + + let mut tensors = tensors.into_iter(); + let this = Self { + phoneme_length: tensors + .next() + .expect("the length should have been checked") + .try_into()?, + }; + Ok(this) } } @@ -55,7 +74,7 @@ pub(crate) enum PredictIntonation {} impl InferenceSignature for PredictIntonation { type Group = InferenceGroupImpl; type Input = PredictIntonationInput; - type Output = (Vec,); + type Output = PredictIntonationOutput; const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictIntonation; } @@ -72,24 +91,42 @@ pub(crate) struct PredictIntonationInput { impl InferenceInputSignature for PredictIntonationInput { type Signature = PredictIntonation; + + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { + R::RunContext::from(sess) + .with_input(self.length) + .with_input(self.vowel_phoneme) + .with_input(self.consonant_phoneme) + .with_input(self.start_accent) + .with_input(self.end_accent) + .with_input(self.start_accent_phrase) + .with_input(self.end_accent_phrase) + .with_input(self.speaker_id) + } } -impl> + SupportsInferenceInputTensor>> - SupportsInferenceInputSignature for R -{ - fn make_run_context( - sess: &mut Self::Session, - input: PredictIntonationInput, - ) -> Self::RunContext<'_> { - Self::RunContext::from(sess) - .with_input(input.length) - .with_input(input.vowel_phoneme) - .with_input(input.consonant_phoneme) - .with_input(input.start_accent) - .with_input(input.end_accent) - .with_input(input.start_accent_phrase) - .with_input(input.end_accent_phrase) - .with_input(input.speaker_id) +pub(crate) struct PredictIntonationOutput { + pub(crate) f0_list: Array1, +} + +impl TryFrom> for PredictIntonationOutput { + type Error = anyhow::Error; + + fn try_from(tensors: Vec) -> Result { + ensure!( + tensors.len() == 1, + "expected 1 tensor(s), got {}", + tensors.len(), + ); + + let mut tensors = tensors.into_iter(); + let this = Self { + f0_list: tensors + .next() + .expect("the length should have been checked") + .try_into()?, + }; + Ok(this) } } @@ -98,7 +135,7 @@ pub(crate) enum Decode {} impl InferenceSignature for Decode { type Group = InferenceGroupImpl; type Input = DecodeInput; - type Output = (Vec,); + type Output = DecodeOutput; const INFERENCE: InferencelKindImpl = InferencelKindImpl::Decode; } @@ -110,15 +147,36 @@ pub(crate) struct DecodeInput { impl InferenceInputSignature for DecodeInput { type Signature = Decode; + + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { + R::RunContext::from(sess) + .with_input(self.f0) + .with_input(self.phoneme) + .with_input(self.speaker_id) + } +} + +pub(crate) struct DecodeOutput { + pub(crate) wave: Array1, } -impl> + SupportsInferenceInputTensor>> - SupportsInferenceInputSignature for R -{ - fn make_run_context(sess: &mut Self::Session, input: DecodeInput) -> Self::RunContext<'_> { - Self::RunContext::from(sess) - .with_input(input.f0) - .with_input(input.phoneme) - .with_input(input.speaker_id) +impl TryFrom> for DecodeOutput { + type Error = anyhow::Error; + + fn try_from(tensors: Vec) -> Result { + ensure!( + tensors.len() == 1, + "expected 1 tensor(s), got {}", + tensors.len(), + ); + + let mut tensors = tensors.into_iter(); + let this = Self { + wave: tensors + .next() + .expect("the length should have been checked") + .try_into()?, + }; + Ok(this) } } diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 264f56942..30dc37995 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -2,10 +2,10 @@ use self::status::*; use super::*; use crate::infer::{ signatures::{ - Decode, DecodeInput, PredictDuration, PredictDurationInput, PredictIntonation, - PredictIntonationInput, + DecodeInput, DecodeOutput, PredictDurationInput, PredictDurationOutput, + PredictIntonationInput, PredictIntonationOutput, }, - InferenceRuntime, SupportsInferenceSignature, + InferenceRuntime, }; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; @@ -14,12 +14,7 @@ pub(crate) struct InferenceCore { status: Status, } -impl< - R: SupportsInferenceSignature - + SupportsInferenceSignature - + SupportsInferenceSignature, - > InferenceCore -{ +impl InferenceCore { pub(crate) fn new(use_gpu: bool, cpu_num_threads: u16) -> Result { if !use_gpu || Self::can_support_gpu_feature()? { let status = Status::new(use_gpu, cpu_num_threads); @@ -71,7 +66,9 @@ impl< let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let (mut output,) = self + let PredictDurationOutput { + phoneme_length: output, + } = self .status .run_session( &model_id, @@ -81,6 +78,7 @@ impl< }, ) .await?; + let mut output = output.into_raw_vec(); for output_item in output.iter_mut() { if *output_item < PHONEME_LENGTH_MINIMAL { @@ -109,7 +107,7 @@ impl< let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let (output,) = self + let PredictIntonationOutput { f0_list: output } = self .status .run_session( &model_id, @@ -126,7 +124,7 @@ impl< ) .await?; - Ok(output) + Ok(output.into_raw_vec()) } pub async fn decode( @@ -159,7 +157,7 @@ impl< padding_size, ); - let (output,) = self + let DecodeOutput { wave: output } = self .status .run_session( &model_id, @@ -175,7 +173,10 @@ impl< ) .await?; - Ok(Self::trim_padding_from_output(output, padding_size)) + Ok(Self::trim_padding_from_output( + output.into_raw_vec(), + padding_size, + )) } fn make_f0_with_padding( diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 16bc04a4d..51cabf20d 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -2,8 +2,7 @@ use super::*; use crate::infer::{ signatures::{InferenceGroupImpl, InferencelKindImpl}, InferenceInputSignature, InferenceRuntime, InferenceSessionCell, InferenceSessionOptions, - InferenceSessionSet, InferenceSignature, SupportsInferenceInputSignature, - SupportsInferenceOutput, + InferenceSessionSet, InferenceSignature, }; use educe::Educe; use itertools::iproduct; @@ -90,8 +89,6 @@ impl Status { where I: InferenceInputSignature, I::Signature: InferenceSignature, - R: SupportsInferenceInputSignature - + SupportsInferenceOutput<::Output>, { let sess = self.loaded_models.lock().unwrap().get(model_id);