Skip to content

Commit

Permalink
ランタイムは任意次元任意個数の入出力ができると仮定する
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 11, 2023
1 parent c40afd5 commit 81b5804
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 143 deletions.
12 changes: 2 additions & 10 deletions crates/voicevox_core/src/engine/synthesis_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ use std::sync::Arc;
use super::full_context_label::Utterance;
use super::open_jtalk::OpenJtalk;
use super::*;
use crate::infer::{
signatures::{Decode, PredictDuration, PredictIntonation},
InferenceRuntime, SupportsInferenceSignature,
};
use crate::infer::InferenceRuntime;
use crate::numerics::F32Ext as _;
use crate::InferenceCore;

Expand All @@ -26,12 +23,7 @@ pub(crate) struct SynthesisEngine<R: InferenceRuntime> {
open_jtalk: Arc<OpenJtalk>,
}

impl<
R: SupportsInferenceSignature<PredictDuration>
+ SupportsInferenceSignature<PredictIntonation>
+ SupportsInferenceSignature<Decode>,
> SynthesisEngine<R>
{
impl<R: InferenceRuntime> SynthesisEngine<R> {
pub fn inference_core(&self) -> &InferenceCore<R> {
&self.inference_core
}
Expand Down
114 changes: 70 additions & 44 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,28 @@ use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc};
use derive_new::new;
use easy_ext::ext;
use enum_map::{Enum, EnumMap};
use ndarray::{Array, ArrayD, Dimension, ShapeError};
use thiserror::Error;

use crate::{ErrorRepr, SupportedDevices};

pub(crate) trait InferenceRuntime: 'static {
type Session: InferenceSession;
type Session: Sized + Send + 'static;
type RunContext<'a>: RunContext<'a, Runtime = Self>;

fn supported_devices() -> crate::Result<SupportedDevices>;
}

pub(crate) trait InferenceSession: Sized + Send + 'static {
fn new(
fn new_session(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: InferenceSessionOptions,
) -> anyhow::Result<Self>;
) -> anyhow::Result<Self::Session>;

fn push_input(
input: Array<impl InputScalar, impl Dimension + 'static>,
ctx: &mut Self::RunContext<'_>,
);

fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<AnyTensor>>;
}

pub(crate) trait RunContext<'a>:
Expand All @@ -32,54 +39,56 @@ pub(crate) trait RunContext<'a>:

