Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode with mem limit #616

Merged
merged 24 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 111 additions & 23 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::{
},
compact::Compact,
encode_like::EncodeLike,
mem_tracking::DecodeWithMemTracking,
DecodeFinished, Error,
};

Expand Down Expand Up @@ -85,16 +86,21 @@ pub trait Input {
/// This is called when decoding reference-based type is finished.
fn ascend_ref(&mut self) {}

/// Hook that is called before allocating memory on the heap.
fn on_before_alloc_mem(&mut self, _size: usize) -> Result<(), Error> {
Ok(())
}

/// !INTERNAL USE ONLY!
///
/// Decodes a `bytes::Bytes`.
/// Used when decoding a `bytes::Bytes` from a `BytesCursor` input.
#[cfg(feature = "bytes")]
#[doc(hidden)]
fn scale_internal_decode_bytes(&mut self) -> Result<bytes::Bytes, Error>
fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor>
serban300 marked this conversation as resolved.
Show resolved Hide resolved
where
Self: Sized,
{
Vec::<u8>::decode(self).map(bytes::Bytes::from)
None
}
}

Expand Down Expand Up @@ -408,12 +414,32 @@ mod feature_wrapper_bytes {
impl EncodeLike<Bytes> for Vec<u8> {}
}

/// `Input` implementation optimized for decoding `bytes::Bytes`.
#[cfg(feature = "bytes")]
struct BytesCursor {
pub struct BytesCursor {
serban300 marked this conversation as resolved.
Show resolved Hide resolved
bytes: bytes::Bytes,
position: usize,
}

#[cfg(feature = "bytes")]
impl BytesCursor {
/// Create a new instance of `BytesCursor`.
pub fn new(bytes: bytes::Bytes) -> Self {
Self { bytes, position: 0 }
}

fn decode_bytes_with_len(&mut self, length: usize) -> Result<bytes::Bytes, Error> {
bytes::Buf::advance(&mut self.bytes, self.position);
serban300 marked this conversation as resolved.
Show resolved Hide resolved
self.position = 0;

if length > self.bytes.len() {
return Err("Not enough data to fill buffer".into());
}

Ok(self.bytes.split_to(length))
}
}

#[cfg(feature = "bytes")]
impl Input for BytesCursor {
fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
Expand All @@ -430,17 +456,11 @@ impl Input for BytesCursor {
Ok(())
}

fn scale_internal_decode_bytes(&mut self) -> Result<bytes::Bytes, Error> {
let length = <Compact<u32>>::decode(self)?.0 as usize;

bytes::Buf::advance(&mut self.bytes, self.position);
self.position = 0;

if length > self.bytes.len() {
return Err("Not enough data to fill buffer".into());
}

Ok(self.bytes.split_to(length))
fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor>
where
Self: Sized,
{
Some(self)
}
}

Expand All @@ -466,17 +486,29 @@ where
// However, if `T` doesn't contain any `Bytes` then this extra allocation is
// technically unnecessary, and we can avoid it by tracking the position ourselves
// and treating the underlying `Bytes` as a fancy `&[u8]`.
let mut input = BytesCursor { bytes, position: 0 };
let mut input = BytesCursor::new(bytes);
T::decode(&mut input)
}

#[cfg(feature = "bytes")]
impl Decode for bytes::Bytes {
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
input.scale_internal_decode_bytes()
let len = <Compact<u32>>::decode(input)?.0 as usize;
if input.__private_bytes_cursor().is_some() {
input.on_before_alloc_mem(len)?;
}

if let Some(bytes_cursor) = input.__private_bytes_cursor() {
bytes_cursor.decode_bytes_with_len(len)
} else {
gui1117 marked this conversation as resolved.
Show resolved Hide resolved
decode_vec_with_len::<u8, _>(input, len).map(bytes::Bytes::from)
}
}
}

#[cfg(feature = "bytes")]
impl DecodeWithMemTracking for bytes::Bytes {}

