From 2bbd88206a09471cba1210eb5095c96803063954 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 23 Jul 2024 13:44:58 +0300 Subject: [PATCH 01/24] Add mem tracking primitives --- src/codec.rs | 5 +++++ src/lib.rs | 1 + src/mem_tracking.rs | 20 ++++++++++++++++++++ 3 files changed, 26 insertions(+) create mode 100644 src/mem_tracking.rs diff --git a/src/codec.rs b/src/codec.rs index 664116c7..901748b3 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -85,6 +85,11 @@ pub trait Input { /// This is called when decoding reference-based type is finished. fn ascend_ref(&mut self) {} + /// Hook that is called before allocating memory on the heap. + fn on_before_alloc_mem(&mut self, _size: usize) -> Result<(), Error> { + Ok(()) + } + /// !INTERNAL USE ONLY! /// /// Decodes a `bytes::Bytes`. diff --git a/src/lib.rs b/src/lib.rs index 9f95cf69..a8099635 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,6 +57,7 @@ mod joiner; mod keyedvec; #[cfg(feature = "max-encoded-len")] mod max_encoded_len; +mod mem_tracking; #[cfg(feature = "std")] pub use self::codec::IoReader; diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs new file mode 100644 index 00000000..61f35c0a --- /dev/null +++ b/src/mem_tracking.rs @@ -0,0 +1,20 @@ +// Copyright (C) Parity Technologies (UK) Ltd. +// SPDX-License-Identifier: Apache-2.0 + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::Decode; + +/// Marker trait used for identifying types that call the mem tracking hooks exposed by `Input` +/// while decoding. +pub trait DecodeWithMemTracking: Decode {} From 7051378e86ab77c239026991b4695327f5a66393 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 23 Jul 2024 14:27:15 +0300 Subject: [PATCH 02/24] Implement DecodeWithMemTracking for basic types --- src/codec.rs | 73 +++++++++++++++++++++++++++++++++++++++++---- src/mem_tracking.rs | 4 +++ 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index 901748b3..b3e76e79 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -44,6 +44,7 @@ use crate::{ }, compact::Compact, encode_like::EncodeLike, + mem_tracking::DecodeWithMemTracking, DecodeFinished, Error, }; @@ -445,6 +446,7 @@ impl Input for BytesCursor { return Err("Not enough data to fill buffer".into()); } + self.on_before_alloc_mem(length)?; Ok(self.bytes.split_to(length)) } } @@ -482,6 +484,9 @@ impl Decode for bytes::Bytes { } } +#[cfg(feature = "bytes")] +impl DecodeWithMemTracking for bytes::Bytes {} + impl Encode for X where T: Encode + ?Sized, @@ -543,6 +548,7 @@ impl WrapperTypeDecode for Box { // TODO: Use `Box::new_uninit` once that's stable. let layout = core::alloc::Layout::new::>(); + input.on_before_alloc_mem(layout.size())?; let ptr: *mut MaybeUninit = if layout.size() == 0 { core::ptr::NonNull::dangling().as_ptr() } else { @@ -581,6 +587,8 @@ impl WrapperTypeDecode for Box { } } +impl DecodeWithMemTracking for Box {} + impl WrapperTypeDecode for Rc { type Wrapped = T; @@ -593,6 +601,9 @@ impl WrapperTypeDecode for Rc { } } +// `Rc` uses `Box::::decode()` internally, so it supports `DecodeWithMemTracking`. +impl DecodeWithMemTracking for Rc {} + #[cfg(target_has_atomic = "ptr")] impl WrapperTypeDecode for Arc { type Wrapped = T; @@ -606,6 +617,9 @@ impl WrapperTypeDecode for Arc { } } +// `Arc` uses `Box::::decode()` internally, so it supports `DecodeWithMemTracking`. +impl DecodeWithMemTracking for Arc {} + impl Decode for X where T: Decode + Into, @@ -695,6 +709,8 @@ impl Decode for Result { } } +impl DecodeWithMemTracking for Result {} + /// Shim type because we can't do a specialised implementation for `Option` directly. #[derive(Eq, PartialEq, Clone, Copy)] pub struct OptionBool(pub Option); @@ -732,6 +748,8 @@ impl Decode for OptionBool { } } +impl DecodeWithMemTracking for OptionBool {} + impl, U: Encode> EncodeLike> for Option {} impl Encode for Option { @@ -768,6 +786,8 @@ impl Decode for Option { } } +impl DecodeWithMemTracking for Option {} + macro_rules! impl_for_non_zero { ( $( $name:ty ),* $(,)? ) => { $( @@ -797,6 +817,8 @@ macro_rules! impl_for_non_zero { .ok_or_else(|| Error::from("cannot create non-zero number from 0")) } } + + impl DecodeWithMemTracking for $name {} )* } } @@ -1000,6 +1022,8 @@ impl Decode for [T; N] { } } +impl DecodeWithMemTracking for [T; N] {} + impl, U: Encode, const N: usize> EncodeLike<[U; N]> for [T; N] {} impl Encode for str { @@ -1029,6 +1053,11 @@ where } } +impl<'a, T: ToOwned + DecodeWithMemTracking> DecodeWithMemTracking for Cow<'a, T> where + Cow<'a, T>: Decode +{ +} + impl EncodeLike for PhantomData {} impl Encode for PhantomData { @@ -1041,12 +1070,16 @@ impl Decode for PhantomData { } } +impl DecodeWithMemTracking for PhantomData where PhantomData: Decode {} + impl Decode for String { fn decode(input: &mut I) -> Result { Self::from_utf8(Vec::decode(input)?).map_err(|_| "Invalid utf8 sequence".into()) } } +impl DecodeWithMemTracking for String {} + /// Writes the compact encoding of `len` do `dest`. pub(crate) fn compact_encode_len_to( dest: &mut W, @@ -1072,9 +1105,13 @@ impl Encode for [T] { } } -fn decode_vec_chunked(len: usize, mut decode_chunk: F) -> Result, Error> +fn decode_vec_chunked( + input: &mut I, + len: usize, + mut decode_chunk: F, +) -> Result, Error> where - F: FnMut(&mut Vec, usize) -> Result<(), Error>, + F: FnMut(&mut I, &mut Vec, usize) -> Result<(), Error>, { const { assert!(MAX_PREALLOCATION >= mem::size_of::()) } // we have to account for the fact that `mem::size_of::` can be 0 for types like `()` @@ -1085,9 +1122,10 @@ where let mut num_undecoded_items = len; while num_undecoded_items > 0 { let chunk_len = chunk_len.min(num_undecoded_items); + input.on_before_alloc_mem(chunk_len.saturating_mul(mem::size_of::()))?; decoded_vec.reserve_exact(chunk_len); - decode_chunk(&mut decoded_vec, chunk_len)?; + decode_chunk(input, &mut decoded_vec, chunk_len)?; num_undecoded_items -= chunk_len; } @@ -1115,7 +1153,7 @@ where } } - decode_vec_chunked(len, |decoded_vec, chunk_len| { + decode_vec_chunked(input, len, |input, decoded_vec, chunk_len| { let decoded_vec_len = decoded_vec.len(); let decoded_vec_size = decoded_vec_len * mem::size_of::(); unsafe { @@ -1133,7 +1171,7 @@ where I: Input, { input.descend_ref()?; - let vec = decode_vec_chunked(len, |decoded_vec, chunk_len| { + let vec = decode_vec_chunked(input, len, |input, decoded_vec, chunk_len| { for _ in 0..chunk_len { decoded_vec.push(T::decode(input)?); } @@ -1185,6 +1223,8 @@ impl Decode for Vec { } } +impl DecodeWithMemTracking for Vec {} + macro_rules! impl_codec_through_iterator { ($( $type:ident @@ -1212,13 +1252,20 @@ macro_rules! impl_codec_through_iterator { fn decode(input: &mut I) -> Result { >::decode(input).and_then(move |Compact(len)| { input.descend_ref()?; - let result = Result::from_iter((0..len).map(|_| Decode::decode(input))); + let result = Result::from_iter((0..len).map(|_| { + input.on_before_alloc_mem(0 $( + mem::size_of::<$generics>() )*)?; + Decode::decode(input) + })); input.ascend_ref(); result }) } } + impl<$( $generics: DecodeWithMemTracking ),*> DecodeWithMemTracking + for $type<$( $generics, )*> + where $type<$( $generics, )*>: Decode {} + impl<$( $impl_like_generics )*> EncodeLike<$type<$( $type_like_generics ),*>> for $type<$( $generics ),*> {} impl<$( $impl_like_generics )*> EncodeLike<&[( $( $type_like_generics, )* )]> @@ -1265,6 +1312,8 @@ impl Decode for VecDeque { } } +impl DecodeWithMemTracking for VecDeque {} + impl EncodeLike for () {} impl Encode for () { @@ -1445,6 +1494,8 @@ macro_rules! impl_endians { Some(mem::size_of::<$t>()) } } + + impl DecodeWithMemTracking for $t {} )* } } macro_rules! impl_one_byte { @@ -1470,6 +1521,8 @@ macro_rules! impl_one_byte { Ok(input.read_byte()? as $t) } } + + impl DecodeWithMemTracking for $t {} )* } } @@ -1505,6 +1558,8 @@ impl Decode for bool { } } +impl DecodeWithMemTracking for bool {} + impl Encode for Duration { fn size_hint(&self) -> usize { mem::size_of::() + mem::size_of::() @@ -1529,6 +1584,8 @@ impl Decode for Duration { } } +impl DecodeWithMemTracking for Duration {} + impl EncodeLike for Duration {} impl Encode for Range @@ -1555,6 +1612,8 @@ where } } +impl DecodeWithMemTracking for Range {} + impl Encode for RangeInclusive where T: Encode, @@ -1579,6 +1638,8 @@ where } } +impl DecodeWithMemTracking for RangeInclusive {} + #[cfg(test)] mod tests { use super::*; diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 61f35c0a..5080d8f7 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -14,7 +14,11 @@ // limitations under the License. use crate::Decode; +use impl_trait_for_tuples::impl_for_tuples; /// Marker trait used for identifying types that call the mem tracking hooks exposed by `Input` /// while decoding. pub trait DecodeWithMemTracking: Decode {} + +#[impl_for_tuples(18)] +impl DecodeWithMemTracking for Tuple {} From 0313d3b2273e8e6d0588fc86a4ac924ae27ea375 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 23 Jul 2024 17:40:22 +0300 Subject: [PATCH 03/24] Define MemTrackingInput --- src/lib.rs | 1 + src/mem_tracking.rs | 56 +++++++++++++++++++- tests/mem_tracking.rs | 118 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 tests/mem_tracking.rs diff --git a/src/lib.rs b/src/lib.rs index a8099635..d29673dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,6 +75,7 @@ pub use self::{ error::Error, joiner::Joiner, keyedvec::KeyedVec, + mem_tracking::{DecodeWithMemTracking, MemTrackingInput}, }; #[cfg(feature = "max-encoded-len")] pub use const_encoded_len::ConstEncodedLen; diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 5080d8f7..2042b57b 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -13,12 +13,66 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::Decode; +use crate::{Decode, Error, Input}; use impl_trait_for_tuples::impl_for_tuples; /// Marker trait used for identifying types that call the mem tracking hooks exposed by `Input` /// while decoding. pub trait DecodeWithMemTracking: Decode {} +const DECODE_OOM_MSG: &str = "Heap memory limit exceeded while decoding"; + #[impl_for_tuples(18)] impl DecodeWithMemTracking for Tuple {} + +/// `Input` implementation that can be used for limiting the heap memory usage while decoding. +pub struct MemTrackingInput<'a, I> { + input: &'a mut I, + used_mem: usize, + mem_limit: usize, +} + +impl<'a, I: Input> MemTrackingInput<'a, I> { + /// Create a new instance of `MemTrackingInput`. + pub fn new(input: &'a mut I, mem_limit: usize) -> Self { + Self { input, used_mem: 0, mem_limit } + } + + /// Get the `used_mem` field. + pub fn used_mem(&self) -> usize { + self.used_mem + } +} + +impl<'a, I: Input> Input for MemTrackingInput<'a, I> { + fn remaining_len(&mut self) -> Result, Error> { + self.input.remaining_len() + } + + fn read(&mut self, into: &mut [u8]) -> Result<(), Error> { + self.input.read(into) + } + + fn read_byte(&mut self) -> Result { + self.input.read_byte() + } + + fn descend_ref(&mut self) -> Result<(), Error> { + self.input.descend_ref() + } + + fn ascend_ref(&mut self) { + self.input.ascend_ref() + } + + fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> { + self.input.on_before_alloc_mem(size)?; + + self.used_mem = self.used_mem.saturating_add(size); + if self.used_mem >= self.mem_limit { + return Err(DECODE_OOM_MSG.into()) + } + + Ok(()) + } +} diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs new file mode 100644 index 00000000..b9b994ae --- /dev/null +++ b/tests/mem_tracking.rs @@ -0,0 +1,118 @@ +// Copyright (C) Parity Technologies (UK) Ltd. +// SPDX-License-Identifier: Apache-2.0 + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use core::fmt::Debug; +use parity_scale_codec::{ + alloc::{ + collections::{BTreeMap, BTreeSet, LinkedList, VecDeque}, + rc::Rc, + }, + DecodeWithMemTracking, Encode, Error, MemTrackingInput, +}; +use parity_scale_codec_derive::{Decode as DeriveDecode, Encode as DeriveEncode}; + +fn decode_object(obj: T, mem_limit: usize, expected_used_mem: usize) -> Result +where + T: Encode + DecodeWithMemTracking + PartialEq + Debug, +{ + let encoded_bytes = obj.encode(); + let raw_input = &mut &encoded_bytes[..]; + let mut input = MemTrackingInput::new(raw_input, mem_limit); + let decoded_obj = T::decode(&mut input)?; + assert_eq!(&decoded_obj, &obj); + assert_eq!(input.used_mem(), expected_used_mem); + Ok(decoded_obj) +} + +#[test] +fn decode_simple_objects_works() { + const ARRAY: [u8; 1000] = [11; 1000]; + + // Test simple objects + assert!(decode_object(ARRAY, usize::MAX, 0).is_ok()); + assert!(decode_object(Some(ARRAY), usize::MAX, 0).is_ok()); + assert!(decode_object((ARRAY, ARRAY), usize::MAX, 0).is_ok()); + assert!(decode_object(1u8, usize::MAX, 0).is_ok()); + assert!(decode_object(1u32, usize::MAX, 0).is_ok()); + assert!(decode_object(1f64, usize::MAX, 0).is_ok()); + + // Test heap objects + assert!(decode_object(Box::new(ARRAY), usize::MAX, 1000).is_ok()); + #[cfg(target_has_atomic = "ptr")] + { + use parity_scale_codec::alloc::sync::Arc; + assert!(decode_object(Arc::new(ARRAY), usize::MAX, 1000).is_ok()); + } + assert!(decode_object(Rc::new(ARRAY), usize::MAX, 1000).is_ok()); + // Simple collections + assert!(decode_object(vec![ARRAY; 3], usize::MAX, 3000).is_ok()); + assert!(decode_object(VecDeque::from(vec![ARRAY; 5]), usize::MAX, 5000).is_ok()); + assert!(decode_object(String::from("test"), usize::MAX, 4).is_ok()); + #[cfg(feature = "bytes")] + assert!(decode_object(bytes::Bytes::from(&ARRAY[..]), usize::MAX, 1000).is_ok()); + // Complex Collections + assert!(decode_object(BTreeMap::::from([(1, 2), (2, 3)]), usize::MAX, 4).is_ok()); + assert!(decode_object( + BTreeMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ]), + usize::MAX, + 116, + ) + .is_ok()); + assert!(decode_object(BTreeSet::::from([1, 2, 3, 4, 5]), usize::MAX, 5).is_ok()); + assert!(decode_object(LinkedList::::from([1, 2, 3, 4, 5]), usize::MAX, 5).is_ok()); +} + +#[test] +fn decode_complex_objects_works() { + assert!(decode_object(vec![vec![vec![vec![vec![1u8]]]]], usize::MAX, 97).is_ok()); + assert!(decode_object(Box::new(Rc::new(vec![String::from("test")])), usize::MAX, 60).is_ok()); +} + +#[test] +fn decode_complex_derived_struct_works() { + #[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] + #[allow(clippy::large_enum_variant)] + enum TestEnum { + Empty, + Array([u8; 1000]), + } + + impl DecodeWithMemTracking for TestEnum {} + + #[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] + struct ComplexStruct { + test_enum: TestEnum, + boxed_test_enum: Box, + box_field: Box, + vec: Vec, + } + + impl DecodeWithMemTracking for ComplexStruct {} + + assert!(decode_object( + ComplexStruct { + test_enum: TestEnum::Array([0; 1000]), + boxed_test_enum: Box::new(TestEnum::Empty), + box_field: Box::new(1), + vec: vec![1; 10], + }, + usize::MAX, + 1015 + ) + .is_ok()) +} From 5b1b387522a441b7f73184f2ce71cf24e3f6ac8f Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Wed, 24 Jul 2024 14:43:08 +0300 Subject: [PATCH 04/24] Fix `Bytes` decoding --- src/codec.rs | 58 +++++++++++++++++++++++++++++-------------- src/depth_limit.rs | 12 +++++++++ src/lib.rs | 2 +- src/mem_tracking.rs | 8 ++++++ tests/mem_tracking.rs | 14 +++++++++++ 5 files changed, 75 insertions(+), 19 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index b3e76e79..65783693 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -93,14 +93,14 @@ pub trait Input { /// !INTERNAL USE ONLY! /// - /// Decodes a `bytes::Bytes`. + /// Used when decoding a `bytes::Bytes` from a `BytesCursor` input. #[cfg(feature = "bytes")] #[doc(hidden)] - fn scale_internal_decode_bytes(&mut self) -> Result + fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor> where Self: Sized, { - Vec::::decode(self).map(bytes::Bytes::from) + None } } @@ -414,12 +414,32 @@ mod feature_wrapper_bytes { impl EncodeLike for Vec {} } +/// `Input` implementation optimized for decoding `bytes::Bytes`. #[cfg(feature = "bytes")] -struct BytesCursor { +pub struct BytesCursor { bytes: bytes::Bytes, position: usize, } +#[cfg(feature = "bytes")] +impl BytesCursor { + /// Create a new instance of `BytesCursor`. + pub fn new(bytes: bytes::Bytes) -> Self { + Self { bytes, position: 0 } + } + + fn decode_bytes_with_len(&mut self, length: usize) -> Result { + bytes::Buf::advance(&mut self.bytes, self.position); + self.position = 0; + + if length > self.bytes.len() { + return Err("Not enough data to fill buffer".into()); + } + + Ok(self.bytes.split_to(length)) + } +} + #[cfg(feature = "bytes")] impl Input for BytesCursor { fn remaining_len(&mut self) -> Result, Error> { @@ -436,18 +456,11 @@ impl Input for BytesCursor { Ok(()) } - fn scale_internal_decode_bytes(&mut self) -> Result { - let length = >::decode(self)?.0 as usize; - - bytes::Buf::advance(&mut self.bytes, self.position); - self.position = 0; - - if length > self.bytes.len() { - return Err("Not enough data to fill buffer".into()); - } - - self.on_before_alloc_mem(length)?; - Ok(self.bytes.split_to(length)) + fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor> + where + Self: Sized, + { + Some(self) } } @@ -473,14 +486,23 @@ where // However, if `T` doesn't contain any `Bytes` then this extra allocation is // technically unnecessary, and we can avoid it by tracking the position ourselves // and treating the underlying `Bytes` as a fancy `&[u8]`. - let mut input = BytesCursor { bytes, position: 0 }; + let mut input = BytesCursor::new(bytes); T::decode(&mut input) } #[cfg(feature = "bytes")] impl Decode for bytes::Bytes { fn decode(input: &mut I) -> Result { - input.scale_internal_decode_bytes() + let len = >::decode(input)?.0 as usize; + if input.__private_bytes_cursor().is_some() { + input.on_before_alloc_mem(len)?; + } + + if let Some(bytes_cursor) = input.__private_bytes_cursor() { + bytes_cursor.decode_bytes_with_len(len) + } else { + decode_vec_with_len::(input, len).map(bytes::Bytes::from) + } } } diff --git a/src/depth_limit.rs b/src/depth_limit.rs index 2af17843..7d4affb4 100644 --- a/src/depth_limit.rs +++ b/src/depth_limit.rs @@ -64,6 +64,18 @@ impl<'a, I: Input> Input for DepthTrackingInput<'a, I> { self.input.ascend_ref(); self.depth -= 1; } + + fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> { + self.input.on_before_alloc_mem(size) + } + + #[cfg(feature = "bytes")] + fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor> + where + Self: Sized, + { + self.input.__private_bytes_cursor() + } } impl DecodeLimit for T { diff --git a/src/lib.rs b/src/lib.rs index d29673dd..6bea868b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,4 +130,4 @@ pub use max_encoded_len::MaxEncodedLen; pub use parity_scale_codec_derive::MaxEncodedLen; #[cfg(feature = "bytes")] -pub use self::codec::decode_from_bytes; +pub use self::codec::{decode_from_bytes, BytesCursor}; diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 2042b57b..409baa9c 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -75,4 +75,12 @@ impl<'a, I: Input> Input for MemTrackingInput<'a, I> { Ok(()) } + + #[cfg(feature = "bytes")] + fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor> + where + Self: Sized, + { + self.input.__private_bytes_cursor() + } } diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs index b9b994ae..354c7d9a 100644 --- a/tests/mem_tracking.rs +++ b/tests/mem_tracking.rs @@ -83,6 +83,20 @@ fn decode_complex_objects_works() { assert!(decode_object(Box::new(Rc::new(vec![String::from("test")])), usize::MAX, 60).is_ok()); } +#[cfg(feature = "bytes")] +#[test] +fn decode_bytes_from_bytes_works() { + use parity_scale_codec::Decode; + + let obj = ([0u8; 100], Box::new(0u8), bytes::Bytes::from(vec![0u8; 50])); + let encoded_bytes = obj.encode(); + let mut bytes_cursor = parity_scale_codec::BytesCursor::new(bytes::Bytes::from(encoded_bytes)); + let mut input = MemTrackingInput::new(&mut bytes_cursor, usize::MAX); + let decoded_obj = <([u8; 100], Box, bytes::Bytes)>::decode(&mut input).unwrap(); + assert_eq!(&decoded_obj, &obj); + assert_eq!(input.used_mem(), 51); +} + #[test] fn decode_complex_derived_struct_works() { #[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] From a8c31846a93a91ef183c3fa11313cd5306c85b4f Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Mon, 9 Sep 2024 17:54:25 +0300 Subject: [PATCH 05/24] Add clarifying comment --- src/codec.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/codec.rs b/src/codec.rs index 65783693..d1d17055 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -87,6 +87,12 @@ pub trait Input { fn ascend_ref(&mut self) {} /// Hook that is called before allocating memory on the heap. + /// + /// The aim is to get a reasonable approximation of memory usage, especially with variably + /// sized types like `Vec`s. Depending on the structure, it is acceptable to be off by a bit. + /// For example for structures like `BTreeMap` we don't track the memory used by the internal + /// tree nodes etc. Also we don't take alignment or memory layouts into account. + /// But we should always track the memory used by the decoded data inside the type. fn on_before_alloc_mem(&mut self, _size: usize) -> Result<(), Error> { Ok(()) } From 577330dc60783ada62b9e9da5fe60ed2a1a31b33 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 10 Sep 2024 10:49:13 +0300 Subject: [PATCH 06/24] Test mem limit exceeded --- tests/mem_tracking.rs | 70 +++++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs index 354c7d9a..51d53968 100644 --- a/tests/mem_tracking.rs +++ b/tests/mem_tracking.rs @@ -23,6 +23,27 @@ use parity_scale_codec::{ }; use parity_scale_codec_derive::{Decode as DeriveDecode, Encode as DeriveEncode}; +const ARRAY: [u8; 1000] = [11; 1000]; + +#[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] +#[allow(clippy::large_enum_variant)] +enum TestEnum { + Empty, + Array([u8; 1000]), +} + +impl DecodeWithMemTracking for TestEnum {} + +#[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] +struct ComplexStruct { + test_enum: TestEnum, + boxed_test_enum: Box, + box_field: Box, + vec: Vec, +} + +impl DecodeWithMemTracking for ComplexStruct {} + fn decode_object(obj: T, mem_limit: usize, expected_used_mem: usize) -> Result where T: Encode + DecodeWithMemTracking + PartialEq + Debug, @@ -38,8 +59,6 @@ where #[test] fn decode_simple_objects_works() { - const ARRAY: [u8; 1000] = [11; 1000]; - // Test simple objects assert!(decode_object(ARRAY, usize::MAX, 0).is_ok()); assert!(decode_object(Some(ARRAY), usize::MAX, 0).is_ok()); @@ -99,25 +118,6 @@ fn decode_bytes_from_bytes_works() { #[test] fn decode_complex_derived_struct_works() { - #[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] - #[allow(clippy::large_enum_variant)] - enum TestEnum { - Empty, - Array([u8; 1000]), - } - - impl DecodeWithMemTracking for TestEnum {} - - #[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] - struct ComplexStruct { - test_enum: TestEnum, - boxed_test_enum: Box, - box_field: Box, - vec: Vec, - } - - impl DecodeWithMemTracking for ComplexStruct {} - assert!(decode_object( ComplexStruct { test_enum: TestEnum::Array([0; 1000]), @@ -128,5 +128,31 @@ fn decode_complex_derived_struct_works() { usize::MAX, 1015 ) - .is_ok()) + .is_ok()); +} + +#[test] +fn mem_limit_exceeded_is_triggered() { + // Test simple heap object + assert_eq!( + decode_object(Box::new(ARRAY), 999, 999).unwrap_err().to_string(), + "Heap memory limit exceeded while decoding" + ); + + // Test complex derived struct + assert_eq!( + decode_object( + ComplexStruct { + test_enum: TestEnum::Array([0; 1000]), + boxed_test_enum: Box::new(TestEnum::Empty), + box_field: Box::new(1), + vec: vec![1; 10], + }, + 1014, + 1014 + ) + .unwrap_err() + .to_string(), + "Could not decode `ComplexStruct::vec`:\n\tHeap memory limit exceeded while decoding\n" + ); } From adff127b45d5ebbb8bc3e679f77c40f22480ac00 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 10 Sep 2024 11:04:11 +0300 Subject: [PATCH 07/24] Implement DecodeWithMemTracking for more structs --- src/bit_vec.rs | 7 ++++++- src/compact.rs | 21 ++++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/bit_vec.rs b/src/bit_vec.rs index c4a00669..20d65120 100644 --- a/src/bit_vec.rs +++ b/src/bit_vec.rs @@ -15,7 +15,8 @@ //! `BitVec` specific serialization. use crate::{ - codec::decode_vec_with_len, Compact, Decode, Encode, EncodeLike, Error, Input, Output, + codec::decode_vec_with_len, Compact, Decode, DecodeWithMemTracking, Encode, EncodeLike, Error, + Input, Output, }; use bitvec::{ boxed::BitBox, order::BitOrder, slice::BitSlice, store::BitStore, vec::BitVec, view::BitView, @@ -74,6 +75,8 @@ impl Decode for BitVec { } } +impl DecodeWithMemTracking for BitVec {} + impl Encode for BitBox { fn encode_to(&self, dest: &mut W) { self.as_bitslice().encode_to(dest) @@ -88,6 +91,8 @@ impl Decode for BitBox { } } +impl DecodeWithMemTracking for BitBox {} + #[cfg(test)] mod tests { use super::*; diff --git a/src/compact.rs b/src/compact.rs index db5b1768..ffcd2ae4 100644 --- a/src/compact.rs +++ b/src/compact.rs @@ -22,7 +22,7 @@ use crate::{ alloc::vec::Vec, codec::{Decode, Encode, EncodeAsRef, Input, Output}, encode_like::EncodeLike, - Error, + DecodeWithMemTracking, Error, }; #[cfg(feature = "fuzz")] use arbitrary::Arbitrary; @@ -175,6 +175,13 @@ where } } +impl DecodeWithMemTracking for Compact +where + T: CompactAs, + Compact: DecodeWithMemTracking, +{ +} + macro_rules! impl_from_compact { ( $( $ty:ty ),* ) => { $( @@ -482,6 +489,8 @@ impl Decode for Compact<()> { } } +impl DecodeWithMemTracking for Compact<()> {} + const U8_OUT_OF_RANGE: &str = "out of range decoding Compact"; const U16_OUT_OF_RANGE: &str = "out of range decoding Compact"; const U32_OUT_OF_RANGE: &str = "out of range decoding Compact"; @@ -506,6 +515,8 @@ impl Decode for Compact { } } +impl DecodeWithMemTracking for Compact {} + impl Decode for Compact { fn decode(input: &mut I) -> Result { let prefix = input.read_byte()?; @@ -532,6 +543,8 @@ impl Decode for Compact { } } +impl DecodeWithMemTracking for Compact {} + impl Decode for Compact { fn decode(input: &mut I) -> Result { let prefix = input.read_byte()?; @@ -572,6 +585,8 @@ impl Decode for Compact { } } +impl DecodeWithMemTracking for Compact {} + impl Decode for Compact { fn decode(input: &mut I) -> Result { let prefix = input.read_byte()?; @@ -628,6 +643,8 @@ impl Decode for Compact { } } +impl DecodeWithMemTracking for Compact {} + impl Decode for Compact { fn decode(input: &mut I) -> Result { let prefix = input.read_byte()?; @@ -692,6 +709,8 @@ impl Decode for Compact { } } +impl DecodeWithMemTracking for Compact {} + #[cfg(test)] mod tests { use super::*; From 00a6f4fb435888b97b30bfb9d4a8ab5fc4a06a6c Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Thu, 12 Sep 2024 13:04:11 +0200 Subject: [PATCH 08/24] Address CR comments --- src/codec.rs | 4 ++-- src/mem_tracking.rs | 2 +- tests/mem_tracking.rs | 5 +++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index d1d17055..e3e517b3 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -90,7 +90,7 @@ pub trait Input { /// /// The aim is to get a reasonable approximation of memory usage, especially with variably /// sized types like `Vec`s. Depending on the structure, it is acceptable to be off by a bit. - /// For example for structures like `BTreeMap` we don't track the memory used by the internal + /// For example for structures like `BTreeMap` we don't track the memory used by the internal /// tree nodes etc. Also we don't take alignment or memory layouts into account. /// But we should always track the memory used by the decoded data inside the type. fn on_before_alloc_mem(&mut self, _size: usize) -> Result<(), Error> { @@ -1281,7 +1281,7 @@ macro_rules! impl_codec_through_iterator { >::decode(input).and_then(move |Compact(len)| { input.descend_ref()?; let result = Result::from_iter((0..len).map(|_| { - input.on_before_alloc_mem(0 $( + mem::size_of::<$generics>() )*)?; + input.on_before_alloc_mem(0usize$(.saturating_add(mem::size_of::<$generics>()))*)?; Decode::decode(input) })); input.ascend_ref(); diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 409baa9c..5d231ea1 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -16,7 +16,7 @@ use crate::{Decode, Error, Input}; use impl_trait_for_tuples::impl_for_tuples; -/// Marker trait used for identifying types that call the mem tracking hooks exposed by `Input` +/// Marker trait used for identifying types that call the [`Input::on_before_alloc_mem`] hook /// while decoding. pub trait DecodeWithMemTracking: Decode {} diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs index 51d53968..97b27125 100644 --- a/tests/mem_tracking.rs +++ b/tests/mem_tracking.rs @@ -54,6 +54,11 @@ where let decoded_obj = T::decode(&mut input)?; assert_eq!(&decoded_obj, &obj); assert_eq!(input.used_mem(), expected_used_mem); + + if expected_used_mem > 0 { + let mut input = MemTrackingInput::new(raw_input, expected_used_mem); + assert!(T::decode(&mut input).is_err()); + } Ok(decoded_obj) } From f74bd5cbf1e19f05c73c438cd0917ceb1764707d Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Thu, 12 Sep 2024 16:35:52 +0200 Subject: [PATCH 09/24] Revert "Fix `Bytes` decoding" This reverts commit 5b1b387522a441b7f73184f2ce71cf24e3f6ac8f. --- src/codec.rs | 58 ++++++++++++++----------------------------- src/depth_limit.rs | 8 ------ src/lib.rs | 2 +- src/mem_tracking.rs | 8 ------ tests/mem_tracking.rs | 14 ----------- 5 files changed, 19 insertions(+), 71 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index e3e517b3..a4df0675 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -99,14 +99,14 @@ pub trait Input { /// !INTERNAL USE ONLY! /// - /// Used when decoding a `bytes::Bytes` from a `BytesCursor` input. + /// Decodes a `bytes::Bytes`. #[cfg(feature = "bytes")] #[doc(hidden)] - fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor> + fn scale_internal_decode_bytes(&mut self) -> Result where Self: Sized, { - None + Vec::::decode(self).map(bytes::Bytes::from) } } @@ -420,32 +420,12 @@ mod feature_wrapper_bytes { impl EncodeLike for Vec {} } -/// `Input` implementation optimized for decoding `bytes::Bytes`. #[cfg(feature = "bytes")] -pub struct BytesCursor { +struct BytesCursor { bytes: bytes::Bytes, position: usize, } -#[cfg(feature = "bytes")] -impl BytesCursor { - /// Create a new instance of `BytesCursor`. - pub fn new(bytes: bytes::Bytes) -> Self { - Self { bytes, position: 0 } - } - - fn decode_bytes_with_len(&mut self, length: usize) -> Result { - bytes::Buf::advance(&mut self.bytes, self.position); - self.position = 0; - - if length > self.bytes.len() { - return Err("Not enough data to fill buffer".into()); - } - - Ok(self.bytes.split_to(length)) - } -} - #[cfg(feature = "bytes")] impl Input for BytesCursor { fn remaining_len(&mut self) -> Result, Error> { @@ -462,11 +442,18 @@ impl Input for BytesCursor { Ok(()) } - fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor> - where - Self: Sized, - { - Some(self) + fn scale_internal_decode_bytes(&mut self) -> Result { + let length = >::decode(self)?.0 as usize; + + bytes::Buf::advance(&mut self.bytes, self.position); + self.position = 0; + + if length > self.bytes.len() { + return Err("Not enough data to fill buffer".into()); + } + + self.on_before_alloc_mem(length)?; + Ok(self.bytes.split_to(length)) } } @@ -492,23 +479,14 @@ where // However, if `T` doesn't contain any `Bytes` then this extra allocation is // technically unnecessary, and we can avoid it by tracking the position ourselves // and treating the underlying `Bytes` as a fancy `&[u8]`. - let mut input = BytesCursor::new(bytes); + let mut input = BytesCursor { bytes, position: 0 }; T::decode(&mut input) } #[cfg(feature = "bytes")] impl Decode for bytes::Bytes { fn decode(input: &mut I) -> Result { - let len = >::decode(input)?.0 as usize; - if input.__private_bytes_cursor().is_some() { - input.on_before_alloc_mem(len)?; - } - - if let Some(bytes_cursor) = input.__private_bytes_cursor() { - bytes_cursor.decode_bytes_with_len(len) - } else { - decode_vec_with_len::(input, len).map(bytes::Bytes::from) - } + input.scale_internal_decode_bytes() } } diff --git a/src/depth_limit.rs b/src/depth_limit.rs index 7d4affb4..8b3a7edf 100644 --- a/src/depth_limit.rs +++ b/src/depth_limit.rs @@ -68,14 +68,6 @@ impl<'a, I: Input> Input for DepthTrackingInput<'a, I> { fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> { self.input.on_before_alloc_mem(size) } - - #[cfg(feature = "bytes")] - fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor> - where - Self: Sized, - { - self.input.__private_bytes_cursor() - } } impl DecodeLimit for T { diff --git a/src/lib.rs b/src/lib.rs index 6bea868b..d29673dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,4 +130,4 @@ pub use max_encoded_len::MaxEncodedLen; pub use parity_scale_codec_derive::MaxEncodedLen; #[cfg(feature = "bytes")] -pub use self::codec::{decode_from_bytes, BytesCursor}; +pub use self::codec::decode_from_bytes; diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 5d231ea1..82db454d 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -75,12 +75,4 @@ impl<'a, I: Input> Input for MemTrackingInput<'a, I> { Ok(()) } - - #[cfg(feature = "bytes")] - fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor> - where - Self: Sized, - { - self.input.__private_bytes_cursor() - } } diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs index 97b27125..f852f5ad 100644 --- a/tests/mem_tracking.rs +++ b/tests/mem_tracking.rs @@ -107,20 +107,6 @@ fn decode_complex_objects_works() { assert!(decode_object(Box::new(Rc::new(vec![String::from("test")])), usize::MAX, 60).is_ok()); } -#[cfg(feature = "bytes")] -#[test] -fn decode_bytes_from_bytes_works() { - use parity_scale_codec::Decode; - - let obj = ([0u8; 100], Box::new(0u8), bytes::Bytes::from(vec![0u8; 50])); - let encoded_bytes = obj.encode(); - let mut bytes_cursor = parity_scale_codec::BytesCursor::new(bytes::Bytes::from(encoded_bytes)); - let mut input = MemTrackingInput::new(&mut bytes_cursor, usize::MAX); - let decoded_obj = <([u8; 100], Box, bytes::Bytes)>::decode(&mut input).unwrap(); - assert_eq!(&decoded_obj, &obj); - assert_eq!(input.used_mem(), 51); -} - #[test] fn decode_complex_derived_struct_works() { assert!(decode_object( From 3bda5cafe7d28f1deb1bd6aadacce583906aabaf Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Wed, 25 Sep 2024 17:31:02 +0300 Subject: [PATCH 10/24] Deduplications --- derive/src/decode.rs | 48 +++++++++++++++----------------------------- derive/src/encode.rs | 27 +++++++++---------------- derive/src/utils.rs | 29 +++++++++++++++++++++----- 3 files changed, 50 insertions(+), 54 deletions(-) diff --git a/derive/src/decode.rs b/derive/src/decode.rs index 7f2d08b2..94ffd86b 100644 --- a/derive/src/decode.rs +++ b/derive/src/decode.rs @@ -31,33 +31,20 @@ pub fn quote( crate_path: &syn::Path, ) -> TokenStream { match *data { - Data::Struct(ref data) => match data.fields { - Fields::Named(_) | Fields::Unnamed(_) => create_instance( - quote! { #type_name #type_generics }, - &type_name.to_string(), - input, - &data.fields, - crate_path, - ), - Fields::Unit => { - quote_spanned! { data.fields.span() => - ::core::result::Result::Ok(#type_name) - } - }, - }, + Data::Struct(ref data) => create_instance( + quote! { #type_name #type_generics }, + &type_name.to_string(), + input, + &data.fields, + crate_path, + ), Data::Enum(ref data) => { - let data_variants = - || data.variants.iter().filter(|variant| !utils::should_skip(&variant.attrs)); - - if data_variants().count() > 256 { - return Error::new( - data.variants.span(), - "Currently only enums with at most 256 variants are encodable.", - ) - .to_compile_error(); - } + let variants = match utils::try_get_variants(data) { + Ok(variants) => variants, + Err(e) => return e.to_compile_error(), + }; - let recurse = data_variants().enumerate().map(|(i, v)| { + let recurse = variants.iter().enumerate().map(|(i, v)| { let name = &v.ident; let index = utils::variant_index(v, i); @@ -198,12 +185,12 @@ fn create_decode_expr( crate_path: &syn::Path, ) -> TokenStream { let encoded_as = utils::get_encoded_as_type(field); - let compact = utils::is_compact(field); + let compact = utils::get_compact_type(field, crate_path); let skip = utils::should_skip(&field.attrs); let res = quote!(__codec_res_edqy); - if encoded_as.is_some() as u8 + compact as u8 + skip as u8 > 1 { + if encoded_as.is_some() as u8 + compact.is_some() as u8 + skip as u8 > 1 { return Error::new( field.span(), "`encoded_as`, `compact` and `skip` can only be used one at a time!", @@ -213,13 +200,10 @@ fn create_decode_expr( let err_msg = format!("Could not decode `{}`", name); - if compact { - let field_type = &field.ty; + if let Some(compact) = compact { quote_spanned! { field.span() => { - let #res = < - <#field_type as #crate_path::HasCompact>::Type as #crate_path::Decode - >::decode(#input); + let #res = <#compact as #crate_path::Decode>::decode(#input); match #res { ::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)), ::core::result::Result::Ok(#res) => #res.into(), diff --git a/derive/src/encode.rs b/derive/src/encode.rs index 142bb439..df7af38a 100644 --- a/derive/src/encode.rs +++ b/derive/src/encode.rs @@ -28,7 +28,7 @@ fn encode_single_field( crate_path: &syn::Path, ) -> TokenStream { let encoded_as = utils::get_encoded_as_type(field); - let compact = utils::is_compact(field); + let compact = utils::get_compact_type(field, crate_path); if utils::should_skip(&field.attrs) { return Error::new( @@ -38,7 +38,7 @@ fn encode_single_field( .to_compile_error(); } - if encoded_as.is_some() && compact { + if encoded_as.is_some() && compact.is_some() { return Error::new( Span::call_site(), "`encoded_as` and `compact` can not be used at the same time!", @@ -46,12 +46,11 @@ fn encode_single_field( .to_compile_error(); } - let final_field_variable = if compact { + let final_field_variable = if let Some(compact) = compact { let field_type = &field.ty; quote_spanned! { field.span() => { - <<#field_type as #crate_path::HasCompact>::Type as - #crate_path::EncodeAsRef<'_, #field_type>>::RefType::from(#field_name) + <#compact as #crate_path::EncodeAsRef<'_, #field_type>>::RefType::from(#field_name) } } } else if let Some(encoded_as) = encoded_as { @@ -298,23 +297,17 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS Fields::Unit => [quote! { 0_usize }, quote!()], }, Data::Enum(ref data) => { - let data_variants = - || data.variants.iter().filter(|variant| !utils::should_skip(&variant.attrs)); - - if data_variants().count() > 256 { - return Error::new( - data.variants.span(), - "Currently only enums with at most 256 variants are encodable.", - ) - .to_compile_error(); - } + let variants = match utils::try_get_variants(data) { + Ok(variants) => variants, + Err(e) => return e.to_compile_error(), + }; // If the enum has no variants, we don't need to encode anything. - if data_variants().count() == 0 { + if variants.is_empty() { return quote!(); } - let recurse = data_variants().enumerate().map(|(i, f)| { + let recurse = variants.iter().enumerate().map(|(i, f)| { let name = &f.ident; let index = utils::variant_index(f, i); diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 091a45ee..bd5805f4 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -22,8 +22,9 @@ use std::str::FromStr; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ - parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput, - Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant, + parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DataEnum, + DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, + Path, Variant, }; fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option @@ -85,17 +86,22 @@ pub fn get_encoded_as_type(field: &Field) -> Option { } /// Look for a `#[codec(compact)]` outer attribute on the given `Field`. -pub fn is_compact(field: &Field) -> bool { +pub fn get_compact_type(field: &Field, crate_path: &syn::Path) -> Option { find_meta_item(field.attrs.iter(), |meta| { if let NestedMeta::Meta(Meta::Path(ref path)) = meta { if path.is_ident("compact") { - return Some(()); + let field_type = &field.ty; + return Some(quote! {<#field_type as #crate_path::HasCompact>::Type}); } } None }) - .is_some() +} + +/// Look for a `#[codec(compact)]` outer attribute on the given `Field`. +pub fn is_compact(field: &Field) -> bool { + get_compact_type(field, &parse_quote!(::crate)).is_some() } /// Look for a `#[codec(skip)]` in the given attributes. @@ -449,3 +455,16 @@ pub fn is_transparent(attrs: &[syn::Attribute]) -> bool { // TODO: When migrating to syn 2 the `"(transparent)"` needs to be changed into `"transparent"`. check_repr(attrs, "(transparent)") } + +pub fn try_get_variants(data: &DataEnum) -> Result, syn::Error> { + let data_variants = || data.variants.iter().filter(|variant| !should_skip(&variant.attrs)); + + if data_variants().count() > 256 { + return Err(syn::Error::new( + data.variants.span(), + "Currently only enums with at most 256 variants are encodable/decodable.", + )) + } + + Ok(data_variants().collect()) +} From 9b4ac0687893bda36ec52258f90c6603c5eb68cf Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Fri, 20 Sep 2024 10:40:46 +0300 Subject: [PATCH 11/24] Derive DecodeWithMemTracking --- Cargo.lock | 2 +- Cargo.toml | 4 +- derive/src/decode.rs | 46 +++++++++++- derive/src/lib.rs | 51 ++++++++++++- derive/src/utils.rs | 16 +++- src/codec.rs | 2 +- src/compact.rs | 14 +++- src/mem_tracking.rs | 5 +- tests/clippy.rs | 4 +- tests/mem_tracking.rs | 13 ++-- tests/mod.rs | 103 +++++++++++++++----------- tests/single_field_struct_encoding.rs | 67 +++++++++++++---- tests/skip.rs | 18 +++-- tests/type_inference.rs | 8 +- 14 files changed, 266 insertions(+), 87 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f28c33e9..7101073e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -575,7 +575,7 @@ checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267" [[package]] name = "parity-scale-codec" -version = "3.6.8" +version = "3.6.12" dependencies = [ "arbitrary", "arrayvec", diff --git a/Cargo.toml b/Cargo.toml index 24d553c9..061d30a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "parity-scale-codec" description = "SCALE - Simple Concatenating Aggregated Little Endians" -version = "3.6.8" +version = "3.6.12" authors = ["Parity Technologies "] license = "Apache-2.0" repository = "https://github.com/paritytech/parity-scale-codec" @@ -13,7 +13,7 @@ rust-version = "1.60.0" arrayvec = { version = "0.7", default-features = false } serde = { version = "1.0.204", default-features = false, optional = true } parity-scale-codec-derive = { path = "derive", version = ">= 3.6.8", default-features = false, optional = true } -bitvec = { version = "1", default-features = false, features = [ "alloc" ], optional = true } +bitvec = { version = "1", default-features = false, features = ["alloc"], optional = true } bytes = { version = "1", default-features = false, optional = true } byte-slice-cast = { version = "1.2.2", default-features = false } generic-array = { version = "0.14.7", optional = true } diff --git a/derive/src/decode.rs b/derive/src/decode.rs index 94ffd86b..2ff28c4c 100644 --- a/derive/src/decode.rs +++ b/derive/src/decode.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::utils; use proc_macro2::{Ident, Span, TokenStream}; +use quote::ToTokens; +use std::iter; use syn::{spanned::Spanned, Data, Error, Field, Fields}; -use crate::utils; - /// Generate function block for function `Decode::decode`. /// /// * data: data info of the type, @@ -284,3 +285,44 @@ fn create_instance( }, } } + +pub fn quote_decode_with_mem_tracking(data: &Data, crate_path: &syn::Path) -> TokenStream { + let fields: Box> = match data { + Data::Struct(data) => Box::new(data.fields.iter()), + Data::Enum(ref data) => { + let variants = match utils::try_get_variants(data) { + Ok(variants) => variants, + Err(e) => return e.to_compile_error(), + }; + + let mut fields: Box> = Box::new(iter::empty()); + for variant in variants { + fields = Box::new(fields.chain(variant.fields.iter())); + } + fields + }, + Data::Union(_) => { + return Error::new(Span::call_site(), "Union types are not supported.") + .to_compile_error(); + }, + }; + + let processed_fields = fields.filter_map(|field| { + if utils::should_skip(&field.attrs) { + return None; + } + + let field_type = if let Some(compact) = utils::get_compact_type(field, crate_path) { + compact + } else if let Some(encoded_as) = utils::get_encoded_as_type(field) { + encoded_as + } else { + field.ty.to_token_stream() + }; + Some(quote! {<#field_type as #crate_path::DecodeWithMemTracking>::__is_implemented();}) + }); + + quote! { + #(#processed_fields)* + } +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs index a74f6262..6a0644aa 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -168,7 +168,7 @@ pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream wrap_with_dummy_const(input, impl_block) } -/// Derive `parity_scale_codec::Decode` and for struct and enum. +/// Derive `parity_scale_codec::Decode` for struct and enum. /// /// see derive `Encode` documentation. #[proc_macro_derive(Decode, attributes(codec))] @@ -240,6 +240,55 @@ pub fn decode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream wrap_with_dummy_const(input, impl_block) } +/// Derive `parity_scale_codec::DecodeWithMemTracking` for struct and enum. +#[proc_macro_derive(DecodeWithMemTracking, attributes(codec))] +pub fn decode_with_mem_tracking_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let mut input: DeriveInput = match syn::parse(input) { + Ok(input) => input, + Err(e) => return e.to_compile_error().into(), + }; + + if let Err(e) = utils::check_attributes(&input) { + return e.to_compile_error().into(); + } + + let crate_path = match codec_crate_path(&input.attrs) { + Ok(crate_path) => crate_path, + Err(error) => return error.into_compile_error().into(), + }; + + if let Err(e) = trait_bounds::add( + &input.ident, + &mut input.generics, + &input.data, + utils::custom_decode_with_mem_tracking_trait_bound(&input.attrs), + parse_quote!(#crate_path::DecodeWithMemTracking), + Some(parse_quote!(Default)), + utils::has_dumb_trait_bound(&input.attrs), + &crate_path, + ) { + return e.to_compile_error().into(); + } + + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let decode_with_mem_tracking_body = + decode::quote_decode_with_mem_tracking(&input.data, &crate_path); + let impl_block = quote! { + #[automatically_derived] + impl #impl_generics #crate_path::DecodeWithMemTracking for #name #ty_generics #where_clause { + fn __is_implemented() { + #decode_with_mem_tracking_body + } + } + }; + + // panic!("mem tracking impl_block: {}", impl_block.to_string()); + + wrap_with_dummy_const(input, impl_block) +} + /// Derive `parity_scale_codec::Compact` and `parity_scale_codec::CompactAs` for struct with single /// field. /// diff --git a/derive/src/utils.rs b/derive/src/utils.rs index bd5805f4..fbf7e619 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -235,6 +235,7 @@ impl Parse for CustomTraitBound { syn::custom_keyword!(encode_bound); syn::custom_keyword!(decode_bound); +syn::custom_keyword!(decode_with_mem_tracking_bound); syn::custom_keyword!(mel_bound); syn::custom_keyword!(skip_type_params); @@ -245,6 +246,15 @@ pub fn custom_decode_trait_bound(attrs: &[Attribute]) -> Option Option> { + find_meta_item(attrs.iter(), Some) +} + /// Look for a `#[codec(encode_bound(T: Encode))]` in the given attributes. /// /// If found, it should be used as trait bounds when deriving the `Encode` trait. @@ -411,11 +421,13 @@ fn check_variant_attribute(attr: &Attribute) -> syn::Result<()> { fn check_top_attribute(attr: &Attribute) -> syn::Result<()> { let top_error = "Invalid attribute: only `#[codec(dumb_trait_bound)]`, \ `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, \ - `#[codec(decode_bound(T: Decode))]`, or `#[codec(mel_bound(T: MaxEncodedLen))]` \ - are accepted as top attribute"; + `#[codec(decode_bound(T: Decode))]`, \ + `#[codec(decode_bound_with_mem_tracking_bound(T: Decode))]` or \ + `#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute"; if attr.path.is_ident("codec") && attr.parse_args::>().is_err() && attr.parse_args::>().is_err() && + attr.parse_args::>().is_err() && attr.parse_args::>().is_err() && codec_crate_path_inner(attr).is_none() { diff --git a/src/codec.rs b/src/codec.rs index a4df0675..cddcd843 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -1076,7 +1076,7 @@ impl Decode for PhantomData { } } -impl DecodeWithMemTracking for PhantomData where PhantomData: Decode {} +impl DecodeWithMemTracking for PhantomData where PhantomData: Decode {} impl Decode for String { fn decode(input: &mut I) -> Result { diff --git a/src/compact.rs b/src/compact.rs index ffcd2ae4..06c6cf4d 100644 --- a/src/compact.rs +++ b/src/compact.rs @@ -256,7 +256,12 @@ impl MaybeMaxEncodedLen for T {} /// Trait that tells you if a given type can be encoded/decoded in a compact way. pub trait HasCompact: Sized { /// The compact type; this can be - type Type: for<'a> EncodeAsRef<'a, Self> + Decode + From + Into + MaybeMaxEncodedLen; + type Type: for<'a> EncodeAsRef<'a, Self> + + Decode + + DecodeWithMemTracking + + From + + Into + + MaybeMaxEncodedLen; } impl<'a, T: 'a> EncodeAsRef<'a, T> for Compact @@ -280,7 +285,12 @@ where impl HasCompact for T where - Compact: for<'a> EncodeAsRef<'a, T> + Decode + From + Into + MaybeMaxEncodedLen, + Compact: for<'a> EncodeAsRef<'a, T> + + Decode + + DecodeWithMemTracking + + From + + Into + + MaybeMaxEncodedLen, { type Type = Compact; } diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 82db454d..182a1b87 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -18,7 +18,10 @@ use impl_trait_for_tuples::impl_for_tuples; /// Marker trait used for identifying types that call the [`Input::on_before_alloc_mem`] hook /// while decoding. -pub trait DecodeWithMemTracking: Decode {} +pub trait DecodeWithMemTracking: Decode { + /// Internal method used to derive `DecodeWithMemTracking`. + fn __is_implemented() {} +} const DECODE_OOM_MSG: &str = "Heap memory limit exceeded while decoding"; diff --git a/tests/clippy.rs b/tests/clippy.rs index f1cb7e6d..3b552645 100644 --- a/tests/clippy.rs +++ b/tests/clippy.rs @@ -16,10 +16,10 @@ //! This file is checked by clippy to make sure that the code generated by the derive macro //! doesn't spew out warnings/errors in users' code. -use parity_scale_codec_derive::{Decode, Encode}; +use parity_scale_codec_derive::{Decode, DecodeWithMemTracking, Encode}; #[repr(u8)] -#[derive(Decode, Encode)] +#[derive(Decode, DecodeWithMemTracking, Encode)] pub enum CLike { Foo = 0, Bar = 1, diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs index f852f5ad..3361ae71 100644 --- a/tests/mem_tracking.rs +++ b/tests/mem_tracking.rs @@ -21,20 +21,21 @@ use parity_scale_codec::{ }, DecodeWithMemTracking, Encode, Error, MemTrackingInput, }; -use parity_scale_codec_derive::{Decode as DeriveDecode, Encode as DeriveEncode}; +use parity_scale_codec_derive::{ + Decode as DeriveDecode, DecodeWithMemTracking as DeriveDecodeWithMemTracking, + Encode as DeriveEncode, +}; const ARRAY: [u8; 1000] = [11; 1000]; -#[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] +#[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, PartialEq, Debug)] #[allow(clippy::large_enum_variant)] enum TestEnum { Empty, Array([u8; 1000]), } -impl DecodeWithMemTracking for TestEnum {} - -#[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] +#[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, PartialEq, Debug)] struct ComplexStruct { test_enum: TestEnum, boxed_test_enum: Box, @@ -42,8 +43,6 @@ struct ComplexStruct { vec: Vec, } -impl DecodeWithMemTracking for ComplexStruct {} - fn decode_object(obj: T, mem_limit: usize, expected_used_mem: usize) -> Result where T: Encode + DecodeWithMemTracking + PartialEq + Debug, diff --git a/tests/mod.rs b/tests/mod.rs index ff86c646..08b88922 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -13,25 +13,29 @@ // limitations under the License. use parity_scale_codec::{ - Compact, CompactAs, Decode, Encode, EncodeAsRef, Error, HasCompact, Output, + Compact, CompactAs, Decode, DecodeWithMemTracking, Encode, EncodeAsRef, Error, HasCompact, + Output, +}; +use parity_scale_codec_derive::{ + Decode as DeriveDecode, DecodeWithMemTracking as DeriveDecodeWithMemTracking, + Encode as DeriveEncode, }; -use parity_scale_codec_derive::{Decode as DeriveDecode, Encode as DeriveEncode}; use serde_derive::{Deserialize, Serialize}; -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct Unit; -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct Indexed(u32, u64); -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, Default)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, Default)] struct Struct { pub a: A, pub b: B, pub c: C, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct StructWithPhantom { pub a: u32, pub b: u64, @@ -46,7 +50,7 @@ impl Struct { } } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] enum EnumType { #[codec(index = 15)] A, @@ -57,26 +61,26 @@ enum EnumType { }, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] enum EnumWithDiscriminant { A = 1, B = 15, C = 255, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct TestHasCompact { #[codec(encoded_as = "::Type")] bar: T, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct TestCompactHasCompact { #[codec(compact)] bar: T, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] enum TestHasCompactEnum { Unnamed(#[codec(encoded_as = "::Type")] T), Named { @@ -90,13 +94,13 @@ enum TestHasCompactEnum { }, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct TestCompactAttribute { #[codec(compact)] bar: u64, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] enum TestCompactAttributeEnum { Unnamed(#[codec(compact)] u64), Named { @@ -329,7 +333,7 @@ fn associated_type_bounds() { type NonEncodableType; } - #[derive(DeriveEncode, DeriveDecode, Debug, PartialEq)] + #[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, Debug, PartialEq)] struct Struct { field: (Vec, Type), } @@ -372,7 +376,7 @@ fn generic_bound_encoded_as() { type RefType = &'a u32; } - #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] + #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct TestGeneric> where u32: for<'a> EncodeAsRef<'a, A>, @@ -408,7 +412,7 @@ fn generic_bound_hascompact() { } } - #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] + #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] enum TestGenericHasCompact { A { #[codec(compact)] @@ -429,14 +433,14 @@ fn generic_trait() { struct StructNoCodec; - #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] + #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct StructCodec; impl TraitNoCodec for StructNoCodec { type Type = StructCodec; } - #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] + #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct TestGenericTrait { t: T::Type, } @@ -448,7 +452,9 @@ fn generic_trait() { #[test] fn recursive_variant_1_encode_works() { - #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, Default)] + #[derive( + Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, Default, + )] struct Recursive { data: N, other: Vec>, @@ -460,7 +466,9 @@ fn recursive_variant_1_encode_works() { #[test] fn recursive_variant_2_encode_works() { - #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, Default)] + #[derive( + Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, Default, + )] struct Recursive { data: N, other: Vec>>, @@ -476,10 +484,14 @@ fn private_type_in_where_bound() { // an error. #![deny(warnings)] - #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, Default)] + #[derive( + Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, Default, + )] struct Private; - #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, Default)] + #[derive( + Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, Default, + )] #[codec(dumb_trait_bound)] pub struct Test { data: Vec<(N, Private)>, @@ -491,7 +503,7 @@ fn private_type_in_where_bound() { #[test] fn encode_decode_empty_enum() { - #[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)] + #[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking, PartialEq, Debug)] enum EmptyEnumDerive {} fn impls_encode_decode() {} @@ -513,13 +525,13 @@ fn codec_vec_u8() { #[test] fn recursive_type() { - #[derive(DeriveEncode, DeriveDecode)] + #[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] pub enum Foo { T(Box), A, } - #[derive(DeriveEncode, DeriveDecode)] + #[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] pub struct Bar { field: Foo, } @@ -567,7 +579,7 @@ fn weird_derive() { }; } - make_struct!(#[derive(DeriveEncode, DeriveDecode)]); + make_struct!(#[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)]); } #[test] @@ -577,15 +589,16 @@ fn output_trait_object() { #[test] fn custom_trait_bound() { - #[derive(DeriveEncode, DeriveDecode)] + #[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] #[codec(encode_bound(N: Encode, T: Default))] #[codec(decode_bound(N: Decode, T: Default))] + #[codec(decode_with_mem_tracking_bound(N: DecodeWithMemTracking, T: Default))] struct Something { hello: Hello, val: N, } - #[derive(DeriveEncode, DeriveDecode)] + #[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] #[codec(encode_bound())] #[codec(decode_bound())] struct Hello { @@ -620,7 +633,7 @@ fn bit_vec_works() { assert_eq!(original_vec, original_vec_clone); - #[derive(DeriveDecode, DeriveEncode, PartialEq, Debug)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking, DeriveEncode, PartialEq, Debug)] struct MyStruct { v: BitVec, x: u8, @@ -637,7 +650,7 @@ fn bit_vec_works() { #[test] fn no_warning_for_deprecated() { - #[derive(DeriveEncode, DeriveDecode)] + #[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] pub enum MyEnum { VariantA, #[deprecated] @@ -665,11 +678,11 @@ fn decoding_a_huge_array_inside_of_arc_does_not_overflow_the_stack() { #[test] fn decoding_a_huge_boxed_newtype_array_does_not_overflow_the_stack() { - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] #[repr(transparent)] struct HugeArrayNewtype([u8; 100 * 1024 * 1024]); - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] struct HugeArrayNewtypeBox(#[allow(dead_code)] Box); let data = &[]; @@ -680,10 +693,10 @@ fn decoding_a_huge_boxed_newtype_array_does_not_overflow_the_stack() { fn decoding_two_indirectly_boxed_arrays_works() { // This test will fail if the check for `#[repr(transparent)]` in the derive crate // doesn't work when implementing `Decode::decode_into`. - #[derive(DeriveDecode, PartialEq, Eq, Debug)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking, PartialEq, Eq, Debug)] struct SmallArrays([u8; 2], [u8; 2]); - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] struct SmallArraysBox(Box); let data = &[1, 2, 3, 4]; @@ -695,16 +708,16 @@ fn decoding_two_indirectly_boxed_arrays_works() { #[test] fn zero_sized_types_are_properly_decoded_in_a_transparent_boxed_struct() { - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] #[repr(transparent)] struct ZstTransparent; - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] struct ZstNonTransparent; struct ConsumeByte; - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] #[repr(transparent)] struct NewtypeWithZst { _zst_1: ConsumeByte, @@ -714,7 +727,7 @@ fn zero_sized_types_are_properly_decoded_in_a_transparent_boxed_struct() { _zst_4: ConsumeByte, } - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] struct NewtypeWithZstBox(#[allow(dead_code)] Box); impl Decode for ConsumeByte { @@ -725,21 +738,23 @@ fn zero_sized_types_are_properly_decoded_in_a_transparent_boxed_struct() { } } + impl DecodeWithMemTracking for ConsumeByte {} + let data = &[1, 2, 3]; assert_eq!(NewtypeWithZst::decode(&mut data.as_slice()).unwrap().field, [2]); } #[test] fn boxed_zero_sized_newtype_with_everything_being_transparent_is_decoded_correctly() { - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] #[repr(transparent)] struct Zst; - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] #[repr(transparent)] struct NewtypeWithZst(Zst); - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] #[repr(transparent)] struct NewtypeWithZstBox(Box); @@ -765,7 +780,7 @@ fn incomplete_decoding_of_an_array_drops_partially_read_elements_if_reading_fail pub static COUNTER: core::cell::Cell = const { core::cell::Cell::new(0) }; } - #[derive(DeriveDecode)] + #[derive(DeriveDecode, DeriveDecodeWithMemTracking)] struct Foobar(#[allow(dead_code)] u8); impl Drop for Foobar { @@ -822,10 +837,10 @@ fn incomplete_decoding_of_an_array_drops_partially_read_elements_if_reading_pani #[test] fn deserializing_of_big_recursively_nested_enum_works() { - #[derive(PartialEq, Eq, DeriveDecode, DeriveEncode)] + #[derive(PartialEq, Eq, DeriveDecode, DeriveDecodeWithMemTracking, DeriveEncode)] struct Data([u8; 1472]); - #[derive(PartialEq, Eq, DeriveDecode, DeriveEncode)] + #[derive(PartialEq, Eq, DeriveDecode, DeriveDecodeWithMemTracking, DeriveEncode)] #[allow(clippy::large_enum_variant)] enum Enum { Nested(Vec), diff --git a/tests/single_field_struct_encoding.rs b/tests/single_field_struct_encoding.rs index 565a7d72..24ff2909 100644 --- a/tests/single_field_struct_encoding.rs +++ b/tests/single_field_struct_encoding.rs @@ -1,16 +1,27 @@ use parity_scale_codec::{Compact, Decode, Encode, HasCompact}; use parity_scale_codec_derive::{ - CompactAs as DeriveCompactAs, Decode as DeriveDecode, Encode as DeriveEncode, + CompactAs as DeriveCompactAs, Decode as DeriveDecode, + DecodeWithMemTracking as DeriveDecodeWithMemTracking, Encode as DeriveEncode, }; use serde_derive::{Deserialize, Serialize}; -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct S { x: u32, } #[cfg_attr(feature = "std", derive(Serialize, Deserialize))] -#[derive(Debug, PartialEq, Eq, Clone, Copy, DeriveEncode, DeriveDecode, DeriveCompactAs)] +#[derive( + Debug, + PartialEq, + Eq, + Clone, + Copy, + DeriveEncode, + DeriveDecode, + DeriveDecodeWithMemTracking, + DeriveCompactAs, +)] struct SSkip { #[codec(skip)] s1: u32, @@ -19,45 +30,75 @@ struct SSkip { s2: u32, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct Sc { #[codec(compact)] x: u32, } -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct Sh { #[codec(encoded_as = "::Type")] x: T, } #[cfg_attr(feature = "std", derive(Serialize, Deserialize))] -#[derive(Debug, PartialEq, Eq, Clone, Copy, DeriveEncode, DeriveDecode, DeriveCompactAs)] +#[derive( + Debug, + PartialEq, + Eq, + Clone, + Copy, + DeriveEncode, + DeriveDecode, + DeriveDecodeWithMemTracking, + DeriveCompactAs, +)] struct U(u32); #[cfg_attr(feature = "std", derive(Serialize, Deserialize))] -#[derive(Debug, PartialEq, Eq, Clone, Copy, DeriveEncode, DeriveDecode, DeriveCompactAs)] +#[derive( + Debug, + PartialEq, + Eq, + Clone, + Copy, + DeriveEncode, + DeriveDecode, + DeriveDecodeWithMemTracking, + DeriveCompactAs, +)] struct U2 { a: u64, } #[cfg_attr(feature = "std", derive(Serialize, Deserialize))] -#[derive(Debug, PartialEq, Eq, Clone, Copy, DeriveEncode, DeriveDecode, DeriveCompactAs)] +#[derive( + Debug, + PartialEq, + Eq, + Clone, + Copy, + DeriveEncode, + DeriveDecode, + DeriveDecodeWithMemTracking, + DeriveCompactAs, +)] struct USkip(#[codec(skip)] u32, u32, #[codec(skip)] u32); -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct Uc(#[codec(compact)] u32); -#[derive(Debug, PartialEq, Clone, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, Clone, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct Ucas(#[codec(compact)] U); -#[derive(Debug, PartialEq, Clone, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, Clone, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct USkipcas(#[codec(compact)] USkip); -#[derive(Debug, PartialEq, Clone, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, Clone, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct SSkipcas(#[codec(compact)] SSkip); -#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode)] +#[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct Uh(#[codec(encoded_as = "::Type")] T); #[test] diff --git a/tests/skip.rs b/tests/skip.rs index 20f37311..16ccbce9 100644 --- a/tests/skip.rs +++ b/tests/skip.rs @@ -1,5 +1,8 @@ use parity_scale_codec::{Decode, Encode}; -use parity_scale_codec_derive::{Decode as DeriveDecode, Encode as DeriveEncode}; +use parity_scale_codec_derive::{ + Decode as DeriveDecode, DecodeWithMemTracking as DeriveDecodeWithMemTracking, + Encode as DeriveEncode, +}; #[test] fn enum_struct_test() { @@ -9,8 +12,11 @@ fn enum_struct_test() { #[derive(PartialEq, Debug)] struct UncodecUndefaultType; - use parity_scale_codec_derive::{Decode as DeriveDecode, Encode as DeriveEncode}; - #[derive(PartialEq, Debug, DeriveEncode, DeriveDecode)] + use parity_scale_codec_derive::{ + Decode as DeriveDecode, DecodeWithMemTracking as DeriveDecodeWithMemTracking, + Encode as DeriveEncode, + }; + #[derive(PartialEq, Debug, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] enum Enum { #[codec(skip)] A(S), @@ -22,14 +28,14 @@ fn enum_struct_test() { C(#[codec(skip)] T, u32), } - #[derive(PartialEq, Debug, DeriveEncode, DeriveDecode)] + #[derive(PartialEq, Debug, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct StructNamed { #[codec(skip)] a: T, b: u32, } - #[derive(PartialEq, Debug, DeriveEncode, DeriveDecode)] + #[derive(PartialEq, Debug, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] struct StructUnnamed(#[codec(skip)] T, u32); let ea: Enum = Enum::A(UncodecUndefaultType); @@ -56,7 +62,7 @@ fn skip_enum_struct_inner_variant() { // Make sure the skipping does not generates a warning. #![deny(warnings)] - #[derive(DeriveEncode, DeriveDecode)] + #[derive(DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] enum Enum { Data { some_named: u32, diff --git a/tests/type_inference.rs b/tests/type_inference.rs index dc91d7f5..9ef665d6 100644 --- a/tests/type_inference.rs +++ b/tests/type_inference.rs @@ -15,17 +15,19 @@ //! Test for type inference issue in decode. use parity_scale_codec::Decode; -use parity_scale_codec_derive::Decode as DeriveDecode; +use parity_scale_codec_derive::{ + Decode as DeriveDecode, DecodeWithMemTracking as DeriveDecodeWithMemTracking, +}; pub trait Trait { type Value; type AccountId: Decode; } -#[derive(DeriveDecode)] +#[derive(DeriveDecode, DeriveDecodeWithMemTracking)] pub enum A { _C((T::AccountId, T::AccountId), Vec<(T::Value, T::Value)>), } -#[derive(DeriveDecode)] +#[derive(DeriveDecode, DeriveDecodeWithMemTracking)] pub struct B((T::AccountId, T::AccountId), #[allow(dead_code)] Vec<(T::Value, T::Value)>); From 63432ad1cb688b3d123c593f823d1df8d51f9ec6 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Wed, 25 Sep 2024 21:01:37 +0300 Subject: [PATCH 12/24] Collections mem tracking adjustments --- src/codec.rs | 108 +++++++++++++++++++++++++++++++++--------- src/lib.rs | 2 +- src/mem_tracking.rs | 19 ++++++++ tests/mem_tracking.rs | 19 +++++--- 4 files changed, 117 insertions(+), 31 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index cddcd843..a204df99 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -1231,7 +1231,7 @@ impl Decode for Vec { impl DecodeWithMemTracking for Vec {} -macro_rules! impl_codec_through_iterator { +macro_rules! impl_encode_for_collection { ($( $type:ident { $( $generics:ident $( : $decode_additional:ident )? ),* } @@ -1240,7 +1240,7 @@ macro_rules! impl_codec_through_iterator { )*) => {$( impl<$( $generics: Encode ),*> Encode for $type<$( $generics, )*> { fn size_hint(&self) -> usize { - mem::size_of::() $( + mem::size_of::<$generics>() * self.len() )* + mem::size_of::() + mem::size_of::<($($generics,)*)>().saturating_mul(self.len()) } fn encode_to(&self, dest: &mut W) { @@ -1252,26 +1252,6 @@ macro_rules! impl_codec_through_iterator { } } - impl<$( $generics: Decode $( + $decode_additional )? ),*> Decode - for $type<$( $generics, )*> - { - fn decode(input: &mut I) -> Result { - >::decode(input).and_then(move |Compact(len)| { - input.descend_ref()?; - let result = Result::from_iter((0..len).map(|_| { - input.on_before_alloc_mem(0usize$(.saturating_add(mem::size_of::<$generics>()))*)?; - Decode::decode(input) - })); - input.ascend_ref(); - result - }) - } - } - - impl<$( $generics: DecodeWithMemTracking ),*> DecodeWithMemTracking - for $type<$( $generics, )*> - where $type<$( $generics, )*>: Decode {} - impl<$( $impl_like_generics )*> EncodeLike<$type<$( $type_like_generics ),*>> for $type<$( $generics ),*> {} impl<$( $impl_like_generics )*> EncodeLike<&[( $( $type_like_generics, )* )]> @@ -1281,17 +1261,99 @@ macro_rules! impl_codec_through_iterator { )*} } -impl_codec_through_iterator! { +// Constants from rust's source: +// https://doc.rust-lang.org/src/alloc/collections/btree/node.rs.html#43-45 +const BTREE_B: usize = 6; +const BTREE_CAPACITY: usize = 2 * BTREE_B - 1; +const BTREE_MIN_LEN_AFTER_SPLIT: usize = BTREE_B - 1; + +/// Estimate the mem size of a btree. +fn mem_size_of_btree(len: u32) -> usize { + // We try to estimate the size of the `InternalNode` struct from: + // https://doc.rust-lang.org/src/alloc/collections/btree/node.rs.html#97 + // A btree `LeafNode` has 2*B - 1 (K,V) pairs and (usize, u16, u16) overhead. + // An `InternalNode` additionally has 2*B `usize` overhead. + let node_size = mem::size_of::<(usize, u16, u16, [T; BTREE_CAPACITY], [usize; 2 * BTREE_B])>(); + // A node can contain between B - 1 and 2*B - 1 elements, so we assume it has the midpoint. + let num_nodes = (len as usize).saturating_div((BTREE_CAPACITY + BTREE_MIN_LEN_AFTER_SPLIT) / 2); + core::cmp::max(num_nodes, 1).saturating_mul(node_size) +} + +impl_encode_for_collection! { BTreeMap { K: Ord, V } { LikeK, LikeV} { K: EncodeLike, LikeK: Encode, V: EncodeLike, LikeV: Encode } +} + +impl Decode for BTreeMap { + fn decode(input: &mut I) -> Result { + >::decode(input).and_then(move |Compact(len)| { + input.descend_ref()?; + input.on_before_alloc_mem(mem_size_of_btree::<(K, V)>(len))?; + let result = Result::from_iter((0..len).map(|_| Decode::decode(input))); + input.ascend_ref(); + result + }) + } +} + +impl DecodeWithMemTracking for BTreeMap where + BTreeMap: Decode +{ +} + +impl_encode_for_collection! { BTreeSet { T: Ord } { LikeT } { T: EncodeLike, LikeT: Encode } +} + +impl Decode for BTreeSet { + fn decode(input: &mut I) -> Result { + >::decode(input).and_then(move |Compact(len)| { + input.descend_ref()?; + input.on_before_alloc_mem(mem_size_of_btree::(len))?; + let result = Result::from_iter((0..len).map(|_| Decode::decode(input))); + input.ascend_ref(); + result + }) + } +} +impl DecodeWithMemTracking for BTreeSet where BTreeSet: Decode {} + +impl_encode_for_collection! { LinkedList { T } { LikeT } { T: EncodeLike, LikeT: Encode } +} + +impl Decode for LinkedList { + fn decode(input: &mut I) -> Result { + >::decode(input).and_then(move |Compact(len)| { + input.descend_ref()?; + let result = Result::from_iter((0..len).map(|_| { + // We account for the size of the `prev` and `next` pointers of each list node, + // plus the decoded element. + input.on_before_alloc_mem(mem::size_of::<(usize, usize, T)>())?; + Decode::decode(input) + })); + input.ascend_ref(); + result + }) + } +} + +impl DecodeWithMemTracking for LinkedList where LinkedList: Decode {} + +impl_encode_for_collection! { BinaryHeap { T: Ord } { LikeT } { T: EncodeLike, LikeT: Encode } } +impl Decode for BinaryHeap { + fn decode(input: &mut I) -> Result { + Ok(Vec::decode(input)?.into()) + } +} +impl DecodeWithMemTracking for BinaryHeap where BinaryHeap: Decode {} + impl EncodeLike for VecDeque {} impl, U: Encode> EncodeLike<&[U]> for VecDeque {} impl, U: Encode> EncodeLike> for &[T] {} diff --git a/src/lib.rs b/src/lib.rs index d29673dd..8d6ab4c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,7 +75,7 @@ pub use self::{ error::Error, joiner::Joiner, keyedvec::KeyedVec, - mem_tracking::{DecodeWithMemTracking, MemTrackingInput}, + mem_tracking::{DecodeWithMemLimit, DecodeWithMemTracking, MemTrackingInput}, }; #[cfg(feature = "max-encoded-len")] pub use const_encoded_len::ConstEncodedLen; diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 182a1b87..c6043fb3 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -79,3 +79,22 @@ impl<'a, I: Input> Input for MemTrackingInput<'a, I> { Ok(()) } } + +/// Extension trait to [`Decode`] for decoding with a maximum memory limit. +pub trait DecodeWithMemLimit: DecodeWithMemTracking { + /// Decode `Self` with the given maximum memory limit and advance `input` by the number of + /// bytes consumed. + /// + /// If `mem_limit` is hit, an error is returned. + fn decode_with_mem_limit(input: &mut I, mem_limit: usize) -> Result; +} + +impl DecodeWithMemLimit for T +where + T: DecodeWithMemTracking, +{ + fn decode_with_mem_limit(input: &mut I, mem_limit: usize) -> Result { + let mut input = MemTrackingInput::new(input, mem_limit); + T::decode(&mut input) + } +} diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs index 3361ae71..574650e7 100644 --- a/tests/mem_tracking.rs +++ b/tests/mem_tracking.rs @@ -19,7 +19,7 @@ use parity_scale_codec::{ collections::{BTreeMap, BTreeSet, LinkedList, VecDeque}, rc::Rc, }, - DecodeWithMemTracking, Encode, Error, MemTrackingInput, + DecodeWithMemLimit, DecodeWithMemTracking, Encode, Error, MemTrackingInput, }; use parity_scale_codec_derive::{ Decode as DeriveDecode, DecodeWithMemTracking as DeriveDecodeWithMemTracking, @@ -45,19 +45,24 @@ struct ComplexStruct { fn decode_object(obj: T, mem_limit: usize, expected_used_mem: usize) -> Result where - T: Encode + DecodeWithMemTracking + PartialEq + Debug, + T: Encode + DecodeWithMemTracking + DecodeWithMemLimit + PartialEq + Debug, { let encoded_bytes = obj.encode(); + + let decoded_obj = T::decode_with_mem_limit(&mut &encoded_bytes[..], mem_limit)?; + assert_eq!(&decoded_obj, &obj); + let raw_input = &mut &encoded_bytes[..]; let mut input = MemTrackingInput::new(raw_input, mem_limit); let decoded_obj = T::decode(&mut input)?; assert_eq!(&decoded_obj, &obj); assert_eq!(input.used_mem(), expected_used_mem); - if expected_used_mem > 0 { + let raw_input = &mut &encoded_bytes[..]; let mut input = MemTrackingInput::new(raw_input, expected_used_mem); assert!(T::decode(&mut input).is_err()); } + Ok(decoded_obj) } @@ -86,18 +91,18 @@ fn decode_simple_objects_works() { #[cfg(feature = "bytes")] assert!(decode_object(bytes::Bytes::from(&ARRAY[..]), usize::MAX, 1000).is_ok()); // Complex Collections - assert!(decode_object(BTreeMap::::from([(1, 2), (2, 3)]), usize::MAX, 4).is_ok()); + assert!(decode_object(BTreeMap::::from([(1, 2), (2, 3)]), usize::MAX, 136).is_ok()); assert!(decode_object( BTreeMap::from([ ("key1".to_string(), "value1".to_string()), ("key2".to_string(), "value2".to_string()), ]), usize::MAX, - 116, + 660, ) .is_ok()); - assert!(decode_object(BTreeSet::::from([1, 2, 3, 4, 5]), usize::MAX, 5).is_ok()); - assert!(decode_object(LinkedList::::from([1, 2, 3, 4, 5]), usize::MAX, 5).is_ok()); + assert!(decode_object(BTreeSet::::from([1, 2, 3, 4, 5]), usize::MAX, 120).is_ok()); + assert!(decode_object(LinkedList::::from([1, 2, 3, 4, 5]), usize::MAX, 120).is_ok()); } #[test] From d0659d9eef6f88fbe8a79b3d2620f2b7c14341dc Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Thu, 26 Sep 2024 14:46:00 +0300 Subject: [PATCH 13/24] Fixes --- derive/src/utils.rs | 2 +- src/codec.rs | 4 ++-- tests/max_encoded_len_ui/crate_str.stderr | 2 +- tests/max_encoded_len_ui/incomplete_attr.stderr | 2 +- tests/max_encoded_len_ui/missing_crate_specifier.stderr | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/derive/src/utils.rs b/derive/src/utils.rs index fbf7e619..79a1e1d3 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -422,7 +422,7 @@ fn check_top_attribute(attr: &Attribute) -> syn::Result<()> { let top_error = "Invalid attribute: only `#[codec(dumb_trait_bound)]`, \ `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, \ `#[codec(decode_bound(T: Decode))]`, \ - `#[codec(decode_bound_with_mem_tracking_bound(T: Decode))]` or \ + `#[codec(decode_bound_with_mem_tracking_bound(T: DecodeWithMemTracking))]` or \ `#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute"; if attr.path.is_ident("codec") && attr.parse_args::>().is_err() && diff --git a/src/codec.rs b/src/codec.rs index a204df99..545a5e29 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -90,8 +90,8 @@ pub trait Input { /// /// The aim is to get a reasonable approximation of memory usage, especially with variably /// sized types like `Vec`s. Depending on the structure, it is acceptable to be off by a bit. - /// For example for structures like `BTreeMap` we don't track the memory used by the internal - /// tree nodes etc. Also we don't take alignment or memory layouts into account. + /// In some cases we might not track the memory used by internal sub-structures, and + /// also we don't take alignment or memory layouts into account. /// But we should always track the memory used by the decoded data inside the type. fn on_before_alloc_mem(&mut self, _size: usize) -> Result<(), Error> { Ok(()) diff --git a/tests/max_encoded_len_ui/crate_str.stderr b/tests/max_encoded_len_ui/crate_str.stderr index 0583b229..a90fdb79 100644 --- a/tests/max_encoded_len_ui/crate_str.stderr +++ b/tests/max_encoded_len_ui/crate_str.stderr @@ -1,4 +1,4 @@ -error: Invalid attribute: only `#[codec(dumb_trait_bound)]`, `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, `#[codec(decode_bound(T: Decode))]`, or `#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute +error: Invalid attribute: only `#[codec(dumb_trait_bound)]`, `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, `#[codec(decode_bound(T: Decode))]`, `#[codec(decode_bound_with_mem_tracking_bound(T: DecodeWithMemTracking))]` or `#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute --> tests/max_encoded_len_ui/crate_str.rs:4:9 | 4 | #[codec(crate = "parity_scale_codec")] diff --git a/tests/max_encoded_len_ui/incomplete_attr.stderr b/tests/max_encoded_len_ui/incomplete_attr.stderr index a7649942..2f17bdbf 100644 --- a/tests/max_encoded_len_ui/incomplete_attr.stderr +++ b/tests/max_encoded_len_ui/incomplete_attr.stderr @@ -1,4 +1,4 @@ -error: Invalid attribute: only `#[codec(dumb_trait_bound)]`, `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, `#[codec(decode_bound(T: Decode))]`, or `#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute +error: Invalid attribute: only `#[codec(dumb_trait_bound)]`, `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, `#[codec(decode_bound(T: Decode))]`, `#[codec(decode_bound_with_mem_tracking_bound(T: DecodeWithMemTracking))]` or `#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute --> tests/max_encoded_len_ui/incomplete_attr.rs:4:9 | 4 | #[codec(crate)] diff --git a/tests/max_encoded_len_ui/missing_crate_specifier.stderr b/tests/max_encoded_len_ui/missing_crate_specifier.stderr index b575f512..ab5a74aa 100644 --- a/tests/max_encoded_len_ui/missing_crate_specifier.stderr +++ b/tests/max_encoded_len_ui/missing_crate_specifier.stderr @@ -1,4 +1,4 @@ -error: Invalid attribute: only `#[codec(dumb_trait_bound)]`, `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, `#[codec(decode_bound(T: Decode))]`, or `#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute +error: Invalid attribute: only `#[codec(dumb_trait_bound)]`, `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, `#[codec(decode_bound(T: Decode))]`, `#[codec(decode_bound_with_mem_tracking_bound(T: DecodeWithMemTracking))]` or `#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute --> tests/max_encoded_len_ui/missing_crate_specifier.rs:4:9 | 4 | #[codec(parity_scale_codec)] From 37f7b71ac1610b49d3c64a1191794b62dcec9c21 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Mon, 30 Sep 2024 11:11:26 +0300 Subject: [PATCH 14/24] Code review comments --- derive/src/decode.rs | 10 +++++++--- derive/src/lib.rs | 13 ++++++------- derive/src/utils.rs | 7 ++++--- src/codec.rs | 12 ++++++------ src/mem_tracking.rs | 5 +---- 5 files changed, 24 insertions(+), 23 deletions(-) diff --git a/derive/src/decode.rs b/derive/src/decode.rs index 2ff28c4c..6b2b0a42 100644 --- a/derive/src/decode.rs +++ b/derive/src/decode.rs @@ -286,7 +286,7 @@ fn create_instance( } } -pub fn quote_decode_with_mem_tracking(data: &Data, crate_path: &syn::Path) -> TokenStream { +pub fn quote_decode_with_mem_tracking_checks(data: &Data, crate_path: &syn::Path) -> TokenStream { let fields: Box> = match data { Data::Struct(data) => Box::new(data.fields.iter()), Data::Enum(ref data) => { @@ -319,10 +319,14 @@ pub fn quote_decode_with_mem_tracking(data: &Data, crate_path: &syn::Path) -> To } else { field.ty.to_token_stream() }; - Some(quote! {<#field_type as #crate_path::DecodeWithMemTracking>::__is_implemented();}) + Some(quote! {#field_type}) }); quote! { - #(#processed_fields)* + fn check_field() {} + + #( + check_field::<#processed_fields>(); + )* } } diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 6a0644aa..2f9ea7d9 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -273,19 +273,18 @@ pub fn decode_with_mem_tracking_derive(input: proc_macro::TokenStream) -> proc_m let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let decode_with_mem_tracking_body = - decode::quote_decode_with_mem_tracking(&input.data, &crate_path); + let decode_with_mem_tracking_checks = + decode::quote_decode_with_mem_tracking_checks(&input.data, &crate_path); let impl_block = quote! { + fn check_struct #impl_generics() #where_clause { + #decode_with_mem_tracking_checks + } + #[automatically_derived] impl #impl_generics #crate_path::DecodeWithMemTracking for #name #ty_generics #where_clause { - fn __is_implemented() { - #decode_with_mem_tracking_body - } } }; - // panic!("mem tracking impl_block: {}", impl_block.to_string()); - wrap_with_dummy_const(input, impl_block) } diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 79a1e1d3..26a02840 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -469,14 +469,15 @@ pub fn is_transparent(attrs: &[syn::Attribute]) -> bool { } pub fn try_get_variants(data: &DataEnum) -> Result, syn::Error> { - let data_variants = || data.variants.iter().filter(|variant| !should_skip(&variant.attrs)); + let data_variants: Vec<_> = + data.variants.iter().filter(|variant| !should_skip(&variant.attrs)).collect(); - if data_variants().count() > 256 { + if data_variants.len() > 256 { return Err(syn::Error::new( data.variants.span(), "Currently only enums with at most 256 variants are encodable/decodable.", )) } - Ok(data_variants().collect()) + Ok(data_variants) } diff --git a/src/codec.rs b/src/codec.rs index 545a5e29..7035e7ef 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -1328,12 +1328,12 @@ impl Decode for LinkedList { fn decode(input: &mut I) -> Result { >::decode(input).and_then(move |Compact(len)| { input.descend_ref()?; - let result = Result::from_iter((0..len).map(|_| { - // We account for the size of the `prev` and `next` pointers of each list node, - // plus the decoded element. - input.on_before_alloc_mem(mem::size_of::<(usize, usize, T)>())?; - Decode::decode(input) - })); + // We account for the size of the `prev` and `next` pointers of each list node, + // plus the decoded element. + input.on_before_alloc_mem( + (len as usize).saturating_mul(size_of::<(usize, usize, T)>()), + )?; + let result = Result::from_iter((0..len).map(|_| Decode::decode(input))); input.ascend_ref(); result }) diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index c6043fb3..25ec3372 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -18,10 +18,7 @@ use impl_trait_for_tuples::impl_for_tuples; /// Marker trait used for identifying types that call the [`Input::on_before_alloc_mem`] hook /// while decoding. -pub trait DecodeWithMemTracking: Decode { - /// Internal method used to derive `DecodeWithMemTracking`. - fn __is_implemented() {} -} +pub trait DecodeWithMemTracking: Decode {} const DECODE_OOM_MSG: &str = "Heap memory limit exceeded while decoding"; From de1018a54aee49a25aa3a97f00e1167b0b76fea9 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Mon, 30 Sep 2024 15:33:09 +0300 Subject: [PATCH 15/24] Test mem_size_of_btree() --- .github/workflows/ci.yml | 33 +++++++----- Cargo.toml | 1 + src/btree_utils.rs | 109 +++++++++++++++++++++++++++++++++++++++ src/codec.rs | 22 +------- src/lib.rs | 3 ++ tests/mem_tracking.rs | 6 +-- 6 files changed, 137 insertions(+), 37 deletions(-) create mode 100644 src/btree_utils.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4f762bfe..41b3a228 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,10 +22,10 @@ jobs: - id: set_image run: echo "IMAGE=${{ env.IMAGE }}" >> $GITHUB_OUTPUT -# Checks + # Checks fmt: runs-on: ubuntu-latest - needs: [set-image] + needs: [ set-image ] container: ${{ needs.set-image.outputs.IMAGE }} steps: - name: Checkout code @@ -39,7 +39,7 @@ jobs: clippy: runs-on: ubuntu-latest - needs: [set-image] + needs: [ set-image ] container: ${{ needs.set-image.outputs.IMAGE }} steps: - name: Checkout code/. @@ -64,7 +64,7 @@ jobs: checks: runs-on: ubuntu-latest - needs: [set-image] + needs: [ set-image ] container: ${{ needs.set-image.outputs.IMAGE }} steps: - name: Checkout code @@ -99,10 +99,10 @@ jobs: export RUSTFLAGS='-Cdebug-assertions=y -Dwarnings' time cargo +stable check --verbose --features max-encoded-len -# Tests + # Tests tests: runs-on: ubuntu-latest - needs: [set-image] + needs: [ set-image ] container: ${{ needs.set-image.outputs.IMAGE }} steps: - name: Checkout code @@ -139,13 +139,18 @@ jobs: export RUSTFLAGS='-Cdebug-assertions=y -Dwarnings' time cargo +stable test --verbose --features max-encoded-len,std --no-default-features -# Benches + - name: Run Unstable Rust Tests + run: | + export RUSTFLAGS='-Cdebug-assertions=y -Dwarnings' + time cargo +nightly test --verbose --features unstable-tests --lib btree_utils + + # Benches bench-rust-nightly: runs-on: ubuntu-latest - needs: [set-image] + needs: [ set-image ] strategy: matrix: - feature: [bit-vec,bytes,generic-array,derive,max-encoded-len] + feature: [ bit-vec,bytes,generic-array,derive,max-encoded-len ] container: ${{ needs.set-image.outputs.IMAGE }} steps: - name: Checkout code @@ -163,10 +168,10 @@ jobs: miri: runs-on: ubuntu-latest - needs: [set-image] + needs: [ set-image ] strategy: matrix: - feature: [bit-vec,bytes,generic-array,arbitrary] + feature: [ bit-vec,bytes,generic-array,arbitrary ] container: ${{ needs.set-image.outputs.IMAGE }} steps: - name: Checkout code @@ -184,11 +189,11 @@ jobs: export MIRIFLAGS='-Zmiri-disable-isolation' time cargo +nightly miri test --features ${{ matrix.feature }} --release -# Build + # Build build-linux-ubuntu-amd64: runs-on: ubuntu-latest - needs: [set-image, clippy, checks, tests] + needs: [ set-image, clippy, checks, tests ] container: ${{ needs.set-image.outputs.IMAGE }} steps: - name: Checkout code @@ -204,7 +209,7 @@ jobs: publish-dry-run: runs-on: ubuntu-latest - needs: [set-image, build-linux-ubuntu-amd64] + needs: [ set-image, build-linux-ubuntu-amd64 ] container: ${{ needs.set-image.outputs.IMAGE }} steps: - name: Checkout code diff --git a/Cargo.toml b/Cargo.toml index 061d30a4..ae8287d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ derive = ["parity-scale-codec-derive"] std = ["serde/std", "bitvec?/std", "byte-slice-cast/std", "chain-error"] bit-vec = ["bitvec"] fuzz = ["std", "arbitrary"] +unstable-tests = [] # Enables the new `MaxEncodedLen` trait. # NOTE: This is still considered experimental and is exempt from the usual diff --git a/src/btree_utils.rs b/src/btree_utils.rs new file mode 100644 index 00000000..4acb98a1 --- /dev/null +++ b/src/btree_utils.rs @@ -0,0 +1,109 @@ +// Copyright 2017-2018 Parity Technologies +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use core::mem::size_of; + +// Constants from rust's source: +// https://doc.rust-lang.org/src/alloc/collections/btree/node.rs.html#43-45 +const B: usize = 6; +const CAPACITY: usize = 2 * B - 1; +const MIN_LEN_AFTER_SPLIT: usize = B - 1; + +/// Estimate the mem size of a btree. +pub fn mem_size_of_btree(len: u32) -> usize { + if len == 0 { + return 0; + } + + // We try to estimate the size of the `InternalNode` struct from: + // https://doc.rust-lang.org/src/alloc/collections/btree/node.rs.html#97 + // A btree `LeafNode` has 2*B - 1 (K,V) pairs and (usize, u16, u16) overhead. + let leaf_node_size = size_of::<(usize, u16, u16, [T; CAPACITY])>(); + // An `InternalNode` additionally has 2*B `usize` overhead. + let internal_node_size = leaf_node_size + size_of::<[usize; 2 * B]>(); + // A node can contain between B - 1 and 2*B - 1 elements. We assume 2/3 occupancy. + let num_nodes = (len as usize).saturating_div((CAPACITY + MIN_LEN_AFTER_SPLIT) * 2 / 3); + + // If the tree has only one node, it's a leaf node. + if num_nodes == 0 { + return leaf_node_size; + } + num_nodes.saturating_mul(internal_node_size) +} + +#[cfg(all(test, feature = "unstable-tests"))] +mod test { + use super::*; + use crate::alloc::{ + collections::{BTreeMap, BTreeSet}, + sync::{Arc, Mutex}, + }; + use core::{ + alloc::{AllocError, Allocator, Layout}, + ptr::NonNull, + }; + + #[cfg(feature = "std")] + #[derive(Clone)] + struct MockAllocator { + total: Arc>, + } + + unsafe impl Allocator for MockAllocator { + fn allocate(&self, layout: Layout) -> Result, AllocError> { + let ptr = std::alloc::System.allocate(layout); + if ptr.is_ok() { + *self.total.lock().unwrap() += layout.size(); + } + ptr + } + + unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { + *self.total.lock().unwrap() -= layout.size(); + std::alloc::System.deallocate(ptr, layout) + } + } + + fn check_btree_size(expected_size: usize, actual_size: Arc>) { + /// Check that the margin of error is at most 25%. + assert!(*actual_size.lock().unwrap() as f64 * 0.75 <= expected_size as f64); + assert!(*actual_size.lock().unwrap() as f64 * 1.25 >= expected_size as f64); + } + + #[test] + fn mem_size_of_btree_works() { + let map_allocator = MockAllocator { total: Arc::new(Mutex::new(0)) }; + let map_actual_size = map_allocator.total.clone(); + let mut map = BTreeMap::::new_in(map_allocator); + + let set_allocator = MockAllocator { total: Arc::new(Mutex::new(0)) }; + let set_actual_size = set_allocator.total.clone(); + let mut set = BTreeSet::::new_in(set_allocator); + + for i in 0..1000000 { + map.insert(i, 0); + set.insert(i as u128); + + /// For small number of elements, the differences between the expected size and + /// the actual size can be higher. + if i > 100 { + let map_expected_size = mem_size_of_btree::<(u32, u32)>(map.len() as u32); + check_btree_size(map_expected_size, map_actual_size.clone()); + + let set_expected_size = mem_size_of_btree::(set.len() as u32); + check_btree_size(set_expected_size, set_actual_size.clone()); + } + } + } +} diff --git a/src/codec.rs b/src/codec.rs index 7035e7ef..252497be 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -1261,24 +1261,6 @@ macro_rules! impl_encode_for_collection { )*} } -// Constants from rust's source: -// https://doc.rust-lang.org/src/alloc/collections/btree/node.rs.html#43-45 -const BTREE_B: usize = 6; -const BTREE_CAPACITY: usize = 2 * BTREE_B - 1; -const BTREE_MIN_LEN_AFTER_SPLIT: usize = BTREE_B - 1; - -/// Estimate the mem size of a btree. -fn mem_size_of_btree(len: u32) -> usize { - // We try to estimate the size of the `InternalNode` struct from: - // https://doc.rust-lang.org/src/alloc/collections/btree/node.rs.html#97 - // A btree `LeafNode` has 2*B - 1 (K,V) pairs and (usize, u16, u16) overhead. - // An `InternalNode` additionally has 2*B `usize` overhead. - let node_size = mem::size_of::<(usize, u16, u16, [T; BTREE_CAPACITY], [usize; 2 * BTREE_B])>(); - // A node can contain between B - 1 and 2*B - 1 elements, so we assume it has the midpoint. - let num_nodes = (len as usize).saturating_div((BTREE_CAPACITY + BTREE_MIN_LEN_AFTER_SPLIT) / 2); - core::cmp::max(num_nodes, 1).saturating_mul(node_size) -} - impl_encode_for_collection! { BTreeMap { K: Ord, V } { LikeK, LikeV} { K: EncodeLike, LikeK: Encode, V: EncodeLike, LikeV: Encode } @@ -1288,7 +1270,7 @@ impl Decode for BTreeMap { fn decode(input: &mut I) -> Result { >::decode(input).and_then(move |Compact(len)| { input.descend_ref()?; - input.on_before_alloc_mem(mem_size_of_btree::<(K, V)>(len))?; + input.on_before_alloc_mem(super::btree_utils::mem_size_of_btree::<(K, V)>(len))?; let result = Result::from_iter((0..len).map(|_| Decode::decode(input))); input.ascend_ref(); result @@ -1310,7 +1292,7 @@ impl Decode for BTreeSet { fn decode(input: &mut I) -> Result { >::decode(input).and_then(move |Compact(len)| { input.descend_ref()?; - input.on_before_alloc_mem(mem_size_of_btree::(len))?; + input.on_before_alloc_mem(super::btree_utils::mem_size_of_btree::(len))?; let result = Result::from_iter((0..len).map(|_| Decode::decode(input))); input.ascend_ref(); result diff --git a/src/lib.rs b/src/lib.rs index 8d6ab4c9..015f04e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![cfg_attr(all(test, feature = "unstable-tests"), feature(allocator_api))] +#![cfg_attr(all(test, feature = "unstable-tests"), feature(btreemap_alloc))] #![doc = include_str!("../README.md")] #![warn(missing_docs)] #![cfg_attr(not(feature = "std"), no_std)] @@ -41,6 +43,7 @@ pub mod alloc { #[cfg(feature = "bit-vec")] mod bit_vec; +mod btree_utils; mod codec; mod compact; #[cfg(feature = "max-encoded-len")] diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs index 574650e7..721d3b51 100644 --- a/tests/mem_tracking.rs +++ b/tests/mem_tracking.rs @@ -91,17 +91,17 @@ fn decode_simple_objects_works() { #[cfg(feature = "bytes")] assert!(decode_object(bytes::Bytes::from(&ARRAY[..]), usize::MAX, 1000).is_ok()); // Complex Collections - assert!(decode_object(BTreeMap::::from([(1, 2), (2, 3)]), usize::MAX, 136).is_ok()); + assert!(decode_object(BTreeMap::::from([(1, 2), (2, 3)]), usize::MAX, 40).is_ok()); assert!(decode_object( BTreeMap::from([ ("key1".to_string(), "value1".to_string()), ("key2".to_string(), "value2".to_string()), ]), usize::MAX, - 660, + 564, ) .is_ok()); - assert!(decode_object(BTreeSet::::from([1, 2, 3, 4, 5]), usize::MAX, 120).is_ok()); + assert!(decode_object(BTreeSet::::from([1, 2, 3, 4, 5]), usize::MAX, 24).is_ok()); assert!(decode_object(LinkedList::::from([1, 2, 3, 4, 5]), usize::MAX, 120).is_ok()); } From 3e8fb6125f66ca4309be5b0c740a420af13e06c0 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 1 Oct 2024 10:03:04 +0300 Subject: [PATCH 16/24] Bump minor version --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7101073e..67bd80c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -575,7 +575,7 @@ checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267" [[package]] name = "parity-scale-codec" -version = "3.6.12" +version = "3.7.0" dependencies = [ "arbitrary", "arrayvec", diff --git a/Cargo.toml b/Cargo.toml index ae8287d1..8d43ee00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "parity-scale-codec" description = "SCALE - Simple Concatenating Aggregated Little Endians" -version = "3.6.12" +version = "3.7.0" authors = ["Parity Technologies "] license = "Apache-2.0" repository = "https://github.com/paritytech/parity-scale-codec" From 9b4fcdc4ce8372b035f9383bd5208d540e59e895 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 1 Oct 2024 10:08:36 +0300 Subject: [PATCH 17/24] Fix --- src/codec.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index 252497be..1dce353a 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -1312,9 +1312,11 @@ impl Decode for LinkedList { input.descend_ref()?; // We account for the size of the `prev` and `next` pointers of each list node, // plus the decoded element. - input.on_before_alloc_mem( - (len as usize).saturating_mul(size_of::<(usize, usize, T)>()), - )?; + input.on_before_alloc_mem((len as usize).saturating_mul(mem::size_of::<( + usize, + usize, + T, + )>()))?; let result = Result::from_iter((0..len).map(|_| Decode::decode(input))); input.ascend_ref(); result From 8577b5c77344e9b2e3ca25b0355cab65d5a97a98 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 1 Oct 2024 10:15:05 +0300 Subject: [PATCH 18/24] Fix rustdoc --- src/btree_utils.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/btree_utils.rs b/src/btree_utils.rs index 4acb98a1..3e537849 100644 --- a/src/btree_utils.rs +++ b/src/btree_utils.rs @@ -76,7 +76,7 @@ mod test { } fn check_btree_size(expected_size: usize, actual_size: Arc>) { - /// Check that the margin of error is at most 25%. + // Check that the margin of error is at most 25%. assert!(*actual_size.lock().unwrap() as f64 * 0.75 <= expected_size as f64); assert!(*actual_size.lock().unwrap() as f64 * 1.25 >= expected_size as f64); } @@ -95,8 +95,8 @@ mod test { map.insert(i, 0); set.insert(i as u128); - /// For small number of elements, the differences between the expected size and - /// the actual size can be higher. + // For small number of elements, the differences between the expected size and + // the actual size can be higher. if i > 100 { let map_expected_size = mem_size_of_btree::<(u32, u32)>(map.len() as u32); check_btree_size(map_expected_size, map_actual_size.clone()); From 9b6fe015ffcd9f97d7474ed3402d6b62c6093702 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 1 Oct 2024 15:24:17 +0300 Subject: [PATCH 19/24] Remove unstable-tests feature --- .github/workflows/ci.yml | 4 ++-- Cargo.lock | 7 +++++++ Cargo.toml | 7 ++++++- build.rs | 6 ++++++ src/btree_utils.rs | 4 ++-- src/lib.rs | 4 ++-- 6 files changed, 25 insertions(+), 7 deletions(-) create mode 100644 build.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 41b3a228..471f61fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -139,10 +139,10 @@ jobs: export RUSTFLAGS='-Cdebug-assertions=y -Dwarnings' time cargo +stable test --verbose --features max-encoded-len,std --no-default-features - - name: Run Unstable Rust Tests + - name: Run Nightly Tests run: | export RUSTFLAGS='-Cdebug-assertions=y -Dwarnings' - time cargo +nightly test --verbose --features unstable-tests --lib btree_utils + time cargo +nightly test --verbose --lib btree_utils # Benches bench-rust-nightly: diff --git a/Cargo.lock b/Cargo.lock index 67bd80c6..a9277d86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -589,6 +589,7 @@ dependencies = [ "paste", "proptest", "quickcheck", + "rustversion", "serde", "serde_derive", "trybuild", @@ -834,6 +835,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "rusty-fork" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 8d43ee00..773a1198 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ license = "Apache-2.0" repository = "https://github.com/paritytech/parity-scale-codec" categories = ["encoding"] edition = "2021" +build = "build.rs" rust-version = "1.60.0" [dependencies] @@ -19,6 +20,7 @@ byte-slice-cast = { version = "1.2.2", default-features = false } generic-array = { version = "0.14.7", optional = true } arbitrary = { version = "1.3.2", features = ["derive"], optional = true } impl-trait-for-tuples = "0.2.2" +rustversion = "1.0.17" [dev-dependencies] criterion = "0.4.0" @@ -28,6 +30,10 @@ quickcheck = "1.0" proptest = "1.5.0" trybuild = "1.0.97" paste = "1" +rustversion = "1" + +[build-dependencies] +rustversion = "1" [[bench]] name = "benches" @@ -42,7 +48,6 @@ derive = ["parity-scale-codec-derive"] std = ["serde/std", "bitvec?/std", "byte-slice-cast/std", "chain-error"] bit-vec = ["bitvec"] fuzz = ["std", "arbitrary"] -unstable-tests = [] # Enables the new `MaxEncodedLen` trait. # NOTE: This is still considered experimental and is exempt from the usual diff --git a/build.rs b/build.rs new file mode 100644 index 00000000..8454031d --- /dev/null +++ b/build.rs @@ -0,0 +1,6 @@ +fn main() { + if rustversion::cfg!(nightly) { + println!("cargo:rustc-check-cfg=cfg(nightly)"); + println!("cargo:rustc-cfg=nightly"); + } +} diff --git a/src/btree_utils.rs b/src/btree_utils.rs index 3e537849..50bfbe63 100644 --- a/src/btree_utils.rs +++ b/src/btree_utils.rs @@ -42,7 +42,8 @@ pub fn mem_size_of_btree(len: u32) -> usize { num_nodes.saturating_mul(internal_node_size) } -#[cfg(all(test, feature = "unstable-tests"))] +#[cfg(test)] +#[rustversion::nightly] mod test { use super::*; use crate::alloc::{ @@ -54,7 +55,6 @@ mod test { ptr::NonNull, }; - #[cfg(feature = "std")] #[derive(Clone)] struct MockAllocator { total: Arc>, diff --git a/src/lib.rs b/src/lib.rs index 015f04e5..478460a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![cfg_attr(all(test, feature = "unstable-tests"), feature(allocator_api))] -#![cfg_attr(all(test, feature = "unstable-tests"), feature(btreemap_alloc))] +#![cfg_attr(all(test, nightly), feature(allocator_api))] +#![cfg_attr(all(test, nightly), feature(btreemap_alloc))] #![doc = include_str!("../README.md")] #![warn(missing_docs)] #![cfg_attr(not(feature = "std"), no_std)] From 483c7802764aa6c78d8d04c5708b6eee486863b2 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Wed, 2 Oct 2024 16:39:51 +0300 Subject: [PATCH 20/24] Disable miri for mem_size_of_btree_works() test --- src/btree_utils.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/btree_utils.rs b/src/btree_utils.rs index 50bfbe63..e6c49da5 100644 --- a/src/btree_utils.rs +++ b/src/btree_utils.rs @@ -43,6 +43,7 @@ pub fn mem_size_of_btree(len: u32) -> usize { } #[cfg(test)] +#[cfg(not(miri))] #[rustversion::nightly] mod test { use super::*; From 3e359b1ba000bd07aa5fab12ad4f9aa50e9abfb9 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Thu, 3 Oct 2024 16:58:55 +0300 Subject: [PATCH 21/24] quote_spanned --- build.rs | 2 +- derive/src/decode.rs | 2 +- tests/decode_with_mem_tracking_ui.rs | 22 +++++++++++++++++++ .../trait_bound_not_satisfied.rs | 11 ++++++++++ .../trait_bound_not_satisfied.stderr | 22 +++++++++++++++++++ 5 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 tests/decode_with_mem_tracking_ui.rs create mode 100644 tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.rs create mode 100644 tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.stderr diff --git a/build.rs b/build.rs index 8454031d..76854863 100644 --- a/build.rs +++ b/build.rs @@ -1,6 +1,6 @@ fn main() { + println!("cargo:rustc-check-cfg=cfg(nightly)"); if rustversion::cfg!(nightly) { - println!("cargo:rustc-check-cfg=cfg(nightly)"); println!("cargo:rustc-cfg=nightly"); } } diff --git a/derive/src/decode.rs b/derive/src/decode.rs index 6b2b0a42..1e228776 100644 --- a/derive/src/decode.rs +++ b/derive/src/decode.rs @@ -319,7 +319,7 @@ pub fn quote_decode_with_mem_tracking_checks(data: &Data, crate_path: &syn::Path } else { field.ty.to_token_stream() }; - Some(quote! {#field_type}) + Some(quote_spanned! {field.span() => #field_type}) }); quote! { diff --git a/tests/decode_with_mem_tracking_ui.rs b/tests/decode_with_mem_tracking_ui.rs new file mode 100644 index 00000000..7a916829 --- /dev/null +++ b/tests/decode_with_mem_tracking_ui.rs @@ -0,0 +1,22 @@ +// Copyright (C) 2020-2021 Parity Technologies (UK) Ltd. +// SPDX-License-Identifier: Apache-2.0 + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[test] +#[cfg(feature = "derive")] +fn derive_no_bound_ui() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/decode_with_mem_tracking_ui/*.rs"); + t.pass("tests/decode_with_mem_tracking_ui/pass/*.rs"); +} diff --git a/tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.rs b/tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.rs new file mode 100644 index 00000000..5384b536 --- /dev/null +++ b/tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.rs @@ -0,0 +1,11 @@ +use parity_scale_codec::{Decode, DecodeWithMemTracking}; + +#[derive(Decode)] +struct Base {} + +#[derive(Decode, DecodeWithMemTracking)] +struct Wrapper { + base: Base, +} + +fn main() {} diff --git a/tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.stderr b/tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.stderr new file mode 100644 index 00000000..c98fcf03 --- /dev/null +++ b/tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.stderr @@ -0,0 +1,22 @@ +error[E0277]: the trait bound `Base: DecodeWithMemTracking` is not satisfied + --> tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.rs:8:8 + | +8 | base: Base, + | ^^^^ the trait `DecodeWithMemTracking` is not implemented for `Base` + | + = help: the following other types implement trait `DecodeWithMemTracking`: + () + (TupleElement0, TupleElement1) + (TupleElement0, TupleElement1, TupleElement2) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5, TupleElement6) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5, TupleElement6, TupleElement7) + and $N others +note: required by a bound in `check_field` + --> tests/decode_with_mem_tracking_ui/trait_bound_not_satisfied.rs:6:18 + | +6 | #[derive(Decode, DecodeWithMemTracking)] + | ^^^^^^^^^^^^^^^^^^^^^ required by this bound in `check_field` + = note: this error originates in the derive macro `DecodeWithMemTracking` (in Nightly builds, run with -Z macro-backtrace for more info) From 336c813bec1c27f4b413a536824ad46dc1be0865 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Fri, 4 Oct 2024 14:59:26 +0300 Subject: [PATCH 22/24] Avoid bumping major version --- derive/src/trait_bounds.rs | 10 ++++++---- src/compact.rs | 14 ++------------ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/derive/src/trait_bounds.rs b/derive/src/trait_bounds.rs index 209e1bf6..40cb97db 100644 --- a/derive/src/trait_bounds.rs +++ b/derive/src/trait_bounds.rs @@ -162,10 +162,12 @@ pub fn add( .into_iter() .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #codec_bound))); - let has_compact_bound: syn::Path = parse_quote!(#crate_path::HasCompact); - compact_types - .into_iter() - .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #has_compact_bound))); + compact_types.into_iter().for_each(|ty| { + where_clause.predicates.push(parse_quote!(#ty : #crate_path::HasCompact)); + where_clause + .predicates + .push(parse_quote!(<#ty as #crate_path::HasCompact>::Type : #codec_bound)); + }); skip_types.into_iter().for_each(|ty| { let codec_skip_bound = codec_skip_bound.as_ref(); diff --git a/src/compact.rs b/src/compact.rs index 06c6cf4d..ffcd2ae4 100644 --- a/src/compact.rs +++ b/src/compact.rs @@ -256,12 +256,7 @@ impl MaybeMaxEncodedLen for T {} /// Trait that tells you if a given type can be encoded/decoded in a compact way. pub trait HasCompact: Sized { /// The compact type; this can be - type Type: for<'a> EncodeAsRef<'a, Self> - + Decode - + DecodeWithMemTracking - + From - + Into - + MaybeMaxEncodedLen; + type Type: for<'a> EncodeAsRef<'a, Self> + Decode + From + Into + MaybeMaxEncodedLen; } impl<'a, T: 'a> EncodeAsRef<'a, T> for Compact @@ -285,12 +280,7 @@ where impl HasCompact for T where - Compact: for<'a> EncodeAsRef<'a, T> - + Decode - + DecodeWithMemTracking - + From - + Into - + MaybeMaxEncodedLen, + Compact: for<'a> EncodeAsRef<'a, T> + Decode + From + Into + MaybeMaxEncodedLen, { type Type = Compact; } From fbd1865f2674cb0860cee0af7f64a58a15a37f16 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Fri, 4 Oct 2024 15:29:54 +0300 Subject: [PATCH 23/24] Fixes --- tests/mod.rs | 5 ++++- tests/single_field_struct_encoding.rs | 11 ++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/mod.rs b/tests/mod.rs index 08b88922..f78feb7c 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -69,7 +69,10 @@ enum EnumWithDiscriminant { } #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] -struct TestHasCompact { +struct TestHasCompact +where + ::Type: DecodeWithMemTracking, +{ #[codec(encoded_as = "::Type")] bar: T, } diff --git a/tests/single_field_struct_encoding.rs b/tests/single_field_struct_encoding.rs index 24ff2909..c3f83a4e 100644 --- a/tests/single_field_struct_encoding.rs +++ b/tests/single_field_struct_encoding.rs @@ -1,4 +1,4 @@ -use parity_scale_codec::{Compact, Decode, Encode, HasCompact}; +use parity_scale_codec::{Compact, Decode, DecodeWithMemTracking, Encode, HasCompact}; use parity_scale_codec_derive::{ CompactAs as DeriveCompactAs, Decode as DeriveDecode, DecodeWithMemTracking as DeriveDecodeWithMemTracking, Encode as DeriveEncode, @@ -37,7 +37,10 @@ struct Sc { } #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] -struct Sh { +struct Sh +where + ::Type: DecodeWithMemTracking, +{ #[codec(encoded_as = "::Type")] x: T, } @@ -99,7 +102,9 @@ struct USkipcas(#[codec(compact)] USkip); struct SSkipcas(#[codec(compact)] SSkip); #[derive(Debug, PartialEq, DeriveEncode, DeriveDecode, DeriveDecodeWithMemTracking)] -struct Uh(#[codec(encoded_as = "::Type")] T); +struct Uh(#[codec(encoded_as = "::Type")] T) +where + ::Type: DecodeWithMemTracking; #[test] fn test_encoding() { From 70caacb85f8c3c5ee3c4656788bdcd8a98ae2022 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Wed, 9 Oct 2024 09:34:59 +0300 Subject: [PATCH 24/24] Small fixes --- Cargo.toml | 1 - derive/src/utils.rs | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 773a1198..3890d8f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ byte-slice-cast = { version = "1.2.2", default-features = false } generic-array = { version = "0.14.7", optional = true } arbitrary = { version = "1.3.2", features = ["derive"], optional = true } impl-trait-for-tuples = "0.2.2" -rustversion = "1.0.17" [dev-dependencies] criterion = "0.4.0" diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 26a02840..9f737afa 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -85,7 +85,8 @@ pub fn get_encoded_as_type(field: &Field) -> Option { }) } -/// Look for a `#[codec(compact)]` outer attribute on the given `Field`. +/// Look for a `#[codec(compact)]` outer attribute on the given `Field`. If the attribute is found, +/// return the compact type associated with the field type. pub fn get_compact_type(field: &Field, crate_path: &syn::Path) -> Option { find_meta_item(field.attrs.iter(), |meta| { if let NestedMeta::Meta(Meta::Path(ref path)) = meta {