Skip to content

Commit

Permalink
signaturesのマクロ化を完了させる
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 13, 2023
1 parent 2274a34 commit b7d48f3
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 81 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ derive_more = "0.99.17"
easy-ext = "1.0.1"
fs-err = { version = "2.9.0", features = ["tokio"] }
futures = "0.3.26"
indexmap = { version = "2.0.0", features = ["serde"] }
itertools = "0.10.5"
ndarray = "0.15.6"
once_cell = "1.18.0"
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ educe = "0.4.23"
enum-map = "3.0.0-beta.1"
fs-err.workspace = true
futures.workspace = true
indexmap = { version = "2.0.0", features = ["serde"] }
indexmap.workspace = true
itertools.workspace = true
nanoid = "0.4.0"
ndarray.workspace = true
Expand Down
85 changes: 27 additions & 58 deletions crates/voicevox_core/src/infer/signatures.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,34 @@
use enum_map::{Enum, EnumMap};
use macros::{InferenceInputSignature, InferenceOutputSignature};
use enum_map::Enum;
use macros::{InferenceGroup, InferenceInputSignature, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};

use super::{
InferenceGroup, InferenceInputSignature as _, InferenceOutputSignature as _,
InferenceSignature, OutputTensor,
};
use super::{InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor};

#[derive(Clone, Copy, Enum)]
#[derive(Clone, Copy, Enum, InferenceGroup)]
pub(crate) enum InferenceKind {
#[inference_group(
type Input = PredictDurationInput;
type Output = PredictDurationOutput;
)]
PredictDuration,
PredictIntonation,
Decode,
}

// FIXME: ここもマクロ化する
impl InferenceGroup for InferenceKind {
const INPUT_PARAM_INFOS: enum_map::EnumMap<
Self,
&'static [super::ParamInfo<super::InputScalarKind>],
> = EnumMap::from_array([
PredictDurationInput::PARAM_INFOS,
PredictIntonationInput::PARAM_INFOS,
DecodeInput::PARAM_INFOS,
]);

const OUTPUT_PARAM_INFOS: enum_map::EnumMap<
Self,
&'static [super::ParamInfo<super::OutputScalarKind>],
> = EnumMap::from_array([
PredictDurationOutput::PARAM_INFOS,
PredictIntonationOutput::PARAM_INFOS,
DecodeOutput::PARAM_INFOS,
]);
}

pub(crate) enum PredictDuration {}
#[inference_group(
type Input = PredictIntonationInput;
type Output = PredictIntonationOutput;
)]
PredictIntonation,

impl InferenceSignature for PredictDuration {
type Group = InferenceKind;
type Input = PredictDurationInput;
type Output = PredictDurationOutput;
const KIND: InferenceKind = InferenceKind::PredictDuration;
#[inference_group(
type Input = DecodeInput;
type Output = DecodeOutput;
)]
Decode,
}

