From 697e45926b5d4a8630605ea7d6b81ffe7ebe82ba Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 10 Dec 2024 02:08:25 +0900 Subject: [PATCH] =?UTF-8?q?feat!:=20`RunAsync`=E3=82=92=E4=BD=BF=E3=81=86?= =?UTF-8?q?=20(#889)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 非同期APIにおける推論に`ort::Session::run_async`を使うようにする。これに より、推論がキャンセル可能になる。 #888 の影響を受け推論が失敗するので、CIを通すため、Python APIのテストと exampleについては(すべてのプラットフォームで) `cpu_num_threads=max(multiprocessing.cpu_cout(), 2)`とする。 Resolves: #687 Refs: VOICEVOX/ort#11 BREAKING-CHANGE: 非同期APIにおいて、INTRA Thread Countが`1`だと推論がすべてエラーになる。 See-also: https://docs.rs/ort/2.0.0-rc.4/ort/struct.Session.html#method.run_async --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/voicevox_core/src/infer.rs | 60 +++++-- .../src/infer/runtimes/onnxruntime.rs | 88 +++++----- crates/voicevox_core/src/infer/session_set.rs | 19 ++- crates/voicevox_core/src/status.rs | 6 +- crates/voicevox_core/src/synthesizer.rs | 153 +++++++++--------- .../src/inference_domain.rs | 16 +- .../test/test_asyncio_user_dict_load.py | 9 +- example/python/run-asyncio.py | 8 +- example/python/run.py | 8 +- 11 files changed, 226 insertions(+), 147 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2005f6ffa..c187c0cf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4251,7 +4251,7 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "voicevox-ort" version = "2.0.0-rc.4" -source = "git+https://github.com/VOICEVOX/ort.git?rev=17f741301db0bb08da0eafe8a338e5efd8a4b5df#17f741301db0bb08da0eafe8a338e5efd8a4b5df" +source = "git+https://github.com/VOICEVOX/ort.git?rev=09a9fe1619c1561efafc02f68f0bda4aad879771#09a9fe1619c1561efafc02f68f0bda4aad879771" dependencies = [ "anyhow", "half", @@ -4268,7 +4268,7 @@ dependencies = [ [[package]] name = "voicevox-ort-sys" version = "2.0.0-rc.4" -source = "git+https://github.com/VOICEVOX/ort.git?rev=17f741301db0bb08da0eafe8a338e5efd8a4b5df#17f741301db0bb08da0eafe8a338e5efd8a4b5df" +source = "git+https://github.com/VOICEVOX/ort.git?rev=09a9fe1619c1561efafc02f68f0bda4aad879771#09a9fe1619c1561efafc02f68f0bda4aad879771" dependencies = [ "flate2", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index a3a606317..3c3c184c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,7 +99,7 @@ zip = "0.6.3" [workspace.dependencies.voicevox-ort] git = "https://github.com/VOICEVOX/ort.git" -rev = "17f741301db0bb08da0eafe8a338e5efd8a4b5df" +rev = "09a9fe1619c1561efafc02f68f0bda4aad879771" [workspace.dependencies.open_jtalk] git = "https://github.com/VOICEVOX/open_jtalk-rs.git" diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index e827ddd7d..7b81b7c5d 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -12,14 +12,39 @@ use ndarray::{Array, ArrayD, Dimension, ShapeError}; use thiserror::Error; use crate::{ + asyncs::{Async, BlockingThreadPool, SingleTasked}, devices::{DeviceSpec, GpuSpec}, StyleType, SupportedDevices, }; +pub(crate) trait AsyncExt: Async { + async fn run_session( + ctx: R::RunContext, + ) -> anyhow::Result>; +} + +impl AsyncExt for SingleTasked { + async fn run_session( + ctx: R::RunContext, + ) -> anyhow::Result> { + R::run_blocking(ctx) + } +} + +impl AsyncExt for BlockingThreadPool { + async fn run_session( + ctx: R::RunContext, + ) -> anyhow::Result> { + R::run_async(ctx).await + } +} + pub(crate) trait InferenceRuntime: 'static { // TODO: "session"とは何なのかを定め、ドキュメントを書く。`InferenceSessionSet`も同様。 - type Session: Sized + Send + 'static; - type RunContext<'a>: From<&'a mut Self::Session> + PushInputTensor; + type Session; + + // 本当は`From<&'_ Self::Session>`としたいが、 rust-lang/rust#100013 が立ち塞がる + type RunContext: From> + PushInputTensor; /// 名前。 const DISPLAY_NAME: &'static str; @@ -45,7 +70,9 @@ pub(crate) trait InferenceRuntime: 'static { Vec>, )>; - fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; + fn run_blocking(ctx: Self::RunContext) -> anyhow::Result>; + + async fn run_async(ctx: Self::RunContext) -> anyhow::Result>; } /// 共に扱われるべき推論操作の集合を示す。 @@ -86,7 +113,7 @@ pub(crate) trait InferenceOperation: Copy + Enum { /// `InferenceDomain`の推論操作を表す列挙型。 /// /// `::macros::InferenceOperation`により、具体型ごと生成される。 -pub(crate) trait InferenceSignature: Sized + Send + 'static { +pub(crate) trait InferenceSignature { type Domain: InferenceDomain; type Input: InferenceInputSignature; type Output: InferenceOutputSignature; @@ -96,13 +123,13 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static { /// 推論操作の入力シグネチャ。 /// /// `::macros::InferenceInputSignature`により導出される。 -pub(crate) trait InferenceInputSignature: Send + 'static { +pub(crate) trait InferenceInputSignature { type Signature: InferenceSignature; const PARAM_INFOS: &'static [ParamInfo]; fn make_run_context( self, - sess: &mut R::Session, - ) -> anyhow::Result>; + sess: Arc, + ) -> anyhow::Result; } pub(crate) trait InputScalar: Sized { @@ -110,6 +137,7 @@ pub(crate) trait InputScalar: Sized { // TODO: `Array`ではなく`ArrayView`を取ることができるかもしれない fn push_tensor_to_ctx( + name: &'static str, tensor: Array, visitor: &mut impl PushInputTensor, ) -> anyhow::Result<()>; @@ -124,10 +152,11 @@ impl InputScalar for T { const KIND: InputScalarKind = KIND_VAL; fn push_tensor_to_ctx( + name: &'static str, tensor: Array, ctx: &mut impl PushInputTensor, ) -> anyhow::Result<()> { - ctx.push(tensor) + ctx.push(name, tensor) } } @@ -141,15 +170,24 @@ pub(crate) enum InputScalarKind { } pub(crate) trait PushInputTensor { - fn push_int64(&mut self, tensor: Array) -> anyhow::Result<()>; - fn push_float32(&mut self, tensor: Array) -> anyhow::Result<()>; + fn push_int64( + &mut self, + name: &'static str, + tensor: Array, + ) -> anyhow::Result<()>; + + fn push_float32( + &mut self, + name: &'static str, + tensor: Array, + ) -> anyhow::Result<()>; } /// 推論操作の出力シグネチャ。 /// /// `::macros::InferenceOutputSignature`により、`TryFrom`も含めて導出される。 pub(crate) trait InferenceOutputSignature: - TryFrom, Error = anyhow::Error> + Send + TryFrom, Error = anyhow::Error> { const PARAM_INFOS: &'static [ParamInfo]; } diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index ffeb55e5f..f2cc4fac7 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -11,7 +11,7 @@ // } // ``` -use std::{fmt::Debug, vec}; +use std::{fmt::Debug, sync::Arc, vec}; use anyhow::{anyhow, bail, ensure}; use duplicate::duplicate_item; @@ -32,8 +32,8 @@ use super::super::{ }; impl InferenceRuntime for self::blocking::Onnxruntime { - type Session = ort::Session; - type RunContext<'a> = OnnxruntimeRunContext<'a>; + type Session = async_lock::Mutex; // WASMでは`ort`を利用しないので、ここはasync-lockを用いてよいはず + type RunContext = OnnxruntimeRunContext; const DISPLAY_NAME: &'static str = if cfg!(feature = "load-onnxruntime") { "現在ロードされているONNX Runtime" @@ -179,58 +179,44 @@ impl InferenceRuntime for self::blocking::Onnxruntime { }) .collect::>()?; - Ok((sess, input_param_infos, output_param_infos)) + Ok((sess.into(), input_param_infos, output_param_infos)) } - fn run( - OnnxruntimeRunContext { sess, inputs }: OnnxruntimeRunContext<'_>, + fn run_blocking( + OnnxruntimeRunContext { sess, inputs }: Self::RunContext, ) -> anyhow::Result> { - let outputs = sess.run(&*inputs)?; - - (0..outputs.len()) - .map(|i| { - let output = &outputs[i]; - - let ValueType::Tensor { ty, .. } = output.dtype()? else { - bail!( - "unexpected output. currently `ONNX_TYPE_TENSOR` and \ - `ONNX_TYPE_SPARSETENSOR` is supported", - ); - }; + extract_outputs(&sess.lock_blocking().run(inputs)?) + } - match ty { - TensorElementType::Float32 => { - let output = output.try_extract_tensor::()?; - Ok(OutputTensor::Float32(output.into_owned())) - } - _ => bail!("unexpected output tensor element data type"), - } - }) - .collect() + async fn run_async( + OnnxruntimeRunContext { sess, inputs }: Self::RunContext, + ) -> anyhow::Result> { + extract_outputs(&sess.lock().await.run_async(inputs)?.await?) } } -pub(crate) struct OnnxruntimeRunContext<'sess> { - sess: &'sess ort::Session, - inputs: Vec>, +pub(crate) struct OnnxruntimeRunContext { + sess: Arc>, + inputs: Vec<(&'static str, ort::SessionInputValue<'static>)>, } -impl OnnxruntimeRunContext<'_> { +impl OnnxruntimeRunContext { fn push_input( &mut self, + name: &'static str, input: Array< impl PrimitiveTensorElementType + Debug + Clone + 'static, impl Dimension + 'static, >, ) -> anyhow::Result<()> { let input = ort::Value::from_array(input)?.into(); - self.inputs.push(input); + self.inputs.push((name, input)); Ok(()) } } -impl<'sess> From<&'sess mut ort::Session> for OnnxruntimeRunContext<'sess> { - fn from(sess: &'sess mut ort::Session) -> Self { +impl From>> for OnnxruntimeRunContext { + fn from(sess: Arc>) -> Self { Self { sess, inputs: vec![], @@ -238,17 +224,45 @@ impl<'sess> From<&'sess mut ort::Session> for OnnxruntimeRunContext<'sess> { } } -impl PushInputTensor for OnnxruntimeRunContext<'_> { +impl PushInputTensor for OnnxruntimeRunContext { #[duplicate_item( method T; [ push_int64 ] [ i64 ]; [ push_float32 ] [ f32 ]; )] - fn method(&mut self, tensor: Array) -> anyhow::Result<()> { - self.push_input(tensor) + fn method( + &mut self, + name: &'static str, + tensor: Array, + ) -> anyhow::Result<()> { + self.push_input(name, tensor) } } +// FIXME: use ouroboros to reduce copies +fn extract_outputs(outputs: &ort::SessionOutputs<'_, '_>) -> anyhow::Result> { + (0..outputs.len()) + .map(|i| { + let output = &outputs[i]; + + let ValueType::Tensor { ty, .. } = output.dtype()? else { + bail!( + "unexpected output. currently `ONNX_TYPE_TENSOR` and `ONNX_TYPE_SPARSETENSOR` + is supported", + ); + }; + + match ty { + TensorElementType::Float32 => { + let output = output.try_extract_tensor::()?; + Ok(OutputTensor::Float32(output.into_owned())) + } + _ => bail!("unexpected output tensor element data type"), + } + }) + .collect() +} + pub(crate) mod blocking { use ort::EnvHandle; use ref_cast::{ref_cast_custom, RefCastCustom}; diff --git a/crates/voicevox_core/src/infer/session_set.rs b/crates/voicevox_core/src/infer/session_set.rs index e94fff962..1bac459b7 100644 --- a/crates/voicevox_core/src/infer/session_set.rs +++ b/crates/voicevox_core/src/infer/session_set.rs @@ -12,7 +12,7 @@ use super::{ }; pub(crate) struct InferenceSessionSet( - EnumMap>>, + EnumMap>, ); impl InferenceSessionSet { @@ -33,7 +33,7 @@ impl InferenceSessionSet { check_param_infos(expected_input_param_infos, &actual_input_param_infos)?; check_param_infos(expected_output_param_infos, &actual_output_param_infos)?; - Ok((op.into_usize(), std::sync::Mutex::new(sess).into())) + Ok((op.into_usize(), sess.into())) }) .collect::>>()?; @@ -84,18 +84,21 @@ impl InferenceSessionSet { } pub(crate) struct InferenceSessionCell { - inner: Arc>, + inner: Arc, marker: PhantomData, } impl InferenceSessionCell { - pub(crate) fn run( + pub(crate) async fn run( self, input: I, ) -> crate::Result<::Output> { - let inner = &mut self.inner.lock().unwrap(); - (|| R::run(input.make_run_context::(inner)?)?.try_into())() - .map_err(ErrorRepr::RunModel) - .map_err(Into::into) + async { + let ctx = input.make_run_context::(self.inner.clone())?; + A::run_session::(ctx).await?.try_into() + } + .await + .map_err(ErrorRepr::RunModel) + .map_err(Into::into) } } diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 64986f627..c59573412 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -9,6 +9,7 @@ use itertools::iproduct; use crate::{ error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult}, infer::{ + self, domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain}, session_set::{InferenceSessionCell, InferenceSessionSet}, InferenceDomain, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions, @@ -104,17 +105,18 @@ impl Status { /// # Panics /// /// `self`が`model_id`を含んでいないとき、パニックする。 - pub(crate) fn run_session( + pub(crate) async fn run_session( &self, model_id: VoiceModelId, input: I, ) -> Result<::Output> where + A: infer::AsyncExt, I: InferenceInputSignature, ::Domain: InferenceDomainExt, { let sess = self.loaded_models.lock().unwrap().get(model_id); - sess.run(input) + sess.run::(input).await } } diff --git a/crates/voicevox_core/src/synthesizer.rs b/crates/voicevox_core/src/synthesizer.rs index 31aec6b72..230a26294 100644 --- a/crates/voicevox_core/src/synthesizer.rs +++ b/crates/voicevox_core/src/synthesizer.rs @@ -1,4 +1,7 @@ -use crate::asyncs::{Async, BlockingThreadPool, SingleTasked}; +use crate::{ + asyncs::{BlockingThreadPool, SingleTasked}, + infer, +}; pub use self::inner::MARGIN; @@ -70,14 +73,14 @@ pub struct InitializeOptions { pub cpu_num_threads: u16, } -trait AsyncForOnnxruntime: Async { +trait AsyncExt: infer::AsyncExt { async fn unblock(f: F) -> T where F: FnOnce() -> T + Send + 'static, T: Send + 'static; } -impl AsyncForOnnxruntime for SingleTasked { +impl AsyncExt for SingleTasked { async fn unblock(f: F) -> T where F: FnOnce() -> T + Send + 'static, @@ -87,7 +90,7 @@ impl AsyncForOnnxruntime for SingleTasked { } } -impl AsyncForOnnxruntime for BlockingThreadPool { +impl AsyncExt for BlockingThreadPool { async fn unblock(f: F) -> T where F: FnOnce() -> T + Send + 'static, @@ -108,11 +111,12 @@ mod inner { use tracing::info; use crate::{ - asyncs::{BlockingThreadPool, SingleTasked}, + asyncs::{Async, BlockingThreadPool, SingleTasked}, devices::{DeviceSpec, GpuSpec}, engine::{create_kana, mora_to_text, wav_from_s16le, Mora, OjtPhoneme}, error::ErrorRepr, infer::{ + self, domains::{ GenerateFullIntermediateInput, GenerateFullIntermediateOutput, InferenceDomainMap, PredictDurationInput, PredictDurationOutput, PredictIntonationInput, @@ -127,7 +131,7 @@ mod inner { SynthesisOptions, VoiceModelId, VoiceModelMeta, }; - use super::{AccelerationMode, AsyncForOnnxruntime, InitializeOptions, TtsOptions}; + use super::{AccelerationMode, AsyncExt, InitializeOptions, TtsOptions}; const DEFAULT_SAMPLING_RATE: u32 = 24000; /// 音が途切れてしまうのを避けるworkaround処理のためのパディング幅(フレーム数) @@ -169,7 +173,7 @@ mod inner { audio_query: AudioQuery, } - pub struct Inner { + pub struct Inner { status: Arc>, open_jtalk_analyzer: OpenJTalkAnalyzer, kana_analyzer: KanaAnalyzer, @@ -189,7 +193,7 @@ mod inner { } } - impl Inner { + impl Inner { pub(super) fn new( onnxruntime: &'static crate::blocking::Onnxruntime, open_jtalk: O, @@ -751,7 +755,7 @@ mod inner { } } - impl Inner { + impl Inner { pub(super) async fn create_accent_phrases( &self, text: &str, @@ -782,7 +786,8 @@ mod inner { } } - impl Inner { + // TODO: この層を破壊する + impl Inner { pub(super) async fn predict_duration( &self, phoneme_vector: &[i64], @@ -790,7 +795,7 @@ mod inner { ) -> Result> { let status = self.status.clone(); let phoneme_vector = ndarray::arr1(phoneme_vector); - A::unblock(move || status.predict_duration(phoneme_vector, style_id)).await + status.predict_duration::(phoneme_vector, style_id).await } #[expect( @@ -816,8 +821,8 @@ mod inner { let end_accent_vector = ndarray::arr1(end_accent_vector); let start_accent_phrase_vector = ndarray::arr1(start_accent_phrase_vector); let end_accent_phrase_vector = ndarray::arr1(end_accent_phrase_vector); - A::unblock(move || { - status.predict_intonation( + status + .predict_intonation::( length, vowel_phoneme_vector, consonant_phoneme_vector, @@ -827,8 +832,7 @@ mod inner { end_accent_phrase_vector, style_id, ) - }) - .await + .await } pub(super) async fn generate_full_intermediate( @@ -842,16 +846,9 @@ mod inner { let status = self.status.clone(); let f0 = ndarray::arr1(f0); let phoneme_vector = ndarray::arr1(phoneme_vector); - A::unblock(move || { - status.generate_full_intermediate( - length, - phoneme_size, - f0, - phoneme_vector, - style_id, - ) - }) - .await + status + .generate_full_intermediate::(length, phoneme_size, f0, phoneme_vector, style_id) + .await } pub(super) async fn render_audio_segment( @@ -860,7 +857,7 @@ mod inner { style_id: StyleId, ) -> Result> { let status = self.status.clone(); - A::unblock(move || status.render_audio_segment(spec, style_id)).await + status.render_audio_segment::(spec, style_id).await } pub(super) async fn decode( @@ -874,14 +871,14 @@ mod inner { let status = self.status.clone(); let f0 = ndarray::arr1(f0); let phoneme_vector = ndarray::arr1(phoneme_vector); - A::unblock(move || status.decode(length, phoneme_size, f0, phoneme_vector, style_id)) + status + .decode::(length, phoneme_size, f0, phoneme_vector, style_id) .await } } impl Status { - /// CPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。 - fn predict_duration( + pub(super) async fn predict_duration( &self, phoneme_vector: ndarray::Array1, style_id: StyleId, @@ -890,13 +887,15 @@ mod inner { let PredictDurationOutput { phoneme_length: output, - } = self.run_session( - model_id, - PredictDurationInput { - phoneme_list: phoneme_vector, - speaker_id: ndarray::arr1(&[inner_voice_id.raw_id().into()]), - }, - )?; + } = self + .run_session::( + model_id, + PredictDurationInput { + phoneme_list: phoneme_vector, + speaker_id: ndarray::arr1(&[inner_voice_id.raw_id().into()]), + }, + ) + .await?; let mut output = output.into_raw_vec(); for output_item in output.iter_mut() { @@ -910,13 +909,12 @@ mod inner { const PHONEME_LENGTH_MINIMAL: f32 = 0.01; } - /// CPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。 #[expect( clippy::too_many_arguments, reason = "compatible_engineでの`predict_intonation`の形を考えると、ここの引数を構造体に\ まとめたりしても可読性に寄与しない" )] - fn predict_intonation( + pub(super) async fn predict_intonation( &self, length: usize, vowel_phoneme_vector: ndarray::Array1, @@ -929,19 +927,21 @@ mod inner { ) -> Result> { let (model_id, inner_voice_id) = self.ids_for::(style_id)?; - let PredictIntonationOutput { f0_list: output } = self.run_session( - model_id, - PredictIntonationInput { - length: ndarray::arr0(length as i64), - vowel_phoneme_list: vowel_phoneme_vector, - consonant_phoneme_list: consonant_phoneme_vector, - start_accent_list: start_accent_vector, - end_accent_list: end_accent_vector, - start_accent_phrase_list: start_accent_phrase_vector, - end_accent_phrase_list: end_accent_phrase_vector, - speaker_id: ndarray::arr1(&[inner_voice_id.raw_id().into()]), - }, - )?; + let PredictIntonationOutput { f0_list: output } = self + .run_session::( + model_id, + PredictIntonationInput { + length: ndarray::arr0(length as i64), + vowel_phoneme_list: vowel_phoneme_vector, + consonant_phoneme_list: consonant_phoneme_vector, + start_accent_list: start_accent_vector, + end_accent_list: end_accent_vector, + start_accent_phrase_list: start_accent_phrase_vector, + end_accent_phrase_list: end_accent_phrase_vector, + speaker_id: ndarray::arr1(&[inner_voice_id.raw_id().into()]), + }, + ) + .await?; Ok(output.into_raw_vec()) } @@ -949,9 +949,7 @@ mod inner { /// モデル`generate_full_intermediate`の実行と、その前後の処理を行う。 /// /// 無音パディングを付加して音声特徴量を計算し、マージン込みの音声特徴量を返す。 - /// - /// CPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。 - fn generate_full_intermediate( + pub(super) async fn generate_full_intermediate( &self, length: usize, phoneme_size: usize, @@ -973,16 +971,18 @@ mod inner { let GenerateFullIntermediateOutput { spec: spec_with_padding, - } = self.run_session( - model_id, - GenerateFullIntermediateInput { - f0: f0_with_padding - .into_shape([length_with_padding, 1]) - .unwrap(), - phoneme: phoneme_with_padding, - speaker_id: ndarray::arr1(&[inner_voice_id.raw_id().into()]), - }, - )?; + } = self + .run_session::( + model_id, + GenerateFullIntermediateInput { + f0: f0_with_padding + .into_shape([length_with_padding, 1]) + .unwrap(), + phoneme: phoneme_with_padding, + speaker_id: ndarray::arr1(&[inner_voice_id.raw_id().into()]), + }, + ) + .await?; // マージンがデータからはみ出さないことを保証 // cf. https://github.com/VOICEVOX/voicevox_core/pull/854#discussion_r1803691291 @@ -1024,20 +1024,19 @@ mod inner { } /// 与えられた音声特徴量で音声生成。 - /// CPU/GPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。 - fn render_audio_segment( + pub(super) async fn render_audio_segment( &self, spec: ndarray::Array2, style_id: StyleId, ) -> Result> { let (model_id, _inner_voice_id) = self.ids_for::(style_id)?; - let RenderAudioSegmentOutput { wave } = - self.run_session(model_id, RenderAudioSegmentInput { spec })?; + let RenderAudioSegmentOutput { wave } = self + .run_session::(model_id, RenderAudioSegmentInput { spec }) + .await?; Ok(wave) } - /// CPU/GPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。 - fn decode( + pub(super) async fn decode( &self, length: usize, phoneme_size: usize, @@ -1045,14 +1044,12 @@ mod inner { phoneme_vector: ndarray::Array1, style_id: StyleId, ) -> Result> { - let intermediate = self.generate_full_intermediate( - length, - phoneme_size, - f0, - phoneme_vector, - style_id, - )?; - let output_with_margin = self.render_audio_segment(intermediate, style_id)?; + let intermediate = self + .generate_full_intermediate::(length, phoneme_size, f0, phoneme_vector, style_id) + .await?; + let output_with_margin = self + .render_audio_segment::(intermediate, style_id) + .await?; let output = trim_margin_from_wave(output_with_margin); Ok(output.to_vec()) } diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs index f959982e4..6570791d2 100644 --- a/crates/voicevox_core_macros/src/inference_domain.rs +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -222,17 +222,22 @@ pub(crate) fn derive_inference_input_signature( fn make_run_context( self, - sess: &mut R::Session, - ) -> ::anyhow::Result> { - let mut ctx = as ::std::convert::From<_>>::from(sess); + sess: ::std::sync::Arc, + ) -> ::anyhow::Result { + let mut ctx = >::from(sess); #( - __ArrayExt::push_to_ctx(self.#field_names, &mut ctx)?; + __ArrayExt::push_to_ctx( + self.#field_names, + ::std::stringify!(#field_names), + &mut ctx, + )?; )* return ::std::result::Result::Ok(ctx); trait __ArrayExt { fn push_to_ctx( self, + name: &'static str, ctx: &mut impl crate::infer::PushInputTensor, ) -> ::anyhow::Result<()>; } @@ -242,9 +247,10 @@ pub(crate) fn derive_inference_input_signature( { fn push_to_ctx( self, + name: &'static str, ctx: &mut impl crate::infer::PushInputTensor, ) -> ::anyhow::Result<()> { - A::push_tensor_to_ctx(self, ctx) + A::push_tensor_to_ctx(name, self, ctx) } } } diff --git a/crates/voicevox_core_python_api/python/test/test_asyncio_user_dict_load.py b/crates/voicevox_core_python_api/python/test/test_asyncio_user_dict_load.py index e433a7442..135822ca7 100644 --- a/crates/voicevox_core_python_api/python/test/test_asyncio_user_dict_load.py +++ b/crates/voicevox_core_python_api/python/test/test_asyncio_user_dict_load.py @@ -6,6 +6,7 @@ # AudioQueryのkanaを比較して変化するかどうかで判断する。 +import multiprocessing from uuid import UUID import conftest # noqa: F401 @@ -20,7 +21,13 @@ async def test_user_dict_load() -> None: ) open_jtalk = await voicevox_core.asyncio.OpenJtalk.new(conftest.open_jtalk_dic_dir) model = await voicevox_core.asyncio.VoiceModelFile.open(conftest.model_dir) - synthesizer = voicevox_core.asyncio.Synthesizer(onnxruntime, open_jtalk) + synthesizer = voicevox_core.asyncio.Synthesizer( + onnxruntime, + open_jtalk, + cpu_num_threads=max( + multiprocessing.cpu_count(), 2 + ), # https://github.com/VOICEVOX/voicevox_core/issues/888 + ) await synthesizer.load_voice_model(model) diff --git a/example/python/run-asyncio.py b/example/python/run-asyncio.py index 0d7460f90..3cb8a7e48 100644 --- a/example/python/run-asyncio.py +++ b/example/python/run-asyncio.py @@ -4,6 +4,7 @@ import dataclasses import json import logging +import multiprocessing from argparse import ArgumentParser from pathlib import Path @@ -91,7 +92,12 @@ async def main() -> None: logger.info("%s", f"Initializing ({args.mode=}, {args.dict_dir=})") synthesizer = Synthesizer( - onnxruntime, await OpenJtalk.new(args.dict_dir), acceleration_mode=args.mode + onnxruntime, + await OpenJtalk.new(args.dict_dir), + acceleration_mode=args.mode, + cpu_num_threads=max( + multiprocessing.cpu_count(), 2 + ), # https://github.com/VOICEVOX/voicevox_core/issues/888 ) logger.debug("%s", f"{synthesizer.metas=}") diff --git a/example/python/run.py b/example/python/run.py index 4a22e709c..5970f5dff 100644 --- a/example/python/run.py +++ b/example/python/run.py @@ -1,6 +1,7 @@ import dataclasses import json import logging +import multiprocessing from argparse import ArgumentParser from pathlib import Path @@ -95,7 +96,12 @@ def main() -> None: logger.info("%s", f"Initializing ({args.mode=}, {args.dict_dir=})") synthesizer = Synthesizer( - onnxruntime, OpenJtalk(args.dict_dir), acceleration_mode=args.mode + onnxruntime, + OpenJtalk(args.dict_dir), + acceleration_mode=args.mode, + cpu_num_threads=max( + multiprocessing.cpu_count(), 2 + ), # https://github.com/VOICEVOX/voicevox_core/issues/888 ) logger.debug("%s", f"{synthesizer.metas=}")