Skip to content

Commit

Permalink
todo!を消す
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Jan 20, 2024
1 parent 6740a9d commit be70dba
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 40 deletions.
15 changes: 9 additions & 6 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static {
pub(crate) trait InferenceInputSignature: Send + 'static {
type Signature: InferenceSignature<Input = Self>;
const PARAM_INFOS: &'static [ParamInfo<InputScalarKind>];
fn make_run_context<R: InferenceRuntime>(self, sess: &mut R::Session) -> R::RunContext<'_>;
fn make_run_context<R: InferenceRuntime>(
self,
sess: &mut R::Session,
) -> anyhow::Result<R::RunContext<'_>>;
}

pub(crate) trait InputScalar: Sized {
Expand All @@ -79,7 +82,7 @@ pub(crate) trait InputScalar: Sized {
fn push_tensor_to_ctx(
tensor: Array<Self, impl Dimension + 'static>,
visitor: &mut impl PushInputTensor,
);
) -> anyhow::Result<()>;
}

#[duplicate_item(
Expand All @@ -93,8 +96,8 @@ impl InputScalar for T {
fn push_tensor_to_ctx(
tensor: Array<Self, impl Dimension + 'static>,
ctx: &mut impl PushInputTensor,
) {
ctx.push(tensor);
) -> anyhow::Result<()> {
ctx.push(tensor)
}
}

Expand All @@ -108,8 +111,8 @@ pub(crate) enum InputScalarKind {
}

