Skip to content

Commit

Permalink
der_derive: use an ErrorType to store the error attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
baloo committed Oct 8, 2024
1 parent 87bd37f commit 4e31347
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 25 deletions.
30 changes: 26 additions & 4 deletions der_derive/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,33 @@
use crate::{Asn1Type, Tag, TagMode, TagNumber};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::{quote, ToTokens};
use std::{fmt::Debug, str::FromStr};
use syn::punctuated::Punctuated;
use syn::{parse::Parse, parse::ParseStream, Attribute, Ident, LitStr, Path, Token};

/// Error type used by the structure
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub(crate) enum ErrorType {
/// Represents the ::der::Error type
#[default]
Der,
/// Represents an error designed by Path
Custom(Path),
}

impl ToTokens for ErrorType {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
Self::Der => {
let err = quote! { ::der::Error };
err.to_tokens(tokens)
}
Self::Custom(path) => path.to_tokens(tokens),
}
}
}

/// Attribute name.
pub(crate) const ATTR_NAME: &str = "asn1";

Expand All @@ -18,7 +40,7 @@ pub(crate) struct TypeAttrs {
///
/// The default value is `EXPLICIT`.
pub tag_mode: TagMode,
pub error: Option<Path>,
pub error: ErrorType,
}

impl TypeAttrs {
Expand All @@ -44,7 +66,7 @@ impl TypeAttrs {
abort!(attr, "duplicate ASN.1 `error` attribute");
}

error = Some(meta.value()?.parse()?);
error = Some(ErrorType::Custom(meta.value()?.parse()?));
} else {
return Err(syn::Error::new_spanned(
attr,
Expand All @@ -58,7 +80,7 @@ impl TypeAttrs {

Ok(Self {
tag_mode: tag_mode.unwrap_or_default(),
error,
error: error.unwrap_or_default(),
})
}
}
Expand Down
16 changes: 6 additions & 10 deletions der_derive/src/choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
mod variant;

use self::variant::ChoiceVariant;
use crate::{default_lifetime, TypeAttrs};
use crate::{default_lifetime, ErrorType, TypeAttrs};
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam, Path};
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};

/// Derive the `Choice` trait for an enum.
pub(crate) struct DeriveChoice {
Expand All @@ -22,7 +22,7 @@ pub(crate) struct DeriveChoice {
variants: Vec<ChoiceVariant>,

/// Error type for `DecodeValue` implementation.
error: Option<Path>,
error: ErrorType,
}

impl DeriveChoice {
Expand All @@ -36,7 +36,7 @@ impl DeriveChoice {
),
};

let mut type_attrs = TypeAttrs::parse(&input.attrs)?;
let type_attrs = TypeAttrs::parse(&input.attrs)?;
let variants = data
.variants
.iter()
Expand All @@ -47,7 +47,7 @@ impl DeriveChoice {
ident: input.ident,
generics: input.generics.clone(),
variants,
error: type_attrs.error.take(),
error: type_attrs.error.clone(),
})
}

Expand Down Expand Up @@ -88,11 +88,7 @@ impl DeriveChoice {
tagged_body.push(variant.to_tagged_tokens());
}

let error = self
.error
.as_ref()
.map(ToTokens::to_token_stream)
.unwrap_or_else(|| quote! { ::der::Error });
let error = self.error.to_token_stream();

quote! {
impl #impl_generics ::der::Choice<#lifetime> for #ident #ty_generics #where_clause {
Expand Down
2 changes: 1 addition & 1 deletion der_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ mod value_ord;

use crate::{
asn1_type::Asn1Type,
attributes::{FieldAttrs, TypeAttrs, ATTR_NAME},
attributes::{ErrorType, FieldAttrs, TypeAttrs, ATTR_NAME},
choice::DeriveChoice,
enumerated::DeriveEnumerated,
sequence::DeriveSequence,
Expand Down
16 changes: 6 additions & 10 deletions der_derive/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
mod field;

use crate::{default_lifetime, TypeAttrs};
use crate::{default_lifetime, ErrorType, TypeAttrs};
use field::SequenceField;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam, Path};
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};

/// Derive the `Sequence` trait for a struct
pub(crate) struct DeriveSequence {
Expand All @@ -21,7 +21,7 @@ pub(crate) struct DeriveSequence {
fields: Vec<SequenceField>,

/// Error type for `DecodeValue` implementation.
error: Option<Path>,
error: ErrorType,
}

impl DeriveSequence {
Expand All @@ -35,7 +35,7 @@ impl DeriveSequence {
),
};

let mut type_attrs = TypeAttrs::parse(&input.attrs)?;
let type_attrs = TypeAttrs::parse(&input.attrs)?;

let fields = data
.fields
Expand All @@ -47,7 +47,7 @@ impl DeriveSequence {
ident: input.ident,
generics: input.generics.clone(),
fields,
error: type_attrs.error.take(),
error: type_attrs.error.clone(),
})
}

Expand Down Expand Up @@ -88,11 +88,7 @@ impl DeriveSequence {
encode_fields.push(quote!(#field.encode(writer)?;));
}

let error = self
.error
.as_ref()
.map(ToTokens::to_token_stream)
.unwrap_or_else(|| quote! { ::der::Error });
let error = self.error.to_token_stream();

quote! {
impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause {
Expand Down

0 comments on commit 4e31347

Please sign in to comment.