impl<T, X> Encode for X
where
T: Encode + ?Sized,
Expand Down Expand Up @@ -538,6 +570,7 @@ impl<T> WrapperTypeDecode for Box<T> {
// TODO: Use `Box::new_uninit` once that's stable.
let layout = core::alloc::Layout::new::<MaybeUninit<T>>();

input.on_before_alloc_mem(layout.size())?;
let ptr: *mut MaybeUninit<T> = if layout.size() == 0 {
core::ptr::NonNull::dangling().as_ptr()
} else {
Expand Down Expand Up @@ -576,6 +609,8 @@ impl<T> WrapperTypeDecode for Box<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Box<T> {}

impl<T> WrapperTypeDecode for Rc<T> {
type Wrapped = T;

Expand All @@ -588,6 +623,9 @@ impl<T> WrapperTypeDecode for Rc<T> {
}
}

// `Rc<T>` uses `Box::<T>::decode()` internally, so it supports `DecodeWithMemTracking`.
impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Rc<T> {}

#[cfg(target_has_atomic = "ptr")]
impl<T> WrapperTypeDecode for Arc<T> {
type Wrapped = T;
Expand All @@ -601,6 +639,9 @@ impl<T> WrapperTypeDecode for Arc<T> {
}
}

// `Arc<T>` uses `Box::<T>::decode()` internally, so it supports `DecodeWithMemTracking`.
impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Arc<T> {}

impl<T, X> Decode for X
where
T: Decode + Into<X>,
Expand Down Expand Up @@ -690,6 +731,8 @@ impl<T: Decode, E: Decode> Decode for Result<T, E> {
}
}

impl<T: DecodeWithMemTracking, E: DecodeWithMemTracking> DecodeWithMemTracking for Result<T, E> {}

/// Shim type because we can't do a specialised implementation for `Option<bool>` directly.
#[derive(Eq, PartialEq, Clone, Copy)]
pub struct OptionBool(pub Option<bool>);
Expand Down Expand Up @@ -727,6 +770,8 @@ impl Decode for OptionBool {
}
}

impl DecodeWithMemTracking for OptionBool {}

impl<T: EncodeLike<U>, U: Encode> EncodeLike<Option<U>> for Option<T> {}

impl<T: Encode> Encode for Option<T> {
Expand Down Expand Up @@ -763,6 +808,8 @@ impl<T: Decode> Decode for Option<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Option<T> {}

macro_rules! impl_for_non_zero {
( $( $name:ty ),* $(,)? ) => {
$(
Expand Down Expand Up @@ -792,6 +839,8 @@ macro_rules! impl_for_non_zero {
.ok_or_else(|| Error::from("cannot create non-zero number from 0"))
}
}

impl DecodeWithMemTracking for $name {}
)*
}
}
Expand Down Expand Up @@ -995,6 +1044,8 @@ impl<T: Decode, const N: usize> Decode for [T; N] {
}
}

impl<T: DecodeWithMemTracking, const N: usize> DecodeWithMemTracking for [T; N] {}

impl<T: EncodeLike<U>, U: Encode, const N: usize> EncodeLike<[U; N]> for [T; N] {}

impl Encode for str {
Expand Down Expand Up @@ -1024,6 +1075,11 @@ where
}
}

impl<'a, T: ToOwned + DecodeWithMemTracking> DecodeWithMemTracking for Cow<'a, T> where
Cow<'a, T>: Decode
{
}

impl<T> EncodeLike for PhantomData<T> {}

impl<T> Encode for PhantomData<T> {
Expand All @@ -1036,12 +1092,16 @@ impl<T> Decode for PhantomData<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for PhantomData<T> where PhantomData<T>: Decode {}

impl Decode for String {
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
Self::from_utf8(Vec::decode(input)?).map_err(|_| "Invalid utf8 sequence".into())
}
}

impl DecodeWithMemTracking for String {}

/// Writes the compact encoding of `len` do `dest`.
pub(crate) fn compact_encode_len_to<W: Output + ?Sized>(
dest: &mut W,
Expand All @@ -1067,9 +1127,13 @@ impl<T: Encode> Encode for [T] {
}
}

fn decode_vec_chunked<T, F>(len: usize, mut decode_chunk: F) -> Result<Vec<T>, Error>
fn decode_vec_chunked<T, I: Input, F>(
input: &mut I,
len: usize,
mut decode_chunk: F,
) -> Result<Vec<T>, Error>
where
F: FnMut(&mut Vec<T>, usize) -> Result<(), Error>,
F: FnMut(&mut I, &mut Vec<T>, usize) -> Result<(), Error>,
{
const { assert!(MAX_PREALLOCATION >= mem::size_of::<T>()) }
// we have to account for the fact that `mem::size_of::<T>` can be 0 for types like `()`
Expand All @@ -1080,9 +1144,10 @@ where
let mut num_undecoded_items = len;
while num_undecoded_items > 0 {
let chunk_len = chunk_len.min(num_undecoded_items);
input.on_before_alloc_mem(chunk_len.saturating_mul(mem::size_of::<T>()))?;
decoded_vec.reserve_exact(chunk_len);

decode_chunk(&mut decoded_vec, chunk_len)?;
decode_chunk(input, &mut decoded_vec, chunk_len)?;

num_undecoded_items -= chunk_len;
}
Expand Down Expand Up @@ -1110,7 +1175,7 @@ where
}
}

