Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ananas-block committed Jan 19, 2025
1 parent 896a8f4 commit 6b30abd
Show file tree
Hide file tree
Showing 15 changed files with 646 additions and 435 deletions.
103 changes: 71 additions & 32 deletions program-libs/batched-merkle-tree/src/batch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use borsh::{BorshDeserialize, BorshSerialize};
use light_bloom_filter::BloomFilter;
use light_hasher::{Hasher, Poseidon};
use light_zero_copy::{slice_mut::ZeroCopySliceMutU64, vec::ZeroCopyVecU64};
use light_zero_copy::vec::ZeroCopyVecU64;
use solana_program::msg;
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};

Expand Down Expand Up @@ -43,13 +44,26 @@ impl From<BatchState> for u64 {
/// - is part of a queue, by default a queue has two batches.
/// - is inserted into the tree by zkp batch.
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, KnownLayout, Immutable, IntoBytes, FromBytes)]
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Eq,
KnownLayout,
Immutable,
IntoBytes,
FromBytes,
Default,
BorshSerialize,
BorshDeserialize,
)]
pub struct Batch {
/// Number of inserted elements in the zkp batch.
num_inserted: u64,
state: u64,
current_zkp_batch_index: u64,
num_inserted_zkps: u64,
pub num_inserted_zkps: u64,
/// Number of iterations for the bloom_filter.
pub num_iters: u64,
/// Theoretical capacity of the bloom_filter. We want to make it much larger
Expand Down Expand Up @@ -234,7 +248,7 @@ impl Batch {
&mut self,
bloom_filter_value: &[u8; 32],
hashchain_value: &[u8; 32],
bloom_filter_stores: &mut [ZeroCopySliceMutU64<u8>],
bloom_filter_stores: &mut [&mut [u8]],
hashchain_store: &mut ZeroCopyVecU64<[u8; 32]>,
bloom_filter_index: usize,
) -> Result<(), BatchedMerkleTreeError> {
Expand All @@ -251,13 +265,18 @@ impl Batch {
BloomFilter::new(
self.num_iters as usize,
self.bloom_filter_capacity,
bloom_filter.as_mut_slice(),
bloom_filter,
)?
.insert(bloom_filter_value)?;

// 3. Check that value is not in any other bloom filter.
for bf_store in before.iter_mut().chain(after.iter_mut()) {
self.check_non_inclusion(bloom_filter_value, bf_store.as_mut_slice())?;
Self::check_non_inclusion(
self.num_iters as usize,
self.bloom_filter_capacity,
bloom_filter_value,
bf_store,
)?;
}
}
Ok(())
Expand Down Expand Up @@ -310,12 +329,13 @@ impl Batch {

/// Checks that value is not in the bloom filter.
pub fn check_non_inclusion(
&self,
num_iters: usize,
bloom_filter_capacity: u64,
value: &[u8; 32],
store: &mut [u8],
) -> Result<(), BatchedMerkleTreeError> {
let mut bloom_filter =
BloomFilter::new(self.num_iters as usize, self.bloom_filter_capacity, store)?;
let mut bloom_filter = BloomFilter::new(num_iters, bloom_filter_capacity, store)?;
println!("Checking non inclusion");
if bloom_filter.contains(value) {
return Err(BatchedMerkleTreeError::NonInclusionCheckFailed);
}
Expand Down Expand Up @@ -475,10 +495,10 @@ mod tests {
fn test_insert() {
// Behavior Input queue
let mut batch = get_test_batch();
let mut stores = vec![vec![0u8; 20_008]; 2];
let mut stores = vec![vec![0u8; 20_000]; 2];
let mut bloom_filter_stores = stores
.iter_mut()
.map(|store| ZeroCopySliceMutU64::new(20_000, store).unwrap())
.map(|store| &mut store[..])
.collect::<Vec<_>>();
let mut hashchain_store_bytes = vec![
0u8;
Expand Down Expand Up @@ -541,19 +561,24 @@ mod tests {
let mut bloom_filter = BloomFilter {
num_iters: batch.num_iters as usize,
capacity: batch.bloom_filter_capacity,
store: bloom_filter_stores[processing_index].as_mut_slice(),
store: bloom_filter_stores[processing_index],
};
assert!(bloom_filter.contains(&value));
let other_index = if processing_index == 0 { 1 } else { 0 };
batch
.check_non_inclusion(&value, bloom_filter_stores[other_index].as_mut_slice())
.unwrap();
batch
.check_non_inclusion(
&value,
bloom_filter_stores[processing_index].as_mut_slice(),
)
.unwrap_err();
Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[other_index],
)
.unwrap();
Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[processing_index],
)
.unwrap_err();

ref_batch.num_inserted += 1;
if ref_batch.num_inserted == ref_batch.zkp_batch_size {
Expand Down Expand Up @@ -611,10 +636,10 @@ mod tests {
let mut batch = get_test_batch();

let value = [1u8; 32];
let mut stores = vec![vec![0u8; 20_008]; 2];
let mut stores = vec![vec![0u8; 20_000]; 2];
let mut bloom_filter_stores = stores
.iter_mut()
.map(|store| ZeroCopySliceMutU64::new(20_000, store).unwrap())
.map(|store| &mut store[..])
.collect::<Vec<_>>();
let mut hashchain_store_bytes = vec![
0u8;
Expand All @@ -628,9 +653,15 @@ mod tests {
)
.unwrap();

assert!(batch
.check_non_inclusion(&value, bloom_filter_stores[processing_index].as_mut_slice())
.is_ok());
assert_eq!(
Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[processing_index]
),
Ok(())
);
let ref_batch = get_test_batch();
assert_eq!(batch, ref_batch);
batch
Expand All @@ -642,14 +673,22 @@ mod tests {
processing_index,
)
.unwrap();
assert!(batch
.check_non_inclusion(&value, bloom_filter_stores[processing_index].as_mut_slice())
.is_err());
assert!(Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[processing_index]
)
.is_err());

let other_index = if processing_index == 0 { 1 } else { 0 };
assert!(batch
.check_non_inclusion(&value, bloom_filter_stores[other_index].as_mut_slice())
.is_ok());
assert!(Batch::check_non_inclusion(
batch.num_iters as usize,
batch.bloom_filter_capacity,
&value,
bloom_filter_stores[other_index]
)
.is_ok());
}
}

Expand Down
82 changes: 57 additions & 25 deletions program-libs/batched-merkle-tree/src/batch_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub struct BatchMetadata {
pub currently_processing_batch_index: u64,
/// Next batch to be inserted into the tree.
pub next_full_batch_index: u64,
pub batches: [Batch; 2],
}

impl BatchMetadata {
Expand All @@ -45,6 +46,14 @@ impl BatchMetadata {
self.batch_size / self.zkp_batch_size
}

pub fn get_current_batch(&self) -> &Batch {
&self.batches[self.currently_processing_batch_index as usize]
}

pub fn get_current_batch_mut(&mut self) -> &mut Batch {
&mut self.batches[self.currently_processing_batch_index as usize]
}

/// Validates that the batch size is properly divisible by the ZKP batch size.
fn validate_batch_sizes(
batch_size: u64,
Expand All @@ -68,7 +77,12 @@ impl BatchMetadata {
batch_size,
currently_processing_batch_index: 0,
next_full_batch_index: 0,
// Output queues don't use bloom filters.
bloom_filter_capacity: 0,
batches: [
Batch::new(0, 0, batch_size, zkp_batch_size, 0),
Batch::new(0, 0, batch_size, zkp_batch_size, batch_size),
],
})
}

Expand All @@ -77,6 +91,8 @@ impl BatchMetadata {
bloom_filter_capacity: u64,
zkp_batch_size: u64,
num_batches: u64,
num_iters: u64,
start_index: u64,
) -> Result<Self, BatchedMerkleTreeError> {
Self::validate_batch_sizes(batch_size, zkp_batch_size)?;

Expand All @@ -87,6 +103,22 @@ impl BatchMetadata {
currently_processing_batch_index: 0,
next_full_batch_index: 0,
bloom_filter_capacity,
batches: [
Batch::new(
num_iters,
bloom_filter_capacity,
batch_size,
zkp_batch_size,
start_index,
),
Batch::new(
num_iters,
bloom_filter_capacity,
batch_size,
zkp_batch_size,
batch_size + start_index,
),
],
})
}

Expand All @@ -98,7 +130,8 @@ impl BatchMetadata {
}

/// Increment the currently_processing_batch_index if current state is BatchState::Full.
pub fn increment_currently_processing_batch_index_if_full(&mut self, state: BatchState) {
pub fn increment_currently_processing_batch_index_if_full(&mut self) {
let state = self.get_current_batch().get_state();
if state == BatchState::Full {
self.currently_processing_batch_index =
(self.currently_processing_batch_index + 1) % self.num_batches;
Expand Down Expand Up @@ -151,19 +184,18 @@ impl BatchMetadata {
} else {
BatchedQueueMetadata::LEN
};
let batches_size =
ZeroCopySliceMutU64::<Batch>::required_size_for_capacity(self.num_batches);
// let batches_size =
// ZeroCopySliceMutU64::<Batch>::required_size_for_capacity(self.num_batches);
let value_vecs_size =
ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(self.batch_size) * num_value_vec;
// Bloomfilter capacity is in bits.
let bloom_filter_stores_size =
ZeroCopySliceMutU64::<u8>::required_size_for_capacity(self.bloom_filter_capacity / 8)
* num_bloom_filter_stores;
(self.bloom_filter_capacity / 8) as usize * num_bloom_filter_stores;
let hashchain_store_size =
ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(self.get_num_zkp_batches())
* num_hashchain_store;
let size = account_size
+ batches_size
// + batches_size
+ value_vecs_size
+ bloom_filter_stores_size
+ hashchain_store_size;
Expand All @@ -173,7 +205,7 @@ impl BatchMetadata {

#[test]
fn test_increment_next_full_batch_index_if_inserted() {
let mut metadata = BatchMetadata::new_input_queue(10, 10, 10, 2).unwrap();
let mut metadata = BatchMetadata::new_input_queue(10, 10, 10, 2, 3, 0).unwrap();
assert_eq!(metadata.next_full_batch_index, 0);
// increment next full batch index
metadata.increment_next_full_batch_index_if_inserted(BatchState::Inserted);
Expand All @@ -188,30 +220,30 @@ fn test_increment_next_full_batch_index_if_inserted() {
assert_eq!(metadata.next_full_batch_index, 0);
}

#[test]
fn test_increment_currently_processing_batch_index_if_full() {
let mut metadata = BatchMetadata::new_input_queue(10, 10, 10, 2).unwrap();
assert_eq!(metadata.currently_processing_batch_index, 0);
// increment currently_processing_batch_index
metadata.increment_currently_processing_batch_index_if_full(BatchState::Full);
assert_eq!(metadata.currently_processing_batch_index, 1);
// increment currently_processing_batch_index
metadata.increment_currently_processing_batch_index_if_full(BatchState::Full);
assert_eq!(metadata.currently_processing_batch_index, 0);
// try incrementing next full batch index with state not full
metadata.increment_currently_processing_batch_index_if_full(BatchState::Fill);
assert_eq!(metadata.currently_processing_batch_index, 0);
metadata.increment_currently_processing_batch_index_if_full(BatchState::Inserted);
assert_eq!(metadata.currently_processing_batch_index, 0);
}
// #[test]
// fn test_increment_currently_processing_batch_index_if_full() {
// let mut metadata = BatchMetadata::new_input_queue(10, 10, 10, 2).unwrap();
// assert_eq!(metadata.currently_processing_batch_index, 0);
// // increment currently_processing_batch_index
// metadata.increment_currently_processing_batch_index_if_full(BatchState::Full);
// assert_eq!(metadata.currently_processing_batch_index, 1);
// // increment currently_processing_batch_index
// metadata.increment_currently_processing_batch_index_if_full(BatchState::Full);
// assert_eq!(metadata.currently_processing_batch_index, 0);
// // try incrementing next full batch index with state not full
// metadata.increment_currently_processing_batch_index_if_full(BatchState::Fill);
// assert_eq!(metadata.currently_processing_batch_index, 0);
// metadata.increment_currently_processing_batch_index_if_full(BatchState::Inserted);
// assert_eq!(metadata.currently_processing_batch_index, 0);
// }

#[test]
fn test_batch_size_validation() {
// Test invalid batch size
assert!(BatchMetadata::new_input_queue(10, 10, 3, 2).is_err());
assert!(BatchMetadata::new_input_queue(10, 10, 3, 2, 3, 0).is_err());
assert!(BatchMetadata::new_output_queue(10, 3, 2).is_err());

// Test valid batch size
assert!(BatchMetadata::new_input_queue(9, 10, 3, 2).is_ok());
assert!(BatchMetadata::new_input_queue(9, 10, 3, 2, 3, 0).is_ok());
assert!(BatchMetadata::new_output_queue(9, 3, 2).is_ok());
}
Loading

0 comments on commit 6b30abd

Please sign in to comment.