diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 2a0ac3318..a8577bd24 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -3,10 +3,13 @@ pub(crate) mod runtimes; pub(crate) mod signatures; pub(crate) mod status; -use std::fmt::Debug; +use std::{ + borrow::Cow, + fmt::{self, Debug, Display}, +}; use derive_new::new; -use enum_map::Enum; +use enum_map::{Enum, EnumMap}; use ndarray::{Array, ArrayD, Dimension, ShapeError}; use thiserror::Error; @@ -18,10 +21,15 @@ pub(crate) trait InferenceRuntime: 'static { fn supported_devices() -> crate::Result; + #[allow(clippy::type_complexity)] fn new_session( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: InferenceSessionOptions, - ) -> anyhow::Result; + ) -> anyhow::Result<( + Self::Session, + Vec>, + Vec>, + )>; fn push_input( input: Array, @@ -31,30 +39,59 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } -pub(crate) trait InferenceGroup: Copy + Enum {} +pub(crate) trait InferenceGroup: Copy + Enum { + const INPUT_PARAM_INFOS: EnumMap]>; + const OUTPUT_PARAM_INFOS: EnumMap]>; +} pub(crate) trait InferenceSignature: Sized + Send + 'static { type Group: InferenceGroup; type Input: InferenceInputSignature; - type Output: TryFrom, Error = anyhow::Error> + Send; + type Output: InferenceOutputSignature; const KIND: Self::Group; } pub(crate) trait InferenceInputSignature: Send + 'static { type Signature: InferenceSignature; + const PARAM_INFOS: &'static [ParamInfo]; fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>; } -pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {} +pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static { + const KIND: InputScalarKind; +} -impl InputScalar for i64 {} -impl InputScalar for f32 {} +impl InputScalar for i64 { + const KIND: InputScalarKind = InputScalarKind::Int64; +} + +impl InputScalar for f32 { + const KIND: InputScalarKind = InputScalarKind::Float32; +} + +#[derive(Clone, Copy, PartialEq, derive_more::Display)] +pub(crate) enum InputScalarKind { + #[display(fmt = "int64_t")] + Int64, + + #[display(fmt = "float")] + Float32, +} + +pub(crate) trait InferenceOutputSignature: + TryFrom, Error = anyhow::Error> + Send +{ + const PARAM_INFOS: &'static [ParamInfo]; +} pub(crate) trait OutputScalar: Sized { + const KIND: OutputScalarKind; fn extract(tensor: OutputTensor) -> std::result::Result, ExtractError>; } impl OutputScalar for f32 { + const KIND: OutputScalarKind = OutputScalarKind::Float32; + fn extract(tensor: OutputTensor) -> std::result::Result, ExtractError> { match tensor { OutputTensor::Float32(tensor) => Ok(tensor), @@ -62,6 +99,12 @@ impl OutputScalar for f32 { } } +#[derive(Clone, Copy, PartialEq, derive_more::Display)] +pub(crate) enum OutputScalarKind { + #[display(fmt = "float")] + Float32, +} + pub(crate) enum OutputTensor { Float32(ArrayD), } @@ -75,6 +118,41 @@ impl TryFrom for Array { } } +pub(crate) struct ParamInfo { + name: Cow<'static, str>, + dt: D, + ndim: Option, +} + +impl ParamInfo { + fn accepts(&self, other: &Self) -> bool { + self.name == other.name + && self.dt == other.dt + && (self.ndim.is_none() || self.ndim == other.ndim) + } +} + +impl Display for ParamInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}: {}", self.name, self.dt)?; + if let Some(ndim) = self.ndim { + f.write_str(&"[]".repeat(ndim)) + } else { + f.write_str("[]...") + } + } +} + +pub(crate) trait ArrayExt { + type Scalar; + type Dimension: Dimension; +} + +impl ArrayExt for Array { + type Scalar = A; + type Dimension = D; +} + #[derive(new, Clone, Copy, PartialEq, Debug)] pub(crate) struct InferenceSessionOptions { pub(crate) cpu_num_threads: u16, diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 6c3901d04..ca5b28aaa 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -1,18 +1,19 @@ -use std::fmt::Debug; +use std::{fmt::Debug, vec}; +use anyhow::anyhow; use ndarray::{Array, Dimension}; use once_cell::sync::Lazy; use onnxruntime::{ environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType, }; +use crate::{devices::SupportedDevices, error::ErrorRepr}; + use self::assert_send::AssertSend; -use crate::{ - devices::SupportedDevices, - error::ErrorRepr, - infer::{ - DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, OutputTensor, - }, + +use super::super::{ + DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, InputScalarKind, + OutputScalarKind, OutputTensor, ParamInfo, }; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -47,7 +48,11 @@ impl InferenceRuntime for Onnxruntime { fn new_session( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: InferenceSessionOptions, - ) -> anyhow::Result { + ) -> anyhow::Result<( + Self::Session, + Vec>, + Vec>, + )> { let mut builder = ENVIRONMENT .new_session_builder()? .with_optimization_level(GraphOptimizationLevel::Basic)? @@ -72,8 +77,67 @@ impl InferenceRuntime for Onnxruntime { } let model = model()?; - let sess = builder.with_model_from_memory(model)?.into(); - return Ok(sess); + let sess = AssertSend::from(builder.with_model_from_memory(model)?); + + let input_param_infos = sess + .inputs + .iter() + .map(|info| { + let dt = match info.input_type { + TensorElementDataType::Float => Ok(InputScalarKind::Float32), + TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), + TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), + TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), + TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), + TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), + TensorElementDataType::Int64 => Ok(InputScalarKind::Int64), + TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), + TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), + TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), + TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + } + .map_err(|actual| { + anyhow!("unsupported input datatype `{actual}` for `{}`", info.name) + })?; + + Ok(ParamInfo { + name: info.name.clone().into(), + dt, + ndim: Some(info.dimensions.len()), + }) + }) + .collect::>()?; + + let output_param_infos = sess + .outputs + .iter() + .map(|info| { + let dt = match info.output_type { + TensorElementDataType::Float => Ok(OutputScalarKind::Float32), + TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), + TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), + TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), + TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), + TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), + TensorElementDataType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"), + TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), + TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), + TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), + TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + } + .map_err(|actual| { + anyhow!("unsupported output datatype `{actual}` for `{}`", info.name) + })?; + + Ok(ParamInfo { + name: info.name.clone().into(), + dt, + ndim: Some(info.dimensions.len()), + }) + }) + .collect::>()?; + + return Ok((sess, input_param_infos, output_param_infos)); static ENVIRONMENT: Lazy = Lazy::new(|| { Environment::builder() diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index bce6f62da..c4633658d 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,16 +1,40 @@ -use enum_map::Enum; -use macros::{InferenceGroup, InferenceInputSignature, TryFromVecOutputTensor}; +use enum_map::{Enum, EnumMap}; +use macros::{InferenceInputSignature, InferenceOutputSignature}; use ndarray::{Array0, Array1, Array2}; -use super::{InferenceSignature, OutputTensor}; +use super::{ + InferenceGroup, InferenceInputSignature as _, InferenceOutputSignature as _, + InferenceSignature, OutputTensor, +}; -#[derive(Clone, Copy, Enum, InferenceGroup)] +#[derive(Clone, Copy, Enum)] pub(crate) enum InferenceKind { PredictDuration, PredictIntonation, Decode, } +// FIXME: ここもマクロ化する +impl InferenceGroup for InferenceKind { + const INPUT_PARAM_INFOS: enum_map::EnumMap< + Self, + &'static [super::ParamInfo], + > = EnumMap::from_array([ + PredictDurationInput::PARAM_INFOS, + PredictIntonationInput::PARAM_INFOS, + DecodeInput::PARAM_INFOS, + ]); + + const OUTPUT_PARAM_INFOS: enum_map::EnumMap< + Self, + &'static [super::ParamInfo], + > = EnumMap::from_array([ + PredictDurationOutput::PARAM_INFOS, + PredictIntonationOutput::PARAM_INFOS, + DecodeOutput::PARAM_INFOS, + ]); +} + pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { @@ -23,11 +47,11 @@ impl InferenceSignature for PredictDuration { #[derive(InferenceInputSignature)] #[input_signature(Signature = PredictDuration)] pub(crate) struct PredictDurationInput { - pub(crate) phoneme: Array1, + pub(crate) phoneme_list: Array1, pub(crate) speaker_id: Array1, } -#[derive(TryFromVecOutputTensor)] +#[derive(InferenceOutputSignature)] pub(crate) struct PredictDurationOutput { pub(crate) phoneme_length: Array1, } @@ -45,16 +69,16 @@ impl InferenceSignature for PredictIntonation { #[input_signature(Signature = PredictIntonation)] pub(crate) struct PredictIntonationInput { pub(crate) length: Array0, - pub(crate) vowel_phoneme: Array1, - pub(crate) consonant_phoneme: Array1, - pub(crate) start_accent: Array1, - pub(crate) end_accent: Array1, - pub(crate) start_accent_phrase: Array1, - pub(crate) end_accent_phrase: Array1, + pub(crate) vowel_phoneme_list: Array1, + pub(crate) consonant_phoneme_list: Array1, + pub(crate) start_accent_list: Array1, + pub(crate) end_accent_list: Array1, + pub(crate) start_accent_phrase_list: Array1, + pub(crate) end_accent_phrase_list: Array1, pub(crate) speaker_id: Array1, } -#[derive(TryFromVecOutputTensor)] +#[derive(InferenceOutputSignature)] pub(crate) struct PredictIntonationOutput { pub(crate) f0_list: Array1, } @@ -76,7 +100,7 @@ pub(crate) struct DecodeInput { pub(crate) speaker_id: Array1, } -#[derive(TryFromVecOutputTensor)] +#[derive(InferenceOutputSignature)] pub(crate) struct DecodeOutput { pub(crate) wave: Array1, } diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index 587ce21fa..e1d2a8e3a 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -1,15 +1,18 @@ use std::{ collections::{BTreeMap, HashMap}, + fmt::Display, marker::PhantomData, sync::Arc, }; +use anyhow::bail; use educe::Educe; use enum_map::EnumMap; -use itertools::iproduct; +use itertools::{iproduct, Itertools as _}; use crate::{ error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult}, + infer::ParamInfo, manifest::ModelInnerId, metas::{SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta}, voice_model::{VoiceModel, VoiceModelId}, @@ -249,14 +252,39 @@ impl SessionSet { let mut sessions = model_bytes .iter() .map(|(k, m)| { - let sess = R::new_session(|| model_file::decrypt(m), options[k])?; + let expected_input_param_infos = G::INPUT_PARAM_INFOS[k]; + let expected_output_param_infos = G::OUTPUT_PARAM_INFOS[k]; + + let (sess, actual_input_param_infos, actual_output_param_infos) = + R::new_session(|| model_file::decrypt(m), options[k])?; + + check_param_infos(expected_input_param_infos, &actual_input_param_infos)?; + check_param_infos(expected_output_param_infos, &actual_output_param_infos)?; + Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) }) .collect::>>()?; - Ok(Self(EnumMap::::from_fn(|k| { + return Ok(Self(EnumMap::::from_fn(|k| { sessions.remove(&k.into_usize()).expect("should exist") - }))) + }))); + + fn check_param_infos( + expected: &[ParamInfo], + actual: &[ParamInfo], + ) -> anyhow::Result<()> { + if !(expected.len() == actual.len() + && itertools::zip_eq(expected, actual) + .all(|(expected, actual)| expected.accepts(actual))) + { + bail!( + "expected {{{}}}, got {{{}}}", + expected.iter().join(", "), + actual.iter().join(", "), + ) + } + Ok(()) + } } } diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 5fdbdf6dd..6de29c201 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -87,7 +87,7 @@ impl InferenceCore { .run_session( &model_id, PredictDurationInput { - phoneme: ndarray::arr1(phoneme_vector), + phoneme_list: ndarray::arr1(phoneme_vector), speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), }, ) @@ -127,12 +127,12 @@ impl InferenceCore { &model_id, PredictIntonationInput { length: ndarray::arr0(length as i64), - vowel_phoneme: ndarray::arr1(vowel_phoneme_vector), - consonant_phoneme: ndarray::arr1(consonant_phoneme_vector), - start_accent: ndarray::arr1(start_accent_vector), - end_accent: ndarray::arr1(end_accent_vector), - start_accent_phrase: ndarray::arr1(start_accent_phrase_vector), - end_accent_phrase: ndarray::arr1(end_accent_phrase_vector), + vowel_phoneme_list: ndarray::arr1(vowel_phoneme_vector), + consonant_phoneme_list: ndarray::arr1(consonant_phoneme_vector), + start_accent_list: ndarray::arr1(start_accent_vector), + end_accent_list: ndarray::arr1(end_accent_vector), + start_accent_phrase_list: ndarray::arr1(start_accent_phrase_vector), + end_accent_phrase_list: ndarray::arr1(end_accent_phrase_vector), speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), }, ) diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index dcd235537..a39aa2f20 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -5,13 +5,21 @@ use syn::{ parse::{Parse, ParseStream}, parse_macro_input, spanned::Spanned as _, - Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, + Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, Type, }; #[proc_macro_derive(InferenceGroup)] pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let DeriveInput { ident, .. } = parse_macro_input!(input as DeriveInput); - quote!(impl crate::infer::InferenceGroup for #ident {}).into() + let DeriveInput { + ident, generics, .. + } = parse_macro_input!(input as DeriveInput); + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + quote! { + impl #impl_generics crate::infer::InferenceGroup for #ident #ty_generics #where_clause {} + } + .into() } #[proc_macro_derive(InferenceInputSignature, attributes(input_signature))] @@ -43,7 +51,28 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ .parse_args()?; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - let field_names = struct_field_names(data)?; + + let fields = struct_fields(data)?; + + let param_infos = fields + .iter() + .map(|(name, ty)| { + let name = name.to_string(); + quote! { + crate::infer::ParamInfo { + name: ::std::borrow::Cow::Borrowed(#name), + dt: < + <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::InputScalar + >::KIND, + ndim: < + <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension + >::NDIM, + }, + } + }) + .collect::(); + + let field_names = fields.iter().map(|(name, _)| name); Ok(quote! { impl #impl_generics crate::infer::InferenceInputSignature for #ident #ty_generics @@ -51,6 +80,12 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ { type Signature = #signature; + const PARAM_INFOS: &'static [crate::infer::ParamInfo< + crate::infer::InputScalarKind + >] = &[ + #param_infos + ]; + fn make_run_context( self, sess: &mut R::Session, @@ -80,13 +115,15 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ } } -#[proc_macro_derive(TryFromVecOutputTensor)] -pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - return derive_try_from_vec_any_tensor(&parse_macro_input!(input)) +#[proc_macro_derive(InferenceOutputSignature)] +pub fn derive_inference_output_signature( + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + return derive_inference_output_signature(&parse_macro_input!(input)) .unwrap_or_else(|e| e.to_compile_error()) .into(); - fn derive_try_from_vec_any_tensor( + fn derive_inference_output_signature( input: &DeriveInput, ) -> syn::Result { let DeriveInput { @@ -97,10 +134,41 @@ pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_ma } = input; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - let field_names = struct_field_names(data)?; - let num_fields = field_names.len(); + + let fields = struct_fields(data)?; + let num_fields = fields.len(); + + let param_infos = fields + .iter() + .map(|(name, ty)| { + let name = name.to_string(); + quote! { + crate::infer::ParamInfo { + name: ::std::borrow::Cow::Borrowed(#name), + dt: < + <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::OutputScalar + >::KIND, + ndim: < + <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension + >::NDIM, + }, + } + }) + .collect::(); + + let field_names = fields.iter().map(|(name, _)| name); Ok(quote! { + impl #impl_generics crate::infer::InferenceOutputSignature for #ident #ty_generics + #where_clause + { + const PARAM_INFOS: &'static [crate::infer::ParamInfo< + crate::infer::OutputScalarKind + >] = &[ + #param_infos + ]; + } + impl #impl_generics ::std::convert::TryFrom<::std::vec::Vec> for #ident #ty_generics @@ -133,7 +201,7 @@ pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_ma } } -fn struct_field_names(data: &Data) -> syn::Result> { +fn struct_fields(data: &Data) -> syn::Result> { let fields = match data { Data::Struct(DataStruct { fields: Fields::Named(fields), @@ -153,6 +221,6 @@ fn struct_field_names(data: &Data) -> syn::Result> { Ok(fields .named .iter() - .map(|Field { ident, .. }| ident.as_ref().expect("should be named")) + .map(|Field { ident, ty, .. }| (ident.as_ref().expect("should be named"), ty)) .collect()) }