diff --git a/README.md b/README.md index 3260515..8ca5e63 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ println!("{}", timestamp); // Timestamp(1701620628123456789) - Adding the following attributes to `#[strong_type(...)]` allows for additional features: - `auto_operators`: Automatically implements relevant arithmetic (for numeric types) or logical (for boolean types) operators. - `custom_display`: Allows users to manually implement the `Display` trait, providing an alternative to the default display format. + - `conversion`: Automatically implements `From` and `Into` traits for the underlying type. - `underlying`: Specifies the underlying primitive type for nested strong types. ## Installation diff --git a/strong-type-derive/src/detail/conversion.rs b/strong-type-derive/src/detail/conversion.rs new file mode 100644 index 0000000..09c05cf --- /dev/null +++ b/strong-type-derive/src/detail/conversion.rs @@ -0,0 +1,28 @@ +use proc_macro2::TokenStream; +use quote::quote; + +pub(crate) fn implement_conversion(name: &syn::Ident, value_type: &syn::Ident) -> TokenStream { + quote! { + impl From<#value_type> for #name { + fn from(value: #value_type) -> Self { + Self::new(value) + } + } + + impl From<#name> for #value_type { + fn from(value: #name) -> #value_type { + value.0 + } + } + } +} + +pub(crate) fn implement_str_conversion(name: &syn::Ident) -> TokenStream { + quote! { + impl From<&str> for #name { + fn from(value: &str) -> Self { + Self::new(value) + } + } + } +} diff --git a/strong-type-derive/src/detail/mod.rs b/strong-type-derive/src/detail/mod.rs index dd4a8a2..0211066 100644 --- a/strong-type-derive/src/detail/mod.rs +++ b/strong-type-derive/src/detail/mod.rs @@ -5,6 +5,7 @@ mod basic_string; mod bit_ops; mod bool_ops; mod constants; +mod conversion; mod display; mod hash; mod nan; @@ -26,6 +27,7 @@ pub(crate) use bool_ops::implement_bool_ops; pub(crate) use constants::{ implement_constants, implement_constants_derived, implement_infinity, implement_limit, }; +pub(crate) use conversion::{implement_conversion, implement_str_conversion}; pub(crate) use display::implement_display; pub(crate) use hash::implement_hash; pub(crate) use nan::implement_nan; diff --git a/strong-type-derive/src/detail/utils.rs b/strong-type-derive/src/detail/utils.rs index 7e25591..5f52c47 100644 --- a/strong-type-derive/src/detail/utils.rs +++ b/strong-type-derive/src/detail/utils.rs @@ -5,6 +5,7 @@ use syn::{Data, DeriveInput, Fields, Visibility}; pub(crate) struct StrongTypeAttributes { pub has_auto_operators: bool, pub has_custom_display: bool, + pub has_conversion: bool, pub type_info: TypeInfo, } @@ -12,6 +13,7 @@ pub(crate) fn get_attributes(input: &DeriveInput) -> StrongTypeAttributes { let mut attributes = StrongTypeAttributes { has_auto_operators: false, has_custom_display: false, + has_conversion: false, type_info: get_type(input), }; @@ -24,6 +26,9 @@ pub(crate) fn get_attributes(input: &DeriveInput) -> StrongTypeAttributes { } else if meta.path.is_ident("custom_display") { attributes.has_custom_display = true; Ok(()) + } else if meta.path.is_ident("conversion") { + attributes.has_conversion = true; + Ok(()) } else if meta.path.is_ident("underlying") { if let Ok(strm) = meta.value() { if let Ok(primitive_type) = strm.parse::() { @@ -35,7 +40,7 @@ pub(crate) fn get_attributes(input: &DeriveInput) -> StrongTypeAttributes { } Ok(()) } else { - Err(meta.error(format!("Invalid strong_type attribute {}, should be one of {{auto_operators, custom_display, underlying=type}}", + Err(meta.error(format!("Invalid strong_type attribute {}, should be one of {{auto_operators, custom_display, conversion, underlying=type}}", meta.path.get_ident().expect("Failed to parse strong_type attributes.")))) } }) { diff --git a/strong-type-derive/src/strong_type.rs b/strong-type-derive/src/strong_type.rs index 60734ac..7b86f6f 100644 --- a/strong-type-derive/src/strong_type.rs +++ b/strong-type-derive/src/strong_type.rs @@ -1,11 +1,12 @@ use crate::detail::{ get_attributes, implement_arithmetic, implement_basic, implement_basic_primitive, implement_basic_string, implement_bit_shift, implement_bool_ops, implement_constants, - implement_constants_derived, implement_display, implement_hash, implement_infinity, - implement_limit, implement_nan, implement_negate, implement_primitive_accessor, - implement_primitive_accessor_derived, implement_primitive_str_accessor, - implement_primitive_str_accessor_derived, is_struct_valid, StrongTypeAttributes, TypeInfo, - UnderlyingType, ValueTypeGroup, + implement_constants_derived, implement_conversion, implement_display, implement_hash, + implement_infinity, implement_limit, implement_nan, implement_negate, + implement_primitive_accessor, implement_primitive_accessor_derived, + implement_primitive_str_accessor, implement_primitive_str_accessor_derived, + implement_str_conversion, is_struct_valid, StrongTypeAttributes, TypeInfo, UnderlyingType, + ValueTypeGroup, }; use proc_macro2::TokenStream; use quote::quote; @@ -20,6 +21,7 @@ pub(super) fn expand_strong_type(input: DeriveInput) -> TokenStream { let StrongTypeAttributes { has_auto_operators, has_custom_display, + has_conversion, type_info: TypeInfo { primitive_type, @@ -37,6 +39,13 @@ pub(super) fn expand_strong_type(input: DeriveInput) -> TokenStream { ast.extend(implement_display(name)); }; + if has_conversion { + ast.extend(implement_conversion(name, &value_type)); + if let ValueTypeGroup::String(UnderlyingType::Primitive) = &type_group { + ast.extend(implement_str_conversion(name)); + } + } + match &type_group { ValueTypeGroup::Int(underlying_type) | ValueTypeGroup::UInt(underlying_type) diff --git a/strong-type-tests/tests/conversion.rs b/strong-type-tests/tests/conversion.rs new file mode 100644 index 0000000..51f8933 --- /dev/null +++ b/strong-type-tests/tests/conversion.rs @@ -0,0 +1,33 @@ +#[cfg(test)] +mod tests { + use std::fmt::Debug; + use strong_type::StrongType; + + fn test(value: T, underlying: T) { + assert_eq!(value, underlying); + } + + #[test] + fn test_conversion() { + #[derive(StrongType)] + #[strong_type(conversion)] + struct NamedI32(i64); + + let i64_value = 64; + let value = NamedI32::new(i64_value); + + test(value, i64_value.into()); + test(value.into(), i64_value); + + #[derive(StrongType)] + #[strong_type(conversion)] + struct NamedString(String); + let str_value = "test"; + let string_value = String::from(str_value); + let value = NamedString::new(string_value.clone()); + + test(value.clone(), str_value.into()); + test(value.clone(), string_value.clone().into()); + test(value.into(), string_value); + } +} diff --git a/strong-type-tests/tests/tests.rs b/strong-type-tests/tests/tests.rs index 1d60509..d8bf539 100644 --- a/strong-type-tests/tests/tests.rs +++ b/strong-type-tests/tests/tests.rs @@ -1,4 +1,5 @@ mod auto_operators; +mod conversion; mod custom_underlying; mod display; mod strong_type;