Skip to content

Commit

Permalink
test: Add more unit tests for light-hash-set (#842)
Browse files Browse the repository at this point in the history
Some methods were not covered. This change achieves full coverage.
  • Loading branch information
vadorovsky authored Jun 20, 2024
1 parent 6f10474 commit 269730b
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 64 deletions.
198 changes: 187 additions & 11 deletions merkle-tree/hash-set/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ impl HashSet {
}

/// Size which needs to be allocated on Solana account to fit the hash set.
pub fn size_in_account(capacity_values: usize) -> Result<usize, HashSetError> {
pub fn size_in_account(capacity_values: usize) -> usize {
let dyn_fields_size = Self::non_dyn_fields_size();

let buckets_size_unaligned = mem::size_of::<Option<HashSetCell>>() * capacity_values;
// Make sure that alignment of `values` matches the alignment of `usize`.
let buckets_size = buckets_size_unaligned + mem::align_of::<usize>()
- (buckets_size_unaligned % mem::align_of::<usize>());

Ok(dyn_fields_size + buckets_size)
dyn_fields_size + buckets_size
}

// Create a new hash set with the given capacity
Expand Down Expand Up @@ -181,7 +181,7 @@ impl HashSet {
let capacity_values = usize::from_ne_bytes(bytes[0..8].try_into().unwrap());
let sequence_threshold = usize::from_ne_bytes(bytes[8..16].try_into().unwrap());

let expected_size = Self::size_in_account(capacity_values)?;
let expected_size = Self::size_in_account(capacity_values);
if bytes.len() != expected_size {
return Err(HashSetError::BufferSize(expected_size, bytes.len()));
}
Expand Down Expand Up @@ -544,6 +544,14 @@ impl Drop for HashSet {
}
}

impl PartialEq for HashSet {
fn eq(&self, other: &Self) -> bool {
self.capacity.eq(&other.capacity)
&& self.sequence_threshold.eq(&other.sequence_threshold)
&& self.iter().eq(other.iter())
}
}

pub struct HashSetIterator<'a> {
hash_set: &'a HashSet,
current: usize,
Expand All @@ -567,11 +575,14 @@ impl<'a> Iterator for HashSetIterator<'a> {

#[cfg(test)]
mod test {
use super::*;
use ark_bn254::Fr;
use ark_ff::UniformRand;
use rand::{thread_rng, Rng};

use crate::zero_copy::HashSetZeroCopy;

use super::*;

#[test]
fn test_is_valid() {
let mut rng = thread_rng();
Expand Down Expand Up @@ -756,17 +767,27 @@ mod test {
assert_eq!(hs.contains(&nullifier, Some(seq)).unwrap(), false);
hs.insert(&nullifier, seq as usize).unwrap();
assert_eq!(hs.contains(&nullifier, Some(seq)).unwrap(), true);

let nullifier_bytes = bigint_to_be_bytes_array(&nullifier).unwrap();

let element = hs
.find_element(&nullifier, Some(seq))
.unwrap()
.unwrap()
.0
.clone();
assert_eq!(
hs.find_element(&nullifier, Some(seq))
.unwrap()
.unwrap()
.0
.clone(),
element,
HashSetCell {
value: bigint_to_be_bytes_array(&nullifier).unwrap(),
sequence_number: None,
}
);
assert_eq!(element.value_bytes(), nullifier_bytes);
assert_eq!(&element.value_biguint(), nullifier);
assert_eq!(element.sequence_number(), None);
assert!(!element.is_marked());
assert!(element.is_valid(seq));

hs.mark_with_sequence_number(&nullifier, seq).unwrap();
let element = hs
Expand All @@ -779,10 +800,15 @@ mod test {
assert_eq!(
element,
HashSetCell {
value: bigint_to_be_bytes_array(&nullifier).unwrap(),
value: nullifier_bytes,
sequence_number: Some(2400 + seq)
}
);
assert_eq!(element.value_bytes(), nullifier_bytes);
assert_eq!(&element.value_biguint(), nullifier);
assert_eq!(element.sequence_number(), Some(2400 + seq));
assert!(element.is_marked());
assert!(element.is_valid(seq));

// Trying to insert the same nullifier, before reaching the
// sequence threshold, should fail.
Expand All @@ -796,11 +822,56 @@ mod test {
}
}

fn hash_set_from_bytes_copy<
const CAPACITY: usize,
const SEQUENCE_THRESHOLD: usize,
const OPERATIONS: usize,
>() {
let mut hs_1 = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();

let mut rng = thread_rng();

// Create a buffer with random bytes.
let mut bytes = vec![0u8; HashSet::size_in_account(CAPACITY)];
rng.fill(bytes.as_mut_slice());

// Initialize a hash set on top of a byte slice.
{
let mut hs_2 = unsafe {
HashSetZeroCopy::from_bytes_zero_copy_init(&mut bytes, CAPACITY, SEQUENCE_THRESHOLD)
.unwrap()
};

for seq in 0..OPERATIONS {
let value = BigUint::from(Fr::rand(&mut rng));
hs_1.insert(&value, seq).unwrap();
hs_2.insert(&value, seq).unwrap();
}

assert_eq!(hs_1, *hs_2);
}

// Create a copy on top of a byte slice.
{
let hs_2 = unsafe { HashSet::from_bytes_copy(&mut bytes).unwrap() };
assert_eq!(hs_1, hs_2);
}
}

#[test]
fn test_hash_set_from_bytes_copy_6857_2400_3600() {
hash_set_from_bytes_copy::<6857, 2400, 3600>()
}

#[test]
fn test_hash_set_from_bytes_copy_9601_2400_5000() {
hash_set_from_bytes_copy::<9601, 2400, 5000>()
}

fn hash_set_full<const CAPACITY: usize, const SEQUENCE_THRESHOLD: usize>() {
for _ in 0..100 {
let mut hs = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();

// let mut rng = StdRng::seed_from_u64(1);
let mut rng = rand::thread_rng();

// Insert as many values as possible. The important point is to
Expand Down Expand Up @@ -1002,4 +1073,109 @@ mod test {
fn test_hash_set_iter_random_9601_2400() {
hash_set_iter_random::<5000, 9601, 2400>()
}

#[test]
fn test_hash_set_get_bucket() {
let mut hs = HashSet::new(6857, 2400).unwrap();

// Insert incremental elements, so they end up being in the same
// sequence in the hash set.
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
hs.insert(&bn_i, i).unwrap();
}

for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
let element = hs.get_bucket(i).unwrap().unwrap();
assert_eq!(element.value_biguint(), bn_i);
}
// Unused cells within the capacity should be `Some(None)`.
for i in 3600..6857 {
assert!(hs.get_bucket(i).unwrap().is_none());
}
// Cells over the capacity should be `None`.
for i in 6857..10_000 {
assert!(hs.get_bucket(i).is_none());
}
}

#[test]
fn test_hash_set_get_bucket_mut() {
let mut hs = HashSet::new(6857, 2400).unwrap();

// Insert incremental elements, so they end up being in the same
// sequence in the hash set.
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
hs.insert(&bn_i, i).unwrap();
}

for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
let element = hs.get_bucket_mut(i).unwrap();
assert_eq!(element.unwrap().value_biguint(), bn_i);

// "Nullify" the element.
*element = Some(HashSetCell {
value: [0_u8; 32],
sequence_number: None,
});
}
for i in 0..3600 {
let element = hs.get_bucket_mut(i).unwrap().unwrap();
assert_eq!(element.value_bytes(), [0_u8; 32]);
}
// Unused cells within the capacity should be `Some(None)`.
for i in 3600..6857 {
assert!(hs.get_bucket_mut(i).unwrap().is_none());
}
// Cells over the capacity should be `None`.
for i in 6857..10_000 {
assert!(hs.get_bucket_mut(i).is_none());
}
}

#[test]
fn test_hash_set_get_unmarked_bucket() {
let mut hs = HashSet::new(6857, 2400).unwrap();

// Insert incremental elements, so they end up being in the same
// sequence in the hash set.
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
hs.insert(&bn_i, i).unwrap();
}

for i in 0..3600 {
let element = hs.get_unmarked_bucket(i);
assert!(element.is_some());
}

// Mark the elements.
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
hs.mark_with_sequence_number(&bn_i, i).unwrap();
}

for i in 0..3600 {
let element = hs.get_unmarked_bucket(i);
assert!(element.is_none());
}
}

#[test]
fn test_hash_set_first_no_seq() {
let mut hs = HashSet::new(6857, 2400).unwrap();

// Insert incremental elements, so they end up being in the same
// sequence in the hash set.
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
hs.insert(&bn_i, i).unwrap();

let element = hs.first_no_seq().unwrap().unwrap();
assert_eq!(element.0.value_biguint(), 0.to_biguint().unwrap());
}
}
}
73 changes: 22 additions & 51 deletions merkle-tree/hash-set/src/zero_copy.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::{marker::PhantomData, mem, ptr::NonNull};

use num_bigint::BigUint;
use std::{
marker::PhantomData,
mem,
ops::{Deref, DerefMut},
ptr::NonNull,
};

use crate::{HashSet, HashSetCell, HashSetError};

Expand Down Expand Up @@ -114,53 +117,6 @@ impl<'a> HashSetZeroCopy<'a> {

Ok(hash_set)
}

/// Returns a reference to a bucket under the given `index`. Does not check
/// the validity.
pub fn get_bucket(&self, index: usize) -> Option<&Option<HashSetCell>> {
self.hash_set.get_bucket(index)
}

/// Returns a mutable reference to a bucket under the given `index`. Does
/// not check the validity.
pub fn get_bucket_mut(&mut self, index: usize) -> Option<&mut Option<HashSetCell>> {
self.hash_set.get_bucket_mut(index)
}

/// Returns a reference to an unmarked bucket under the given index. If the
/// bucket is marked, returns `None`.
pub fn get_unmarked_bucket(&self, index: usize) -> Option<&Option<HashSetCell>> {
self.hash_set.get_unmarked_bucket(index)
}

/// Inserts a value into the hash set.
pub fn insert(&mut self, value: &BigUint, sequence_number: usize) -> Result<(), HashSetError> {
self.hash_set.insert(value, sequence_number)
}

/// Returns a first available element.
pub fn first(&self, sequence_number: usize) -> Result<Option<&HashSetCell>, HashSetError> {
self.hash_set.first(sequence_number)
}

/// Check if the hash set contains a value.
pub fn contains(
&self,
value: &BigUint,
sequence_number: Option<usize>,
) -> Result<bool, HashSetError> {
self.hash_set.contains(value, sequence_number)
}

/// Marks the given element with a given sequence number.
pub fn mark_with_sequence_number(
&mut self,
value: &BigUint,
sequence_number: usize,
) -> Result<(), HashSetError> {
self.hash_set
.mark_with_sequence_number(value, sequence_number)
}
}