#[ext(RunContextExt)]
impl<'a, T: RunContext<'a>> T {
fn with_input<I>(mut self, tensor: I) -> Self
where
T::Runtime: SupportsInferenceInputTensor<I>,
{
fn with_input(mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> Self {
T::Runtime::push_input(tensor, &mut self);
self
}
}

pub(crate) trait SupportsInferenceSignature<S: InferenceSignature>:
SupportsInferenceInputSignature<S::Input> + SupportsInferenceOutput<S::Output>
{
pub(crate) trait InferenceGroup {
type Kind: Copy + Enum;
}

impl<
R: SupportsInferenceInputSignature<S::Input> + SupportsInferenceOutput<S::Output>,
S: InferenceSignature,
> SupportsInferenceSignature<S> for R
{
pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Group: InferenceGroup;
type Input: InferenceInputSignature<Signature = Self>;
type Output: TryFrom<Vec<AnyTensor>, Error = anyhow::Error> + Send;
const INFERENCE: <Self::Group as InferenceGroup>::Kind;
}

pub(crate) trait SupportsInferenceInputTensor<I>: InferenceRuntime {
fn push_input(input: I, ctx: &mut Self::RunContext<'_>);
pub(crate) trait InferenceInputSignature: Send + 'static {
type Signature: InferenceSignature<Input = Self>;
fn make_run_context<R: InferenceRuntime>(self, sess: &mut R::Session) -> R::RunContext<'_>;
}

pub(crate) trait SupportsInferenceInputSignature<I: InferenceInputSignature>:
InferenceRuntime
{
fn make_run_context(sess: &mut Self::Session, input: I) -> Self::RunContext<'_>;
}
pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {}

impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait SupportsInferenceOutput<O: Send>: InferenceRuntime {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<O>;
pub(crate) trait OutputScalar: Sized {
fn extract_dyn_dim(tensor: AnyTensor) -> std::result::Result<ArrayD<Self>, ExtractError>;
}

pub(crate) trait InferenceGroup {
type Kind: Copy + Enum;
impl OutputScalar for f32 {
fn extract_dyn_dim(tensor: AnyTensor) -> std::result::Result<ArrayD<Self>, ExtractError> {
match tensor {
AnyTensor::Float32(tensor) => Ok(tensor),
}
}
}

pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Group: InferenceGroup;
type Input: InferenceInputSignature<Signature = Self>;
type Output: Send;
const INFERENCE: <Self::Group as InferenceGroup>::Kind;
pub(crate) enum AnyTensor {
Float32(ArrayD<f32>),
}

pub(crate) trait InferenceInputSignature: Send + 'static {
type Signature: InferenceSignature<Input = Self>;
impl<A: OutputScalar, D: Dimension> TryFrom<AnyTensor> for Array<A, D> {
type Error = ExtractError;

fn try_from(tensor: AnyTensor) -> Result<Self, Self::Error> {
let this = A::extract_dyn_dim(tensor)?.into_dimensionality()?;
Ok(this)
}
}

pub(crate) struct InferenceSessionSet<G: InferenceGroup, R: InferenceRuntime>(
Expand All @@ -94,7 +103,7 @@ impl<G: InferenceGroup, R: InferenceRuntime> InferenceSessionSet<G, R> {
let mut sessions = model_bytes
.iter()
.map(|(k, m)| {
let sess = R::Session::new(|| model_file::decrypt(m), options(k))?;
let sess = R::new_session(|| model_file::decrypt(m), options(k))?;
Ok((k.into_usize(), std::sync::Mutex::new(sess).into()))
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;
Expand Down Expand Up @@ -123,19 +132,16 @@ pub(crate) struct InferenceSessionCell<R: InferenceRuntime, I> {
marker: PhantomData<fn(I)>,
}

impl<
R: SupportsInferenceInputSignature<I>
+ SupportsInferenceOutput<<I::Signature as InferenceSignature>::Output>,
I: InferenceInputSignature,
> InferenceSessionCell<R, I>
{
impl<R: InferenceRuntime, I: InferenceInputSignature> InferenceSessionCell<R, I> {
pub(crate) fn run(
self,
input: I,
) -> crate::Result<<I::Signature as InferenceSignature>::Output> {
let inner = &mut self.inner.lock().unwrap();
let ctx = R::make_run_context(inner, input);
R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into())
let ctx = input.make_run_context::<R>(inner);
R::run(ctx)
.and_then(TryInto::try_into)
.map_err(|e| ErrorRepr::InferenceFailed(e).into())
}
}

Expand All @@ -145,6 +151,26 @@ pub(crate) struct InferenceSessionOptions {
pub(crate) use_gpu: bool,
}

#[derive(Error, Debug)]
pub(crate) enum ExtractError {
#[error(transparent)]
Shape(#[from] ShapeError),
}

#[derive(Error, Debug)]
#[error("不正なモデルファイルです")]
pub(crate) struct DecryptModelError;

mod sealed {
pub(crate) trait InputScalar: OnnxruntimeInputScalar {}

impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait OnnxruntimeInputScalar:
onnxruntime::TypeToTensorElementDataType
{
}

impl<T: onnxruntime::TypeToTensorElementDataType> OnnxruntimeInputScalar for T {}
}
71 changes: 40 additions & 31 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ use std::fmt::Debug;
use ndarray::{Array, Dimension};
use once_cell::sync::Lazy;
use onnxruntime::{
environment::Environment, GraphOptimizationLevel, LoggingLevel, TypeToTensorElementDataType,
environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType,
};

use self::assert_send::AssertSend;
use crate::{
devices::SupportedDevices,
error::ErrorRepr,
infer::{
DecryptModelError, InferenceRuntime, InferenceSession, InferenceSessionOptions, RunContext,
SupportsInferenceInputTensor, SupportsInferenceOutput,
AnyTensor, DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar,
RunContext,
},
};

Expand Down Expand Up @@ -44,13 +44,11 @@ impl InferenceRuntime for Onnxruntime {
dml: dml_support,
})
}
}

impl InferenceSession for AssertSend<onnxruntime::session::Session<'static>> {
fn new(
fn new_session(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: InferenceSessionOptions,
) -> anyhow::Result<Self> {
) -> anyhow::Result<Self::Session> {
let mut builder = ENVIRONMENT
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
Expand All @@ -75,8 +73,8 @@ impl InferenceSession for AssertSend<onnxruntime::session::Session<'static>> {
}

let model = model()?;
let this = builder.with_model_from_memory(model)?.into();
return Ok(this);
let sess = builder.with_model_from_memory(model)?.into();
return Ok(sess);

static ENVIRONMENT: Lazy<Environment> = Lazy::new(|| {
Environment::builder()
Expand All @@ -92,6 +90,39 @@ impl InferenceSession for AssertSend<onnxruntime::session::Session<'static>> {
LoggingLevel::Warning
};
}

fn push_input(
input: Array<impl InputScalar, impl Dimension + 'static>,
ctx: &mut Self::RunContext<'_>,
) {
ctx.inputs
.push(Box::new(onnxruntime::session::NdArray::new(input)));
}

fn run(
OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>,
) -> anyhow::Result<Vec<AnyTensor>> {
// FIXME: 現状では`f32`のみ対応。実行時にsessionからdatatypeが取れるので、別の型の対応も
// おそらく可能ではあるが、それが必要になるよりもortクレートへの引越しが先になると思われる
// のでこのままにする。

if !sess
.outputs
.iter()
.all(|info| matches!(info.output_type, TensorElementDataType::Float))
{
unimplemented!(
"currently only `ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT` is supported for output",
);
}

let outputs = sess.run::<f32>(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?;

Ok(outputs
.iter()
.map(|o| AnyTensor::Float32((*o).clone().into_owned()))
.collect())
}
}

pub(crate) struct OnnxruntimeRunContext<'sess> {
Expand All @@ -114,28 +145,6 @@ impl<'sess> RunContext<'sess> for OnnxruntimeRunContext<'sess> {
type Runtime = Onnxruntime;
}

impl<A: TypeToTensorElementDataType + Debug + 'static, D: Dimension + 'static>
SupportsInferenceInputTensor<Array<A, D>> for Onnxruntime
{
fn push_input(input: Array<A, D>, ctx: &mut Self::RunContext<'_>) {
ctx.inputs
.push(Box::new(onnxruntime::session::NdArray::new(input)));
}
}

impl<T: Send + TypeToTensorElementDataType + Debug + Clone> SupportsInferenceOutput<(Vec<T>,)>
for Onnxruntime
{
fn run(
OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>,
) -> anyhow::Result<(Vec<T>,)> {
let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?;

// FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く
Ok((outputs[0].as_slice().unwrap().to_owned(),))
}
}

// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。
// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614
mod assert_send {
Expand Down
Loading

0 comments on commit 81b5804

Please sign in to comment.