Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

入力テンソルをvisitor patternで捌く #680

Merged
merged 3 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub(crate) mod status;
use std::{borrow::Cow, fmt::Debug};

use derive_new::new;
use duplicate::duplicate_item;
use enum_map::{Enum, EnumMap};
use ndarray::{Array, ArrayD, Dimension, ShapeError};
use thiserror::Error;
Expand All @@ -14,7 +15,7 @@ use crate::SupportedDevices;

pub(crate) trait InferenceRuntime: 'static {
type Session: Sized + Send + 'static;
type RunContext<'a>: From<&'a mut Self::Session>;
type RunContext<'a>: From<&'a mut Self::Session> + PushInputTensor;

fn supported_devices() -> crate::Result<SupportedDevices>;

Expand All @@ -28,11 +29,6 @@ pub(crate) trait InferenceRuntime: 'static {
Vec<ParamInfo<OutputScalarKind>>,
)>;

fn push_input(
input: Array<impl InputScalar, impl Dimension + 'static>,
ctx: &mut Self::RunContext<'_>,
);

fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

Expand Down Expand Up @@ -77,16 +73,29 @@ pub(crate) trait InferenceInputSignature: Send + 'static {
fn make_run_context<R: InferenceRuntime>(self, sess: &mut R::Session) -> R::RunContext<'_>;
}

pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {
pub(crate) trait InputScalar: Sized {
const KIND: InputScalarKind;
}

impl InputScalar for i64 {
const KIND: InputScalarKind = InputScalarKind::Int64;
fn push_tensor_to_ctx(
tensor: Array<Self, impl Dimension + 'static>,
visitor: &mut impl PushInputTensor,
);
}

impl InputScalar for f32 {
const KIND: InputScalarKind = InputScalarKind::Float32;
#[duplicate_item(
T KIND_VAL push;
[ i64 ] [ InputScalarKind::Int64 ] [ push_int64 ];
[ f32 ] [ InputScalarKind::Float32 ] [ push_float32 ];
)]
impl InputScalar for T {
const KIND: InputScalarKind = KIND_VAL;

fn push_tensor_to_ctx(
tensor: Array<Self, impl Dimension + 'static>,
ctx: &mut impl PushInputTensor,
) {
ctx.push(tensor);
}
}

#[derive(Clone, Copy, PartialEq, derive_more::Display)]
Expand All @@ -98,6 +107,11 @@ pub(crate) enum InputScalarKind {
Float32,
}

pub(crate) trait PushInputTensor {
fn push_int64(&mut self, tensor: Array<i64, impl Dimension + 'static>);
fn push_float32(&mut self, tensor: Array<f32, impl Dimension + 'static>);
}
Comment on lines +110 to +113
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

補足: これはenum OutputTensorと対になる感じかと思います。

enum InputTensor { Int64(_), Float32(_) }というのも考えたのですが、経路を遠くしてまで統一性を出す必要は無いかなと思い、この形にしました。


/// 推論操作の出力シグネチャ。
///
/// `::macros::InferenceOutputSignature`により、`TryFrom<OutputTensor>`も含めて導出される。
Expand Down Expand Up @@ -170,19 +184,3 @@ pub(crate) enum ExtractError {
#[derive(Error, Debug)]
#[error("不正なモデルファイルです")]
pub(crate) struct DecryptModelError;

// FIXME: `onnxruntime::TypeToTensorElementDataType`に依存する代わりに、`InputScalar`から`runtimes`
// まではvisitor patternでつなぐ
mod sealed {
pub(crate) trait InputScalar: OnnxruntimeInputScalar {}

impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait OnnxruntimeInputScalar:
onnxruntime::TypeToTensorElementDataType
{
}

impl<T: onnxruntime::TypeToTensorElementDataType> OnnxruntimeInputScalar for T {}
}
35 changes: 25 additions & 10 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use std::{fmt::Debug, vec};

use anyhow::anyhow;
use duplicate::duplicate_item;
use ndarray::{Array, Dimension};
use once_cell::sync::Lazy;
use onnxruntime::{
environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType,
TypeToTensorElementDataType,
};

use crate::{devices::SupportedDevices, error::ErrorRepr};

use self::assert_send::AssertSend;

use super::super::{
DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, InputScalarKind,
OutputScalarKind, OutputTensor, ParamInfo,
DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalarKind,
OutputScalarKind, OutputTensor, ParamInfo, PushInputTensor,
};

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
Expand Down Expand Up @@ -154,14 +156,6 @@ impl InferenceRuntime for Onnxruntime {
};
}

fn push_input(
input: Array<impl InputScalar, impl Dimension + 'static>,
ctx: &mut Self::RunContext<'_>,
) {
ctx.inputs
.push(Box::new(onnxruntime::session::NdArray::new(input)));
}

fn run(
OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>,
) -> anyhow::Result<Vec<OutputTensor>> {
Expand Down Expand Up @@ -193,6 +187,16 @@ pub(crate) struct OnnxruntimeRunContext<'sess> {
inputs: Vec<Box<dyn onnxruntime::session::AnyArray>>,
}

impl OnnxruntimeRunContext<'_> {
fn push_input(
&mut self,
input: Array<impl TypeToTensorElementDataType + Debug + 'static, impl Dimension + 'static>,
) {
self.inputs
.push(Box::new(onnxruntime::session::NdArray::new(input)));
}
}

impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>>
for OnnxruntimeRunContext<'sess>
{
Expand All @@ -204,6 +208,17 @@ impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>>
}
}

impl PushInputTensor for OnnxruntimeRunContext<'_> {
#[duplicate_item(
method T;
[ push_int64 ] [ i64 ];
[ push_float32 ] [ f32 ];
)]
fn method(&mut self, tensor: Array<T, impl Dimension + 'static>) {
self.push_input(tensor);
}
}

// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。
// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614
mod assert_send {
Expand Down
16 changes: 14 additions & 2 deletions crates/voicevox_core_macros/src/inference_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,21 @@ pub(crate) fn derive_inference_input_signature(
) -> R::RunContext<'_> {
let mut ctx = <R::RunContext<'_> as ::std::convert::From<_>>::from(sess);
#(
R::push_input(self.#field_names, &mut ctx);
__ArrayExt::push_to_ctx(self.#field_names, &mut ctx);
)*
ctx
return ctx;

trait __ArrayExt {
fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor);
}

impl<A: crate::infer::InputScalar, D: ::ndarray::Dimension + 'static> __ArrayExt
for ::ndarray::Array<A, D>
{
fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor) {
A::push_tensor_to_ctx(self, ctx);
}
}
}
}
});
Expand Down
Loading