diff --git a/src/record.rs b/src/record.rs index 6807820..244e8bd 100644 --- a/src/record.rs +++ b/src/record.rs @@ -134,7 +134,11 @@ impl<'a> Serializable<'a> for MultiPlexedRecord<'a> { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) struct MultiRecord<'a> { + /// The buffer contains concatenated items following this pattern: + /// + /// The two integers are encoded as little endian. buffer: &'a [u8], + /// Offset into the buffer above used while iterating over the serialized items. byte_offset: usize, } @@ -199,25 +203,25 @@ impl<'a> Iterator for MultiRecord<'a> { // no more record return None; } - + const HEADER_LEN: usize = 12; let buffer = &self.buffer[self.byte_offset..]; - if buffer.len() < 10 { + if buffer.len() < HEADER_LEN { // too short: corrupted self.byte_offset = buffer.len(); return Some(Err(MultiRecordCorruption)); } let position = u64::from_le_bytes(buffer[0..8].try_into().unwrap()); - let len = u32::from_le_bytes(buffer[8..12].try_into().unwrap()) as usize; + let len = u32::from_le_bytes(buffer[8..HEADER_LEN].try_into().unwrap()) as usize; - let buffer = &buffer[12..]; + let buffer = &buffer[HEADER_LEN..]; if buffer.len() < len { self.byte_offset = buffer.len(); return Some(Err(MultiRecordCorruption)); } - self.byte_offset += 12 + len; + self.byte_offset += HEADER_LEN + len; Some(Ok((position, &buffer[..len]))) } @@ -225,9 +229,9 @@ impl<'a> Iterator for MultiRecord<'a> { #[cfg(test)] mod tests { + use super::{MultiRecord, MultiPlexedRecord, RecordType}; use std::convert::TryFrom; - - use super::RecordType; + use crate::Serializable; #[test] fn test_record_type_serialize() { @@ -240,4 +244,91 @@ mod tests { } assert_eq!(num_record_types, 4); } + + #[test] + fn test_multirecord_deserialization_ok() { + let mut buffer: Vec = vec![]; + MultiRecord::serialize( + [b"123".as_slice(), b"4567".as_slice()].into_iter(), + 5, + &mut buffer, + ); + match MultiRecord::new(&buffer) { + Err(_) => panic!("Parsing serialized buffers should work"), + Ok(record) => { + let items: Vec<_> = record + .into_iter() + .map(|item| item.expect("Deserializing item should work")) + .collect(); + assert_eq!( + items, + vec![(5u64, b"123".as_slice()), (6u64, b"4567".as_slice())] + ); + } + } + } + + #[test] + fn test_multirecord_deserialization_corruption() { + let mut buffer: Vec = vec![]; + MultiRecord::serialize( + [b"123".as_slice(), b"4567".as_slice()].into_iter(), + 5, + &mut buffer, + ); + let mut num_errors = 0; + for num_truncated_bytes in 1..buffer.len() { + // This should not panic. Typically, this will be an error, but + // deserializing can also succeed (but will have wrong data). + num_errors += MultiRecord::new(&buffer[..buffer.len() - num_truncated_bytes]).is_err() as i32; + } + assert!(num_errors >= 1); + } + + #[test] + fn test_multiplexedrecord_deserialization_ok() { + let mut buffer_multirecord: Vec = vec![]; + MultiRecord::serialize( + [b"123".as_slice()].into_iter(), + 2, + &mut buffer_multirecord, + ); + let record = MultiPlexedRecord::AppendRecords { + queue: "queue_name", + position: 10, + records: MultiRecord::new_unchecked(&buffer_multirecord), + }; + let mut buffer_multiplexed: Vec = vec![]; + record.serialize(&mut buffer_multiplexed); + match MultiPlexedRecord::deserialize(&buffer_multiplexed) { + None => panic!("Deserialization should work"), + Some(parsed_record) => assert_eq!(parsed_record, record), + } + } + + #[test] + fn test_multiplexedrecord_deserialization_corruption() { + let mut buffer_multirecord: Vec = vec![]; + MultiRecord::serialize( + [b"123".as_slice()].into_iter(), + 2, + &mut buffer_multirecord, + ); + let record = MultiPlexedRecord::AppendRecords { + queue: "queue_name", + position: 10, + records: MultiRecord::new_unchecked(&buffer_multirecord), + }; + let mut buffer_multiplexed: Vec = vec![]; + record.serialize(&mut buffer_multiplexed); + + let mut num_errors = 0; + for num_truncated_bytes in 1..buffer_multiplexed.len() { + // This should not panic. Typically, this will be an error, but + // deserializing can also succeed (but will have wrong data). + num_errors += MultiPlexedRecord::deserialize( + &buffer_multiplexed[..buffer_multiplexed.len() - num_truncated_bytes]).is_none() as i32; + } + assert!(num_errors >= 1); + } }