diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index d524714..1f2b21a 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -16,10 +16,11 @@ k256 = { version = "0.13.3", features = ["arithmetic", "hash2curve", "expose-fie pse-poseidon = { git = "https://github.com/aerius-labs/pse-poseidon.git", branch = "feat/stateless-hash" } rand = "0.8.5" sha2 = "0.10.8" -indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merkle-tree-halo2.git" , branch = "feat/aggr-indexed" } +indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merkle-tree-halo2.git" , branch = "fix/imt" } rand_chacha = "0.3.1" snark-verifier-sdk = { git = "https://github.com/aerius-labs/snark-verifier.git", branch = "feat/custom" } ark-std = "0.4.0" +num-traits = "0.2.18" voter = { path = "../voter" } voter-tests = { path = "../voter_tests" } diff --git a/aggregator/benches/state_transition_circuit.rs b/aggregator/benches/state_transition_circuit.rs index b534168..8b0b9de 100644 --- a/aggregator/benches/state_transition_circuit.rs +++ b/aggregator/benches/state_transition_circuit.rs @@ -1,5 +1,7 @@ use aggregator::state_transition::{state_transition_circuit, StateTransitionInput}; -use aggregator::utils::generate_random_state_transition_circuit_inputs; +use aggregator::utils::{ + assign_big_uint, compress_native_nullifier, generate_random_state_transition_circuit_inputs, +}; use ark_std::{end_timer, start_timer}; use halo2_base::gates::circuit::BaseCircuitParams; use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; @@ -42,7 +44,37 @@ fn state_transition_circuit_bench( let range = builder.range_chip(); let mut public_inputs = Vec::>::new(); - state_transition_circuit(builder.main(0), &range, input, &mut public_inputs); + let mut state_transition_vec = Vec::>::new(); + let ctx = builder.main(0); + + state_transition_vec.extend( + compress_native_nullifier(&input.nullifier) + .iter() + .map(|&x| ctx.load_witness(x)), + ); + + let enc_g = input.pk_enc.g; + let enc_h = input.pk_enc.n; + + state_transition_vec.extend(assign_big_uint(ctx, &enc_g)); + state_transition_vec.extend(assign_big_uint(ctx, &enc_h)); + + state_transition_vec.extend( + input + .incoming_vote + .iter() + .flat_map(|x| assign_big_uint(ctx, x)), + ); + + state_transition_vec.extend(input.prev_vote.iter().flat_map(|x| assign_big_uint(ctx, x))); + + state_transition_circuit( + ctx, + &range, + input.nullifier_tree, + state_transition_vec, + &mut public_inputs, + ); end_timer!(start0); if !stage.witness_gen_only() { diff --git a/aggregator/benches/wrapper_circuit.rs b/aggregator/benches/wrapper_circuit.rs index d47d466..0f22d38 100644 --- a/aggregator/benches/wrapper_circuit.rs +++ b/aggregator/benches/wrapper_circuit.rs @@ -1,21 +1,19 @@ -use aggregator::state_transition::StateTransitionCircuit; use aggregator::utils::generate_wrapper_circuit_input; use aggregator::wrapper::common::gen_dummy_snark; use aggregator::wrapper::common::gen_pk; +use aggregator::wrapper::common::gen_proof; use aggregator::wrapper::common::gen_snark; use aggregator::wrapper::recursion::RecursionCircuit; use halo2_base::gates::circuit::BaseCircuitParams; -use halo2_base::gates::circuit::CircuitBuilderStage; +use halo2_base::halo2_proofs::{halo2curves::bn256::Fr, plonk::*}; +use halo2_base::utils::decompose_biguint; use halo2_base::utils::fs::gen_srs; -use halo2_base::{ - halo2_proofs::{halo2curves::bn256::Fr, plonk::*}, - utils::testing::gen_proof, -}; use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; use pprof::criterion::{Output, PProfProfiler}; +use snark_verifier_sdk::CircuitExt; use voter::VoterCircuit; const K: u32 = 22; @@ -37,27 +35,6 @@ fn bench(c: &mut Criterion) { let voter_pk = gen_pk(&voter_params, &voter_circuit); let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); - // Generating state transition proof - let state_transition_config = BaseCircuitParams { - k: 15, - num_advice_per_phase: vec![3], - num_lookup_advice_per_phase: vec![1, 0, 0], - num_fixed: 1, - lookup_bits: Some(14), - num_instance_columns: 1, - }; - let state_transition_params = gen_srs(15); - let state_transition_circuit = StateTransitionCircuit::new( - state_transition_config.clone(), - state_transition_inputs[0].clone(), - ); - let state_transition_pk = gen_pk(&state_transition_params, &state_transition_circuit); - let state_transition_snark = gen_snark( - &state_transition_params, - &state_transition_pk, - state_transition_circuit, - ); - let recursion_config = BaseCircuitParams { k: K as usize, num_advice_per_phase: vec![4], @@ -68,6 +45,11 @@ fn bench(c: &mut Criterion) { }; let recursion_params = gen_srs(K); + let mut init_vote = Vec::::new(); + for x in state_transition_inputs[0].prev_vote.iter() { + init_vote.extend(decompose_biguint::(x, 4, 88)); + } + // Init Base Instances let mut base_instances = [ Fr::zero(), // preprocessed_digest @@ -77,24 +59,18 @@ fn bench(c: &mut Criterion) { voter_snark.instances[0][3], ] .to_vec(); - base_instances.extend(state_transition_snark.instances[0][4..24].iter()); // init_vote + base_instances.extend(init_vote); // init_vote base_instances.extend([ - state_transition_snark.instances[0][68], // nullifier_old_root - state_transition_snark.instances[0][68], // nullifier_new_root - voter_snark.instances[0][28], // membership_root - voter_snark.instances[0][29], // proposal_id - Fr::from(0), // round + state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_old_root + state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_new_root + voter_snark.instances[0][28], // membership_root + voter_snark.instances[0][29], // proposal_id + Fr::from(0), // round ]); let recursion_circuit = RecursionCircuit::new( - CircuitBuilderStage::Keygen, &recursion_params, gen_dummy_snark::>(&voter_params, Some(voter_pk.get_vk()), voter_config), - gen_dummy_snark::>( - &state_transition_params, - Some(state_transition_pk.get_vk()), - state_transition_config, - ), RecursionCircuit::initial_snark( &recursion_params, None, @@ -103,30 +79,24 @@ fn bench(c: &mut Criterion) { ), 0, recursion_config, + state_transition_inputs[0].clone(), ); let pk = gen_pk(&recursion_params, &recursion_circuit); let config_params = recursion_circuit.inner().params(); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); + group.bench_with_input( BenchmarkId::new("wrapper circuit", K), - &( - &recursion_params, - &pk, - &voter_snark, - &state_transition_snark, - ), - |bencher, &(params, pk, voter_snark, state_transition_snark)| { + &(&recursion_params, &pk, &voter_snark), + |bencher, &(params, pk, voter_snark)| { let cloned_voter_snark = voter_snark; - let cloned_state_transition_snark = state_transition_snark; bencher.iter(|| { let cloned_config_params = config_params.clone(); let circuit = RecursionCircuit::new( - CircuitBuilderStage::Prover, ¶ms, cloned_voter_snark.clone(), - cloned_state_transition_snark.clone(), RecursionCircuit::initial_snark( ¶ms, None, @@ -135,9 +105,12 @@ fn bench(c: &mut Criterion) { ), 0, cloned_config_params, + state_transition_inputs[0].clone(), ); - - gen_proof(params, pk, circuit); + println!("reached proof generation"); + let instances = circuit.inner().instances().clone(); + gen_proof(params, pk, circuit, instances); + println!("completed proof generation") }) }, ); diff --git a/aggregator/src/state_transition.rs b/aggregator/src/state_transition.rs index eb88ea8..33e751d 100644 --- a/aggregator/src/state_transition.rs +++ b/aggregator/src/state_transition.rs @@ -1,32 +1,25 @@ -use halo2_base::gates::circuit::builder::BaseCircuitBuilder; -use halo2_base::gates::circuit::{BaseCircuitParams, BaseConfig}; -use halo2_base::gates::GateInstructions; -use halo2_base::halo2_proofs::circuit::{Layouter, SimpleFloorPlanner}; -use halo2_base::halo2_proofs::plonk::{Circuit, ConstraintSystem, Error}; +use biguint_halo2::big_uint::chip::BigUintChip; +use biguint_halo2::big_uint::{AssignedBigUint, Fresh}; +use halo2_base::halo2_proofs::halo2curves::secp256k1::Secp256k1Affine; use halo2_base::poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}; +use halo2_base::utils::fe_to_biguint; use halo2_base::{ gates::{RangeChip, RangeInstructions}, halo2_proofs::circuit::Value, utils::BigPrimeField, AssignedValue, Context, }; -use halo2_ecc::ecc::EccChip; -use halo2_ecc::fields::fp::FpChip; +use halo2_ecc::bigint::OverflowInteger; use indexed_merkle_tree_halo2::indexed_merkle_tree::{insert_leaf, IndexedMerkleTreeLeaf}; use indexed_merkle_tree_halo2::utils::IndexedMerkleTreeLeaf as IMTLeaf; use num_bigint::BigUint; - -use biguint_halo2::big_uint::chip::BigUintChip; -use halo2_base::halo2_proofs::halo2curves::secp256k1::{Fp, Secp256k1Affine}; use paillier_chip::paillier::{EncryptionPublicKeyAssigned, PaillierChip}; use serde::{Deserialize, Serialize}; -use voter::{compress_nullifier, CircuitExt, EncryptionPublicKey}; +use voter::EncryptionPublicKey; const ENC_BIT_LEN: usize = 176; const LIMB_BIT_LEN: usize = 88; -//TODO: Constrain the nullifier hash using x and y limbs - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct IndexedMerkleTreeInput { old_root: F, @@ -38,7 +31,6 @@ pub struct IndexedMerkleTreeInput { new_leaf_index: F, new_leaf_proof: Vec, new_leaf_proof_helper: Vec, - is_new_leaf_largest: F, } impl IndexedMerkleTreeInput { @@ -52,7 +44,6 @@ impl IndexedMerkleTreeInput { new_leaf_index: F, new_leaf_proof: Vec, new_leaf_proof_helper: Vec, - is_new_leaf_largest: F, ) -> Self { Self { old_root, @@ -64,11 +55,21 @@ impl IndexedMerkleTreeInput { new_leaf_index, new_leaf_proof, new_leaf_proof_helper, - is_new_leaf_largest, } } + pub fn get_old_root(&self) -> F { + self.old_root + } + pub fn get_new_root(&self) -> F { + self.new_root + } +} +fn limbs_to_biguint(x: Vec) -> BigUint { + x.iter() + .enumerate() + .map(|(i, limb)| fe_to_biguint(limb) * BigUint::from(2u64).pow(88 * (i as u32))) + .sum() } - #[derive(Debug, Clone)] pub struct StateTransitionInput { pub pk_enc: EncryptionPublicKey, @@ -98,7 +99,8 @@ impl StateTransitionInput { pub fn state_transition_circuit( ctx: &mut Context, range: &RangeChip, - input: StateTransitionInput, + nullifier_tree: IndexedMerkleTreeInput, + input_vec: Vec>, public_inputs: &mut Vec>, ) { let gate = range.gate(); @@ -108,44 +110,72 @@ pub fn state_transition_circuit( let biguint_chip = BigUintChip::construct(range, LIMB_BIT_LEN); let paillier_chip = PaillierChip::construct(&biguint_chip, ENC_BIT_LEN); - let fp_chip = FpChip::::new(range, LIMB_BIT_LEN, 3); - let ecc_chip = EccChip::>::new(&fp_chip); + let nullifier_hash = hasher.hash_fix_len_array(ctx, gate, &input_vec[0..4]); - let nullifier = ecc_chip.load_private_unchecked(ctx, (input.nullifier.x, input.nullifier.y)); - let compressed_nullifier = compress_nullifier(ctx, range, &nullifier); - let nullifier_hash = hasher.hash_fix_len_array(ctx, gate, &compressed_nullifier); - - let n_assigned = biguint_chip - .assign_integer(ctx, Value::known(input.pk_enc.n.clone()), ENC_BIT_LEN) - .unwrap(); - - let g_assigned = biguint_chip - .assign_integer(ctx, Value::known(input.pk_enc.g.clone()), ENC_BIT_LEN) - .unwrap(); + let n_fr = input_vec[4..6] + .iter() + .map(|vote| *vote.value()) + .collect::>(); + let n_assigned = AssignedBigUint::::new( + OverflowInteger::new(input_vec[4..6].to_vec(), 88), + Value::known(limbs_to_biguint(n_fr)), + ); + let g_fr = input_vec[6..8] + .iter() + .map(|vote| *vote.value()) + .collect::>(); + let g_assigned = AssignedBigUint::::new( + OverflowInteger::new(input_vec[6..8].to_vec(), 88), + Value::known(limbs_to_biguint(g_fr)), + ); let pk_enc = EncryptionPublicKeyAssigned { n: n_assigned, g: g_assigned, }; - - let incoming_vote = input - .incoming_vote + let incoming_vote_fr: Vec = input_vec[8..28].iter().map(|x| *x.value()).collect(); + let incoming_vote_biguint = incoming_vote_fr + .chunks(4) + .into_iter() + .map(|chunk| limbs_to_biguint(chunk.to_vec())) + .collect::>(); + let incoming_vote_overflow_int = input_vec[8..28] + .chunks(4) + .into_iter() + .map(|chunk| OverflowInteger::new(chunk.to_vec(), 88)) + .collect::>>(); + let incoming_vote = incoming_vote_overflow_int .iter() - .map(|x| { - biguint_chip - .assign_integer(ctx, Value::known(x.clone()), ENC_BIT_LEN * 2) - .unwrap() + .enumerate() + .map(|(i, over_flow)| { + AssignedBigUint::new( + over_flow.clone(), + Value::known(incoming_vote_biguint[i].clone()), + ) }) - .collect::>(); - let prev_vote = input - .prev_vote + .collect::>>(); + + let prev_vote_fr: Vec = input_vec[28..48].iter().map(|x| *x.value()).collect(); + let prev_vote_biguint = prev_vote_fr + .chunks(4) + .into_iter() + .map(|chunk| limbs_to_biguint(chunk.to_vec())) + .collect::>(); + let prev_vote_overflow_int = input_vec[28..48] + .chunks(4) + .into_iter() + .map(|chunk| OverflowInteger::new(chunk.to_vec(), 88)) + .collect::>>(); + let prev_vote = prev_vote_overflow_int .iter() - .map(|x| { - biguint_chip - .assign_integer(ctx, Value::known(x.clone()), ENC_BIT_LEN * 2) - .unwrap() + .enumerate() + .map(|(i, over_flow)| { + AssignedBigUint::new( + over_flow.clone(), + Value::known(prev_vote_biguint[i].clone()), + ) }) - .collect::>(); + .collect::>>(); // Step 1: Aggregate the votes let aggr_vote = incoming_vote @@ -155,46 +185,41 @@ pub fn state_transition_circuit( .collect::>(); // Step 2: Update the nullifier tree - let val = ctx.load_witness(input.nullifier_tree.low_leaf.val); - let next_val = ctx.load_witness(input.nullifier_tree.low_leaf.next_val); - let next_idx = ctx.load_witness(input.nullifier_tree.low_leaf.next_idx); + let val = ctx.load_witness(nullifier_tree.low_leaf.val); + let next_val = ctx.load_witness(nullifier_tree.low_leaf.next_val); + let next_idx = ctx.load_witness(nullifier_tree.low_leaf.next_idx); - let old_root = ctx.load_witness(input.nullifier_tree.old_root); + let old_root = ctx.load_witness(nullifier_tree.old_root); let low_leaf = IndexedMerkleTreeLeaf::new(val, next_val, next_idx); - let new_root = ctx.load_witness(input.nullifier_tree.new_root); + let new_root = ctx.load_witness(nullifier_tree.new_root); - let val = ctx.load_witness(input.nullifier_tree.new_leaf.val); - assert_eq!(val.value(), nullifier_hash.value()); + let val = ctx.load_witness(nullifier_tree.new_leaf.val); + //assert_eq!(val.value(), nullifier_hash.value()); ctx.constrain_equal(&val, &nullifier_hash); - let next_val = ctx.load_witness(input.nullifier_tree.new_leaf.next_val); - let next_idx = ctx.load_witness(input.nullifier_tree.new_leaf.next_idx); + let next_val = ctx.load_witness(nullifier_tree.new_leaf.next_val); + let next_idx = ctx.load_witness(nullifier_tree.new_leaf.next_idx); let new_leaf = IndexedMerkleTreeLeaf::new(val, next_val, next_idx); - let new_leaf_index = ctx.load_witness(input.nullifier_tree.new_leaf_index); - let is_new_leaf_largest = ctx.load_witness(input.nullifier_tree.is_new_leaf_largest); + let new_leaf_index = ctx.load_witness(nullifier_tree.new_leaf_index); - let low_leaf_proof = input - .nullifier_tree + let low_leaf_proof = nullifier_tree .low_leaf_proof .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let low_leaf_proof_helper = input - .nullifier_tree + let low_leaf_proof_helper = nullifier_tree .low_leaf_proof_helper .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let new_leaf_proof = input - .nullifier_tree + let new_leaf_proof = nullifier_tree .new_leaf_proof .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let new_leaf_proof_helper = input - .nullifier_tree + let new_leaf_proof_helper = nullifier_tree .new_leaf_proof_helper .iter() .map(|x| ctx.load_witness(*x)) @@ -213,142 +238,15 @@ pub fn state_transition_circuit( &new_leaf_index, &new_leaf_proof, &new_leaf_proof_helper, - &is_new_leaf_largest, ); - // PK_ENC N - public_inputs.extend(pk_enc.n.limbs()); - - // PK_ENC G - public_inputs.extend(pk_enc.g.limbs()); - - // PREV_VOTE - for enc_vote in prev_vote { - public_inputs.extend(enc_vote.limbs()); - } - - // INCOMING_VOTE - for enc_vote in incoming_vote { - public_inputs.extend(enc_vote.limbs()); - } - - // AGGR_VOTE for enc_vote in aggr_vote { public_inputs.extend(enc_vote.limbs()); } - // NULLIFIER - public_inputs.extend(compressed_nullifier); - // NULLIFIER_OLD_ROOT public_inputs.extend([old_root]); // NULLIFIER_NEW_ROOT public_inputs.extend([new_root]); } - -pub struct StateTransitionCircuit { - input: StateTransitionInput, - pub inner: BaseCircuitBuilder, -} - -impl StateTransitionCircuit { - pub fn new(config: BaseCircuitParams, input: StateTransitionInput) -> Self { - let mut inner = BaseCircuitBuilder::default(); - inner.set_params(config); - - let range = inner.range_chip(); - let ctx = inner.main(0); - - let mut public_inputs = Vec::>::new(); - state_transition_circuit(ctx, &range, input.clone(), &mut public_inputs); - inner.assigned_instances[0].extend(public_inputs); - inner.calculate_params(Some(10)); - Self { input, inner } - } -} - -impl Circuit for StateTransitionCircuit { - type Config = BaseConfig; - type FloorPlanner = SimpleFloorPlanner; - type Params = BaseCircuitParams; - - fn params(&self) -> Self::Params { - self.inner.params() - } - - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { - BaseCircuitBuilder::configure_with_params(meta, params) - } - - fn configure(_: &mut ConstraintSystem) -> Self::Config { - unreachable!() - } - - fn synthesize(&self, config: Self::Config, layouter: impl Layouter) -> Result<(), Error> { - self.inner.synthesize(config, layouter) - } -} - -impl CircuitExt for StateTransitionCircuit { - fn num_instance() -> Vec { - vec![70] - } - - fn instances(&self) -> Vec> { - vec![self.inner.assigned_instances[0] - .iter() - .map(|instance| *instance.value()) - .collect()] - } -} - -#[cfg(test)] -mod test { - use halo2_base::{ - gates::circuit::BaseCircuitParams, - halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}, - utils::testing::base_test, - AssignedValue, - }; - use voter::CircuitExt; - - use crate::utils::generate_wrapper_circuit_input; - - use super::{state_transition_circuit, StateTransitionCircuit}; - - #[test] - fn test_state_transition_circuit() { - let (_, multiple_input) = generate_wrapper_circuit_input(4); - - let config = BaseCircuitParams { - k: 15, - num_advice_per_phase: vec![3], - num_lookup_advice_per_phase: vec![1, 0, 0], - num_fixed: 1, - lookup_bits: Some(14), - num_instance_columns: 1, - }; - - for (round, input) in multiple_input.iter().enumerate() { - println!("------round[{}]--------", round); - - let circuit = StateTransitionCircuit::new(config.clone(), input.clone()); - let prover = MockProver::run(15, &circuit, circuit.instances()).unwrap(); - prover.verify().unwrap(); - } - - base_test() - .k(19) - .lookup_bits(18) - .expect_satisfied(true) - .run(|ctx, range| { - let mut public_inputs = Vec::>::new(); - state_transition_circuit(ctx, range, multiple_input[0].clone(), &mut public_inputs) - }); - } -} diff --git a/aggregator/src/utils.rs b/aggregator/src/utils.rs index d633ee6..d93136c 100644 --- a/aggregator/src/utils.rs +++ b/aggregator/src/utils.rs @@ -3,7 +3,8 @@ use halo2_base::halo2_proofs::halo2curves::bn256::Fr; use halo2_base::halo2_proofs::halo2curves::group::Curve; use halo2_base::halo2_proofs::halo2curves::secp256k1::{Fq, Secp256k1, Secp256k1Affine}; use halo2_base::halo2_proofs::halo2curves::secq256k1::Fp; -use halo2_base::utils::{fe_to_biguint, ScalarField}; +use halo2_base::utils::{fe_to_biguint, BigPrimeField, ScalarField}; +use halo2_base::{AssignedValue, Context}; use halo2_ecc::*; use num_bigint::{BigUint, RandBigInt}; use paillier_chip::paillier::{paillier_add_native, paillier_enc_native}; @@ -32,7 +33,7 @@ fn generate_voter_circuit_inputs( pk_voter: Secp256k1Affine, vote: Vec, r_enc: Vec, - members_tree: &MerkleTree<'_, Fr, T, RATE>, + members_tree: &IndexedMerkleTree, round: usize, ) -> VoterCircuitInput { let membership_root = members_tree.get_root(); @@ -115,7 +116,11 @@ fn generate_state_transition_circuit_inputs( let new_val = native_hasher.squeeze_and_reset(); let mut tree = - IndexedMerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); + IndexedMerkleTree::::new_default_leaf(nullifier_tree_preimages.len()); + + for i in 0..nullifier_tree_leaves.len() { + tree.insert_leaf(&mut native_hasher, nullifier_tree_leaves[i], i); + } let old_root = tree.get_root(); @@ -123,9 +128,10 @@ fn generate_state_transition_circuit_inputs( update_idx_leaf(nullifier_tree_preimages.clone(), new_val, round); let low_leaf = nullifier_tree_preimages[low_leaf_idx].clone(); let (low_leaf_proof, low_leaf_proof_helper) = tree.get_proof(low_leaf_idx); + assert_eq!( tree.verify_proof( - &leaves[low_leaf_idx], + &mut native_hasher, low_leaf_idx, &tree.get_root(), &low_leaf_proof @@ -147,11 +153,12 @@ fn generate_state_transition_circuit_inputs( updated_idx_leaves[round as usize].next_idx, ]); leaves[round as usize] = new_native_hasher.squeeze_and_reset(); - tree = IndexedMerkleTree::::new(&mut new_native_hasher, leaves.clone()).unwrap(); + let (new_leaf_proof, new_leaf_proof_helper) = tree.get_proof(round as usize); assert_eq!( tree.verify_proof( - &leaves[round as usize], + &mut new_native_hasher, + // &leaves[round as usize], round as usize, &tree.get_root(), &new_leaf_proof @@ -166,11 +173,6 @@ fn generate_state_transition_circuit_inputs( next_idx: updated_idx_leaves[round as usize].next_idx, }; let new_leaf_index = Fr::from(round); - let is_new_leaf_largest = if new_leaf.next_val == Fr::zero() { - Fr::one() - } else { - Fr::zero() - }; let idx_input = IndexedMerkleTreeInput::new( old_root, @@ -182,7 +184,6 @@ fn generate_state_transition_circuit_inputs( new_leaf_index, new_leaf_proof, new_leaf_proof_helper, - is_new_leaf_largest, ); let input = StateTransitionInput::new( @@ -253,12 +254,17 @@ pub fn generate_wrapper_circuit_input( members_tree_leaves.push(native_hasher.squeeze_and_reset()); } + let mut members_tree = IndexedMerkleTree::::new_default_leaf(8); + for _ in num_round..8 { native_hasher.update(&[Fr::ZERO]); - members_tree_leaves.push(native_hasher.squeeze_and_reset()); + let hash = native_hasher.squeeze_and_reset(); + members_tree_leaves.push(hash.clone()); } - let members_tree = MerkleTree::new(&mut native_hasher, members_tree_leaves.clone()).unwrap(); + for i in 0..members_tree_leaves.len() { + members_tree.insert_leaf(&mut native_hasher, members_tree_leaves[i], i); + } let mut rng = thread_rng(); let mut prev_vote = Vec::::new(); @@ -383,6 +389,7 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput let nullifier_affine = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let mut native_hasher = Poseidon::::new(R_F, R_P); + let mut tree = IndexedMerkleTree::::new_default_leaf(tree_size as usize); // Filling leaves with dfault values. for _ in 0..tree_size { @@ -392,8 +399,8 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput let nullifier_compress = compress_native_nullifier(&nullifier_affine); native_hasher.update(&nullifier_compress); let new_val = native_hasher.squeeze_and_reset(); - let mut tree = - IndexedMerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); + + tree.insert_leaf(&mut native_hasher, leaves[0], 0); let old_root = tree.get_root(); let low_leaf = IMTLeaf:: { @@ -402,8 +409,9 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput next_idx: Fr::from(0u64), }; let (low_leaf_proof, low_leaf_proof_helper) = tree.get_proof(0); + assert_eq!( - tree.verify_proof(&leaves[0], 0, &tree.get_root(), &low_leaf_proof), + tree.verify_proof(&mut native_hasher, 0, &tree.get_root(), &low_leaf_proof), true ); @@ -423,12 +431,9 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput native_hasher.update(&[new_val, Fr::from(0u64), Fr::from(0u64)]); leaves[1] = native_hasher.squeeze_and_reset(); - tree = IndexedMerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); - - let (new_low_leaf_proof, _) = tree.get_proof(0); let (new_leaf_proof, new_leaf_proof_helper) = tree.get_proof(1); assert_eq!( - tree.verify_proof(&leaves[1], 1, &tree.get_root(), &new_leaf_proof), + tree.verify_proof(&mut native_hasher, 1, &tree.get_root(), &new_leaf_proof), true ); @@ -439,7 +444,6 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput next_idx: Fr::from(0u64), }; let new_leaf_index = Fr::from(1u64); - let is_new_leaf_largest = Fr::from(true); let idx_input = IndexedMerkleTreeInput::new( old_root, @@ -451,7 +455,6 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput new_leaf_index, new_leaf_proof, new_leaf_proof_helper, - is_new_leaf_largest, ); let mut rng = thread_rng(); @@ -493,3 +496,14 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput input } + +pub fn assign_big_uint( + ctx: &mut Context, + value: &BigUint, +) -> Vec> { + let limbs = value.to_u64_digits(); + limbs + .into_iter() + .map(|limb| ctx.load_witness(F::from(limb))) + .collect() +} diff --git a/aggregator/src/wrapper.rs b/aggregator/src/wrapper.rs index 0353862..61dae26 100644 --- a/aggregator/src/wrapper.rs +++ b/aggregator/src/wrapper.rs @@ -277,13 +277,11 @@ pub mod common { } pub mod recursion { - use std::mem; + use std::{mem, vec}; use halo2_base::{ gates::{ - circuit::{ - builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig, CircuitBuilderStage, - }, + circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig}, GateInstructions, RangeInstructions, }, AssignedValue, @@ -292,7 +290,7 @@ pub mod recursion { use snark_verifier_sdk::snark_verifier::loader::halo2::{EccInstructions, IntegerInstructions}; use voter::{CircuitExt, VoterCircuit}; - use crate::state_transition::StateTransitionCircuit; + use crate::state_transition::{state_transition_circuit, StateTransitionInput}; use super::*; @@ -409,7 +407,6 @@ pub mod recursion { svk: Svk, default_accumulator: KzgAccumulator, voter: Snark, - state_transition: Snark, previous: Snark, #[allow(dead_code)] round: usize, @@ -431,13 +428,12 @@ pub mod recursion { const ROUND_ROW: usize = 4 * LIMBS + 29; pub fn new( - stage: CircuitBuilderStage, params: &ParamsKZG, voter: Snark, - state_transition: Snark, previous: Snark, round: usize, config_params: BaseCircuitParams, + state_transition_input: StateTransitionInput, ) -> Self { let svk = params.get_g()[0].into(); let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); @@ -459,7 +455,6 @@ pub mod recursion { let accumulators = iter::empty() .chain(succinct_verify(&voter)) - .chain(succinct_verify(&state_transition)) .chain( (round > 0) .then(|| succinct_verify(&previous)) @@ -490,7 +485,6 @@ pub mod recursion { .collect_vec(); poseidon(&NativeLoader, &inputs) }; - let mut current_instances = [ voter.instances[0][0], voter.instances[0][1], @@ -498,14 +492,13 @@ pub mod recursion { voter.instances[0][3], ] .to_vec(); - current_instances.extend(state_transition.instances[0][44..64].iter()); + current_instances.extend(voter.instances[0][4..24].iter()); current_instances.extend([ - state_transition.instances[0][68], - state_transition.instances[0][69], + state_transition_input.nullifier_tree.get_old_root(), + state_transition_input.nullifier_tree.get_new_root(), voter.instances[0][28], voter.instances[0][29], ]); - let instances = [ accumulator.lhs.x, accumulator.lhs.y, @@ -519,25 +512,27 @@ pub mod recursion { .chain([Fr::from(round as u64)]) .collect(); - let inner = BaseCircuitBuilder::from_stage(stage).use_params(config_params); + let inner = BaseCircuitBuilder::new(false).use_params(config_params); + println!("reached here"); let mut circuit = Self { svk, default_accumulator, voter, - state_transition, previous, round, instances, as_proof, inner, }; - circuit.build(); + + circuit.build(state_transition_input); circuit } - fn build(&mut self) { + fn build(&mut self, state_transition_input: StateTransitionInput) { let range = self.inner.range_chip(); let main_gate = range.gate(); + let pool = self.inner.pool(0); let preprocessed_digest = @@ -550,10 +545,7 @@ pub mod recursion { .iter() .map(|instance| main_gate.assign_integer(pool, *instance)) .collect::>(); - let aggr_vote = self.instances[Self::VOTE_ROW..Self::VOTE_ROW + 20] - .iter() - .map(|instance| main_gate.assign_integer(pool, *instance)) - .collect::>(); + let nullifier_old_root = main_gate.assign_integer(pool, self.instances[Self::NULLIFIER_OLD_ROOT_ROW]); let nullifier_new_root = @@ -572,14 +564,14 @@ pub mod recursion { let (mut voter_instances, voter_accumulators) = succinct_verify(&self.svk, &loader, &self.voter, None); - let (mut state_transition_instances, state_transition_accumulators) = - succinct_verify(&self.svk, &loader, &self.state_transition, None); + println!("voter proof verified"); let (mut previous_instances, previous_accumulators) = succinct_verify( &self.svk, &loader, &self.previous, Some(preprocessed_digest), ); + println!("previous proof verified"); let default_accmulator = self.load_default_accumulator(&loader).unwrap(); let previous_accumulators = previous_accumulators @@ -597,23 +589,48 @@ pub mod recursion { let KzgAccumulator { lhs, rhs } = accumulate( &loader, - [ - voter_accumulators, - state_transition_accumulators, - previous_accumulators, - ] - .concat(), + [voter_accumulators, previous_accumulators].concat(), self.as_proof(), ); let lhs = lhs.into_assigned(); let rhs = rhs.into_assigned(); let voter_instances = voter_instances.pop().unwrap(); - let state_transition_instances = state_transition_instances.pop().unwrap(); let previous_instances = previous_instances.pop().unwrap(); let mut pool = loader.take_ctx(); let ctx = pool.main(); + + let mut state_transition_vec = Vec::>::new(); + + //nullifier + state_transition_vec.extend(voter_instances[24..28].iter()); + //n,g + state_transition_vec.extend(voter_instances[0..4].iter()); + //incoming_vote + state_transition_vec.extend(voter_instances[4..24].iter()); + //previous_vote + state_transition_vec + .extend(previous_instances[4 * LIMBS + 5..4 * LIMBS + 5 + 20].iter()); + + let mut public_inputs = Vec::>::new(); + state_transition_circuit( + ctx, + &range, + state_transition_input.nullifier_tree, + state_transition_vec, + &mut public_inputs, + ); + let aggr_vote_fr = public_inputs[0..20] + .iter() + .map(|x| *x.value()) + .collect::>(); + + self.instances[Self::VOTE_ROW..Self::VOTE_ROW + 20].copy_from_slice(&aggr_vote_fr); + let aggr_vote = public_inputs[0..20].to_vec(); + + //updated state transition public inputs + for (lhs, rhs) in [ // Propagate preprocessed_digest ( @@ -629,66 +646,18 @@ pub mod recursion { ctx.constrain_equal(lhs, rhs); } - // state_transition(pk_enc) == previous(pk_enc) == voter(pk_enc) - for i in 0..4 { - // assert_eq!( - // state_transition_instances[i].value(), - // voter_instances[i].value() - // ); - ctx.constrain_equal(&state_transition_instances[i], &voter_instances[i]); - - // assert_eq!( - // state_transition_instances[i].value(), - // previous_instances[4 * LIMBS + i + 1].value() - // ); - ctx.constrain_equal( - &state_transition_instances[i], - &previous_instances[4 * LIMBS + i + 1], - ); - } - - // state_transition(prev_vote) == previous(aggr_vote) - for i in 0..20 { - // assert_eq!( - // state_transition_instances[i + 4].value(), - // previous_instances[4 * LIMBS + i + 1 + 4].value() - // ); - ctx.constrain_equal( - &state_transition_instances[i + 4], - &previous_instances[4 * LIMBS + i + 1 + 4], - ); - } - - // state_transition(incoming_vote) == voter(vote) - for i in 0..20 { - // assert_eq!( - // state_transition_instances[i + 24].value(), - // voter_instances[i + 4].value() - // ); - ctx.constrain_equal(&state_transition_instances[i + 24], &voter_instances[i + 4]); - } - - // state_transition(nullifier) == voter(nullifier) + // previous(pk_enc) == voter(pk_enc) for i in 0..4 { - // assert_eq!( - // state_transition_instances[i + 64].value(), - // voter_instances[i + 24].value() - // ); - ctx.constrain_equal( - &state_transition_instances[i + 64], - &voter_instances[i + 24], - ); + // assert_eq!(previous_instances[i].value(), voter_instances[i].value()); + ctx.constrain_equal(&previous_instances[4 * LIMBS + i + 1], &voter_instances[i]); } // state_transition(nullifier_old_root) == previous(nullifier_new_root) // assert_eq!( - // state_transition_instances[68].value(), + // public_inputs[68].value(), // previous_instances[4 * LIMBS + 1 + 25].value() // ); - ctx.constrain_equal( - &state_transition_instances[68], - &previous_instances[4 * LIMBS + 1 + 25], - ); + ctx.constrain_equal(&public_inputs[20], &previous_instances[4 * LIMBS + 1 + 25]); // previous(membership_root]) == voter(membership_root) // assert_eq!( @@ -732,9 +701,6 @@ pub mod recursion { ) .copied(), ); - - self.inner.calculate_params(Some(10)); - println!("recursion params: {:?}", self.inner.params()); } pub fn initial_snark( @@ -833,24 +799,16 @@ pub mod recursion { pub fn gen_recursion_pk( voter_params: &ParamsKZG, - state_transition_params: &ParamsKZG, recursion_params: &ParamsKZG, voter_vk: &VerifyingKey, - state_transition_vk: &VerifyingKey, voter_config: BaseCircuitParams, - state_transition_config: BaseCircuitParams, recursion_config: BaseCircuitParams, init_aggr_instances: Vec, + state_transition_input: StateTransitionInput, ) -> ProvingKey { let recursion = RecursionCircuit::new( - CircuitBuilderStage::Keygen, recursion_params, gen_dummy_snark::>(voter_params, Some(voter_vk), voter_config), - gen_dummy_snark::>( - state_transition_params, - Some(state_transition_vk), - state_transition_config, - ), RecursionCircuit::initial_snark( recursion_params, None, @@ -858,22 +816,23 @@ pub mod recursion { init_aggr_instances, ), 0, - recursion_config, + recursion_config.clone(), + state_transition_input, ); // we cannot auto-configure the circuit because dummy_snark must know the configuration beforehand // uncomment the following line only in development to test and print out the optimal configuration ahead of time // recursion.inner.0.builder.borrow().config(recursion_params.k() as usize, Some(10)); + gen_pk(recursion_params, &recursion) } pub fn gen_recursion_snark( - stage: CircuitBuilderStage, recursion_params: &ParamsKZG, recursion_pk: &ProvingKey, recursion_config: BaseCircuitParams, voter_snarks: Vec, - state_transition_snarks: Vec, init_aggr_instances: Vec, + state_transition_inputs: Vec>, ) -> Snark { let mut previous = RecursionCircuit::initial_snark( recursion_params, @@ -881,19 +840,14 @@ pub mod recursion { recursion_config.clone(), init_aggr_instances, ); - for (round, (voter, state_transition)) in voter_snarks - .into_iter() - .zip(state_transition_snarks) - .enumerate() - { + for (round, voter) in voter_snarks.into_iter().enumerate() { let recursion = RecursionCircuit::new( - stage, recursion_params, voter, - state_transition, previous, round, recursion_config.clone(), + state_transition_inputs[round].clone(), ); println!("Generate recursion snark for round {}", round); previous = gen_snark(recursion_params, recursion_pk, recursion); @@ -905,11 +859,13 @@ pub mod recursion { #[cfg(test)] mod test { use std::path::{Path, PathBuf}; + use std::time::Instant; use std::{fs, io::BufReader}; use ark_std::{end_timer, start_timer}; + use halo2_base::utils::decompose_biguint; use halo2_base::{ - gates::circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage}, + gates::circuit::{builder::BaseCircuitBuilder, BaseCircuitParams}, halo2_proofs::{ halo2curves::bn256::{Fr, G1Affine}, plonk::ProvingKey, @@ -920,7 +876,7 @@ mod test { use snark_verifier_sdk::{snark_verifier::verifier::SnarkVerifier, NativeLoader}; use voter::VoterCircuit; - use crate::{state_transition::StateTransitionCircuit, utils::generate_wrapper_circuit_input}; + use crate::utils::generate_wrapper_circuit_input; use super::{ gen_pk, gen_snark, @@ -943,10 +899,9 @@ mod test { #[test] fn test_recursion() { const GEN_VOTER_PK: bool = true; - const GEN_STATE_TRANSITION_PK: bool = true; const GEN_RECURSION_PK: bool = true; - - let num_round = 7; + println!("STOP UNDO"); + let num_round = 3; let (voter_input, state_transition_input) = generate_wrapper_circuit_input(num_round); @@ -990,55 +945,6 @@ mod test { println!("Generating voter snark"); let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); - let state_transition_config = BaseCircuitParams { - k: 15, - num_advice_per_phase: vec![3], - num_lookup_advice_per_phase: vec![1, 0, 0], - num_fixed: 1, - lookup_bits: Some(14), - num_instance_columns: 1, - }; - let state_transition_params = gen_srs(15); - let state_transition_circuit = StateTransitionCircuit::new( - state_transition_config.clone(), - state_transition_input[0].clone(), - ); - let state_transition_pk: ProvingKey; - if GEN_STATE_TRANSITION_PK { - println!("Generating state transition pk"); - state_transition_pk = gen_pk(&state_transition_params, &state_transition_circuit); - let mut state_transition_pk_bytes = Vec::new(); - state_transition_pk - .write( - &mut state_transition_pk_bytes, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, - ) - .unwrap(); - - fs::write( - build_dir.join("state_transition_pk.bin"), - state_transition_pk_bytes, - ) - .unwrap(); - } else { - println!("Reading state transition pk"); - let file = fs::read(build_dir.join("state_transition_pk.bin")).unwrap(); - let state_transition_pk_reader = &mut BufReader::new(file.as_slice()); - state_transition_pk = - ProvingKey::::read::, BaseCircuitBuilder>( - state_transition_pk_reader, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, - state_transition_config.clone(), - ) - .unwrap(); - } - println!("Generating state transition snark"); - let state_transition_snark = gen_snark( - &state_transition_params, - &state_transition_pk, - state_transition_circuit, - ); - let k = 22; let recursion_config = BaseCircuitParams { k, @@ -1050,6 +956,11 @@ mod test { }; let recursion_params = gen_srs(k as u32); + let mut init_vote = Vec::::new(); + for x in state_transition_input[0].prev_vote.iter() { + init_vote.extend(decompose_biguint::(x, 4, 88)); + } + // Init Base Instances let mut base_instances = [ Fr::zero(), // preprocessed_digest @@ -1059,29 +970,28 @@ mod test { voter_snark.instances[0][3], ] .to_vec(); - base_instances.extend(state_transition_snark.instances[0][4..24].iter()); // init_vote + base_instances.extend(init_vote); // init_vote base_instances.extend([ - state_transition_snark.instances[0][68], // nullifier_old_root - state_transition_snark.instances[0][68], // nullifier_new_root - voter_snark.instances[0][28], // membership_root - voter_snark.instances[0][29], // proposal_id - Fr::from(0), // round + state_transition_input[0].nullifier_tree.get_old_root(), // nullifier_old_root + state_transition_input[0].nullifier_tree.get_old_root(), // nullifier_new_root + voter_snark.instances[0][28], // membership_root + voter_snark.instances[0][29], // proposal_id + Fr::from(0), // round ]); let pk_time = start_timer!(|| "Generate recursion pk"); + let recursion_pk: ProvingKey; if GEN_RECURSION_PK { println!("Generating recursion pk"); recursion_pk = recursion::gen_recursion_pk( &voter_params, - &state_transition_params, &recursion_params, voter_pk.get_vk(), - state_transition_pk.get_vk(), voter_config.clone(), - state_transition_config.clone(), recursion_config.clone(), base_instances.clone(), + state_transition_input[0].clone(), ); let mut recursion_pk_bytes = Vec::new(); recursion_pk @@ -1107,33 +1017,22 @@ mod test { end_timer!(pk_time); let mut voter_snarks = vec![voter_snark]; - let mut state_transition_snarks = vec![state_transition_snark]; for i in 1..num_round { let voter_circuit = VoterCircuit::new(voter_config.clone(), voter_input[i].clone()); voter_snarks.push(gen_snark(&voter_params, &voter_pk, voter_circuit)); - - let state_transition_circuit = StateTransitionCircuit::new( - state_transition_config.clone(), - state_transition_input[i].clone(), - ); - state_transition_snarks.push(gen_snark( - &state_transition_params, - &state_transition_pk, - state_transition_circuit, - )); } println!("Starting recursion..."); let pf_time = start_timer!(|| "Generate full recursive snark"); + let start = Instant::now(); let final_snark = recursion::gen_recursion_snark( - CircuitBuilderStage::Mock, &recursion_params, &recursion_pk, recursion_config, voter_snarks, - state_transition_snarks, base_instances, + state_transition_input.clone(), ); end_timer!(pf_time); @@ -1156,5 +1055,8 @@ mod test { PlonkVerifier::verify(&dk, &final_snark.protocol, &final_snark.instances, &proof) .unwrap(); } + println!("Time taken for recursion: {:?}", start.elapsed()); } + + //function for Fr to Vec } diff --git a/voter/Cargo.toml b/voter/Cargo.toml index b96f310..b4e9c19 100644 --- a/voter/Cargo.toml +++ b/voter/Cargo.toml @@ -13,6 +13,7 @@ pse-poseidon = { git = "https://github.com/aerius-labs/pse-poseidon.git", branch num-bigint = { version = "0.4.4", features = ["serde"] } itertools = "0.12.0" serde ="1.0.196" +indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merkle-tree-halo2.git" , branch = "fix/imt" } [dev-dependencies] criterion = "0.5.1" diff --git a/voter/src/lib.rs b/voter/src/lib.rs index 88de07c..8d22a1d 100644 --- a/voter/src/lib.rs +++ b/voter/src/lib.rs @@ -23,8 +23,9 @@ use halo2_ecc::{ fields::FieldChip, secp256k1::{sha256::Sha256Chip, FpChip, FqChip}, }; +use indexed_merkle_tree_halo2::indexed_merkle_tree::verify_merkle_proof; use itertools::Itertools; -use merkletree::verify_membership_proof; + use num_bigint::BigUint; use biguint_halo2::big_uint::chip::BigUintChip; @@ -226,74 +227,82 @@ pub fn voter_circuit( g: g_assigned, }; + let zero = ctx.load_zero(); + let one = ctx.load_constant(F::from(1)); + // 1. Verify if the voter is in the membership tree - // verify_membership_proof( - // ctx, - // gate, - // &hasher, - // &membership_root, - // &leaf, - // &membership_proof, - // &membership_proof_helper, - // ); + verify_merkle_proof( + ctx, + range, + &hasher, + &membership_root, + &leaf, + &membership_proof, + &membership_proof_helper, + &zero, + &one, + false, + ); // Check to verify correct votes have been passed. - // let _ = vote_assigned_fe.iter().map(|x| gate.assert_bit(ctx, *x)); - // let zero = ctx.load_zero(); - // let one = ctx.load_constant(F::ONE); - // let vote_sum_assigned = vote_assigned_fe - // .iter() - // .fold(zero, |zero, x| gate.add(ctx, zero, *x)); - // ctx.constrain_equal(&vote_sum_assigned, &one); + let _ = vote_assigned_fe.iter().map(|x| gate.assert_bit(ctx, *x)); + let zero = ctx.load_zero(); + let one = ctx.load_constant(F::ONE); + let vote_sum_assigned = vote_assigned_fe + .iter() + .fold(zero, |zero, x| gate.add(ctx, zero, *x)); + ctx.constrain_equal(&vote_sum_assigned, &one); //PK_ENC_n + println!("pk_enc n : {:?}", pk_enc.n.limbs().to_vec()); public_inputs.extend(pk_enc.n.limbs().to_vec()); //PK_ENC_g + println!("pk_enc n : {:?}", pk_enc.n.limbs().to_vec()); public_inputs.extend(pk_enc.g.limbs().to_vec()); // 2. Verify correct vote encryption for i in 0..input.vote.len() { - // let _vote_enc = paillier_chip - // .encrypt(ctx, &pk_enc, &vote_assigned_big[i], &r_assigned[i]) - // .unwrap(); + let _vote_enc = paillier_chip + .encrypt(ctx, &pk_enc, &vote_assigned_big[i], &r_assigned[i]) + .unwrap(); - // biguint_chip - // .assert_equal_fresh(ctx, &vote_enc_assigned_big[i], &_vote_enc) - // .unwrap(); + biguint_chip + .assert_equal_fresh(ctx, &vote_enc_assigned_big[i], &_vote_enc) + .unwrap(); //ENC_VOTE public_inputs.append(&mut vote_enc_assigned_big[i].limbs().to_vec()); } // 3. Verify nullifier - // let message = proposal_id.value().to_bytes_le()[..2] - // .iter() - // .map(|v| ctx.load_witness(F::from(*v as u64))) - // .collect::>(); - // { - // let mut _proposal_id = ctx.load_zero(); - // for i in 0..2 { - // _proposal_id = gate.mul_add( - // ctx, - // message[i], - // QuantumCell::Constant(F::from(1u64 << (8 * i))), - // _proposal_id, - // ); - // } - // ctx.constrain_equal(&_proposal_id, &proposal_id); - // } + let message = proposal_id.value().to_bytes_le()[..2] + .iter() + .map(|v| ctx.load_witness(F::from(*v as u64))) + .collect::>(); + { + let mut _proposal_id = ctx.load_zero(); + for i in 0..2 { + _proposal_id = gate.mul_add( + ctx, + message[i], + QuantumCell::Constant(F::from(1u64 << (8 * i))), + _proposal_id, + ); + } + ctx.constrain_equal(&_proposal_id, &proposal_id); + } let compressed_nullifier = compress_nullifier(ctx, range, &nullifier); - // let plume_input = PlumeInput::new( - // nullifier, - // s_nullifier.clone(), - // c_nullifier, - // pk_voter, - // message, - // ); - // verify_plume(ctx, &ecc_chip, &sha256_chip, 4, 4, plume_input); + let plume_input = PlumeInput::new( + nullifier, + s_nullifier.clone(), + c_nullifier, + pk_voter, + message, + ); + verify_plume(ctx, &ecc_chip, &sha256_chip, 4, 4, plume_input); //NULLIFIER public_inputs.extend(compressed_nullifier.to_vec()); diff --git a/voter_tests/Cargo.toml b/voter_tests/Cargo.toml index 6094f04..bf0add9 100644 --- a/voter_tests/Cargo.toml +++ b/voter_tests/Cargo.toml @@ -17,3 +17,4 @@ serde = "1.0.196" k256 = { version = "0.13.3", features = ["arithmetic", "hash2curve", "expose-field", "sha2"]} num-bigint = { version = "0.4.4", features = ["serde"] } voter = { path = "../voter" } +indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merkle-tree-halo2.git" , branch = "fix/imt" } diff --git a/voter_tests/src/lib.rs b/voter_tests/src/lib.rs index 80575ce..411edad 100644 --- a/voter_tests/src/lib.rs +++ b/voter_tests/src/lib.rs @@ -5,6 +5,7 @@ use halo2_base::halo2_proofs::halo2curves::group::Curve; use halo2_base::halo2_proofs::halo2curves::secp256k1::{Fp, Fq, Secp256k1, Secp256k1Affine}; use halo2_base::utils::{fe_to_biguint, ScalarField}; use halo2_ecc::*; +use indexed_merkle_tree_halo2::utils::IndexedMerkleTree; use k256::elliptic_curve::hash2curve::GroupDigest; use k256::elliptic_curve::sec1::ToEncodedPoint; use k256::{ @@ -18,7 +19,7 @@ use rand::rngs::OsRng; use rand::thread_rng; use sha2::{Digest, Sha256}; -use voter::merkletree::native::MerkleTree; +use voter::merkletree::native::{self, MerkleTree}; use voter::{voter_circuit, EncryptionPublicKey, VoterCircuitInput}; pub fn compress_point(point: &Secp256k1Affine) -> [u8; 33] { @@ -158,6 +159,7 @@ pub fn generate_random_voter_circuit_inputs() -> VoterCircuitInput { } let mut native_hasher = Poseidon::::new(R_F, R_P); + let mut membership_tree = IndexedMerkleTree::::new_default_leaf(treesize as usize); let mut leaves = Vec::::new(); @@ -189,18 +191,27 @@ pub fn generate_random_voter_circuit_inputs() -> VoterCircuitInput { native_hasher.update(&[Fr::ZERO]); } leaves.push(native_hasher.squeeze_and_reset()); + membership_tree.insert_leaf(&mut native_hasher, leaves[i as usize], i as usize); } - let mut membership_tree = - MerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); - let membership_root = membership_tree.get_root(); let (membership_proof, membership_proof_helper) = membership_tree.get_proof(0); + assert_eq!( - membership_tree.verify_proof(&leaves[0], 0, &membership_root, &membership_proof), + membership_tree.verify_proof(&mut native_hasher, 0, &membership_root, &membership_proof), true ); + // let mut membership_tree = + // MerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); + + // let membership_root = membership_tree.get_root(); + // let (membership_proof, membership_proof_helper) = membership_tree.get_proof(0); + // assert_eq!( + // membership_tree.verify_proof(&leaves[0], 0, &membership_root, &membership_proof), + // true + // ); + let pk_enc = EncryptionPublicKey { n, g }; // Proposal id = 1