diff --git a/crates/voicevox_core/src/error.rs b/crates/voicevox_core/src/error.rs index c0445f2e0..59624e544 100644 --- a/crates/voicevox_core/src/error.rs +++ b/crates/voicevox_core/src/error.rs @@ -5,7 +5,8 @@ use crate::{ }; //use engine:: use duplicate::duplicate_item; -use std::path::PathBuf; +use itertools::Itertools as _; +use std::{collections::BTreeSet, path::PathBuf}; use thiserror::Error; use uuid::Uuid; @@ -71,10 +72,14 @@ pub(crate) enum ErrorRepr { GetSupportedDevices(#[source] anyhow::Error), #[error( - "`{style_id}`に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\ - み込みが解除されています" + "`{style_id}` ([{style_types}])に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\ + み込みが解除されています", + style_types = style_types.iter().format(", ") )] - StyleNotFound { style_id: StyleId }, + StyleNotFound { + style_id: StyleId, + style_types: &'static BTreeSet, + }, #[error( "`{model_id}`に対する音声モデルが見つかりませんでした。読み込まれていないか、読み込みが既\ diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 3d66e774a..2ddbb8009 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -3,7 +3,9 @@ mod model_file; pub(crate) mod runtimes; pub(crate) mod status; -use std::{borrow::Cow, convert::Infallible, fmt::Debug, marker::PhantomData}; +use std::{ + borrow::Cow, collections::BTreeSet, convert::Infallible, fmt::Debug, marker::PhantomData, +}; use derive_new::new; use duplicate::duplicate_item; @@ -39,9 +41,9 @@ pub(crate) trait InferenceDomainGroup { pub(crate) trait InferenceDomainMap { type Group: InferenceDomainGroup; - fn contains_for(&self, style_type: StyleType) -> bool + fn any

(&self, p: P) -> bool where - A: InferenceDomainOptionAssociation; + P: InferenceDomainAssociationTargetPredicate; fn try_ref_map< F: ConvertInferenceDomainAssociationTarget, @@ -53,6 +55,14 @@ pub(crate) trait InferenceDomainMap { ) -> Result<::Map, E>; } +pub(crate) trait InferenceDomainAssociationTargetPredicate { + type Association: InferenceDomainAssociation; + fn test( + &self, + x: &::Target, + ) -> bool; +} + pub(crate) trait ConvertInferenceDomainAssociationTarget< G: InferenceDomainGroup + ?Sized, A1: InferenceDomainAssociation, @@ -70,27 +80,19 @@ pub(crate) trait InferenceDomainAssociation { type Target; } -pub(crate) trait InferenceDomainOptionAssociation: InferenceDomainAssociation { - fn is_some(x: &Self::Target) -> bool; -} - pub(crate) struct Optional(Infallible, PhantomData A>); impl InferenceDomainAssociation for Optional { type Target = Option>; } -impl InferenceDomainOptionAssociation for Optional { - fn is_some(x: &Self::Target) -> bool { - x.is_some() - } -} - /// ある`VoiceModel`が提供する推論操作の集合を示す。 pub(crate) trait InferenceDomain: Sized { type Group: InferenceDomainGroup; type Operation: InferenceOperation; + fn style_types() -> &'static BTreeSet; + fn visit( map: &::Map, ) -> &A::Target; diff --git a/crates/voicevox_core/src/infer/domains.rs b/crates/voicevox_core/src/infer/domains.rs index b4d3771ba..a4743c2f7 100644 --- a/crates/voicevox_core/src/infer/domains.rs +++ b/crates/voicevox_core/src/infer/domains.rs @@ -1,15 +1,13 @@ mod talk; -use crate::StyleType; - pub(crate) use self::talk::{ DecodeInput, DecodeOutput, PredictDurationInput, PredictDurationOutput, PredictIntonationInput, PredictIntonationOutput, TalkDomain, TalkOperation, }; use super::{ - ConvertInferenceDomainAssociationTarget, InferenceDomainAssociation, InferenceDomainGroup, - InferenceDomainMap, InferenceDomainOptionAssociation, + ConvertInferenceDomainAssociationTarget, InferenceDomainAssociation, + InferenceDomainAssociationTargetPredicate, InferenceDomainGroup, InferenceDomainMap, }; pub(crate) enum InferenceDomainGroupImpl {} @@ -25,13 +23,11 @@ pub(crate) struct InferenceDomainMapImpl { impl InferenceDomainMap for InferenceDomainMapImpl { type Group = InferenceDomainGroupImpl; - fn contains_for(&self, style_type: StyleType) -> bool + fn any

