diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index c816c9899..5b07686fd 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -70,7 +70,10 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static { pub(crate) trait InferenceInputSignature: Send + 'static { type Signature: InferenceSignature; const PARAM_INFOS: &'static [ParamInfo]; - fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>; + fn make_run_context( + self, + sess: &mut R::Session, + ) -> anyhow::Result>; } pub(crate) trait InputScalar: Sized { @@ -79,7 +82,7 @@ pub(crate) trait InputScalar: Sized { fn push_tensor_to_ctx( tensor: Array, visitor: &mut impl PushInputTensor, - ); + ) -> anyhow::Result<()>; } #[duplicate_item( @@ -93,8 +96,8 @@ impl InputScalar for T { fn push_tensor_to_ctx( tensor: Array, ctx: &mut impl PushInputTensor, - ) { - ctx.push(tensor); + ) -> anyhow::Result<()> { + ctx.push(tensor) } } @@ -108,8 +111,8 @@ pub(crate) enum InputScalarKind { } pub(crate) trait PushInputTensor { - fn push_int64(&mut self, tensor: Array); - fn push_float32(&mut self, tensor: Array); + fn push_int64(&mut self, tensor: Array) -> anyhow::Result<()>; + fn push_float32(&mut self, tensor: Array) -> anyhow::Result<()>; } /// 推論操作の出力シグネチャ。 diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 787b647f9..abc225670 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -1,6 +1,6 @@ use std::{fmt::Debug, vec}; -use anyhow::{anyhow, ensure}; +use anyhow::{anyhow, bail, ensure}; use duplicate::duplicate_item; use ndarray::{Array, Dimension}; use ort::{ @@ -80,7 +80,11 @@ impl InferenceRuntime for Onnxruntime { .iter() .map(|info| { let ValueType::Tensor { ty, .. } = info.input_type else { - todo!() + bail!( + "unexpected input value type for `{}`. currently `ONNX_TYPE_TENSOR` and \ + `ONNX_TYPE_SPARSETENSOR` is supported", + info.name, + ); }; let dt = match ty { @@ -92,12 +96,12 @@ impl InferenceRuntime for Onnxruntime { TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), TensorElementType::Int64 => Ok(InputScalarKind::Int64), TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), - TensorElementType::Bfloat16 => todo!(), - TensorElementType::Float16 => todo!(), + TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"), + TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"), TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), - TensorElementType::Bool => todo!(), + TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"), } .map_err(|actual| { anyhow!("unsupported input datatype `{actual}` for `{}`", info.name) @@ -116,7 +120,11 @@ impl InferenceRuntime for Onnxruntime { .iter() .map(|info| { let ValueType::Tensor { ty, .. } = info.output_type else { - todo!() + bail!( + "unexpected output value type for `{}`. currently `ONNX_TYPE_TENSOR` and \ + `ONNX_TYPE_SPARSETENSOR` is supported", + info.name, + ); }; let dt = match ty { @@ -128,12 +136,12 @@ impl InferenceRuntime for Onnxruntime { TensorElementType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), TensorElementType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"), TensorElementType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), - TensorElementType::Bfloat16 => todo!(), - TensorElementType::Float16 => todo!(), + TensorElementType::Bfloat16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"), + TensorElementType::Float16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"), TensorElementType::Float64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), TensorElementType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), TensorElementType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), - TensorElementType::Bool => todo!(), + TensorElementType::Bool => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"), } .map_err(|actual| { anyhow!("unsupported output datatype `{actual}` for `{}`", info.name) @@ -159,18 +167,20 @@ impl InferenceRuntime for Onnxruntime { let output = &outputs[i]; let dtype = output.dtype()?; - if !matches!( - dtype, - ValueType::Tensor { - ty: TensorElementType::Float32, - .. + let ValueType::Tensor { ty, .. } = dtype else { + bail!( + "unexpected output. currently `ONNX_TYPE_TENSOR` and \ + `ONNX_TYPE_SPARSETENSOR` is supported", + ); + }; + + match ty { + TensorElementType::Float32 => { + let tensor = output.extract_tensor::()?; + Ok(OutputTensor::Float32(tensor.view().clone().into_owned())) } - ) { - todo!(); + _ => bail!("unexpected output tensor element data type"), } - - let tensor = output.extract_tensor::()?; - Ok(OutputTensor::Float32(tensor.view().clone().into_owned())) }) .collect() } @@ -197,9 +207,10 @@ impl OnnxruntimeRunContext<'_> { impl IntoTensorElementType + Debug + Clone + 'static, impl Dimension + 'static, >, - ) { - self.inputs - .push(input.try_into().unwrap_or_else(|_| todo!())); + ) -> anyhow::Result<()> { + let input = input.try_into()?; + self.inputs.push(input); + Ok(()) } } @@ -218,7 +229,7 @@ impl PushInputTensor for OnnxruntimeRunContext<'_> { [ push_int64 ] [ i64 ]; [ push_float32 ] [ f32 ]; )] - fn method(&mut self, tensor: Array) { - self.push_input(tensor); + fn method(&mut self, tensor: Array) -> anyhow::Result<()> { + self.push_input(tensor) } } diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index 12367dda2..6f82a9991 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -330,10 +330,10 @@ struct SessionCell { impl SessionCell { fn run(self, input: I) -> crate::Result<::Output> { let inner = &mut self.inner.lock().unwrap(); - let ctx = input.make_run_context::(inner); - R::run(ctx) - .and_then(TryInto::try_into) - .map_err(|e| ErrorRepr::InferenceFailed(e).into()) + + (|| R::run(input.make_run_context::(inner)?)?.try_into())() + .map_err(ErrorRepr::InferenceFailed) + .map_err(Into::into) } } diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs index 72bc4d18a..d24a20ab1 100644 --- a/crates/voicevox_core_macros/src/inference_domain.rs +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -223,22 +223,28 @@ pub(crate) fn derive_inference_input_signature( fn make_run_context( self, sess: &mut R::Session, - ) -> R::RunContext<'_> { + ) -> ::anyhow::Result> { let mut ctx = as ::std::convert::From<_>>::from(sess); #( - __ArrayExt::push_to_ctx(self.#field_names, &mut ctx); + __ArrayExt::push_to_ctx(self.#field_names, &mut ctx)?; )* - return ctx; + return ::std::result::Result::Ok(ctx); trait __ArrayExt { - fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor); + fn push_to_ctx( + self, + ctx: &mut impl crate::infer::PushInputTensor, + ) -> ::anyhow::Result<()>; } impl __ArrayExt for ::ndarray::Array { - fn push_to_ctx(self, ctx: &mut impl crate::infer::PushInputTensor) { - A::push_tensor_to_ctx(self, ctx); + fn push_to_ctx( + self, + ctx: &mut impl crate::infer::PushInputTensor, + ) -> ::anyhow::Result<()> { + A::push_tensor_to_ctx(self, ctx) } } }