From 2acc67cb158464d532ac9c972008506e3085e5ca Mon Sep 17 00:00:00 2001 From: ananas-block Date: Thu, 9 Jan 2025 20:49:00 +0000 Subject: [PATCH] cleanup --- .../batched-merkle-tree/src/errors.rs | 9 ++ .../src/initialize_address_tree.rs | 5 +- .../src/initialize_state_tree.rs | 3 +- .../batched-merkle-tree/src/merkle_tree.rs | 53 ++++++----- program-libs/batched-merkle-tree/src/queue.rs | 94 +++---------------- 5 files changed, 54 insertions(+), 110 deletions(-) diff --git a/program-libs/batched-merkle-tree/src/errors.rs b/program-libs/batched-merkle-tree/src/errors.rs index 44e10eff3..4ae53ff25 100644 --- a/program-libs/batched-merkle-tree/src/errors.rs +++ b/program-libs/batched-merkle-tree/src/errors.rs @@ -37,6 +37,12 @@ pub enum BatchedMerkleTreeError { ProgramError(#[from] ProgramError), #[error("Verifier error {0}")] VerifierErrorError(#[from] VerifierError), + #[error("Zero copy cast error {0}")] + ZeroCopyCastError(String), + #[error("Invalid batch index")] + InvalidBatchIndex, + #[error("Invalid index")] + InvalidIndex, } #[cfg(feature = "solana")] @@ -50,6 +56,9 @@ impl From for u32 { BatchedMerkleTreeError::InvalidNetworkFee => 14305, BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize => 14306, BatchedMerkleTreeError::InclusionProofByIndexFailed => 14307, + BatchedMerkleTreeError::ZeroCopyCastError(_) => 14308, + BatchedMerkleTreeError::InvalidBatchIndex => 14309, + BatchedMerkleTreeError::InvalidIndex => 14310, BatchedMerkleTreeError::Hasher(e) => e.into(), BatchedMerkleTreeError::ZeroCopy(e) => e.into(), BatchedMerkleTreeError::MerkleTreeMetadata(e) => e.into(), diff --git a/program-libs/batched-merkle-tree/src/initialize_address_tree.rs b/program-libs/batched-merkle-tree/src/initialize_address_tree.rs index 0391778fa..276a2b28d 100644 --- a/program-libs/batched-merkle-tree/src/initialize_address_tree.rs +++ b/program-libs/batched-merkle-tree/src/initialize_address_tree.rs @@ -131,10 +131,7 @@ pub fn init_batched_address_merkle_tree_account( ) -> Result, BatchedMerkleTreeError> { let num_batches_input_queue = params.input_queue_num_batches; let height = params.height; - // let (discriminator, mt_account_data) = mt_account_data.split_at_mut(DISCRIMINATOR_LEN); - // let account_data_len = mt_account_data.len(); - // println!("account_data_len {:?}", account_data_len); - // set_discriminator::>(discriminator)?; + let rollover_fee = match params.rollover_threshold { Some(rollover_threshold) => { let rent = merkle_tree_rent; diff --git a/program-libs/batched-merkle-tree/src/initialize_state_tree.rs b/program-libs/batched-merkle-tree/src/initialize_state_tree.rs index eb1688977..dc358aad7 100644 --- a/program-libs/batched-merkle-tree/src/initialize_state_tree.rs +++ b/program-libs/batched-merkle-tree/src/initialize_state_tree.rs @@ -208,7 +208,7 @@ pub fn init_batched_state_merkle_tree_accounts<'a>( associated_merkle_tree: mt_pubkey, }; - let batched_queue_account = BatchedQueueAccount::init( + BatchedQueueAccount::init( output_queue_account_data, metadata, num_batches_output_queue, @@ -217,7 +217,6 @@ pub fn init_batched_state_merkle_tree_accounts<'a>( 0, 0, )?; - println!("batched_queue_account {:?}", batched_queue_account.batches); } let metadata = MerkleTreeMetadata { next_merkle_tree: Pubkey::default(), diff --git a/program-libs/batched-merkle-tree/src/merkle_tree.rs b/program-libs/batched-merkle-tree/src/merkle_tree.rs index 9260c52ff..c76a592ca 100644 --- a/program-libs/batched-merkle-tree/src/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/src/merkle_tree.rs @@ -244,7 +244,6 @@ pub struct BatchedMerkleTreeAccount<'a> { pub value_vecs: Vec>, pub bloom_filter_stores: Vec>, pub hashchain_store: Vec>, - phantom: std::marker::PhantomData<&'a ()>, } impl Deref for BatchedMerkleTreeAccount<'_> { @@ -339,10 +338,9 @@ impl<'a> BatchedMerkleTreeAccount<'a> { check_account_info_mut::(program_id, account_info)?; let mut data = account_info.try_borrow_mut_data()?; + // Necessary to convince the borrow checker. let data_slice: &'a mut [u8] = unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) }; - - // Pass the mutable slice to the function Self::from_bytes_mut::(data_slice) } @@ -360,14 +358,15 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let (discriminator, account_data) = account_data.split_at_mut(DISCRIMINATOR_LEN); check_discriminator::(discriminator)?; let (metadata, account_data) = - Ref::<&'a mut [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data).unwrap(); + Ref::<&'a mut [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data) + .map_err(|e| BatchedMerkleTreeError::ZeroCopyCastError(e.to_string()))?; if metadata.tree_type != TREE_TYPE { return Err(MerkleTreeMetadataError::InvalidTreeType.into()); } if account_data_len != metadata.get_account_size()? { return Err(ZeroCopyError::InvalidAccountSize.into()); } - // let mut start_offset = BatchedMerkleTreeMetadata::LEN; + let (root_history, account_data) = ZeroCopyCyclicVecU64::from_bytes_at(account_data)?; let (batches, value_vecs, bloom_filter_stores, hashchain_store) = input_queue_bytes( &metadata.queue_metadata, @@ -382,7 +381,6 @@ impl<'a> BatchedMerkleTreeAccount<'a> { value_vecs, bloom_filter_stores, hashchain_store, - phantom: std::marker::PhantomData, }) } @@ -404,7 +402,8 @@ impl<'a> BatchedMerkleTreeAccount<'a> { set_discriminator::(discriminator)?; let (mut account_metadata, account_data) = - Ref::<&'a mut [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data).unwrap(); + Ref::<&'a mut [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data) + .map_err(|e| BatchedMerkleTreeError::ZeroCopyCastError(e.to_string()))?; account_metadata.metadata = metadata; account_metadata.root_history_capacity = root_history_capacity; account_metadata.height = height; @@ -424,7 +423,6 @@ impl<'a> BatchedMerkleTreeAccount<'a> { ); return Err(ZeroCopyError::InvalidAccountSize.into()); } - // let mut start_offset = BatchedMerkleTreeMetadata::LEN; let (mut root_history, account_data) = ZeroCopyCyclicVecU64::new_at( account_metadata.root_history_capacity as u64, @@ -452,7 +450,6 @@ impl<'a> BatchedMerkleTreeAccount<'a> { value_vecs, bloom_filter_stores, hashchain_store, - phantom: std::marker::PhantomData, }) } @@ -482,7 +479,9 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let batch_index = queue_account.batch_metadata.next_full_batch_index; let circuit_batch_size = queue_account.get_metadata().batch_metadata.zkp_batch_size; let batches = &mut queue_account.batches; - let full_batch = batches.get_mut(batch_index as usize).unwrap(); + let full_batch = batches + .get_mut(batch_index as usize) + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; let new_root = instruction_data.public_inputs.new_root; let num_zkps = full_batch.get_first_ready_zkp_batch()?; @@ -490,10 +489,13 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let leaves_hashchain = queue_account .hashchain_store .get(batch_index as usize) - .unwrap() + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)? .get(num_zkps as usize) - .unwrap(); - let old_root = self.root_history.last().unwrap(); + .ok_or(BatchedMerkleTreeError::InvalidIndex)?; + let old_root = self + .root_history + .last() + .ok_or(BatchedMerkleTreeError::InvalidIndex)?; let start_index = self.get_metadata().next_index; let mut start_index_bytes = [0u8; 32]; start_index_bytes[24..].copy_from_slice(&start_index.to_be_bytes()); @@ -567,20 +569,23 @@ impl<'a> BatchedMerkleTreeAccount<'a> { ) -> Result { let batch_index = self.get_metadata().queue_metadata.next_full_batch_index; - let full_batch = self.batches.get(batch_index as usize).unwrap(); + let full_batch = self + .batches + .get(batch_index as usize) + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; let num_zkps = full_batch.get_first_ready_zkp_batch()?; let leaves_hashchain = self .hashchain_store .get(batch_index as usize) - .unwrap() + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)? .get(num_zkps as usize) - .unwrap(); + .ok_or(BatchedMerkleTreeError::InvalidIndex)?; let old_root = self .root_history .get(instruction_data.public_inputs.old_root_index as usize) - .unwrap(); + .ok_or(BatchedMerkleTreeError::InvalidIndex)?; let new_root = instruction_data.public_inputs.new_root; let public_input_hash = if QUEUE_TYPE == QueueType::Input as u64 { @@ -608,7 +613,10 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let root_history_capacity = self.get_metadata().root_history_capacity; let sequence_number = self.get_metadata().sequence_number; - let full_batch = self.batches.get_mut(batch_index as usize).unwrap(); + let full_batch = self + .batches + .get_mut(batch_index as usize) + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; full_batch.mark_as_inserted_in_merkle_tree( sequence_number, self.root_history.last_index() as u32, @@ -791,9 +799,12 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let num_inserted_elements = self .batches .get(current_batch as usize) - .unwrap() + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)? .get_num_inserted_elements(); - let previous_full_batch = self.batches.get_mut(previous_full_batch_index).unwrap(); + let previous_full_batch = self + .batches + .get_mut(previous_full_batch_index) + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; if previous_full_batch.get_state() == BatchState::Inserted && batch_size / 2 > num_inserted_elements && !previous_full_batch.bloom_filter_is_wiped() @@ -801,7 +812,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let bloom_filter = self .bloom_filter_stores .get_mut(previous_full_batch_index) - .unwrap(); + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; bloom_filter.as_mut_slice().iter_mut().for_each(|x| *x = 0); previous_full_batch.set_bloom_filter_is_wiped(); let seq = previous_full_batch.sequence_number; diff --git a/program-libs/batched-merkle-tree/src/queue.rs b/program-libs/batched-merkle-tree/src/queue.rs index 5810db842..83f40f43f 100644 --- a/program-libs/batched-merkle-tree/src/queue.rs +++ b/program-libs/batched-merkle-tree/src/queue.rs @@ -159,7 +159,6 @@ pub struct BatchedQueueAccount<'a> { pub bloom_filter_stores: Vec>, /// hashchain_store_capacity = batch_capacity / zkp_batch_size pub hashchain_store: Vec>, - marker: std::marker::PhantomData<&'a ()>, } impl Deref for BatchedQueueAccount<'_> { @@ -202,10 +201,10 @@ impl<'a> BatchedQueueAccount<'a> { ) -> Result, BatchedMerkleTreeError> { check_account_info_mut::(program_id, account_info)?; let account_data = &mut account_info.try_borrow_mut_data()?; + // Necessary to convince the borrow checker. let account_data: &'a mut [u8] = unsafe { std::slice::from_raw_parts_mut(account_data.as_mut_ptr(), account_data.len()) }; - Self::internal_from_bytes_mut::(account_data) } @@ -228,9 +227,10 @@ impl<'a> BatchedQueueAccount<'a> { ) -> Result, BatchedMerkleTreeError> { let (discriminator, account_data) = account_data.split_at_mut(DISCRIMINATOR_LEN); check_discriminator::(discriminator)?; - // TODO: remove unwrap + let (metadata, account_data) = - Ref::<&'a mut [u8], BatchedQueueMetadata>::from_prefix(account_data).unwrap(); + Ref::<&'a mut [u8], BatchedQueueMetadata>::from_prefix(account_data) + .map_err(|e| BatchedMerkleTreeError::ZeroCopyCastError(e.to_string()))?; if metadata.metadata.queue_type != QUEUE_TYPE { return Err(MerkleTreeMetadataError::InvalidQueueType.into()); @@ -250,7 +250,6 @@ impl<'a> BatchedQueueAccount<'a> { value_vecs, bloom_filter_stores, hashchain_store, - marker: std::marker::PhantomData, }) } @@ -268,7 +267,8 @@ impl<'a> BatchedQueueAccount<'a> { set_discriminator::(discriminator)?; let (mut account_metadata, account_data) = - Ref::<&mut [u8], BatchedQueueMetadata>::from_prefix(account_data).unwrap(); + Ref::<&mut [u8], BatchedQueueMetadata>::from_prefix(account_data) + .map_err(|e| BatchedMerkleTreeError::ZeroCopyCastError(e.to_string()))?; account_metadata.init( metadata, @@ -308,7 +308,6 @@ impl<'a> BatchedQueueAccount<'a> { value_vecs, bloom_filter_stores, hashchain_store, - marker: std::marker::PhantomData, }) } @@ -435,7 +434,9 @@ pub fn insert_into_current_batch( let mut value_store = value_vecs.get_mut(currently_processing_batch_index); let mut hashchain_store = hashchain_store.get_mut(currently_processing_batch_index); - let current_batch = batches.get_mut(currently_processing_batch_index).unwrap(); + let current_batch = batches + .get_mut(currently_processing_batch_index) + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; let mut wipe = false; if current_batch.get_state() == BatchState::Inserted { current_batch.advance_state_to_can_be_filled()?; @@ -556,40 +557,12 @@ pub fn input_queue_bytes<'a>( > { let (num_value_stores, num_stores, hashchain_store_capacity) = account.get_size_parameters(queue_type)?; - // if queue_type == QueueType::Output as u64 { - // *start_offset += BatchedQueueMetadata::LEN; - // } + let (batches, account_data) = ZeroCopySliceMutU64::from_bytes_at(account_data)?; let (value_vecs, account_data) = ZeroCopyVecU64::from_bytes_at_multiple(num_value_stores, account_data)?; - // let mut bloom_filter_stores = Vec::with_capacity(num_stores); let (bloom_filter_stores, account_data) = ZeroCopySliceMutU64::from_bytes_at_multiple(num_stores, account_data)?; - // for _ in 0..num_stores { - - // } - println!( - "account.bloom_filter_capacity {:?}", - account.bloom_filter_capacity / 8 - ); - // let account_data = &mut account_data[*start_offset..]; - // let (bloom_filter_store, account_data) = Ref::<&'a mut [u8], [u8]>::from_prefix_with_elems( - // account_data, - // (account.bloom_filter_capacity / 8) as usize, - // ) - // .unwrap(); - // // *start_offset += (account.bloom_filter_capacity / 8) as usize; - // bloom_filter_stores.push(bloom_filter_store); - // let (bloom_filter_store, account_data) = Ref::<&'a mut [u8], [u8]>::from_prefix_with_elems( - // account_data, - // (account.bloom_filter_capacity / 8) as usize, - // ) - // .unwrap(); - // // *start_offset += (account.bloom_filter_capacity / 8) as usize; - // bloom_filter_stores.push(bloom_filter_store); - - // // reset start_offset to 0 - // *start_offset = 0; let (hashchain_store, _) = ZeroCopyVecU64::from_bytes_at_multiple(hashchain_store_capacity, account_data)?; @@ -617,10 +590,6 @@ pub fn init_queue<'a>( let (num_value_stores, num_stores, num_hashchain_stores) = account.get_size_parameters(queue_type)?; - // if queue_type == QueueType::Output as u64 { - // *start_offset += BatchedQueueMetadata::LEN; - // } - let (mut batches, account_data) = ZeroCopySliceMutU64::new_at(account.num_batches, account_data)?; @@ -641,48 +610,7 @@ pub fn init_queue<'a>( account.bloom_filter_capacity / 8, account_data, )?; - // ZeroCopySliceMutU64::new_at_multiple( - // num_stores, - // account.bloom_filter_capacity as usize / 8, - // account_data, - // start_offset, - // )?; - // if num_stores == 2 { - // let account_data = &mut account_data[*start_offset..]; - // let (bloom_filter_store, account_data) = Ref::<&mut [u8], [u8]>::from_prefix_with_elems( - // account_data, - // (account.bloom_filter_capacity / 8) as usize, - // ) - // .unwrap(); - // *start_offset += (account.bloom_filter_capacity / 8) as usize; - // bloom_filter_stores.push(bloom_filter_store); - // let (bloom_filter_store, account_data) = Ref::<&mut [u8], [u8]>::from_prefix_with_elems( - // account_data, - // (account.bloom_filter_capacity / 8) as usize, - // ) - // .unwrap(); - // *start_offset += (account.bloom_filter_capacity / 8) as usize; - // bloom_filter_stores.push(bloom_filter_store); - // *start_offset = 0; - // let hashchain_store = ZeroCopyVecU64::new_at_multiple( - // num_hashchain_stores, - // account.get_num_zkp_batches() as usize, - // account_data, - // start_offset, - // )?; - - // Ok((batches, value_vecs, bloom_filter_stores, hashchain_store)) - // } else { - // assert_eq!(num_stores, 0); - // let hashchain_store = ZeroCopyVecU64::new_at_multiple( - // num_hashchain_stores, - // account.get_num_zkp_batches() as usize, - // account_data, - // start_offset, - // )?; - - // Ok((batches, value_vecs, bloom_filter_stores, hashchain_store)) - // } + let (hashchain_store, _) = ZeroCopyVecU64::new_at_multiple( num_hashchain_stores, account.get_num_zkp_batches(),