diff --git a/Cargo.lock b/Cargo.lock index e33dbc954..69b5d78ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1230,6 +1230,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb2a23ad36148a32085addb3ef1aa39805d044d4532ff258360d523a4eff38e5" dependencies = [ "enum-map-derive", + "serde", ] [[package]] diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 82ada11be..984b62ec5 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -29,7 +29,7 @@ derive_more = { workspace = true, features = ["add", "deref", "display", "from", duplicate.workspace = true easy-ext.workspace = true educe.workspace = true -enum-map.workspace = true +enum-map = { workspace = true, features = ["serde"] } fs-err.workspace = true futures-io.workspace = true futures-lite.workspace = true diff --git a/crates/voicevox_core/src/manifest.rs b/crates/voicevox_core/src/manifest.rs index 0b1a005cd..740254f6f 100644 --- a/crates/voicevox_core/src/manifest.rs +++ b/crates/voicevox_core/src/manifest.rs @@ -1,13 +1,14 @@ use std::{ collections::BTreeMap, fmt::{self, Display}, + ops::Index, sync::Arc, }; use derive_getters::Getters; use derive_more::Deref; use derive_new::new; -use macros::IndexForFields; +use enum_map::{Enum, EnumMap}; use serde::{de, Deserialize, Deserializer, Serialize}; use serde_with::{serde_as, DisplayFromStr}; @@ -81,24 +82,43 @@ pub struct Manifest { pub(crate) type ManifestDomains = inference_domain_map_values!(for Option); -#[derive(Deserialize, IndexForFields)] +#[derive(Deserialize)] #[cfg_attr(test, derive(Default))] -#[index_for_fields(TalkOperation)] pub(crate) struct TalkManifest { - #[index_for_fields(TalkOperation::PredictDuration)] - pub(crate) predict_duration_filename: Arc, + #[serde(flatten)] + filenames: EnumMap>, - #[index_for_fields(TalkOperation::PredictIntonation)] - pub(crate) predict_intonation_filename: Arc, + #[serde(default)] + pub(crate) style_id_to_inner_voice_id: StyleIdToInnerVoiceId, +} - #[index_for_fields(TalkOperation::GenerateFullIntermediate)] - pub(crate) generate_full_intermediate_filename: Arc, +// TODO: #825 では`TalkOperation`と統合する。`Index`の実装もderive_moreで委譲する +#[derive(Enum, Deserialize)] +pub(crate) enum TalkOperationFilenameKey { + #[serde(rename = "predict_duration_filename")] + PredictDuration, + #[serde(rename = "predict_intonation_filename")] + PredictIntonation, + #[serde(rename = "generate_full_intermediate_filename")] + GenerateFullIntermediate, + #[serde(rename = "render_audio_segment_filename")] + RenderAudioSegment, +} - #[index_for_fields(TalkOperation::RenderAudioSegment)] - pub(crate) render_audio_segment_filename: Arc, +impl Index for TalkManifest { + type Output = Arc; - #[serde(default)] - pub(crate) style_id_to_inner_voice_id: StyleIdToInnerVoiceId, + fn index(&self, index: TalkOperation) -> &Self::Output { + let key = match index { + TalkOperation::PredictDuration => TalkOperationFilenameKey::PredictDuration, + TalkOperation::PredictIntonation => TalkOperationFilenameKey::PredictIntonation, + TalkOperation::GenerateFullIntermediate => { + TalkOperationFilenameKey::GenerateFullIntermediate + } + TalkOperation::RenderAudioSegment => TalkOperationFilenameKey::RenderAudioSegment, + }; + &self.filenames[key] + } } #[serde_as] diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 9914409eb..d0dde486f 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -11,7 +11,7 @@ use std::{ use anyhow::{anyhow, Context as _}; use derive_more::From; use easy_ext::ext; -use enum_map::{enum_map, EnumMap}; +use enum_map::{Enum, EnumMap}; use futures_io::{AsyncBufRead, AsyncRead, AsyncSeek}; use futures_util::future::{OptionFuture, TryFutureExt as _}; use itertools::Itertools as _; @@ -23,7 +23,7 @@ use crate::{ asyncs::{Async, Mutex as _}, error::{LoadModelError, LoadModelErrorKind, LoadModelResult}, infer::{ - domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain, TalkOperation}, + domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain}, InferenceDomain, }, manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId}, @@ -128,7 +128,7 @@ impl Inner { let header = VoiceModelHeader::new(manifest, metas, path)?.into(); - InnerTryBuilder { + return InnerTryBuilder { header, inference_model_entries_builder: |header| { let VoiceModelHeader { manifest, .. } = &**header; @@ -139,21 +139,8 @@ impl Inner { talk: |talk| { talk.as_ref() .map(|manifest| { - let indices = enum_map! { - TalkOperation::PredictDuration => { - find_entry_index(&manifest.predict_duration_filename)? - } - TalkOperation::PredictIntonation => { - find_entry_index(&manifest.predict_intonation_filename)? - } - TalkOperation::GenerateFullIntermediate => { - find_entry_index(&manifest.generate_full_intermediate_filename)? - } - TalkOperation::RenderAudioSegment => { - find_entry_index(&manifest.render_audio_segment_filename)? - } - }; - + let indices = EnumMap::from_fn(|k| &manifest[k]) + .try_map(|_, s| find_entry_index(s))?; Ok(InferenceModelEntry { indices, manifest }) }) .transpose() @@ -172,7 +159,26 @@ impl Inner { }, zip: zip.into_inner().into_inner().into(), } - .try_build() + .try_build(); + + #[ext] + impl EnumMap { + fn try_map( + self, + f: impl FnMut(K, V) -> Result, + ) -> Result, E> { + let mut elems = self + .map(f) + .into_iter() + .map(|(_, r)| r.map(Some)) + .collect::, _>>()?; + + Ok(EnumMap::::from_fn(|key| { + let key = key.into_usize(); + elems[key].take().expect("each `key` should be distinct") + })) + } + } } fn id(&self) -> VoiceModelId { diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index e96456ea8..f9362d8c9 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -3,7 +3,6 @@ mod extract; mod inference_domain; mod inference_domains; -mod manifest; use syn::parse_macro_input; @@ -103,38 +102,6 @@ pub fn derive_inference_output_signature( from_syn(inference_domain::derive_inference_output_signature(input)) } -/// 構造体のフィールドを取得できる`std::ops::Index`の実装を導出する。 -/// -/// # Example -/// -/// ``` -/// use macros::IndexForFields; -/// -/// #[derive(IndexForFields)] -/// #[index_for_fields(TalkOperation)] -/// pub(crate) struct TalkManifest { -/// #[index_for_fields(TalkOperation::PredictDuration)] -/// pub(crate) predict_duration_filename: Arc, -/// -/// #[index_for_fields(TalkOperation::PredictIntonation)] -/// pub(crate) predict_intonation_filename: Arc, -/// -/// #[index_for_fields(TalkOperation::GenerateFullIntermediate)] -/// pub(crate) generate_full_intermediate_filename: Arc, -/// -/// #[index_for_fields(TalkOperation::RenderAudioSegment)] -/// pub(crate) render_audio_segment_filename: Arc, -/// -/// // … -/// } -/// ``` -#[cfg(not(doctest))] -#[proc_macro_derive(IndexForFields, attributes(index_for_fields))] -pub fn derive_index_for_fields(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = &parse_macro_input!(input); - from_syn(manifest::derive_index_for_fields(input)) -} - /// # Example /// /// ``` diff --git a/crates/voicevox_core_macros/src/manifest.rs b/crates/voicevox_core_macros/src/manifest.rs deleted file mode 100644 index 9560b1fd4..000000000 --- a/crates/voicevox_core_macros/src/manifest.rs +++ /dev/null @@ -1,72 +0,0 @@ -use proc_macro2::Span; -use quote::quote; -use syn::{Attribute, DeriveInput, Expr, Meta, Type}; - -pub(crate) fn derive_index_for_fields( - input: &DeriveInput, -) -> syn::Result { - const ATTR_NAME: &str = "index_for_fields"; - - let DeriveInput { - attrs, - ident, - generics, - data, - .. - } = input; - - let idx = attrs - .iter() - .find_map(|Attribute { meta, .. }| match meta { - Meta::List(list) if list.path.is_ident(ATTR_NAME) => Some(list), - _ => None, - }) - .ok_or_else(|| { - syn::Error::new( - Span::call_site(), - format!("missing `#[{ATTR_NAME}(…)]` in the struct itself"), - ) - })? - .parse_args::()?; - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let targets = crate::extract::struct_fields(data)? - .into_iter() - .flat_map(|(attrs, name, output)| { - let meta = attrs.iter().find_map(|Attribute { meta, .. }| match meta { - Meta::List(meta) if meta.path.is_ident(ATTR_NAME) => Some(meta), - _ => None, - })?; - Some((meta, name, output)) - }) - .map(|(meta, name, output)| { - let key = meta.parse_args::()?; - Ok((key, name, output)) - }) - .collect::>>()?; - - let (_, _, output) = targets.first().ok_or_else(|| { - syn::Error::new( - Span::call_site(), - format!("no fields have `#[{ATTR_NAME}(…)]`"), - ) - })?; - - let arms = targets - .iter() - .map(|(key, name, _)| Ok(quote!(#key => &self.#name))) - .collect::>>()?; - - Ok(quote! { - impl #impl_generics ::std::ops::Index<#idx> for #ident #ty_generics #where_clause { - type Output = #output; - - fn index(&self, index: #idx) -> &Self::Output { - match index { - #(#arms),* - } - } - } - }) -}