Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ananas-block committed Jan 9, 2025
1 parent ea575c1 commit 2acc67c
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 110 deletions.
9 changes: 9 additions & 0 deletions program-libs/batched-merkle-tree/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ pub enum BatchedMerkleTreeError {
ProgramError(#[from] ProgramError),
#[error("Verifier error {0}")]
VerifierErrorError(#[from] VerifierError),
#[error("Zero copy cast error {0}")]
ZeroCopyCastError(String),
#[error("Invalid batch index")]
InvalidBatchIndex,
#[error("Invalid index")]
InvalidIndex,
}

#[cfg(feature = "solana")]
Expand All @@ -50,6 +56,9 @@ impl From<BatchedMerkleTreeError> for u32 {
BatchedMerkleTreeError::InvalidNetworkFee => 14305,
BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize => 14306,
BatchedMerkleTreeError::InclusionProofByIndexFailed => 14307,
BatchedMerkleTreeError::ZeroCopyCastError(_) => 14308,
BatchedMerkleTreeError::InvalidBatchIndex => 14309,
BatchedMerkleTreeError::InvalidIndex => 14310,
BatchedMerkleTreeError::Hasher(e) => e.into(),
BatchedMerkleTreeError::ZeroCopy(e) => e.into(),
BatchedMerkleTreeError::MerkleTreeMetadata(e) => e.into(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,7 @@ pub fn init_batched_address_merkle_tree_account(
) -> Result<BatchedMerkleTreeAccount<'_>, BatchedMerkleTreeError> {
let num_batches_input_queue = params.input_queue_num_batches;
let height = params.height;
// let (discriminator, mt_account_data) = mt_account_data.split_at_mut(DISCRIMINATOR_LEN);
// let account_data_len = mt_account_data.len();
// println!("account_data_len {:?}", account_data_len);
// set_discriminator::<BatchedMerkleTreeAccount<'_>>(discriminator)?;

let rollover_fee = match params.rollover_threshold {
Some(rollover_threshold) => {
let rent = merkle_tree_rent;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ pub fn init_batched_state_merkle_tree_accounts<'a>(
associated_merkle_tree: mt_pubkey,
};

let batched_queue_account = BatchedQueueAccount::init(
BatchedQueueAccount::init(
output_queue_account_data,
metadata,
num_batches_output_queue,
Expand All @@ -217,7 +217,6 @@ pub fn init_batched_state_merkle_tree_accounts<'a>(
0,
0,
)?;
println!("batched_queue_account {:?}", batched_queue_account.batches);
}
let metadata = MerkleTreeMetadata {
next_merkle_tree: Pubkey::default(),
Expand Down
53 changes: 32 additions & 21 deletions program-libs/batched-merkle-tree/src/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ pub struct BatchedMerkleTreeAccount<'a> {
pub value_vecs: Vec<ZeroCopyVecU64<'a, [u8; 32]>>,
pub bloom_filter_stores: Vec<ZeroCopySliceMutU64<'a, u8>>,
pub hashchain_store: Vec<ZeroCopyVecU64<'a, [u8; 32]>>,
phantom: std::marker::PhantomData<&'a ()>,
}

impl Deref for BatchedMerkleTreeAccount<'_> {
Expand Down Expand Up @@ -339,10 +338,9 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
check_account_info_mut::<Self>(program_id, account_info)?;
let mut data = account_info.try_borrow_mut_data()?;

// Necessary to convince the borrow checker.
let data_slice: &'a mut [u8] =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) };

// Pass the mutable slice to the function
Self::from_bytes_mut::<TREE_TYPE>(data_slice)
}

Expand All @@ -360,14 +358,15 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
let (discriminator, account_data) = account_data.split_at_mut(DISCRIMINATOR_LEN);
check_discriminator::<Self>(discriminator)?;
let (metadata, account_data) =
Ref::<&'a mut [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data).unwrap();
Ref::<&'a mut [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data)
.map_err(|e| BatchedMerkleTreeError::ZeroCopyCastError(e.to_string()))?;
if metadata.tree_type != TREE_TYPE {
return Err(MerkleTreeMetadataError::InvalidTreeType.into());
}
if account_data_len != metadata.get_account_size()? {
return Err(ZeroCopyError::InvalidAccountSize.into());
}
// let mut start_offset = BatchedMerkleTreeMetadata::LEN;

let (root_history, account_data) = ZeroCopyCyclicVecU64::from_bytes_at(account_data)?;
let (batches, value_vecs, bloom_filter_stores, hashchain_store) = input_queue_bytes(
&metadata.queue_metadata,
Expand All @@ -382,7 +381,6 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
value_vecs,
bloom_filter_stores,
hashchain_store,
phantom: std::marker::PhantomData,
})
}

Expand All @@ -404,7 +402,8 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
set_discriminator::<Self>(discriminator)?;

let (mut account_metadata, account_data) =
Ref::<&'a mut [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data).unwrap();
Ref::<&'a mut [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data)
.map_err(|e| BatchedMerkleTreeError::ZeroCopyCastError(e.to_string()))?;
account_metadata.metadata = metadata;
account_metadata.root_history_capacity = root_history_capacity;
account_metadata.height = height;
Expand All @@ -424,7 +423,6 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
);
return Err(ZeroCopyError::InvalidAccountSize.into());
}
// let mut start_offset = BatchedMerkleTreeMetadata::LEN;

