Skip to content

Commit

Permalink
ONNX Runtimeとモデルのシグネチャを隔離する
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 5, 2023
1 parent 9ae1110 commit 5ff2b59
Show file tree
Hide file tree
Showing 11 changed files with 397 additions and 269 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ easy-ext = "1.0.1"
fs-err = { version = "2.9.0", features = ["tokio"] }
futures = "0.3.26"
itertools = "0.10.5"
ndarray = "0.15.6"
once_cell = "1.18.0"
regex = "1.10.0"
rstest = "0.15.0"
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ futures.workspace = true
indexmap = { version = "2.0.0", features = ["serde"] }
itertools.workspace = true
nanoid = "0.4.0"
ndarray.workspace = true
once_cell.workspace = true
regex.workspace = true
serde.workspace = true
Expand Down
92 changes: 92 additions & 0 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
pub(crate) mod runtimes;
pub(crate) mod signatures;

use std::{fmt::Debug, marker::PhantomData, sync::Arc};

use derive_new::new;
use ndarray::{Array, Dimension, LinalgScalar};
use thiserror::Error;

pub(crate) trait InferenceRuntime: Copy {
type Session: Session;
type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>;
}

pub(crate) trait Session: Sized + 'static {
fn new(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: SessionOptions,
) -> anyhow::Result<Self>;
}

pub(crate) trait RunBuilder<'a>:
From<&'a mut <Self::Runtime as InferenceRuntime>::Session>
{
type Runtime: InferenceRuntime;
fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self;
}

pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputScalar {}

impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait Signature: Sized + Send + Sync + 'static {
type SessionSet<R: InferenceRuntime>;
type Output;
fn get_session<R: InferenceRuntime>(
session_set: &Self::SessionSet<R>,
) -> &Arc<std::sync::Mutex<TypedSession<R, Self>>>;
fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>);
}

pub(crate) trait Output<R: InferenceRuntime>: Sized + Send {
fn run(ctx: R::RunBuilder<'_>) -> anyhow::Result<Self>;
}

pub(crate) struct TypedSession<R: InferenceRuntime, I> {
inner: R::Session,
marker: PhantomData<fn(I)>,
}

impl<R: InferenceRuntime, S: Signature> TypedSession<R, S> {
pub(crate) fn new(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: SessionOptions,
) -> anyhow::Result<Self> {
let inner = R::Session::new(model, options)?;
Ok(Self {
inner,
marker: PhantomData,
})
}

pub(crate) fn run(&mut self, sig: S) -> anyhow::Result<S::Output>
where
S::Output: Output<R>,
{
let mut ctx = R::RunBuilder::from(&mut self.inner);
sig.input(&mut ctx);
S::Output::run(ctx)
}
}

#[derive(new, Clone, Copy)]
pub(crate) struct SessionOptions {
pub(crate) cpu_num_threads: u16,
pub(crate) use_gpu: bool,
}

#[derive(Error, Debug)]
#[error("不正なモデルファイルです")]
pub(crate) struct DecryptModelError;

mod sealed {
pub(crate) trait OnnxruntimeInputScalar:
onnxruntime::TypeToTensorElementDataType
{
}

impl OnnxruntimeInputScalar for i64 {}
impl OnnxruntimeInputScalar for f32 {}
}
3 changes: 3 additions & 0 deletions crates/voicevox_core/src/infer/runtimes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod onnxruntime;

pub(crate) use self::onnxruntime::Onnxruntime;
136 changes: 136 additions & 0 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use ndarray::{Array, Dimension};
use once_cell::sync::Lazy;
use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel};

use crate::infer::{
DecryptModelError, InferenceRuntime, InputScalar, Output, RunBuilder, Session, SessionOptions,
};

pub(crate) use self::assert_send::AssertSend;

#[derive(Clone, Copy)]
pub(crate) enum Onnxruntime {}

impl InferenceRuntime for Onnxruntime {
type Session = AssertSend<onnxruntime::session::Session<'static>>;
type RunBuilder<'a> = OnnxruntimeInferenceBuilder<'a>;
}

impl Session for AssertSend<onnxruntime::session::Session<'static>> {
fn new(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: SessionOptions,
) -> anyhow::Result<Self> {
let mut builder = ENVIRONMENT
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
.with_intra_op_num_threads(options.cpu_num_threads.into())?
.with_inter_op_num_threads(options.cpu_num_threads.into())?;

if options.use_gpu {
#[cfg(feature = "directml")]
{
use onnxruntime::ExecutionMode;

builder = builder
.with_disable_mem_pattern()?
.with_execution_mode(ExecutionMode::ORT_SEQUENTIAL)?
.with_append_execution_provider_directml(0)?;
}

#[cfg(not(feature = "directml"))]
{
builder = builder.with_append_execution_provider_cuda(Default::default())?;
}
}

let model = model()?;
let this = builder.with_model_from_memory(model)?.into();
return Ok(this);

static ENVIRONMENT: Lazy<Environment> = Lazy::new(|| {
Environment::builder()
.with_name(env!("CARGO_PKG_NAME"))
.with_log_level(LOGGING_LEVEL)
.build()
.unwrap()
});

const LOGGING_LEVEL: LoggingLevel = if cfg!(debug_assertions) {
LoggingLevel::Verbose
} else {
LoggingLevel::Warning
};
}
}

