diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index c6b81348a..c816c9899 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -6,6 +6,7 @@ pub(crate) mod status; use std::{borrow::Cow, fmt::Debug}; use derive_new::new; +use duplicate::duplicate_item; use enum_map::{Enum, EnumMap}; use ndarray::{Array, ArrayD, Dimension, ShapeError}; use thiserror::Error; @@ -14,7 +15,7 @@ use crate::SupportedDevices; pub(crate) trait InferenceRuntime: 'static { type Session: Sized + Send + 'static; - type RunContext<'a>: From<&'a mut Self::Session>; + type RunContext<'a>: From<&'a mut Self::Session> + PushInputTensor; fn supported_devices() -> crate::Result; @@ -28,11 +29,6 @@ pub(crate) trait InferenceRuntime: 'static { Vec>, )>; - fn push_input( - input: Array, - ctx: &mut Self::RunContext<'_>, - ); - fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } @@ -77,16 +73,29 @@ pub(crate) trait InferenceInputSignature: Send + 'static { fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>; } -pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static { +pub(crate) trait InputScalar: Sized { const KIND: InputScalarKind; -} -impl InputScalar for i64 { - const KIND: InputScalarKind = InputScalarKind::Int64; + fn push_tensor_to_ctx( + tensor: Array, + visitor: &mut impl PushInputTensor, + ); } -impl InputScalar for f32 { - const KIND: InputScalarKind = InputScalarKind::Float32; +#[duplicate_item( + T KIND_VAL push; + [ i64 ] [ InputScalarKind::Int64 ] [ push_int64 ]; + [ f32 ] [ InputScalarKind::Float32 ] [ push_float32 ]; +)] +impl InputScalar for T { + const KIND: InputScalarKind = KIND_VAL; + + fn push_tensor_to_ctx( + tensor: Array, + ctx: &mut impl PushInputTensor, + ) { + ctx.push(tensor); + } } #[derive(Clone, Copy, PartialEq, derive_more::Display)] @@ -98,6 +107,11 @@ pub(crate) enum InputScalarKind { Float32, } +pub(crate) trait PushInputTensor { + fn push_int64(&mut self, tensor: Array); + fn push_float32(&mut self, tensor: Array); +} + /// 推論操作の出力シグネチャ。 /// /// `::macros::InferenceOutputSignature`により、`TryFrom`も含めて導出される。 @@ -170,19 +184,3 @@ pub(crate) enum ExtractError { #[derive(Error, Debug)] #[error("不正なモデルファイルです")] pub(crate) struct DecryptModelError; - -// FIXME: `onnxruntime::TypeToTensorElementDataType`に依存する代わりに、`InputScalar`から`runtimes` -// まではvisitor patternでつなぐ -mod sealed { - pub(crate) trait InputScalar: OnnxruntimeInputScalar {} - - impl InputScalar for i64 {} - impl InputScalar for f32 {} - - pub(crate) trait OnnxruntimeInputScalar: - onnxruntime::TypeToTensorElementDataType - { - } - - impl OnnxruntimeInputScalar for T {} -} diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index ca5b28aaa..a22503055 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -1,10 +1,12 @@ use std::{fmt::Debug, vec}; use anyhow::anyhow; +use duplicate::duplicate_item; use ndarray::{Array, Dimension}; use once_cell::sync::Lazy; use onnxruntime::{ environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType, + TypeToTensorElementDataType, }; use crate::{devices::SupportedDevices, error::ErrorRepr}; @@ -12,8 +14,8 @@ use crate::{devices::SupportedDevices, error::ErrorRepr}; use self::assert_send::AssertSend; use super::super::{ - DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, InputScalarKind, - OutputScalarKind, OutputTensor, ParamInfo, + DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalarKind, + OutputScalarKind, OutputTensor, ParamInfo, PushInputTensor, }; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -154,14 +156,6 @@ impl InferenceRuntime for Onnxruntime { }; } - fn push_input( - input: Array, - ctx: &mut Self::RunContext<'_>, - ) { - ctx.inputs - .push(Box::new(onnxruntime::session::NdArray::new(input))); - } - fn run( OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, ) -> anyhow::Result> { @@ -193,6 +187,16 @@ pub(crate) struct OnnxruntimeRunContext<'sess> { inputs: Vec>, } +impl OnnxruntimeRunContext<'_> { + fn push_input( + &mut self, + input: Array, + ) { + self.inputs + .push(Box::new(onnxruntime::session::NdArray::new(input))); + } +} + impl<'sess> From<&'sess mut AssertSend>> for OnnxruntimeRunContext<'sess> { @@ -204,6 +208,17 @@ impl<'sess> From<&'sess mut AssertSend>> } } +impl PushInputTensor for OnnxruntimeRunContext<'_> { + #[duplicate_item( + method T; + [ push_int64 ] [ i64 ]; + [ push_float32 ] [ f32 ]; + )] + fn method(&mut self, tensor: Array) { + self.push_input(tensor); + } +} + // FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 // https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 mod assert_send { diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs index 4a447d37d..72bc4d18a 100644 --- a/crates/voicevox_core_macros/src/inference_domain.rs +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -226,9 +226,21 @@ pub(crate) fn derive_inference_input_signature( ) -> R::RunContext<'_> { let mut ctx = as ::std::convert::From<_>>::from(sess); #( - R::push_input(self.#field_names, &mut ctx); + __ArrayExt::push_to_ctx(self.#field_names, &mut ctx); )* - ctx + return ctx; + + trait __ArrayExt { + fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor); + } + + impl __ArrayExt + for ::ndarray::Array + { + fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor) { + A::push_tensor_to_ctx(self, ctx); + } + } } } });