Skip to content

Commit

Permalink
perf: optimize batched Merkle tree & output queue deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ananas-block committed Jan 20, 2025
1 parent 896a8f4 commit 84d3a51
Show file tree
Hide file tree
Showing 30 changed files with 600 additions and 643 deletions.
6 changes: 3 additions & 3 deletions forester-utils/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ where
.unwrap();

let full_batch_index = merkle_tree.queue_metadata.next_full_batch_index;
let batch = &merkle_tree.batches[full_batch_index as usize];
let batch = &merkle_tree.queue_metadata.batches[full_batch_index as usize];
let zkp_batch_index = batch.get_num_inserted_zkps();
let leaves_hashchain =
merkle_tree.hashchain_store[full_batch_index as usize][zkp_batch_index as usize];
Expand Down Expand Up @@ -232,7 +232,7 @@ pub async fn create_append_batch_ix_data<R: RpcConnection, I: Indexer<R>>(
let zkp_batch_size = output_queue.batch_metadata.zkp_batch_size;

let num_inserted_zkps =
output_queue.batches[full_batch_index as usize].get_num_inserted_zkps();
output_queue.batch_metadata.batches[full_batch_index as usize].get_num_inserted_zkps();

let leaves_hashchain =
output_queue.hashchain_store[full_batch_index as usize][num_inserted_zkps as usize];
Expand Down Expand Up @@ -341,7 +341,7 @@ pub async fn create_nullify_batch_ix_data<R: RpcConnection, I: Indexer<R>>(
BatchedMerkleTreeAccount::state_from_bytes(account.data.as_mut_slice()).unwrap();
let batch_idx = merkle_tree.queue_metadata.next_full_batch_index as usize;
let zkp_size = merkle_tree.queue_metadata.zkp_batch_size;
let batch = &merkle_tree.batches[batch_idx];
let batch = &merkle_tree.queue_metadata.batches[batch_idx];
let zkp_idx = batch.get_num_inserted_zkps();
let hashchain = merkle_tree.hashchain_store[batch_idx][zkp_idx as usize];
let root = *merkle_tree.root_history.last().unwrap();
Expand Down
18 changes: 13 additions & 5 deletions forester/src/batch_processor/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl<R: RpcConnection, I: Indexer<R> + IndexerType<R>> BatchProcessor<R, I> {
};

let batch_index = tree.queue_metadata.next_full_batch_index;
match tree.batches.get(batch_index as usize) {
match tree.queue_metadata.batches.get(batch_index as usize) {
Some(batch) => Self::calculate_completion(batch),
None => 0.0,
}
Expand All @@ -139,7 +139,7 @@ impl<R: RpcConnection, I: Indexer<R> + IndexerType<R>> BatchProcessor<R, I> {
};

let batch_index = queue.batch_metadata.next_full_batch_index;
match queue.batches.get(batch_index as usize) {
match queue.batch_metadata.batches.get(batch_index as usize) {
Some(batch) => Self::calculate_completion(batch),
None => 0.0,
}
Expand Down Expand Up @@ -181,7 +181,7 @@ impl<R: RpcConnection, I: Indexer<R> + IndexerType<R>> BatchProcessor<R, I> {
let zkp_batch_size = output_queue.batch_metadata.zkp_batch_size;

(
output_queue.batches[batch_index as usize].get_num_inserted_zkps(),
output_queue.batch_metadata.batches[batch_index as usize].get_num_inserted_zkps(),
zkp_batch_size as usize,
)
};
Expand All @@ -206,7 +206,11 @@ impl<R: RpcConnection, I: Indexer<R> + IndexerType<R>> BatchProcessor<R, I> {

if let Ok(tree) = merkle_tree {
let batch_index = tree.queue_metadata.next_full_batch_index;
let full_batch = tree.batches.get(batch_index as usize).unwrap();
let full_batch = tree
.queue_metadata
.batches
.get(batch_index as usize)
.unwrap();

full_batch.get_state() != BatchState::Inserted
&& full_batch.get_current_zkp_batch_index() > full_batch.get_num_inserted_zkps()
Expand All @@ -230,7 +234,11 @@ impl<R: RpcConnection, I: Indexer<R> + IndexerType<R>> BatchProcessor<R, I> {

if let Ok(queue) = output_queue {
let batch_index = queue.batch_metadata.next_full_batch_index;
let full_batch = queue.batches.get(batch_index as usize).unwrap();
let full_batch = queue
.batch_metadata
.batches
.get(batch_index as usize)
.unwrap();

full_batch.get_state() != BatchState::Inserted
&& full_batch.get_current_zkp_batch_index() > full_batch.get_num_inserted_zkps()
Expand Down
4 changes: 2 additions & 2 deletions forester/tests/batched_address_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ async fn test_address_batched() {
let num_zkp_batches = batch_size / zkp_batch_size;

let mut completed_items = 0;
for batch_idx in 0..merkle_tree.batches.len() {
let batch = merkle_tree.batches.get(batch_idx).unwrap();
for batch_idx in 0..merkle_tree.queue_metadata.batches.len() {
let batch = merkle_tree.queue_metadata.batches.get(batch_idx).unwrap();
if batch.get_state() == BatchState::Inserted {
completed_items += batch_size;
}
Expand Down
4 changes: 2 additions & 2 deletions forester/tests/batched_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ async fn test_state_batched() {
let num_zkp_batches = batch_size / zkp_batch_size;

let mut completed_items = 0;
for batch_idx in 0..output_queue.batches.len() {
let batch = output_queue.batches.get(batch_idx).unwrap();
for batch_idx in 0..output_queue.batch_metadata.batches.len() {
let batch = output_queue.batch_metadata.batches.get(batch_idx).unwrap();
if batch.get_state() == BatchState::Inserted {
completed_items += batch_size;
}
Expand Down
102 changes: 70 additions & 32 deletions program-libs/batched-merkle-tree/src/batch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use borsh::{BorshDeserialize, BorshSerialize};
use light_bloom_filter::BloomFilter;
use light_hasher::{Hasher, Poseidon};
use light_zero_copy::{slice_mut::ZeroCopySliceMutU64, vec::ZeroCopyVecU64};
use light_zero_copy::vec::ZeroCopyVecU64;
use solana_program::msg;
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};

Expand Down Expand Up @@ -43,13 +44,26 @@ impl From<BatchState> for u64 {
/// - is part of a queue, by default a queue has two batches.
/// - is inserted into the tree by zkp batch.
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, KnownLayout, Immutable, IntoBytes, FromBytes)]
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Eq,
KnownLayout,
Immutable,
IntoBytes,
FromBytes,
Default,
BorshSerialize,
BorshDeserialize,
)]
pub struct Batch {
/// Number of inserted elements in the zkp batch.
num_inserted: u64,
state: u64,
current_zkp_batch_index: u64,
num_inserted_zkps: u64,
pub num_inserted_zkps: u64,
/// Number of iterations for the bloom_filter.
pub num_iters: u64,
/// Theoretical capacity of the bloom_filter. We want to make it much larger
Expand Down Expand Up @@ -234,7 +248,7 @@ impl Batch {
&mut self,
bloom_filter_value: &[u8; 32],
hashchain_value: &[u8; 32],
bloom_filter_stores: &mut [ZeroCopySliceMutU64<u8>],
bloom_filter_stores: &mut [&mut [u8]],
hashchain_store: &mut ZeroCopyVecU64<[u8; 32]>,
bloom_filter_index: usize,
) -> Result<(), BatchedMerkleTreeError> {
Expand All @@ -251,13 +265,18 @@ impl Batch {
BloomFilter::new(
self.num_iters as usize,
self.bloom_filter_capacity,
bloom_filter.as_mut_slice(),
bloom_filter,
)?
.insert(bloom_filter_value)?;

// 3. Check that value is not in any other bloom filter.
for bf_store in before.iter_mut().chain(after.iter_mut()) {
self.check_non_inclusion(bloom_filter_value, bf_store.as_mut_slice())?;
Self::check_non_inclusion(
self.num_iters as usize,
self.bloom_filter_capacity,
bloom_filter_value,
bf_store,
)?;
}
}
Ok(())
Expand Down Expand Up @@ -310,12 +329,12 @@ impl Batch {

/// Checks that value is not in the bloom filter.
pub fn check_non_inclusion(
&self,
num_iters: usize,
bloom_filter_capacity: u64,
value: &[u8; 32],
store: &mut [u8],
) -> Result<(), BatchedMerkleTreeError> {
let mut bloom_filter =
BloomFilter::new(self.num_iters as usize, self.bloom_filter_capacity, store)?;
let mut bloom_filter = BloomFilter::new(num_iters, bloom_filter_capacity, store)?;
if bloom_filter.contains(value) {
return Err(BatchedMerkleTreeError::NonInclusionCheckFailed);
}
Expand Down Expand Up @@ -475,10 +494,10 @@ mod tests {
fn test_insert() {
// Behavior Input queue
let mut batch = get_test_batch();
let mut stores = vec![vec![0u8; 20_008]; 2];
let mut stores = vec![vec![0u8; 20_000]; 2];
let mut bloom_filter_stores = stores
.iter_mut()
.map(|store| ZeroCopySliceMutU64::new(20_000, store).unwrap())
.map(|store| &mut store[..])
.collect::<Vec<_>>();
let mut hashchain_store_bytes = vec![
0u8;
Expand Down Expand Up @@ -541,19 +560,24 @@ mod tests {
let mut bloom_filter = BloomFilter {
num_iters: batch.num_iters as usize,
capacity: batch.bloom_filter_capacity,
store: bloom_filter_stores[processing_index].as_mut_slice(),
store: bloom_filter_stores[processing_index],
};
assert!(bloom_filter.contains(&value));
let other_index = if processing_index == 0 { 1 } else { 0 };
batch
.check_non_inclusion(&value, bloom_filter_stores[other_index].as_mut_slice())
.unwrap();
batch
.check_non_inclusion(
&value,
bloom_filter_stores[processing_index].as_mut_slice(),
)
.unwrap_err();
Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[other_index],
)
.unwrap();
Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[processing_index],
)
.unwrap_err();

ref_batch.num_inserted += 1;
if ref_batch.num_inserted == ref_batch.zkp_batch_size {
Expand Down Expand Up @@ -611,10 +635,10 @@ mod tests {
let mut batch = get_test_batch();

let value = [1u8; 32];
let mut stores = vec![vec![0u8; 20_008]; 2];
let mut stores = vec![vec![0u8; 20_000]; 2];
let mut bloom_filter_stores = stores
.iter_mut()
.map(|store| ZeroCopySliceMutU64::new(20_000, store).unwrap())
.map(|store| &mut store[..])
.collect::<Vec<_>>();
let mut hashchain_store_bytes = vec![
0u8;
Expand All @@ -628,9 +652,15 @@ mod tests {
)
.unwrap();

assert!(batch
.check_non_inclusion(&value, bloom_filter_stores[processing_index].as_mut_slice())
.is_ok());
assert_eq!(
Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[processing_index]
),
Ok(())
);
let ref_batch = get_test_batch();
assert_eq!(batch, ref_batch);
batch
Expand All @@ -642,14 +672,22 @@ mod tests {
processing_index,
)
.unwrap();
assert!(batch
.check_non_inclusion(&value, bloom_filter_stores[processing_index].as_mut_slice())
.is_err());
assert!(Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[processing_index]
)
.is_err());

let other_index = if processing_index == 0 { 1 } else { 0 };
assert!(batch
.check_non_inclusion(&value, bloom_filter_stores[other_index].as_mut_slice())
.is_ok());
assert!(Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[other_index]
)
.is_ok());
}
}

Expand Down
Loading

0 comments on commit 84d3a51

Please sign in to comment.