Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 6, 2023
1 parent 20db67a commit cc84068
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod model_file;
pub(crate) mod runtimes;
pub(crate) mod signatures;

use std::{fmt::Debug, marker::PhantomData, sync::Arc};
use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc};

use derive_new::new;
use enum_map::{Enum, EnumMap};
Expand All @@ -29,13 +29,13 @@ pub(crate) trait RunBuilder<'a>: From<&'a mut Self::Session> {
fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self;
}

pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputScalar {}
pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::InputScalar {}

impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait Signature: Sized + Send + 'static {
type Kind: Enum;
type Kind: Enum + Copy;
type Output;
const KIND: Self::Kind;
fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>);
Expand All @@ -49,7 +49,7 @@ pub(crate) struct SessionSet<K: Enum, R: InferenceRuntime>(
EnumMap<K, Arc<std::sync::Mutex<R::Session>>>,
);

impl<K: Enum, R: InferenceRuntime> SessionSet<K, R> {
impl<K: Enum + Copy, R: InferenceRuntime> SessionSet<K, R> {
pub(crate) fn new(
model_bytes: &EnumMap<K, Vec<u8>>,
mut options: impl FnMut(K) -> SessionOptions,
Expand All @@ -58,12 +58,12 @@ impl<K: Enum, R: InferenceRuntime> SessionSet<K, R> {
.iter()
.map(|(k, m)| {
let sess = R::Session::new(|| model_file::decrypt(m), options(k))?;
Ok(Some(Arc::new(std::sync::Mutex::new(sess))))
Ok((k.into_usize(), std::sync::Mutex::new(sess).into()))
})
.collect::<anyhow::Result<Vec<_>>>()?;
.collect::<anyhow::Result<HashMap<_, _>>>()?;

Ok(Self(EnumMap::<K, _>::from_fn(|k| {
sessions[k.into_usize()].take().expect("should exist")
sessions.remove(&k.into_usize()).expect("should exist")
})))
}
}
Expand Down Expand Up @@ -105,11 +105,15 @@ pub(crate) struct SessionOptions {
pub(crate) struct DecryptModelError;

mod sealed {
pub(crate) trait InputScalar: OnnxruntimeInputScalar {}

impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait OnnxruntimeInputScalar:
onnxruntime::TypeToTensorElementDataType
{
}

impl OnnxruntimeInputScalar for i64 {}
impl OnnxruntimeInputScalar for f32 {}
impl<T: onnxruntime::TypeToTensorElementDataType> OnnxruntimeInputScalar for T {}
}

0 comments on commit cc84068

Please sign in to comment.