Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Dec 9, 2024
1 parent 438f14f commit ac6ca3f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
18 changes: 8 additions & 10 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod model_file;
pub(crate) mod runtimes;
pub(crate) mod session_set;

use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, future::Future, ops::Index, sync::Arc};
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, ops::Index, sync::Arc};

use derive_new::new;
use duplicate::duplicate_item;
Expand All @@ -27,7 +27,7 @@ impl AsyncExt for SingleTasked {
async fn run_session<R: InferenceRuntime>(
ctx: R::RunContext,
) -> anyhow::Result<Vec<OutputTensor>> {
R::run(ctx)
R::run_blocking(ctx)
}
}

Expand All @@ -43,7 +43,7 @@ pub(crate) trait InferenceRuntime: 'static {
// TODO: "session"とは何なのかを定め、ドキュメントを書く。`InferenceSessionSet`も同様。
type Session;

// 本当は`From<'_ Self::Session>`としたいが、 rust-lang/rust#100013 がある
// 本当は`From<&'_ Self::Session>`としたいが、 rust-lang/rust#100013 が立ち塞がる
type RunContext: From<Arc<Self::Session>> + PushInputTensor;

/// 名前。
Expand All @@ -70,11 +70,9 @@ pub(crate) trait InferenceRuntime: 'static {
Vec<ParamInfo<OutputScalarKind>>,
)>;

fn run(ctx: Self::RunContext) -> anyhow::Result<Vec<OutputTensor>>;
fn run_blocking(ctx: Self::RunContext) -> anyhow::Result<Vec<OutputTensor>>;

fn run_async(
ctx: Self::RunContext,
) -> impl Future<Output = anyhow::Result<Vec<OutputTensor>>> + Send;
async fn run_async(ctx: Self::RunContext) -> anyhow::Result<Vec<OutputTensor>>;
}

/// 共に扱われるべき推論操作の集合を示す。
Expand Down Expand Up @@ -115,7 +113,7 @@ pub(crate) trait InferenceOperation: Copy + Enum {
/// `InferenceDomain`の推論操作を表す列挙型。
///
/// `::macros::InferenceOperation`により、具体型ごと生成される。
pub(crate) trait InferenceSignature: Sized + Send + 'static {
pub(crate) trait InferenceSignature {
type Domain: InferenceDomain;
type Input: InferenceInputSignature<Signature = Self>;
type Output: InferenceOutputSignature;
Expand All @@ -125,7 +123,7 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static {
/// 推論操作の入力シグネチャ。
///
/// `::macros::InferenceInputSignature`により導出される。
pub(crate) trait InferenceInputSignature: Send + 'static {
pub(crate) trait InferenceInputSignature {
type Signature: InferenceSignature<Input = Self>;
const PARAM_INFOS: &'static [ParamInfo<InputScalarKind>];
fn make_run_context<R: InferenceRuntime>(
Expand Down Expand Up @@ -189,7 +187,7 @@ pub(crate) trait PushInputTensor {
///
/// `::macros::InferenceOutputSignature`により、`TryFrom<OutputTensor>`も含めて導出される。
pub(crate) trait InferenceOutputSignature:
TryFrom<Vec<OutputTensor>, Error = anyhow::Error> + Send
TryFrom<Vec<OutputTensor>, Error = anyhow::Error>
{
const PARAM_INFOS: &'static [ParamInfo<OutputScalarKind>];
}
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl InferenceRuntime for self::blocking::Onnxruntime {
Ok((sess.into(), input_param_infos, output_param_infos))
}

fn run(
fn run_blocking(
OnnxruntimeRunContext { sess, inputs }: Self::RunContext,
) -> anyhow::Result<Vec<OutputTensor>> {
extract_outputs(&sess.lock_blocking().run(inputs)?)
Expand Down

0 comments on commit ac6ca3f

Please sign in to comment.