-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
397 additions
and
269 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
pub(crate) mod runtimes; | ||
pub(crate) mod signatures; | ||
|
||
use std::{fmt::Debug, marker::PhantomData, sync::Arc}; | ||
|
||
use derive_new::new; | ||
use ndarray::{Array, Dimension, LinalgScalar}; | ||
use thiserror::Error; | ||
|
||
pub(crate) trait InferenceRuntime: Copy { | ||
type Session: Session; | ||
type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>; | ||
} | ||
|
||
pub(crate) trait Session: Sized + 'static { | ||
fn new( | ||
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>, | ||
options: SessionOptions, | ||
) -> anyhow::Result<Self>; | ||
} | ||
|
||
pub(crate) trait RunBuilder<'a>: | ||
From<&'a mut <Self::Runtime as InferenceRuntime>::Session> | ||
{ | ||
type Runtime: InferenceRuntime; | ||
fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self; | ||
} | ||
|
||
pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputScalar {} | ||
|
||
impl InputScalar for i64 {} | ||
impl InputScalar for f32 {} | ||
|
||
pub(crate) trait Signature: Sized + Send + Sync + 'static { | ||
type SessionSet<R: InferenceRuntime>; | ||
type Output; | ||
fn get_session<R: InferenceRuntime>( | ||
session_set: &Self::SessionSet<R>, | ||
) -> &Arc<std::sync::Mutex<TypedSession<R, Self>>>; | ||
fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>); | ||
} | ||
|
||
pub(crate) trait Output<R: InferenceRuntime>: Sized + Send { | ||
fn run(ctx: R::RunBuilder<'_>) -> anyhow::Result<Self>; | ||
} | ||
|
||
pub(crate) struct TypedSession<R: InferenceRuntime, I> { | ||
inner: R::Session, | ||
marker: PhantomData<fn(I)>, | ||
} | ||
|
||
impl<R: InferenceRuntime, S: Signature> TypedSession<R, S> { | ||
pub(crate) fn new( | ||
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>, | ||
options: SessionOptions, | ||
) -> anyhow::Result<Self> { | ||
let inner = R::Session::new(model, options)?; | ||
Ok(Self { | ||
inner, | ||
marker: PhantomData, | ||
}) | ||
} | ||
|
||
pub(crate) fn run(&mut self, sig: S) -> anyhow::Result<S::Output> | ||
where | ||
S::Output: Output<R>, | ||
{ | ||
let mut ctx = R::RunBuilder::from(&mut self.inner); | ||
sig.input(&mut ctx); | ||
S::Output::run(ctx) | ||
} | ||
} | ||
|
||
#[derive(new, Clone, Copy)] | ||
pub(crate) struct SessionOptions { | ||
pub(crate) cpu_num_threads: u16, | ||
pub(crate) use_gpu: bool, | ||
} | ||
|
||
#[derive(Error, Debug)] | ||
#[error("不正なモデルファイルです")] | ||
pub(crate) struct DecryptModelError; | ||
|
||
mod sealed { | ||
pub(crate) trait OnnxruntimeInputScalar: | ||
onnxruntime::TypeToTensorElementDataType | ||
{ | ||
} | ||
|
||
impl OnnxruntimeInputScalar for i64 {} | ||
impl OnnxruntimeInputScalar for f32 {} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mod onnxruntime; | ||
|
||
pub(crate) use self::onnxruntime::Onnxruntime; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
use ndarray::{Array, Dimension}; | ||
use once_cell::sync::Lazy; | ||
use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel}; | ||
|
||
use crate::infer::{ | ||
DecryptModelError, InferenceRuntime, InputScalar, Output, RunBuilder, Session, SessionOptions, | ||
}; | ||
|
||
pub(crate) use self::assert_send::AssertSend; | ||
|
||
#[derive(Clone, Copy)] | ||
pub(crate) enum Onnxruntime {} | ||
|
||
impl InferenceRuntime for Onnxruntime { | ||
type Session = AssertSend<onnxruntime::session::Session<'static>>; | ||
type RunBuilder<'a> = OnnxruntimeInferenceBuilder<'a>; | ||
} | ||
|
||
impl Session for AssertSend<onnxruntime::session::Session<'static>> { | ||
fn new( | ||
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>, | ||
options: SessionOptions, | ||
) -> anyhow::Result<Self> { | ||
let mut builder = ENVIRONMENT | ||
.new_session_builder()? | ||
.with_optimization_level(GraphOptimizationLevel::Basic)? | ||
.with_intra_op_num_threads(options.cpu_num_threads.into())? | ||
.with_inter_op_num_threads(options.cpu_num_threads.into())?; | ||
|
||
if options.use_gpu { | ||
#[cfg(feature = "directml")] | ||
{ | ||
use onnxruntime::ExecutionMode; | ||
|
||
builder = builder | ||
.with_disable_mem_pattern()? | ||
.with_execution_mode(ExecutionMode::ORT_SEQUENTIAL)? | ||
.with_append_execution_provider_directml(0)?; | ||
} | ||
|
||
#[cfg(not(feature = "directml"))] | ||
{ | ||
builder = builder.with_append_execution_provider_cuda(Default::default())?; | ||
} | ||
} | ||
|
||
let model = model()?; | ||
let this = builder.with_model_from_memory(model)?.into(); | ||
return Ok(this); | ||
|
||
static ENVIRONMENT: Lazy<Environment> = Lazy::new(|| { | ||
Environment::builder() | ||
.with_name(env!("CARGO_PKG_NAME")) | ||
.with_log_level(LOGGING_LEVEL) | ||
.build() | ||
.unwrap() | ||
}); | ||
|
||
const LOGGING_LEVEL: LoggingLevel = if cfg!(debug_assertions) { | ||
LoggingLevel::Verbose | ||
} else { | ||
LoggingLevel::Warning | ||
}; | ||
} | ||
} | ||
|
||
pub(crate) struct OnnxruntimeInferenceBuilder<'sess> { | ||
sess: &'sess mut AssertSend<onnxruntime::session::Session<'static>>, | ||
inputs: Vec<Box<dyn onnxruntime::session::AnyArray>>, | ||
} | ||
|
||
impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>> | ||
for OnnxruntimeInferenceBuilder<'sess> | ||
{ | ||
fn from(sess: &'sess mut AssertSend<onnxruntime::session::Session<'static>>) -> Self { | ||
Self { | ||
sess, | ||
inputs: vec![], | ||
} | ||
} | ||
} | ||
|
||
impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> { | ||
type Runtime = Onnxruntime; | ||
|
||
fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self { | ||
self.inputs | ||
.push(Box::new(onnxruntime::session::NdArray::new(tensor))); | ||
self | ||
} | ||
} | ||
|
||
impl Output<Onnxruntime> for (Vec<f32>,) { | ||
fn run( | ||
OnnxruntimeInferenceBuilder { sess, mut inputs }: OnnxruntimeInferenceBuilder<'_>, | ||
) -> anyhow::Result<Self> { | ||
let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; | ||
|
||
// FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く | ||
Ok((outputs[0].as_slice().unwrap().to_owned(),)) | ||
} | ||
} | ||
|
||
// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 | ||
// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 | ||
mod assert_send { | ||
use std::ops::{Deref, DerefMut}; | ||
|
||
pub(crate) struct AssertSend<T>(T); | ||
|
||
impl From<onnxruntime::session::Session<'static>> | ||
for AssertSend<onnxruntime::session::Session<'static>> | ||
{ | ||
fn from(session: onnxruntime::session::Session<'static>) -> Self { | ||
Self(session) | ||
} | ||
} | ||
|
||
impl<T> Deref for AssertSend<T> { | ||
type Target = T; | ||
|
||
fn deref(&self) -> &Self::Target { | ||
&self.0 | ||
} | ||
} | ||
|
||
impl<T> DerefMut for AssertSend<T> { | ||
fn deref_mut(&mut self) -> &mut Self::Target { | ||
&mut self.0 | ||
} | ||
} | ||
|
||
// SAFETY: `Session` is probably "send"able. | ||
#[allow(unsafe_code)] | ||
unsafe impl<T> Send for AssertSend<T> {} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
use std::sync::Arc; | ||
|
||
use ndarray::{Array0, Array1, Array2}; | ||
|
||
use crate::infer::{InferenceRuntime, RunBuilder, Signature, TypedSession}; | ||
|
||
pub(crate) struct SessionSet<R: InferenceRuntime> { | ||
pub(crate) predict_duration: Arc<std::sync::Mutex<TypedSession<R, PredictDuration>>>, | ||
pub(crate) predict_intonation: Arc<std::sync::Mutex<TypedSession<R, PredictIntonation>>>, | ||
pub(crate) decode: Arc<std::sync::Mutex<TypedSession<R, Decode>>>, | ||
} | ||
|
||
pub(crate) struct PredictDuration { | ||
pub(crate) phoneme: Array1<i64>, | ||
pub(crate) speaker_id: Array1<i64>, | ||
} | ||
|
||
impl Signature for PredictDuration { | ||
type SessionSet<R: InferenceRuntime> = SessionSet<R>; | ||
type Output = (Vec<f32>,); | ||
|
||
fn get_session<R: InferenceRuntime>( | ||
session_set: &Self::SessionSet<R>, | ||
) -> &Arc<std::sync::Mutex<TypedSession<R, Self>>> { | ||
&session_set.predict_duration | ||
} | ||
|
||
fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { | ||
ctx.input(self.phoneme).input(self.speaker_id); | ||
} | ||
} | ||
|
||
pub(crate) struct PredictIntonation { | ||
pub(crate) length: Array0<i64>, | ||
pub(crate) vowel_phoneme: Array1<i64>, | ||
pub(crate) consonant_phoneme: Array1<i64>, | ||
pub(crate) start_accent: Array1<i64>, | ||
pub(crate) end_accent: Array1<i64>, | ||
pub(crate) start_accent_phrase: Array1<i64>, | ||
pub(crate) end_accent_phrase: Array1<i64>, | ||
pub(crate) speaker_id: Array1<i64>, | ||
} | ||
|
||
impl Signature for PredictIntonation { | ||
type SessionSet<R: InferenceRuntime> = SessionSet<R>; | ||
type Output = (Vec<f32>,); | ||
|
||
fn get_session<R: InferenceRuntime>( | ||
session_set: &Self::SessionSet<R>, | ||
) -> &Arc<std::sync::Mutex<TypedSession<R, Self>>> { | ||
&session_set.predict_intonation | ||
} | ||
|
||
fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { | ||
ctx.input(self.length) | ||
.input(self.vowel_phoneme) | ||
.input(self.consonant_phoneme) | ||
.input(self.start_accent) | ||
.input(self.end_accent) | ||
.input(self.start_accent_phrase) | ||
.input(self.end_accent_phrase) | ||
.input(self.speaker_id); | ||
} | ||
} | ||
|
||
pub(crate) struct Decode { | ||
pub(crate) f0: Array2<f32>, | ||
pub(crate) phoneme: Array2<f32>, | ||
pub(crate) speaker_id: Array1<i64>, | ||
} | ||
|
||
impl Signature for Decode { | ||
type SessionSet<R: InferenceRuntime> = SessionSet<R>; | ||
type Output = (Vec<f32>,); | ||
|
||
fn get_session<R: InferenceRuntime>( | ||
session_set: &Self::SessionSet<R>, | ||
) -> &Arc<std::sync::Mutex<TypedSession<R, Self>>> { | ||
&session_set.decode | ||
} | ||
|
||
fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { | ||
ctx.input(self.f0) | ||
.input(self.phoneme) | ||
.input(self.speaker_id); | ||
} | ||
} |
Oops, something went wrong.