Skip to content

Commit

Permalink
ArrayExtをマクロ内に押し込める
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 16, 2023
1 parent af828eb commit b6b7975
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 28 deletions.
10 changes: 0 additions & 10 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,6 @@ impl<D: PartialEq> ParamInfo<D> {
}
}

pub(crate) trait ArrayExt {
type Scalar;
type Dimension: Dimension;
}

impl<A, D: Dimension> ArrayExt for Array<A, D> {
type Scalar = A;
type Dimension = D;
}

#[derive(new, Clone, Copy, PartialEq, Debug)]
pub(crate) struct InferenceSessionOptions {
pub(crate) cpu_num_threads: u16,
Expand Down
52 changes: 34 additions & 18 deletions crates/voicevox_core_macros/src/inference_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,8 @@ pub(crate) fn derive_inference_input_signature(
quote! {
crate::infer::ParamInfo {
name: ::std::borrow::Cow::Borrowed(#name),
dt: <
<#ty as crate::infer::ArrayExt>::Scalar as crate::infer::InputScalar
>::KIND,
ndim: <
<#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension
>::NDIM,
dt: <<#ty as __ArrayExt>::Scalar as crate::infer::InputScalar>::KIND,
ndim: <<#ty as __ArrayExt>::Dimension as ::ndarray::Dimension>::NDIM,
},
}
})
Expand All @@ -208,9 +204,21 @@ pub(crate) fn derive_inference_input_signature(

const PARAM_INFOS: &'static [crate::infer::ParamInfo<
crate::infer::InputScalarKind
>] = &[
#param_infos
];
>] = {
trait __ArrayExt {
type Scalar: crate::infer::InputScalar;
type Dimension: ::ndarray::Dimension + 'static;
}

impl<A: crate::infer::InputScalar, D: ::ndarray::Dimension + 'static> __ArrayExt
for ::ndarray::Array<A, D>
{
type Scalar = A;
type Dimension = D;
}

&[#param_infos]
};

fn make_run_context<R: crate::infer::InferenceRuntime>(
self,
Expand Down Expand Up @@ -261,12 +269,8 @@ pub(crate) fn derive_inference_output_signature(
quote! {
crate::infer::ParamInfo {
name: ::std::borrow::Cow::Borrowed(#name),
dt: <
<#ty as crate::infer::ArrayExt>::Scalar as crate::infer::OutputScalar
>::KIND,
ndim: <
<#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension
>::NDIM,
dt: <<#ty as __ArrayExt>::Scalar as crate::infer::OutputScalar>::KIND,
ndim: <<#ty as __ArrayExt>::Dimension as ::ndarray::Dimension>::NDIM,
},
}
})
Expand All @@ -280,9 +284,21 @@ pub(crate) fn derive_inference_output_signature(
{
const PARAM_INFOS: &'static [crate::infer::ParamInfo<
crate::infer::OutputScalarKind
>] = &[
#param_infos
];
>] = {
trait __ArrayExt {
type Scalar: crate::infer::OutputScalar;
type Dimension: ::ndarray::Dimension + 'static;
}

impl<A: crate::infer::OutputScalar, D: ::ndarray::Dimension + 'static> __ArrayExt
for ::ndarray::Array<A, D>
{
type Scalar = A;
type Dimension = D;
}

&[#param_infos]
};
}

impl #impl_generics ::std::convert::TryFrom<::std::vec::Vec<crate::infer::OutputTensor>>
Expand Down

0 comments on commit b6b7975

Please sign in to comment.