decode_vec_chunked(len, |decoded_vec, chunk_len| {
decode_vec_chunked(input, len, |input, decoded_vec, chunk_len| {
let decoded_vec_len = decoded_vec.len();
let decoded_vec_size = decoded_vec_len * mem::size_of::<T>();
unsafe {
Expand All @@ -1128,7 +1193,7 @@ where
I: Input,
{
input.descend_ref()?;
let vec = decode_vec_chunked(len, |decoded_vec, chunk_len| {
let vec = decode_vec_chunked(input, len, |input, decoded_vec, chunk_len| {
for _ in 0..chunk_len {
decoded_vec.push(T::decode(input)?);
}
Expand Down Expand Up @@ -1180,6 +1245,8 @@ impl<T: Decode> Decode for Vec<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Vec<T> {}

macro_rules! impl_codec_through_iterator {
($(
$type:ident
Expand Down Expand Up @@ -1207,13 +1274,20 @@ macro_rules! impl_codec_through_iterator {
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
<Compact<u32>>::decode(input).and_then(move |Compact(len)| {
input.descend_ref()?;
let result = Result::from_iter((0..len).map(|_| Decode::decode(input)));
let result = Result::from_iter((0..len).map(|_| {
serban300 marked this conversation as resolved.
Show resolved Hide resolved
input.on_before_alloc_mem(0 $( + mem::size_of::<$generics>() )*)?;
gui1117 marked this conversation as resolved.
Show resolved Hide resolved
serban300 marked this conversation as resolved.
Show resolved Hide resolved
Decode::decode(input)
}));
input.ascend_ref();
result
})
}
}

impl<$( $generics: DecodeWithMemTracking ),*> DecodeWithMemTracking
for $type<$( $generics, )*>
where $type<$( $generics, )*>: Decode {}

impl<$( $impl_like_generics )*> EncodeLike<$type<$( $type_like_generics ),*>>
for $type<$( $generics ),*> {}
impl<$( $impl_like_generics )*> EncodeLike<&[( $( $type_like_generics, )* )]>
Expand Down Expand Up @@ -1260,6 +1334,8 @@ impl<T: Decode> Decode for VecDeque<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for VecDeque<T> {}

impl EncodeLike for () {}

impl Encode for () {
gui1117 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1440,6 +1516,8 @@ macro_rules! impl_endians {
Some(mem::size_of::<$t>())
}
}

impl DecodeWithMemTracking for $t {}
)* }
}
macro_rules! impl_one_byte {
Expand All @@ -1465,6 +1543,8 @@ macro_rules! impl_one_byte {
Ok(input.read_byte()? as $t)
}
}

impl DecodeWithMemTracking for $t {}
)* }
}

Expand Down Expand Up @@ -1500,6 +1580,8 @@ impl Decode for bool {
}
}

impl DecodeWithMemTracking for bool {}

impl Encode for Duration {
fn size_hint(&self) -> usize {
mem::size_of::<u64>() + mem::size_of::<u32>()
Expand All @@ -1524,6 +1606,8 @@ impl Decode for Duration {
}
}

impl DecodeWithMemTracking for Duration {}

impl EncodeLike for Duration {}

impl<T> Encode for Range<T>
Expand All @@ -1550,6 +1634,8 @@ where
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Range<T> {}

impl<T> Encode for RangeInclusive<T>
where
T: Encode,
Expand All @@ -1574,6 +1660,8 @@ where
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for RangeInclusive<T> {}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
12 changes: 12 additions & 0 deletions src/depth_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ impl<'a, I: Input> Input for DepthTrackingInput<'a, I> {
self.input.ascend_ref();
self.depth -= 1;
}

fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> {
self.input.on_before_alloc_mem(size)
}

#[cfg(feature = "bytes")]
fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor>
where
Self: Sized,
{
self.input.__private_bytes_cursor()
}
}

impl<T: Decode> DecodeLimit for T {
Expand Down
Loading