#[derive(InferenceInputSignature)]
#[input_signature(Signature = PredictDuration)]
#[inference_input_signature(
type Signature = PredictDuration;
)]
pub(crate) struct PredictDurationInput {
pub(crate) phoneme_list: Array1<i64>,
pub(crate) speaker_id: Array1<i64>,
Expand All @@ -56,17 +39,10 @@ pub(crate) struct PredictDurationOutput {
pub(crate) phoneme_length: Array1<f32>,
}

pub(crate) enum PredictIntonation {}

impl InferenceSignature for PredictIntonation {
type Group = InferenceKind;
type Input = PredictIntonationInput;
type Output = PredictIntonationOutput;
const KIND: InferenceKind = InferenceKind::PredictIntonation;
}

#[derive(InferenceInputSignature)]
#[input_signature(Signature = PredictIntonation)]
#[inference_input_signature(
type Signature = PredictIntonation;
)]
pub(crate) struct PredictIntonationInput {
pub(crate) length: Array0<i64>,
pub(crate) vowel_phoneme_list: Array1<i64>,
Expand All @@ -83,17 +59,10 @@ pub(crate) struct PredictIntonationOutput {
pub(crate) f0_list: Array1<f32>,
}

pub(crate) enum Decode {}

impl InferenceSignature for Decode {
type Group = InferenceKind;
type Input = DecodeInput;
type Output = DecodeOutput;
const KIND: InferenceKind = InferenceKind::Decode;
}

#[derive(InferenceInputSignature)]
#[input_signature(Signature = Decode)]
#[inference_input_signature(
type Signature = Decode;
)]
pub(crate) struct DecodeInput {
pub(crate) f0: Array2<f32>,
pub(crate) phoneme: Array2<f32>,
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ name = "macros"
proc-macro = true

[dependencies]
indexmap.workspace = true
proc-macro2 = "1.0.69"
quote = "1.0.33"
syn = "2.0.38"
syn = { version = "2.0.38", features = ["extra-traits"] }
178 changes: 157 additions & 21 deletions crates/voicevox_core_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,142 @@
#![warn(rust_2018_idioms)]

use indexmap::IndexMap;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned as _,
Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, Type,
Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Generics,
ItemType, Type, Variant,
};

#[proc_macro_derive(InferenceGroup)]
#[proc_macro_derive(InferenceGroup, attributes(inference_group))]
pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let DeriveInput {
ident, generics, ..
} = parse_macro_input!(input as DeriveInput);
return derive_inference_group(&parse_macro_input!(input))
.unwrap_or_else(|e| e.to_compile_error())
.into();

fn derive_inference_group(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let DeriveInput {
vis,
ident: group_name,
generics,
data,
..
} = input;

deny_generics(generics)?;

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let variants = unit_enum_variants(data)?
.into_iter()
.map(|(attrs, variant_name)| {
let AssocTypes { input, output } = attrs
.iter()
.find(|a| a.path().is_ident("inference_group"))
.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"missing `#[inference_group(…)]`",
)
})?
.parse_args()?;

Ok((variant_name, (input, output)))
})
.collect::<syn::Result<IndexMap<_, _>>>()?;

quote! {
impl #impl_generics crate::infer::InferenceGroup for #ident #ty_generics #where_clause {}
let variant_names = &variants.keys().collect::<Vec<_>>();

let signatures = variants
.iter()
.map(|(variant_name, (input_ty, output_ty))| {
quote! {
#vis enum #variant_name {}

impl crate::infer::InferenceSignature for #variant_name {
type Group = #group_name;
type Input = #input_ty;
type Output = #output_ty;
const KIND: Self::Group = #group_name :: #variant_name;
}
}
});

Ok(quote! {
impl crate::infer::InferenceGroup for #group_name {
const INPUT_PARAM_INFOS: ::enum_map::EnumMap<
Self,
&'static [crate::infer::ParamInfo<crate::infer::InputScalarKind>],
> = ::enum_map::EnumMap::from_array([
#(<#variant_names as crate::infer::InferenceSignature>::Input::PARAM_INFOS),*
]);

const OUTPUT_PARAM_INFOS: ::enum_map::EnumMap<
Self,
&'static [crate::infer::ParamInfo<crate::infer::OutputScalarKind>],
> = ::enum_map::EnumMap::from_array([
#(<#variant_names as crate::infer::InferenceSignature>::Output::PARAM_INFOS),*
]);
}

#(#signatures)*
})
}

struct AssocTypes {
input: Type,
output: Type,
}

