Skip to content

Commit

Permalink
change: VVMにUUIDを割り振り、それをVoiceModelIdとする
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed May 23, 2024
1 parent 5a644ca commit 61725a2
Show file tree
Hide file tree
Showing 26 changed files with 174 additions and 148 deletions.
11 changes: 1 addition & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ libc = "0.2.134"
libloading = "0.7.3"
libtest-mimic = "0.6.0"
log = "0.4.17"
nanoid = "0.4.0"
ndarray = "0.15.6"
ndarray-stats = "0.5.1"
octocrab = { version = "0.19.0", default-features = false }
Expand Down
1 change: 0 additions & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ futures.workspace = true
indexmap = { workspace = true, features = ["serde"] }
itertools.workspace = true
jlabel.workspace = true
nanoid.workspace = true
ndarray.workspace = true
once_cell.workspace = true
open_jtalk.workspace = true
Expand Down
5 changes: 4 additions & 1 deletion crates/voicevox_core/src/__internal/interop.rs
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
pub use crate::{metas::merge as merge_metas, synthesizer::blocking::PerformInference};
pub use crate::{
metas::merge as merge_metas, synthesizer::blocking::PerformInference,
voice_model::blocking::IdRef,
};
5 changes: 2 additions & 3 deletions crates/voicevox_core/src/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use derive_new::new;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};

use crate::StyleId;
use crate::{StyleId, VoiceModelId};

