Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SerializeBits/DeserializeBits derive macros #84

Merged
merged 11 commits into from
May 17, 2024
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ rust-version = "1.65"
default = []
# Enables constness, see README.md: only usable with nightly-2022-11-03
nightly = ["arbitrary-int/const_convert_and_const_trait_impl", "bilge-impl/nightly"]
serde = ["bilge-impl/serde", "arbitrary-int/serde"]

[dependencies]
# cargo clippy workaround, we can't add `path = "../arbitrary-int"` as well
arbitrary-int = { version = "1.2.6" }
arbitrary-int = "1.2.7"
bilge-impl = { version = "=0.2.0", path = "bilge-impl" }

[dev-dependencies]
Expand All @@ -48,6 +49,8 @@ rustversion = "1.0"
trybuild = "1.0"
custom_bits = { path = "tests/custom_bits" }
assert_matches = "1.5.0"
serde = "1.0"
serde_test = "1.0"

# examples
# volatile = { git = "https://github.com/theseus-os/volatile" }
Expand Down
1 change: 1 addition & 0 deletions bilge-impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ proc-macro = true
default = []
# Enables constness, see README.md for the specific nightly version
nightly = []
serde = []

[dependencies]
syn = { version = "2.0", features = ["full"] }
Expand Down
23 changes: 23 additions & 0 deletions bilge-impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ mod debug_bits;
mod default_bits;
mod fmt_bits;
mod from_bits;
#[cfg(feature = "serde")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
mod serde_bits;
mod try_from_bits;

