Skip to content

Commit

Permalink
InferenceModelsの定義をsignaturesに移動
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 5, 2023
1 parent f959911 commit 55b04d3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
6 changes: 6 additions & 0 deletions crates/voicevox_core/src/infer/signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ use ndarray::{Array0, Array1, Array2};

use crate::infer::{InferenceRuntime, RunBuilder, Signature, TypedSession};

pub(crate) struct ModelBytesSet {
pub(crate) predict_duration: Vec<u8>,
pub(crate) predict_intonation: Vec<u8>,
pub(crate) decode: Vec<u8>,
}

pub(crate) struct SessionSet<R: InferenceRuntime> {
pub(crate) predict_duration: Arc<std::sync::Mutex<TypedSession<R, PredictDuration>>>,
pub(crate) predict_intonation: Arc<std::sync::Mutex<TypedSession<R, PredictIntonation>>>,
Expand Down
11 changes: 4 additions & 7 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,17 @@ impl<R: InferenceRuntime> Status<R> {
let models = model.read_inference_models().await?;

let predict_duration_session = self.new_session(
models.predict_duration_model(),
&models.predict_duration,
&self.light_session_options,
model.path(),
)?;
let predict_intonation_session = self.new_session(
models.predict_intonation_model(),
&models.predict_intonation,
&self.light_session_options,
model.path(),
)?;
let decode_model = self.new_session(
models.decode_model(),
&self.heavy_session_options,
model.path(),
)?;
let decode_model =
self.new_session(&models.decode, &self.heavy_session_options, model.path())?;

self.loaded_models.lock().unwrap().insert(
model,
Expand Down
18 changes: 6 additions & 12 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use futures::future::join3;
use serde::{de::DeserializeOwned, Deserialize};

use super::*;
use crate::infer::signatures::ModelBytesSet;
use std::{
collections::{BTreeMap, HashMap},
io,
Expand Down Expand Up @@ -35,15 +36,8 @@ pub struct VoiceModel {
path: PathBuf,
}

#[derive(Getters)]
pub(crate) struct InferenceModels {
decode_model: Vec<u8>,
predict_duration_model: Vec<u8>,
predict_intonation_model: Vec<u8>,
}

impl VoiceModel {
pub(crate) async fn read_inference_models(&self) -> LoadModelResult<InferenceModels> {
pub(crate) async fn read_inference_models(&self) -> LoadModelResult<ModelBytesSet> {
let reader = VvmEntryReader::open(&self.path).await?;
let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) =
join3(
Expand All @@ -53,10 +47,10 @@ impl VoiceModel {
)
.await;

Ok(InferenceModels {
predict_duration_model: predict_duration_model_result?,
predict_intonation_model: predict_intonation_model_result?,
decode_model: decode_model_result?,
Ok(ModelBytesSet {
predict_duration: predict_duration_model_result?,
predict_intonation: predict_intonation_model_result?,
decode: decode_model_result?,
})
}
/// VVMファイルから`VoiceModel`をコンストラクトする。
Expand Down

0 comments on commit 55b04d3

Please sign in to comment.