Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Mar 10, 2024
1 parent ad47904 commit c272d3d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 24 deletions.
16 changes: 9 additions & 7 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ mod model_file;
pub(crate) mod runtimes;
pub(crate) mod status;

use std::{
borrow::Cow, collections::BTreeSet, convert::Infallible, fmt::Debug, marker::PhantomData,
};
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug};

use derive_new::new;
use duplicate::duplicate_item;
Expand Down Expand Up @@ -34,7 +32,7 @@ pub(crate) trait InferenceRuntime: 'static {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

pub(crate) trait InferenceDomainGroup {
pub(crate) trait InferenceDomainGroup: Sized {
type Map<A: InferenceDomainAssociation>: InferenceDomainMap<A, Group = Self>;
}

Expand Down Expand Up @@ -64,7 +62,7 @@ pub(crate) trait InferenceDomainAssociationTargetPredicate {
}

pub(crate) trait ConvertInferenceDomainAssociationTarget<
G: InferenceDomainGroup + ?Sized,
G: InferenceDomainGroup,
A1: InferenceDomainAssociation,
A2: InferenceDomainAssociation,
E,
Expand All @@ -80,9 +78,13 @@ pub(crate) trait InferenceDomainAssociation {
type Target<D: InferenceDomain>;
}

pub(crate) struct Optional<A>(Infallible, PhantomData<fn() -> A>);
impl<A1: InferenceDomainAssociation, A2: InferenceDomainAssociation> InferenceDomainAssociation
for (A1, A2)
{
type Target<D: InferenceDomain> = (A1::Target<D>, A2::Target<D>);
}

impl<A: InferenceDomainAssociation> InferenceDomainAssociation for Optional<A> {
impl<A: InferenceDomainAssociation> InferenceDomainAssociation for Option<A> {
type Target<D: InferenceDomain> = Option<A::Target<D>>;
}

Expand Down
32 changes: 19 additions & 13 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
use super::{
model_file, InferenceDomain, InferenceDomainAssociationTargetPredicate,
InferenceDomainMap as _, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions,
InferenceSignature, Optional,
InferenceSignature,
};

pub(crate) struct Status<R: InferenceRuntime, G: InferenceDomainGroup> {
Expand All @@ -47,7 +47,7 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> Status<R, G> {
pub(crate) fn insert_model(
&self,
model_header: &VoiceModelHeader,
model_bytes: &G::Map<Optional<ModelDataByInferenceDomain>>,
model_bytes: &G::Map<Option<ModelDataByInferenceDomain>>,
) -> Result<()> {
self.loaded_models
.lock()
Expand Down Expand Up @@ -79,16 +79,16 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> Status<R, G> {
impl<R: InferenceRuntime, G: InferenceDomainGroup>
ConvertInferenceDomainAssociationTarget<
G,
Optional<ModelDataByInferenceDomain>,
Optional<ModelInnerIdsAndSessionSetByDomain<R>>,
Option<ModelDataByInferenceDomain>,
Option<(ModelInnerIdsByDomain, SessionSetByDomain<R>)>,
anyhow::Error,
> for CreateSessionSet<'_, R, G>
{
fn try_ref_map<D: InferenceDomain<Group = G>>(
&self,
model_data: &<Optional<ModelDataByInferenceDomain> as InferenceDomainAssociation>::Target<D>,
model_data: &<Option<ModelDataByInferenceDomain> as InferenceDomainAssociation>::Target<D>,
) -> anyhow::Result<
<Optional<ModelInnerIdsAndSessionSetByDomain<R>> as InferenceDomainAssociation>::Target<D>,
<Option<(ModelInnerIdsByDomain, SessionSetByDomain<R>)> as InferenceDomainAssociation>::Target<D>,
>{
model_data
.as_ref()
Expand Down Expand Up @@ -174,7 +174,7 @@ struct LoadedModels<R: InferenceRuntime, G: InferenceDomainGroup>(

struct LoadedModel<R: InferenceRuntime, G: InferenceDomainGroup> {
metas: VoiceModelMeta,
by_domain: G::Map<Optional<ModelInnerIdsAndSessionSetByDomain<R>>>,
by_domain: G::Map<Option<(ModelInnerIdsByDomain, SessionSetByDomain<R>)>>,
}

impl<R: InferenceRuntime, G: InferenceDomainGroup> LoadedModels<R, G> {
Expand Down Expand Up @@ -254,7 +254,7 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> LoadedModels<R, G> {
fn ensure_acceptable(
&self,
model_header: &VoiceModelHeader,
model_bytes_or_sessions: &G::Map<Optional<impl InferenceDomainAssociation>>,
model_bytes_or_sessions: &G::Map<Option<impl InferenceDomainAssociation>>,
) -> LoadModelResult<()> {
let error = |context| LoadModelError {
path: model_header.path.clone(),
Expand Down Expand Up @@ -312,7 +312,7 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> LoadedModels<R, G> {
impl<A: InferenceDomainAssociation> InferenceDomainAssociationTargetPredicate
for ContainsForStyleType<A>
{
type Association = Optional<A>;
type Association = Option<A>;

fn test<D: InferenceDomain>(
&self,
Expand All @@ -326,7 +326,7 @@ impl<R: InferenceRuntime, G: InferenceDomainGroup> LoadedModels<R, G> {
fn insert(
&mut self,
model_header: &VoiceModelHeader,
session_sets: G::Map<Optional<ModelInnerIdsAndSessionSetByDomain<R>>>,
session_sets: G::Map<Option<(ModelInnerIdsByDomain, SessionSetByDomain<R>)>>,
) -> Result<()> {
self.ensure_acceptable(model_header, &session_sets)?;

Expand Down Expand Up @@ -453,10 +453,16 @@ impl InferenceDomainAssociation for SessionOptionsByDomain {
type Target<D: InferenceDomain> = EnumMap<D::Operation, InferenceSessionOptions>;
}

struct ModelInnerIdsAndSessionSetByDomain<R>(Infallible, PhantomData<fn() -> R>);
enum ModelInnerIdsByDomain {}

impl<R: InferenceRuntime> InferenceDomainAssociation for ModelInnerIdsAndSessionSetByDomain<R> {
type Target<D: InferenceDomain> = (BTreeMap<StyleId, ModelInnerId>, SessionSet<R, D>);
impl InferenceDomainAssociation for ModelInnerIdsByDomain {
type Target<D: InferenceDomain> = BTreeMap<StyleId, ModelInnerId>;
}

struct SessionSetByDomain<R>(Infallible, PhantomData<fn() -> R>);

impl<R: InferenceRuntime> InferenceDomainAssociation for SessionSetByDomain<R> {
type Target<D: InferenceDomain> = SessionSet<R, D>;
}

#[cfg(test)]
Expand Down
8 changes: 4 additions & 4 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub(crate) mod blocking {

use crate::{
error::{LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{domains::InferenceDomainMapImpl, Optional},
infer::domains::InferenceDomainMapImpl,
manifest::{Manifest, TalkManifest},
VoiceModelMeta,
};
Expand All @@ -90,7 +90,7 @@ pub(crate) mod blocking {
impl self::VoiceModel {
pub(crate) fn read_inference_models(
&self,
) -> LoadModelResult<InferenceDomainMapImpl<Optional<ModelDataByInferenceDomain>>> {
) -> LoadModelResult<InferenceDomainMapImpl<Option<ModelDataByInferenceDomain>>> {
let reader = BlockingVvmEntryReader::open(&self.header.path)?;

let talk = self
Expand Down Expand Up @@ -221,7 +221,7 @@ pub(crate) mod tokio {

use crate::{
error::{LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{domains::InferenceDomainMapImpl, Optional},
infer::domains::InferenceDomainMapImpl,
manifest::{Manifest, TalkManifest},
Result, VoiceModelMeta,
};
Expand All @@ -239,7 +239,7 @@ pub(crate) mod tokio {
impl self::VoiceModel {
pub(crate) async fn read_inference_models(
&self,
) -> LoadModelResult<InferenceDomainMapImpl<Optional<ModelDataByInferenceDomain>>> {
) -> LoadModelResult<InferenceDomainMapImpl<Option<ModelDataByInferenceDomain>>> {
let reader = AsyncVvmEntryReader::open(&self.header.path).await?;

let talk = OptionFuture::from(self.header.manifest.talk().as_ref().map(
Expand Down

0 comments on commit c272d3d

Please sign in to comment.