Skip to content

Commit

Permalink
ids_fortypeをチェックする
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Mar 10, 2024
1 parent fd2905b commit ad47904
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 60 deletions.
13 changes: 9 additions & 4 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::{
};
//use engine::
use duplicate::duplicate_item;
use std::path::PathBuf;
use itertools::Itertools as _;
use std::{collections::BTreeSet, path::PathBuf};
use thiserror::Error;
use uuid::Uuid;

Expand Down Expand Up @@ -71,10 +72,14 @@ pub(crate) enum ErrorRepr {
GetSupportedDevices(#[source] anyhow::Error),

#[error(
"`{style_id}`に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\
み込みが解除されています"
"`{style_id}` ([{style_types}])に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\
み込みが解除されています",
style_types = style_types.iter().format(", ")
)]
StyleNotFound { style_id: StyleId },
StyleNotFound {
style_id: StyleId,
style_types: &'static BTreeSet<StyleType>,
},

#[error(
"`{model_id}`に対する音声モデルが見つかりませんでした。読み込まれていないか、読み込みが既\
Expand Down
28 changes: 15 additions & 13 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ mod model_file;
pub(crate) mod runtimes;
pub(crate) mod status;

use std::{borrow::Cow, convert::Infallible, fmt::Debug, marker::PhantomData};
use std::{
borrow::Cow, collections::BTreeSet, convert::Infallible, fmt::Debug, marker::PhantomData,
};

