Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SerializeBits/DeserializeBits derive macros #84

Merged
merged 11 commits into from
May 17, 2024
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ jobs:
- name: Lint stable
if: ${{ matrix.toolchain == 'stable' }}
run: |
cargo clippy --workspace -- -D warnings
cargo clippy --workspace --features serde -- -D warnings

- name: Lint nightly
if: ${{ matrix.toolchain == 'nightly-2022-11-03' }}
run: |
cargo clippy --workspace --features nightly -- -D warnings
cargo clippy --workspace --features nightly serde -- -D warnings

- name: Test stable
if: ${{ matrix.toolchain == 'stable' }}
run: |
cargo test --workspace
cargo test --workspace --features serde

- name: Test nightly
if: ${{ matrix.toolchain == 'nightly-2022-11-03' }}
run: |
cargo test --workspace --features nightly
cargo test --workspace --features nightly serde
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ rust-version = "1.65"
default = []
# Enables constness, see README.md: only usable with nightly-2022-11-03
nightly = ["arbitrary-int/const_convert_and_const_trait_impl", "bilge-impl/nightly"]
serde = ["bilge-impl/serde", "arbitrary-int/serde"]

[dependencies]
# cargo clippy workaround, we can't add `path = "../arbitrary-int"` as well
arbitrary-int = { version = "1.2.6" }
arbitrary-int = "1.2.7"
bilge-impl = { version = "=0.2.0", path = "bilge-impl" }

[dev-dependencies]
Expand All @@ -48,6 +49,8 @@ rustversion = "1.0"
trybuild = "1.0"
custom_bits = { path = "tests/custom_bits" }
assert_matches = "1.5.0"
serde = "1.0"
serde_test = "1.0"

# examples
# volatile = { git = "https://github.com/theseus-os/volatile" }
Expand Down
1 change: 1 addition & 0 deletions bilge-impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ proc-macro = true
default = []
# Enables constness, see README.md for the specific nightly version
nightly = []
serde = []

[dependencies]
syn = { version = "2.0", features = ["full"] }
Expand Down
23 changes: 23 additions & 0 deletions bilge-impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ mod debug_bits;
mod default_bits;
mod fmt_bits;
mod from_bits;
#[cfg(feature = "serde")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
mod serde_bits;
mod try_from_bits;

