From 4181e81fbceffe35fc4d02d518d47f5b2c98607a Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sun, 6 Aug 2023 11:54:53 -0700 Subject: [PATCH] Move visitors to a shared module --- serdect/src/array.rs | 31 +++++--- serdect/src/common.rs | 181 ++++++++++++++++++++++++++++++++++++++++++ serdect/src/lib.rs | 30 +------ serdect/src/slice.rs | 155 +++--------------------------------- 4 files changed, 215 insertions(+), 182 deletions(-) create mode 100644 serdect/src/common.rs diff --git a/serdect/src/array.rs b/serdect/src/array.rs index 25ee4285b..551f05611 100644 --- a/serdect/src/array.rs +++ b/serdect/src/array.rs @@ -1,6 +1,6 @@ //! Serialization primitives for arrays. -// Unfortunately, we currently cannot assert generically that we are serializing +// Unfortunately, we currently cannot tell `serde` in a uniform fashion that we are serializing // a fixed-size byte array. // See https://github.com/serde-rs/serde/issues/2120 for the discussion. // Therefore we have to fall back to the slice methods, @@ -9,11 +9,12 @@ // to be exactly equal to the size of the buffer during deserialization, // while for slices the buffer can be larger than the deserialized data. +use core::fmt; use core::marker::PhantomData; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::slice; +use crate::common::{self, LengthCheck, SliceVisitor, StrIntoBufVisitor}; #[cfg(feature = "zeroize")] use zeroize::Zeroize; @@ -25,7 +26,7 @@ where S: Serializer, T: AsRef<[u8]>, { - slice::serialize_hex_lower_or_bin(value, serializer) + common::serialize_hex_lower_or_bin(value, serializer) } /// Serialize the given type as upper case hex when using human-readable @@ -35,7 +36,22 @@ where S: Serializer, T: AsRef<[u8]>, { - slice::serialize_hex_upper_or_bin(value, serializer) + common::serialize_hex_upper_or_bin(value, serializer) +} + +struct ExactLength; + +impl LengthCheck for ExactLength { + fn length_check(buffer_length: usize, data_length: usize) -> bool { + buffer_length == data_length + } + fn expecting( + formatter: &mut fmt::Formatter<'_>, + data_type: &str, + data_length: usize, + ) -> fmt::Result { + write!(formatter, "{} of length {}", data_type, data_length) + } } /// Deserialize from hex when using human-readable formats or binary if the @@ -46,12 +62,9 @@ where D: Deserializer<'de>, { if deserializer.is_human_readable() { - deserializer.deserialize_str(slice::StrVisitor::(buffer, PhantomData)) + deserializer.deserialize_str(StrIntoBufVisitor::(buffer, PhantomData)) } else { - deserializer.deserialize_byte_buf(slice::SliceVisitor::( - buffer, - PhantomData, - )) + deserializer.deserialize_byte_buf(SliceVisitor::(buffer, PhantomData)) } } diff --git a/serdect/src/common.rs b/serdect/src/common.rs new file mode 100644 index 000000000..52eb172d5 --- /dev/null +++ b/serdect/src/common.rs @@ -0,0 +1,181 @@ +use core::fmt; +use core::marker::PhantomData; + +use serde::{ + de::{Error, Visitor}, + Serializer, +}; + +#[cfg(feature = "alloc")] +use ::{alloc::vec::Vec, serde::Serialize}; + +#[cfg(not(feature = "alloc"))] +use serde::ser::Error as SerError; + +pub(crate) fn serialize_hex( + value: &T, + serializer: S, +) -> Result +where + S: Serializer, + T: AsRef<[u8]>, +{ + #[cfg(feature = "alloc")] + if UPPERCASE { + return base16ct::upper::encode_string(value.as_ref()).serialize(serializer); + } else { + return base16ct::lower::encode_string(value.as_ref()).serialize(serializer); + } + #[cfg(not(feature = "alloc"))] + { + let _ = value; + let _ = serializer; + return Err(S::Error::custom( + "serializer is human readable, which requires the `alloc` crate feature", + )); + } +} + +pub(crate) fn serialize_hex_lower_or_bin(value: &T, serializer: S) -> Result +where + S: Serializer, + T: AsRef<[u8]>, +{ + if serializer.is_human_readable() { + serialize_hex::<_, _, false>(value, serializer) + } else { + serializer.serialize_bytes(value.as_ref()) + } +} + +/// Serialize the given type as upper case hex when using human-readable +/// formats or binary if the format is binary. +pub(crate) fn serialize_hex_upper_or_bin(value: &T, serializer: S) -> Result +where + S: Serializer, + T: AsRef<[u8]>, +{ + if serializer.is_human_readable() { + serialize_hex::<_, _, true>(value, serializer) + } else { + serializer.serialize_bytes(value.as_ref()) + } +} + +pub(crate) trait LengthCheck { + fn length_check(buffer_length: usize, data_length: usize) -> bool; + fn expecting( + formatter: &mut fmt::Formatter<'_>, + data_type: &str, + data_length: usize, + ) -> fmt::Result; +} + +pub(crate) struct StrIntoBufVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData); + +impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrIntoBufVisitor<'b, T> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + T::expecting(formatter, "a string", self.0.len() * 2) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + if !T::length_check(self.0.len() * 2, v.len()) { + return Err(Error::invalid_length(v.len(), &self)); + } + // TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`. + base16ct::mixed::decode(v, self.0) + .map(|_| ()) + .map_err(E::custom) + } +} + +#[cfg(feature = "alloc")] +pub(crate) struct StrIntoVecVisitor; + +#[cfg(feature = "alloc")] +impl<'de> Visitor<'de> for StrIntoVecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "a string") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + base16ct::mixed::decode_vec(v).map_err(E::custom) + } +} + +pub(crate) struct SliceVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData); + +impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + T::expecting(formatter, "an array", self.0.len()) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: Error, + { + // Workaround for + // https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions + if T::length_check(self.0.len(), v.len()) { + let buffer = &mut self.0[..v.len()]; + buffer.copy_from_slice(v); + return Ok(()); + } + + Err(E::invalid_length(v.len(), &self)) + } + + #[cfg(feature = "alloc")] + fn visit_byte_buf(self, mut v: Vec) -> Result + where + E: Error, + { + // Workaround for + // https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions + if T::length_check(self.0.len(), v.len()) { + let buffer = &mut self.0[..v.len()]; + buffer.swap_with_slice(&mut v); + return Ok(()); + } + + Err(E::invalid_length(v.len(), &self)) + } +} + +#[cfg(feature = "alloc")] +pub(crate) struct VecVisitor; + +#[cfg(feature = "alloc")] +impl<'de> Visitor<'de> for VecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "a bytestring") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: Error, + { + Ok(v.into()) + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: Error, + { + Ok(v) + } +} diff --git a/serdect/src/lib.rs b/serdect/src/lib.rs index b38b2005d..7bc289f68 100644 --- a/serdect/src/lib.rs +++ b/serdect/src/lib.rs @@ -131,35 +131,7 @@ extern crate alloc; pub mod array; +mod common; pub mod slice; pub use serde; - -use serde::Serializer; - -#[cfg(not(feature = "alloc"))] -use serde::ser::Error; - -#[cfg(feature = "alloc")] -use serde::Serialize; - -fn serialize_hex(value: &T, serializer: S) -> Result -where - S: Serializer, - T: AsRef<[u8]>, -{ - #[cfg(feature = "alloc")] - if UPPERCASE { - return base16ct::upper::encode_string(value.as_ref()).serialize(serializer); - } else { - return base16ct::lower::encode_string(value.as_ref()).serialize(serializer); - } - #[cfg(not(feature = "alloc"))] - { - let _ = value; - let _ = serializer; - return Err(S::Error::custom( - "serializer is human readable, which requires the `alloc` crate feature", - )); - } -} diff --git a/serdect/src/slice.rs b/serdect/src/slice.rs index 7e7b3470f..9104d177a 100644 --- a/serdect/src/slice.rs +++ b/serdect/src/slice.rs @@ -3,14 +3,18 @@ use core::fmt; use core::marker::PhantomData; -use serde::de::{Error, Visitor}; use serde::{Deserializer, Serializer}; +use crate::common::{self, LengthCheck, SliceVisitor, StrIntoBufVisitor}; + #[cfg(feature = "alloc")] -use serde::Serialize; +use ::{ + alloc::vec::Vec, + serde::{Deserialize, Serialize}, +}; #[cfg(feature = "alloc")] -use ::{alloc::vec::Vec, serde::Deserialize}; +use crate::common::{StrIntoVecVisitor, VecVisitor}; #[cfg(feature = "zeroize")] use zeroize::Zeroize; @@ -22,11 +26,7 @@ where S: Serializer, T: AsRef<[u8]>, { - if serializer.is_human_readable() { - crate::serialize_hex::<_, _, false>(value, serializer) - } else { - serializer.serialize_bytes(value.as_ref()) - } + common::serialize_hex_lower_or_bin(value, serializer) } /// Serialize the given type as upper case hex when using human-readable @@ -36,35 +36,7 @@ where S: Serializer, T: AsRef<[u8]>, { - if serializer.is_human_readable() { - crate::serialize_hex::<_, _, true>(value, serializer) - } else { - serializer.serialize_bytes(value.as_ref()) - } -} - -pub(crate) trait LengthCheck { - fn length_check(buffer_length: usize, data_length: usize) -> bool; - fn expecting( - formatter: &mut fmt::Formatter<'_>, - data_type: &str, - data_length: usize, - ) -> fmt::Result; -} - -pub(crate) struct ExactLength; - -impl LengthCheck for ExactLength { - fn length_check(buffer_length: usize, data_length: usize) -> bool { - buffer_length == data_length - } - fn expecting( - formatter: &mut fmt::Formatter<'_>, - data_type: &str, - data_length: usize, - ) -> fmt::Result { - write!(formatter, "{} of length {}", data_type, data_length) - } + common::serialize_hex_upper_or_bin(value, serializer) } struct UpperBound; @@ -86,70 +58,6 @@ impl LengthCheck for UpperBound { } } -pub(crate) struct StrVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData); - -impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrVisitor<'b, T> { - type Value = (); - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - T::expecting(formatter, "a string", self.0.len() * 2) - } - - fn visit_str(self, v: &str) -> Result - where - E: Error, - { - if !T::length_check(self.0.len() * 2, v.len()) { - return Err(Error::invalid_length(v.len(), &self)); - } - // TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`. - base16ct::mixed::decode(v, self.0) - .map(|_| ()) - .map_err(E::custom) - } -} - -pub(crate) struct SliceVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData); - -impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> { - type Value = (); - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - T::expecting(formatter, "an array", self.0.len()) - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: Error, - { - // Workaround for - // https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions - if T::length_check(self.0.len(), v.len()) { - let buffer = &mut self.0[..v.len()]; - buffer.copy_from_slice(v); - return Ok(()); - } - - Err(E::invalid_length(v.len(), &self)) - } - - #[cfg(feature = "alloc")] - fn visit_byte_buf(self, mut v: Vec) -> Result - where - E: Error, - { - // Workaround for - // https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions - if T::length_check(self.0.len(), v.len()) { - let buffer = &mut self.0[..v.len()]; - buffer.swap_with_slice(&mut v); - return Ok(()); - } - - Err(E::invalid_length(v.len(), &self)) - } -} - /// Deserialize from hex when using human-readable formats or binary if the /// format is binary. Fails if the `buffer` is smaller then the resulting /// slice. @@ -158,7 +66,7 @@ where D: Deserializer<'de>, { if deserializer.is_human_readable() { - deserializer.deserialize_str(StrVisitor::(buffer, PhantomData)) + deserializer.deserialize_str(StrIntoBufVisitor::(buffer, PhantomData)) } else { deserializer.deserialize_byte_buf(SliceVisitor::(buffer, PhantomData)) } @@ -172,49 +80,8 @@ where D: Deserializer<'de>, { if deserializer.is_human_readable() { - struct StrVisitor; - - impl<'de> Visitor<'de> for StrVisitor { - type Value = Vec; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "a string") - } - - fn visit_str(self, v: &str) -> Result - where - E: Error, - { - base16ct::mixed::decode_vec(v).map_err(E::custom) - } - } - - deserializer.deserialize_str(StrVisitor) + deserializer.deserialize_str(StrIntoVecVisitor) } else { - struct VecVisitor; - - impl<'de> Visitor<'de> for VecVisitor { - type Value = Vec; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "a bytestring") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: Error, - { - Ok(v.into()) - } - - fn visit_byte_buf(self, v: Vec) -> Result - where - E: Error, - { - Ok(v) - } - } - deserializer.deserialize_byte_buf(VecVisitor) } }