From 0fe5ab87d0b3b7402e5457850c65592c18a97d92 Mon Sep 17 00:00:00 2001 From: Yun-Jhong Wu Date: Wed, 31 Jan 2024 22:14:38 -0600 Subject: [PATCH] NestedStrongType with custom_underlying --- strong-type-derive/src/detail/constants.rs | 13 ++++ strong-type-derive/src/detail/mod.rs | 4 +- .../src/detail/underlying_type.rs | 65 ++++++++++++++++++- strong-type-derive/src/lib.rs | 2 +- strong-type-derive/src/strong_type.rs | 47 +++++++++----- strong-type-tests/tests/custom_underlying.rs | 26 ++++++++ strong-type-tests/tests/tests.rs | 1 + 7 files changed, 136 insertions(+), 22 deletions(-) create mode 100644 strong-type-tests/tests/custom_underlying.rs diff --git a/strong-type-derive/src/detail/constants.rs b/strong-type-derive/src/detail/constants.rs index 04cb505..7e0e7a8 100644 --- a/strong-type-derive/src/detail/constants.rs +++ b/strong-type-derive/src/detail/constants.rs @@ -11,3 +11,16 @@ pub(crate) fn implement_constants(name: &syn::Ident, value_type: &syn::Ident) -> } } } +pub(crate) fn implement_constants_derive( + name: &syn::Ident, + value_type: &syn::Ident, +) -> TokenStream { + quote! { + impl #name { + pub const MIN: Self = Self(#value_type::MIN); + pub const MAX: Self = Self(#value_type::MAX); + pub const ZERO: Self = Self(#value_type::ZERO); + pub const ONE: Self = Self(#value_type::ONE); + } + } +} diff --git a/strong-type-derive/src/detail/mod.rs b/strong-type-derive/src/detail/mod.rs index e112118..913ab13 100644 --- a/strong-type-derive/src/detail/mod.rs +++ b/strong-type-derive/src/detail/mod.rs @@ -18,10 +18,10 @@ pub(crate) use basic_primitive::implement_basic_primitive; pub(crate) use basic_string::implement_basic_string; pub(crate) use bit_ops::implement_bit_shift; pub(crate) use bool_ops::implement_bool_ops; -pub(crate) use constants::implement_constants; +pub(crate) use constants::{implement_constants, implement_constants_derive}; pub(crate) use display::implement_display; pub(crate) use hash::implement_hash; pub(crate) use nan::implement_nan; pub(crate) use negate::implement_negate; -pub(crate) use underlying_type::{get_type_group, get_type_ident, UnderlyingTypeGroup}; +pub(crate) use underlying_type::{get_type, TypeInfo, UnderlyingTypeGroup}; pub(crate) use utils::{get_attributes, is_struct_valid, StrongTypeAttributes}; diff --git a/strong-type-derive/src/detail/underlying_type.rs b/strong-type-derive/src/detail/underlying_type.rs index e43dd9b..f74c746 100644 --- a/strong-type-derive/src/detail/underlying_type.rs +++ b/strong-type-derive/src/detail/underlying_type.rs @@ -9,15 +9,62 @@ pub(crate) enum UnderlyingTypeGroup { Bool, Char, String, + IntDerive, + FloatDerive, + UIntDerive, + BoolDerive, + CharDerive, + StringDerive, } -pub(crate) fn get_type_ident(input: &DeriveInput) -> &syn::Ident { +pub(crate) struct TypeInfo<'a> { + pub value_type: &'a syn::Ident, + pub type_group: UnderlyingTypeGroup, +} + +fn get_type_ident(input: &DeriveInput) -> Option<&syn::Ident> { if let Data::Struct(ref data_struct) = input.data { if let Type::Path(ref path) = &data_struct.fields.iter().next().unwrap().ty { - return &path.path.segments.last().unwrap().ident; + return Some(&path.path.segments.last().unwrap().ident); + } + } + None +} + +fn get_group_from_custom_underlying(input: &DeriveInput) -> Option { + for attr in input.attrs.iter() { + if attr.path().is_ident("custom_underlying") { + let mut type_group = None; + attr.parse_nested_meta(|meta| { + if let Some(ident) = meta.path.get_ident() { + type_group = Some(get_type_group(ident)); + Ok(()) + } else { + Err(meta.error("Unsupported attribute")) + } + }) + .ok()?; + return type_group; } } - panic!("Unsupported input") + + None +} + +pub(crate) fn get_type(input: &DeriveInput) -> TypeInfo { + if let Some(value_type) = get_type_ident(input) { + let type_group = match get_group_from_custom_underlying(input) { + Some(type_group) => type_group_to_derive(type_group), + None => get_type_group(value_type), + }; + + TypeInfo { + value_type, + type_group, + } + } else { + panic!("Unsupported input") + } } pub(crate) fn get_type_group(value_type: &syn::Ident) -> UnderlyingTypeGroup { @@ -53,3 +100,15 @@ pub(crate) fn get_type_group(value_type: &syn::Ident) -> UnderlyingTypeGroup { } panic!("Unsupported type: {}", value_type); } + +pub fn type_group_to_derive(type_group: UnderlyingTypeGroup) -> UnderlyingTypeGroup { + match type_group { + UnderlyingTypeGroup::Int => UnderlyingTypeGroup::IntDerive, + UnderlyingTypeGroup::UInt => UnderlyingTypeGroup::UIntDerive, + UnderlyingTypeGroup::Float => UnderlyingTypeGroup::FloatDerive, + UnderlyingTypeGroup::Bool => UnderlyingTypeGroup::BoolDerive, + UnderlyingTypeGroup::Char => UnderlyingTypeGroup::CharDerive, + UnderlyingTypeGroup::String => UnderlyingTypeGroup::StringDerive, + _ => panic!("Unsupported type group {:?}", type_group), + } +} diff --git a/strong-type-derive/src/lib.rs b/strong-type-derive/src/lib.rs index 0d06ed1..ca75986 100644 --- a/strong-type-derive/src/lib.rs +++ b/strong-type-derive/src/lib.rs @@ -6,7 +6,7 @@ use syn::{parse_macro_input, DeriveInput}; use crate::strong_type::expand_strong_type; -#[proc_macro_derive(StrongType, attributes(strong_type))] +#[proc_macro_derive(StrongType, attributes(strong_type, custom_underlying))] pub fn strong_type(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_strong_type(input).into() diff --git a/strong-type-derive/src/strong_type.rs b/strong-type-derive/src/strong_type.rs index 71ac51d..08aa01a 100644 --- a/strong-type-derive/src/strong_type.rs +++ b/strong-type-derive/src/strong_type.rs @@ -1,8 +1,8 @@ use crate::detail::{ - get_attributes, get_type_group, get_type_ident, implement_arithmetic, implement_basic, - implement_basic_primitive, implement_basic_string, implement_bit_shift, implement_bool_ops, - implement_constants, implement_display, implement_hash, implement_nan, implement_negate, - is_struct_valid, StrongTypeAttributes, UnderlyingTypeGroup, + get_attributes, get_type, implement_arithmetic, implement_basic, implement_basic_primitive, + implement_basic_string, implement_bit_shift, implement_bool_ops, implement_constants, + implement_constants_derive, implement_display, implement_hash, implement_nan, implement_negate, + is_struct_valid, StrongTypeAttributes, TypeInfo, UnderlyingTypeGroup, }; use proc_macro2::TokenStream; use quote::quote; @@ -14,8 +14,10 @@ pub(super) fn expand_strong_type(input: DeriveInput) -> TokenStream { } let name = &input.ident; - let value_type = get_type_ident(&input); - let group = get_type_group(value_type); + let TypeInfo { + value_type, + type_group, + } = get_type(&input); let StrongTypeAttributes { has_auto_operators, has_custom_display, @@ -28,50 +30,63 @@ pub(super) fn expand_strong_type(input: DeriveInput) -> TokenStream { ast.extend(implement_display(name)); }; - match &group { + match &type_group { UnderlyingTypeGroup::Int | UnderlyingTypeGroup::UInt => { ast.extend(implement_basic_primitive(name, value_type)); ast.extend(implement_constants(name, value_type)); ast.extend(implement_hash(name)); } + UnderlyingTypeGroup::IntDerive | UnderlyingTypeGroup::UIntDerive => { + ast.extend(implement_basic_primitive(name, value_type)); + ast.extend(implement_constants_derive(name, value_type)); + ast.extend(implement_hash(name)); + } UnderlyingTypeGroup::Float => { ast.extend(implement_basic_primitive(name, value_type)); ast.extend(implement_constants(name, value_type)); ast.extend(implement_nan(name, value_type)); } - UnderlyingTypeGroup::Bool => { + UnderlyingTypeGroup::FloatDerive => { + ast.extend(implement_basic_primitive(name, value_type)); + ast.extend(implement_constants_derive(name, value_type)); + ast.extend(implement_nan(name, value_type)); + } + UnderlyingTypeGroup::Bool | UnderlyingTypeGroup::BoolDerive => { ast.extend(implement_basic_primitive(name, value_type)); ast.extend(implement_hash(name)); } - UnderlyingTypeGroup::Char => { + UnderlyingTypeGroup::Char | UnderlyingTypeGroup::CharDerive => { ast.extend(implement_basic_primitive(name, value_type)); ast.extend(implement_hash(name)); } - UnderlyingTypeGroup::String => { + UnderlyingTypeGroup::String | UnderlyingTypeGroup::StringDerive => { ast.extend(implement_basic_string(name)); ast.extend(implement_hash(name)); } } if has_auto_operators { - match &group { - UnderlyingTypeGroup::Float => { + match &type_group { + UnderlyingTypeGroup::Float | UnderlyingTypeGroup::FloatDerive => { ast.extend(implement_arithmetic(name)); ast.extend(implement_negate(name)); } - UnderlyingTypeGroup::Int => { + UnderlyingTypeGroup::Int | UnderlyingTypeGroup::IntDerive => { ast.extend(implement_arithmetic(name)); ast.extend(implement_negate(name)); ast.extend(implement_bit_shift(name)); } - UnderlyingTypeGroup::UInt => { + UnderlyingTypeGroup::UInt | UnderlyingTypeGroup::UIntDerive => { ast.extend(implement_arithmetic(name)); ast.extend(implement_bit_shift(name)); } - UnderlyingTypeGroup::Bool => { + UnderlyingTypeGroup::Bool | UnderlyingTypeGroup::BoolDerive => { ast.extend(implement_bool_ops(name)); } - _ => {} + UnderlyingTypeGroup::Char + | UnderlyingTypeGroup::CharDerive + | UnderlyingTypeGroup::String + | UnderlyingTypeGroup::StringDerive => {} } } diff --git a/strong-type-tests/tests/custom_underlying.rs b/strong-type-tests/tests/custom_underlying.rs new file mode 100644 index 0000000..532f82a --- /dev/null +++ b/strong-type-tests/tests/custom_underlying.rs @@ -0,0 +1,26 @@ +#[cfg(test)] +mod tests { + use std::mem; + use strong_type::StrongType; + + fn test_type() {} + + #[test] + fn test_custom_underlying() { + #[derive(StrongType)] + #[strong_type(auto_operators)] + struct Dollar(i32); + + #[derive(StrongType)] + #[strong_type(auto_operators)] + #[custom_underlying(i32)] + struct Cash(Dollar); + test_type::(); + assert_eq!(mem::size_of::(), mem::size_of::()); + + assert_eq!( + Cash::new(Dollar::new(10)), + Cash::new(Dollar::new(2)) + Cash::new(Dollar::new(8)) + ); + } +} diff --git a/strong-type-tests/tests/tests.rs b/strong-type-tests/tests/tests.rs index 88e41f7..1d60509 100644 --- a/strong-type-tests/tests/tests.rs +++ b/strong-type-tests/tests/tests.rs @@ -1,3 +1,4 @@ mod auto_operators; +mod custom_underlying; mod display; mod strong_type;