Skip to content

Commit

Permalink
シグネチャの実行時チェック機構を入れる
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 12, 2023
1 parent c316209 commit 2274a34
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 55 deletions.
94 changes: 86 additions & 8 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ pub(crate) mod runtimes;
pub(crate) mod signatures;
pub(crate) mod status;

use std::fmt::Debug;
use std::{
borrow::Cow,
fmt::{self, Debug, Display},
};

use derive_new::new;
use enum_map::Enum;
use enum_map::{Enum, EnumMap};
use ndarray::{Array, ArrayD, Dimension, ShapeError};
use thiserror::Error;

Expand All @@ -18,10 +21,15 @@ pub(crate) trait InferenceRuntime: 'static {

fn supported_devices() -> crate::Result<SupportedDevices>;

#[allow(clippy::type_complexity)]
fn new_session(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: InferenceSessionOptions,
) -> anyhow::Result<Self::Session>;
) -> anyhow::Result<(
Self::Session,
Vec<ParamInfo<InputScalarKind>>,
Vec<ParamInfo<OutputScalarKind>>,
)>;

fn push_input(
input: Array<impl InputScalar, impl Dimension + 'static>,
Expand All @@ -31,37 +39,72 @@ pub(crate) trait InferenceRuntime: 'static {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

pub(crate) trait InferenceGroup: Copy + Enum {}
pub(crate) trait InferenceGroup: Copy + Enum {
const INPUT_PARAM_INFOS: EnumMap<Self, &'static [ParamInfo<InputScalarKind>]>;
const OUTPUT_PARAM_INFOS: EnumMap<Self, &'static [ParamInfo<OutputScalarKind>]>;
}

pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Group: InferenceGroup;
type Input: InferenceInputSignature<Signature = Self>;
type Output: TryFrom<Vec<OutputTensor>, Error = anyhow::Error> + Send;
type Output: InferenceOutputSignature;
const KIND: Self::Group;
}

pub(crate) trait InferenceInputSignature: Send + 'static {
type Signature: InferenceSignature<Input = Self>;
const PARAM_INFOS: &'static [ParamInfo<InputScalarKind>];
fn make_run_context<R: InferenceRuntime>(self, sess: &mut R::Session) -> R::RunContext<'_>;
}

pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {}
pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {
const KIND: InputScalarKind;
}

impl InputScalar for i64 {}
impl InputScalar for f32 {}
impl InputScalar for i64 {
const KIND: InputScalarKind = InputScalarKind::Int64;
}

impl InputScalar for f32 {
const KIND: InputScalarKind = InputScalarKind::Float32;
}

#[derive(Clone, Copy, PartialEq, derive_more::Display)]
pub(crate) enum InputScalarKind {
#[display(fmt = "int64_t")]
Int64,

#[display(fmt = "float")]
Float32,
}

pub(crate) trait InferenceOutputSignature:
TryFrom<Vec<OutputTensor>, Error = anyhow::Error> + Send
{
const PARAM_INFOS: &'static [ParamInfo<OutputScalarKind>];
}

pub(crate) trait OutputScalar: Sized {
const KIND: OutputScalarKind;
fn extract(tensor: OutputTensor) -> std::result::Result<ArrayD<Self>, ExtractError>;
}

impl OutputScalar for f32 {
const KIND: OutputScalarKind = OutputScalarKind::Float32;

fn extract(tensor: OutputTensor) -> std::result::Result<ArrayD<Self>, ExtractError> {
match tensor {
OutputTensor::Float32(tensor) => Ok(tensor),
}
}
}

#[derive(Clone, Copy, PartialEq, derive_more::Display)]
pub(crate) enum OutputScalarKind {
#[display(fmt = "float")]
Float32,
}

pub(crate) enum OutputTensor {
Float32(ArrayD<f32>),
}
Expand All @@ -75,6 +118,41 @@ impl<A: OutputScalar, D: Dimension> TryFrom<OutputTensor> for Array<A, D> {
}
}

pub(crate) struct ParamInfo<D> {
name: Cow<'static, str>,
dt: D,
ndim: Option<usize>,
}

impl<D: PartialEq> ParamInfo<D> {
fn accepts(&self, other: &Self) -> bool {
self.name == other.name
&& self.dt == other.dt
&& (self.ndim.is_none() || self.ndim == other.ndim)
}
}

impl<D: Display> Display for ParamInfo<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.name, self.dt)?;
if let Some(ndim) = self.ndim {
f.write_str(&"[]".repeat(ndim))
} else {
f.write_str("[]...")
}
}
}

pub(crate) trait ArrayExt {
type Scalar;
type Dimension: Dimension;
}

impl<A, D: Dimension> ArrayExt for Array<A, D> {
type Scalar = A;
type Dimension = D;
}

