Skip to content

Commit

Permalink
"kind"を直接"group"と呼ぶことにする
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 11, 2023
1 parent c39f48c commit c316209
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 49 deletions.
6 changes: 2 additions & 4 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,13 @@ pub(crate) trait InferenceRuntime: 'static {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

pub(crate) trait InferenceGroup {
type Kind: Copy + Enum;
}
pub(crate) trait InferenceGroup: Copy + Enum {}

pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Group: InferenceGroup;
type Input: InferenceInputSignature<Signature = Self>;
type Output: TryFrom<Vec<OutputTensor>, Error = anyhow::Error> + Send;
const KIND: <Self::Group as InferenceGroup>::Kind;
const KIND: Self::Group;
}

pub(crate) trait InferenceInputSignature: Send + 'static {
Expand Down
26 changes: 10 additions & 16 deletions crates/voicevox_core/src/infer/signatures.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
use enum_map::Enum;
use macros::{InferenceInputSignature, TryFromVecOutputTensor};
use macros::{InferenceGroup, InferenceInputSignature, TryFromVecOutputTensor};
use ndarray::{Array0, Array1, Array2};

use super::{InferenceGroup, InferenceSignature, OutputTensor};
use super::{InferenceSignature, OutputTensor};

pub(crate) enum InferenceGroupImpl {}

impl InferenceGroup for InferenceGroupImpl {
type Kind = InferencelKindImpl;
}