pub(crate) struct OnnxruntimeInferenceBuilder<'sess> {
sess: &'sess mut AssertSend<onnxruntime::session::Session<'static>>,
inputs: Vec<Box<dyn onnxruntime::session::AnyArray>>,
}

impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>>
for OnnxruntimeInferenceBuilder<'sess>
{
fn from(sess: &'sess mut AssertSend<onnxruntime::session::Session<'static>>) -> Self {
Self {
sess,
inputs: vec![],
}
}
}

impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> {
type Runtime = Onnxruntime;

fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self {
self.inputs
.push(Box::new(onnxruntime::session::NdArray::new(tensor)));
self
}
}

impl Output<Onnxruntime> for (Vec<f32>,) {
fn run(
OnnxruntimeInferenceBuilder { sess, mut inputs }: OnnxruntimeInferenceBuilder<'_>,
) -> anyhow::Result<Self> {
let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?;

// FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く
Ok((outputs[0].as_slice().unwrap().to_owned(),))
}
}

// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。
// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614
mod assert_send {
use std::ops::{Deref, DerefMut};

pub(crate) struct AssertSend<T>(T);

impl From<onnxruntime::session::Session<'static>>
for AssertSend<onnxruntime::session::Session<'static>>
{
fn from(session: onnxruntime::session::Session<'static>) -> Self {
Self(session)
}
}

impl<T> Deref for AssertSend<T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<T> DerefMut for AssertSend<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

// SAFETY: `Session` is probably "send"able.
#[allow(unsafe_code)]
unsafe impl<T> Send for AssertSend<T> {}
}
87 changes: 87 additions & 0 deletions crates/voicevox_core/src/infer/signatures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use std::sync::Arc;

use ndarray::{Array0, Array1, Array2};

use crate::infer::{InferenceRuntime, RunBuilder, Signature, TypedSession};

pub(crate) struct SessionSet<R: InferenceRuntime> {
pub(crate) predict_duration: Arc<std::sync::Mutex<TypedSession<R, PredictDuration>>>,
pub(crate) predict_intonation: Arc<std::sync::Mutex<TypedSession<R, PredictIntonation>>>,
pub(crate) decode: Arc<std::sync::Mutex<TypedSession<R, Decode>>>,
}

pub(crate) struct PredictDuration {
pub(crate) phoneme: Array1<i64>,
pub(crate) speaker_id: Array1<i64>,
}

impl Signature for PredictDuration {
type SessionSet<R: InferenceRuntime> = SessionSet<R>;
type Output = (Vec<f32>,);

fn get_session<R: InferenceRuntime>(
session_set: &Self::SessionSet<R>,
) -> &Arc<std::sync::Mutex<TypedSession<R, Self>>> {
&session_set.predict_duration
}

fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) {
ctx.input(self.phoneme).input(self.speaker_id);
}
}

pub(crate) struct PredictIntonation {
pub(crate) length: Array0<i64>,
pub(crate) vowel_phoneme: Array1<i64>,
pub(crate) consonant_phoneme: Array1<i64>,
pub(crate) start_accent: Array1<i64>,
pub(crate) end_accent: Array1<i64>,
pub(crate) start_accent_phrase: Array1<i64>,
pub(crate) end_accent_phrase: Array1<i64>,
pub(crate) speaker_id: Array1<i64>,
}

impl Signature for PredictIntonation {
type SessionSet<R: InferenceRuntime> = SessionSet<R>;
type Output = (Vec<f32>,);

fn get_session<R: InferenceRuntime>(
session_set: &Self::SessionSet<R>,
) -> &Arc<std::sync::Mutex<TypedSession<R, Self>>> {
&session_set.predict_intonation
}

fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) {
ctx.input(self.length)
.input(self.vowel_phoneme)
.input(self.consonant_phoneme)
.input(self.start_accent)
.input(self.end_accent)
.input(self.start_accent_phrase)
.input(self.end_accent_phrase)
.input(self.speaker_id);
}
}

pub(crate) struct Decode {
pub(crate) f0: Array2<f32>,
pub(crate) phoneme: Array2<f32>,
pub(crate) speaker_id: Array1<i64>,
}

impl Signature for Decode {
type SessionSet<R: InferenceRuntime> = SessionSet<R>;
type Output = (Vec<f32>,);

fn get_session<R: InferenceRuntime>(
session_set: &Self::SessionSet<R>,
) -> &Arc<std::sync::Mutex<TypedSession<R, Self>>> {
&session_set.decode
}

fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) {
ctx.input(self.f0)
.input(self.phoneme)
.input(self.speaker_id);
}
}
Loading

0 comments on commit 5ff2b59

Please sign in to comment.