mod shared;
Expand Down Expand Up @@ -74,3 +77,23 @@ pub fn derive_binary_bits(item: TokenStream) -> TokenStream {
pub fn derive_default_bits(item: TokenStream) -> TokenStream {
default_bits::default_bits(item.into()).into()
}

/// Generate an `impl serde::Serialize` for bitfield structs.
///
/// Please use normal #[derive(Serialize)] for enums.
#[cfg(feature = "serde")]
#[proc_macro_error]
#[proc_macro_derive(SerializeBits, attributes(bitsize_internal))]
pub fn serialize_bits(item: TokenStream) -> TokenStream {
serde_bits::serialize_bits(item.into()).into()
}

/// Generate an `impl serde::Deserialize` for bitfield structs.
///
/// Please use normal #[derive(Deserialize)] for enums.
#[cfg(feature = "serde")]
#[proc_macro_error]
#[proc_macro_derive(DeserializeBits, attributes(bitsize_internal))]
pub fn deserialize_bits(item: TokenStream) -> TokenStream {
serde_bits::deserialize_bits(item.into()).into()
}
226 changes: 226 additions & 0 deletions bilge-impl/src/serde_bits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
use itertools::multiunzip;
use proc_macro2::{Ident, TokenStream};
use proc_macro_error::abort_call_site;
use quote::quote;
use syn::{Data, Field, Fields};

use crate::shared::{self, unreachable};

fn filter_not_reserved_or_padding(field: &&Field) -> bool {
let field_name_string = field.ident.as_ref().unwrap().to_string();
!field_name_string.starts_with("reserved_") && !field_name_string.starts_with("padding_")
}

pub(super) fn serialize_bits(item: TokenStream) -> TokenStream {
let derive_input = shared::parse_derive(item);
let name = &derive_input.ident;
let name_str = name.to_string();
let struct_data = match derive_input.data {
Data::Struct(s) => s,
Data::Enum(_) => abort_call_site!("use derive(Serialize) for enums"),
Data::Union(_) => unreachable(()),
};

let serialize_impl = match struct_data.fields {
Fields::Named(fields) => {
let calls = fields.named.iter().filter(filter_not_reserved_or_padding).map(|f| {
// We can unwrap since this is a named field
let call = f.ident.as_ref().unwrap();
let name = call.to_string();
quote!(state.serialize_field(#name, &self.#call())?;)
});
let len = fields.named.iter().filter(filter_not_reserved_or_padding).count();
quote! {
use ::serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct(#name_str, #len)?;
// state.serialize_field("field1", &self.field1())?; state.serialize_field("field2", &self.field2())?; state.serialize_field("field3", &self.field3())?; state.end()
#(#calls)*
state.end()
}
}
Fields::Unnamed(fields) => {
let calls = fields.unnamed.iter().enumerate().map(|(i, _)| {
let call: Ident = syn::parse_str(&format!("val_{}", i)).unwrap_or_else(unreachable);
quote!(state.serialize_field(&self.#call())?;)
});
let len = fields.unnamed.len();
quote! {
use serde::ser::SerializeTupleStruct;
let mut state = serializer.serialize_tuple_struct(#name_str, #len)?;
// state.serialize_field(&self.val0())?; state.serialize_field(&self.val1())?; state.end()
#(#calls)*
state.end()
}
}
Fields::Unit => todo!("this is a unit struct, which is not supported right now"),
};

quote! {
impl ::serde::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
#serialize_impl
}
}
}
}

fn deserialize_field_parts(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this.

i: usize, field_ident: &Ident,
) -> (
TokenStream,
TokenStream,
TokenStream,
TokenStream,
TokenStream,
TokenStream,
TokenStream,
String,
) {
let field_name_string = field_ident.to_string();
(
quote!(#field_ident,),
quote!(#field_name_string => Ok(Field::#field_ident),),
quote!(#field_name_string,),
quote!(let #field_ident = seq.next_element()?.ok_or_else(|| ::serde::de::Error::invalid_length(#i, &self))?;),
quote!(let mut #field_ident = None;),
quote!(Field::#field_ident => {
if #field_ident.is_some() {
return Err(::serde::de::Error::duplicate_field(#field_name_string));
}
#field_ident = Some(map.next_value()?);
}),
quote!(let #field_ident = #field_ident.ok_or_else(|| ::serde::de::Error::missing_field(#field_name_string))?;),
format!("`{}`", field_name_string),
)
}

pub(super) fn deserialize_bits(item: TokenStream) -> TokenStream {
let derive_input = shared::parse_derive(item);
let name = &derive_input.ident;
let name_str = name.to_string();
let struct_name_str = format!("struct {}", name_str);
let struct_data = match derive_input.data {
Data::Struct(s) => s,
Data::Enum(_) => abort_call_site!("use derive(Serialize) for enums"),
Data::Union(_) => unreachable(()),
};

let should_have_visit_map = matches!(struct_data.fields, Fields::Named(_));

let (
field_names,
field_deserialize,
field_name_strings,
field_visit_seq,
field_visit_map_init,
field_visit_map_match,
field_visit_map_check,
mut field_expecting,
): (Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>) = match struct_data.fields {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this.

Fields::Named(fields) => multiunzip(
fields
.named
.iter()
.filter(filter_not_reserved_or_padding)
.enumerate()
.map(|(i, f)| deserialize_field_parts(i, f.ident.as_ref().unwrap())),
),
Fields::Unnamed(fields) => multiunzip(
fields
.unnamed
.iter()
.enumerate()
.map(|(i, _)| deserialize_field_parts(i, &syn::parse_str(&format!("val_{}", i)).unwrap_or_else(unreachable))),
),
Fields::Unit => todo!("this is a unit struct, which is not supported right now"),
};

if field_expecting.len() > 1 {
field_expecting.last_mut().unwrap().insert_str(0, "or ");
}
let field_expecting = field_expecting.join(", ");

let visit_map = if should_have_visit_map {
quote!(fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: ::serde::de::MapAccess<'de>,
{
#(#field_visit_map_init)*
while let Some(key) = map.next_key()? {
match key {
#(#field_visit_map_match)*
}
}
#(#field_visit_map_check)*
Ok(#name::new(#(#field_names)*))
})
} else {
quote!()
};

quote! {
impl<'de> ::serde::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: ::serde::Deserializer<'de>,
{
#[allow(non_camel_case_types)]
enum Field { #(#field_names)* }
impl<'de> ::serde::Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where
D: ::serde::Deserializer<'de>,
{
struct FieldVisitor;

impl<'de> ::serde::de::Visitor<'de> for FieldVisitor {
type Value = Field;

fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
formatter.write_str(#field_expecting)
}

fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: ::serde::de::Error,
{
match value {
#(#field_deserialize)*
_ => Err(::serde::de::Error::unknown_field(value, FIELDS)),
}
}
}

deserializer.deserialize_identifier(FieldVisitor)
}
}

struct Visitor;

impl<'de> ::serde::de::Visitor<'de> for Visitor {
type Value = #name;

fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
formatter.write_str(#struct_name_str)
}

fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
where
V: ::serde::de::SeqAccess<'de>,
{
#(#field_visit_seq)*
Ok(Self::Value::new(#(#field_names)*))
}

#visit_map
}

const FIELDS: &'static [&'static str] = &[#(#field_name_strings)*];
deserializer.deserialize_struct(#name_str, FIELDS, Visitor)
}
}
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use core::fmt;
#[doc(no_inline)]
pub use arbitrary_int;
pub use bilge_impl::{bitsize, bitsize_internal, BinaryBits, DebugBits, DefaultBits, FromBits, TryFromBits};
#[cfg(feature = "serde")]
pub use bilge_impl::{DeserializeBits, SerializeBits};

/// used for `use bilge::prelude::*;`
pub mod prelude {
Expand All @@ -17,6 +19,8 @@ pub mod prelude {
// we control the version, so this should not be a problem
arbitrary_int::*,
};
#[cfg(feature = "serde")]
pub use super::{DeserializeBits, SerializeBits};
}

/// This is internally used, but might be useful. No guarantees are given (for now).
Expand Down
90 changes: 90 additions & 0 deletions tests/serde.rs
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are nice, but we have to have one testing an arbitrary integer field ser-de 😄

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#![cfg(feature = "serde")]

use bilge::prelude::*;
use serde_test::{assert_de_tokens_error, assert_tokens, Token};

#[bitsize(17)]
#[derive(FromBits, PartialEq, SerializeBits, DeserializeBits, DebugBits)]
struct BitsStruct {
field1: u8,
field2: u8,
padding: u1,
}

#[test]
fn serde_struct() {
let bits = BitsStruct::from(u17::new(0b0_0000_0001_0010_0011));

assert_tokens(
&bits,
&[
Token::Struct { name: "BitsStruct", len: 2 },
Token::Str("field1"),
Token::U8(0b0010_0011),
Token::Str("field2"),
Token::U8(0b0000_0001),
Token::StructEnd,
],
);
}

#[test]
fn serde_struct_missing_field() {
assert_de_tokens_error::<BitsStruct>(
&[
Token::Struct { name: "BitsStruct", len: 1 },
Token::Str("field1"),
Token::U8(0b0010_0011),
Token::StructEnd,
],
"missing field `field2`",
);
}

#[test]
fn serde_struct_extra_field() {
assert_de_tokens_error::<BitsStruct>(
&[
Token::Struct { name: "BitsStruct", len: 3 },
Token::Str("field1"),
Token::U8(0b0010_0011),
Token::Str("field2"),
Token::U8(0b0000_0001),
Token::Str("field3"),
],
"unknown field `field3`, expected `field1` or `field2`",
);
}

#[bitsize(16)]
#[derive(FromBits, PartialEq, SerializeBits, DeserializeBits, DebugBits)]
struct BitsTupleStruct(u8, u8);

#[test]
fn serde_tuple_struct() {
let bits = BitsTupleStruct::from(0b0000_0001_0010_0011);

assert_tokens(
&bits,
&[
Token::TupleStruct {
name: "BitsTupleStruct",
len: 2,
},
Token::U8(0b0010_0011),
Token::U8(0b0000_0001),
Token::TupleStructEnd,
],
);
}

#[test]
fn serde_tuple_struct_map() {
assert_de_tokens_error::<BitsTupleStruct>(
&[
Token::TupleStruct { name: "BitsStruct", len: 3 },
Token::Str("val_0"),
],
r#"invalid type: string "val_0", expected u8"#,
);
}
Loading