Skip to content

Commit

Permalink
入力テンソルをvisitor patternで捌く (#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip authored Nov 17, 2023
1 parent 3174372 commit 3c9b09d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 40 deletions.
54 changes: 26 additions & 28 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub(crate) mod status;
use std::{borrow::Cow, fmt::Debug};

use derive_new::new;
use duplicate::duplicate_item;
use enum_map::{Enum, EnumMap};
use ndarray::{Array, ArrayD, Dimension, ShapeError};
use thiserror::Error;
Expand All @@ -14,7 +15,7 @@ use crate::SupportedDevices;

pub(crate) trait InferenceRuntime: 'static {
type Session: Sized + Send + 'static;
type RunContext<'a>: From<&'a mut Self::Session>;
type RunContext<'a>: From<&'a mut Self::Session> + PushInputTensor;

fn supported_devices() -> crate::Result<SupportedDevices>;

Expand All @@ -28,11 +29,6 @@ pub(crate) trait InferenceRuntime: 'static {
Vec<ParamInfo<OutputScalarKind>>,
)>;

fn push_input(
input: Array<impl InputScalar, impl Dimension + 'static>,
ctx: &mut Self::RunContext<'_>,
);

fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

Expand Down Expand Up @@ -77,16 +73,29 @@ pub(crate) trait InferenceInputSignature: Send + 'static {
fn make_run_context<R: InferenceRuntime>(self, sess: &mut R::Session) -> R::RunContext<'_>;
}

pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {
pub(crate) trait InputScalar: Sized {
const KIND: InputScalarKind;
}

impl InputScalar for i64 {
const KIND: InputScalarKind = InputScalarKind::Int64;
fn push_tensor_to_ctx(
tensor: Array<Self, impl Dimension + 'static>,
visitor: &mut impl PushInputTensor,
);
}

impl InputScalar for f32 {
const KIND: InputScalarKind = InputScalarKind::Float32;
#[duplicate_item(
T KIND_VAL push;
[ i64 ] [ InputScalarKind::Int64 ] [ push_int64 ];
[ f32 ] [ InputScalarKind::Float32 ] [ push_float32 ];
)]
impl InputScalar for T {
const KIND: InputScalarKind = KIND_VAL;

fn push_tensor_to_ctx(
tensor: Array<Self, impl Dimension + 'static>,
ctx: &mut impl PushInputTensor,
) {
ctx.push(tensor);
}
}

#[derive(Clone, Copy, PartialEq, derive_more::Display)]
Expand All @@ -98,6 +107,11 @@ pub(crate) enum InputScalarKind {
Float32,
}

pub(crate) trait PushInputTensor {
fn push_int64(&mut self, tensor: Array<i64, impl Dimension + 'static>);
fn push_float32(&mut self, tensor: Array<f32, impl Dimension + 'static>);
}

/// 推論操作の出力シグネチャ。
///
/// `::macros::InferenceOutputSignature`により、`TryFrom<OutputTensor>`も含めて導出される。
Expand Down Expand Up @@ -170,19 +184,3 @@ pub(crate) enum ExtractError {
#[derive(Error, Debug)]
#[error("不正なモデルファイルです")]
pub(crate) struct DecryptModelError;

// FIXME: `onnxruntime::TypeToTensorElementDataType`に依存する代わりに、`InputScalar`から`runtimes`
// まではvisitor patternでつなぐ
mod sealed {
pub(crate) trait InputScalar: OnnxruntimeInputScalar {}

impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait OnnxruntimeInputScalar:
onnxruntime::TypeToTensorElementDataType
{
}

impl<T: onnxruntime::TypeToTensorElementDataType> OnnxruntimeInputScalar for T {}
}
35 changes: 25 additions & 10 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use std::{fmt::Debug, vec};

use anyhow::anyhow;
use duplicate::duplicate_item;
use ndarray::{Array, Dimension};
use once_cell::sync::Lazy;
use onnxruntime::{
environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType,
TypeToTensorElementDataType,
};

use crate::{devices::SupportedDevices, error::ErrorRepr};

use self::assert_send::AssertSend;

use super::super::{
DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, InputScalarKind,
OutputScalarKind, OutputTensor, ParamInfo,
DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalarKind,
OutputScalarKind, OutputTensor, ParamInfo, PushInputTensor,
};

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
Expand Down Expand Up @@ -154,14 +156,6 @@ impl InferenceRuntime for Onnxruntime {
};
}

fn push_input(
input: Array<impl InputScalar, impl Dimension + 'static>,
ctx: &mut Self::RunContext<'_>,
) {
ctx.inputs
.push(Box::new(onnxruntime::session::NdArray::new(input)));
}

fn run(
OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>,
) -> anyhow::Result<Vec<OutputTensor>> {
Expand Down Expand Up @@ -193,6 +187,16 @@ pub(crate) struct OnnxruntimeRunContext<'sess> {
inputs: Vec<Box<dyn onnxruntime::session::AnyArray>>,
}

impl OnnxruntimeRunContext<'_> {
fn push_input(
&mut self,
input: Array<impl TypeToTensorElementDataType + Debug + 'static, impl Dimension + 'static>,
) {
self.inputs
.push(Box::new(onnxruntime::session::NdArray::new(input)));
}
}

impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>>
for OnnxruntimeRunContext<'sess>
{
Expand All @@ -204,6 +208,17 @@ impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>>
}
}

impl PushInputTensor for OnnxruntimeRunContext<'_> {
#[duplicate_item(
method T;
[ push_int64 ] [ i64 ];
[ push_float32 ] [ f32 ];
)]
fn method(&mut self, tensor: Array<T, impl Dimension + 'static>) {
self.push_input(tensor);
}
}

// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。
// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614
mod assert_send {
Expand Down
16 changes: 14 additions & 2 deletions crates/voicevox_core_macros/src/inference_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,21 @@ pub(crate) fn derive_inference_input_signature(
) -> R::RunContext<'_> {
let mut ctx = <R::RunContext<'_> as ::std::convert::From<_>>::from(sess);
#(
R::push_input(self.#field_names, &mut ctx);
__ArrayExt::push_to_ctx(self.#field_names, &mut ctx);
)*
ctx
return ctx;

trait __ArrayExt {
fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor);
}

impl<A: crate::infer::InputScalar, D: ::ndarray::Dimension + 'static> __ArrayExt
for ::ndarray::Array<A, D>
{
fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor) {
A::push_tensor_to_ctx(self, ctx);
}
}
}
}
});
Expand Down

0 comments on commit 3c9b09d

Please sign in to comment.