#[derive(Clone, Copy, Enum)]
pub(crate) enum InferencelKindImpl {
#[derive(Clone, Copy, Enum, InferenceGroup)]
pub(crate) enum InferenceKind {
PredictDuration,
PredictIntonation,
Decode,
Expand All @@ -20,10 +14,10 @@ pub(crate) enum InferencelKindImpl {
pub(crate) enum PredictDuration {}

impl InferenceSignature for PredictDuration {
type Group = InferenceGroupImpl;
type Group = InferenceKind;
type Input = PredictDurationInput;
type Output = PredictDurationOutput;
const KIND: InferencelKindImpl = InferencelKindImpl::PredictDuration;
const KIND: InferenceKind = InferenceKind::PredictDuration;
}

#[derive(InferenceInputSignature)]
Expand All @@ -41,10 +35,10 @@ pub(crate) struct PredictDurationOutput {
pub(crate) enum PredictIntonation {}

impl InferenceSignature for PredictIntonation {
type Group = InferenceGroupImpl;
type Group = InferenceKind;
type Input = PredictIntonationInput;
type Output = PredictIntonationOutput;
const KIND: InferencelKindImpl = InferencelKindImpl::PredictIntonation;
const KIND: InferenceKind = InferenceKind::PredictIntonation;
}

#[derive(InferenceInputSignature)]
Expand All @@ -68,10 +62,10 @@ pub(crate) struct PredictIntonationOutput {
pub(crate) enum Decode {}

impl InferenceSignature for Decode {
type Group = InferenceGroupImpl;
type Group = InferenceKind;
type Input = DecodeInput;
type Output = DecodeOutput;
const KIND: InferencelKindImpl = InferencelKindImpl::Decode;
const KIND: InferenceKind = InferenceKind::Decode;
}

#[derive(InferenceInputSignature)]
Expand Down
40 changes: 19 additions & 21 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
};

use educe::Educe;
use enum_map::{Enum as _, EnumMap};
use enum_map::EnumMap;
use itertools::iproduct;

use crate::{
Expand All @@ -23,11 +23,11 @@ use super::{

pub(crate) struct Status<R: InferenceRuntime, G: InferenceGroup> {
loaded_models: std::sync::Mutex<LoadedModels<R, G>>,
session_options: EnumMap<G::Kind, InferenceSessionOptions>,
session_options: EnumMap<G, InferenceSessionOptions>,
}

impl<R: InferenceRuntime, G: InferenceGroup> Status<R, G> {
pub fn new(session_options: EnumMap<G::Kind, InferenceSessionOptions>) -> Self {
pub fn new(session_options: EnumMap<G, InferenceSessionOptions>) -> Self {
Self {
loaded_models: Default::default(),
session_options,
Expand All @@ -37,7 +37,7 @@ impl<R: InferenceRuntime, G: InferenceGroup> Status<R, G> {
pub async fn load_model(
&self,
model: &VoiceModel,
model_bytes: &EnumMap<G::Kind, Vec<u8>>,
model_bytes: &EnumMap<G, Vec<u8>>,
) -> Result<()> {
self.loaded_models
.lock()
Expand Down Expand Up @@ -238,13 +238,13 @@ impl<R: InferenceRuntime, G: InferenceGroup> LoadedModels<R, G> {
}

struct SessionSet<R: InferenceRuntime, G: InferenceGroup>(
EnumMap<G::Kind, Arc<std::sync::Mutex<R::Session>>>,
EnumMap<G, Arc<std::sync::Mutex<R::Session>>>,
);

impl<R: InferenceRuntime, G: InferenceGroup> SessionSet<R, G> {
fn new(
model_bytes: &EnumMap<G::Kind, Vec<u8>>,
options: &EnumMap<G::Kind, InferenceSessionOptions>,
model_bytes: &EnumMap<G, Vec<u8>>,
options: &EnumMap<G, InferenceSessionOptions>,
) -> anyhow::Result<Self> {
let mut sessions = model_bytes
.iter()
Expand All @@ -254,7 +254,7 @@ impl<R: InferenceRuntime, G: InferenceGroup> SessionSet<R, G> {
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;

Ok(Self(EnumMap::<G::Kind, _>::from_fn(|k| {
Ok(Self(EnumMap::<G, _>::from_fn(|k| {
sessions.remove(&k.into_usize()).expect("should exist")
})))
}
Expand Down Expand Up @@ -295,10 +295,8 @@ mod tests {
use rstest::rstest;

use crate::{
infer::signatures::{InferenceGroupImpl, InferencelKindImpl},
macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl,
test_util::open_default_vvm_file,
infer::signatures::InferenceKind, macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file,
};

use super::{super::InferenceSessionOptions, Status};
Expand All @@ -315,23 +313,23 @@ mod tests {
let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false);
let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu);
let session_options = enum_map! {
InferencelKindImpl::PredictDuration
| InferencelKindImpl::PredictIntonation => light_session_options,
InferencelKindImpl::Decode => heavy_session_options,
InferenceKind::PredictDuration
| InferenceKind::PredictIntonation => light_session_options,
InferenceKind::Decode => heavy_session_options,
};
let status = Status::<InferenceRuntimeImpl, InferenceGroupImpl>::new(session_options);
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(session_options);

assert_eq!(
light_session_options,
status.session_options[InferencelKindImpl::PredictDuration],
status.session_options[InferenceKind::PredictDuration],
);
assert_eq!(
light_session_options,
status.session_options[InferencelKindImpl::PredictIntonation],
status.session_options[InferenceKind::PredictIntonation],
);
assert_eq!(
heavy_session_options,
status.session_options[InferencelKindImpl::Decode],
status.session_options[InferenceKind::Decode],
);

assert!(status.loaded_models.lock().unwrap().0.is_empty());
Expand All @@ -340,7 +338,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_load_model_works() {
let status = Status::<InferenceRuntimeImpl, InferenceGroupImpl>::new(
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(
enum_map!(_ => InferenceSessionOptions::new(0, false)),
);
let model = &open_default_vvm_file().await;
Expand All @@ -353,7 +351,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_is_model_loaded_works() {
let status = Status::<InferenceRuntimeImpl, InferenceGroupImpl>::new(
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(
enum_map!(_ => InferenceSessionOptions::new(0, false)),
);
let vvm = open_default_vvm_file().await;
Expand Down
12 changes: 6 additions & 6 deletions crates/voicevox_core/src/inference_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use enum_map::enum_map;

use crate::infer::{
signatures::{
DecodeInput, DecodeOutput, InferenceGroupImpl, InferencelKindImpl, PredictDurationInput,
PredictDurationOutput, PredictIntonationInput, PredictIntonationOutput,
DecodeInput, DecodeOutput, InferenceKind, PredictDurationInput, PredictDurationOutput,
PredictIntonationInput, PredictIntonationOutput,
},
status::Status,
InferenceRuntime, InferenceSessionOptions,
Expand All @@ -14,7 +14,7 @@ use super::*;
const PHONEME_LENGTH_MINIMAL: f32 = 0.01;

pub(crate) struct InferenceCore<R: InferenceRuntime> {
status: Status<R, InferenceGroupImpl>,
status: Status<R, InferenceKind>,
}

impl<R: InferenceRuntime> InferenceCore<R> {
Expand All @@ -27,9 +27,9 @@ impl<R: InferenceRuntime> InferenceCore<R> {
let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu);

let status = Status::new(enum_map! {
InferencelKindImpl::PredictDuration
| InferencelKindImpl::PredictIntonation => light_session_options,
InferencelKindImpl::Decode => heavy_session_options,
InferenceKind::PredictDuration
| InferenceKind::PredictIntonation => light_session_options,
InferenceKind::Decode => heavy_session_options,
});
Ok(Self { status })
} else {
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use futures::future::join3;
use serde::{de::DeserializeOwned, Deserialize};

use super::*;
use crate::infer::signatures::InferencelKindImpl;
use crate::infer::signatures::InferenceKind;
use std::{
collections::{BTreeMap, HashMap},
io,
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct VoiceModel {
impl VoiceModel {
pub(crate) async fn read_inference_models(
&self,
) -> LoadModelResult<EnumMap<InferencelKindImpl, Vec<u8>>> {
) -> LoadModelResult<EnumMap<InferenceKind, Vec<u8>>> {
let reader = VvmEntryReader::open(&self.path).await?;
let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) =
join3(
Expand Down
6 changes: 6 additions & 0 deletions crates/voicevox_core_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ use syn::{
Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token,
};

#[proc_macro_derive(InferenceGroup)]
pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let DeriveInput { ident, .. } = parse_macro_input!(input as DeriveInput);
quote!(impl crate::infer::InferenceGroup for #ident {}).into()
}

#[proc_macro_derive(InferenceInputSignature, attributes(input_signature))]
pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
return derive_inference_input_signature(&parse_macro_input!(input))
Expand Down

0 comments on commit c316209

Please sign in to comment.