Skip to content

Commit

Permalink
InferenceInputInferenceInputSignature
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 8, 2023
1 parent a5dbbdd commit 525f4b1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 26 deletions.
22 changes: 12 additions & 10 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,25 @@ impl<'a, T: RunContext<'a>> T {
}

pub(crate) trait SupportsInferenceSignature<S: InferenceSignature>:
SupportsInferenceInputTensors<S::Input> + SupportsInferenceOutput<S::Output>
SupportsInferenceInputSignature<S::Input> + SupportsInferenceOutput<S::Output>
{
}

impl<
R: SupportsInferenceInputTensors<S::Input> + SupportsInferenceOutput<S::Output>,
R: SupportsInferenceInputSignature<S::Input> + SupportsInferenceOutput<S::Output>,
S: InferenceSignature,
> SupportsInferenceSignature<S> for R
{
}

pub(crate) trait SupportsInferenceInputTensor<I>: InferenceRuntime {
fn input(tensor: I, ctx: &mut Self::RunContext<'_>);
fn input(input: I, ctx: &mut Self::RunContext<'_>);
}

pub(crate) trait SupportsInferenceInputTensors<I: InferenceInput>: InferenceRuntime {
fn input(tensors: I, ctx: &mut Self::RunContext<'_>);
pub(crate) trait SupportsInferenceInputSignature<I: InferenceInputSignature>:
InferenceRuntime
{
fn input(input: I, ctx: &mut Self::RunContext<'_>);
}

pub(crate) trait SupportsInferenceOutput<O: Send>: InferenceRuntime {
Expand All @@ -67,12 +69,12 @@ pub(crate) trait SupportsInferenceOutput<O: Send>: InferenceRuntime {

pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Kind: Enum + Copy;
type Input: InferenceInput;
type Input: InferenceInputSignature;
type Output: Send;
const KIND: Self::Kind;
}

pub(crate) trait InferenceInput: Send + 'static {
pub(crate) trait InferenceInputSignature: Send + 'static {
type Signature: InferenceSignature;
}

Expand Down Expand Up @@ -102,7 +104,7 @@ impl<K: Enum + Copy, R: InferenceRuntime> InferenceSessionSet<K, R> {
impl<K: Enum, R: InferenceRuntime> InferenceSessionSet<K, R> {
pub(crate) fn get<I>(&self) -> InferenceSessionCell<R, I>
where
I: InferenceInput,
I: InferenceInputSignature,
I::Signature: InferenceSignature<Kind = K>,
{
InferenceSessionCell {
Expand All @@ -118,9 +120,9 @@ pub(crate) struct InferenceSessionCell<R: InferenceRuntime, I> {
}

impl<
R: SupportsInferenceInputTensors<I>
R: SupportsInferenceInputSignature<I>
+ SupportsInferenceOutput<<I::Signature as InferenceSignature>::Output>,
I: InferenceInput,
I: InferenceInputSignature,
> InferenceSessionCell<R, I>
{
pub(crate) fn run(
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ impl<'sess> RunContext<'sess> for OnnxruntimeRunContext<'sess> {
impl<A: TypeToTensorElementDataType + Debug + 'static, D: Dimension + 'static>
SupportsInferenceInputTensor<Array<A, D>> for Onnxruntime
{
fn input(tensor: Array<A, D>, ctx: &mut Self::RunContext<'_>) {
fn input(input: Array<A, D>, ctx: &mut Self::RunContext<'_>) {
ctx.inputs
.push(Box::new(onnxruntime::session::NdArray::new(tensor)));
.push(Box::new(onnxruntime::session::NdArray::new(input)));
}
}

Expand Down
16 changes: 8 additions & 8 deletions crates/voicevox_core/src/infer/signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use enum_map::Enum;
use ndarray::{Array0, Array1, Array2};

use crate::infer::{
InferenceInput, InferenceSignature, RunContextExt as _, SupportsInferenceInputTensor,
SupportsInferenceInputTensors,
InferenceInputSignature, InferenceSignature, RunContextExt as _,
SupportsInferenceInputSignature, SupportsInferenceInputTensor,
};

#[derive(Clone, Copy, Enum)]
Expand All @@ -27,12 +27,12 @@ pub(crate) struct PredictDurationInput {
pub(crate) speaker_id: Array1<i64>,
}

impl InferenceInput for PredictDurationInput {
impl InferenceInputSignature for PredictDurationInput {
type Signature = PredictDuration;
}

impl<R: SupportsInferenceInputTensor<Array1<i64>>>
SupportsInferenceInputTensors<PredictDurationInput> for R
SupportsInferenceInputSignature<PredictDurationInput> for R
{
fn input(input: PredictDurationInput, ctx: &mut R::RunContext<'_>) {
ctx.input(input.phoneme).input(input.speaker_id);
Expand All @@ -59,12 +59,12 @@ pub(crate) struct PredictIntonationInput {
pub(crate) speaker_id: Array1<i64>,
}

impl InferenceInput for PredictIntonationInput {
impl InferenceInputSignature for PredictIntonationInput {
type Signature = PredictIntonation;
}

impl<R: SupportsInferenceInputTensor<Array0<i64>> + SupportsInferenceInputTensor<Array1<i64>>>
SupportsInferenceInputTensors<PredictIntonationInput> for R
SupportsInferenceInputSignature<PredictIntonationInput> for R
{
fn input(input: PredictIntonationInput, ctx: &mut R::RunContext<'_>) {
ctx.input(input.length)
Expand Down Expand Up @@ -93,12 +93,12 @@ pub(crate) struct DecodeInput {
pub(crate) speaker_id: Array1<i64>,
}

impl InferenceInput for DecodeInput {
impl InferenceInputSignature for DecodeInput {
type Signature = Decode;
}

impl<R: SupportsInferenceInputTensor<Array1<i64>> + SupportsInferenceInputTensor<Array2<f32>>>
SupportsInferenceInputTensors<DecodeInput> for R
SupportsInferenceInputSignature<DecodeInput> for R
{
fn input(input: DecodeInput, ctx: &mut R::RunContext<'_>) {
ctx.input(input.f0)
Expand Down
12 changes: 6 additions & 6 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::*;
use crate::infer::{
signatures::InferenceSignatureKind, InferenceInput, InferenceRuntime, InferenceSessionCell,
InferenceSessionOptions, InferenceSessionSet, InferenceSignature,
SupportsInferenceInputTensors, SupportsInferenceOutput,
signatures::InferenceSignatureKind, InferenceInputSignature, InferenceRuntime,
InferenceSessionCell, InferenceSessionOptions, InferenceSessionSet, InferenceSignature,
SupportsInferenceInputSignature, SupportsInferenceOutput,
};
use educe::Educe;
use itertools::iproduct;
Expand Down Expand Up @@ -87,9 +87,9 @@ impl<R: InferenceRuntime> Status<R> {
input: I,
) -> Result<<I::Signature as InferenceSignature>::Output>
where
I: InferenceInput,
I: InferenceInputSignature,
I::Signature: InferenceSignature<Kind = InferenceSignatureKind>,
R: SupportsInferenceInputTensors<I>
R: SupportsInferenceInputSignature<I>
+ SupportsInferenceOutput<<I::Signature as InferenceSignature>::Output>,
{
let sess = self.loaded_models.lock().unwrap().get(model_id);
Expand Down Expand Up @@ -151,7 +151,7 @@ impl<R: InferenceRuntime> LoadedModels<R> {
/// `self`が`model_id`を含んでいないとき、パニックする。
fn get<I>(&self, model_id: &VoiceModelId) -> InferenceSessionCell<R, I>
where
I: InferenceInput,
I: InferenceInputSignature,
I::Signature: InferenceSignature<Kind = InferenceSignatureKind>,
{
self.0[model_id].session_set.get()
Expand Down

0 comments on commit 525f4b1

Please sign in to comment.