diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index d5a55cfea..1fccadf0a 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -2,7 +2,7 @@ mod model_file; pub(crate) mod runtimes; pub(crate) mod signatures; -use std::{fmt::Debug, marker::PhantomData, sync::Arc}; +use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; use derive_new::new; use enum_map::{Enum, EnumMap}; @@ -29,13 +29,13 @@ pub(crate) trait RunBuilder<'a>: From<&'a mut Self::Session> { fn input(&mut self, tensor: Array) -> &mut Self; } -pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputScalar {} +pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::InputScalar {} impl InputScalar for i64 {} impl InputScalar for f32 {} pub(crate) trait Signature: Sized + Send + 'static { - type Kind: Enum; + type Kind: Enum + Copy; type Output; const KIND: Self::Kind; fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>); @@ -49,7 +49,7 @@ pub(crate) struct SessionSet( EnumMap>>, ); -impl SessionSet { +impl SessionSet { pub(crate) fn new( model_bytes: &EnumMap>, mut options: impl FnMut(K) -> SessionOptions, @@ -58,12 +58,12 @@ impl SessionSet { .iter() .map(|(k, m)| { let sess = R::Session::new(|| model_file::decrypt(m), options(k))?; - Ok(Some(Arc::new(std::sync::Mutex::new(sess)))) + Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) }) - .collect::>>()?; + .collect::>>()?; Ok(Self(EnumMap::::from_fn(|k| { - sessions[k.into_usize()].take().expect("should exist") + sessions.remove(&k.into_usize()).expect("should exist") }))) } } @@ -105,11 +105,15 @@ pub(crate) struct SessionOptions { pub(crate) struct DecryptModelError; mod sealed { + pub(crate) trait InputScalar: OnnxruntimeInputScalar {} + + impl InputScalar for i64 {} + impl InputScalar for f32 {} + pub(crate) trait OnnxruntimeInputScalar: onnxruntime::TypeToTensorElementDataType { } - impl OnnxruntimeInputScalar for i64 {} - impl OnnxruntimeInputScalar for f32 {} + impl OnnxruntimeInputScalar for T {} }