(&self, p: P) -> bool where - A: InferenceDomainOptionAssociation, + P: InferenceDomainAssociationTargetPredicate, { - match style_type { - StyleType::Talk => A::is_some(&self.talk), - } + p.test(&self.talk) } fn try_ref_map< diff --git a/crates/voicevox_core/src/infer/domains/talk.rs b/crates/voicevox_core/src/infer/domains/talk.rs index 4aff5f999..90ea56b34 100644 --- a/crates/voicevox_core/src/infer/domains/talk.rs +++ b/crates/voicevox_core/src/infer/domains/talk.rs @@ -1,6 +1,11 @@ +use std::collections::BTreeSet; + use enum_map::Enum; use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature}; use ndarray::{Array0, Array1, Array2}; +use once_cell::sync::Lazy; + +use crate::StyleType; use super::{ super::{ @@ -16,6 +21,12 @@ impl InferenceDomain for TalkDomain { type Group = InferenceDomainGroupImpl; type Operation = TalkOperation; + fn style_types() -> &'static BTreeSet { + static STYLE_TYPES: Lazy> = + Lazy::new(|| BTreeSet::from([StyleType::Talk])); + &STYLE_TYPES + } + fn visit( map: &::Map, ) -> &A::Target { diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index bea598aa9..6d270f139 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -22,12 +22,13 @@ use crate::{ manifest::ModelInnerId, metas::{self, SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta}, voice_model::{ModelData, ModelDataByInferenceDomain, VoiceModelHeader, VoiceModelId}, - Result, + Result, StyleType, }; use super::{ - model_file, InferenceDomain, InferenceDomainMap as _, InferenceInputSignature, - InferenceRuntime, InferenceSessionOptions, InferenceSignature, Optional, + model_file, InferenceDomain, InferenceDomainAssociationTargetPredicate, + InferenceDomainMap as _, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions, + InferenceSignature, Optional, }; pub(crate) struct Status { @@ -132,14 +133,12 @@ impl Status { .contains_voice_model(voice_model_id) } + // FIXME: この関数はcompatible_engineとテストでのみ使われるが、テストのために`StyleType`を + // 引数に含めるようにする pub(crate) fn is_loaded_model_by_style_id(&self, style_id: StyleId) -> bool { self.loaded_models.lock().unwrap().contains_style(style_id) } - pub(crate) fn validate_speaker_id(&self, style_id: StyleId) -> bool { - self.is_loaded_model_by_style_id(style_id) - } - /// 推論を実行する。 /// /// # Performance @@ -191,12 +190,14 @@ impl LoadedModels { .0 .iter() .find(|(_, LoadedModel { metas, .. })| { - metas - .iter() - .flat_map(SpeakerMeta::styles) - .any(|style| *style.id() == style_id) + metas.iter().flat_map(SpeakerMeta::styles).any(|style| { + *style.id() == style_id && D::style_types().contains(style.r#type()) + }) }) - .ok_or(ErrorRepr::StyleNotFound { style_id })?; + .ok_or(ErrorRepr::StyleNotFound { + style_id, + style_types: D::style_types(), + })?; let model_inner_id = D::visit(by_domain) .as_ref() @@ -280,11 +281,17 @@ impl LoadedModels { .metas .iter() .flat_map(|speaker| speaker.styles()); - if let Some(style_type) = external - .clone() - .map(StyleMeta::r#type) - .copied() - .find(|&t| !model_bytes_or_sessions.contains_for(t)) + if let Some(style_type) = + external + .clone() + .map(StyleMeta::r#type) + .copied() + .find(|&style_type| { + !model_bytes_or_sessions.any(ContainsForStyleType { + style_type, + marker: PhantomData, + }) + }) { return Err(error(LoadModelErrorKind::MissingModelData { style_type })); } @@ -295,7 +302,25 @@ impl LoadedModels { id: *style.id(), })); } - Ok(()) + return Ok(()); + + struct ContainsForStyleType { + style_type: StyleType, + marker: PhantomData A>, + } + + impl InferenceDomainAssociationTargetPredicate + for ContainsForStyleType + { + type Association = Optional; + + fn test( + &self, + x: &::Target, + ) -> bool { + D::style_types().contains(&self.style_type) && x.is_some() + } + } } fn insert( diff --git a/crates/voicevox_core/src/metas.rs b/crates/voicevox_core/src/metas.rs index df8f65800..ff0161720 100644 --- a/crates/voicevox_core/src/metas.rs +++ b/crates/voicevox_core/src/metas.rs @@ -164,7 +164,20 @@ pub struct StyleMeta { } /// **スタイル**(_style_)に対応するモデルの種類。 -#[derive(Default, Clone, Copy, Debug, strum::Display, Deserialize, Serialize)] +#[derive( + Default, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Debug, + strum::Display, + Deserialize, + Serialize, +)] #[strum(serialize_all = "snake_case")] #[serde(rename_all = "snake_case")] pub enum StyleType { diff --git a/crates/voicevox_core/src/synthesizer.rs b/crates/voicevox_core/src/synthesizer.rs index b0eb3953f..11f8fe380 100644 --- a/crates/voicevox_core/src/synthesizer.rs +++ b/crates/voicevox_core/src/synthesizer.rs @@ -833,11 +833,6 @@ pub(crate) mod blocking { impl PerformInference for self::Synthesizer { fn predict_duration(&self, phoneme_vector: &[i64], style_id: StyleId) -> Result> { - // FIXME: `Status::ids_for`があるため、ここは不要なはず - if !self.status.validate_speaker_id(style_id) { - return Err(ErrorRepr::StyleNotFound { style_id }.into()); - } - let (model_id, model_inner_id) = self.status.ids_for::(style_id)?; let PredictDurationOutput { @@ -873,11 +868,6 @@ pub(crate) mod blocking { end_accent_phrase_vector: &[i64], style_id: StyleId, ) -> Result> { - // FIXME: `Status::ids_for`があるため、ここは不要なはず - if !self.status.validate_speaker_id(style_id) { - return Err(ErrorRepr::StyleNotFound { style_id }.into()); - } - let (model_id, model_inner_id) = self.status.ids_for::(style_id)?; let PredictIntonationOutput { f0_list: output } = self.status.run_session( @@ -905,11 +895,6 @@ pub(crate) mod blocking { phoneme_vector: &[f32], style_id: StyleId, ) -> Result> { - // FIXME: `Status::ids_for`があるため、ここは不要なはず - if !self.status.validate_speaker_id(style_id) { - return Err(ErrorRepr::StyleNotFound { style_id }.into()); - } - let (model_id, model_inner_id) = self.status.ids_for::(style_id)?; // 音が途切れてしまうのを避けるworkaround処理が入っている