diff --git a/der_derive/src/attributes.rs b/der_derive/src/attributes.rs index 74099703b..911adcbd6 100644 --- a/der_derive/src/attributes.rs +++ b/der_derive/src/attributes.rs @@ -2,11 +2,33 @@ use crate::{Asn1Type, Tag, TagMode, TagNumber}; use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{quote, ToTokens}; use std::{fmt::Debug, str::FromStr}; use syn::punctuated::Punctuated; use syn::{parse::Parse, parse::ParseStream, Attribute, Ident, LitStr, Path, Token}; +/// Error type used by the structure +#[derive(Debug, Clone, Default, Eq, PartialEq)] +pub(crate) enum ErrorType { + /// Represents the ::der::Error type + #[default] + Der, + /// Represents an error designed by Path + Custom(Path), +} + +impl ToTokens for ErrorType { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Self::Der => { + let err = quote! { ::der::Error }; + err.to_tokens(tokens) + } + Self::Custom(path) => path.to_tokens(tokens), + } + } +} + /// Attribute name. pub(crate) const ATTR_NAME: &str = "asn1"; @@ -18,7 +40,7 @@ pub(crate) struct TypeAttrs { /// /// The default value is `EXPLICIT`. pub tag_mode: TagMode, - pub error: Option, + pub error: ErrorType, } impl TypeAttrs { @@ -44,7 +66,7 @@ impl TypeAttrs { abort!(attr, "duplicate ASN.1 `error` attribute"); } - error = Some(meta.value()?.parse()?); + error = Some(ErrorType::Custom(meta.value()?.parse()?)); } else { return Err(syn::Error::new_spanned( attr, @@ -58,7 +80,7 @@ impl TypeAttrs { Ok(Self { tag_mode: tag_mode.unwrap_or_default(), - error, + error: error.unwrap_or_default(), }) } } diff --git a/der_derive/src/choice.rs b/der_derive/src/choice.rs index 8f10aa89c..8cd50ca01 100644 --- a/der_derive/src/choice.rs +++ b/der_derive/src/choice.rs @@ -5,10 +5,10 @@ mod variant; use self::variant::ChoiceVariant; -use crate::{default_lifetime, TypeAttrs}; +use crate::{default_lifetime, ErrorType, TypeAttrs}; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam, Path}; +use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; /// Derive the `Choice` trait for an enum. pub(crate) struct DeriveChoice { @@ -22,7 +22,7 @@ pub(crate) struct DeriveChoice { variants: Vec, /// Error type for `DecodeValue` implementation. - error: Option, + error: ErrorType, } impl DeriveChoice { @@ -36,7 +36,7 @@ impl DeriveChoice { ), }; - let mut type_attrs = TypeAttrs::parse(&input.attrs)?; + let type_attrs = TypeAttrs::parse(&input.attrs)?; let variants = data .variants .iter() @@ -47,7 +47,7 @@ impl DeriveChoice { ident: input.ident, generics: input.generics.clone(), variants, - error: type_attrs.error.take(), + error: type_attrs.error.clone(), }) } @@ -88,11 +88,7 @@ impl DeriveChoice { tagged_body.push(variant.to_tagged_tokens()); } - let error = self - .error - .as_ref() - .map(ToTokens::to_token_stream) - .unwrap_or_else(|| quote! { ::der::Error }); + let error = self.error.to_token_stream(); quote! { impl #impl_generics ::der::Choice<#lifetime> for #ident #ty_generics #where_clause { diff --git a/der_derive/src/lib.rs b/der_derive/src/lib.rs index 79e73b663..e5fd48b13 100644 --- a/der_derive/src/lib.rs +++ b/der_derive/src/lib.rs @@ -144,7 +144,7 @@ mod value_ord; use crate::{ asn1_type::Asn1Type, - attributes::{FieldAttrs, TypeAttrs, ATTR_NAME}, + attributes::{ErrorType, FieldAttrs, TypeAttrs, ATTR_NAME}, choice::DeriveChoice, enumerated::DeriveEnumerated, sequence::DeriveSequence, diff --git a/der_derive/src/sequence.rs b/der_derive/src/sequence.rs index 360525a03..f347c727f 100644 --- a/der_derive/src/sequence.rs +++ b/der_derive/src/sequence.rs @@ -3,11 +3,11 @@ mod field; -use crate::{default_lifetime, TypeAttrs}; +use crate::{default_lifetime, ErrorType, TypeAttrs}; use field::SequenceField; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam, Path}; +use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; /// Derive the `Sequence` trait for a struct pub(crate) struct DeriveSequence { @@ -21,7 +21,7 @@ pub(crate) struct DeriveSequence { fields: Vec, /// Error type for `DecodeValue` implementation. - error: Option, + error: ErrorType, } impl DeriveSequence { @@ -35,7 +35,7 @@ impl DeriveSequence { ), }; - let mut type_attrs = TypeAttrs::parse(&input.attrs)?; + let type_attrs = TypeAttrs::parse(&input.attrs)?; let fields = data .fields @@ -47,7 +47,7 @@ impl DeriveSequence { ident: input.ident, generics: input.generics.clone(), fields, - error: type_attrs.error.take(), + error: type_attrs.error.clone(), }) } @@ -88,11 +88,7 @@ impl DeriveSequence { encode_fields.push(quote!(#field.encode(writer)?;)); } - let error = self - .error - .as_ref() - .map(ToTokens::to_token_stream) - .unwrap_or_else(|| quote! { ::der::Error }); + let error = self.error.to_token_stream(); quote! { impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause {