diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs
index c816c9899..5b07686fd 100644
--- a/crates/voicevox_core/src/infer.rs
+++ b/crates/voicevox_core/src/infer.rs
@@ -70,7 +70,10 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static {
pub(crate) trait InferenceInputSignature: Send + 'static {
type Signature: InferenceSignature;
const PARAM_INFOS: &'static [ParamInfo];
- fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>;
+ fn make_run_context(
+ self,
+ sess: &mut R::Session,
+ ) -> anyhow::Result>;
}
pub(crate) trait InputScalar: Sized {
@@ -79,7 +82,7 @@ pub(crate) trait InputScalar: Sized {
fn push_tensor_to_ctx(
tensor: Array,
visitor: &mut impl PushInputTensor,
- );
+ ) -> anyhow::Result<()>;
}
#[duplicate_item(
@@ -93,8 +96,8 @@ impl InputScalar for T {
fn push_tensor_to_ctx(
tensor: Array,
ctx: &mut impl PushInputTensor,
- ) {
- ctx.push(tensor);
+ ) -> anyhow::Result<()> {
+ ctx.push(tensor)
}
}
@@ -108,8 +111,8 @@ pub(crate) enum InputScalarKind {
}
pub(crate) trait PushInputTensor {
- fn push_int64(&mut self, tensor: Array);
- fn push_float32(&mut self, tensor: Array);
+ fn push_int64(&mut self, tensor: Array) -> anyhow::Result<()>;
+ fn push_float32(&mut self, tensor: Array) -> anyhow::Result<()>;
}
/// 推論操作の出力シグネチャ。
diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
index 787b647f9..abc225670 100644
--- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
+++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
@@ -1,6 +1,6 @@
use std::{fmt::Debug, vec};
-use anyhow::{anyhow, ensure};
+use anyhow::{anyhow, bail, ensure};
use duplicate::duplicate_item;
use ndarray::{Array, Dimension};
use ort::{
@@ -80,7 +80,11 @@ impl InferenceRuntime for Onnxruntime {
.iter()
.map(|info| {
let ValueType::Tensor { ty, .. } = info.input_type else {
- todo!()
+ bail!(
+ "unexpected input value type for `{}`. currently `ONNX_TYPE_TENSOR` and \
+ `ONNX_TYPE_SPARSETENSOR` is supported",
+ info.name,
+ );
};
let dt = match ty {
@@ -92,12 +96,12 @@ impl InferenceRuntime for Onnxruntime {
TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
TensorElementType::Int64 => Ok(InputScalarKind::Int64),
TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
- TensorElementType::Bfloat16 => todo!(),
- TensorElementType::Float16 => todo!(),
+ TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"),
+ TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"),
TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
- TensorElementType::Bool => todo!(),
+ TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"),
}
.map_err(|actual| {
anyhow!("unsupported input datatype `{actual}` for `{}`", info.name)
@@ -116,7 +120,11 @@ impl InferenceRuntime for Onnxruntime {
.iter()
.map(|info| {
let ValueType::Tensor { ty, .. } = info.output_type else {
- todo!()
+ bail!(
+ "unexpected output value type for `{}`. currently `ONNX_TYPE_TENSOR` and \
+ `ONNX_TYPE_SPARSETENSOR` is supported",
+ info.name,
+ );
};
let dt = match ty {
@@ -128,12 +136,12 @@ impl InferenceRuntime for Onnxruntime {
TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"),
TensorElementType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"),
TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"),
- TensorElementType::Bfloat16 => todo!(),
- TensorElementType::Float16 => todo!(),
+ TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"),
+ TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"),
TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"),
TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"),
TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"),
- TensorElementType::Bool => todo!(),
+ TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"),
}
.map_err(|actual| {
anyhow!("unsupported output datatype `{actual}` for `{}`", info.name)
@@ -159,18 +167,20 @@ impl InferenceRuntime for Onnxruntime {
let output = &outputs[i];
let dtype = output.dtype()?;
- if !matches!(
- dtype,
- ValueType::Tensor {
- ty: TensorElementType::Float32,
- ..
+ let ValueType::Tensor { ty, .. } = dtype else {
+ bail!(
+ "unexpected output. currently `ONNX_TYPE_TENSOR` and \
+ `ONNX_TYPE_SPARSETENSOR` is supported",
+ );
+ };
+
+ match ty {
+ TensorElementType::Float32 => {
+ let tensor = output.extract_tensor::()?;
+ Ok(OutputTensor::Float32(tensor.view().clone().into_owned()))
}
- ) {
- todo!();
+ _ => bail!("unexpected output tensor element data type"),
}
-
- let tensor = output.extract_tensor::()?;
- Ok(OutputTensor::Float32(tensor.view().clone().into_owned()))
})
.collect()
}
@@ -197,9 +207,10 @@ impl OnnxruntimeRunContext<'_> {
impl IntoTensorElementType + Debug + Clone + 'static,
impl Dimension + 'static,
>,
- ) {
- self.inputs
- .push(input.try_into().unwrap_or_else(|_| todo!()));
+ ) -> anyhow::Result<()> {
+ let input = input.try_into()?;
+ self.inputs.push(input);
+ Ok(())
}
}
@@ -218,7 +229,7 @@ impl PushInputTensor for OnnxruntimeRunContext<'_> {
[ push_int64 ] [ i64 ];
[ push_float32 ] [ f32 ];
)]
- fn method(&mut self, tensor: Array) {
- self.push_input(tensor);
+ fn method(&mut self, tensor: Array) -> anyhow::Result<()> {
+ self.push_input(tensor)
}
}
diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs
index 12367dda2..6f82a9991 100644
--- a/crates/voicevox_core/src/infer/status.rs
+++ b/crates/voicevox_core/src/infer/status.rs
@@ -330,10 +330,10 @@ struct SessionCell {
impl SessionCell {
fn run(self, input: I) -> crate::Result<::Output> {
let inner = &mut self.inner.lock().unwrap();
- let ctx = input.make_run_context::(inner);
- R::run(ctx)
- .and_then(TryInto::try_into)
- .map_err(|e| ErrorRepr::InferenceFailed(e).into())
+
+ (|| R::run(input.make_run_context::(inner)?)?.try_into())()
+ .map_err(ErrorRepr::InferenceFailed)
+ .map_err(Into::into)
}
}
diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs
index 72bc4d18a..d24a20ab1 100644
--- a/crates/voicevox_core_macros/src/inference_domain.rs
+++ b/crates/voicevox_core_macros/src/inference_domain.rs
@@ -223,22 +223,28 @@ pub(crate) fn derive_inference_input_signature(
fn make_run_context(
self,
sess: &mut R::Session,
- ) -> R::RunContext<'_> {
+ ) -> ::anyhow::Result> {
let mut ctx = as ::std::convert::From<_>>::from(sess);
#(
- __ArrayExt::push_to_ctx(self.#field_names, &mut ctx);
+ __ArrayExt::push_to_ctx(self.#field_names, &mut ctx)?;
)*
- return ctx;
+ return ::std::result::Result::Ok(ctx);
trait __ArrayExt {
- fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor);
+ fn push_to_ctx(
+ self,
+ ctx: &mut impl crate::infer::PushInputTensor,
+ ) -> ::anyhow::Result<()>;
}
impl __ArrayExt
for ::ndarray::Array
{
- fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor) {
- A::push_tensor_to_ctx(self, ctx);
+ fn push_to_ctx(
+ self,
+ ctx: &mut impl crate::infer::PushInputTensor,
+ ) -> ::anyhow::Result<()> {
+ A::push_tensor_to_ctx(self, ctx)
}
}
}