diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 260235543..2a0ac3318 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -31,15 +31,13 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } -pub(crate) trait InferenceGroup { - type Kind: Copy + Enum; -} +pub(crate) trait InferenceGroup: Copy + Enum {} pub(crate) trait InferenceSignature: Sized + Send + 'static { type Group: InferenceGroup; type Input: InferenceInputSignature; type Output: TryFrom, Error = anyhow::Error> + Send; - const KIND: ::Kind; + const KIND: Self::Group; } pub(crate) trait InferenceInputSignature: Send + 'static { diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index b7efb6244..bce6f62da 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,17 +1,11 @@ use enum_map::Enum; -use macros::{InferenceInputSignature, TryFromVecOutputTensor}; +use macros::{InferenceGroup, InferenceInputSignature, TryFromVecOutputTensor}; use ndarray::{Array0, Array1, Array2}; -use super::{InferenceGroup, InferenceSignature, OutputTensor}; +use super::{InferenceSignature, OutputTensor}; -pub(crate) enum InferenceGroupImpl {} - -impl InferenceGroup for InferenceGroupImpl { - type Kind = InferencelKindImpl; -} - -#[derive(Clone, Copy, Enum)] -pub(crate) enum InferencelKindImpl { +#[derive(Clone, Copy, Enum, InferenceGroup)] +pub(crate) enum InferenceKind { PredictDuration, PredictIntonation, Decode, @@ -20,10 +14,10 @@ pub(crate) enum InferencelKindImpl { pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { - type Group = InferenceGroupImpl; + type Group = InferenceKind; type Input = PredictDurationInput; type Output = PredictDurationOutput; - const KIND: InferencelKindImpl = InferencelKindImpl::PredictDuration; + const KIND: InferenceKind = InferenceKind::PredictDuration; } #[derive(InferenceInputSignature)] @@ -41,10 +35,10 @@ pub(crate) struct PredictDurationOutput { pub(crate) enum PredictIntonation {} impl InferenceSignature for PredictIntonation { - type Group = InferenceGroupImpl; + type Group = InferenceKind; type Input = PredictIntonationInput; type Output = PredictIntonationOutput; - const KIND: InferencelKindImpl = InferencelKindImpl::PredictIntonation; + const KIND: InferenceKind = InferenceKind::PredictIntonation; } #[derive(InferenceInputSignature)] @@ -68,10 +62,10 @@ pub(crate) struct PredictIntonationOutput { pub(crate) enum Decode {} impl InferenceSignature for Decode { - type Group = InferenceGroupImpl; + type Group = InferenceKind; type Input = DecodeInput; type Output = DecodeOutput; - const KIND: InferencelKindImpl = InferencelKindImpl::Decode; + const KIND: InferenceKind = InferenceKind::Decode; } #[derive(InferenceInputSignature)] diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index 4938f89a6..587ce21fa 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -5,7 +5,7 @@ use std::{ }; use educe::Educe; -use enum_map::{Enum as _, EnumMap}; +use enum_map::EnumMap; use itertools::iproduct; use crate::{ @@ -23,11 +23,11 @@ use super::{ pub(crate) struct Status { loaded_models: std::sync::Mutex>, - session_options: EnumMap, + session_options: EnumMap, } impl Status { - pub fn new(session_options: EnumMap) -> Self { + pub fn new(session_options: EnumMap) -> Self { Self { loaded_models: Default::default(), session_options, @@ -37,7 +37,7 @@ impl Status { pub async fn load_model( &self, model: &VoiceModel, - model_bytes: &EnumMap>, + model_bytes: &EnumMap>, ) -> Result<()> { self.loaded_models .lock() @@ -238,13 +238,13 @@ impl LoadedModels { } struct SessionSet( - EnumMap>>, + EnumMap>>, ); impl SessionSet { fn new( - model_bytes: &EnumMap>, - options: &EnumMap, + model_bytes: &EnumMap>, + options: &EnumMap, ) -> anyhow::Result { let mut sessions = model_bytes .iter() @@ -254,7 +254,7 @@ impl SessionSet { }) .collect::>>()?; - Ok(Self(EnumMap::::from_fn(|k| { + Ok(Self(EnumMap::::from_fn(|k| { sessions.remove(&k.into_usize()).expect("should exist") }))) } @@ -295,10 +295,8 @@ mod tests { use rstest::rstest; use crate::{ - infer::signatures::{InferenceGroupImpl, InferencelKindImpl}, - macros::tests::assert_debug_fmt_eq, - synthesizer::InferenceRuntimeImpl, - test_util::open_default_vvm_file, + infer::signatures::InferenceKind, macros::tests::assert_debug_fmt_eq, + synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file, }; use super::{super::InferenceSessionOptions, Status}; @@ -315,23 +313,23 @@ mod tests { let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false); let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); let session_options = enum_map! { - InferencelKindImpl::PredictDuration - | InferencelKindImpl::PredictIntonation => light_session_options, - InferencelKindImpl::Decode => heavy_session_options, + InferenceKind::PredictDuration + | InferenceKind::PredictIntonation => light_session_options, + InferenceKind::Decode => heavy_session_options, }; - let status = Status::::new(session_options); + let status = Status::::new(session_options); assert_eq!( light_session_options, - status.session_options[InferencelKindImpl::PredictDuration], + status.session_options[InferenceKind::PredictDuration], ); assert_eq!( light_session_options, - status.session_options[InferencelKindImpl::PredictIntonation], + status.session_options[InferenceKind::PredictIntonation], ); assert_eq!( heavy_session_options, - status.session_options[InferencelKindImpl::Decode], + status.session_options[InferenceKind::Decode], ); assert!(status.loaded_models.lock().unwrap().0.is_empty()); @@ -340,7 +338,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_load_model_works() { - let status = Status::::new( + let status = Status::::new( enum_map!(_ => InferenceSessionOptions::new(0, false)), ); let model = &open_default_vvm_file().await; @@ -353,7 +351,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_is_model_loaded_works() { - let status = Status::::new( + let status = Status::::new( enum_map!(_ => InferenceSessionOptions::new(0, false)), ); let vvm = open_default_vvm_file().await; diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 5987a0df4..5fdbdf6dd 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -2,8 +2,8 @@ use enum_map::enum_map; use crate::infer::{ signatures::{ - DecodeInput, DecodeOutput, InferenceGroupImpl, InferencelKindImpl, PredictDurationInput, - PredictDurationOutput, PredictIntonationInput, PredictIntonationOutput, + DecodeInput, DecodeOutput, InferenceKind, PredictDurationInput, PredictDurationOutput, + PredictIntonationInput, PredictIntonationOutput, }, status::Status, InferenceRuntime, InferenceSessionOptions, @@ -14,7 +14,7 @@ use super::*; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; pub(crate) struct InferenceCore { - status: Status, + status: Status, } impl InferenceCore { @@ -27,9 +27,9 @@ impl InferenceCore { let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); let status = Status::new(enum_map! { - InferencelKindImpl::PredictDuration - | InferencelKindImpl::PredictIntonation => light_session_options, - InferencelKindImpl::Decode => heavy_session_options, + InferenceKind::PredictDuration + | InferenceKind::PredictIntonation => light_session_options, + InferenceKind::Decode => heavy_session_options, }); Ok(Self { status }) } else { diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 136bc3742..5b75bcacf 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -4,7 +4,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::signatures::InferencelKindImpl; +use crate::infer::signatures::InferenceKind; use std::{ collections::{BTreeMap, HashMap}, io, @@ -40,7 +40,7 @@ pub struct VoiceModel { impl VoiceModel { pub(crate) async fn read_inference_models( &self, - ) -> LoadModelResult>> { + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index 192ca83ba..dcd235537 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -8,6 +8,12 @@ use syn::{ Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, }; +#[proc_macro_derive(InferenceGroup)] +pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let DeriveInput { ident, .. } = parse_macro_input!(input as DeriveInput); + quote!(impl crate::infer::InferenceGroup for #ident {}).into() +} + #[proc_macro_derive(InferenceInputSignature, attributes(input_signature))] pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream { return derive_inference_input_signature(&parse_macro_input!(input))