#[derive(new, Clone, Copy, PartialEq, Debug)]
pub(crate) struct InferenceSessionOptions {
pub(crate) cpu_num_threads: u16,
Expand Down
84 changes: 74 additions & 10 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
use std::fmt::Debug;
use std::{fmt::Debug, vec};

use anyhow::anyhow;
use ndarray::{Array, Dimension};
use once_cell::sync::Lazy;
use onnxruntime::{
environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType,
};

use crate::{devices::SupportedDevices, error::ErrorRepr};

use self::assert_send::AssertSend;
use crate::{
devices::SupportedDevices,
error::ErrorRepr,
infer::{
DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, OutputTensor,
},

use super::super::{
DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, InputScalarKind,
OutputScalarKind, OutputTensor, ParamInfo,
};

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
Expand Down Expand Up @@ -47,7 +48,11 @@ impl InferenceRuntime for Onnxruntime {
fn new_session(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: InferenceSessionOptions,
) -> anyhow::Result<Self::Session> {
) -> anyhow::Result<(
Self::Session,
Vec<ParamInfo<InputScalarKind>>,
Vec<ParamInfo<OutputScalarKind>>,
)> {
let mut builder = ENVIRONMENT
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
Expand All @@ -72,8 +77,67 @@ impl InferenceRuntime for Onnxruntime {
}

let model = model()?;
let sess = builder.with_model_from_memory(model)?.into();
return Ok(sess);
let sess = AssertSend::from(builder.with_model_from_memory(model)?);

let input_param_infos = sess
.inputs
.iter()
.map(|info| {
let dt = match info.input_type {
TensorElementDataType::Float => Ok(InputScalarKind::Float32),
TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"),
TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"),
TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"),
TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"),
TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
TensorElementDataType::Int64 => Ok(InputScalarKind::Int64),
TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
}
.map_err(|actual| {
anyhow!("unsupported input datatype `{actual}` for `{}`", info.name)
})?;

Ok(ParamInfo {
name: info.name.clone().into(),
dt,
ndim: Some(info.dimensions.len()),
})
})
.collect::<anyhow::Result<_>>()?;

let output_param_infos = sess
.outputs
.iter()
.map(|info| {
let dt = match info.output_type {
TensorElementDataType::Float => Ok(OutputScalarKind::Float32),
TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"),
TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"),
TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"),
TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"),
TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
TensorElementDataType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"),
TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
}
.map_err(|actual| {
anyhow!("unsupported output datatype `{actual}` for `{}`", info.name)
})?;

Ok(ParamInfo {
name: info.name.clone().into(),
dt,
ndim: Some(info.dimensions.len()),
})
})
.collect::<anyhow::Result<_>>()?;

return Ok((sess, input_param_infos, output_param_infos));

static ENVIRONMENT: Lazy<Environment> = Lazy::new(|| {
Environment::builder()
Expand Down
52 changes: 38 additions & 14 deletions crates/voicevox_core/src/infer/signatures.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,40 @@
use enum_map::Enum;
use macros::{InferenceGroup, InferenceInputSignature, TryFromVecOutputTensor};
use enum_map::{Enum, EnumMap};
use macros::{InferenceInputSignature, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};

use super::{InferenceSignature, OutputTensor};
use super::{
InferenceGroup, InferenceInputSignature as _, InferenceOutputSignature as _,
InferenceSignature, OutputTensor,
};

#[derive(Clone, Copy, Enum, InferenceGroup)]
#[derive(Clone, Copy, Enum)]
pub(crate) enum InferenceKind {
PredictDuration,
PredictIntonation,
Decode,
}

// FIXME: ここもマクロ化する
impl InferenceGroup for InferenceKind {
const INPUT_PARAM_INFOS: enum_map::EnumMap<
Self,
&'static [super::ParamInfo<super::InputScalarKind>],
> = EnumMap::from_array([
PredictDurationInput::PARAM_INFOS,
PredictIntonationInput::PARAM_INFOS,
DecodeInput::PARAM_INFOS,
]);

const OUTPUT_PARAM_INFOS: enum_map::EnumMap<
Self,
&'static [super::ParamInfo<super::OutputScalarKind>],
> = EnumMap::from_array([
PredictDurationOutput::PARAM_INFOS,
PredictIntonationOutput::PARAM_INFOS,
DecodeOutput::PARAM_INFOS,
]);
}

pub(crate) enum PredictDuration {}

impl InferenceSignature for PredictDuration {
Expand All @@ -23,11 +47,11 @@ impl InferenceSignature for PredictDuration {
#[derive(InferenceInputSignature)]
#[input_signature(Signature = PredictDuration)]
pub(crate) struct PredictDurationInput {
pub(crate) phoneme: Array1<i64>,
pub(crate) phoneme_list: Array1<i64>,
pub(crate) speaker_id: Array1<i64>,
}

#[derive(TryFromVecOutputTensor)]
#[derive(InferenceOutputSignature)]
pub(crate) struct PredictDurationOutput {
pub(crate) phoneme_length: Array1<f32>,
}
Expand All @@ -45,16 +69,16 @@ impl InferenceSignature for PredictIntonation {
#[input_signature(Signature = PredictIntonation)]
pub(crate) struct PredictIntonationInput {
pub(crate) length: Array0<i64>,
pub(crate) vowel_phoneme: Array1<i64>,
pub(crate) consonant_phoneme: Array1<i64>,
pub(crate) start_accent: Array1<i64>,
pub(crate) end_accent: Array1<i64>,
pub(crate) start_accent_phrase: Array1<i64>,
pub(crate) end_accent_phrase: Array1<i64>,
pub(crate) vowel_phoneme_list: Array1<i64>,
pub(crate) consonant_phoneme_list: Array1<i64>,
pub(crate) start_accent_list: Array1<i64>,
pub(crate) end_accent_list: Array1<i64>,
pub(crate) start_accent_phrase_list: Array1<i64>,
pub(crate) end_accent_phrase_list: Array1<i64>,
pub(crate) speaker_id: Array1<i64>,
}

#[derive(TryFromVecOutputTensor)]
#[derive(InferenceOutputSignature)]
pub(crate) struct PredictIntonationOutput {
pub(crate) f0_list: Array1<f32>,
}
Expand All @@ -76,7 +100,7 @@ pub(crate) struct DecodeInput {
pub(crate) speaker_id: Array1<i64>,
}

#[derive(TryFromVecOutputTensor)]
#[derive(InferenceOutputSignature)]
pub(crate) struct DecodeOutput {
pub(crate) wave: Array1<f32>,
}
Loading

0 comments on commit 2274a34

Please sign in to comment.