mod shared;
Expand Down Expand Up @@ -74,3 +77,23 @@ pub fn derive_binary_bits(item: TokenStream) -> TokenStream {
pub fn derive_default_bits(item: TokenStream) -> TokenStream {
default_bits::default_bits(item.into()).into()
}

/// Generate an `impl serde::Serialize` for bitfield structs.
///
/// Please use normal #[derive(Serialize)] for enums.
#[cfg(feature = "serde")]
#[proc_macro_error]
#[proc_macro_derive(SerializeBits, attributes(bitsize_internal))]
pub fn serialize_bits(item: TokenStream) -> TokenStream {
serde_bits::serialize_bits(item.into()).into()
}

/// Generate an `impl serde::Deserialize` for bitfield structs.
///
/// Please use normal #[derive(Deserialize)] for enums.
#[cfg(feature = "serde")]
#[proc_macro_error]
#[proc_macro_derive(DeserializeBits, attributes(bitsize_internal))]
pub fn deserialize_bits(item: TokenStream) -> TokenStream {
serde_bits::deserialize_bits(item.into()).into()
}
224 changes: 224 additions & 0 deletions bilge-impl/src/serde_bits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
use itertools::MultiUnzip;
use proc_macro2::{Ident, TokenStream};
use proc_macro_error::abort_call_site;
use quote::quote;
use syn::{Data, Field, Fields};

use crate::shared::{self, unreachable};

fn filter_not_reserved_or_padding(field: &&Field) -> bool {
let field_name_string = field.ident.as_ref().unwrap().to_string();
!field_name_string.starts_with("reserved_") && !field_name_string.starts_with("padding_")
}

pub(super) fn serialize_bits(item: TokenStream) -> TokenStream {
let derive_input = shared::parse_derive(item);
let name = &derive_input.ident;
let name_str = name.to_string();
let struct_data = match derive_input.data {
Data::Struct(s) => s,
Data::Enum(_) => abort_call_site!("use derive(Serialize) for enums"),
Data::Union(_) => unreachable(()),
};

let serialize_impl = match struct_data.fields {
Fields::Named(fields) => {
let calls = fields.named.iter().filter(filter_not_reserved_or_padding).map(|f| {
// We can unwrap since this is a named field
let call = f.ident.as_ref().unwrap();
let name = call.to_string();
quote!(state.serialize_field(#name, &self.#call())?;)
});
let len = fields.named.iter().filter(filter_not_reserved_or_padding).count();
quote! {
use ::serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct(#name_str, #len)?;
// state.serialize_field("field1", &self.field1())?; state.serialize_field("field2", &self.field2())?; state.serialize_field("field3", &self.field3())?; state.end()
#(#calls)*
state.end()
}
}
Fields::Unnamed(fields) => {
let calls = fields.unnamed.iter().enumerate().map(|(i, _)| {
let call: Ident = syn::parse_str(&format!("val_{}", i)).unwrap_or_else(unreachable);
quote!(state.serialize_field(&self.#call())?;)
});
let len = fields.unnamed.len();
quote! {
use serde::ser::SerializeTupleStruct;
let mut state = serializer.serialize_tuple_struct(#name_str, #len)?;
// state.serialize_field(&self.val0())?; state.serialize_field(&self.val1())?; state.end()
#(#calls)*
state.end()
}
}
Fields::Unit => todo!("this is a unit struct, which is not supported right now"),
};

quote! {
impl ::serde::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
#serialize_impl
}
}
}
}

fn deserialize_field_parts(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this.

i: usize, field_ident: &Ident,
) -> (
TokenStream,
TokenStream,
TokenStream,
TokenStream,
TokenStream,
TokenStream,
TokenStream,
String,
) {
let field_name_string = field_ident.to_string();
(
quote!(#field_ident,),
quote!(#field_name_string => Ok(Field::#field_ident),),
quote!(#field_name_string,),
quote!(let #field_ident = seq.next_element()?.ok_or_else(|| ::serde::de::Error::invalid_length(#i, &self))?;),
quote!(let mut #field_ident = None;),
quote!(Field::#field_ident => {
if #field_ident.is_some() {
return Err(::serde::de::Error::duplicate_field(#field_name_string));
}
#field_ident = Some(map.next_value()?);
}),
quote!(let #field_ident = #field_ident.ok_or_else(|| ::serde::de::Error::missing_field(#field_name_string))?;),
format!("`{}`", field_name_string),
)
}

pub(super) fn deserialize_bits(item: TokenStream) -> TokenStream {
let derive_input = shared::parse_derive(item);
let name = &derive_input.ident;
let name_str = name.to_string();
let struct_name_str = format!("struct {}", name_str);
let struct_data = match derive_input.data {
Data::Struct(s) => s,
Data::Enum(_) => abort_call_site!("use derive(Serialize) for enums"),
Data::Union(_) => unreachable(()),
};

let should_have_visit_map = matches!(struct_data.fields, Fields::Named(_));

let (
field_names,
field_deserialize,
field_name_strings,
field_visit_seq,
field_visit_map_init,
field_visit_map_match,
field_visit_map_check,
mut field_expecting,
): (Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>) = match struct_data.fields {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this.

Fields::Named(fields) => fields
.named
.iter()
.filter(filter_not_reserved_or_padding)
.enumerate()
.map(|(i, f)| deserialize_field_parts(i, f.ident.as_ref().unwrap()))
.multiunzip(),
Fields::Unnamed(fields) => fields
.unnamed
.iter()
.enumerate()
.map(|(i, _)| deserialize_field_parts(i, &syn::parse_str(&format!("val_{}", i)).unwrap_or_else(unreachable)))
.multiunzip(),
Fields::Unit => todo!("this is a unit struct, which is not supported right now"),
};

if field_expecting.len() > 1 {
field_expecting.last_mut().unwrap().insert_str(0, "or ");
}
let field_expecting = field_expecting.join(", ");

let visit_map = if should_have_visit_map {
quote!(fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: ::serde::de::MapAccess<'de>,
{
#(#field_visit_map_init)*
while let Some(key) = map.next_key()? {
match key {
#(#field_visit_map_match)*
}
}
#(#field_visit_map_check)*
Ok(#name::new(#(#field_names)*))
})
} else {
quote!()
};

quote! {
impl<'de> ::serde::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: ::serde::Deserializer<'de>,
{
#[allow(non_camel_case_types)]
enum Field { #(#field_names)* }
impl<'de> ::serde::Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where
D: ::serde::Deserializer<'de>,
{
struct FieldVisitor;

impl<'de> ::serde::de::Visitor<'de> for FieldVisitor {
type Value = Field;

fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
formatter.write_str(#field_expecting)
}

fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: ::serde::de::Error,
{
match value {
#(#field_deserialize)*
_ => Err(::serde::de::Error::unknown_field(value, FIELDS)),
}
}
}

deserializer.deserialize_identifier(FieldVisitor)
}
}

struct Visitor;

impl<'de> ::serde::de::Visitor<'de> for Visitor {
type Value = #name;

fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
formatter.write_str(#struct_name_str)
}

fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
where
V: ::serde::de::SeqAccess<'de>,
{
#(#field_visit_seq)*
Ok(Self::Value::new(#(#field_names)*))
}

#visit_map
}

const FIELDS: &'static [&'static str] = &[#(#field_name_strings)*];
deserializer.deserialize_struct(#name_str, FIELDS, Visitor)
}
}
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use core::fmt;
#[doc(no_inline)]
pub use arbitrary_int;
pub use bilge_impl::{bitsize, bitsize_internal, BinaryBits, DebugBits, DefaultBits, FromBits, TryFromBits};
#[cfg(feature = "serde")]
pub use bilge_impl::{DeserializeBits, SerializeBits};

/// used for `use bilge::prelude::*;`
pub mod prelude {
Expand All @@ -17,6 +19,8 @@ pub mod prelude {
// we control the version, so this should not be a problem
arbitrary_int::*,
};
#[cfg(feature = "serde")]
pub use super::{DeserializeBits, SerializeBits};
}

/// This is internally used, but might be useful. No guarantees are given (for now).
Expand Down
Loading
Loading