Skip to content

Commit

Permalink
tls_codec: implement U24 (#1284)
Browse files Browse the repository at this point in the history
  • Loading branch information
tnytown authored Jan 4, 2024
1 parent 91c3aa6 commit 189b4d0
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 22 deletions.
50 changes: 47 additions & 3 deletions tls_codec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ mod quic_vec;
mod tls_vec;

pub use tls_vec::{
SecretTlsVecU16, SecretTlsVecU32, SecretTlsVecU8, TlsByteSliceU16, TlsByteSliceU32,
TlsByteSliceU8, TlsByteVecU16, TlsByteVecU32, TlsByteVecU8, TlsSliceU16, TlsSliceU32,
TlsSliceU8, TlsVecU16, TlsVecU32, TlsVecU8,
SecretTlsVecU16, SecretTlsVecU24, SecretTlsVecU32, SecretTlsVecU8, TlsByteSliceU16,
TlsByteSliceU24, TlsByteSliceU32, TlsByteSliceU8, TlsByteVecU16, TlsByteVecU24, TlsByteVecU32,
TlsByteVecU8, TlsSliceU16, TlsSliceU24, TlsSliceU32, TlsSliceU8, TlsVecU16, TlsVecU24,
TlsVecU32, TlsVecU8,
};

#[cfg(feature = "std")]
Expand Down Expand Up @@ -226,3 +227,46 @@ pub trait DeserializeBytes: Size {
Ok(out)
}
}

/// A 3 byte wide unsigned integer type as defined in [RFC 5246].
///
/// [RFC 5246]: https://datatracker.ietf.org/doc/html/rfc5246#section-4.4
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub struct U24([u8; 3]);

impl U24 {
pub const MAX: Self = Self([255u8; 3]);
pub const MIN: Self = Self([0u8; 3]);

pub fn from_be_bytes(bytes: [u8; 3]) -> Self {
U24(bytes)
}

pub fn to_be_bytes(self) -> [u8; 3] {
self.0
}
}

impl From<U24> for usize {
fn from(value: U24) -> usize {
const LEN: usize = core::mem::size_of::<usize>();
let mut usize_bytes = [0u8; LEN];
usize_bytes[LEN - 3..].copy_from_slice(&value.0);
usize::from_be_bytes(usize_bytes)
}
}

impl TryFrom<usize> for U24 {
type Error = Error;

fn try_from(value: usize) -> Result<Self, Self::Error> {
const LEN: usize = core::mem::size_of::<usize>();
// In practice, our usages of this conversion should never be invalid, as the values
// have to come from `TryFrom<U24> for usize`.
if value > (1 << 24) - 1 {
Err(Error::LibraryError)
} else {
Ok(U24(value.to_be_bytes()[LEN - 3..].try_into()?))
}
}
}
5 changes: 3 additions & 2 deletions tls_codec/src/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use alloc::vec::Vec;

use crate::{DeserializeBytes, SerializeBytes};
use crate::{DeserializeBytes, SerializeBytes, U24};

use super::{Deserialize, Error, Serialize, Size};