pub type RawManifestVersion = String;
#[derive(Deserialize, Clone, Debug, PartialEq, new)]
Expand Down Expand Up @@ -38,10 +38,9 @@ impl Display for ModelInnerId {

#[derive(Deserialize, Getters, Clone)]
pub struct Manifest {
// FIXME: UUIDにする
// https://github.com/VOICEVOX/voicevox_core/issues/581
#[allow(dead_code)]
manifest_version: ManifestVersion,
pub(crate) id: VoiceModelId,
metas_filename: String,
#[serde(flatten)]
domains: ManifestDomains,
Expand Down
35 changes: 16 additions & 19 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl<R: InferenceRuntime> Status<R> {
Ok(())
}

pub(crate) fn unload_model(&self, voice_model_id: &VoiceModelId) -> Result<()> {
pub(crate) fn unload_model(&self, voice_model_id: VoiceModelId) -> Result<()> {
self.loaded_models.lock().unwrap().remove(voice_model_id)
}

Expand All @@ -77,7 +77,7 @@ impl<R: InferenceRuntime> Status<R> {
self.loaded_models.lock().unwrap().ids_for::<D>(style_id)
}

pub(crate) fn is_loaded_model(&self, voice_model_id: &VoiceModelId) -> bool {
pub(crate) fn is_loaded_model(&self, voice_model_id: VoiceModelId) -> bool {
self.loaded_models
.lock()
.unwrap()
Expand All @@ -101,7 +101,7 @@ impl<R: InferenceRuntime> Status<R> {
/// `self`が`model_id`を含んでいないとき、パニックする。
pub(crate) fn run_session<I>(
&self,
model_id: &VoiceModelId,
model_id: VoiceModelId,
input: I,
) -> Result<<I::Signature as InferenceSignature>::Output>
where
Expand Down Expand Up @@ -159,7 +159,7 @@ impl<R: InferenceRuntime> LoadedModels<R> {
.and_then(|(model_inner_ids, _)| model_inner_ids.get(&style_id).copied())
.unwrap_or_else(|| ModelInnerId::new(style_id.raw_id()));

Ok((model_id.clone(), model_inner_id))
Ok((*model_id, model_inner_id))
}

/// # Panics
Expand All @@ -168,12 +168,12 @@ impl<R: InferenceRuntime> LoadedModels<R> {
///
/// - `self`が`model_id`を含んでいないとき
/// - 対応する`InferenceDomain`が欠けているとき
fn get<I>(&self, model_id: &VoiceModelId) -> InferenceSessionCell<R, I>
fn get<I>(&self, model_id: VoiceModelId) -> InferenceSessionCell<R, I>
where
I: InferenceInputSignature,
<I::Signature as InferenceSignature>::Domain: InferenceDomainExt,
{
let (_, session_set) = self.0[model_id]
let (_, session_set) = self.0[&model_id]
.session_sets_with_inner_ids
.get::<<I::Signature as InferenceSignature>::Domain>()
.as_ref()
Expand All @@ -190,8 +190,8 @@ impl<R: InferenceRuntime> LoadedModels<R> {
session_set.get()
}

fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool {
self.0.contains_key(model_id)
fn contains_voice_model(&self, model_id: VoiceModelId) -> bool {
self.0.contains_key(&model_id)
}

fn contains_style(&self, style_id: StyleId) -> bool {
Expand All @@ -216,9 +216,9 @@ impl<R: InferenceRuntime> LoadedModels<R> {
source: None,
};

if self.0.contains_key(&model_header.id) {
if self.0.contains_key(&model_header.manifest.id) {
return Err(error(LoadModelErrorKind::ModelAlreadyLoaded {
id: model_header.id.clone(),
id: model_header.manifest.id,
}));
}

Expand Down Expand Up @@ -255,7 +255,7 @@ impl<R: InferenceRuntime> LoadedModels<R> {
self.ensure_acceptable(model_header)?;

let prev = self.0.insert(
model_header.id.clone(),
model_header.manifest.id,
LoadedModel {
metas: model_header.metas.clone(),
session_sets_with_inner_ids,
Expand All @@ -265,12 +265,9 @@ impl<R: InferenceRuntime> LoadedModels<R> {
Ok(())
}

fn remove(&mut self, model_id: &VoiceModelId) -> Result<()> {
if self.0.remove(model_id).is_none() {
return Err(ErrorRepr::ModelNotFound {
model_id: model_id.clone(),
}
.into());
fn remove(&mut self, model_id: VoiceModelId) -> Result<()> {
if self.0.remove(&model_id).is_none() {
return Err(ErrorRepr::ModelNotFound { model_id }.into());
}
Ok(())
}
Expand Down Expand Up @@ -415,13 +412,13 @@ mod tests {
let model_header = vvm.header();
let model_contents = &vvm.read_inference_models().await.unwrap();
assert!(
!status.is_loaded_model(&model_header.id),
!status.is_loaded_model(model_header.manifest.id),
"model should not be loaded"
);
let result = status.insert_model(model_header, model_contents);
assert_debug_fmt_eq!(Ok(()), result);
assert!(
status.is_loaded_model(&model_header.id),
status.is_loaded_model(model_header.manifest.id),
"model should be loaded",
);
}
Expand Down
14 changes: 7 additions & 7 deletions crates/voicevox_core/src/synthesizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,12 @@ pub(crate) mod blocking {
}

/// 音声モデルの読み込みを解除する。
pub fn unload_voice_model(&self, voice_model_id: &VoiceModelId) -> Result<()> {
pub fn unload_voice_model(&self, voice_model_id: VoiceModelId) -> Result<()> {
self.status.unload_model(voice_model_id)
}

/// 指定したIDの音声モデルが読み込まれているか判定する。
pub fn is_loaded_voice_model(&self, voice_model_id: &VoiceModelId) -> bool {
pub fn is_loaded_voice_model(&self, voice_model_id: VoiceModelId) -> bool {
self.status.is_loaded_model(voice_model_id)
}

Expand Down Expand Up @@ -841,7 +841,7 @@ pub(crate) mod blocking {
let PredictDurationOutput {
phoneme_length: output,
} = self.status.run_session(
&model_id,
model_id,
PredictDurationInput {
phoneme_list: ndarray::arr1(phoneme_vector),
speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]),
Expand Down Expand Up @@ -874,7 +874,7 @@ pub(crate) mod blocking {
let (model_id, model_inner_id) = self.status.ids_for::<TalkDomain>(style_id)?;

let PredictIntonationOutput { f0_list: output } = self.status.run_session(
&model_id,
model_id,
PredictIntonationInput {
length: ndarray::arr0(length as i64),
vowel_phoneme_list: ndarray::arr1(vowel_phoneme_vector),
Expand Down Expand Up @@ -917,7 +917,7 @@ pub(crate) mod blocking {
);

let DecodeOutput { wave: output } = self.status.run_session(
&model_id,
model_id,
DecodeInput {
f0: ndarray::arr1(&f0_with_padding)
.into_shape([length_with_padding, 1])
Expand Down Expand Up @@ -1150,11 +1150,11 @@ pub(crate) mod tokio {
self.0.status.insert_model(model.header(), model_bytes)
}

pub fn unload_voice_model(&self, voice_model_id: &VoiceModelId) -> Result<()> {
pub fn unload_voice_model(&self, voice_model_id: VoiceModelId) -> Result<()> {
self.0.unload_voice_model(voice_model_id)
}

pub fn is_loaded_voice_model(&self, voice_model_id: &VoiceModelId) -> bool {
pub fn is_loaded_voice_model(&self, voice_model_id: VoiceModelId) -> bool {
self.0.is_loaded_voice_model(voice_model_id)
}

Expand Down
43 changes: 22 additions & 21 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
use anyhow::anyhow;
use derive_getters::Getters;
use derive_more::From;
use derive_new::new;
use easy_ext::ext;
use enum_map::EnumMap;
use itertools::Itertools as _;
use serde::Deserialize;
use uuid::Uuid;

use crate::{
error::{LoadModelError, LoadModelErrorKind, LoadModelResult},
Expand All @@ -24,7 +26,7 @@ use std::path::{Path, PathBuf};
/// [`VoiceModelId`]の実体。
///
/// [`VoiceModelId`]: VoiceModelId
pub type RawVoiceModelId = String;
pub type RawVoiceModelId = Uuid;

pub(crate) type ModelBytesWithInnerIdsByDomain =
(Option<(StyleIdToModelInnerId, EnumMap<TalkOperation, Vec<u8>>)>,);
Expand All @@ -34,6 +36,7 @@ pub(crate) type ModelBytesWithInnerIdsByDomain =
PartialEq,
Eq,
Clone,
Copy,
Ord,
Hash,
PartialOrd,
Expand All @@ -42,7 +45,9 @@ pub(crate) type ModelBytesWithInnerIdsByDomain =
Getters,
derive_more::Display,
Debug,
From,
)]
#[serde(transparent)]
pub struct VoiceModelId {
raw_voice_model_id: RawVoiceModelId,
}
Expand All @@ -53,9 +58,7 @@ pub struct VoiceModelId {
/// モデルの`[u8]`と分けて`Status`に渡す。
#[derive(Clone)]
pub(crate) struct VoiceModelHeader {
/// ID。
pub(crate) id: VoiceModelId,
manifest: Manifest,
pub(crate) manifest: Manifest,
/// メタ情報。
///
/// `manifest`が対応していない`StyleType`のスタイルは含まれるべきではない。
Expand All @@ -64,12 +67,7 @@ pub(crate) struct VoiceModelHeader {
}

impl VoiceModelHeader {
fn new(
id: VoiceModelId,
manifest: Manifest,
metas: &[u8],
path: &Path,
) -> LoadModelResult<Self> {
fn new(manifest: Manifest, metas: &[u8], path: &Path) -> LoadModelResult<Self> {
let metas =
serde_json::from_slice::<VoiceModelMeta>(metas).map_err(|source| LoadModelError {
path: path.to_owned(),
Expand All @@ -94,7 +92,6 @@ impl VoiceModelHeader {
})?;

Ok(Self {
id,
manifest,
metas,
path: path.to_owned(),
Expand Down Expand Up @@ -151,8 +148,8 @@ pub(crate) mod blocking {
path::Path,
};

use easy_ext::ext;
use enum_map::EnumMap;
use nanoid::nanoid;
use ouroboros::self_referencing;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use serde::de::DeserializeOwned;
Expand Down Expand Up @@ -220,14 +217,13 @@ pub(crate) mod blocking {
let reader = BlockingVvmEntryReader::open(path)?;
let manifest = reader.read_vvm_json::<Manifest>("manifest.json")?;
let metas = &reader.read_vvm_entry(manifest.metas_filename())?;
let id = VoiceModelId::new(nanoid!());
let header = VoiceModelHeader::new(id, manifest, metas, path)?;
let header = VoiceModelHeader::new(manifest, metas, path)?;
Ok(Self { header })
}

/// ID。
pub fn id(&self) -> &VoiceModelId {
&self.header.id
pub fn id(&self) -> VoiceModelId {
self.header.manifest.id
}

/// メタ情報。
Expand Down Expand Up @@ -289,6 +285,13 @@ pub(crate) mod blocking {
})
}
}

#[ext(IdRef)]
pub impl VoiceModel {
fn id_ref(&self) -> &VoiceModelId {
&self.header.manifest.id
}
}
}

pub(crate) mod tokio {
Expand All @@ -297,7 +300,6 @@ pub(crate) mod tokio {
use derive_new::new;
use enum_map::EnumMap;
use futures::future::{join3, OptionFuture};
use nanoid::nanoid;
use serde::de::DeserializeOwned;

use crate::{
Expand Down Expand Up @@ -360,14 +362,13 @@ pub(crate) mod tokio {
let reader = AsyncVvmEntryReader::open(path.as_ref()).await?;
let manifest = reader.read_vvm_json::<Manifest>("manifest.json").await?;
let metas = &reader.read_vvm_entry(manifest.metas_filename()).await?;
let id = VoiceModelId::new(nanoid!());
let header = VoiceModelHeader::new(id, manifest, metas, path.as_ref())?;
let header = VoiceModelHeader::new(manifest, metas, path.as_ref())?;
Ok(Self { header })
}

/// ID。
pub fn id(&self) -> &VoiceModelId {
&self.header.id
pub fn id(&self) -> VoiceModelId {
self.header.manifest.id
}

/// メタ情報。
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core_c_api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ chrono = { workspace = true, default-features = false, features = ["clock"] }
colorchoice.workspace = true
cstr.workspace = true
derive-getters.workspace = true
easy-ext.workspace = true
futures.workspace = true
itertools.workspace = true
libc.workspace = true
Expand Down
Loading

0 comments on commit 61725a2

Please sign in to comment.