impl Parse for AssocTypes {
fn parse(stream: ParseStream<'_>) -> syn::Result<Self> {
let mut input = None;
let mut output = None;

while !stream.is_empty() {
let ItemType {
ident,
generics,
ty,
..
} = stream.parse()?;

deny_generics(&generics)?;

*match &*ident.to_string() {
"Input" => &mut input,
"Output" => &mut output,
_ => {
return Err(syn::Error::new(
ident.span(),
"expected `Input` or `Output`",
))
}
} = Some(*ty);
}

let input =
input.ok_or_else(|| syn::Error::new(stream.span(), "missing `type Input = …;`"))?;

let output = output
.ok_or_else(|| syn::Error::new(stream.span(), "missing `type Output = …;`"))?;

Ok(Self { input, output })
}
}

fn deny_generics(generics: &Generics) -> syn::Result<()> {
if !generics.params.is_empty() {
return Err(syn::Error::new(generics.params.span(), "must be empty"));
}
if let Some(where_clause) = &generics.where_clause {
return Err(syn::Error::new(where_clause.span(), "must be empty"));
}
Ok(())
}
.into()
}

#[proc_macro_derive(InferenceInputSignature, attributes(input_signature))]
#[proc_macro_derive(InferenceInputSignature, attributes(inference_input_signature))]
pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
return derive_inference_input_signature(&parse_macro_input!(input))
.unwrap_or_else(|e| e.to_compile_error())
Expand All @@ -41,11 +155,11 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_

let AssocTypeSignature(signature) = attrs
.iter()
.find(|a| a.path().is_ident("input_signature"))
.find(|a| a.path().is_ident("inference_input_signature"))
.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"missing `#[input_signature(…)]`",
"missing `#[inference_input_signature(…)]`",
)
})?
.parse_args()?;
Expand Down Expand Up @@ -100,17 +214,16 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_
})
}

struct AssocTypeSignature(syn::Ident);
struct AssocTypeSignature(Type);

impl Parse for AssocTypeSignature {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let key = input.parse::<syn::Ident>()?;
if key != "Signature" {
return Err(syn::Error::new(key.span(), "expected `Signature`"));
let ItemType { ident, ty, .. } = input.parse()?;

if ident != "Signature" {
return Err(syn::Error::new(ident.span(), "expected `Signature`"));
}
input.parse::<Token![=]>()?;
let value = input.parse::<syn::Ident>()?;
Ok(Self(value))
Ok(Self(*ty))
}
}
}
Expand Down Expand Up @@ -211,10 +324,10 @@ fn struct_fields(data: &Data) -> syn::Result<Vec<(&syn::Ident, &Type)>> {
return Err(syn::Error::new(fields.span(), "expect named fields"));
}
Data::Enum(DataEnum { enum_token, .. }) => {
return Err(syn::Error::new(enum_token.span(), "expected an enum"));
return Err(syn::Error::new(enum_token.span(), "expected a struct"));
}
Data::Union(DataUnion { union_token, .. }) => {
return Err(syn::Error::new(union_token.span(), "expected an enum"));
return Err(syn::Error::new(union_token.span(), "expected a struct"));
}
};

Expand All @@ -224,3 +337,26 @@ fn struct_fields(data: &Data) -> syn::Result<Vec<(&syn::Ident, &Type)>> {
.map(|Field { ident, ty, .. }| (ident.as_ref().expect("should be named"), ty))
.collect())
}

fn unit_enum_variants(data: &Data) -> syn::Result<Vec<(&[Attribute], &syn::Ident)>> {
let variants = match data {
Data::Struct(DataStruct { struct_token, .. }) => {
return Err(syn::Error::new(struct_token.span(), "expected an enum"));
}
Data::Enum(DataEnum { variants, .. }) => variants,
Data::Union(DataUnion { union_token, .. }) => {
return Err(syn::Error::new(union_token.span(), "expected an enum"));
}
};

for Variant { fields, .. } in variants {
if *fields != Fields::Unit {
return Err(syn::Error::new(fields.span(), "must be unit"));
}
}

Ok(variants
.iter()
.map(|Variant { attrs, ident, .. }| (&**attrs, ident))
.collect())
}

0 comments on commit b7d48f3

Please sign in to comment.