Expand Down Expand Up @@ -115,7 +115,7 @@ macro_rules! impl_unsigned {
#[cfg(feature = "std")]
#[inline]
fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error> {
let mut x = (0 as $t).to_be_bytes();
let mut x = <$t>::default().to_be_bytes();
bytes.read_exact(&mut x)?;
Ok(<$t>::from_be_bytes(x))
}
Expand Down Expand Up @@ -187,6 +187,7 @@ macro_rules! impl_unsigned {

impl_unsigned!(u8, 1);
impl_unsigned!(u16, 2);
impl_unsigned!(U24, 3);
impl_unsigned!(u32, 4);
impl_unsigned!(u64, 8);

Expand Down
23 changes: 14 additions & 9 deletions tls_codec/src/tls_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use serde::ser::SerializeStruct;
use std::io::{Read, Write};
use zeroize::Zeroize;

use crate::{Deserialize, DeserializeBytes, Error, Serialize, SerializeBytes, Size};
use crate::{Deserialize, DeserializeBytes, Error, Serialize, SerializeBytes, Size, U24};

macro_rules! impl_size {
($self:ident, $size:ty, $name:ident, $len_len:literal) => {
Expand Down Expand Up @@ -42,7 +42,7 @@ macro_rules! impl_byte_deserialize {
#[cfg(feature = "std")]
#[inline(always)]
fn deserialize_bytes<R: Read>(bytes: &mut R) -> Result<Self, Error> {
let len = <$size>::tls_deserialize(bytes)? as usize;
let len = <$size>::tls_deserialize(bytes)?.try_into().unwrap();
// When fuzzing we limit the maximum size to allocate.
// XXX: We should think about a configurable limit for the allocation
// here.
Expand All @@ -63,7 +63,7 @@ macro_rules! impl_byte_deserialize {
#[inline(always)]
fn deserialize_bytes_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> {
let (type_len, remainder) = <$size>::tls_deserialize_bytes(bytes)?;
let len = type_len as usize;
let len = type_len.try_into().unwrap();
// When fuzzing we limit the maximum size to allocate.
// XXX: We should think about a configurable limit for the allocation
// here.
Expand Down Expand Up @@ -92,7 +92,7 @@ macro_rules! impl_deserialize {
let len = <$size>::tls_deserialize(bytes)?;
let mut read = len.tls_serialized_len();
let len_len = read;
while (read - len_len) < len as usize {
while (read - len_len) < len.try_into().unwrap() {
let element = T::tls_deserialize(bytes)?;
read += element.tls_serialized_len();
result.push(element);
Expand All @@ -110,7 +110,7 @@ macro_rules! impl_deserialize_bytes {
let (len, mut remainder) = <$size>::tls_deserialize_bytes(bytes)?;
let mut read = len.tls_serialized_len();
let len_len = read;
while (read - len_len) < len as usize {
while (read - len_len) < len.try_into().unwrap() {
let (element, next_remainder) = T::tls_deserialize_bytes(remainder)?;
remainder = next_remainder;
read += element.tls_serialized_len();
Expand All @@ -130,7 +130,7 @@ macro_rules! impl_serialize {
// large and write it out.
let (tls_serialized_len, byte_length) = $self.get_content_lengths()?;

let mut written = <$size as Serialize>::tls_serialize(&(byte_length as $size), writer)?;
let mut written = <$size as Serialize>::tls_serialize(&<$size>::try_from(byte_length).unwrap(), writer)?;

// Now serialize the elements
for e in $self.as_slice().iter() {
Expand All @@ -152,7 +152,7 @@ macro_rules! impl_byte_serialize {
// large and write it out.
let (tls_serialized_len, byte_length) = $self.get_content_lengths()?;

let mut written = <$size as Serialize>::tls_serialize(&(byte_length as $size), writer)?;
let mut written = <$size as Serialize>::tls_serialize(&<$size>::try_from(byte_length).unwrap(), writer)?;

// Now serialize the elements
written += writer.write($self.as_slice())?;
Expand All @@ -170,7 +170,7 @@ macro_rules! impl_serialize_common {
let tls_serialized_len = $self.tls_serialized_len();
let byte_length = tls_serialized_len - $len_len;

let max_len = <$size>::MAX as usize;
let max_len = <$size>::MAX.try_into().unwrap();
debug_assert!(
byte_length <= max_len,
"Vector length can't be encoded in the vector length a {} >= {}",
Expand Down Expand Up @@ -207,7 +207,7 @@ macro_rules! impl_serialize_bytes_bytes {
let (tls_serialized_len, byte_length) = $self.get_content_lengths()?;

let mut vec = Vec::<u8>::with_capacity(tls_serialized_len);
let length_vec = <$size as SerializeBytes>::tls_serialize(&(byte_length as $size))?;
let length_vec = <$size as SerializeBytes>::tls_serialize(&byte_length.try_into().unwrap())?;
let mut written = length_vec.len();
vec.extend_from_slice(&length_vec);

Expand Down Expand Up @@ -885,15 +885,18 @@ macro_rules! impl_tls_byte_vec {

impl_public_tls_vec!(u8, TlsVecU8, 1);
impl_public_tls_vec!(u16, TlsVecU16, 2);
impl_public_tls_vec!(U24, TlsVecU24, 3);
impl_public_tls_vec!(u32, TlsVecU32, 4);

impl_tls_byte_vec!(u8, TlsByteVecU8, 1);
impl_tls_byte_vec!(u16, TlsByteVecU16, 2);
impl_tls_byte_vec!(U24, TlsByteVecU24, 3);
impl_tls_byte_vec!(u32, TlsByteVecU32, 4);

// Secrets should be put into these Secret tls vectors as they implement zeroize.
impl_secret_tls_vec!(u8, SecretTlsVecU8, 1);
impl_secret_tls_vec!(u16, SecretTlsVecU16, 2);
impl_secret_tls_vec!(U24, SecretTlsVecU24, 3);
impl_secret_tls_vec!(u32, SecretTlsVecU32, 4);

// We also implement shallow serialization for slices
Expand Down Expand Up @@ -948,6 +951,7 @@ macro_rules! impl_tls_byte_slice {

impl_tls_byte_slice!(u8, TlsByteSliceU8, 1);
impl_tls_byte_slice!(u16, TlsByteSliceU16, 2);
impl_tls_byte_slice!(U24, TlsByteSliceU24, 3);
impl_tls_byte_slice!(u32, TlsByteSliceU32, 4);

macro_rules! impl_tls_slice {
Expand Down Expand Up @@ -1003,6 +1007,7 @@ macro_rules! impl_tls_slice {

impl_tls_slice!(u8, TlsSliceU8, 1);
impl_tls_slice!(u16, TlsSliceU16, 2);
impl_tls_slice!(U24, TlsSliceU24, 3);
impl_tls_slice!(u32, TlsSliceU32, 4);

impl From<core::num::TryFromIntError> for Error {
Expand Down
7 changes: 5 additions & 2 deletions tls_codec/tests/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use tls_codec::{
Error, Serialize, Size, TlsByteSliceU16, TlsByteVecU16, TlsByteVecU8, TlsSliceU16, TlsVecU16,
TlsVecU32, TlsVecU8, VLByteSlice, VLBytes,
TlsVecU32, TlsVecU8, VLByteSlice, VLBytes, U24,
};

#[test]
Expand Down Expand Up @@ -41,7 +41,7 @@ fn deserialize_option_bytes() {
#[test]
fn deserialize_bytes_primitives() {
use tls_codec::DeserializeBytes;
let b = &[77u8, 88, 1, 99] as &[u8];
let b = &[77u8, 88, 1, 99, 1, 0, 73] as &[u8];

let (a, remainder) = u8::tls_deserialize_bytes(b).expect("Unable to tls_deserialize");
assert_eq!(1, a.tls_serialized_len());
Expand All @@ -52,6 +52,9 @@ fn deserialize_bytes_primitives() {
let (a, remainder) = u16::tls_deserialize_bytes(remainder).expect("Unable to tls_deserialize");
assert_eq!(2, a.tls_serialized_len());
assert_eq!(355, a);
let (a, remainder) = U24::tls_deserialize_bytes(remainder).expect("Unable to tls_deserialize");
assert_eq!(3, a.tls_serialized_len());
assert_eq!(U24::try_from(65609usize).unwrap(), a);

// It's empty now.
assert!(remainder.is_empty());
Expand Down
11 changes: 10 additions & 1 deletion tls_codec/tests/decode_bytes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use tls_codec::{DeserializeBytes, TlsByteVecU16, TlsByteVecU32, TlsByteVecU8};
use tls_codec::{DeserializeBytes, TlsByteVecU16, TlsByteVecU24, TlsByteVecU32, TlsByteVecU8};

#[test]
fn deserialize_tls_byte_vec_u8() {
Expand All @@ -18,6 +18,15 @@ fn deserialize_tls_byte_vec_u16() {
assert_eq!(rest, []);
}

#[test]
fn deserialize_tls_byte_vec_u24() {
let bytes = [0, 0, 3, 2, 1, 0];
let (result, rest) = TlsByteVecU24::tls_deserialize_bytes(&bytes).unwrap();
let expected_result = [2, 1, 0];
assert_eq!(result.as_slice(), expected_result);
assert_eq!(rest, []);
}

#[test]
fn deserialize_tls_byte_vec_u32() {
let bytes = [0, 0, 0, 3, 2, 1, 0];
Expand Down
13 changes: 10 additions & 3 deletions tls_codec/tests/encode.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
#![cfg(feature = "std")]

use tls_codec::{Serialize, TlsVecU16, VLByteSlice, VLBytes};
use tls_codec::{Serialize, TlsVecU16, TlsVecU24, VLByteSlice, VLBytes, U24};

#[test]
fn serialize_primitives() {
let mut v = Vec::new();
77u8.tls_serialize(&mut v).expect("Error encoding u8");
88u8.tls_serialize(&mut v).expect("Error encoding u8");
355u16.tls_serialize(&mut v).expect("Error encoding u16");
let b = [77u8, 88, 1, 99];
U24::try_from(65609usize)
.unwrap()
.tls_serialize(&mut v)
.expect("Error encoding U24");
let b = [77u8, 88, 1, 99, 1, 0, 73];
assert_eq!(&b[..], &v[..]);
}

Expand All @@ -19,8 +23,11 @@ fn serialize_tls_vec() {
TlsVecU16::<u8>::from_slice(&[77, 88, 1, 99])
.tls_serialize(&mut v)
.expect("Error encoding u8");
TlsVecU24::<u8>::from_slice(&[255, 42, 73])
.tls_serialize(&mut v)
.expect("Error encoding u8");

let b = [1u8, 0, 4, 77, 88, 1, 99];
let b = [1u8, 0, 4, 77, 88, 1, 99, 0, 0, 3, 255, 42, 73];
assert_eq!(&b[..], &v[..]);
}

Expand Down
19 changes: 17 additions & 2 deletions tls_codec/tests/encode_bytes.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
use tls_codec::{SerializeBytes, TlsByteVecU16, TlsByteVecU32, TlsByteVecU8};
use tls_codec::{SerializeBytes, TlsByteVecU16, TlsByteVecU24, TlsByteVecU32, TlsByteVecU8, U24};

#[test]
fn serialize_primitives() {
let mut v = Vec::new();
v.append(&mut 77u8.tls_serialize().expect("Error encoding u8"));
v.append(&mut 88u8.tls_serialize().expect("Error encoding u8"));
v.append(&mut 355u16.tls_serialize().expect("Error encoding u16"));
let b = [77u8, 88, 1, 99];
v.append(
&mut U24::try_from(65609usize)
.unwrap()
.tls_serialize()
.expect("Error encoding U24"),
);
let b = [77u8, 88, 1, 99, 1, 0, 73];
assert_eq!(&b[..], &v[..]);
}

Expand Down Expand Up @@ -59,6 +65,15 @@ fn serialize_tls_byte_vec_u16() {
assert_eq!(actual_result, vec![0, 3, 1, 2, 3]);
}

#[test]
fn serialize_tls_byte_vec_u24() {
let byte_vec = TlsByteVecU24::from_slice(&[1, 2, 3]);
let actual_result = byte_vec
.tls_serialize()
.expect("Error encoding byte vector");
assert_eq!(actual_result, vec![0, 0, 3, 1, 2, 3]);
}

#[test]
fn serialize_tls_byte_vec_u32() {
let byte_vec = TlsByteVecU32::from_slice(&[1, 2, 3]);
Expand Down

0 comments on commit 189b4d0

Please sign in to comment.