Skip to content

Commit

Permalink
InferenceDomainの不在に対してのエラーを作成
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Mar 10, 2024
1 parent dcc5ee1 commit e4d3eec
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 19 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ regex.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true, features = ["preserve_order"] }
smallvec.workspace = true
strum = { workspace = true, features = ["derive"] }
tempfile.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["rt"] } # FIXME: feature-gateする
Expand Down
7 changes: 6 additions & 1 deletion crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
engine::{FullContextLabelError, KanaParseError},
user_dict::InvalidWordError,
StyleId, VoiceModelId,
StyleId, StyleType, VoiceModelId,
};
//use engine::
use duplicate::duplicate_item;
Expand Down Expand Up @@ -38,6 +38,7 @@ impl Error {
LoadModelErrorKind::ReadZipEntry { .. } => ErrorKind::ReadZipEntry,
LoadModelErrorKind::ModelAlreadyLoaded { .. } => ErrorKind::ModelAlreadyLoaded,
LoadModelErrorKind::StyleAlreadyLoaded { .. } => ErrorKind::StyleAlreadyLoaded,
LoadModelErrorKind::MissingModelData { .. } => ErrorKind::MissingModelData,
LoadModelErrorKind::InvalidModelData => ErrorKind::InvalidModelData,
},
ErrorRepr::GetSupportedDevices(_) => ErrorKind::GetSupportedDevices,
Expand Down Expand Up @@ -121,6 +122,8 @@ pub enum ErrorKind {
ModelAlreadyLoaded,
/// すでに読み込まれているスタイルを読み込もうとした。
StyleAlreadyLoaded,
/// モデルデータが見つからなかった。
MissingModelData,
/// 無効なモデルデータ。
InvalidModelData,
/// サポートされているデバイス情報取得に失敗した。
Expand Down Expand Up @@ -169,6 +172,8 @@ pub(crate) enum LoadModelErrorKind {
ModelAlreadyLoaded { id: VoiceModelId },
#[display(fmt = "スタイル`{id}`は既に読み込まれています")]
StyleAlreadyLoaded { id: StyleId },
#[display(fmt = "`{style_type}`に対応するモデルデータがありませんでした")]
MissingModelData { style_type: StyleType },
#[display(fmt = "モデルデータを読むことができませんでした")]
InvalidModelData,
}
16 changes: 15 additions & 1 deletion crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use enum_map::{Enum, EnumMap};
use ndarray::{Array, ArrayD, Dimension, ShapeError};
use thiserror::Error;

use crate::SupportedDevices;
use crate::{StyleType, SupportedDevices};

pub(crate) trait InferenceRuntime: 'static {
type Session: Sized + Send + 'static;
Expand Down Expand Up @@ -39,6 +39,10 @@ pub(crate) trait InferenceDomainGroup {
pub(crate) trait InferenceDomainMap<A: InferenceDomainAssociation> {
type Group: InferenceDomainGroup;

fn contains_for(&self, style_type: StyleType) -> bool
where
A: InferenceDomainOptionAssociation;

fn try_ref_map<
F: ConvertInferenceDomainAssociationTarget<Self::Group, A, A2, E>,
A2: InferenceDomainAssociation,
Expand Down Expand Up @@ -66,12 +70,22 @@ pub(crate) trait InferenceDomainAssociation {
type Target<D: InferenceDomain>;
}

pub(crate) trait InferenceDomainOptionAssociation: InferenceDomainAssociation {
fn is_some<D: InferenceDomain>(x: &Self::Target<D>) -> bool;
}

pub(crate) struct Optional<A>(Infallible, PhantomData<fn() -> A>);

impl<A: InferenceDomainAssociation> InferenceDomainAssociation for Optional<A> {
type Target<D: InferenceDomain> = Option<A::Target<D>>;
}

impl<A: InferenceDomainAssociation> InferenceDomainOptionAssociation for Optional<A> {
fn is_some<D: InferenceDomain>(x: &Self::Target<D>) -> bool {
x.is_some()
}
}

/// ある`VoiceModel`が提供する推論操作の集合を示す。
pub(crate) trait InferenceDomain: Sized {
type Group: InferenceDomainGroup;
Expand Down
13 changes: 12 additions & 1 deletion crates/voicevox_core/src/infer/domains.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
mod talk;

use crate::StyleType;

pub(crate) use self::talk::{
DecodeInput, DecodeOutput, PredictDurationInput, PredictDurationOutput, PredictIntonationInput,
PredictIntonationOutput, TalkDomain, TalkOperation,
};

use super::{
ConvertInferenceDomainAssociationTarget, InferenceDomainAssociation, InferenceDomainGroup,
InferenceDomainMap,
InferenceDomainMap, InferenceDomainOptionAssociation,
};

pub(crate) enum InferenceDomainGroupImpl {}
Expand All @@ -23,6 +25,15 @@ pub(crate) struct InferenceDomainMapImpl<A: InferenceDomainAssociation> {
impl<A: InferenceDomainAssociation> InferenceDomainMap<A> for InferenceDomainMapImpl<A> {
type Group = InferenceDomainGroupImpl;

fn contains_for(&self, style_type: StyleType) -> bool
where
A: InferenceDomainOptionAssociation,
{
match style_type {
StyleType::Talk => A::is_some(&self.talk),
}
}

fn try_ref_map<
F: ConvertInferenceDomainAssociationTarget<Self::Group, A, A2, E>,
A2: InferenceDomainAssociation,
Expand Down
49 changes: 38 additions & 11 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<R: InferenceRuntime, S: InferenceDomainGroup> Status<R, S> {
self.loaded_models
.lock()
.unwrap()
.ensure_acceptable(model_header)?;
.ensure_acceptable(model_header, model_bytes)?;

let session_set = model_bytes
.try_ref_map(CreateSessionSet {
Expand Down Expand Up @@ -131,7 +131,10 @@ impl<R: InferenceRuntime, S: InferenceDomainGroup> Status<R, S> {
///
/// # Panics
///
/// `self`が`model_id`を含んでいないとき、パニックする。
/// 次の場合にパニックする。
///
/// - `self`が`model_id`を含んでいないとき
/// - 対応する`InferenceDomain`が欠けているとき
pub(crate) fn run_session<I>(
&self,
model_id: &VoiceModelId,
Expand Down Expand Up @@ -193,15 +196,25 @@ impl<R: InferenceRuntime, S: InferenceDomainGroup> LoadedModels<R, S> {

/// # Panics
///
/// `self`が`model_id`を含んでいないとき、パニックする。
/// 次の場合にパニックする。
///
/// - `self`が`model_id`を含んでいないとき
/// - 対応する`InferenceDomain`が欠けているとき
fn get<I>(&self, model_id: &VoiceModelId) -> SessionCell<R, I>
where
I: InferenceInputSignature,
<I::Signature as InferenceSignature>::Domain: InferenceDomain<Group = S>,
{
<I::Signature as InferenceSignature>::Domain::visit(&self.0[model_id].session_sets)
.as_ref()
.unwrap_or_else(|| todo!("`ensure_acceptable`で検査する"))
.unwrap_or_else(|| {
let type_name =
std::any::type_name::<<I::Signature as InferenceSignature>::Domain>()
.split("::")
.last()
.unwrap();
panic!("missing session set for `{type_name}`");
})
.get()
}

Expand All @@ -219,14 +232,25 @@ impl<R: InferenceRuntime, S: InferenceDomainGroup> LoadedModels<R, S> {
///
/// 次の場合にエラーを返す。
///
/// - 音声モデルIDかスタイルIDが`model_header`と重複するとき
fn ensure_acceptable(&self, model_header: &VoiceModelHeader) -> LoadModelResult<()> {
/// - 現在持っている音声モデルIDかスタイルIDが`model_header`と重複するとき
/// - 必要であるはずの`InferenceDomain`のモデルデータが欠けているとき
fn ensure_acceptable(
&self,
model_header: &VoiceModelHeader,
model_bytes_or_sessions: &S::Map<Optional<impl InferenceDomainAssociation>>,
) -> LoadModelResult<()> {
let error = |context| LoadModelError {
path: model_header.path.clone(),
context,
source: None,
};

if self.0.contains_key(&model_header.id) {
return Err(error(LoadModelErrorKind::ModelAlreadyLoaded {
id: model_header.id.clone(),
}));
}

let loaded = self.speakers();
let external = model_header.metas.iter();
for (loaded, external) in iproduct!(loaded, external) {
Expand All @@ -240,10 +264,13 @@ impl<R: InferenceRuntime, S: InferenceDomainGroup> LoadedModels<R, S> {
.metas
.iter()
.flat_map(|speaker| speaker.styles());
if self.0.contains_key(&model_header.id) {
return Err(error(LoadModelErrorKind::ModelAlreadyLoaded {
id: model_header.id.clone(),
}));
if let Some(style_type) = external
.clone()
.map(StyleMeta::r#type)
.copied()
.find(|&t| !model_bytes_or_sessions.contains_for(t))
{
return Err(error(LoadModelErrorKind::MissingModelData { style_type }));
}
if let Some((style, _)) =
iproduct!(loaded, external).find(|(loaded, external)| loaded.id() == external.id())
Expand All @@ -260,7 +287,7 @@ impl<R: InferenceRuntime, S: InferenceDomainGroup> LoadedModels<R, S> {
model_header: &VoiceModelHeader,
session_sets: S::Map<Optional<SessionSetByDomain<R>>>,
) -> Result<()> {
self.ensure_acceptable(model_header)?;
self.ensure_acceptable(model_header, &session_sets)?;

let prev = self.0.insert(
model_header.id.clone(),
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core/src/metas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ pub struct StyleMeta {
}

/// **スタイル**(_style_)に対応するモデルの種類。
#[derive(Default, Clone, Copy, Deserialize, Serialize)]
#[derive(Default, Clone, Copy, Debug, strum::Display, Deserialize, Serialize)]
#[strum(serialize_all = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum StyleType {
/// 音声合成クエリの作成と音声合成が可能。
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core_c_api/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub(crate) fn into_result_code_with_error(result: CApiResult<()>) -> VoicevoxRes
ReadZipEntry => VOICEVOX_RESULT_READ_ZIP_ENTRY_ERROR,
ModelAlreadyLoaded => VOICEVOX_RESULT_MODEL_ALREADY_LOADED_ERROR,
StyleAlreadyLoaded => VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR,
MissingModelData => VOICEVOX_RESULT_MISSING_MODEL_DATA_ERROR,
InvalidModelData => VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR,
GetSupportedDevices => VOICEVOX_RESULT_GET_SUPPORTED_DEVICES_ERROR,
StyleNotFound => VOICEVOX_RESULT_STYLE_NOT_FOUND_ERROR,
Expand Down
3 changes: 3 additions & 0 deletions crates/voicevox_core_c_api/src/result_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub enum VoicevoxResultCode {
VOICEVOX_RESULT_MODEL_ALREADY_LOADED_ERROR = 18,
/// すでに読み込まれているスタイルを読み込もうとした
VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR = 26,
/// モデルデータが見つからなかった
VOICEVOX_RESULT_MISSING_MODEL_DATA_ERROR = 28,
/// 無効なモデルデータ
VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR = 27,
/// ユーザー辞書を読み込めなかった
Expand Down Expand Up @@ -94,6 +96,7 @@ pub(crate) const fn error_result_to_message(result_code: VoicevoxResultCode) ->
VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR => {
cstr!("同じIDのスタイルを読むことはできません")
}
VOICEVOX_RESULT_MISSING_MODEL_DATA_ERROR => cstr!("モデルデータがありませんでした"),
VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR => {
cstr!("モデルデータを読むことができませんでした")
}
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core_java_api/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ where
ReadZipEntry,
ModelAlreadyLoaded,
StyleAlreadyLoaded,
MissingModelData,
InvalidModelData,
GetSupportedDevices,
StyleNotFound,
Expand Down
9 changes: 5 additions & 4 deletions crates/voicevox_core_python_api/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ use voicevox_core::{

use crate::{
ExtractFullContextLabelError, GetSupportedDevicesError, GpuSupportError, InferenceFailedError,
InvalidModelDataError, InvalidWordError, LoadUserDictError, ModelAlreadyLoadedError,
ModelNotFoundError, NotLoadedOpenjtalkDictError, OpenZipFileError, ParseKanaError,
ReadZipEntryError, SaveUserDictError, StyleAlreadyLoadedError, StyleNotFoundError,
UseUserDictError, WordNotFoundError,
InvalidModelDataError, InvalidWordError, LoadUserDictError, MissingModelDataError,
ModelAlreadyLoadedError, ModelNotFoundError, NotLoadedOpenjtalkDictError, OpenZipFileError,
ParseKanaError, ReadZipEntryError, SaveUserDictError, StyleAlreadyLoadedError,
StyleNotFoundError, UseUserDictError, WordNotFoundError,
};

pub(crate) fn from_acceleration_mode(ob: &PyAny) -> PyResult<AccelerationMode> {
Expand Down Expand Up @@ -194,6 +194,7 @@ pub(crate) impl<T> voicevox_core::Result<T> {
ErrorKind::ReadZipEntry => ReadZipEntryError::new_err(msg),
ErrorKind::ModelAlreadyLoaded => ModelAlreadyLoadedError::new_err(msg),
ErrorKind::StyleAlreadyLoaded => StyleAlreadyLoadedError::new_err(msg),
ErrorKind::MissingModelData => MissingModelDataError::new_err(msg),
ErrorKind::InvalidModelData => InvalidModelDataError::new_err(msg),
ErrorKind::GetSupportedDevices => GetSupportedDevicesError::new_err(msg),
ErrorKind::StyleNotFound => StyleNotFoundError::new_err(msg),
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core_python_api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ exceptions! {
ReadZipEntryError: PyException;
ModelAlreadyLoadedError: PyException;
StyleAlreadyLoadedError: PyException;
MissingModelDataError: PyException;
InvalidModelDataError: PyException;
GetSupportedDevicesError: PyException;
StyleNotFoundError: PyKeyError;
Expand Down

0 comments on commit e4d3eec

Please sign in to comment.