let (mut root_history, account_data) = ZeroCopyCyclicVecU64::new_at(
account_metadata.root_history_capacity as u64,
Expand Down Expand Up @@ -452,7 +450,6 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
value_vecs,
bloom_filter_stores,
hashchain_store,
phantom: std::marker::PhantomData,
})
}

Expand Down Expand Up @@ -482,18 +479,23 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
let batch_index = queue_account.batch_metadata.next_full_batch_index;
let circuit_batch_size = queue_account.get_metadata().batch_metadata.zkp_batch_size;
let batches = &mut queue_account.batches;
let full_batch = batches.get_mut(batch_index as usize).unwrap();
let full_batch = batches
.get_mut(batch_index as usize)
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?;

let new_root = instruction_data.public_inputs.new_root;
let num_zkps = full_batch.get_first_ready_zkp_batch()?;

let leaves_hashchain = queue_account
.hashchain_store
.get(batch_index as usize)
.unwrap()
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?
.get(num_zkps as usize)
.unwrap();
let old_root = self.root_history.last().unwrap();
.ok_or(BatchedMerkleTreeError::InvalidIndex)?;
let old_root = self
.root_history
.last()
.ok_or(BatchedMerkleTreeError::InvalidIndex)?;
let start_index = self.get_metadata().next_index;
let mut start_index_bytes = [0u8; 32];
start_index_bytes[24..].copy_from_slice(&start_index.to_be_bytes());
Expand Down Expand Up @@ -567,20 +569,23 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
) -> Result<BatchNullifyEvent, BatchedMerkleTreeError> {
let batch_index = self.get_metadata().queue_metadata.next_full_batch_index;

let full_batch = self.batches.get(batch_index as usize).unwrap();
let full_batch = self
.batches
.get(batch_index as usize)
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?;

let num_zkps = full_batch.get_first_ready_zkp_batch()?;

let leaves_hashchain = self
.hashchain_store
.get(batch_index as usize)
.unwrap()
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?
.get(num_zkps as usize)
.unwrap();
.ok_or(BatchedMerkleTreeError::InvalidIndex)?;
let old_root = self
.root_history
.get(instruction_data.public_inputs.old_root_index as usize)
.unwrap();
.ok_or(BatchedMerkleTreeError::InvalidIndex)?;
let new_root = instruction_data.public_inputs.new_root;

let public_input_hash = if QUEUE_TYPE == QueueType::Input as u64 {
Expand Down Expand Up @@ -608,7 +613,10 @@ impl<'a> BatchedMerkleTreeAccount<'a> {

let root_history_capacity = self.get_metadata().root_history_capacity;
let sequence_number = self.get_metadata().sequence_number;
let full_batch = self.batches.get_mut(batch_index as usize).unwrap();
let full_batch = self
.batches
.get_mut(batch_index as usize)
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?;
full_batch.mark_as_inserted_in_merkle_tree(
sequence_number,
self.root_history.last_index() as u32,
Expand Down Expand Up @@ -791,17 +799,20 @@ impl<'a> BatchedMerkleTreeAccount<'a> {
let num_inserted_elements = self
.batches
.get(current_batch as usize)
.unwrap()
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?
.get_num_inserted_elements();
let previous_full_batch = self.batches.get_mut(previous_full_batch_index).unwrap();
let previous_full_batch = self
.batches
.get_mut(previous_full_batch_index)
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?;
if previous_full_batch.get_state() == BatchState::Inserted
&& batch_size / 2 > num_inserted_elements
&& !previous_full_batch.bloom_filter_is_wiped()
{
let bloom_filter = self
.bloom_filter_stores
.get_mut(previous_full_batch_index)
.unwrap();
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?;
bloom_filter.as_mut_slice().iter_mut().for_each(|x| *x = 0);
previous_full_batch.set_bloom_filter_is_wiped();
let seq = previous_full_batch.sequence_number;
Expand Down
94 changes: 11 additions & 83 deletions program-libs/batched-merkle-tree/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ pub struct BatchedQueueAccount<'a> {
pub bloom_filter_stores: Vec<ZeroCopySliceMutU64<'a, u8>>,
/// hashchain_store_capacity = batch_capacity / zkp_batch_size
pub hashchain_store: Vec<ZeroCopyVecU64<'a, [u8; 32]>>,
marker: std::marker::PhantomData<&'a ()>,
}

impl Deref for BatchedQueueAccount<'_> {
Expand Down Expand Up @@ -202,10 +201,10 @@ impl<'a> BatchedQueueAccount<'a> {
) -> Result<BatchedQueueAccount<'a>, BatchedMerkleTreeError> {
check_account_info_mut::<Self>(program_id, account_info)?;
let account_data = &mut account_info.try_borrow_mut_data()?;
// Necessary to convince the borrow checker.
let account_data: &'a mut [u8] = unsafe {
std::slice::from_raw_parts_mut(account_data.as_mut_ptr(), account_data.len())
};

Self::internal_from_bytes_mut::<OUTPUT_QUEUE_TYPE>(account_data)
}

Expand All @@ -228,9 +227,10 @@ impl<'a> BatchedQueueAccount<'a> {
) -> Result<BatchedQueueAccount<'a>, BatchedMerkleTreeError> {
let (discriminator, account_data) = account_data.split_at_mut(DISCRIMINATOR_LEN);
check_discriminator::<BatchedQueueAccount>(discriminator)?;
// TODO: remove unwrap

let (metadata, account_data) =
Ref::<&'a mut [u8], BatchedQueueMetadata>::from_prefix(account_data).unwrap();
Ref::<&'a mut [u8], BatchedQueueMetadata>::from_prefix(account_data)
.map_err(|e| BatchedMerkleTreeError::ZeroCopyCastError(e.to_string()))?;

if metadata.metadata.queue_type != QUEUE_TYPE {
return Err(MerkleTreeMetadataError::InvalidQueueType.into());
Expand All @@ -250,7 +250,6 @@ impl<'a> BatchedQueueAccount<'a> {
value_vecs,
bloom_filter_stores,
hashchain_store,
marker: std::marker::PhantomData,
})
}

Expand All @@ -268,7 +267,8 @@ impl<'a> BatchedQueueAccount<'a> {
set_discriminator::<Self>(discriminator)?;

let (mut account_metadata, account_data) =
Ref::<&mut [u8], BatchedQueueMetadata>::from_prefix(account_data).unwrap();
Ref::<&mut [u8], BatchedQueueMetadata>::from_prefix(account_data)
.map_err(|e| BatchedMerkleTreeError::ZeroCopyCastError(e.to_string()))?;

account_metadata.init(
metadata,
Expand Down Expand Up @@ -308,7 +308,6 @@ impl<'a> BatchedQueueAccount<'a> {
value_vecs,
bloom_filter_stores,
hashchain_store,
marker: std::marker::PhantomData,
})
}

Expand Down Expand Up @@ -435,7 +434,9 @@ pub fn insert_into_current_batch(
let mut value_store = value_vecs.get_mut(currently_processing_batch_index);
let mut hashchain_store = hashchain_store.get_mut(currently_processing_batch_index);

let current_batch = batches.get_mut(currently_processing_batch_index).unwrap();
let current_batch = batches
.get_mut(currently_processing_batch_index)
.ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?;
let mut wipe = false;
if current_batch.get_state() == BatchState::Inserted {
current_batch.advance_state_to_can_be_filled()?;
Expand Down Expand Up @@ -556,40 +557,12 @@ pub fn input_queue_bytes<'a>(
> {
let (num_value_stores, num_stores, hashchain_store_capacity) =
account.get_size_parameters(queue_type)?;
// if queue_type == QueueType::Output as u64 {
// *start_offset += BatchedQueueMetadata::LEN;
// }

let (batches, account_data) = ZeroCopySliceMutU64::from_bytes_at(account_data)?;
let (value_vecs, account_data) =
ZeroCopyVecU64::from_bytes_at_multiple(num_value_stores, account_data)?;
// let mut bloom_filter_stores = Vec::with_capacity(num_stores);
let (bloom_filter_stores, account_data) =
ZeroCopySliceMutU64::from_bytes_at_multiple(num_stores, account_data)?;
// for _ in 0..num_stores {

// }
println!(
"account.bloom_filter_capacity {:?}",
account.bloom_filter_capacity / 8
);
// let account_data = &mut account_data[*start_offset..];
// let (bloom_filter_store, account_data) = Ref::<&'a mut [u8], [u8]>::from_prefix_with_elems(
// account_data,
// (account.bloom_filter_capacity / 8) as usize,
// )
// .unwrap();
// // *start_offset += (account.bloom_filter_capacity / 8) as usize;
// bloom_filter_stores.push(bloom_filter_store);
// let (bloom_filter_store, account_data) = Ref::<&'a mut [u8], [u8]>::from_prefix_with_elems(
// account_data,
// (account.bloom_filter_capacity / 8) as usize,
// )
// .unwrap();
// // *start_offset += (account.bloom_filter_capacity / 8) as usize;
// bloom_filter_stores.push(bloom_filter_store);

// // reset start_offset to 0
// *start_offset = 0;

let (hashchain_store, _) =
ZeroCopyVecU64::from_bytes_at_multiple(hashchain_store_capacity, account_data)?;
Expand Down Expand Up @@ -617,10 +590,6 @@ pub fn init_queue<'a>(
let (num_value_stores, num_stores, num_hashchain_stores) =
account.get_size_parameters(queue_type)?;

// if queue_type == QueueType::Output as u64 {
// *start_offset += BatchedQueueMetadata::LEN;
// }

let (mut batches, account_data) =
ZeroCopySliceMutU64::new_at(account.num_batches, account_data)?;

Expand All @@ -641,48 +610,7 @@ pub fn init_queue<'a>(
account.bloom_filter_capacity / 8,
account_data,
)?;
// ZeroCopySliceMutU64::new_at_multiple(
// num_stores,
// account.bloom_filter_capacity as usize / 8,
// account_data,
// start_offset,
// )?;
// if num_stores == 2 {
// let account_data = &mut account_data[*start_offset..];
// let (bloom_filter_store, account_data) = Ref::<&mut [u8], [u8]>::from_prefix_with_elems(
// account_data,
// (account.bloom_filter_capacity / 8) as usize,
// )
// .unwrap();
// *start_offset += (account.bloom_filter_capacity / 8) as usize;
// bloom_filter_stores.push(bloom_filter_store);
// let (bloom_filter_store, account_data) = Ref::<&mut [u8], [u8]>::from_prefix_with_elems(
// account_data,
// (account.bloom_filter_capacity / 8) as usize,
// )
// .unwrap();
// *start_offset += (account.bloom_filter_capacity / 8) as usize;
// bloom_filter_stores.push(bloom_filter_store);
// *start_offset = 0;
// let hashchain_store = ZeroCopyVecU64::new_at_multiple(
// num_hashchain_stores,
// account.get_num_zkp_batches() as usize,
// account_data,
// start_offset,
// )?;

// Ok((batches, value_vecs, bloom_filter_stores, hashchain_store))
// } else {
// assert_eq!(num_stores, 0);
// let hashchain_store = ZeroCopyVecU64::new_at_multiple(
// num_hashchain_stores,
// account.get_num_zkp_batches() as usize,
// account_data,
// start_offset,
// )?;

// Ok((batches, value_vecs, bloom_filter_stores, hashchain_store))
// }

let (hashchain_store, _) = ZeroCopyVecU64::new_at_multiple(
num_hashchain_stores,
account.get_num_zkp_batches(),
Expand Down

0 comments on commit 2acc67c

Please sign in to comment.