pub(crate) trait PushInputTensor {
fn push_int64(&mut self, tensor: Array<i64, impl Dimension + 'static>);
fn push_float32(&mut self, tensor: Array<f32, impl Dimension + 'static>);
fn push_int64(&mut self, tensor: Array<i64, impl Dimension + 'static>) -> anyhow::Result<()>;
fn push_float32(&mut self, tensor: Array<f32, impl Dimension + 'static>) -> anyhow::Result<()>;
}

/// 推論操作の出力シグネチャ。
Expand Down
59 changes: 35 additions & 24 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{fmt::Debug, vec};

use anyhow::{anyhow, ensure};
use anyhow::{anyhow, bail, ensure};
use duplicate::duplicate_item;
use ndarray::{Array, Dimension};
use ort::{
Expand Down Expand Up @@ -80,7 +80,11 @@ impl InferenceRuntime for Onnxruntime {
.iter()
.map(|info| {
let ValueType::Tensor { ty, .. } = info.input_type else {
todo!()
bail!(
"unexpected input value type for `{}`. currently `ONNX_TYPE_TENSOR` and \
`ONNX_TYPE_SPARSETENSOR` is supported",
info.name,
);
};

let dt = match ty {
Expand All @@ -92,12 +96,12 @@ impl InferenceRuntime for Onnxruntime {
TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
TensorElementType::Int64 => Ok(InputScalarKind::Int64),
TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
TensorElementType::Bfloat16 => todo!(),
TensorElementType::Float16 => todo!(),
TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"),
TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"),
TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
TensorElementType::Bool => todo!(),
TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"),
}
.map_err(|actual| {
anyhow!("unsupported input datatype `{actual}` for `{}`", info.name)
Expand All @@ -116,7 +120,11 @@ impl InferenceRuntime for Onnxruntime {
.iter()
.map(|info| {
let ValueType::Tensor { ty, .. } = info.output_type else {
todo!()
bail!(
"unexpected output value type for `{}`. currently `ONNX_TYPE_TENSOR` and \
`ONNX_TYPE_SPARSETENSOR` is supported",
info.name,
);
};

let dt = match ty {
Expand All @@ -128,12 +136,12 @@ impl InferenceRuntime for Onnxruntime {
TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
TensorElementType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"),
TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
TensorElementType::Bfloat16 => todo!(),
TensorElementType::Float16 => todo!(),
TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"),
TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"),
TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
TensorElementType::Bool => todo!(),
TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"),
}
.map_err(|actual| {
anyhow!("unsupported output datatype `{actual}` for `{}`", info.name)
Expand All @@ -159,18 +167,20 @@ impl InferenceRuntime for Onnxruntime {
let output = &outputs[i];
let dtype = output.dtype()?;

if !matches!(
dtype,
ValueType::Tensor {
ty: TensorElementType::Float32,
..
let ValueType::Tensor { ty, .. } = dtype else {
bail!(
"unexpected output. currently `ONNX_TYPE_TENSOR` and \
`ONNX_TYPE_SPARSETENSOR` is supported",
);
};

match ty {
TensorElementType::Float32 => {
let tensor = output.extract_tensor::<f32>()?;
Ok(OutputTensor::Float32(tensor.view().clone().into_owned()))
}
) {
todo!();
_ => bail!("unexpected output tensor element data type"),
}

let tensor = output.extract_tensor::<f32>()?;
Ok(OutputTensor::Float32(tensor.view().clone().into_owned()))
})
.collect()
}
Expand All @@ -197,9 +207,10 @@ impl OnnxruntimeRunContext<'_> {
impl IntoTensorElementType + Debug + Clone + 'static,
impl Dimension + 'static,
>,
) {
self.inputs
.push(input.try_into().unwrap_or_else(|_| todo!()));
) -> anyhow::Result<()> {
let input = input.try_into()?;
self.inputs.push(input);
Ok(())
}
}

Expand All @@ -218,7 +229,7 @@ impl PushInputTensor for OnnxruntimeRunContext<'_> {
[ push_int64 ] [ i64 ];
[ push_float32 ] [ f32 ];
)]
fn method(&mut self, tensor: Array<T, impl Dimension + 'static>) {
self.push_input(tensor);
fn method(&mut self, tensor: Array<T, impl Dimension + 'static>) -> anyhow::Result<()> {
self.push_input(tensor)
}
}
8 changes: 4 additions & 4 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,10 @@ struct SessionCell<R: InferenceRuntime, I> {
impl<R: InferenceRuntime, I: InferenceInputSignature> SessionCell<R, I> {
fn run(self, input: I) -> crate::Result<<I::Signature as InferenceSignature>::Output> {
let inner = &mut self.inner.lock().unwrap();
let ctx = input.make_run_context::<R>(inner);
R::run(ctx)
.and_then(TryInto::try_into)
.map_err(|e| ErrorRepr::InferenceFailed(e).into())

(|| R::run(input.make_run_context::<R>(inner)?)?.try_into())()
.map_err(ErrorRepr::InferenceFailed)
.map_err(Into::into)
}
}

Expand Down
18 changes: 12 additions & 6 deletions crates/voicevox_core_macros/src/inference_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,28 @@ pub(crate) fn derive_inference_input_signature(
fn make_run_context<R: crate::infer::InferenceRuntime>(
self,
sess: &mut R::Session,
) -> R::RunContext<'_> {
) -> ::anyhow::Result<R::RunContext<'_>> {
let mut ctx = <R::RunContext<'_> as ::std::convert::From<_>>::from(sess);
#(
__ArrayExt::push_to_ctx(self.#field_names, &mut ctx);
__ArrayExt::push_to_ctx(self.#field_names, &mut ctx)?;
)*
return ctx;
return ::std::result::Result::Ok(ctx);

trait __ArrayExt {
fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor);
fn push_to_ctx(
self,
ctx: &mut impl crate::infer::PushInputTensor,
) -> ::anyhow::Result<()>;
}

impl<A: crate::infer::InputScalar, D: ::ndarray::Dimension + 'static> __ArrayExt
for ::ndarray::Array<A, D>
{
fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor) {
A::push_tensor_to_ctx(self, ctx);
fn push_to_ctx(
self,
ctx: &mut impl crate::infer::PushInputTensor,
) -> ::anyhow::Result<()> {
A::push_tensor_to_ctx(self, ctx)
}
}
}
Expand Down

0 comments on commit be70dba

Please sign in to comment.