From 5f428df86ece0736d1c1a2f6368ef0baa64789de Mon Sep 17 00:00:00 2001 From: Michal Rostecki Date: Tue, 7 May 2024 11:28:34 +0200 Subject: [PATCH 1/2] refactor: Use a regular `Vec` inside `BoundedVec`, add more tests * There is no need to use manual allocations in `BoundedVec`. Instead, we can just embed a regular `Vec`, make it private and make sure that the methods we expose never reallocate it. * Add more test cases covering all error variants. --- .../src/idls/account_compression.ts | 26 +- merkle-tree/bounded-vec/src/lib.rs | 371 ++++++++++-------- merkle-tree/concurrent/src/changelog.rs | 4 +- merkle-tree/concurrent/src/lib.rs | 20 +- merkle-tree/concurrent/tests/tests.rs | 2 +- merkle-tree/indexed/src/copy.rs | 42 +- merkle-tree/indexed/src/lib.rs | 41 +- merkle-tree/indexed/src/reference.rs | 4 +- merkle-tree/indexed/src/zero_copy.rs | 77 ++-- .../account-compression/src/state/address.rs | 2 +- .../src/state/public_state_merkle_tree.rs | 8 +- 11 files changed, 300 insertions(+), 297 deletions(-) diff --git a/js/stateless.js/src/idls/account_compression.ts b/js/stateless.js/src/idls/account_compression.ts index 899d92c9a7..b6b8b54574 100644 --- a/js/stateless.js/src/idls/account_compression.ts +++ b/js/stateless.js/src/idls/account_compression.ts @@ -785,7 +785,7 @@ export type AccountCompression = { { name: 'merkleTreeStruct'; type: { - array: ['u8', 320]; + array: ['u8', 280]; }; }, { @@ -893,7 +893,7 @@ export type AccountCompression = { name: 'stateMerkleTreeStruct'; docs: ['Merkle tree for the transaction state.']; type: { - array: ['u8', 272]; + array: ['u8', 240]; }; }, { @@ -1048,6 +1048,15 @@ export type AccountCompression = { }; }; }, + { + name: 'StateMerkleTree'; + type: { + kind: 'alias'; + value: { + defined: 'ConcurrentMerkleTree26'; + }; + }; + }, ]; errors: [ { @@ -1975,7 +1984,7 @@ export const IDL: AccountCompression = { { name: 'merkleTreeStruct', type: { - array: ['u8', 320], + array: ['u8', 280], }, }, { @@ -2083,7 +2092,7 @@ export const IDL: AccountCompression = { name: 'stateMerkleTreeStruct', docs: ['Merkle tree for the transaction state.'], type: { - array: ['u8', 272], + array: ['u8', 240], }, }, { @@ -2238,6 +2247,15 @@ export const IDL: AccountCompression = { }, }, }, + { + name: 'StateMerkleTree', + type: { + kind: 'alias', + value: { + defined: 'ConcurrentMerkleTree26', + }, + }, + }, ], errors: [ { diff --git a/merkle-tree/bounded-vec/src/lib.rs b/merkle-tree/bounded-vec/src/lib.rs index 691917bccb..42533078f0 100644 --- a/merkle-tree/bounded-vec/src/lib.rs +++ b/merkle-tree/bounded-vec/src/lib.rs @@ -1,8 +1,7 @@ use std::{ - alloc::{self, handle_alloc_error, Layout}, - fmt, mem, + fmt, ops::{Index, IndexMut}, - slice::{self, Iter, IterMut, SliceIndex}, + slice::{Iter, IterMut, SliceIndex}, }; use thiserror::Error; @@ -32,70 +31,28 @@ impl From for solana_program::program_error::ProgramError { } } -/// Plain Old Data. +/// `BoundedVec` is a custom vector implementation which forbids +/// post-initialization reallocations. /// -/// # Safety +/// The purpose is an ability to set an initial limit, but: /// -/// This trait should be implemented only for types with size known at compile -/// time, like primitives or arrays of primitives. -pub unsafe trait Pod {} - -unsafe impl Pod for i8 {} -unsafe impl Pod for i16 {} -unsafe impl Pod for i32 {} -unsafe impl Pod for i64 {} -unsafe impl Pod for isize {} -unsafe impl Pod for u8 {} -unsafe impl Pod for u16 {} -unsafe impl Pod for u32 {} -unsafe impl Pod for u64 {} -unsafe impl Pod for usize {} - -unsafe impl Pod for [u8; N] {} - -/// `BoundedVec` is a custom vector implementation which: +/// * Still be able to define the limit on runtime, not on compile time. +/// * Allocate the memory on heap (not on stack, like arrays). /// -/// * Forbids post-initialization reallocations. The size is not known during -/// compile time (that makes it different from arrays), but can be defined -/// only once (that makes it different from [`Vec`](std::vec::Vec)). -/// * Can store only Plain Old Data ([`Pod`](bytemuck::Pod)). It cannot nest -/// any other dynamically sized types. -pub struct BoundedVec<'a, T> +/// `Vec` is still used as the underlying data structure, `BoundedVec` exposes +/// only the methods which don't trigger reallocations. +#[derive(Clone)] +pub struct BoundedVec(Vec) where - T: Clone + Pod, -{ - capacity: usize, - length: usize, - data: &'a mut [T], -} + T: Clone; -impl<'a, T> BoundedVec<'a, T> +impl BoundedVec where - T: Clone + Pod, + T: Clone, { #[inline] pub fn with_capacity(capacity: usize) -> Self { - let size = mem::size_of::() * capacity; - let align = mem::align_of::(); - // SAFETY: `size` is a multiplication of `capacity`, therefore the - // layout is guaranteed to be aligned. - let layout = unsafe { Layout::from_size_align_unchecked(size, align) }; - - // SAFETY: As long as the provided `Pod` type is correct, this global - // allocator call should be correct too. - // - // We are handling the null pointer case gracefully. - let ptr = unsafe { alloc::alloc(layout) }; - if ptr.is_null() { - handle_alloc_error(layout); - } - let data = unsafe { slice::from_raw_parts_mut(ptr as *mut T, capacity) }; - - Self { - capacity, - length: 0, - data, - } + Self(Vec::with_capacity(capacity)) } pub fn from_array(array: &[T; N]) -> Self { @@ -108,6 +65,16 @@ where vec } + pub fn from_slice(slice: &[T]) -> Self { + let mut vec = Self::with_capacity(slice.len()); + for element in slice { + // SAFETY: We are sure that the array and the vector have equal + // sizes, there is no chance for the error to occur. + vec.push(element.clone()).unwrap(); + } + vec + } + /// Creates a `BoundedVec` directly from a pointer, a capacity, and a length. /// /// # Safety @@ -131,13 +98,8 @@ where /// See the safety documentation of [`pointer::offset`]. #[inline] pub unsafe fn from_raw_parts(ptr: *mut T, length: usize, capacity: usize) -> Self { - let data = slice::from_raw_parts_mut(ptr, capacity); - - Self { - capacity, - length, - data, - } + let vec = Vec::from_raw_parts(ptr, length, capacity); + Self(vec) } /// Returns the total number of elements the vector can hold without @@ -152,12 +114,12 @@ where /// ``` #[inline] pub fn capacity(&self) -> usize { - self.capacity + self.0.capacity() } #[inline] pub fn as_slice(&self) -> &[T] { - &self.data[..self.length] + self.0.as_slice() } /// Appends an element to the back of a collection. @@ -175,70 +137,61 @@ where /// ``` #[inline] pub fn push(&mut self, value: T) -> Result<(), BoundedVecError> { - if self.length == self.capacity { + if self.0.len() == self.0.capacity() { return Err(BoundedVecError::Full); } - - self.data[self.length] = value; - self.length += 1; - + self.0.push(value); Ok(()) } #[inline] pub fn len(&self) -> usize { - self.length + self.0.len() } pub fn is_empty(&self) -> bool { - self.len() == 0 + self.0.is_empty() } #[inline] pub fn get(&self, index: usize) -> Option<&T> { - self.data[..self.length].get(index) + self.0.get(index) } #[inline] pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { - self.data[..self.length].get_mut(index) + self.0.get_mut(index) } #[inline] pub fn iter(&self) -> Iter<'_, T> { - self.data[..self.length].iter() + self.0.iter() } #[inline] pub fn iter_mut(&mut self) -> IterMut<'_, T> { - self.data[..self.length].iter_mut() + self.0.iter_mut() } #[inline] pub fn last(&self) -> Option<&T> { - if self.length < 1 { - return None; - } - self.get(self.length - 1) + self.0.last() } #[inline] pub fn last_mut(&mut self) -> Option<&mut T> { - if self.length < 1 { - return None; - } - self.get_mut(self.length - 1) + self.0.last_mut() } pub fn to_array(&self) -> Result<[T; N], BoundedVecError> { if self.len() != N { return Err(BoundedVecError::ArraySize(N, self.len())); } - Ok(std::array::from_fn(|i| self.data[i].clone())) + Ok(std::array::from_fn(|i| self.0[i].clone())) } pub fn to_vec(self) -> Vec { - self.data[..self.length].to_vec() + self.0 } pub fn extend>(&mut self, iter: U) -> Result<(), BoundedVecError> { @@ -249,99 +202,98 @@ where } } -impl<'a, T> fmt::Debug for BoundedVec<'a, T> +impl fmt::Debug for BoundedVec where - T: Clone + fmt::Debug + Pod, + T: Clone + fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", &self.data[..self.length]) + self.0.fmt(f) } } -impl<'a, T, I: SliceIndex<[T]>> Index for BoundedVec<'a, T> +impl> Index for BoundedVec where - T: Clone + Pod, + T: Clone, I: SliceIndex<[T]>, { type Output = I::Output; #[inline] fn index(&self, index: I) -> &Self::Output { - self.data[..self.length].index(index) + self.0.index(index) } } -impl<'a, T, I> IndexMut for BoundedVec<'a, T> +impl IndexMut for BoundedVec where - T: Clone + Pod, + T: Clone, I: SliceIndex<[T]>, { fn index_mut(&mut self, index: I) -> &mut Self::Output { - self.data[..self.length].index_mut(index) + self.0.index_mut(index) } } -impl<'a, T> PartialEq for BoundedVec<'a, T> +impl PartialEq for BoundedVec where - T: Clone + PartialEq + Pod, + T: Clone + PartialEq, { fn eq(&self, other: &Self) -> bool { - self.data[..self.length] - .iter() - .eq(other.data[..other.length].iter()) + self.0.eq(&other.0) } } -impl<'a, T> Eq for BoundedVec<'a, T> where T: Clone + Eq + Pod {} +impl Eq for BoundedVec where T: Clone + Eq {} /// `CyclicBoundedVec` is a wrapper around [`Vec`](std::vec::Vec) which: /// /// * Forbids post-initialization reallocations. /// * Starts overwriting elements from the beginning once it reaches its /// capacity. -#[derive(Debug)] -pub struct CyclicBoundedVec<'a, T> +#[derive(Clone)] +pub struct CyclicBoundedVec where - T: Clone + Pod, + T: Clone, { - capacity: usize, - length: usize, first_index: usize, last_index: usize, - data: &'a mut [T], + data: Vec, } -impl<'a, T> CyclicBoundedVec<'a, T> +impl CyclicBoundedVec where - T: Clone + Pod, + T: Clone, { #[inline] pub fn with_capacity(capacity: usize) -> Self { - let size = mem::size_of::() * capacity; - let align = mem::align_of::(); - // SAFETY: `size` is a multiplication of `capacity`, therefore the - // layout is guaranteed to be aligned. - let layout = unsafe { Layout::from_size_align_unchecked(size, align) }; - - // SAFETY: As long as the provided `Pod` type is correct, this global - // allocator call should be correct too. - // - // We are handling the null pointer case gracefully. - let ptr = unsafe { alloc::alloc(layout) }; - if ptr.is_null() { - handle_alloc_error(layout); - } - let data = unsafe { slice::from_raw_parts_mut(ptr as *mut T, capacity) }; - + let data = Vec::with_capacity(capacity); Self { - capacity, - length: 0, first_index: 0, last_index: 0, data, } } + pub fn from_array(array: &[T; N]) -> Self { + let mut vec = Self::with_capacity(N); + for element in array { + // SAFETY: We are sure that the array and the vector have equal + // sizes, there is no chance for the error to occur. + vec.push(element.clone()); + } + vec + } + + pub fn from_slice(slice: &[T]) -> Self { + let mut vec = Self::with_capacity(slice.len()); + for element in slice { + // SAFETY: We are sure that the array and the vector have equal + // sizes, there is no chance for the error to occur. + vec.push(element.clone()); + } + vec + } + /// Creates a `CyclicBoundedVec` directly from a pointer, a capacity, and a length. /// /// # Safety @@ -371,10 +323,8 @@ where first_index: usize, last_index: usize, ) -> Self { - let data = slice::from_raw_parts_mut(ptr, capacity); + let data = Vec::from_raw_parts(ptr, length, capacity); Self { - capacity, - length, first_index, last_index, data, @@ -393,7 +343,12 @@ where /// ``` #[inline] pub fn capacity(&self) -> usize { - self.capacity + self.data.capacity() + } + + #[inline] + pub fn as_slice(&self) -> &[T] { + self.data.as_slice() } /// Appends an element to the back of a collection. @@ -407,36 +362,38 @@ where /// ``` #[inline] pub fn push(&mut self, value: T) { - if self.is_empty() { - self.length += 1; - } else if self.len() < self.capacity() { - self.length += 1; - self.last_index += 1; - } else if !self.is_empty() { + if self.len() < self.capacity() { + if !self.is_empty() { + self.last_index += 1; + } + + self.data.push(value); + } else { self.last_index = (self.last_index + 1) % self.capacity(); self.first_index = (self.first_index + 1) % self.capacity(); + + // PANICS: We made sure that `self.newest` doesn't exceed the capacity. + self.data[self.last_index] = value; } - // PANICS: We made sure that `self.newest` doesn't exceed the capacity. - self.data[self.last_index] = value; } #[inline] pub fn len(&self) -> usize { - self.length + self.data.len() } pub fn is_empty(&self) -> bool { - self.len() == 0 + self.data.is_empty() } #[inline] pub fn get(&self, index: usize) -> Option<&T> { - self.data[..self.length].get(index) + self.data.get(index) } #[inline] pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { - self.data[..self.length].get_mut(index) + self.data.get_mut(index) } #[inline] @@ -488,52 +445,61 @@ where } } -impl<'a, T, I> Index for CyclicBoundedVec<'a, T> +impl fmt::Debug for CyclicBoundedVec +where + T: Clone + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.iter().collect::>().as_slice().fmt(f) + } +} + +impl Index for CyclicBoundedVec where - T: Clone + Pod, + T: Clone, I: SliceIndex<[T]>, { type Output = I::Output; #[inline] fn index(&self, index: I) -> &Self::Output { - self.data[..self.length].index(index) + self.data.index(index) } } -impl<'a, T, I> IndexMut for CyclicBoundedVec<'a, T> +impl IndexMut for CyclicBoundedVec where - T: Clone + Pod, + T: Clone, I: SliceIndex<[T]>, { fn index_mut(&mut self, index: I) -> &mut Self::Output { - self.data[..self.length].index_mut(index) + self.data.index_mut(index) } } -impl<'a, T> PartialEq for CyclicBoundedVec<'a, T> +impl PartialEq for CyclicBoundedVec where - T: Clone + Pod + PartialEq, + T: Clone + PartialEq, { fn eq(&self, other: &Self) -> bool { - self.data[..self.length].iter().eq(other.data.iter()) + self.data.eq(&other.data) } } -impl<'a, T> Eq for CyclicBoundedVec<'a, T> where T: Clone + Eq + Pod {} +impl Eq for CyclicBoundedVec where T: Clone + Eq {} pub struct CyclicBoundedVecIterator<'a, T> where - T: Clone + Pod, + T: Clone, { - vec: &'a CyclicBoundedVec<'a, T>, + vec: &'a CyclicBoundedVec, current: usize, is_finished: bool, } impl<'a, T> Iterator for CyclicBoundedVecIterator<'a, T> where - T: Clone + Pod, + T: Clone, { type Item = &'a T; @@ -556,6 +522,61 @@ where mod test { use super::*; + fn bounded_vec_full() { + let mut bounded_vec = BoundedVec::with_capacity(CAPACITY); + + // Append up to capaciity. + for i in 0..CAPACITY { + bounded_vec.push(i).unwrap(); + } + // Try pushing over capacity - should result in an error. + for i in 0..CAPACITY { + let res = bounded_vec.push(i); + assert!(matches!(res, Err(BoundedVecError::Full))); + } + } + + #[test] + fn test_bounded_vec_full_8() { + bounded_vec_full::<8>() + } + + #[test] + fn test_bounded_vec_full_16() { + bounded_vec_full::<16>() + } + + #[test] + fn test_bounded_vec_full_32() { + bounded_vec_full::<32>() + } + + #[test] + fn test_bounded_vec_full_64() { + bounded_vec_full::<64>() + } + + #[test] + fn test_bounded_vec_full_128() { + bounded_vec_full::<128>() + } + + #[test] + fn test_bounded_vec_to_array() { + let bounded_vec = BoundedVec::from_array(&[1u8; 32]); + + assert!(bounded_vec.to_array::<32>().is_ok()); + + assert!(matches!( + bounded_vec.to_array::<31>(), + Err(BoundedVecError::ArraySize(_, _)) + )); + assert!(matches!( + bounded_vec.to_array::<33>(), + Err(BoundedVecError::ArraySize(_, _)) + )); + } + #[test] fn test_cyclic_bounded_vec_manual() { let mut cyclic_bounded_vec = CyclicBoundedVec::with_capacity(8); @@ -861,4 +882,38 @@ mod test { ][..] ); } + + /// Test formatting of a cycled vector. + /// + /// Insert elements over capacity, so the vector resets and starts + /// overwriting elements from the start - 12 elements into a vector with + /// capacity 8. + /// + /// The resulting data structure looks like: + /// + /// ``` + /// $ ^ + /// index [0, 1, 2, 3, 4, 5, 6, 7] + /// value [8, 9, 10, 11, 4, 5, 6, 7] + /// ``` + /// + /// * `^` - first element + /// * `$` - last element + /// + /// The debug format of that structure should look like: + /// + /// ``` + /// [4, 5, 6, 7, 8, 9, 10, 11] + /// ``` + #[test] + fn test_cyclic_bounded_vec_format() { + let mut cyclic_bounded_vec = CyclicBoundedVec::with_capacity(8); + + for i in 0..12 { + cyclic_bounded_vec.push(i); + } + + let f = format!("{cyclic_bounded_vec:?}"); + assert_eq!(f, "[4, 5, 6, 7, 8, 9, 10, 11]"); + } } diff --git a/merkle-tree/concurrent/src/changelog.rs b/merkle-tree/concurrent/src/changelog.rs index b389b3fc7c..deb6ddbff3 100644 --- a/merkle-tree/concurrent/src/changelog.rs +++ b/merkle-tree/concurrent/src/changelog.rs @@ -1,4 +1,4 @@ -use light_bounded_vec::{BoundedVec, Pod}; +use light_bounded_vec::BoundedVec; use crate::errors::ConcurrentMerkleTreeError; @@ -18,8 +18,6 @@ pub type ChangelogEntry26 = ChangelogEntry<26>; pub type ChangelogEntry32 = ChangelogEntry<32>; pub type ChangelogEntry40 = ChangelogEntry<40>; -unsafe impl Pod for ChangelogEntry {} - impl ChangelogEntry { pub fn new(root: [u8; 32], path: [[u8; 32]; HEIGHT], index: usize) -> Self { let index = index as u64; diff --git a/merkle-tree/concurrent/src/lib.rs b/merkle-tree/concurrent/src/lib.rs index d6d259a614..65138b900a 100644 --- a/merkle-tree/concurrent/src/lib.rs +++ b/merkle-tree/concurrent/src/lib.rs @@ -32,7 +32,7 @@ use crate::{ // const generic here is that removing it would require keeping a `BoundecVec` // inside `CyclicBoundedVec`. Casting byte slices to such nested vector is not // a trivial task, but we might eventually do it at some point. -pub struct ConcurrentMerkleTree<'a, H, const HEIGHT: usize> +pub struct ConcurrentMerkleTree where H: Hasher, { @@ -53,23 +53,23 @@ where pub rightmost_leaf: [u8; 32], /// Hashes of subtrees. - pub filled_subtrees: BoundedVec<'a, [u8; 32]>, + pub filled_subtrees: BoundedVec<[u8; 32]>, /// History of Merkle proofs. - pub changelog: CyclicBoundedVec<'a, ChangelogEntry>, + pub changelog: CyclicBoundedVec>, /// History of roots. - pub roots: CyclicBoundedVec<'a, [u8; 32]>, + pub roots: CyclicBoundedVec<[u8; 32]>, /// Cached upper nodes. - pub canopy: BoundedVec<'a, [u8; 32]>, + pub canopy: BoundedVec<[u8; 32]>, pub _hasher: PhantomData, } -pub type ConcurrentMerkleTree22<'a, H> = ConcurrentMerkleTree<'a, H, 22>; -pub type ConcurrentMerkleTree26<'a, H> = ConcurrentMerkleTree<'a, H, 26>; -pub type ConcurrentMerkleTree32<'a, H> = ConcurrentMerkleTree<'a, H, 32>; -pub type ConcurrentMerkleTree40<'a, H> = ConcurrentMerkleTree<'a, H, 40>; +pub type ConcurrentMerkleTree22 = ConcurrentMerkleTree; +pub type ConcurrentMerkleTree26 = ConcurrentMerkleTree; +pub type ConcurrentMerkleTree32 = ConcurrentMerkleTree; +pub type ConcurrentMerkleTree40 = ConcurrentMerkleTree; -impl<'a, H, const HEIGHT: usize> ConcurrentMerkleTree<'a, H, HEIGHT> +impl<'a, H, const HEIGHT: usize> ConcurrentMerkleTree where H: Hasher, { diff --git a/merkle-tree/concurrent/tests/tests.rs b/merkle-tree/concurrent/tests/tests.rs index 78295a4ba2..d6b4ab540f 100644 --- a/merkle-tree/concurrent/tests/tests.rs +++ b/merkle-tree/concurrent/tests/tests.rs @@ -1117,7 +1117,7 @@ where } let merkle_tree = unsafe { - ConcurrentMerkleTree::::from_bytes( + ConcurrentMerkleTree::::copy_from_bytes( bytes_struct.as_slice(), bytes_filled_subtrees.as_slice(), bytes_changelog.as_slice(), diff --git a/merkle-tree/indexed/src/copy.rs b/merkle-tree/indexed/src/copy.rs index a7ef6bea12..99a238df39 100644 --- a/merkle-tree/indexed/src/copy.rs +++ b/merkle-tree/indexed/src/copy.rs @@ -1,6 +1,6 @@ use std::{marker::PhantomData, mem, slice}; -use light_bounded_vec::{BoundedVec, CyclicBoundedVec, Pod}; +use light_bounded_vec::{BoundedVec, CyclicBoundedVec}; use light_concurrent_merkle_tree::{ changelog::ChangelogEntry, errors::ConcurrentMerkleTreeError, ConcurrentMerkleTree, }; @@ -10,39 +10,21 @@ use num_traits::{CheckedAdd, CheckedSub, ToBytes, Unsigned}; use crate::{errors::IndexedMerkleTreeError, IndexedMerkleTree, RawIndexedElement}; #[derive(Debug)] -pub struct IndexedMerkleTreeCopy<'a, H, I, const HEIGHT: usize>( - pub IndexedMerkleTree<'a, H, I, HEIGHT>, -) +pub struct IndexedMerkleTreeCopy(pub IndexedMerkleTree) where H: Hasher, - I: CheckedAdd - + CheckedSub - + Copy - + Clone - + PartialOrd - + ToBytes - + TryFrom - + Unsigned - + Pod, + I: CheckedAdd + CheckedSub + Copy + Clone + PartialOrd + ToBytes + TryFrom + Unsigned, usize: From; -pub type IndexedMerkleTreeCopy22<'a, H, I> = IndexedMerkleTreeCopy<'a, H, I, 22>; -pub type IndexedMerkleTreeCopy26<'a, H, I> = IndexedMerkleTreeCopy<'a, H, I, 26>; -pub type IndexedMerkleTreeCopy32<'a, H, I> = IndexedMerkleTreeCopy<'a, H, I, 32>; -pub type IndexedMerkleTreeCopy40<'a, H, I> = IndexedMerkleTreeCopy<'a, H, I, 40>; +pub type IndexedMerkleTreeCopy22 = IndexedMerkleTreeCopy; +pub type IndexedMerkleTreeCopy26 = IndexedMerkleTreeCopy; +pub type IndexedMerkleTreeCopy32 = IndexedMerkleTreeCopy; +pub type IndexedMerkleTreeCopy40 = IndexedMerkleTreeCopy; -impl<'a, H, I, const HEIGHT: usize> IndexedMerkleTreeCopy<'a, H, I, HEIGHT> +impl IndexedMerkleTreeCopy where H: Hasher, - I: CheckedAdd - + CheckedSub - + Copy - + Clone - + PartialOrd - + ToBytes - + TryFrom - + Unsigned - + Pod, + I: CheckedAdd + CheckedSub + Copy + Clone + PartialOrd + ToBytes + TryFrom + Unsigned, usize: From, { /// Casts a byte slice into wrapped `IndexedMerkleTree` structure reference, @@ -66,9 +48,9 @@ where bytes_changelog: &[u8], bytes_roots: &[u8], bytes_canopy: &[u8], - bytes_indexed_changelog: &'a [u8], + bytes_indexed_changelog: &[u8], ) -> Result { - let expected_bytes_struct_size = mem::size_of::>(); + let expected_bytes_struct_size = mem::size_of::>(); if bytes_struct.len() != expected_bytes_struct_size { return Err(IndexedMerkleTreeError::ConcurrentMerkleTree( ConcurrentMerkleTreeError::StructBufferSize( @@ -77,7 +59,7 @@ where ), )); } - let struct_ref: *mut IndexedMerkleTree<'a, H, I, HEIGHT> = bytes_struct.as_ptr() as _; + let struct_ref: *mut IndexedMerkleTree = bytes_struct.as_ptr() as _; let mut merkle_tree = unsafe { ConcurrentMerkleTree { diff --git a/merkle-tree/indexed/src/lib.rs b/merkle-tree/indexed/src/lib.rs index 3e89182c33..8ee6b3e907 100644 --- a/merkle-tree/indexed/src/lib.rs +++ b/merkle-tree/indexed/src/lib.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use array::{IndexedArray, IndexedElement}; -use light_bounded_vec::{BoundedVec, CyclicBoundedVec, Pod}; +use light_bounded_vec::{BoundedVec, CyclicBoundedVec}; use light_concurrent_merkle_tree::{ errors::ConcurrentMerkleTreeError, light_hasher::Hasher, ConcurrentMerkleTree, }; @@ -23,54 +23,37 @@ pub const FIELD_SIZE_SUB_ONE: &str = #[derive(Debug, Default, Clone, Copy)] pub struct RawIndexedElement where - I: Clone + Pod, + I: Clone, { pub value: [u8; 32], pub next_index: I, pub next_value: [u8; 32], pub index: I, } -unsafe impl Pod for RawIndexedElement where I: Pod + Clone {} #[derive(Debug)] #[repr(C)] -pub struct IndexedMerkleTree<'a, H, I, const HEIGHT: usize> +pub struct IndexedMerkleTree where H: Hasher, - I: CheckedAdd - + CheckedSub - + Copy - + Clone - + PartialOrd - + ToBytes - + TryFrom - + Unsigned - + Pod, + I: CheckedAdd + CheckedSub + Copy + Clone + PartialOrd + ToBytes + TryFrom + Unsigned, usize: From, { - pub merkle_tree: ConcurrentMerkleTree<'a, H, HEIGHT>, - pub changelog: CyclicBoundedVec<'a, RawIndexedElement>, + pub merkle_tree: ConcurrentMerkleTree, + pub changelog: CyclicBoundedVec>, _index: PhantomData, } -pub type IndexedMerkleTree22<'a, H, I> = IndexedMerkleTree<'a, H, I, 22>; -pub type IndexedMerkleTree26<'a, H, I> = IndexedMerkleTree<'a, H, I, 26>; -pub type IndexedMerkleTree32<'a, H, I> = IndexedMerkleTree<'a, H, I, 32>; -pub type IndexedMerkleTree40<'a, H, I> = IndexedMerkleTree<'a, H, I, 40>; +pub type IndexedMerkleTree22 = IndexedMerkleTree; +pub type IndexedMerkleTree26 = IndexedMerkleTree; +pub type IndexedMerkleTree32 = IndexedMerkleTree; +pub type IndexedMerkleTree40 = IndexedMerkleTree; -impl<'a, H, I, const HEIGHT: usize> IndexedMerkleTree<'a, H, I, HEIGHT> +impl IndexedMerkleTree where H: Hasher, - I: CheckedAdd - + CheckedSub - + Copy - + Clone - + PartialOrd - + ToBytes - + TryFrom - + Unsigned - + Pod, + I: CheckedAdd + CheckedSub + Copy + Clone + PartialOrd + ToBytes + TryFrom + Unsigned, usize: From, { pub fn new( diff --git a/merkle-tree/indexed/src/reference.rs b/merkle-tree/indexed/src/reference.rs index ce3b7cae51..46f84621c9 100644 --- a/merkle-tree/indexed/src/reference.rs +++ b/merkle-tree/indexed/src/reference.rs @@ -172,12 +172,12 @@ where /// We prove non-inclusion by: /// 1. Showing that value is greater than leaf_lower_range_value and less than leaf_higher_range_value /// 2. Showing that the leaf_hash H(leaf_lower_range_value, leaf_next_index, leaf_higher_value) is included in the root (Merkle tree) -pub struct NonInclusionProof<'a> { +pub struct NonInclusionProof { pub root: [u8; 32], pub value: [u8; 32], pub leaf_lower_range_value: [u8; 32], pub leaf_higher_range_value: [u8; 32], pub leaf_index: usize, pub next_index: usize, - pub merkle_proof: BoundedVec<'a, [u8; 32]>, + pub merkle_proof: BoundedVec<[u8; 32]>, } diff --git a/merkle-tree/indexed/src/zero_copy.rs b/merkle-tree/indexed/src/zero_copy.rs index 3448cb831e..9f3fa61f9d 100644 --- a/merkle-tree/indexed/src/zero_copy.rs +++ b/merkle-tree/indexed/src/zero_copy.rs @@ -1,6 +1,6 @@ use std::mem; -use light_bounded_vec::{BoundedVec, CyclicBoundedVec, Pod}; +use light_bounded_vec::{BoundedVec, CyclicBoundedVec}; use light_concurrent_merkle_tree::{ changelog::ChangelogEntry, errors::ConcurrentMerkleTreeError, ConcurrentMerkleTree, }; @@ -13,18 +13,10 @@ use crate::{errors::IndexedMerkleTreeError, IndexedMerkleTree, RawIndexedElement pub struct IndexedMerkleTreeZeroCopy<'a, H, I, const HEIGHT: usize> where H: Hasher, - I: CheckedAdd - + CheckedSub - + Copy - + Clone - + PartialOrd - + ToBytes - + TryFrom - + Unsigned - + Pod, + I: CheckedAdd + CheckedSub + Copy + Clone + PartialOrd + ToBytes + TryFrom + Unsigned, usize: From, { - pub merkle_tree: &'a IndexedMerkleTree<'a, H, I, HEIGHT>, + pub merkle_tree: &'a IndexedMerkleTree, } pub type IndexedMerkleTreeZeroCopy22<'a, H, I> = IndexedMerkleTreeZeroCopy<'a, H, I, 22>; @@ -35,15 +27,7 @@ pub type IndexedMerkleTreeZeroCopy40<'a, H, I> = IndexedMerkleTreeZeroCopy<'a, H impl<'a, H, I, const HEIGHT: usize> IndexedMerkleTreeZeroCopy<'a, H, I, HEIGHT> where H: Hasher, - I: CheckedAdd - + CheckedSub - + Copy - + Clone - + PartialOrd - + ToBytes - + TryFrom - + Unsigned - + Pod, + I: CheckedAdd + CheckedSub + Copy + Clone + PartialOrd + ToBytes + TryFrom + Unsigned, usize: From, { /// Casts a byte slice into wrapped `IndexedMerkleTree` structure reference, @@ -64,7 +48,7 @@ where pub unsafe fn struct_from_bytes_zero_copy( bytes_struct: &'a [u8], ) -> Result { - let expected_bytes_struct_size = mem::size_of::>(); + let expected_bytes_struct_size = mem::size_of::>(); if bytes_struct.len() != expected_bytes_struct_size { return Err(IndexedMerkleTreeError::ConcurrentMerkleTree( ConcurrentMerkleTreeError::StructBufferSize( @@ -73,7 +57,7 @@ where ), )); } - let tree: *const IndexedMerkleTree<'a, H, I, HEIGHT> = bytes_struct.as_ptr() as _; + let tree: *const IndexedMerkleTree = bytes_struct.as_ptr() as _; Ok(Self { merkle_tree: &*tree, @@ -163,16 +147,15 @@ where ), )); } - tree.merkle_tree.merkle_tree.roots = - ConcurrentMerkleTree::<'a, H, HEIGHT>::roots_from_bytes( - bytes_roots, - tree.merkle_tree.merkle_tree.roots.len(), - tree.merkle_tree.merkle_tree.roots.capacity(), - tree.merkle_tree.merkle_tree.roots.first_index(), - tree.merkle_tree.merkle_tree.roots.last_index(), - )?; - - let canopy_size = ConcurrentMerkleTree::<'a, H, HEIGHT>::canopy_size( + tree.merkle_tree.merkle_tree.roots = ConcurrentMerkleTree::::roots_from_bytes( + bytes_roots, + tree.merkle_tree.merkle_tree.roots.len(), + tree.merkle_tree.merkle_tree.roots.capacity(), + tree.merkle_tree.merkle_tree.roots.first_index(), + tree.merkle_tree.merkle_tree.roots.last_index(), + )?; + + let canopy_size = ConcurrentMerkleTree::::canopy_size( tree.merkle_tree.merkle_tree.canopy_depth, ); let expected_canopy_size = mem::size_of::<[u8; 32]>() * canopy_size; @@ -213,18 +196,10 @@ where pub struct IndexedMerkleTreeZeroCopyMut<'a, H, I, const HEIGHT: usize> where H: Hasher, - I: CheckedAdd - + CheckedSub - + Copy - + Clone - + PartialOrd - + ToBytes - + TryFrom - + Unsigned - + Pod, + I: CheckedAdd + CheckedSub + Copy + Clone + PartialOrd + ToBytes + TryFrom + Unsigned, usize: From, { - pub merkle_tree: &'a mut IndexedMerkleTree<'a, H, I, HEIGHT>, + pub merkle_tree: &'a mut IndexedMerkleTree, } pub type IndexedMerkleTreeZeroCopyMut22<'a, H, I> = IndexedMerkleTreeZeroCopyMut<'a, H, I, 22>; @@ -235,15 +210,7 @@ pub type IndexedMerkleTreeZeroCopyMut40<'a, H, I> = IndexedMerkleTreeZeroCopyMut impl<'a, H, I, const HEIGHT: usize> IndexedMerkleTreeZeroCopyMut<'a, H, I, HEIGHT> where H: Hasher, - I: CheckedAdd - + CheckedSub - + Copy - + Clone - + PartialOrd - + ToBytes - + TryFrom - + Unsigned - + Pod, + I: CheckedAdd + CheckedSub + Copy + Clone + PartialOrd + ToBytes + TryFrom + Unsigned, usize: From, { /// Casts a byte slice into wrapped `IndexedMerkleTree` structure mutable @@ -264,7 +231,7 @@ where pub unsafe fn struct_from_bytes_zero_copy_mut( bytes_struct: &'a [u8], ) -> Result { - let expected_bytes_struct_size = mem::size_of::>(); + let expected_bytes_struct_size = mem::size_of::>(); if bytes_struct.len() != expected_bytes_struct_size { return Err(IndexedMerkleTreeError::ConcurrentMerkleTree( ConcurrentMerkleTreeError::StructBufferSize( @@ -273,7 +240,7 @@ where ), )); } - let tree: *mut IndexedMerkleTree<'a, H, I, HEIGHT> = bytes_struct.as_ptr() as _; + let tree: *mut IndexedMerkleTree = bytes_struct.as_ptr() as _; Ok(Self { merkle_tree: &mut *tree, @@ -449,7 +416,7 @@ where tree.merkle_tree.merkle_tree.roots.capacity(), tree.merkle_tree.merkle_tree.roots.first_index(), tree.merkle_tree.merkle_tree.roots.last_index(), - ConcurrentMerkleTree::<'a, H, HEIGHT>::canopy_size( + ConcurrentMerkleTree::::canopy_size( tree.merkle_tree.merkle_tree.canopy_depth, ), bytes_indexed_changelog, @@ -471,7 +438,7 @@ mod test { #[test] fn test_from_bytes_zero_copy_init() { - let mut bytes_struct = [0u8; 320]; + let mut bytes_struct = [0u8; 280]; let mut bytes_filled_subtrees = [0u8; 832]; let mut bytes_changelog = [0u8; 1220800]; let mut bytes_roots = [0u8; 76800]; diff --git a/programs/account-compression/src/state/address.rs b/programs/account-compression/src/state/address.rs index f1913f7211..3b9a8eb04b 100644 --- a/programs/account-compression/src/state/address.rs +++ b/programs/account-compression/src/state/address.rs @@ -116,7 +116,7 @@ pub struct AddressMerkleTreeAccount { pub owner: Pubkey, /// Delegate of the Merkle tree. This will be used for program owned Merkle trees. pub delegate: Pubkey, - pub merkle_tree_struct: [u8; 320], + pub merkle_tree_struct: [u8; 280], pub merkle_tree_filled_subtrees: [u8; 832], pub merkle_tree_changelog: [u8; 1220800], pub merkle_tree_roots: [u8; 76800], diff --git a/programs/account-compression/src/state/public_state_merkle_tree.rs b/programs/account-compression/src/state/public_state_merkle_tree.rs index 20b9ee8e99..7602f0d1ba 100644 --- a/programs/account-compression/src/state/public_state_merkle_tree.rs +++ b/programs/account-compression/src/state/public_state_merkle_tree.rs @@ -4,7 +4,7 @@ use light_bounded_vec::CyclicBoundedVec; use light_concurrent_merkle_tree::ConcurrentMerkleTree26; use light_hasher::Poseidon; -pub type StateMerkleTree<'a> = ConcurrentMerkleTree26<'a, Poseidon>; +pub type StateMerkleTree = ConcurrentMerkleTree26; /// Concurrent state Merkle tree used for public compressed transactions. #[account(zero_copy)] @@ -33,7 +33,7 @@ pub struct StateMerkleTreeAccount { pub associated_queue: Pubkey, /// Merkle tree for the transaction state. - pub state_merkle_tree_struct: [u8; 272], + pub state_merkle_tree_struct: [u8; 240], pub state_merkle_tree_filled_subtrees: [u8; 832], pub state_merkle_tree_changelog: [u8; 1220800], pub state_merkle_tree_roots: [u8; 76800], @@ -156,7 +156,7 @@ mod test { owner: Pubkey::new_from_array([2u8; 32]), delegate: Pubkey::new_from_array([3u8; 32]), associated_queue: Pubkey::new_from_array([4u8; 32]), - state_merkle_tree_struct: [0u8; 272], + state_merkle_tree_struct: [0u8; 240], state_merkle_tree_filled_subtrees: [0u8; 832], state_merkle_tree_changelog: [0u8; 1220800], state_merkle_tree_roots: [0u8; 76800], @@ -176,7 +176,7 @@ mod test { } let root = merkle_tree.root().unwrap(); - let merkle_tree_2 = account.load_merkle_tree().unwrap(); + let merkle_tree_2 = account.copy_merkle_tree().unwrap(); assert_eq!(root, merkle_tree_2.root().unwrap()) } } From f2f6963b8e7d62ab84a3f0bc89234a8910c31f31 Mon Sep 17 00:00:00 2001 From: Michal Rostecki Date: Fri, 10 May 2024 14:14:58 +0200 Subject: [PATCH 2/2] test: Use `copy_merkle_tree()` instead of `load_merkle_tree()` --- .../tests/address_merkle_tree_tests.rs | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/test-programs/account-compression-test/tests/address_merkle_tree_tests.rs b/test-programs/account-compression-test/tests/address_merkle_tree_tests.rs index 9105bf49b1..4c8485acbb 100644 --- a/test-programs/account-compression-test/tests/address_merkle_tree_tests.rs +++ b/test-programs/account-compression-test/tests/address_merkle_tree_tests.rs @@ -156,9 +156,9 @@ async fn update_merkle_tree( let address_merkle_tree = &address_merkle_tree .deserialized() - .load_merkle_tree() + .copy_merkle_tree() .unwrap(); - let changelog_index = address_merkle_tree.merkle_tree.changelog_index(); + let changelog_index = address_merkle_tree.0.changelog_index(); changelog_index }; @@ -224,12 +224,10 @@ async fn relayer_update( AccountZeroCopy::::new(context, address_merkle_tree_pubkey) .await; let mut address_merkle_tree_deserialized = address_merkle_tree.deserialized().clone(); - let address_merkle_tree = address_merkle_tree_deserialized - .load_merkle_tree_mut() - .unwrap(); + let address_merkle_tree = address_merkle_tree_deserialized.copy_merkle_tree().unwrap(); assert_eq!( relayer_merkle_tree.root(), - address_merkle_tree.merkle_tree.root().unwrap() + address_merkle_tree.0.root().unwrap() ); let address_queue = unsafe { get_hash_set::(context, address_queue_pubkey).await @@ -261,16 +259,16 @@ async fn relayer_update( for i in 0..16 { bounded_vec.push(array[i]).unwrap(); } - address_merkle_tree - .merkle_tree - .update( - address_merkle_tree.merkle_tree.changelog_index(), - address_bundle.new_element.clone(), - old_low_address.clone(), - old_low_address_next_value.clone(), - &mut bounded_vec, - ) - .unwrap(); + // address_merkle_tree + // .merkle_tree + // .update( + // address_merkle_tree.merkle_tree.changelog_index(), + // address_bundle.new_element.clone(), + // old_low_address.clone(), + // old_low_address_next_value.clone(), + // &mut bounded_vec, + // ) + // .unwrap(); // Update on-chain tree. let update_successful = match update_merkle_tree( @@ -365,7 +363,7 @@ async fn test_address_queue() { .await; let address_merkle_tree = &address_merkle_tree .deserialized() - .load_merkle_tree() + .copy_merkle_tree() .unwrap(); let address_queue = unsafe { @@ -374,19 +372,13 @@ async fn test_address_queue() { assert_eq!( address_queue - .contains( - &address1, - address_merkle_tree.merkle_tree.merkle_tree.sequence_number - ) + .contains(&address1, address_merkle_tree.0.merkle_tree.sequence_number) .unwrap(), true ); assert_eq!( address_queue - .contains( - &address2, - address_merkle_tree.merkle_tree.merkle_tree.sequence_number - ) + .contains(&address2, address_merkle_tree.0.merkle_tree.sequence_number) .unwrap(), true );