Skip to content

Commit

Permalink
Fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 8, 2023
1 parent e0f29c6 commit cb1db34
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 41 deletions.
26 changes: 20 additions & 6 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ pub(crate) mod signatures;
use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc};

use derive_new::new;
use easy_ext::ext;
use enum_map::{Enum, EnumMap};
use thiserror::Error;

use crate::{ErrorRepr, SupportedDevices};

pub(crate) trait InferenceRuntime: 'static {
type Session: InferenceSession;
type RunContext<'a>: RunContext<'a, Session = Self::Session>;
type RunContext<'a>: RunContext<'a, Runtime = Self>;
fn supported_devices() -> crate::Result<SupportedDevices>;
}

Expand All @@ -23,8 +24,21 @@ pub(crate) trait InferenceSession: Sized + Send + 'static {
) -> anyhow::Result<Self>;
}

pub(crate) trait RunContext<'a>: From<&'a mut Self::Session> {
type Session: InferenceSession;
pub(crate) trait RunContext<'a>:
From<&'a mut <Self::Runtime as InferenceRuntime>::Session>
{
type Runtime: InferenceRuntime<RunContext<'a> = Self>;
}

#[ext(RunContextExt)]
impl<'a, T: RunContext<'a>> T {
fn input<I>(&mut self, tensor: I) -> &mut Self
where
T::Runtime: SupportsInferenceInputTensor<I>,
{
<T::Runtime as SupportsInferenceInputTensor<_>>::input(tensor, self);
self
}
}

pub(crate) trait SupportsInferenceSignature<S: InferenceSignature>:
Expand All @@ -40,11 +54,11 @@ impl<
}

pub(crate) trait SupportsInferenceInputTensor<I>: InferenceRuntime {
fn input(ctx: &mut Self::RunContext<'_>, tensor: I);
fn input(tensor: I, ctx: &mut Self::RunContext<'_>);
}

pub(crate) trait SupportsInferenceInputTensors<I: InferenceInput>: InferenceRuntime {
fn input(ctx: &mut Self::RunContext<'_>, tensors: I);
fn input(tensors: I, ctx: &mut Self::RunContext<'_>);
}

pub(crate) trait SupportsInferenceOutput<O: Send>: InferenceRuntime {
Expand Down Expand Up @@ -115,7 +129,7 @@ impl<
) -> crate::Result<<I::Signature as InferenceSignature>::Output> {
let mut inner = self.inner.lock().unwrap();
let mut ctx = R::RunContext::from(&mut inner);
R::input(&mut ctx, input);
R::input(input, &mut ctx);
R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into())
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>>
}

impl<'sess> RunContext<'sess> for OnnxruntimeInferenceBuilder<'sess> {
type Session = AssertSend<onnxruntime::session::Session<'static>>;
type Runtime = Onnxruntime;
}

impl<A: TypeToTensorElementDataType + Debug + 'static, D: Dimension + 'static>
SupportsInferenceInputTensor<Array<A, D>> for Onnxruntime
{
fn input(ctx: &mut Self::RunContext<'_>, tensor: Array<A, D>) {
fn input(tensor: Array<A, D>, ctx: &mut Self::RunContext<'_>) {
ctx.inputs
.push(Box::new(onnxruntime::session::NdArray::new(tensor)));
}
Expand Down
50 changes: 17 additions & 33 deletions crates/voicevox_core/src/infer/signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,10 @@ use enum_map::Enum;
use ndarray::{Array0, Array1, Array2};

use crate::infer::{
InferenceInput, InferenceSignature, SupportsInferenceInputTensor,
SupportsInferenceInputTensors, SupportsInferenceSignature,
InferenceInput, InferenceSignature, RunContextExt as _, SupportsInferenceInputTensor,
SupportsInferenceInputTensors,
};

pub(crate) trait SupportsAllSignatures:
SupportsInferenceSignature<PredictDuration>
+ SupportsInferenceSignature<PredictIntonation>
+ SupportsInferenceSignature<Decode>
{
}

impl<
R: SupportsInferenceSignature<PredictDuration>
+ SupportsInferenceSignature<PredictIntonation>
+ SupportsInferenceSignature<Decode>,
> SupportsAllSignatures for R
{
}

#[derive(Clone, Copy, Enum)]
pub(crate) enum SignatureKind {
PredictDuration,
Expand Down Expand Up @@ -49,9 +34,8 @@ impl InferenceInput for PredictDurationInput {
impl<R: SupportsInferenceInputTensor<Array1<i64>>>
SupportsInferenceInputTensors<PredictDurationInput> for R
{
fn input(ctx: &mut R::RunContext<'_>, input: PredictDurationInput) {
R::input(ctx, input.phoneme);
R::input(ctx, input.speaker_id);
fn input(input: PredictDurationInput, ctx: &mut R::RunContext<'_>) {
ctx.input(input.phoneme).input(input.speaker_id);
}
}

Expand Down Expand Up @@ -82,15 +66,15 @@ impl InferenceInput for PredictIntonationInput {
impl<R: SupportsInferenceInputTensor<Array0<i64>> + SupportsInferenceInputTensor<Array1<i64>>>
SupportsInferenceInputTensors<PredictIntonationInput> for R
{
fn input(ctx: &mut R::RunContext<'_>, input: PredictIntonationInput) {
R::input(ctx, input.length);
R::input(ctx, input.vowel_phoneme);
R::input(ctx, input.consonant_phoneme);
R::input(ctx, input.start_accent);
R::input(ctx, input.end_accent);
R::input(ctx, input.start_accent_phrase);
R::input(ctx, input.end_accent_phrase);
R::input(ctx, input.speaker_id);
fn input(input: PredictIntonationInput, ctx: &mut R::RunContext<'_>) {
ctx.input(input.length)
.input(input.vowel_phoneme)
.input(input.consonant_phoneme)
.input(input.start_accent)
.input(input.end_accent)
.input(input.start_accent_phrase)
.input(input.end_accent_phrase)
.input(input.speaker_id);
}
}

Expand All @@ -116,9 +100,9 @@ impl InferenceInput for DecodeInput {
impl<R: SupportsInferenceInputTensor<Array1<i64>> + SupportsInferenceInputTensor<Array2<f32>>>
SupportsInferenceInputTensors<DecodeInput> for R
{
fn input(ctx: &mut R::RunContext<'_>, input: DecodeInput) {
R::input(ctx, input.f0);
R::input(ctx, input.phoneme);
R::input(ctx, input.speaker_id);
fn input(input: DecodeInput, ctx: &mut R::RunContext<'_>) {
ctx.input(input.f0)
.input(input.phoneme)
.input(input.speaker_id);
}
}

0 comments on commit cb1db34

Please sign in to comment.