use derive_new::new;
use duplicate::duplicate_item;
Expand Down Expand Up @@ -39,9 +41,9 @@ pub(crate) trait InferenceDomainGroup {
pub(crate) trait InferenceDomainMap<A: InferenceDomainAssociation> {
type Group: InferenceDomainGroup;

fn contains_for(&self, style_type: StyleType) -> bool
fn any<P>(&self, p: P) -> bool
where
A: InferenceDomainOptionAssociation;
P: InferenceDomainAssociationTargetPredicate<Association = A>;

fn try_ref_map<
F: ConvertInferenceDomainAssociationTarget<Self::Group, A, A2, E>,
Expand All @@ -53,6 +55,14 @@ pub(crate) trait InferenceDomainMap<A: InferenceDomainAssociation> {
) -> Result<<Self::Group as InferenceDomainGroup>::Map<A2>, E>;
}

pub(crate) trait InferenceDomainAssociationTargetPredicate {
type Association: InferenceDomainAssociation;
fn test<D: InferenceDomain>(
&self,
x: &<Self::Association as InferenceDomainAssociation>::Target<D>,
) -> bool;
}

pub(crate) trait ConvertInferenceDomainAssociationTarget<
G: InferenceDomainGroup + ?Sized,
A1: InferenceDomainAssociation,
Expand All @@ -70,27 +80,19 @@ pub(crate) trait InferenceDomainAssociation {
type Target<D: InferenceDomain>;
}

pub(crate) trait InferenceDomainOptionAssociation: InferenceDomainAssociation {
fn is_some<D: InferenceDomain>(x: &Self::Target<D>) -> bool;
}

pub(crate) struct Optional<A>(Infallible, PhantomData<fn() -> A>);

impl<A: InferenceDomainAssociation> InferenceDomainAssociation for Optional<A> {
type Target<D: InferenceDomain> = Option<A::Target<D>>;
}

impl<A: InferenceDomainAssociation> InferenceDomainOptionAssociation for Optional<A> {
fn is_some<D: InferenceDomain>(x: &Self::Target<D>) -> bool {
x.is_some()
}
}

/// ある`VoiceModel`が提供する推論操作の集合を示す。
pub(crate) trait InferenceDomain: Sized {
type Group: InferenceDomainGroup;
type Operation: InferenceOperation;

fn style_types() -> &'static BTreeSet<StyleType>;

fn visit<A: InferenceDomainAssociation>(
map: &<Self::Group as InferenceDomainGroup>::Map<A>,
) -> &A::Target<Self>;
Expand Down
14 changes: 5 additions & 9 deletions crates/voicevox_core/src/infer/domains.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
mod talk;

use crate::StyleType;

pub(crate) use self::talk::{
DecodeInput, DecodeOutput, PredictDurationInput, PredictDurationOutput, PredictIntonationInput,
PredictIntonationOutput, TalkDomain, TalkOperation,
};

use super::{
ConvertInferenceDomainAssociationTarget, InferenceDomainAssociation, InferenceDomainGroup,
InferenceDomainMap, InferenceDomainOptionAssociation,
ConvertInferenceDomainAssociationTarget, InferenceDomainAssociation,
InferenceDomainAssociationTargetPredicate, InferenceDomainGroup, InferenceDomainMap,
};

pub(crate) enum InferenceDomainGroupImpl {}
Expand All @@ -25,13 +23,11 @@ pub(crate) struct InferenceDomainMapImpl<A: InferenceDomainAssociation> {
impl<A: InferenceDomainAssociation> InferenceDomainMap<A> for InferenceDomainMapImpl<A> {
type Group = InferenceDomainGroupImpl;

fn contains_for(&self, style_type: StyleType) -> bool
fn any<P>(&self, p: P) -> bool
where
A: InferenceDomainOptionAssociation,
P: InferenceDomainAssociationTargetPredicate<Association = A>,
{
match style_type {
StyleType::Talk => A::is_some(&self.talk),
}
p.test(&self.talk)
}

fn try_ref_map<
Expand Down
11 changes: 11 additions & 0 deletions crates/voicevox_core/src/infer/domains/talk.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use std::collections::BTreeSet;

use enum_map::Enum;
use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};
use once_cell::sync::Lazy;

use crate::StyleType;

use super::{
super::{
Expand All @@ -16,6 +21,12 @@ impl InferenceDomain for TalkDomain {
type Group = InferenceDomainGroupImpl;
type Operation = TalkOperation;

fn style_types() -> &'static BTreeSet<StyleType> {
static STYLE_TYPES: Lazy<BTreeSet<StyleType>> =
Lazy::new(|| BTreeSet::from([StyleType::Talk]));
&STYLE_TYPES
}

fn visit<A: InferenceDomainAssociation>(
map: &<Self::Group as InferenceDomainGroup>::Map<A>,
) -> &A::Target<Self> {
Expand Down
61 changes: 43 additions & 18 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ use crate::{
manifest::ModelInnerId,
metas::{self, SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta},
voice_model::{ModelData, ModelDataByInferenceDomain, VoiceModelHeader, VoiceModelId},
Result,
Result, StyleType,
};

use super::{
model_file, InferenceDomain, InferenceDomainMap as _, InferenceInputSignature,
InferenceRuntime, InferenceSessionOptions, InferenceSignature, Optional,
model_file, InferenceDomain, InferenceDomainAssociationTargetPredicate,
InferenceDomainMap as _, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions,
InferenceSignature, Optional,
};

pub(crate) struct Status<R: InferenceRuntime, G: InferenceDomainGroup> {
Expand Down Expand Up @@ -132,14 +133,12 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> Status<R, G> {
.contains_voice_model(voice_model_id)
}

// FIXME: この関数はcompatible_engineとテストでのみ使われるが、テストのために`StyleType`を
// 引数に含めるようにする
pub(crate) fn is_loaded_model_by_style_id(&self, style_id: StyleId) -> bool {
self.loaded_models.lock().unwrap().contains_style(style_id)
}

pub(crate) fn validate_speaker_id(&self, style_id: StyleId) -> bool {
self.is_loaded_model_by_style_id(style_id)
}

/// 推論を実行する。
///
/// # Performance
Expand Down Expand Up @@ -191,12 +190,14 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> LoadedModels<R, G> {
.0
.iter()
.find(|(_, LoadedModel { metas, .. })| {
metas
.iter()
.flat_map(SpeakerMeta::styles)
.any(|style| *style.id() == style_id)
metas.iter().flat_map(SpeakerMeta::styles).any(|style| {
*style.id() == style_id && D::style_types().contains(style.r#type())
})
})
.ok_or(ErrorRepr::StyleNotFound { style_id })?;
.ok_or(ErrorRepr::StyleNotFound {
style_id,
style_types: D::style_types(),
})?;

let model_inner_id = D::visit(by_domain)
.as_ref()
Expand Down Expand Up @@ -280,11 +281,17 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> LoadedModels<R, G> {
.metas
.iter()
.flat_map(|speaker| speaker.styles());
if let Some(style_type) = external
.clone()
.map(StyleMeta::r#type)
.copied()
.find(|&t| !model_bytes_or_sessions.contains_for(t))
if let Some(style_type) =
external
.clone()
.map(StyleMeta::r#type)
.copied()
.find(|&style_type| {
!model_bytes_or_sessions.any(ContainsForStyleType {
style_type,
marker: PhantomData,
})
})
{
return Err(error(LoadModelErrorKind::MissingModelData { style_type }));
}
Expand All @@ -295,7 +302,25 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> LoadedModels<R, G> {
id: *style.id(),
}));
}
Ok(())
return Ok(());

struct ContainsForStyleType<A> {
style_type: StyleType,
marker: PhantomData<fn() -> A>,
}

impl<A: InferenceDomainAssociation> InferenceDomainAssociationTargetPredicate
for ContainsForStyleType<A>
{
type Association = Optional<A>;

fn test<D: InferenceDomain>(
&self,
x: &<Self::Association as InferenceDomainAssociation>::Target<D>,
) -> bool {
D::style_types().contains(&self.style_type) && x.is_some()
}
}
}

fn insert(
Expand Down
15 changes: 14 additions & 1 deletion crates/voicevox_core/src/metas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,20 @@ pub struct StyleMeta {
}

/// **スタイル**(_style_)に対応するモデルの種類。
#[derive(Default, Clone, Copy, Debug, strum::Display, Deserialize, Serialize)]
#[derive(
Default,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Debug,
strum::Display,
Deserialize,
Serialize,
)]
#[strum(serialize_all = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum StyleType {
Expand Down
15 changes: 0 additions & 15 deletions crates/voicevox_core/src/synthesizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,11 +833,6 @@ pub(crate) mod blocking {

impl<O> PerformInference for self::Synthesizer<O> {
fn predict_duration(&self, phoneme_vector: &[i64], style_id: StyleId) -> Result<Vec<f32>> {
// FIXME: `Status::ids_for`があるため、ここは不要なはず
if !self.status.validate_speaker_id(style_id) {
return Err(ErrorRepr::StyleNotFound { style_id }.into());
}

let (model_id, model_inner_id) = self.status.ids_for::<TalkDomain>(style_id)?;

let PredictDurationOutput {
Expand Down Expand Up @@ -873,11 +868,6 @@ pub(crate) mod blocking {
end_accent_phrase_vector: &[i64],
style_id: StyleId,
) -> Result<Vec<f32>> {
// FIXME: `Status::ids_for`があるため、ここは不要なはず
if !self.status.validate_speaker_id(style_id) {
return Err(ErrorRepr::StyleNotFound { style_id }.into());
}

let (model_id, model_inner_id) = self.status.ids_for::<TalkDomain>(style_id)?;

let PredictIntonationOutput { f0_list: output } = self.status.run_session(
Expand Down Expand Up @@ -905,11 +895,6 @@ pub(crate) mod blocking {
phoneme_vector: &[f32],
style_id: StyleId,
) -> Result<Vec<f32>> {
// FIXME: `Status::ids_for`があるため、ここは不要なはず
if !self.status.validate_speaker_id(style_id) {
return Err(ErrorRepr::StyleNotFound { style_id }.into());
}

let (model_id, model_inner_id) = self.status.ids_for::<TalkDomain>(style_id)?;

// 音が途切れてしまうのを避けるworkaround処理が入っている
Expand Down

0 comments on commit ad47904

Please sign in to comment.