impl<'a> Drop for HashSetZeroCopy<'a> {
Expand All @@ -178,10 +134,25 @@ impl<'a> Drop for HashSetZeroCopy<'a> {
}
}

impl<'a> Deref for HashSetZeroCopy<'a> {
type Target = HashSet;

fn deref(&self) -> &Self::Target {
&self.hash_set
}
}

impl<'a> DerefMut for HashSetZeroCopy<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.hash_set
}
}

#[cfg(test)]
mod test {
use ark_bn254::Fr;
use ark_ff::UniformRand;
use num_bigint::BigUint;
use rand::{thread_rng, Rng};

use super::*;
Expand All @@ -192,7 +163,7 @@ mod test {
const SEQUENCE_THRESHOLD: usize = 2400;

// Create a buffer with random bytes.
let mut bytes = vec![0u8; HashSet::size_in_account(VALUES).unwrap()];
let mut bytes = vec![0u8; HashSet::size_in_account(VALUES)];
thread_rng().fill(bytes.as_mut_slice());

// Create random nullifiers.
Expand Down
3 changes: 1 addition & 2 deletions programs/account-compression/src/state/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ impl<'info> GroupAccounts<'info> for InsertIntoQueues<'info> {

impl QueueAccount {
pub fn size(capacity: usize) -> Result<usize> {
Ok(8 + mem::size_of::<Self>()
+ HashSet::size_in_account(capacity).map_err(ProgramError::from)?)
Ok(8 + mem::size_of::<Self>() + HashSet::size_in_account(capacity))
}
}

Expand Down

0 comments on commit 269730b

Please sign in to comment.