From fd3724b569348a86659a9f3a725941c0d9f4aa80 Mon Sep 17 00:00:00 2001 From: Vishal Date: Tue, 26 Mar 2024 11:19:48 +0530 Subject: [PATCH 1/7] feat(wip): moved state-transition logic into wrapper circuit --- aggregator/src/state_transition.rs | 78 ++-- aggregator/src/wrapper.rs | 721 ++++++++++++++--------------- 2 files changed, 389 insertions(+), 410 deletions(-) diff --git a/aggregator/src/state_transition.rs b/aggregator/src/state_transition.rs index eb88ea8..74f133a 100644 --- a/aggregator/src/state_transition.rs +++ b/aggregator/src/state_transition.rs @@ -1,26 +1,27 @@ use halo2_base::gates::circuit::builder::BaseCircuitBuilder; -use halo2_base::gates::circuit::{BaseCircuitParams, BaseConfig}; +use halo2_base::gates::circuit::{ BaseCircuitParams, BaseConfig }; use halo2_base::gates::GateInstructions; -use halo2_base::halo2_proofs::circuit::{Layouter, SimpleFloorPlanner}; -use halo2_base::halo2_proofs::plonk::{Circuit, ConstraintSystem, Error}; -use halo2_base::poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}; +use halo2_base::halo2_proofs::circuit::{ Layouter, SimpleFloorPlanner }; +use halo2_base::halo2_proofs::plonk::{ Circuit, ConstraintSystem, Error }; +use halo2_base::poseidon::hasher::{ spec::OptimizedPoseidonSpec, PoseidonHasher }; use halo2_base::{ - gates::{RangeChip, RangeInstructions}, + gates::{ RangeChip, RangeInstructions }, halo2_proofs::circuit::Value, utils::BigPrimeField, - AssignedValue, Context, + AssignedValue, + Context, }; use halo2_ecc::ecc::EccChip; use halo2_ecc::fields::fp::FpChip; -use indexed_merkle_tree_halo2::indexed_merkle_tree::{insert_leaf, IndexedMerkleTreeLeaf}; +use indexed_merkle_tree_halo2::indexed_merkle_tree::{ insert_leaf, IndexedMerkleTreeLeaf }; use indexed_merkle_tree_halo2::utils::IndexedMerkleTreeLeaf as IMTLeaf; use num_bigint::BigUint; use biguint_halo2::big_uint::chip::BigUintChip; -use halo2_base::halo2_proofs::halo2curves::secp256k1::{Fp, Secp256k1Affine}; -use paillier_chip::paillier::{EncryptionPublicKeyAssigned, PaillierChip}; -use serde::{Deserialize, Serialize}; -use voter::{compress_nullifier, CircuitExt, EncryptionPublicKey}; +use halo2_base::halo2_proofs::halo2curves::secp256k1::{ Fp, Secp256k1Affine }; +use paillier_chip::paillier::{ EncryptionPublicKeyAssigned, PaillierChip }; +use serde::{ Deserialize, Serialize }; +use voter::{ compress_nullifier, CircuitExt, EncryptionPublicKey }; const ENC_BIT_LEN: usize = 176; const LIMB_BIT_LEN: usize = 88; @@ -52,7 +53,7 @@ impl IndexedMerkleTreeInput { new_leaf_index: F, new_leaf_proof: Vec, new_leaf_proof_helper: Vec, - is_new_leaf_largest: F, + is_new_leaf_largest: F ) -> Self { Self { old_root, @@ -67,6 +68,9 @@ impl IndexedMerkleTreeInput { is_new_leaf_largest, } } + pub fn get_old_root(&self) -> F { + self.old_root + } } #[derive(Debug, Clone)] @@ -83,7 +87,7 @@ impl StateTransitionInput { incoming_vote: Vec, prev_vote: Vec, nullifier_tree: IndexedMerkleTreeInput, - nullifier: Secp256k1Affine, + nullifier: Secp256k1Affine ) -> Self { Self { pk_enc, @@ -99,7 +103,7 @@ pub fn state_transition_circuit( ctx: &mut Context, range: &RangeChip, input: StateTransitionInput, - public_inputs: &mut Vec>, + public_inputs: &mut Vec> ) { let gate = range.gate(); let mut hasher = PoseidonHasher::::new(OptimizedPoseidonSpec::new::<8, 57, 0>()); @@ -128,22 +132,16 @@ pub fn state_transition_circuit( g: g_assigned, }; - let incoming_vote = input - .incoming_vote + let incoming_vote = input.incoming_vote .iter() .map(|x| { - biguint_chip - .assign_integer(ctx, Value::known(x.clone()), ENC_BIT_LEN * 2) - .unwrap() + biguint_chip.assign_integer(ctx, Value::known(x.clone()), ENC_BIT_LEN * 2).unwrap() }) .collect::>(); - let prev_vote = input - .prev_vote + let prev_vote = input.prev_vote .iter() .map(|x| { - biguint_chip - .assign_integer(ctx, Value::known(x.clone()), ENC_BIT_LEN * 2) - .unwrap() + biguint_chip.assign_integer(ctx, Value::known(x.clone()), ENC_BIT_LEN * 2).unwrap() }) .collect::>(); @@ -175,27 +173,19 @@ pub fn state_transition_circuit( let new_leaf_index = ctx.load_witness(input.nullifier_tree.new_leaf_index); let is_new_leaf_largest = ctx.load_witness(input.nullifier_tree.is_new_leaf_largest); - let low_leaf_proof = input - .nullifier_tree - .low_leaf_proof + let low_leaf_proof = input.nullifier_tree.low_leaf_proof .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let low_leaf_proof_helper = input - .nullifier_tree - .low_leaf_proof_helper + let low_leaf_proof_helper = input.nullifier_tree.low_leaf_proof_helper .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let new_leaf_proof = input - .nullifier_tree - .new_leaf_proof + let new_leaf_proof = input.nullifier_tree.new_leaf_proof .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let new_leaf_proof_helper = input - .nullifier_tree - .new_leaf_proof_helper + let new_leaf_proof_helper = input.nullifier_tree.new_leaf_proof_helper .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); @@ -213,7 +203,7 @@ pub fn state_transition_circuit( &new_leaf_index, &new_leaf_proof, &new_leaf_proof_helper, - &is_new_leaf_largest, + &is_new_leaf_largest ); // PK_ENC N @@ -300,10 +290,12 @@ impl CircuitExt for StateTransitionCircuit { } fn instances(&self) -> Vec> { - vec![self.inner.assigned_instances[0] - .iter() - .map(|instance| *instance.value()) - .collect()] + vec![ + self.inner.assigned_instances[0] + .iter() + .map(|instance| *instance.value()) + .collect() + ] } } @@ -311,7 +303,7 @@ impl CircuitExt for StateTransitionCircuit { mod test { use halo2_base::{ gates::circuit::BaseCircuitParams, - halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}, + halo2_proofs::{ dev::MockProver, halo2curves::bn256::Fr }, utils::testing::base_test, AssignedValue, }; @@ -319,7 +311,7 @@ mod test { use crate::utils::generate_wrapper_circuit_input; - use super::{state_transition_circuit, StateTransitionCircuit}; + use super::{ state_transition_circuit, StateTransitionCircuit }; #[test] fn test_state_transition_circuit() { diff --git a/aggregator/src/wrapper.rs b/aggregator/src/wrapper.rs index 0353862..0e6d23a 100644 --- a/aggregator/src/wrapper.rs +++ b/aggregator/src/wrapper.rs @@ -1,22 +1,26 @@ -use ark_std::{end_timer, start_timer}; +use ark_std::{ end_timer, start_timer }; use common::*; use halo2_base::halo2_proofs; use halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner}, + circuit::{ Layouter, SimpleFloorPlanner }, dev::MockProver, - halo2curves::{ - bn256::{Bn256, Fr, G1Affine}, - group::ff::Field, - }, + halo2curves::{ bn256::{ Bn256, Fr, G1Affine }, group::ff::Field }, plonk::{ - create_proof, keygen_pk, keygen_vk, Circuit, ConstraintSystem, Error, ProvingKey, Selector, + create_proof, + keygen_pk, + keygen_vk, + Circuit, + ConstraintSystem, + Error, + ProvingKey, + Selector, VerifyingKey, }, poly::{ commitment::ParamsProver, kzg::{ commitment::ParamsKZG, - multiopen::{ProverGWC, VerifierGWC}, + multiopen::{ ProverGWC, VerifierGWC }, strategy::AccumulatorStrategy, }, VerificationStrategy, @@ -25,23 +29,17 @@ use halo2_proofs::{ use itertools::Itertools; use rand_chacha::rand_core::OsRng; use snark_verifier_sdk::snark_verifier::{ - loader::{self, native::NativeLoader, Loader, ScalarLoader}, + loader::{ self, native::NativeLoader, Loader, ScalarLoader }, pcs::{ - kzg::{Gwc19, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, - AccumulationScheme, AccumulationSchemeProver, - }, - system::halo2::{self, compile, Config}, - util::{ - arithmetic::{fe_to_fe, fe_to_limbs}, - hash, - }, - verifier::{ - self, - plonk::{PlonkProof, PlonkProtocol}, - SnarkVerifier, + kzg::{ Gwc19, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding }, + AccumulationScheme, + AccumulationSchemeProver, }, + system::halo2::{ self, compile, Config }, + util::{ arithmetic::{ fe_to_fe, fe_to_limbs }, hash }, + verifier::{ self, plonk::{ PlonkProof, PlonkProtocol }, SnarkVerifier }, }; -use std::{iter, marker::PhantomData, rc::Rc}; +use std::{ iter, marker::PhantomData, rc::Rc }; const LIMBS: usize = 3; const BITS: usize = 88; @@ -56,21 +54,29 @@ type As = KzgAs; type PlonkVerifier = verifier::plonk::PlonkVerifier>; type PlonkSuccinctVerifier = verifier::plonk::PlonkSuccinctVerifier>; type Poseidon = hash::Poseidon; -type PoseidonTranscript = - halo2::transcript::halo2::PoseidonTranscript; +type PoseidonTranscript = halo2::transcript::halo2::PoseidonTranscript< + G1Affine, + L, + S, + T, + RATE, + R_F, + R_P +>; pub mod common { use super::*; - use halo2_proofs::{plonk::verify_proof, poly::commitment::Params}; - use serde::{Deserialize, Serialize}; + use halo2_proofs::{ plonk::verify_proof, poly::commitment::Params }; + use serde::{ Deserialize, Serialize }; use snark_verifier_sdk::snark_verifier::{ - cost::CostEstimation, util::transcript::TranscriptWrite, + cost::CostEstimation, + util::transcript::TranscriptWrite, }; use voter::CircuitExt; pub fn poseidon>( loader: &L, - inputs: &[L::LoadedScalar], + inputs: &[L::LoadedScalar] ) -> L::LoadedScalar { // warning: generating a new spec is time intensive, use lazy_static in production let mut hasher = Poseidon::new::(loader); @@ -78,6 +84,13 @@ pub mod common { hasher.squeeze() } + pub fn limbs_to_biguint(x: Vec) -> BigUint { + x.iter() + .enumerate() + .map(|(i, limb)| fe_to_biguint(limb) * BigUint::from(2u64).pow(88 * (i as u32))) + .sum() + } + #[derive(Clone, Serialize, Deserialize)] pub struct Snark { pub protocol: PlonkProtocol, @@ -89,7 +102,7 @@ pub mod common { pub fn new( protocol: PlonkProtocol, instances: Vec>, - proof: Vec, + proof: Vec ) -> Self { Self { protocol, @@ -112,44 +125,42 @@ pub mod common { params: &ParamsKZG, pk: &ProvingKey, circuit: C, - instances: Vec>, + instances: Vec> ) -> Vec { if params.k() > 3 { let mock = start_timer!(|| "Mock prover"); - MockProver::run(params.k(), &circuit, instances.clone()) - .unwrap() - .assert_satisfied(); + MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); end_timer!(mock); } let instances = instances.iter().map(Vec::as_slice).collect_vec(); let proof = { - let mut transcript = - PoseidonTranscript::::new::(Vec::new()); + let mut transcript = PoseidonTranscript::::new::( + Vec::new() + ); create_proof::<_, ProverGWC<_>, _, _, _, _>( params, pk, &[circuit], &[instances.as_slice()], OsRng, - &mut transcript, - ) - .unwrap(); + &mut transcript + ).unwrap(); transcript.finalize() }; let accept = { - let mut transcript = - PoseidonTranscript::::new::(proof.as_slice()); + let mut transcript = PoseidonTranscript::::new::( + proof.as_slice() + ); VerificationStrategy::<_, VerifierGWC<_>>::finalize( verify_proof::<_, VerifierGWC<_>, _, _, _>( params.verifier_params(), pk.get_vk(), AccumulatorStrategy::new(params.verifier_params()), &[instances.as_slice()], - &mut transcript, - ) - .unwrap(), + &mut transcript + ).unwrap() ) }; assert!(accept); @@ -160,14 +171,14 @@ pub mod common { pub fn gen_snark>( params: &ParamsKZG, pk: &ProvingKey, - circuit: ConcreteCircuit, + circuit: ConcreteCircuit ) -> Snark { let protocol = compile( params, pk.get_vk(), Config::kzg() .with_num_instance(ConcreteCircuit::num_instance()) - .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()) ); let instances = circuit.instances(); @@ -179,17 +190,13 @@ pub mod common { pub fn gen_dummy_snark>( params: &ParamsKZG, vk: Option<&VerifyingKey>, - config_params: ConcreteCircuit::Params, + config_params: ConcreteCircuit::Params ) -> Snark - where - ConcreteCircuit::Params: Clone, + where ConcreteCircuit::Params: Clone { struct CsProxy>(C::Params, PhantomData<(F, C)>); - impl> Circuit for CsProxy - where - C::Params: Clone, - { + impl> Circuit for CsProxy where C::Params: Clone { type Config = C::Config; type FloorPlanner = C::FloorPlanner; type Params = C::Params; @@ -204,7 +211,7 @@ pub mod common { fn configure_with_params( meta: &mut ConstraintSystem, - params: Self::Params, + params: Self::Params ) -> Self::Config { C::configure_with_params(meta, params) } @@ -216,7 +223,7 @@ pub mod common { fn synthesize( &self, config: Self::Config, - mut layouter: impl Layouter, + mut layouter: impl Layouter ) -> Result<(), Error> { // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row // currently this only works if all simple selector columns are used in the actual circuit and there are overlaps amongst all enabled selectors (i.e., the actual circuit will not optimize constraint system further) @@ -227,39 +234,44 @@ pub mod common { q.enable(&mut region, 0)?; } Ok(()) - }, + } )?; Ok(()) } } - let dummy_vk = vk.is_none().then(|| { - keygen_vk( - params, - &CsProxy::(config_params, PhantomData), - ) - .unwrap() - }); + let dummy_vk = vk + .is_none() + .then(|| { + keygen_vk( + params, + &CsProxy::(config_params, PhantomData) + ).unwrap() + }); let protocol = compile( params, vk.or(dummy_vk.as_ref()).unwrap(), Config::kzg() .with_num_instance(ConcreteCircuit::num_instance()) - .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()) ); let instances = ConcreteCircuit::num_instance() .into_iter() - .map(|n| iter::repeat_with(|| Fr::random(OsRng)).take(n).collect()) + .map(|n| + iter + ::repeat_with(|| Fr::random(OsRng)) + .take(n) + .collect() + ) .collect(); let proof = { - let mut transcript = - PoseidonTranscript::::new::(Vec::new()); - for _ in 0..protocol - .num_witness + let mut transcript = PoseidonTranscript::::new::( + Vec::new() + ); + for _ in 0..protocol.num_witness .iter() .chain(Some(&protocol.quotient.num_chunk())) - .sum::() - { + .sum::() { transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); } for _ in 0..protocol.evaluations.len() { @@ -282,17 +294,27 @@ pub mod recursion { use halo2_base::{ gates::{ circuit::{ - builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig, CircuitBuilderStage, + builder::BaseCircuitBuilder, + BaseCircuitParams, + BaseConfig, + CircuitBuilderStage, }, - GateInstructions, RangeInstructions, + GateInstructions, + RangeInstructions, }, + halo2_proofs::halo2curves::secp256k1::Secp256k1Affine, AssignedValue, }; - use halo2_ecc::{bn254::FpChip, ecc::EcPoint}; - use snark_verifier_sdk::snark_verifier::loader::halo2::{EccInstructions, IntegerInstructions}; - use voter::{CircuitExt, VoterCircuit}; - - use crate::state_transition::StateTransitionCircuit; + use halo2_ecc::{ bn254::FpChip, ecc::EcPoint }; + use num_bigint::BigUint; + use snark_verifier_sdk::snark_verifier::loader::halo2::{ EccInstructions, IntegerInstructions }; + use voter::{ CircuitExt, EncryptionPublicKey, VoterCircuit }; + + use crate::state_transition::{ + state_transition_circuit, + StateTransitionCircuit, + StateTransitionInput, + }; use super::*; @@ -303,21 +325,18 @@ pub mod recursion { svk: &Svk, loader: &Rc>, snark: &Snark, - preprocessed_digest: Option>, - ) -> ( - Vec>>, - Vec>>>, - ) { + preprocessed_digest: Option> + ) -> (Vec>>, Vec>>>) { let protocol = if let Some(preprocessed_digest) = preprocessed_digest { let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); let protocol = snark.protocol.loaded_preprocessed_as_witness(loader, false); - let inputs = protocol - .preprocessed + let inputs = protocol.preprocessed .iter() .flat_map(|preprocessed| { let assigned = preprocessed.assigned(); - [assigned.x(), assigned.y()] - .map(|coordinate| loader.scalar_from_assigned(*coordinate.native())) + [assigned.x(), assigned.y()].map(|coordinate| + loader.scalar_from_assigned(*coordinate.native()) + ) }) .chain(protocol.transcript_initial_state.clone()) .collect_vec(); @@ -327,8 +346,7 @@ pub mod recursion { snark.protocol.loaded(loader) }; - let instances = snark - .instances + let instances = snark.instances .iter() .map(|instances| { instances @@ -337,12 +355,22 @@ pub mod recursion { .collect_vec() }) .collect_vec(); - let mut transcript = - PoseidonTranscript::, _>::new::(loader, snark.proof()); - let proof = - PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); - let accumulators = - PlonkSuccinctVerifier::verify(svk, &protocol, &instances, &proof).unwrap(); + let mut transcript = PoseidonTranscript::, _>::new::( + loader, + snark.proof() + ); + let proof = PlonkSuccinctVerifier::read_proof( + svk, + &protocol, + &instances, + &mut transcript + ).unwrap(); + let accumulators = PlonkSuccinctVerifier::verify( + svk, + &protocol, + &instances, + &proof + ).unwrap(); ( instances @@ -362,35 +390,41 @@ pub mod recursion { loader: &Rc>, condition: &AssignedValue, lhs: &KzgAccumulator>>, - rhs: &KzgAccumulator>>, + rhs: &KzgAccumulator>> ) -> Result>>, Error> { let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] .iter() .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) .map(|(lhs, rhs)| { - loader.ecc_chip().select( - loader.ctx_mut().main(), - EcPoint::clone(lhs), - EcPoint::clone(rhs), - *condition, - ) + loader + .ecc_chip() + .select( + loader.ctx_mut().main(), + EcPoint::clone(lhs), + EcPoint::clone(rhs), + *condition + ) }) .collect::>() .try_into() .unwrap(); - Ok(KzgAccumulator::new( - loader.ec_point_from_assigned(lhs), - loader.ec_point_from_assigned(rhs), - )) + Ok( + KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs) + ) + ) } fn accumulate<'a>( loader: &Rc>, accumulators: Vec>>>, - as_proof: &[u8], + as_proof: &[u8] ) -> KzgAccumulator>> { - let mut transcript = - PoseidonTranscript::, _>::new::(loader, as_proof); + let mut transcript = PoseidonTranscript::, _>::new::( + loader, + as_proof + ); let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); As::verify(&Default::default(), &accumulators, &proof).unwrap() } @@ -409,7 +443,6 @@ pub mod recursion { svk: Svk, default_accumulator: KzgAccumulator, voter: Snark, - state_transition: Snark, previous: Snark, #[allow(dead_code)] round: usize, @@ -434,55 +467,60 @@ pub mod recursion { stage: CircuitBuilderStage, params: &ParamsKZG, voter: Snark, - state_transition: Snark, previous: Snark, + nullifier_tree: IndexedMerkleTreeInput, round: usize, - config_params: BaseCircuitParams, + config_params: BaseCircuitParams ) -> Self { let svk = params.get_g()[0].into(); let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); let succinct_verify = |snark: &Snark| { let mut transcript = PoseidonTranscript::::new::( - snark.proof.as_slice(), + snark.proof.as_slice() ); let proof = PlonkSuccinctVerifier::read_proof( &svk, &snark.protocol, &snark.instances, - &mut transcript, - ) - .unwrap(); - PlonkSuccinctVerifier::verify(&svk, &snark.protocol, &snark.instances, &proof) - .unwrap() + &mut transcript + ).unwrap(); + PlonkSuccinctVerifier::verify( + &svk, + &snark.protocol, + &snark.instances, + &proof + ).unwrap() }; - let accumulators = iter::empty() + let accumulators = iter + ::empty() .chain(succinct_verify(&voter)) - .chain(succinct_verify(&state_transition)) .chain( (round > 0) .then(|| succinct_verify(&previous)) .unwrap_or_else(|| { let num_accumulator = 1 + previous.protocol.accumulator_indices.len(); vec![default_accumulator.clone(); num_accumulator] - }), + }) ) .collect_vec(); let (accumulator, as_proof) = { - let mut transcript = - PoseidonTranscript::::new::(Vec::new()); - let accumulator = - As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) - .unwrap(); + let mut transcript = PoseidonTranscript::::new::( + Vec::new() + ); + let accumulator = As::create_proof( + &Default::default(), + &accumulators, + &mut transcript, + OsRng + ).unwrap(); (accumulator, transcript.finalize()) }; let preprocessed_digest = { - let inputs = previous - .protocol - .preprocessed + let inputs = previous.protocol.preprocessed .iter() .flat_map(|preprocessed| [preprocessed.x, preprocessed.y]) .map(fe_to_fe) @@ -491,17 +529,74 @@ pub mod recursion { poseidon(&NativeLoader, &inputs) }; + //State transition circuit input + + //state_transition(pk_enc) + let st_pk_enc_fr = [ + voter.instances[0][0], + voter.instances[0][1], + voter.instances[0][2], + voter.instances[0][3], + ]; + let sk_pk_enc = EncryptionPublicKey { + n: limbs_to_biguint(st_pk_enc[0..2]), + g: limbs_to_biguint(st_pk_enc[2..4]), + }; + + //state_transition(prev_vote) + let st_prev_vote_fr = (0..20) + .map(|i| previous.instances[0][4 * LIMBS + i + 1 + 4]) + .collect::>(); + let st_prev_vote = (0..5) + .map(|i| limbs_to_biguint(st_prev_vote_fr[4 * i + 4 * i + 4])) + .collect::>(); + + //state_transition(incoming_vote) + let st_incoming_vote_fr = (0..20) + .map(|i| voter.instances[0][i + 4]) + .collect::>(); + let st_incoming_vote = (0..5) + .map(|i| limbs_to_biguint(st_incoming_vote_fr[4 * i + 4 * i + 4])) + .collect::>(); + + //state_transition(nullifier) + let st_nullifier = (0..4).map(|i| voter.instances[0][i + 24]).collect::>(); + let nullifier_x = limbs_to_biguint(st_nullifier[0..2]); + let nullifier_y = limbs_to_biguint(st_nullifier[2..4]); + + let st_nullifier = Secp256k1Affine::new( + fe_to_fe::(nullifier_x), + fe_to_fe::(nullifier_y) + ); + let state_transition_input = StateTransitionInput::new( + sk_pk_enc, + st_prev_vote, + st_incoming_vote, + nullifier_tree, + st_nullifier + ); + //state_transition(nullifier_old_root) + // let st_nullifier_old_root = previous.instances[0][4 * LIMBS + 1 + 25]; + + let inner = BaseCircuitBuilder::from_stage(stage).use_params(config_params); + let range = inner.range(); + let mut public_inputs = Vec::>::new(); + state_transition_circuit( + inner.main(0), + range, + state_transition_input, + &mut public_inputs + ); let mut current_instances = [ voter.instances[0][0], voter.instances[0][1], voter.instances[0][2], voter.instances[0][3], - ] - .to_vec(); - current_instances.extend(state_transition.instances[0][44..64].iter()); + ].to_vec(); + current_instances.extend(public_inputs[44..64].iter()); current_instances.extend([ - state_transition.instances[0][68], - state_transition.instances[0][69], + public_inputs[68], + public_inputs[69], voter.instances[0][28], voter.instances[0][29], ]); @@ -512,19 +607,17 @@ pub mod recursion { accumulator.rhs.x, accumulator.rhs.y, ] - .into_iter() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .chain([preprocessed_digest]) - .chain(current_instances) - .chain([Fr::from(round as u64)]) - .collect(); + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([preprocessed_digest]) + .chain(current_instances) + .chain([Fr::from(round as u64)]) + .collect(); - let inner = BaseCircuitBuilder::from_stage(stage).use_params(config_params); let mut circuit = Self { svk, default_accumulator, voter, - state_transition, previous, round, instances, @@ -540,8 +633,10 @@ pub mod recursion { let main_gate = range.gate(); let pool = self.inner.pool(0); - let preprocessed_digest = - main_gate.assign_integer(pool, self.instances[Self::PREPROCESSED_DIGEST_ROW]); + let preprocessed_digest = main_gate.assign_integer( + pool, + self.instances[Self::PREPROCESSED_DIGEST_ROW] + ); let pk_enc_n = self.instances[Self::PK_ENC_N_ROW..Self::PK_ENC_N_ROW + 2] .iter() .map(|instance| main_gate.assign_integer(pool, *instance)) @@ -554,12 +649,18 @@ pub mod recursion { .iter() .map(|instance| main_gate.assign_integer(pool, *instance)) .collect::>(); - let nullifier_old_root = - main_gate.assign_integer(pool, self.instances[Self::NULLIFIER_OLD_ROOT_ROW]); - let nullifier_new_root = - main_gate.assign_integer(pool, self.instances[Self::NULLIFIER_NEW_ROOT_ROW]); - let membership_root = - main_gate.assign_integer(pool, self.instances[Self::MEMBERSHIP_ROOT_ROW]); + let nullifier_old_root = main_gate.assign_integer( + pool, + self.instances[Self::NULLIFIER_OLD_ROOT_ROW] + ); + let nullifier_new_root = main_gate.assign_integer( + pool, + self.instances[Self::NULLIFIER_NEW_ROOT_ROW] + ); + let membership_root = main_gate.assign_integer( + pool, + self.instances[Self::MEMBERSHIP_ROOT_ROW] + ); let proposal_id = main_gate.assign_integer(pool, self.instances[Self::PROPOSAL_ID_ROW]); let round = main_gate.assign_integer(pool, self.instances[Self::ROUND_ROW]); @@ -570,15 +671,18 @@ pub mod recursion { let ecc_chip = BaseFieldEccChip::new(&fp_chip); let loader = Halo2Loader::new(ecc_chip, mem::take(self.inner.pool(0))); - let (mut voter_instances, voter_accumulators) = - succinct_verify(&self.svk, &loader, &self.voter, None); - let (mut state_transition_instances, state_transition_accumulators) = - succinct_verify(&self.svk, &loader, &self.state_transition, None); + let (mut voter_instances, voter_accumulators) = succinct_verify( + &self.svk, + &loader, + &self.voter, + None + ); + let (mut previous_instances, previous_accumulators) = succinct_verify( &self.svk, &loader, &self.previous, - Some(preprocessed_digest), + Some(preprocessed_digest) ); let default_accmulator = self.load_default_accumulator(&loader).unwrap(); @@ -589,27 +693,20 @@ pub mod recursion { &loader, &first_round, &default_accmulator, - previous_accumulator, - ) - .unwrap() + previous_accumulator + ).unwrap() }) .collect::>(); let KzgAccumulator { lhs, rhs } = accumulate( &loader, - [ - voter_accumulators, - state_transition_accumulators, - previous_accumulators, - ] - .concat(), - self.as_proof(), + [voter_accumulators, previous_accumulators].concat(), + self.as_proof() ); let lhs = lhs.into_assigned(); let rhs = rhs.into_assigned(); let voter_instances = voter_instances.pop().unwrap(); - let state_transition_instances = state_transition_instances.pop().unwrap(); let previous_instances = previous_instances.pop().unwrap(); let mut pool = loader.take_ctx(); @@ -621,13 +718,11 @@ pub mod recursion { &previous_instances[Self::PREPROCESSED_DIGEST_ROW], ), // Verify round is increased by 1 when not at first round - ( - &round, - &main_gate.add(ctx, not_first_round, previous_instances[Self::ROUND_ROW]), - ), + (&round, &main_gate.add(ctx, not_first_round, previous_instances[Self::ROUND_ROW])), ] { ctx.constrain_equal(lhs, rhs); } + //TODO: Add constrain nullifier_tree.old_root // state_transition(pk_enc) == previous(pk_enc) == voter(pk_enc) for i in 0..4 { @@ -635,49 +730,13 @@ pub mod recursion { // state_transition_instances[i].value(), // voter_instances[i].value() // ); - ctx.constrain_equal(&state_transition_instances[i], &voter_instances[i]); + // ctx.constrain_equal(&state_transition_instances[i], &voter_instances[i]); // assert_eq!( - // state_transition_instances[i].value(), + // voter_instances[i].value(), // previous_instances[4 * LIMBS + i + 1].value() // ); - ctx.constrain_equal( - &state_transition_instances[i], - &previous_instances[4 * LIMBS + i + 1], - ); - } - - // state_transition(prev_vote) == previous(aggr_vote) - for i in 0..20 { - // assert_eq!( - // state_transition_instances[i + 4].value(), - // previous_instances[4 * LIMBS + i + 1 + 4].value() - // ); - ctx.constrain_equal( - &state_transition_instances[i + 4], - &previous_instances[4 * LIMBS + i + 1 + 4], - ); - } - - // state_transition(incoming_vote) == voter(vote) - for i in 0..20 { - // assert_eq!( - // state_transition_instances[i + 24].value(), - // voter_instances[i + 4].value() - // ); - ctx.constrain_equal(&state_transition_instances[i + 24], &voter_instances[i + 4]); - } - - // state_transition(nullifier) == voter(nullifier) - for i in 0..4 { - // assert_eq!( - // state_transition_instances[i + 64].value(), - // voter_instances[i + 24].value() - // ); - ctx.constrain_equal( - &state_transition_instances[i + 64], - &voter_instances[i + 24], - ); + ctx.constrain_equal(&voter_instances[i], &previous_instances[4 * LIMBS + i + 1]); } // state_transition(nullifier_old_root) == previous(nullifier_new_root) @@ -685,30 +744,25 @@ pub mod recursion { // state_transition_instances[68].value(), // previous_instances[4 * LIMBS + 1 + 25].value() // ); - ctx.constrain_equal( - &state_transition_instances[68], - &previous_instances[4 * LIMBS + 1 + 25], - ); + + // ctx.constrain_equal( + // &state_transition_instances[68], + // &previous_instances[4 * LIMBS + 1 + 25] + // ); // previous(membership_root]) == voter(membership_root) // assert_eq!( // previous_instances[4 * LIMBS + 1 + 26].value(), // voter_instances[28].value() // ); - ctx.constrain_equal( - &previous_instances[4 * LIMBS + 1 + 26], - &voter_instances[28], - ); + ctx.constrain_equal(&previous_instances[4 * LIMBS + 1 + 26], &voter_instances[28]); // voter(proposal_id) == previous(proposal_id) // assert_eq!( // voter_instances[29].value(), // previous_instances[4 * LIMBS + 1 + 27].value() // ); - ctx.constrain_equal( - &voter_instances[29], - &previous_instances[4 * LIMBS + 1 + 27], - ); + ctx.constrain_equal(&voter_instances[29], &previous_instances[4 * LIMBS + 1 + 27]); *self.inner.pool(0) = pool; @@ -727,10 +781,9 @@ pub mod recursion { membership_root, proposal_id, round, - ] - .iter(), + ].iter() ) - .copied(), + .copied() ); self.inner.calculate_params(Some(10)); @@ -741,15 +794,17 @@ pub mod recursion { params: &ParamsKZG, vk: Option<&VerifyingKey>, config_params: BaseCircuitParams, - init_aggr_instances: Vec, + init_aggr_instances: Vec ) -> Snark { let mut snark = gen_dummy_snark::(params, vk, config_params); let g = params.get_g(); - snark.instances = vec![[g[1].x, g[1].y, g[0].x, g[0].y] - .into_iter() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .chain(init_aggr_instances) - .collect_vec()]; + snark.instances = vec![ + [g[1].x, g[1].y, g[0].x, g[0].y] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain(init_aggr_instances) + .collect_vec() + ]; snark } @@ -759,15 +814,16 @@ pub mod recursion { fn load_default_accumulator<'a>( &self, - loader: &Rc>, + loader: &Rc> ) -> Result>>, Error> { - let [lhs, rhs] = - [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { + let [lhs, rhs] = [self.default_accumulator.lhs, self.default_accumulator.rhs].map( + |default| { let assigned = loader .ecc_chip() .assign_constant(&mut loader.ctx_mut(), default); loader.ec_point_from_assigned(assigned) - }); + } + ); Ok(KzgAccumulator::new(lhs, rhs)) } @@ -791,7 +847,7 @@ pub mod recursion { fn configure_with_params( meta: &mut ConstraintSystem, - params: Self::Params, + params: Self::Params ) -> Self::Config { BaseCircuitBuilder::configure_with_params(meta, params) } @@ -803,7 +859,7 @@ pub mod recursion { fn synthesize( &self, config: Self::Config, - layouter: impl Layouter, + layouter: impl Layouter ) -> Result<(), Error> { self.inner.synthesize(config, layouter) } @@ -824,8 +880,9 @@ pub mod recursion { } fn selectors(config: &Self::Config) -> Vec { - config.gate().basic_gates[0] - .iter() + config + .gate() + .basic_gates[0].iter() .map(|gate| gate.q_enable) .collect() } @@ -833,32 +890,26 @@ pub mod recursion { pub fn gen_recursion_pk( voter_params: &ParamsKZG, - state_transition_params: &ParamsKZG, recursion_params: &ParamsKZG, voter_vk: &VerifyingKey, - state_transition_vk: &VerifyingKey, voter_config: BaseCircuitParams, - state_transition_config: BaseCircuitParams, recursion_config: BaseCircuitParams, - init_aggr_instances: Vec, + nullifier_tree: IndexedMerkleTreeInput, + init_aggr_instances: Vec ) -> ProvingKey { let recursion = RecursionCircuit::new( CircuitBuilderStage::Keygen, recursion_params, gen_dummy_snark::>(voter_params, Some(voter_vk), voter_config), - gen_dummy_snark::>( - state_transition_params, - Some(state_transition_vk), - state_transition_config, - ), RecursionCircuit::initial_snark( recursion_params, None, recursion_config.clone(), - init_aggr_instances, + init_aggr_instances ), + nullifier_tree, 0, - recursion_config, + recursion_config ); // we cannot auto-configure the circuit because dummy_snark must know the configuration beforehand // uncomment the following line only in development to test and print out the optimal configuration ahead of time @@ -872,28 +923,24 @@ pub mod recursion { recursion_pk: &ProvingKey, recursion_config: BaseCircuitParams, voter_snarks: Vec, - state_transition_snarks: Vec, init_aggr_instances: Vec, + nullifier_tree: Vec> ) -> Snark { let mut previous = RecursionCircuit::initial_snark( recursion_params, Some(recursion_pk.get_vk()), recursion_config.clone(), - init_aggr_instances, + init_aggr_instances ); - for (round, (voter, state_transition)) in voter_snarks - .into_iter() - .zip(state_transition_snarks) - .enumerate() - { + for (round, voter) in voter_snarks.into_iter().enumerate() { let recursion = RecursionCircuit::new( stage, recursion_params, voter, - state_transition, previous, + nullifier_tree[round], round, - recursion_config.clone(), + recursion_config.clone() ); println!("Generate recursion snark for round {}", round); previous = gen_snark(recursion_params, recursion_pk, recursion); @@ -904,38 +951,35 @@ pub mod recursion { #[cfg(test)] mod test { - use std::path::{Path, PathBuf}; - use std::{fs, io::BufReader}; + use std::path::{ Path, PathBuf }; + use std::{ fs, io::BufReader }; - use ark_std::{end_timer, start_timer}; + use ark_std::{ end_timer, start_timer }; use halo2_base::{ - gates::circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage}, + gates::circuit::{ builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage }, halo2_proofs::{ - halo2curves::bn256::{Fr, G1Affine}, + halo2curves::bn256::{ Fr, G1Affine }, plonk::ProvingKey, poly::commitment::ParamsProver, }, utils::fs::gen_srs, }; - use snark_verifier_sdk::{snark_verifier::verifier::SnarkVerifier, NativeLoader}; + use snark_verifier_sdk::{ snark_verifier::verifier::SnarkVerifier, NativeLoader }; + use voter::merkletree::native::MerkleTree; use voter::VoterCircuit; - use crate::{state_transition::StateTransitionCircuit, utils::generate_wrapper_circuit_input}; + use crate::{ state_transition::StateTransitionCircuit, utils::generate_wrapper_circuit_input }; - use super::{ - gen_pk, gen_snark, - recursion::{self}, - PlonkVerifier, PoseidonTranscript, - }; + use super::{ gen_pk, gen_snark, recursion::{ self }, PlonkVerifier, PoseidonTranscript }; fn workspace_dir() -> PathBuf { - let output = std::process::Command::new(env!("CARGO")) + let output = std::process::Command + ::new(env!("CARGO")) .arg("locate-project") .arg("--workspace") .arg("--message-format=plain") .output() - .unwrap() - .stdout; + .unwrap().stdout; let cargo_path = Path::new(std::str::from_utf8(&output).unwrap().trim()); cargo_path.parent().unwrap().to_path_buf() } @@ -943,12 +987,14 @@ mod test { #[test] fn test_recursion() { const GEN_VOTER_PK: bool = true; - const GEN_STATE_TRANSITION_PK: bool = true; const GEN_RECURSION_PK: bool = true; let num_round = 7; let (voter_input, state_transition_input) = generate_wrapper_circuit_input(num_round); + let nullifier_tree_preimages = (0..num_round) + .map(|i| state_transition_input[i].nullifier_tree.clone()) + .collect::>(); let voter_config = BaseCircuitParams { k: 15, @@ -971,7 +1017,7 @@ mod test { voter_pk .write( &mut voter_pk_bytes, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, + halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked ) .unwrap(); // write voter pk to build folder and make sure folder exists @@ -980,64 +1026,16 @@ mod test { println!("Reading voter pk"); let file = fs::read(build_dir.join("voter_pk.bin")).unwrap(); let voter_pk_reader = &mut BufReader::new(file.as_slice()); - voter_pk = ProvingKey::::read::, BaseCircuitBuilder>( - voter_pk_reader, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, - voter_config.clone(), - ) - .unwrap(); - } - println!("Generating voter snark"); - let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); - - let state_transition_config = BaseCircuitParams { - k: 15, - num_advice_per_phase: vec![3], - num_lookup_advice_per_phase: vec![1, 0, 0], - num_fixed: 1, - lookup_bits: Some(14), - num_instance_columns: 1, - }; - let state_transition_params = gen_srs(15); - let state_transition_circuit = StateTransitionCircuit::new( - state_transition_config.clone(), - state_transition_input[0].clone(), - ); - let state_transition_pk: ProvingKey; - if GEN_STATE_TRANSITION_PK { - println!("Generating state transition pk"); - state_transition_pk = gen_pk(&state_transition_params, &state_transition_circuit); - let mut state_transition_pk_bytes = Vec::new(); - state_transition_pk - .write( - &mut state_transition_pk_bytes, + voter_pk = ProvingKey:: + ::read::, BaseCircuitBuilder>( + voter_pk_reader, halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, - ) - .unwrap(); - - fs::write( - build_dir.join("state_transition_pk.bin"), - state_transition_pk_bytes, - ) - .unwrap(); - } else { - println!("Reading state transition pk"); - let file = fs::read(build_dir.join("state_transition_pk.bin")).unwrap(); - let state_transition_pk_reader = &mut BufReader::new(file.as_slice()); - state_transition_pk = - ProvingKey::::read::, BaseCircuitBuilder>( - state_transition_pk_reader, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, - state_transition_config.clone(), + voter_config.clone() ) .unwrap(); } - println!("Generating state transition snark"); - let state_transition_snark = gen_snark( - &state_transition_params, - &state_transition_pk, - state_transition_circuit, - ); + println!("Generating voter snark"); + let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); let k = 22; let recursion_config = BaseCircuitParams { @@ -1052,20 +1050,19 @@ mod test { // Init Base Instances let mut base_instances = [ - Fr::zero(), // preprocessed_digest + Fr::zero(), // preprocessed_digest voter_snark.instances[0][0], // pk_enc_n voter_snark.instances[0][1], voter_snark.instances[0][2], // pk_enc_g voter_snark.instances[0][3], - ] - .to_vec(); - base_instances.extend(state_transition_snark.instances[0][4..24].iter()); // init_vote + ].to_vec(); + base_instances.extend(voter_snark.instances[0][4..24].iter()); // init_vote base_instances.extend([ - state_transition_snark.instances[0][68], // nullifier_old_root - state_transition_snark.instances[0][68], // nullifier_new_root - voter_snark.instances[0][28], // membership_root - voter_snark.instances[0][29], // proposal_id - Fr::from(0), // round + nullifier_tree_preimages[0].get_old_root(), // nullifier_old_root + nullifier_tree_preimages[0].get_old_root(), // nullifier_new_root + voter_snark.instances[0][28], // membership_root + voter_snark.instances[0][29], // proposal_id + Fr::from(0), // round ]); let pk_time = start_timer!(|| "Generate recursion pk"); @@ -1074,20 +1071,18 @@ mod test { println!("Generating recursion pk"); recursion_pk = recursion::gen_recursion_pk( &voter_params, - &state_transition_params, &recursion_params, voter_pk.get_vk(), - state_transition_pk.get_vk(), voter_config.clone(), - state_transition_config.clone(), recursion_config.clone(), - base_instances.clone(), + nullifier_tree[0].clone(), + base_instances.clone() ); let mut recursion_pk_bytes = Vec::new(); recursion_pk .write( &mut recursion_pk_bytes, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, + halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked ) .unwrap(); @@ -1096,32 +1091,21 @@ mod test { println!("Reading recursion pk"); let file = fs::read(build_dir.join("recursion_pk.bin")).unwrap(); let recursion_pk_reader = &mut BufReader::new(file.as_slice()); - recursion_pk = - ProvingKey::::read::, BaseCircuitBuilder>( + recursion_pk = ProvingKey:: + ::read::, BaseCircuitBuilder>( recursion_pk_reader, halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, - recursion_config.clone(), + recursion_config.clone() ) .unwrap(); } end_timer!(pk_time); let mut voter_snarks = vec![voter_snark]; - let mut state_transition_snarks = vec![state_transition_snark]; for i in 1..num_round { let voter_circuit = VoterCircuit::new(voter_config.clone(), voter_input[i].clone()); voter_snarks.push(gen_snark(&voter_params, &voter_pk, voter_circuit)); - - let state_transition_circuit = StateTransitionCircuit::new( - state_transition_config.clone(), - state_transition_input[i].clone(), - ); - state_transition_snarks.push(gen_snark( - &state_transition_params, - &state_transition_pk, - state_transition_circuit, - )); } println!("Starting recursion..."); @@ -1132,8 +1116,8 @@ mod test { &recursion_pk, recursion_config, voter_snarks, - state_transition_snarks, - base_instances, + nullifier_tree_preimages, + base_instances ); end_timer!(pf_time); @@ -1142,19 +1126,22 @@ mod test { recursion_params.get_g()[0], recursion_params.g2(), recursion_params.s_g2(), - ) - .into(); - let mut transcript = - PoseidonTranscript::::new::<0>(final_snark.proof.as_slice()); + ).into(); + let mut transcript = PoseidonTranscript::::new::<0>( + final_snark.proof.as_slice() + ); let proof = PlonkVerifier::read_proof( &dk, &final_snark.protocol, &final_snark.instances, - &mut transcript, - ) - .unwrap(); - PlonkVerifier::verify(&dk, &final_snark.protocol, &final_snark.instances, &proof) - .unwrap(); + &mut transcript + ).unwrap(); + PlonkVerifier::verify( + &dk, + &final_snark.protocol, + &final_snark.instances, + &proof + ).unwrap(); } } } From b57e4d53923da54e5d468c8e72362b422d555be5 Mon Sep 17 00:00:00 2001 From: Rahul Ghangas Date: Wed, 27 Mar 2024 11:50:53 +0530 Subject: [PATCH 2/7] feat(wip): tested wrapper circuit for 3 rounds --- aggregator/Cargo.toml | 1 + aggregator/benches/wrapper_circuit.rs | 281 +++++------ aggregator/src/state_transition.rs | 3 + aggregator/src/wrapper.rs | 674 ++++++++++++-------------- 4 files changed, 453 insertions(+), 506 deletions(-) diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index d524714..c258fe6 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -20,6 +20,7 @@ indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merk rand_chacha = "0.3.1" snark-verifier-sdk = { git = "https://github.com/aerius-labs/snark-verifier.git", branch = "feat/custom" } ark-std = "0.4.0" +num-traits = "0.2.18" voter = { path = "../voter" } voter-tests = { path = "../voter_tests" } diff --git a/aggregator/benches/wrapper_circuit.rs b/aggregator/benches/wrapper_circuit.rs index d47d466..069f41c 100644 --- a/aggregator/benches/wrapper_circuit.rs +++ b/aggregator/benches/wrapper_circuit.rs @@ -1,152 +1,153 @@ -use aggregator::state_transition::StateTransitionCircuit; -use aggregator::utils::generate_wrapper_circuit_input; -use aggregator::wrapper::common::gen_dummy_snark; -use aggregator::wrapper::common::gen_pk; -use aggregator::wrapper::common::gen_snark; -use aggregator::wrapper::recursion::RecursionCircuit; -use halo2_base::gates::circuit::BaseCircuitParams; -use halo2_base::gates::circuit::CircuitBuilderStage; -use halo2_base::utils::fs::gen_srs; -use halo2_base::{ - halo2_proofs::{halo2curves::bn256::Fr, plonk::*}, - utils::testing::gen_proof, -}; +// use aggregator::state_transition::StateTransitionCircuit; +// use aggregator::utils::generate_wrapper_circuit_input; +// use aggregator::wrapper::common::gen_dummy_snark; +// use aggregator::wrapper::common::gen_pk; +// use aggregator::wrapper::common::gen_snark; +// use aggregator::wrapper::recursion::RecursionCircuit; +// use halo2_base::gates::circuit::BaseCircuitParams; +// use halo2_base::gates::circuit::CircuitBuilderStage; +// use halo2_base::utils::fs::gen_srs; +// use halo2_base::{ +// halo2_proofs::{halo2curves::bn256::Fr, plonk::*}, +// utils::testing::gen_proof, +// }; -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; +// use criterion::{criterion_group, criterion_main}; +// use criterion::{BenchmarkId, Criterion}; -use pprof::criterion::{Output, PProfProfiler}; -use voter::VoterCircuit; +// use pprof::criterion::{Output, PProfProfiler}; +// use voter::VoterCircuit; -const K: u32 = 22; +// const K: u32 = 22; -fn bench(c: &mut Criterion) { - let (voter_inputs, state_transition_inputs) = generate_wrapper_circuit_input(1); +// fn bench(c: &mut Criterion) { +// let (voter_inputs, state_transition_inputs) = generate_wrapper_circuit_input(1); - // Generating voter proof - let voter_config = BaseCircuitParams { - k: 15, - num_advice_per_phase: vec![1], - num_lookup_advice_per_phase: vec![1, 0, 0], - num_fixed: 1, - lookup_bits: Some(14), - num_instance_columns: 1, - }; - let voter_params = gen_srs(15); - let voter_circuit = VoterCircuit::new(voter_config.clone(), voter_inputs[0].clone()); - let voter_pk = gen_pk(&voter_params, &voter_circuit); - let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); +// // Generating voter proof +// let voter_config = BaseCircuitParams { +// k: 15, +// num_advice_per_phase: vec![1], +// num_lookup_advice_per_phase: vec![1, 0, 0], +// num_fixed: 1, +// lookup_bits: Some(14), +// num_instance_columns: 1, +// }; +// let voter_params = gen_srs(15); +// let voter_circuit = VoterCircuit::new(voter_config.clone(), voter_inputs[0].clone()); +// let voter_pk = gen_pk(&voter_params, &voter_circuit); +// let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); - // Generating state transition proof - let state_transition_config = BaseCircuitParams { - k: 15, - num_advice_per_phase: vec![3], - num_lookup_advice_per_phase: vec![1, 0, 0], - num_fixed: 1, - lookup_bits: Some(14), - num_instance_columns: 1, - }; - let state_transition_params = gen_srs(15); - let state_transition_circuit = StateTransitionCircuit::new( - state_transition_config.clone(), - state_transition_inputs[0].clone(), - ); - let state_transition_pk = gen_pk(&state_transition_params, &state_transition_circuit); - let state_transition_snark = gen_snark( - &state_transition_params, - &state_transition_pk, - state_transition_circuit, - ); +// // Generating state transition proof +// let state_transition_config = BaseCircuitParams { +// k: 15, +// num_advice_per_phase: vec![3], +// num_lookup_advice_per_phase: vec![1, 0, 0], +// num_fixed: 1, +// lookup_bits: Some(14), +// num_instance_columns: 1, +// }; +// let state_transition_params = gen_srs(15); +// let state_transition_circuit = StateTransitionCircuit::new( +// state_transition_config.clone(), +// state_transition_inputs[0].clone(), +// ); +// let state_transition_pk = gen_pk(&state_transition_params, &state_transition_circuit); +// let state_transition_snark = gen_snark( +// &state_transition_params, +// &state_transition_pk, +// state_transition_circuit, +// ); - let recursion_config = BaseCircuitParams { - k: K as usize, - num_advice_per_phase: vec![4], - num_lookup_advice_per_phase: vec![1, 0, 0], - num_fixed: 1, - lookup_bits: Some((K - 1) as usize), - num_instance_columns: 1, - }; - let recursion_params = gen_srs(K); +// let recursion_config = BaseCircuitParams { +// k: K as usize, +// num_advice_per_phase: vec![4], +// num_lookup_advice_per_phase: vec![1, 0, 0], +// num_fixed: 1, +// lookup_bits: Some((K - 1) as usize), +// num_instance_columns: 1, +// }; +// let recursion_params = gen_srs(K); - // Init Base Instances - let mut base_instances = [ - Fr::zero(), // preprocessed_digest - voter_snark.instances[0][0], // pk_enc_n - voter_snark.instances[0][1], - voter_snark.instances[0][2], // pk_enc_g - voter_snark.instances[0][3], - ] - .to_vec(); - base_instances.extend(state_transition_snark.instances[0][4..24].iter()); // init_vote - base_instances.extend([ - state_transition_snark.instances[0][68], // nullifier_old_root - state_transition_snark.instances[0][68], // nullifier_new_root - voter_snark.instances[0][28], // membership_root - voter_snark.instances[0][29], // proposal_id - Fr::from(0), // round - ]); +// // Init Base Instances +// let mut base_instances = [ +// Fr::zero(), // preprocessed_digest +// voter_snark.instances[0][0], // pk_enc_n +// voter_snark.instances[0][1], +// voter_snark.instances[0][2], // pk_enc_g +// voter_snark.instances[0][3], +// ] +// .to_vec(); +// base_instances.extend(state_transition_snark.instances[0][4..24].iter()); // init_vote +// base_instances.extend([ +// state_transition_snark.instances[0][68], // nullifier_old_root +// state_transition_snark.instances[0][68], // nullifier_new_root +// voter_snark.instances[0][28], // membership_root +// voter_snark.instances[0][29], // proposal_id +// Fr::from(0), // round +// ]); - let recursion_circuit = RecursionCircuit::new( - CircuitBuilderStage::Keygen, - &recursion_params, - gen_dummy_snark::>(&voter_params, Some(voter_pk.get_vk()), voter_config), - gen_dummy_snark::>( - &state_transition_params, - Some(state_transition_pk.get_vk()), - state_transition_config, - ), - RecursionCircuit::initial_snark( - &recursion_params, - None, - recursion_config.clone(), - base_instances.clone(), - ), - 0, - recursion_config, - ); - let pk = gen_pk(&recursion_params, &recursion_circuit); - let config_params = recursion_circuit.inner().params(); +// let recursion_circuit = RecursionCircuit::new( +// CircuitBuilderStage::Keygen, +// &recursion_params, +// gen_dummy_snark::>(&voter_params, Some(voter_pk.get_vk()), voter_config), +// gen_dummy_snark::>( +// &state_transition_params, +// Some(state_transition_pk.get_vk()), +// state_transition_config, +// ), +// RecursionCircuit::initial_snark( +// &recursion_params, +// None, +// recursion_config.clone(), +// base_instances.clone(), +// ), +// 0, +// recursion_config, +// ); +// let pk = gen_pk(&recursion_params, &recursion_circuit); +// let config_params = recursion_circuit.inner().params(); - let mut group = c.benchmark_group("plonk-prover"); - group.sample_size(10); - group.bench_with_input( - BenchmarkId::new("wrapper circuit", K), - &( - &recursion_params, - &pk, - &voter_snark, - &state_transition_snark, - ), - |bencher, &(params, pk, voter_snark, state_transition_snark)| { - let cloned_voter_snark = voter_snark; - let cloned_state_transition_snark = state_transition_snark; - bencher.iter(|| { - let cloned_config_params = config_params.clone(); - let circuit = RecursionCircuit::new( - CircuitBuilderStage::Prover, - ¶ms, - cloned_voter_snark.clone(), - cloned_state_transition_snark.clone(), - RecursionCircuit::initial_snark( - ¶ms, - None, - cloned_config_params.clone(), - base_instances.clone(), - ), - 0, - cloned_config_params, - ); +// let mut group = c.benchmark_group("plonk-prover"); +// group.sample_size(10); +// group.bench_with_input( +// BenchmarkId::new("wrapper circuit", K), +// &( +// &recursion_params, +// &pk, +// &voter_snark, +// &state_transition_snark, +// ), +// |bencher, &(params, pk, voter_snark, state_transition_snark)| { +// let cloned_voter_snark = voter_snark; +// let cloned_state_transition_snark = state_transition_snark; +// bencher.iter(|| { +// let cloned_config_params = config_params.clone(); +// let circuit = RecursionCircuit::new( +// CircuitBuilderStage::Prover, +// ¶ms, +// cloned_voter_snark.clone(), +// cloned_state_transition_snark.clone(), +// RecursionCircuit::initial_snark( +// ¶ms, +// None, +// cloned_config_params.clone(), +// base_instances.clone(), +// ), +// 0, +// cloned_config_params, +// ); - gen_proof(params, pk, circuit); - }) - }, - ); - group.finish() -} +// gen_proof(params, pk, circuit); +// }) +// }, +// ); +// group.finish() +// } -criterion_group! { - name = benches; - config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); - targets = bench -} -criterion_main!(benches); +// criterion_group! { +// name = benches; +// config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); +// targets = bench +// } +// criterion_main!(benches); +fn main() {} diff --git a/aggregator/src/state_transition.rs b/aggregator/src/state_transition.rs index 74f133a..45006ac 100644 --- a/aggregator/src/state_transition.rs +++ b/aggregator/src/state_transition.rs @@ -71,6 +71,9 @@ impl IndexedMerkleTreeInput { pub fn get_old_root(&self) -> F { self.old_root } + pub fn get_new_root(&self) -> F { + self.new_root + } } #[derive(Debug, Clone)] diff --git a/aggregator/src/wrapper.rs b/aggregator/src/wrapper.rs index 0e6d23a..6fb6692 100644 --- a/aggregator/src/wrapper.rs +++ b/aggregator/src/wrapper.rs @@ -1,26 +1,22 @@ -use ark_std::{ end_timer, start_timer }; +use ark_std::{end_timer, start_timer}; use common::*; use halo2_base::halo2_proofs; use halo2_proofs::{ - circuit::{ Layouter, SimpleFloorPlanner }, + circuit::{Layouter, SimpleFloorPlanner}, dev::MockProver, - halo2curves::{ bn256::{ Bn256, Fr, G1Affine }, group::ff::Field }, + halo2curves::{ + bn256::{Bn256, Fr, G1Affine}, + group::ff::Field, + }, plonk::{ - create_proof, - keygen_pk, - keygen_vk, - Circuit, - ConstraintSystem, - Error, - ProvingKey, - Selector, + create_proof, keygen_pk, keygen_vk, Circuit, ConstraintSystem, Error, ProvingKey, Selector, VerifyingKey, }, poly::{ commitment::ParamsProver, kzg::{ commitment::ParamsKZG, - multiopen::{ ProverGWC, VerifierGWC }, + multiopen::{ProverGWC, VerifierGWC}, strategy::AccumulatorStrategy, }, VerificationStrategy, @@ -29,17 +25,23 @@ use halo2_proofs::{ use itertools::Itertools; use rand_chacha::rand_core::OsRng; use snark_verifier_sdk::snark_verifier::{ - loader::{ self, native::NativeLoader, Loader, ScalarLoader }, + loader::{self, native::NativeLoader, Loader, ScalarLoader}, pcs::{ - kzg::{ Gwc19, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding }, - AccumulationScheme, - AccumulationSchemeProver, + kzg::{Gwc19, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, + AccumulationScheme, AccumulationSchemeProver, + }, + system::halo2::{self, compile, Config}, + util::{ + arithmetic::{fe_to_fe, fe_to_limbs}, + hash, + }, + verifier::{ + self, + plonk::{PlonkProof, PlonkProtocol}, + SnarkVerifier, }, - system::halo2::{ self, compile, Config }, - util::{ arithmetic::{ fe_to_fe, fe_to_limbs }, hash }, - verifier::{ self, plonk::{ PlonkProof, PlonkProtocol }, SnarkVerifier }, }; -use std::{ iter, marker::PhantomData, rc::Rc }; +use std::{iter, marker::PhantomData, rc::Rc}; const LIMBS: usize = 3; const BITS: usize = 88; @@ -54,29 +56,21 @@ type As = KzgAs; type PlonkVerifier = verifier::plonk::PlonkVerifier>; type PlonkSuccinctVerifier = verifier::plonk::PlonkSuccinctVerifier>; type Poseidon = hash::Poseidon; -type PoseidonTranscript = halo2::transcript::halo2::PoseidonTranscript< - G1Affine, - L, - S, - T, - RATE, - R_F, - R_P ->; +type PoseidonTranscript = + halo2::transcript::halo2::PoseidonTranscript; pub mod common { use super::*; - use halo2_proofs::{ plonk::verify_proof, poly::commitment::Params }; - use serde::{ Deserialize, Serialize }; + use halo2_proofs::{plonk::verify_proof, poly::commitment::Params}; + use serde::{Deserialize, Serialize}; use snark_verifier_sdk::snark_verifier::{ - cost::CostEstimation, - util::transcript::TranscriptWrite, + cost::CostEstimation, util::transcript::TranscriptWrite, }; use voter::CircuitExt; pub fn poseidon>( loader: &L, - inputs: &[L::LoadedScalar] + inputs: &[L::LoadedScalar], ) -> L::LoadedScalar { // warning: generating a new spec is time intensive, use lazy_static in production let mut hasher = Poseidon::new::(loader); @@ -84,13 +78,6 @@ pub mod common { hasher.squeeze() } - pub fn limbs_to_biguint(x: Vec) -> BigUint { - x.iter() - .enumerate() - .map(|(i, limb)| fe_to_biguint(limb) * BigUint::from(2u64).pow(88 * (i as u32))) - .sum() - } - #[derive(Clone, Serialize, Deserialize)] pub struct Snark { pub protocol: PlonkProtocol, @@ -102,7 +89,7 @@ pub mod common { pub fn new( protocol: PlonkProtocol, instances: Vec>, - proof: Vec + proof: Vec, ) -> Self { Self { protocol, @@ -125,42 +112,44 @@ pub mod common { params: &ParamsKZG, pk: &ProvingKey, circuit: C, - instances: Vec> + instances: Vec>, ) -> Vec { if params.k() > 3 { let mock = start_timer!(|| "Mock prover"); - MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); end_timer!(mock); } let instances = instances.iter().map(Vec::as_slice).collect_vec(); let proof = { - let mut transcript = PoseidonTranscript::::new::( - Vec::new() - ); + let mut transcript = + PoseidonTranscript::::new::(Vec::new()); create_proof::<_, ProverGWC<_>, _, _, _, _>( params, pk, &[circuit], &[instances.as_slice()], OsRng, - &mut transcript - ).unwrap(); + &mut transcript, + ) + .unwrap(); transcript.finalize() }; let accept = { - let mut transcript = PoseidonTranscript::::new::( - proof.as_slice() - ); + let mut transcript = + PoseidonTranscript::::new::(proof.as_slice()); VerificationStrategy::<_, VerifierGWC<_>>::finalize( verify_proof::<_, VerifierGWC<_>, _, _, _>( params.verifier_params(), pk.get_vk(), AccumulatorStrategy::new(params.verifier_params()), &[instances.as_slice()], - &mut transcript - ).unwrap() + &mut transcript, + ) + .unwrap(), ) }; assert!(accept); @@ -171,14 +160,14 @@ pub mod common { pub fn gen_snark>( params: &ParamsKZG, pk: &ProvingKey, - circuit: ConcreteCircuit + circuit: ConcreteCircuit, ) -> Snark { let protocol = compile( params, pk.get_vk(), Config::kzg() .with_num_instance(ConcreteCircuit::num_instance()) - .with_accumulator_indices(ConcreteCircuit::accumulator_indices()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), ); let instances = circuit.instances(); @@ -190,13 +179,17 @@ pub mod common { pub fn gen_dummy_snark>( params: &ParamsKZG, vk: Option<&VerifyingKey>, - config_params: ConcreteCircuit::Params + config_params: ConcreteCircuit::Params, ) -> Snark - where ConcreteCircuit::Params: Clone + where + ConcreteCircuit::Params: Clone, { struct CsProxy>(C::Params, PhantomData<(F, C)>); - impl> Circuit for CsProxy where C::Params: Clone { + impl> Circuit for CsProxy + where + C::Params: Clone, + { type Config = C::Config; type FloorPlanner = C::FloorPlanner; type Params = C::Params; @@ -211,7 +204,7 @@ pub mod common { fn configure_with_params( meta: &mut ConstraintSystem, - params: Self::Params + params: Self::Params, ) -> Self::Config { C::configure_with_params(meta, params) } @@ -223,7 +216,7 @@ pub mod common { fn synthesize( &self, config: Self::Config, - mut layouter: impl Layouter + mut layouter: impl Layouter, ) -> Result<(), Error> { // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row // currently this only works if all simple selector columns are used in the actual circuit and there are overlaps amongst all enabled selectors (i.e., the actual circuit will not optimize constraint system further) @@ -234,44 +227,39 @@ pub mod common { q.enable(&mut region, 0)?; } Ok(()) - } + }, )?; Ok(()) } } - let dummy_vk = vk - .is_none() - .then(|| { - keygen_vk( - params, - &CsProxy::(config_params, PhantomData) - ).unwrap() - }); + let dummy_vk = vk.is_none().then(|| { + keygen_vk( + params, + &CsProxy::(config_params, PhantomData), + ) + .unwrap() + }); let protocol = compile( params, vk.or(dummy_vk.as_ref()).unwrap(), Config::kzg() .with_num_instance(ConcreteCircuit::num_instance()) - .with_accumulator_indices(ConcreteCircuit::accumulator_indices()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), ); let instances = ConcreteCircuit::num_instance() .into_iter() - .map(|n| - iter - ::repeat_with(|| Fr::random(OsRng)) - .take(n) - .collect() - ) + .map(|n| iter::repeat_with(|| Fr::random(OsRng)).take(n).collect()) .collect(); let proof = { - let mut transcript = PoseidonTranscript::::new::( - Vec::new() - ); - for _ in 0..protocol.num_witness + let mut transcript = + PoseidonTranscript::::new::(Vec::new()); + for _ in 0..protocol + .num_witness .iter() .chain(Some(&protocol.quotient.num_chunk())) - .sum::() { + .sum::() + { transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); } for _ in 0..protocol.evaluations.len() { @@ -294,26 +282,19 @@ pub mod recursion { use halo2_base::{ gates::{ circuit::{ - builder::BaseCircuitBuilder, - BaseCircuitParams, - BaseConfig, - CircuitBuilderStage, + builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig, CircuitBuilderStage, }, - GateInstructions, - RangeInstructions, + range, GateInstructions, RangeInstructions, }, - halo2_proofs::halo2curves::secp256k1::Secp256k1Affine, + poseidon::hasher::state, AssignedValue, }; - use halo2_ecc::{ bn254::FpChip, ecc::EcPoint }; - use num_bigint::BigUint; - use snark_verifier_sdk::snark_verifier::loader::halo2::{ EccInstructions, IntegerInstructions }; - use voter::{ CircuitExt, EncryptionPublicKey, VoterCircuit }; + use halo2_ecc::{bn254::FpChip, ecc::EcPoint}; + use snark_verifier_sdk::snark_verifier::loader::halo2::{EccInstructions, IntegerInstructions}; + use voter::{CircuitExt, VoterCircuit}; use crate::state_transition::{ - state_transition_circuit, - StateTransitionCircuit, - StateTransitionInput, + state_transition_circuit, StateTransitionCircuit, StateTransitionInput, }; use super::*; @@ -325,18 +306,21 @@ pub mod recursion { svk: &Svk, loader: &Rc>, snark: &Snark, - preprocessed_digest: Option> - ) -> (Vec>>, Vec>>>) { + preprocessed_digest: Option>, + ) -> ( + Vec>>, + Vec>>>, + ) { let protocol = if let Some(preprocessed_digest) = preprocessed_digest { let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); let protocol = snark.protocol.loaded_preprocessed_as_witness(loader, false); - let inputs = protocol.preprocessed + let inputs = protocol + .preprocessed .iter() .flat_map(|preprocessed| { let assigned = preprocessed.assigned(); - [assigned.x(), assigned.y()].map(|coordinate| - loader.scalar_from_assigned(*coordinate.native()) - ) + [assigned.x(), assigned.y()] + .map(|coordinate| loader.scalar_from_assigned(*coordinate.native())) }) .chain(protocol.transcript_initial_state.clone()) .collect_vec(); @@ -346,7 +330,8 @@ pub mod recursion { snark.protocol.loaded(loader) }; - let instances = snark.instances + let instances = snark + .instances .iter() .map(|instances| { instances @@ -355,22 +340,12 @@ pub mod recursion { .collect_vec() }) .collect_vec(); - let mut transcript = PoseidonTranscript::, _>::new::( - loader, - snark.proof() - ); - let proof = PlonkSuccinctVerifier::read_proof( - svk, - &protocol, - &instances, - &mut transcript - ).unwrap(); - let accumulators = PlonkSuccinctVerifier::verify( - svk, - &protocol, - &instances, - &proof - ).unwrap(); + let mut transcript = + PoseidonTranscript::, _>::new::(loader, snark.proof()); + let proof = + PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulators = + PlonkSuccinctVerifier::verify(svk, &protocol, &instances, &proof).unwrap(); ( instances @@ -390,41 +365,35 @@ pub mod recursion { loader: &Rc>, condition: &AssignedValue, lhs: &KzgAccumulator>>, - rhs: &KzgAccumulator>> + rhs: &KzgAccumulator>>, ) -> Result>>, Error> { let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] .iter() .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) .map(|(lhs, rhs)| { - loader - .ecc_chip() - .select( - loader.ctx_mut().main(), - EcPoint::clone(lhs), - EcPoint::clone(rhs), - *condition - ) + loader.ecc_chip().select( + loader.ctx_mut().main(), + EcPoint::clone(lhs), + EcPoint::clone(rhs), + *condition, + ) }) .collect::>() .try_into() .unwrap(); - Ok( - KzgAccumulator::new( - loader.ec_point_from_assigned(lhs), - loader.ec_point_from_assigned(rhs) - ) - ) + Ok(KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs), + )) } fn accumulate<'a>( loader: &Rc>, accumulators: Vec>>>, - as_proof: &[u8] + as_proof: &[u8], ) -> KzgAccumulator>> { - let mut transcript = PoseidonTranscript::, _>::new::( - loader, - as_proof - ); + let mut transcript = + PoseidonTranscript::, _>::new::(loader, as_proof); let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); As::verify(&Default::default(), &accumulators, &proof).unwrap() } @@ -467,34 +436,30 @@ pub mod recursion { stage: CircuitBuilderStage, params: &ParamsKZG, voter: Snark, + state_transition_input: StateTransitionInput, previous: Snark, - nullifier_tree: IndexedMerkleTreeInput, round: usize, - config_params: BaseCircuitParams + config_params: BaseCircuitParams, ) -> Self { let svk = params.get_g()[0].into(); let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); let succinct_verify = |snark: &Snark| { let mut transcript = PoseidonTranscript::::new::( - snark.proof.as_slice() + snark.proof.as_slice(), ); let proof = PlonkSuccinctVerifier::read_proof( &svk, &snark.protocol, &snark.instances, - &mut transcript - ).unwrap(); - PlonkSuccinctVerifier::verify( - &svk, - &snark.protocol, - &snark.instances, - &proof - ).unwrap() + &mut transcript, + ) + .unwrap(); + PlonkSuccinctVerifier::verify(&svk, &snark.protocol, &snark.instances, &proof) + .unwrap() }; - let accumulators = iter - ::empty() + let accumulators = iter::empty() .chain(succinct_verify(&voter)) .chain( (round > 0) @@ -502,25 +467,23 @@ pub mod recursion { .unwrap_or_else(|| { let num_accumulator = 1 + previous.protocol.accumulator_indices.len(); vec![default_accumulator.clone(); num_accumulator] - }) + }), ) .collect_vec(); let (accumulator, as_proof) = { - let mut transcript = PoseidonTranscript::::new::( - Vec::new() - ); - let accumulator = As::create_proof( - &Default::default(), - &accumulators, - &mut transcript, - OsRng - ).unwrap(); + let mut transcript = + PoseidonTranscript::::new::(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); (accumulator, transcript.finalize()) }; let preprocessed_digest = { - let inputs = previous.protocol.preprocessed + let inputs = previous + .protocol + .preprocessed .iter() .flat_map(|preprocessed| [preprocessed.x, preprocessed.y]) .map(fe_to_fe) @@ -529,74 +492,17 @@ pub mod recursion { poseidon(&NativeLoader, &inputs) }; - //State transition circuit input - - //state_transition(pk_enc) - let st_pk_enc_fr = [ - voter.instances[0][0], - voter.instances[0][1], - voter.instances[0][2], - voter.instances[0][3], - ]; - let sk_pk_enc = EncryptionPublicKey { - n: limbs_to_biguint(st_pk_enc[0..2]), - g: limbs_to_biguint(st_pk_enc[2..4]), - }; - - //state_transition(prev_vote) - let st_prev_vote_fr = (0..20) - .map(|i| previous.instances[0][4 * LIMBS + i + 1 + 4]) - .collect::>(); - let st_prev_vote = (0..5) - .map(|i| limbs_to_biguint(st_prev_vote_fr[4 * i + 4 * i + 4])) - .collect::>(); - - //state_transition(incoming_vote) - let st_incoming_vote_fr = (0..20) - .map(|i| voter.instances[0][i + 4]) - .collect::>(); - let st_incoming_vote = (0..5) - .map(|i| limbs_to_biguint(st_incoming_vote_fr[4 * i + 4 * i + 4])) - .collect::>(); - - //state_transition(nullifier) - let st_nullifier = (0..4).map(|i| voter.instances[0][i + 24]).collect::>(); - let nullifier_x = limbs_to_biguint(st_nullifier[0..2]); - let nullifier_y = limbs_to_biguint(st_nullifier[2..4]); - - let st_nullifier = Secp256k1Affine::new( - fe_to_fe::(nullifier_x), - fe_to_fe::(nullifier_y) - ); - let state_transition_input = StateTransitionInput::new( - sk_pk_enc, - st_prev_vote, - st_incoming_vote, - nullifier_tree, - st_nullifier - ); - //state_transition(nullifier_old_root) - // let st_nullifier_old_root = previous.instances[0][4 * LIMBS + 1 + 25]; - - let inner = BaseCircuitBuilder::from_stage(stage).use_params(config_params); - let range = inner.range(); - let mut public_inputs = Vec::>::new(); - state_transition_circuit( - inner.main(0), - range, - state_transition_input, - &mut public_inputs - ); let mut current_instances = [ voter.instances[0][0], voter.instances[0][1], voter.instances[0][2], voter.instances[0][3], - ].to_vec(); - current_instances.extend(public_inputs[44..64].iter()); + ] + .to_vec(); + current_instances.extend(voter.instances[0][4..24].iter().clone()); current_instances.extend([ - public_inputs[68], - public_inputs[69], + state_transition_input.nullifier_tree.get_old_root(), + state_transition_input.nullifier_tree.get_new_root(), voter.instances[0][28], voter.instances[0][29], ]); @@ -607,13 +513,22 @@ pub mod recursion { accumulator.rhs.x, accumulator.rhs.y, ] - .into_iter() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .chain([preprocessed_digest]) - .chain(current_instances) - .chain([Fr::from(round as u64)]) - .collect(); + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([preprocessed_digest]) + .chain(current_instances) + .chain([Fr::from(round as u64)]) + .collect(); + let mut inner = BaseCircuitBuilder::from_stage(stage).use_params(config_params); + let range = inner.range_chip(); + let mut public_inputs = Vec::>::new(); + state_transition_circuit( + inner.main(0), + &range, + state_transition_input, + &mut public_inputs, + ); let mut circuit = Self { svk, default_accumulator, @@ -624,19 +539,17 @@ pub mod recursion { as_proof, inner, }; - circuit.build(); + circuit.build(public_inputs); circuit } - fn build(&mut self) { + fn build(&mut self, public_inputs: Vec>) { let range = self.inner.range_chip(); let main_gate = range.gate(); let pool = self.inner.pool(0); - let preprocessed_digest = main_gate.assign_integer( - pool, - self.instances[Self::PREPROCESSED_DIGEST_ROW] - ); + let preprocessed_digest = + main_gate.assign_integer(pool, self.instances[Self::PREPROCESSED_DIGEST_ROW]); let pk_enc_n = self.instances[Self::PK_ENC_N_ROW..Self::PK_ENC_N_ROW + 2] .iter() .map(|instance| main_gate.assign_integer(pool, *instance)) @@ -649,18 +562,12 @@ pub mod recursion { .iter() .map(|instance| main_gate.assign_integer(pool, *instance)) .collect::>(); - let nullifier_old_root = main_gate.assign_integer( - pool, - self.instances[Self::NULLIFIER_OLD_ROOT_ROW] - ); - let nullifier_new_root = main_gate.assign_integer( - pool, - self.instances[Self::NULLIFIER_NEW_ROOT_ROW] - ); - let membership_root = main_gate.assign_integer( - pool, - self.instances[Self::MEMBERSHIP_ROOT_ROW] - ); + let nullifier_old_root = + main_gate.assign_integer(pool, self.instances[Self::NULLIFIER_OLD_ROOT_ROW]); + let nullifier_new_root = + main_gate.assign_integer(pool, self.instances[Self::NULLIFIER_NEW_ROOT_ROW]); + let membership_root = + main_gate.assign_integer(pool, self.instances[Self::MEMBERSHIP_ROOT_ROW]); let proposal_id = main_gate.assign_integer(pool, self.instances[Self::PROPOSAL_ID_ROW]); let round = main_gate.assign_integer(pool, self.instances[Self::ROUND_ROW]); @@ -671,18 +578,14 @@ pub mod recursion { let ecc_chip = BaseFieldEccChip::new(&fp_chip); let loader = Halo2Loader::new(ecc_chip, mem::take(self.inner.pool(0))); - let (mut voter_instances, voter_accumulators) = succinct_verify( - &self.svk, - &loader, - &self.voter, - None - ); + let (mut voter_instances, voter_accumulators) = + succinct_verify(&self.svk, &loader, &self.voter, None); let (mut previous_instances, previous_accumulators) = succinct_verify( &self.svk, &loader, &self.previous, - Some(preprocessed_digest) + Some(preprocessed_digest), ); let default_accmulator = self.load_default_accumulator(&loader).unwrap(); @@ -693,24 +596,32 @@ pub mod recursion { &loader, &first_round, &default_accmulator, - previous_accumulator - ).unwrap() + previous_accumulator, + ) + .unwrap() }) .collect::>(); let KzgAccumulator { lhs, rhs } = accumulate( &loader, - [voter_accumulators, previous_accumulators].concat(), - self.as_proof() + [ + voter_accumulators, + // state_transition_accumulators, + previous_accumulators, + ] + .concat(), + self.as_proof(), ); let lhs = lhs.into_assigned(); let rhs = rhs.into_assigned(); let voter_instances = voter_instances.pop().unwrap(); + // let state_transition_instances = state_transition_instances.pop().unwrap(); let previous_instances = previous_instances.pop().unwrap(); let mut pool = loader.take_ctx(); let ctx = pool.main(); + println!("public inputs length: {}", public_inputs.len()); for (lhs, rhs) in [ // Propagate preprocessed_digest ( @@ -718,51 +629,85 @@ pub mod recursion { &previous_instances[Self::PREPROCESSED_DIGEST_ROW], ), // Verify round is increased by 1 when not at first round - (&round, &main_gate.add(ctx, not_first_round, previous_instances[Self::ROUND_ROW])), + ( + &round, + &main_gate.add(ctx, not_first_round, previous_instances[Self::ROUND_ROW]), + ), ] { ctx.constrain_equal(lhs, rhs); } - //TODO: Add constrain nullifier_tree.old_root // state_transition(pk_enc) == previous(pk_enc) == voter(pk_enc) for i in 0..4 { // assert_eq!( - // state_transition_instances[i].value(), + // public_inputs[i].value(), // voter_instances[i].value() // ); // ctx.constrain_equal(&state_transition_instances[i], &voter_instances[i]); - // assert_eq!( - // voter_instances[i].value(), - // previous_instances[4 * LIMBS + i + 1].value() + assert_eq!( + public_inputs[i].value(), + previous_instances[4 * LIMBS + i + 1].value() + ); + // ctx.constrain_equal(&public_inputs[i], &previous_instances[4 * LIMBS + i + 1]); + } + + // state_transition(prev_vote) == previous(aggr_vote) + for i in 0..20 { + assert_eq!( + public_inputs[i + 4].value(), + previous_instances[4 * LIMBS + i + 1 + 4].value() + ); + // ctx.constrain_equal( + // &public_inputs[i + 4], + // &previous_instances[4 * LIMBS + i + 1 + 4], // ); - ctx.constrain_equal(&voter_instances[i], &previous_instances[4 * LIMBS + i + 1]); } - // state_transition(nullifier_old_root) == previous(nullifier_new_root) - // assert_eq!( - // state_transition_instances[68].value(), - // previous_instances[4 * LIMBS + 1 + 25].value() - // ); + // state_transition(incoming_vote) == voter(vote) + for i in 0..20 { + assert_eq!( + public_inputs[i + 24].value(), + voter_instances[i + 4].value() + ); + // ctx.constrain_equal(&public_inputs[i + 24], &voter_instances[i + 4]); + } - // ctx.constrain_equal( - // &state_transition_instances[68], - // &previous_instances[4 * LIMBS + 1 + 25] - // ); + // state_transition(nullifier) == voter(nullifier) + for i in 0..4 { + assert_eq!( + public_inputs[i + 64].value(), + voter_instances[i + 24].value() + ); + // ctx.constrain_equal(&public_inputs[i + 64], &voter_instances[i + 24]); + } + + // state_transition(nullifier_old_root) == previous(nullifier_new_root) + assert_eq!( + public_inputs[68].value(), + previous_instances[4 * LIMBS + 1 + 25].value() + ); + // ctx.constrain_equal(&public_inputs[68], &previous_instances[4 * LIMBS + 1 + 25]); // previous(membership_root]) == voter(membership_root) - // assert_eq!( - // previous_instances[4 * LIMBS + 1 + 26].value(), - // voter_instances[28].value() + assert_eq!( + previous_instances[4 * LIMBS + 1 + 26].value(), + voter_instances[28].value() + ); + // ctx.constrain_equal( + // &previous_instances[4 * LIMBS + 1 + 26], + // &voter_instances[28], // ); - ctx.constrain_equal(&previous_instances[4 * LIMBS + 1 + 26], &voter_instances[28]); // voter(proposal_id) == previous(proposal_id) - // assert_eq!( - // voter_instances[29].value(), - // previous_instances[4 * LIMBS + 1 + 27].value() + assert_eq!( + voter_instances[29].value(), + previous_instances[4 * LIMBS + 1 + 27].value() + ); + // ctx.constrain_equal( + // &voter_instances[29], + // &previous_instances[4 * LIMBS + 1 + 27], // ); - ctx.constrain_equal(&voter_instances[29], &previous_instances[4 * LIMBS + 1 + 27]); *self.inner.pool(0) = pool; @@ -781,30 +726,29 @@ pub mod recursion { membership_root, proposal_id, round, - ].iter() + ] + .iter(), ) - .copied() + .copied(), ); - self.inner.calculate_params(Some(10)); - println!("recursion params: {:?}", self.inner.params()); + // self.inner.calculate_params(Some(10)); + // println!("recursion params: {:?}", self.inner.params()); } pub fn initial_snark( params: &ParamsKZG, vk: Option<&VerifyingKey>, config_params: BaseCircuitParams, - init_aggr_instances: Vec + init_aggr_instances: Vec, ) -> Snark { let mut snark = gen_dummy_snark::(params, vk, config_params); let g = params.get_g(); - snark.instances = vec![ - [g[1].x, g[1].y, g[0].x, g[0].y] - .into_iter() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .chain(init_aggr_instances) - .collect_vec() - ]; + snark.instances = vec![[g[1].x, g[1].y, g[0].x, g[0].y] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain(init_aggr_instances) + .collect_vec()]; snark } @@ -814,16 +758,15 @@ pub mod recursion { fn load_default_accumulator<'a>( &self, - loader: &Rc> + loader: &Rc>, ) -> Result>>, Error> { - let [lhs, rhs] = [self.default_accumulator.lhs, self.default_accumulator.rhs].map( - |default| { + let [lhs, rhs] = + [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { let assigned = loader .ecc_chip() .assign_constant(&mut loader.ctx_mut(), default); loader.ec_point_from_assigned(assigned) - } - ); + }); Ok(KzgAccumulator::new(lhs, rhs)) } @@ -847,7 +790,7 @@ pub mod recursion { fn configure_with_params( meta: &mut ConstraintSystem, - params: Self::Params + params: Self::Params, ) -> Self::Config { BaseCircuitBuilder::configure_with_params(meta, params) } @@ -859,7 +802,7 @@ pub mod recursion { fn synthesize( &self, config: Self::Config, - layouter: impl Layouter + layouter: impl Layouter, ) -> Result<(), Error> { self.inner.synthesize(config, layouter) } @@ -880,9 +823,8 @@ pub mod recursion { } fn selectors(config: &Self::Config) -> Vec { - config - .gate() - .basic_gates[0].iter() + config.gate().basic_gates[0] + .iter() .map(|gate| gate.q_enable) .collect() } @@ -894,22 +836,23 @@ pub mod recursion { voter_vk: &VerifyingKey, voter_config: BaseCircuitParams, recursion_config: BaseCircuitParams, - nullifier_tree: IndexedMerkleTreeInput, - init_aggr_instances: Vec + init_aggr_instances: Vec, + state_transition_input: StateTransitionInput, ) -> ProvingKey { + println!("pk recursion_config: {:?}", recursion_config); let recursion = RecursionCircuit::new( CircuitBuilderStage::Keygen, recursion_params, gen_dummy_snark::>(voter_params, Some(voter_vk), voter_config), + state_transition_input, RecursionCircuit::initial_snark( recursion_params, None, recursion_config.clone(), - init_aggr_instances + init_aggr_instances, ), - nullifier_tree, 0, - recursion_config + recursion_config, ); // we cannot auto-configure the circuit because dummy_snark must know the configuration beforehand // uncomment the following line only in development to test and print out the optimal configuration ahead of time @@ -923,24 +866,25 @@ pub mod recursion { recursion_pk: &ProvingKey, recursion_config: BaseCircuitParams, voter_snarks: Vec, + state_transition_inputs: Vec>, init_aggr_instances: Vec, - nullifier_tree: Vec> ) -> Snark { + println!("snark recursion_config: {:?}", recursion_config); let mut previous = RecursionCircuit::initial_snark( recursion_params, Some(recursion_pk.get_vk()), recursion_config.clone(), - init_aggr_instances + init_aggr_instances, ); for (round, voter) in voter_snarks.into_iter().enumerate() { let recursion = RecursionCircuit::new( stage, recursion_params, voter, + state_transition_inputs[round].clone(), previous, - nullifier_tree[round], round, - recursion_config.clone() + recursion_config.clone(), ); println!("Generate recursion snark for round {}", round); previous = gen_snark(recursion_params, recursion_pk, recursion); @@ -951,35 +895,38 @@ pub mod recursion { #[cfg(test)] mod test { - use std::path::{ Path, PathBuf }; - use std::{ fs, io::BufReader }; + use std::path::{Path, PathBuf}; + use std::{fs, io::BufReader}; - use ark_std::{ end_timer, start_timer }; + use ark_std::{end_timer, start_timer}; use halo2_base::{ - gates::circuit::{ builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage }, + gates::circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage}, halo2_proofs::{ - halo2curves::bn256::{ Fr, G1Affine }, + halo2curves::bn256::{Fr, G1Affine}, plonk::ProvingKey, poly::commitment::ParamsProver, }, utils::fs::gen_srs, }; - use snark_verifier_sdk::{ snark_verifier::verifier::SnarkVerifier, NativeLoader }; - use voter::merkletree::native::MerkleTree; + use snark_verifier_sdk::{snark_verifier::verifier::SnarkVerifier, NativeLoader}; use voter::VoterCircuit; - use crate::{ state_transition::StateTransitionCircuit, utils::generate_wrapper_circuit_input }; + use crate::{state_transition::StateTransitionCircuit, utils::generate_wrapper_circuit_input}; - use super::{ gen_pk, gen_snark, recursion::{ self }, PlonkVerifier, PoseidonTranscript }; + use super::{ + gen_pk, gen_snark, + recursion::{self}, + PlonkVerifier, PoseidonTranscript, + }; fn workspace_dir() -> PathBuf { - let output = std::process::Command - ::new(env!("CARGO")) + let output = std::process::Command::new(env!("CARGO")) .arg("locate-project") .arg("--workspace") .arg("--message-format=plain") .output() - .unwrap().stdout; + .unwrap() + .stdout; let cargo_path = Path::new(std::str::from_utf8(&output).unwrap().trim()); cargo_path.parent().unwrap().to_path_buf() } @@ -987,14 +934,12 @@ mod test { #[test] fn test_recursion() { const GEN_VOTER_PK: bool = true; + const GEN_STATE_TRANSITION_PK: bool = true; const GEN_RECURSION_PK: bool = true; - let num_round = 7; + let num_round = 3; let (voter_input, state_transition_input) = generate_wrapper_circuit_input(num_round); - let nullifier_tree_preimages = (0..num_round) - .map(|i| state_transition_input[i].nullifier_tree.clone()) - .collect::>(); let voter_config = BaseCircuitParams { k: 15, @@ -1017,7 +962,7 @@ mod test { voter_pk .write( &mut voter_pk_bytes, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked + halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, ) .unwrap(); // write voter pk to build folder and make sure folder exists @@ -1026,13 +971,12 @@ mod test { println!("Reading voter pk"); let file = fs::read(build_dir.join("voter_pk.bin")).unwrap(); let voter_pk_reader = &mut BufReader::new(file.as_slice()); - voter_pk = ProvingKey:: - ::read::, BaseCircuitBuilder>( - voter_pk_reader, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, - voter_config.clone() - ) - .unwrap(); + voter_pk = ProvingKey::::read::, BaseCircuitBuilder>( + voter_pk_reader, + halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, + voter_config.clone(), + ) + .unwrap(); } println!("Generating voter snark"); let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); @@ -1050,19 +994,20 @@ mod test { // Init Base Instances let mut base_instances = [ - Fr::zero(), // preprocessed_digest + Fr::zero(), // preprocessed_digest voter_snark.instances[0][0], // pk_enc_n voter_snark.instances[0][1], voter_snark.instances[0][2], // pk_enc_g voter_snark.instances[0][3], - ].to_vec(); + ] + .to_vec(); base_instances.extend(voter_snark.instances[0][4..24].iter()); // init_vote base_instances.extend([ - nullifier_tree_preimages[0].get_old_root(), // nullifier_old_root - nullifier_tree_preimages[0].get_old_root(), // nullifier_new_root - voter_snark.instances[0][28], // membership_root - voter_snark.instances[0][29], // proposal_id - Fr::from(0), // round + state_transition_input[0].nullifier_tree.get_old_root(), // nullifier_old_root + state_transition_input[0].nullifier_tree.get_old_root(), // nullifier_new_root + voter_snark.instances[0][28], // membership_root + voter_snark.instances[0][29], // proposal_id + Fr::from(0), // round ]); let pk_time = start_timer!(|| "Generate recursion pk"); @@ -1075,14 +1020,14 @@ mod test { voter_pk.get_vk(), voter_config.clone(), recursion_config.clone(), - nullifier_tree[0].clone(), - base_instances.clone() + base_instances.clone(), + state_transition_input[0].clone(), ); let mut recursion_pk_bytes = Vec::new(); recursion_pk .write( &mut recursion_pk_bytes, - halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked + halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, ) .unwrap(); @@ -1091,11 +1036,11 @@ mod test { println!("Reading recursion pk"); let file = fs::read(build_dir.join("recursion_pk.bin")).unwrap(); let recursion_pk_reader = &mut BufReader::new(file.as_slice()); - recursion_pk = ProvingKey:: - ::read::, BaseCircuitBuilder>( + recursion_pk = + ProvingKey::::read::, BaseCircuitBuilder>( recursion_pk_reader, halo2_base::halo2_proofs::SerdeFormat::RawBytesUnchecked, - recursion_config.clone() + recursion_config.clone(), ) .unwrap(); } @@ -1116,8 +1061,8 @@ mod test { &recursion_pk, recursion_config, voter_snarks, - nullifier_tree_preimages, - base_instances + state_transition_input.clone(), + base_instances, ); end_timer!(pf_time); @@ -1126,22 +1071,19 @@ mod test { recursion_params.get_g()[0], recursion_params.g2(), recursion_params.s_g2(), - ).into(); - let mut transcript = PoseidonTranscript::::new::<0>( - final_snark.proof.as_slice() - ); + ) + .into(); + let mut transcript = + PoseidonTranscript::::new::<0>(final_snark.proof.as_slice()); let proof = PlonkVerifier::read_proof( &dk, &final_snark.protocol, &final_snark.instances, - &mut transcript - ).unwrap(); - PlonkVerifier::verify( - &dk, - &final_snark.protocol, - &final_snark.instances, - &proof - ).unwrap(); + &mut transcript, + ) + .unwrap(); + PlonkVerifier::verify(&dk, &final_snark.protocol, &final_snark.instances, &proof) + .unwrap(); } } } From 610c738d36c55cda8d8aaa06ab1e2c6cca3e865d Mon Sep 17 00:00:00 2001 From: Rahul Ghangas Date: Thu, 4 Apr 2024 18:05:07 +0530 Subject: [PATCH 3/7] feat: tested wrapper circuit with witness_gen_only(false) for 3 rounds --- .../benches/state_transition_circuit.rs | 171 +++++----- aggregator/benches/wrapper_circuit.rs | 75 ++--- aggregator/src/state_transition.rs | 318 +++++++----------- aggregator/src/wrapper.rs | 209 +++++------- voter/src/lib.rs | 2 + 5 files changed, 318 insertions(+), 457 deletions(-) diff --git a/aggregator/benches/state_transition_circuit.rs b/aggregator/benches/state_transition_circuit.rs index b534168..64db975 100644 --- a/aggregator/benches/state_transition_circuit.rs +++ b/aggregator/benches/state_transition_circuit.rs @@ -1,96 +1,97 @@ -use aggregator::state_transition::{state_transition_circuit, StateTransitionInput}; -use aggregator::utils::generate_random_state_transition_circuit_inputs; -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::circuit::BaseCircuitParams; -use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; -use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; -use halo2_base::AssignedValue; -use halo2_base::{ - halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr}, - plonk::*, - poly::kzg::commitment::ParamsKZG, - }, - utils::testing::gen_proof, -}; -use pprof::criterion::{Output, PProfProfiler}; -use rand::rngs::OsRng; +// use aggregator::state_transition::{state_transition_circuit, StateTransitionInput}; +// use aggregator::utils::generate_random_state_transition_circuit_inputs; +// use ark_std::{end_timer, start_timer}; +// use halo2_base::gates::circuit::BaseCircuitParams; +// use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; +// use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; +// use halo2_base::AssignedValue; +// use halo2_base::{ +// halo2_proofs::{ +// halo2curves::bn256::{Bn256, Fr}, +// plonk::*, +// poly::kzg::commitment::ParamsKZG, +// }, +// utils::testing::gen_proof, +// }; +// use pprof::criterion::{Output, PProfProfiler}; +// use rand::rngs::OsRng; -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; +// use criterion::{criterion_group, criterion_main}; +// use criterion::{BenchmarkId, Criterion}; -const K: u32 = 15; +// const K: u32 = 15; -fn state_transition_circuit_bench( - stage: CircuitBuilderStage, - input: StateTransitionInput, - config_params: Option, - break_points: Option, -) -> RangeCircuitBuilder { - let k = K as usize; - let lookup_bits = k - 1; - let mut builder = match stage { - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) - } - _ => RangeCircuitBuilder::from_stage(stage) - .use_k(k) - .use_lookup_bits(lookup_bits), - }; +// fn state_transition_circuit_bench( +// stage: CircuitBuilderStage, +// input: StateTransitionInput, +// config_params: Option, +// break_points: Option, +// ) -> RangeCircuitBuilder { +// let k = K as usize; +// let lookup_bits = k - 1; +// let mut builder = match stage { +// CircuitBuilderStage::Prover => { +// RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) +// } +// _ => RangeCircuitBuilder::from_stage(stage) +// .use_k(k) +// .use_lookup_bits(lookup_bits), +// }; - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - let range = builder.range_chip(); +// let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); +// let range = builder.range_chip(); - let mut public_inputs = Vec::>::new(); - state_transition_circuit(builder.main(0), &range, input, &mut public_inputs); +// let mut public_inputs = Vec::>::new(); +// state_transition_circuit(builder.main(0), &range, input, &mut public_inputs); - end_timer!(start0); - if !stage.witness_gen_only() { - builder.calculate_params(Some(20)); - } - builder -} +// end_timer!(start0); +// if !stage.witness_gen_only() { +// builder.calculate_params(Some(20)); +// } +// builder +// } -fn bench(c: &mut Criterion) { - let state_transition_input = generate_random_state_transition_circuit_inputs(); - let circuit = state_transition_circuit_bench( - CircuitBuilderStage::Keygen, - state_transition_input.clone(), - None, - None, - ); - let config_params = circuit.params(); +// fn bench(c: &mut Criterion) { +// let state_transition_input = generate_random_state_transition_circuit_inputs(); +// let circuit = state_transition_circuit_bench( +// CircuitBuilderStage::Keygen, +// state_transition_input.clone(), +// None, +// None, +// ); +// let config_params = circuit.params(); - let params = ParamsKZG::::setup(K, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.break_points(); +// let params = ParamsKZG::::setup(K, OsRng); +// let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); +// let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); +// let break_points = circuit.break_points(); - let mut group = c.benchmark_group("plonk-prover"); - group.sample_size(10); - group.bench_with_input( - BenchmarkId::new("state transition circuit", K), - &(¶ms, &pk, &state_transition_input), - |bencher, &(params, pk, state_transition_input)| { - let input = state_transition_input.clone(); - bencher.iter(|| { - let circuit = state_transition_circuit_bench( - CircuitBuilderStage::Prover, - input.clone(), - Some(config_params.clone()), - Some(break_points.clone()), - ); +// let mut group = c.benchmark_group("plonk-prover"); +// group.sample_size(10); +// group.bench_with_input( +// BenchmarkId::new("state transition circuit", K), +// &(¶ms, &pk, &state_transition_input), +// |bencher, &(params, pk, state_transition_input)| { +// let input = state_transition_input.clone(); +// bencher.iter(|| { +// let circuit = state_transition_circuit_bench( +// CircuitBuilderStage::Prover, +// input.clone(), +// Some(config_params.clone()), +// Some(break_points.clone()), +// ); - gen_proof(params, pk, circuit); - }) - }, - ); - group.finish() -} +// gen_proof(params, pk, circuit); +// }) +// }, +// ); +// group.finish() +// } -criterion_group! { - name = benches; - config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); - targets = bench -} -criterion_main!(benches); +// criterion_group! { +// name = benches; +// config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); +// targets = bench +// } +// criterion_main!(benches); +fn main() {} diff --git a/aggregator/benches/wrapper_circuit.rs b/aggregator/benches/wrapper_circuit.rs index 069f41c..9aa27f2 100644 --- a/aggregator/benches/wrapper_circuit.rs +++ b/aggregator/benches/wrapper_circuit.rs @@ -1,21 +1,21 @@ -// use aggregator::state_transition::StateTransitionCircuit; // use aggregator::utils::generate_wrapper_circuit_input; // use aggregator::wrapper::common::gen_dummy_snark; // use aggregator::wrapper::common::gen_pk; +// use aggregator::wrapper::common::gen_proof; // use aggregator::wrapper::common::gen_snark; // use aggregator::wrapper::recursion::RecursionCircuit; // use halo2_base::gates::circuit::BaseCircuitParams; // use halo2_base::gates::circuit::CircuitBuilderStage; +// use halo2_base::halo2_proofs::dev::MockProver; +// use halo2_base::halo2_proofs::{halo2curves::bn256::Fr, plonk::*}; +// use halo2_base::utils::decompose_biguint; // use halo2_base::utils::fs::gen_srs; -// use halo2_base::{ -// halo2_proofs::{halo2curves::bn256::Fr, plonk::*}, -// utils::testing::gen_proof, -// }; // use criterion::{criterion_group, criterion_main}; // use criterion::{BenchmarkId, Criterion}; // use pprof::criterion::{Output, PProfProfiler}; +// use snark_verifier_sdk::CircuitExt; // use voter::VoterCircuit; // const K: u32 = 22; @@ -37,27 +37,6 @@ // let voter_pk = gen_pk(&voter_params, &voter_circuit); // let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); -// // Generating state transition proof -// let state_transition_config = BaseCircuitParams { -// k: 15, -// num_advice_per_phase: vec![3], -// num_lookup_advice_per_phase: vec![1, 0, 0], -// num_fixed: 1, -// lookup_bits: Some(14), -// num_instance_columns: 1, -// }; -// let state_transition_params = gen_srs(15); -// let state_transition_circuit = StateTransitionCircuit::new( -// state_transition_config.clone(), -// state_transition_inputs[0].clone(), -// ); -// let state_transition_pk = gen_pk(&state_transition_params, &state_transition_circuit); -// let state_transition_snark = gen_snark( -// &state_transition_params, -// &state_transition_pk, -// state_transition_circuit, -// ); - // let recursion_config = BaseCircuitParams { // k: K as usize, // num_advice_per_phase: vec![4], @@ -68,6 +47,11 @@ // }; // let recursion_params = gen_srs(K); +// let mut init_vote = Vec::::new(); +// for x in state_transition_inputs[0].prev_vote.iter() { +// init_vote.extend(decompose_biguint::(x, 4, 88)); +// } + // // Init Base Instances // let mut base_instances = [ // Fr::zero(), // preprocessed_digest @@ -77,24 +61,19 @@ // voter_snark.instances[0][3], // ] // .to_vec(); -// base_instances.extend(state_transition_snark.instances[0][4..24].iter()); // init_vote +// base_instances.extend(init_vote); // init_vote // base_instances.extend([ -// state_transition_snark.instances[0][68], // nullifier_old_root -// state_transition_snark.instances[0][68], // nullifier_new_root -// voter_snark.instances[0][28], // membership_root -// voter_snark.instances[0][29], // proposal_id -// Fr::from(0), // round +// state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_old_root +// state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_new_root +// voter_snark.instances[0][28], // membership_root +// voter_snark.instances[0][29], // proposal_id +// Fr::from(0), // round // ]); // let recursion_circuit = RecursionCircuit::new( // CircuitBuilderStage::Keygen, // &recursion_params, // gen_dummy_snark::>(&voter_params, Some(voter_pk.get_vk()), voter_config), -// gen_dummy_snark::>( -// &state_transition_params, -// Some(state_transition_pk.get_vk()), -// state_transition_config, -// ), // RecursionCircuit::initial_snark( // &recursion_params, // None, @@ -103,6 +82,7 @@ // ), // 0, // recursion_config, +// state_transition_inputs[0].clone(), // ); // let pk = gen_pk(&recursion_params, &recursion_circuit); // let config_params = recursion_circuit.inner().params(); @@ -111,22 +91,16 @@ // group.sample_size(10); // group.bench_with_input( // BenchmarkId::new("wrapper circuit", K), -// &( -// &recursion_params, -// &pk, -// &voter_snark, -// &state_transition_snark, -// ), -// |bencher, &(params, pk, voter_snark, state_transition_snark)| { +// &(&recursion_params, &pk, &voter_snark), +// |bencher, &(params, pk, voter_snark)| { // let cloned_voter_snark = voter_snark; -// let cloned_state_transition_snark = state_transition_snark; // bencher.iter(|| { // let cloned_config_params = config_params.clone(); // let circuit = RecursionCircuit::new( -// CircuitBuilderStage::Prover, +// CircuitBuilderStage::Prover +// , // ¶ms, // cloned_voter_snark.clone(), -// cloned_state_transition_snark.clone(), // RecursionCircuit::initial_snark( // ¶ms, // None, @@ -135,9 +109,12 @@ // ), // 0, // cloned_config_params, +// state_transition_inputs[0].clone(), // ); - -// gen_proof(params, pk, circuit); +// println!("reached proof generation"); +// let instances = circuit.inner().instances().clone(); +// gen_proof(params, pk, circuit, instances); +// println!("completed proof generation") // }) // }, // ); diff --git a/aggregator/src/state_transition.rs b/aggregator/src/state_transition.rs index 45006ac..765cb3f 100644 --- a/aggregator/src/state_transition.rs +++ b/aggregator/src/state_transition.rs @@ -1,33 +1,25 @@ -use halo2_base::gates::circuit::builder::BaseCircuitBuilder; -use halo2_base::gates::circuit::{ BaseCircuitParams, BaseConfig }; -use halo2_base::gates::GateInstructions; -use halo2_base::halo2_proofs::circuit::{ Layouter, SimpleFloorPlanner }; -use halo2_base::halo2_proofs::plonk::{ Circuit, ConstraintSystem, Error }; -use halo2_base::poseidon::hasher::{ spec::OptimizedPoseidonSpec, PoseidonHasher }; +use biguint_halo2::big_uint::chip::BigUintChip; +use biguint_halo2::big_uint::{AssignedBigUint, Fresh}; +use halo2_base::halo2_proofs::halo2curves::secp256k1::Secp256k1Affine; +use halo2_base::poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}; +use halo2_base::utils::fe_to_biguint; use halo2_base::{ - gates::{ RangeChip, RangeInstructions }, + gates::{RangeChip, RangeInstructions}, halo2_proofs::circuit::Value, utils::BigPrimeField, - AssignedValue, - Context, + AssignedValue, Context, }; -use halo2_ecc::ecc::EccChip; -use halo2_ecc::fields::fp::FpChip; -use indexed_merkle_tree_halo2::indexed_merkle_tree::{ insert_leaf, IndexedMerkleTreeLeaf }; +use halo2_ecc::bigint::OverflowInteger; +use indexed_merkle_tree_halo2::indexed_merkle_tree::{insert_leaf, IndexedMerkleTreeLeaf}; use indexed_merkle_tree_halo2::utils::IndexedMerkleTreeLeaf as IMTLeaf; use num_bigint::BigUint; - -use biguint_halo2::big_uint::chip::BigUintChip; -use halo2_base::halo2_proofs::halo2curves::secp256k1::{ Fp, Secp256k1Affine }; -use paillier_chip::paillier::{ EncryptionPublicKeyAssigned, PaillierChip }; -use serde::{ Deserialize, Serialize }; -use voter::{ compress_nullifier, CircuitExt, EncryptionPublicKey }; +use paillier_chip::paillier::{EncryptionPublicKeyAssigned, PaillierChip}; +use serde::{Deserialize, Serialize}; +use voter::EncryptionPublicKey; const ENC_BIT_LEN: usize = 176; const LIMB_BIT_LEN: usize = 88; -//TODO: Constrain the nullifier hash using x and y limbs - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct IndexedMerkleTreeInput { old_root: F, @@ -53,7 +45,7 @@ impl IndexedMerkleTreeInput { new_leaf_index: F, new_leaf_proof: Vec, new_leaf_proof_helper: Vec, - is_new_leaf_largest: F + is_new_leaf_largest: F, ) -> Self { Self { old_root, @@ -75,7 +67,12 @@ impl IndexedMerkleTreeInput { self.new_root } } - +fn limbs_to_biguint(x: Vec) -> BigUint { + x.iter() + .enumerate() + .map(|(i, limb)| fe_to_biguint(limb) * BigUint::from(2u64).pow(88 * (i as u32))) + .sum() +} #[derive(Debug, Clone)] pub struct StateTransitionInput { pub pk_enc: EncryptionPublicKey, @@ -90,7 +87,7 @@ impl StateTransitionInput { incoming_vote: Vec, prev_vote: Vec, nullifier_tree: IndexedMerkleTreeInput, - nullifier: Secp256k1Affine + nullifier: Secp256k1Affine, ) -> Self { Self { pk_enc, @@ -105,8 +102,9 @@ impl StateTransitionInput { pub fn state_transition_circuit( ctx: &mut Context, range: &RangeChip, - input: StateTransitionInput, - public_inputs: &mut Vec> + nullifier_tree: IndexedMerkleTreeInput, + input_vec: Vec>, + public_inputs: &mut Vec>, ) { let gate = range.gate(); let mut hasher = PoseidonHasher::::new(OptimizedPoseidonSpec::new::<8, 57, 0>()); @@ -115,38 +113,72 @@ pub fn state_transition_circuit( let biguint_chip = BigUintChip::construct(range, LIMB_BIT_LEN); let paillier_chip = PaillierChip::construct(&biguint_chip, ENC_BIT_LEN); - let fp_chip = FpChip::::new(range, LIMB_BIT_LEN, 3); - let ecc_chip = EccChip::>::new(&fp_chip); - - let nullifier = ecc_chip.load_private_unchecked(ctx, (input.nullifier.x, input.nullifier.y)); - let compressed_nullifier = compress_nullifier(ctx, range, &nullifier); - let nullifier_hash = hasher.hash_fix_len_array(ctx, gate, &compressed_nullifier); - - let n_assigned = biguint_chip - .assign_integer(ctx, Value::known(input.pk_enc.n.clone()), ENC_BIT_LEN) - .unwrap(); + let nullifier_hash = hasher.hash_fix_len_array(ctx, gate, &input_vec[0..4]); - let g_assigned = biguint_chip - .assign_integer(ctx, Value::known(input.pk_enc.g.clone()), ENC_BIT_LEN) - .unwrap(); + let n_fr = input_vec[4..6] + .iter() + .map(|vote| *vote.value()) + .collect::>(); + let n_assigned = AssignedBigUint::::new( + OverflowInteger::new(input_vec[4..6].to_vec(), 88), + Value::known(limbs_to_biguint(n_fr)), + ); + let g_fr = input_vec[6..8] + .iter() + .map(|vote| *vote.value()) + .collect::>(); + let g_assigned = AssignedBigUint::::new( + OverflowInteger::new(input_vec[6..8].to_vec(), 88), + Value::known(limbs_to_biguint(g_fr)), + ); let pk_enc = EncryptionPublicKeyAssigned { n: n_assigned, g: g_assigned, }; - - let incoming_vote = input.incoming_vote + let incoming_vote_fr: Vec = input_vec[8..28].iter().map(|x| *x.value()).collect(); + let incoming_vote_biguint = incoming_vote_fr + .chunks(4) + .into_iter() + .map(|chunk| limbs_to_biguint(chunk.to_vec())) + .collect::>(); + let incoming_vote_overflow_int = input_vec[8..28] + .chunks(4) + .into_iter() + .map(|chunk| OverflowInteger::new(chunk.to_vec(), 88)) + .collect::>>(); + let incoming_vote = incoming_vote_overflow_int .iter() - .map(|x| { - biguint_chip.assign_integer(ctx, Value::known(x.clone()), ENC_BIT_LEN * 2).unwrap() + .enumerate() + .map(|(i, over_flow)| { + AssignedBigUint::new( + over_flow.clone(), + Value::known(incoming_vote_biguint[i].clone()), + ) }) - .collect::>(); - let prev_vote = input.prev_vote + .collect::>>(); + + let prev_vote_fr: Vec = input_vec[28..48].iter().map(|x| *x.value()).collect(); + let prev_vote_biguint = prev_vote_fr + .chunks(4) + .into_iter() + .map(|chunk| limbs_to_biguint(chunk.to_vec())) + .collect::>(); + let prev_vote_overflow_int = input_vec[28..48] + .chunks(4) + .into_iter() + .map(|chunk| OverflowInteger::new(chunk.to_vec(), 88)) + .collect::>>(); + let prev_vote = prev_vote_overflow_int .iter() - .map(|x| { - biguint_chip.assign_integer(ctx, Value::known(x.clone()), ENC_BIT_LEN * 2).unwrap() + .enumerate() + .map(|(i, over_flow)| { + AssignedBigUint::new( + over_flow.clone(), + Value::known(prev_vote_biguint[i].clone()), + ) }) - .collect::>(); + .collect::>>(); // Step 1: Aggregate the votes let aggr_vote = incoming_vote @@ -156,194 +188,70 @@ pub fn state_transition_circuit( .collect::>(); // Step 2: Update the nullifier tree - let val = ctx.load_witness(input.nullifier_tree.low_leaf.val); - let next_val = ctx.load_witness(input.nullifier_tree.low_leaf.next_val); - let next_idx = ctx.load_witness(input.nullifier_tree.low_leaf.next_idx); + let val = ctx.load_witness(nullifier_tree.low_leaf.val); + let next_val = ctx.load_witness(nullifier_tree.low_leaf.next_val); + let next_idx = ctx.load_witness(nullifier_tree.low_leaf.next_idx); - let old_root = ctx.load_witness(input.nullifier_tree.old_root); + let old_root = ctx.load_witness(nullifier_tree.old_root); let low_leaf = IndexedMerkleTreeLeaf::new(val, next_val, next_idx); - let new_root = ctx.load_witness(input.nullifier_tree.new_root); + let new_root = ctx.load_witness(nullifier_tree.new_root); - let val = ctx.load_witness(input.nullifier_tree.new_leaf.val); - assert_eq!(val.value(), nullifier_hash.value()); + let val = ctx.load_witness(nullifier_tree.new_leaf.val); + //assert_eq!(val.value(), nullifier_hash.value()); ctx.constrain_equal(&val, &nullifier_hash); - let next_val = ctx.load_witness(input.nullifier_tree.new_leaf.next_val); - let next_idx = ctx.load_witness(input.nullifier_tree.new_leaf.next_idx); + let next_val = ctx.load_witness(nullifier_tree.new_leaf.next_val); + let next_idx = ctx.load_witness(nullifier_tree.new_leaf.next_idx); let new_leaf = IndexedMerkleTreeLeaf::new(val, next_val, next_idx); - let new_leaf_index = ctx.load_witness(input.nullifier_tree.new_leaf_index); - let is_new_leaf_largest = ctx.load_witness(input.nullifier_tree.is_new_leaf_largest); + let new_leaf_index = ctx.load_witness(nullifier_tree.new_leaf_index); + let is_new_leaf_largest = ctx.load_witness(nullifier_tree.is_new_leaf_largest); - let low_leaf_proof = input.nullifier_tree.low_leaf_proof + let low_leaf_proof = nullifier_tree + .low_leaf_proof .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let low_leaf_proof_helper = input.nullifier_tree.low_leaf_proof_helper + let low_leaf_proof_helper = nullifier_tree + .low_leaf_proof_helper .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let new_leaf_proof = input.nullifier_tree.new_leaf_proof + let new_leaf_proof = nullifier_tree + .new_leaf_proof .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - let new_leaf_proof_helper = input.nullifier_tree.new_leaf_proof_helper + let new_leaf_proof_helper = nullifier_tree + .new_leaf_proof_helper .iter() .map(|x| ctx.load_witness(*x)) .collect::>(); - insert_leaf::( - ctx, - range, - &hasher, - &old_root, - &low_leaf, - &low_leaf_proof, - &low_leaf_proof_helper, - &new_root, - &new_leaf, - &new_leaf_index, - &new_leaf_proof, - &new_leaf_proof_helper, - &is_new_leaf_largest - ); + // insert_leaf::( + // ctx, + // range, + // &hasher, + // &old_root, + // &low_leaf, + // &low_leaf_proof, + // &low_leaf_proof_helper, + // &new_root, + // &new_leaf, + // &new_leaf_index, + // &new_leaf_proof, + // &new_leaf_proof_helper, + // &is_new_leaf_largest, + // ); - // PK_ENC N - public_inputs.extend(pk_enc.n.limbs()); - - // PK_ENC G - public_inputs.extend(pk_enc.g.limbs()); - - // PREV_VOTE - for enc_vote in prev_vote { - public_inputs.extend(enc_vote.limbs()); - } - - // INCOMING_VOTE - for enc_vote in incoming_vote { - public_inputs.extend(enc_vote.limbs()); - } - - // AGGR_VOTE for enc_vote in aggr_vote { public_inputs.extend(enc_vote.limbs()); } - // NULLIFIER - public_inputs.extend(compressed_nullifier); - // NULLIFIER_OLD_ROOT public_inputs.extend([old_root]); // NULLIFIER_NEW_ROOT public_inputs.extend([new_root]); } - -pub struct StateTransitionCircuit { - input: StateTransitionInput, - pub inner: BaseCircuitBuilder, -} - -impl StateTransitionCircuit { - pub fn new(config: BaseCircuitParams, input: StateTransitionInput) -> Self { - let mut inner = BaseCircuitBuilder::default(); - inner.set_params(config); - - let range = inner.range_chip(); - let ctx = inner.main(0); - - let mut public_inputs = Vec::>::new(); - state_transition_circuit(ctx, &range, input.clone(), &mut public_inputs); - inner.assigned_instances[0].extend(public_inputs); - inner.calculate_params(Some(10)); - Self { input, inner } - } -} - -impl Circuit for StateTransitionCircuit { - type Config = BaseConfig; - type FloorPlanner = SimpleFloorPlanner; - type Params = BaseCircuitParams; - - fn params(&self) -> Self::Params { - self.inner.params() - } - - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { - BaseCircuitBuilder::configure_with_params(meta, params) - } - - fn configure(_: &mut ConstraintSystem) -> Self::Config { - unreachable!() - } - - fn synthesize(&self, config: Self::Config, layouter: impl Layouter) -> Result<(), Error> { - self.inner.synthesize(config, layouter) - } -} - -impl CircuitExt for StateTransitionCircuit { - fn num_instance() -> Vec { - vec![70] - } - - fn instances(&self) -> Vec> { - vec![ - self.inner.assigned_instances[0] - .iter() - .map(|instance| *instance.value()) - .collect() - ] - } -} - -#[cfg(test)] -mod test { - use halo2_base::{ - gates::circuit::BaseCircuitParams, - halo2_proofs::{ dev::MockProver, halo2curves::bn256::Fr }, - utils::testing::base_test, - AssignedValue, - }; - use voter::CircuitExt; - - use crate::utils::generate_wrapper_circuit_input; - - use super::{ state_transition_circuit, StateTransitionCircuit }; - - #[test] - fn test_state_transition_circuit() { - let (_, multiple_input) = generate_wrapper_circuit_input(4); - - let config = BaseCircuitParams { - k: 15, - num_advice_per_phase: vec![3], - num_lookup_advice_per_phase: vec![1, 0, 0], - num_fixed: 1, - lookup_bits: Some(14), - num_instance_columns: 1, - }; - - for (round, input) in multiple_input.iter().enumerate() { - println!("------round[{}]--------", round); - - let circuit = StateTransitionCircuit::new(config.clone(), input.clone()); - let prover = MockProver::run(15, &circuit, circuit.instances()).unwrap(); - prover.verify().unwrap(); - } - - base_test() - .k(19) - .lookup_bits(18) - .expect_satisfied(true) - .run(|ctx, range| { - let mut public_inputs = Vec::>::new(); - state_transition_circuit(ctx, range, multiple_input[0].clone(), &mut public_inputs) - }); - } -} diff --git a/aggregator/src/wrapper.rs b/aggregator/src/wrapper.rs index 6fb6692..61dae26 100644 --- a/aggregator/src/wrapper.rs +++ b/aggregator/src/wrapper.rs @@ -277,25 +277,20 @@ pub mod common { } pub mod recursion { - use std::mem; + use std::{mem, vec}; use halo2_base::{ gates::{ - circuit::{ - builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig, CircuitBuilderStage, - }, - range, GateInstructions, RangeInstructions, + circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig}, + GateInstructions, RangeInstructions, }, - poseidon::hasher::state, AssignedValue, }; use halo2_ecc::{bn254::FpChip, ecc::EcPoint}; use snark_verifier_sdk::snark_verifier::loader::halo2::{EccInstructions, IntegerInstructions}; use voter::{CircuitExt, VoterCircuit}; - use crate::state_transition::{ - state_transition_circuit, StateTransitionCircuit, StateTransitionInput, - }; + use crate::state_transition::{state_transition_circuit, StateTransitionInput}; use super::*; @@ -433,13 +428,12 @@ pub mod recursion { const ROUND_ROW: usize = 4 * LIMBS + 29; pub fn new( - stage: CircuitBuilderStage, params: &ParamsKZG, voter: Snark, - state_transition_input: StateTransitionInput, previous: Snark, round: usize, config_params: BaseCircuitParams, + state_transition_input: StateTransitionInput, ) -> Self { let svk = params.get_g()[0].into(); let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); @@ -491,7 +485,6 @@ pub mod recursion { .collect_vec(); poseidon(&NativeLoader, &inputs) }; - let mut current_instances = [ voter.instances[0][0], voter.instances[0][1], @@ -499,14 +492,13 @@ pub mod recursion { voter.instances[0][3], ] .to_vec(); - current_instances.extend(voter.instances[0][4..24].iter().clone()); + current_instances.extend(voter.instances[0][4..24].iter()); current_instances.extend([ state_transition_input.nullifier_tree.get_old_root(), state_transition_input.nullifier_tree.get_new_root(), voter.instances[0][28], voter.instances[0][29], ]); - let instances = [ accumulator.lhs.x, accumulator.lhs.y, @@ -520,15 +512,8 @@ pub mod recursion { .chain([Fr::from(round as u64)]) .collect(); - let mut inner = BaseCircuitBuilder::from_stage(stage).use_params(config_params); - let range = inner.range_chip(); - let mut public_inputs = Vec::>::new(); - state_transition_circuit( - inner.main(0), - &range, - state_transition_input, - &mut public_inputs, - ); + let inner = BaseCircuitBuilder::new(false).use_params(config_params); + println!("reached here"); let mut circuit = Self { svk, default_accumulator, @@ -539,13 +524,15 @@ pub mod recursion { as_proof, inner, }; - circuit.build(public_inputs); + + circuit.build(state_transition_input); circuit } - fn build(&mut self, public_inputs: Vec>) { + fn build(&mut self, state_transition_input: StateTransitionInput) { let range = self.inner.range_chip(); let main_gate = range.gate(); + let pool = self.inner.pool(0); let preprocessed_digest = @@ -558,10 +545,7 @@ pub mod recursion { .iter() .map(|instance| main_gate.assign_integer(pool, *instance)) .collect::>(); - let aggr_vote = self.instances[Self::VOTE_ROW..Self::VOTE_ROW + 20] - .iter() - .map(|instance| main_gate.assign_integer(pool, *instance)) - .collect::>(); + let nullifier_old_root = main_gate.assign_integer(pool, self.instances[Self::NULLIFIER_OLD_ROOT_ROW]); let nullifier_new_root = @@ -580,13 +564,14 @@ pub mod recursion { let (mut voter_instances, voter_accumulators) = succinct_verify(&self.svk, &loader, &self.voter, None); - + println!("voter proof verified"); let (mut previous_instances, previous_accumulators) = succinct_verify( &self.svk, &loader, &self.previous, Some(preprocessed_digest), ); + println!("previous proof verified"); let default_accmulator = self.load_default_accumulator(&loader).unwrap(); let previous_accumulators = previous_accumulators @@ -604,24 +589,48 @@ pub mod recursion { let KzgAccumulator { lhs, rhs } = accumulate( &loader, - [ - voter_accumulators, - // state_transition_accumulators, - previous_accumulators, - ] - .concat(), + [voter_accumulators, previous_accumulators].concat(), self.as_proof(), ); let lhs = lhs.into_assigned(); let rhs = rhs.into_assigned(); let voter_instances = voter_instances.pop().unwrap(); - // let state_transition_instances = state_transition_instances.pop().unwrap(); let previous_instances = previous_instances.pop().unwrap(); let mut pool = loader.take_ctx(); let ctx = pool.main(); - println!("public inputs length: {}", public_inputs.len()); + + let mut state_transition_vec = Vec::>::new(); + + //nullifier + state_transition_vec.extend(voter_instances[24..28].iter()); + //n,g + state_transition_vec.extend(voter_instances[0..4].iter()); + //incoming_vote + state_transition_vec.extend(voter_instances[4..24].iter()); + //previous_vote + state_transition_vec + .extend(previous_instances[4 * LIMBS + 5..4 * LIMBS + 5 + 20].iter()); + + let mut public_inputs = Vec::>::new(); + state_transition_circuit( + ctx, + &range, + state_transition_input.nullifier_tree, + state_transition_vec, + &mut public_inputs, + ); + let aggr_vote_fr = public_inputs[0..20] + .iter() + .map(|x| *x.value()) + .collect::>(); + + self.instances[Self::VOTE_ROW..Self::VOTE_ROW + 20].copy_from_slice(&aggr_vote_fr); + let aggr_vote = public_inputs[0..20].to_vec(); + + //updated state transition public inputs + for (lhs, rhs) in [ // Propagate preprocessed_digest ( @@ -637,77 +646,38 @@ pub mod recursion { ctx.constrain_equal(lhs, rhs); } - // state_transition(pk_enc) == previous(pk_enc) == voter(pk_enc) - for i in 0..4 { - // assert_eq!( - // public_inputs[i].value(), - // voter_instances[i].value() - // ); - // ctx.constrain_equal(&state_transition_instances[i], &voter_instances[i]); - - assert_eq!( - public_inputs[i].value(), - previous_instances[4 * LIMBS + i + 1].value() - ); - // ctx.constrain_equal(&public_inputs[i], &previous_instances[4 * LIMBS + i + 1]); - } - - // state_transition(prev_vote) == previous(aggr_vote) - for i in 0..20 { - assert_eq!( - public_inputs[i + 4].value(), - previous_instances[4 * LIMBS + i + 1 + 4].value() - ); - // ctx.constrain_equal( - // &public_inputs[i + 4], - // &previous_instances[4 * LIMBS + i + 1 + 4], - // ); - } - - // state_transition(incoming_vote) == voter(vote) - for i in 0..20 { - assert_eq!( - public_inputs[i + 24].value(), - voter_instances[i + 4].value() - ); - // ctx.constrain_equal(&public_inputs[i + 24], &voter_instances[i + 4]); - } - - // state_transition(nullifier) == voter(nullifier) + // previous(pk_enc) == voter(pk_enc) for i in 0..4 { - assert_eq!( - public_inputs[i + 64].value(), - voter_instances[i + 24].value() - ); - // ctx.constrain_equal(&public_inputs[i + 64], &voter_instances[i + 24]); + // assert_eq!(previous_instances[i].value(), voter_instances[i].value()); + ctx.constrain_equal(&previous_instances[4 * LIMBS + i + 1], &voter_instances[i]); } // state_transition(nullifier_old_root) == previous(nullifier_new_root) - assert_eq!( - public_inputs[68].value(), - previous_instances[4 * LIMBS + 1 + 25].value() - ); - // ctx.constrain_equal(&public_inputs[68], &previous_instances[4 * LIMBS + 1 + 25]); + // assert_eq!( + // public_inputs[68].value(), + // previous_instances[4 * LIMBS + 1 + 25].value() + // ); + ctx.constrain_equal(&public_inputs[20], &previous_instances[4 * LIMBS + 1 + 25]); // previous(membership_root]) == voter(membership_root) - assert_eq!( - previous_instances[4 * LIMBS + 1 + 26].value(), - voter_instances[28].value() - ); - // ctx.constrain_equal( - // &previous_instances[4 * LIMBS + 1 + 26], - // &voter_instances[28], + // assert_eq!( + // previous_instances[4 * LIMBS + 1 + 26].value(), + // voter_instances[28].value() // ); + ctx.constrain_equal( + &previous_instances[4 * LIMBS + 1 + 26], + &voter_instances[28], + ); // voter(proposal_id) == previous(proposal_id) - assert_eq!( - voter_instances[29].value(), - previous_instances[4 * LIMBS + 1 + 27].value() - ); - // ctx.constrain_equal( - // &voter_instances[29], - // &previous_instances[4 * LIMBS + 1 + 27], + // assert_eq!( + // voter_instances[29].value(), + // previous_instances[4 * LIMBS + 1 + 27].value() // ); + ctx.constrain_equal( + &voter_instances[29], + &previous_instances[4 * LIMBS + 1 + 27], + ); *self.inner.pool(0) = pool; @@ -731,9 +701,6 @@ pub mod recursion { ) .copied(), ); - - // self.inner.calculate_params(Some(10)); - // println!("recursion params: {:?}", self.inner.params()); } pub fn initial_snark( @@ -839,12 +806,9 @@ pub mod recursion { init_aggr_instances: Vec, state_transition_input: StateTransitionInput, ) -> ProvingKey { - println!("pk recursion_config: {:?}", recursion_config); let recursion = RecursionCircuit::new( - CircuitBuilderStage::Keygen, recursion_params, gen_dummy_snark::>(voter_params, Some(voter_vk), voter_config), - state_transition_input, RecursionCircuit::initial_snark( recursion_params, None, @@ -852,24 +816,24 @@ pub mod recursion { init_aggr_instances, ), 0, - recursion_config, + recursion_config.clone(), + state_transition_input, ); // we cannot auto-configure the circuit because dummy_snark must know the configuration beforehand // uncomment the following line only in development to test and print out the optimal configuration ahead of time // recursion.inner.0.builder.borrow().config(recursion_params.k() as usize, Some(10)); + gen_pk(recursion_params, &recursion) } pub fn gen_recursion_snark( - stage: CircuitBuilderStage, recursion_params: &ParamsKZG, recursion_pk: &ProvingKey, recursion_config: BaseCircuitParams, voter_snarks: Vec, - state_transition_inputs: Vec>, init_aggr_instances: Vec, + state_transition_inputs: Vec>, ) -> Snark { - println!("snark recursion_config: {:?}", recursion_config); let mut previous = RecursionCircuit::initial_snark( recursion_params, Some(recursion_pk.get_vk()), @@ -878,13 +842,12 @@ pub mod recursion { ); for (round, voter) in voter_snarks.into_iter().enumerate() { let recursion = RecursionCircuit::new( - stage, recursion_params, voter, - state_transition_inputs[round].clone(), previous, round, recursion_config.clone(), + state_transition_inputs[round].clone(), ); println!("Generate recursion snark for round {}", round); previous = gen_snark(recursion_params, recursion_pk, recursion); @@ -896,11 +859,13 @@ pub mod recursion { #[cfg(test)] mod test { use std::path::{Path, PathBuf}; + use std::time::Instant; use std::{fs, io::BufReader}; use ark_std::{end_timer, start_timer}; + use halo2_base::utils::decompose_biguint; use halo2_base::{ - gates::circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage}, + gates::circuit::{builder::BaseCircuitBuilder, BaseCircuitParams}, halo2_proofs::{ halo2curves::bn256::{Fr, G1Affine}, plonk::ProvingKey, @@ -911,7 +876,7 @@ mod test { use snark_verifier_sdk::{snark_verifier::verifier::SnarkVerifier, NativeLoader}; use voter::VoterCircuit; - use crate::{state_transition::StateTransitionCircuit, utils::generate_wrapper_circuit_input}; + use crate::utils::generate_wrapper_circuit_input; use super::{ gen_pk, gen_snark, @@ -934,9 +899,8 @@ mod test { #[test] fn test_recursion() { const GEN_VOTER_PK: bool = true; - const GEN_STATE_TRANSITION_PK: bool = true; const GEN_RECURSION_PK: bool = true; - + println!("STOP UNDO"); let num_round = 3; let (voter_input, state_transition_input) = generate_wrapper_circuit_input(num_round); @@ -992,6 +956,11 @@ mod test { }; let recursion_params = gen_srs(k as u32); + let mut init_vote = Vec::::new(); + for x in state_transition_input[0].prev_vote.iter() { + init_vote.extend(decompose_biguint::(x, 4, 88)); + } + // Init Base Instances let mut base_instances = [ Fr::zero(), // preprocessed_digest @@ -1001,7 +970,7 @@ mod test { voter_snark.instances[0][3], ] .to_vec(); - base_instances.extend(voter_snark.instances[0][4..24].iter()); // init_vote + base_instances.extend(init_vote); // init_vote base_instances.extend([ state_transition_input[0].nullifier_tree.get_old_root(), // nullifier_old_root state_transition_input[0].nullifier_tree.get_old_root(), // nullifier_new_root @@ -1011,6 +980,7 @@ mod test { ]); let pk_time = start_timer!(|| "Generate recursion pk"); + let recursion_pk: ProvingKey; if GEN_RECURSION_PK { println!("Generating recursion pk"); @@ -1055,14 +1025,14 @@ mod test { println!("Starting recursion..."); let pf_time = start_timer!(|| "Generate full recursive snark"); + let start = Instant::now(); let final_snark = recursion::gen_recursion_snark( - CircuitBuilderStage::Mock, &recursion_params, &recursion_pk, recursion_config, voter_snarks, - state_transition_input.clone(), base_instances, + state_transition_input.clone(), ); end_timer!(pf_time); @@ -1085,5 +1055,8 @@ mod test { PlonkVerifier::verify(&dk, &final_snark.protocol, &final_snark.instances, &proof) .unwrap(); } + println!("Time taken for recursion: {:?}", start.elapsed()); } + + //function for Fr to Vec } diff --git a/voter/src/lib.rs b/voter/src/lib.rs index 88de07c..346b6a1 100644 --- a/voter/src/lib.rs +++ b/voter/src/lib.rs @@ -247,9 +247,11 @@ pub fn voter_circuit( // ctx.constrain_equal(&vote_sum_assigned, &one); //PK_ENC_n + println!("pk_enc n : {:?}", pk_enc.n.limbs().to_vec()); public_inputs.extend(pk_enc.n.limbs().to_vec()); //PK_ENC_g + println!("pk_enc n : {:?}", pk_enc.n.limbs().to_vec()); public_inputs.extend(pk_enc.g.limbs().to_vec()); // 2. Verify correct vote encryption From fd6ae5213d65cacbfd4fec7d67506053dad7d451 Mon Sep 17 00:00:00 2001 From: 0xvikasrushi <0xvikas@gmail.com> Date: Tue, 10 Sep 2024 20:33:21 +0530 Subject: [PATCH 4/7] fix: updated benches --- .../benches/state_transition_circuit.rs | 225 +++++++++-------- aggregator/benches/wrapper_circuit.rs | 233 +++++++++--------- aggregator/src/utils.rs | 14 +- 3 files changed, 255 insertions(+), 217 deletions(-) diff --git a/aggregator/benches/state_transition_circuit.rs b/aggregator/benches/state_transition_circuit.rs index 64db975..8b0b9de 100644 --- a/aggregator/benches/state_transition_circuit.rs +++ b/aggregator/benches/state_transition_circuit.rs @@ -1,97 +1,128 @@ -// use aggregator::state_transition::{state_transition_circuit, StateTransitionInput}; -// use aggregator::utils::generate_random_state_transition_circuit_inputs; -// use ark_std::{end_timer, start_timer}; -// use halo2_base::gates::circuit::BaseCircuitParams; -// use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; -// use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; -// use halo2_base::AssignedValue; -// use halo2_base::{ -// halo2_proofs::{ -// halo2curves::bn256::{Bn256, Fr}, -// plonk::*, -// poly::kzg::commitment::ParamsKZG, -// }, -// utils::testing::gen_proof, -// }; -// use pprof::criterion::{Output, PProfProfiler}; -// use rand::rngs::OsRng; - -// use criterion::{criterion_group, criterion_main}; -// use criterion::{BenchmarkId, Criterion}; - -// const K: u32 = 15; - -// fn state_transition_circuit_bench( -// stage: CircuitBuilderStage, -// input: StateTransitionInput, -// config_params: Option, -// break_points: Option, -// ) -> RangeCircuitBuilder { -// let k = K as usize; -// let lookup_bits = k - 1; -// let mut builder = match stage { -// CircuitBuilderStage::Prover => { -// RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) -// } -// _ => RangeCircuitBuilder::from_stage(stage) -// .use_k(k) -// .use_lookup_bits(lookup_bits), -// }; - -// let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); -// let range = builder.range_chip(); - -// let mut public_inputs = Vec::>::new(); -// state_transition_circuit(builder.main(0), &range, input, &mut public_inputs); - -// end_timer!(start0); -// if !stage.witness_gen_only() { -// builder.calculate_params(Some(20)); -// } -// builder -// } - -// fn bench(c: &mut Criterion) { -// let state_transition_input = generate_random_state_transition_circuit_inputs(); -// let circuit = state_transition_circuit_bench( -// CircuitBuilderStage::Keygen, -// state_transition_input.clone(), -// None, -// None, -// ); -// let config_params = circuit.params(); - -// let params = ParamsKZG::::setup(K, OsRng); -// let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); -// let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); -// let break_points = circuit.break_points(); - -// let mut group = c.benchmark_group("plonk-prover"); -// group.sample_size(10); -// group.bench_with_input( -// BenchmarkId::new("state transition circuit", K), -// &(¶ms, &pk, &state_transition_input), -// |bencher, &(params, pk, state_transition_input)| { -// let input = state_transition_input.clone(); -// bencher.iter(|| { -// let circuit = state_transition_circuit_bench( -// CircuitBuilderStage::Prover, -// input.clone(), -// Some(config_params.clone()), -// Some(break_points.clone()), -// ); - -// gen_proof(params, pk, circuit); -// }) -// }, -// ); -// group.finish() -// } - -// criterion_group! { -// name = benches; -// config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); -// targets = bench -// } -// criterion_main!(benches); -fn main() {} +use aggregator::state_transition::{state_transition_circuit, StateTransitionInput}; +use aggregator::utils::{ + assign_big_uint, compress_native_nullifier, generate_random_state_transition_circuit_inputs, +}; +use ark_std::{end_timer, start_timer}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; +use halo2_base::AssignedValue; +use halo2_base::{ + halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr}, + plonk::*, + poly::kzg::commitment::ParamsKZG, + }, + utils::testing::gen_proof, +}; +use pprof::criterion::{Output, PProfProfiler}; +use rand::rngs::OsRng; + +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; + +const K: u32 = 15; + +fn state_transition_circuit_bench( + stage: CircuitBuilderStage, + input: StateTransitionInput, + config_params: Option, + break_points: Option, +) -> RangeCircuitBuilder { + let k = K as usize; + let lookup_bits = k - 1; + let mut builder = match stage { + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) + } + _ => RangeCircuitBuilder::from_stage(stage) + .use_k(k) + .use_lookup_bits(lookup_bits), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + let range = builder.range_chip(); + + let mut public_inputs = Vec::>::new(); + let mut state_transition_vec = Vec::>::new(); + let ctx = builder.main(0); + + state_transition_vec.extend( + compress_native_nullifier(&input.nullifier) + .iter() + .map(|&x| ctx.load_witness(x)), + ); + + let enc_g = input.pk_enc.g; + let enc_h = input.pk_enc.n; + + state_transition_vec.extend(assign_big_uint(ctx, &enc_g)); + state_transition_vec.extend(assign_big_uint(ctx, &enc_h)); + + state_transition_vec.extend( + input + .incoming_vote + .iter() + .flat_map(|x| assign_big_uint(ctx, x)), + ); + + state_transition_vec.extend(input.prev_vote.iter().flat_map(|x| assign_big_uint(ctx, x))); + + state_transition_circuit( + ctx, + &range, + input.nullifier_tree, + state_transition_vec, + &mut public_inputs, + ); + + end_timer!(start0); + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder +} + +fn bench(c: &mut Criterion) { + let state_transition_input = generate_random_state_transition_circuit_inputs(); + let circuit = state_transition_circuit_bench( + CircuitBuilderStage::Keygen, + state_transition_input.clone(), + None, + None, + ); + let config_params = circuit.params(); + + let params = ParamsKZG::::setup(K, OsRng); + let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.break_points(); + + let mut group = c.benchmark_group("plonk-prover"); + group.sample_size(10); + group.bench_with_input( + BenchmarkId::new("state transition circuit", K), + &(¶ms, &pk, &state_transition_input), + |bencher, &(params, pk, state_transition_input)| { + let input = state_transition_input.clone(); + bencher.iter(|| { + let circuit = state_transition_circuit_bench( + CircuitBuilderStage::Prover, + input.clone(), + Some(config_params.clone()), + Some(break_points.clone()), + ); + + gen_proof(params, pk, circuit); + }) + }, + ); + group.finish() +} + +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); + targets = bench +} +criterion_main!(benches); diff --git a/aggregator/benches/wrapper_circuit.rs b/aggregator/benches/wrapper_circuit.rs index 9aa27f2..0f22d38 100644 --- a/aggregator/benches/wrapper_circuit.rs +++ b/aggregator/benches/wrapper_circuit.rs @@ -1,130 +1,125 @@ -// use aggregator::utils::generate_wrapper_circuit_input; -// use aggregator::wrapper::common::gen_dummy_snark; -// use aggregator::wrapper::common::gen_pk; -// use aggregator::wrapper::common::gen_proof; -// use aggregator::wrapper::common::gen_snark; -// use aggregator::wrapper::recursion::RecursionCircuit; -// use halo2_base::gates::circuit::BaseCircuitParams; -// use halo2_base::gates::circuit::CircuitBuilderStage; -// use halo2_base::halo2_proofs::dev::MockProver; -// use halo2_base::halo2_proofs::{halo2curves::bn256::Fr, plonk::*}; -// use halo2_base::utils::decompose_biguint; -// use halo2_base::utils::fs::gen_srs; +use aggregator::utils::generate_wrapper_circuit_input; +use aggregator::wrapper::common::gen_dummy_snark; +use aggregator::wrapper::common::gen_pk; +use aggregator::wrapper::common::gen_proof; +use aggregator::wrapper::common::gen_snark; +use aggregator::wrapper::recursion::RecursionCircuit; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::halo2_proofs::{halo2curves::bn256::Fr, plonk::*}; +use halo2_base::utils::decompose_biguint; +use halo2_base::utils::fs::gen_srs; -// use criterion::{criterion_group, criterion_main}; -// use criterion::{BenchmarkId, Criterion}; +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; -// use pprof::criterion::{Output, PProfProfiler}; -// use snark_verifier_sdk::CircuitExt; -// use voter::VoterCircuit; +use pprof::criterion::{Output, PProfProfiler}; +use snark_verifier_sdk::CircuitExt; +use voter::VoterCircuit; -// const K: u32 = 22; +const K: u32 = 22; -// fn bench(c: &mut Criterion) { -// let (voter_inputs, state_transition_inputs) = generate_wrapper_circuit_input(1); +fn bench(c: &mut Criterion) { + let (voter_inputs, state_transition_inputs) = generate_wrapper_circuit_input(1); -// // Generating voter proof -// let voter_config = BaseCircuitParams { -// k: 15, -// num_advice_per_phase: vec![1], -// num_lookup_advice_per_phase: vec![1, 0, 0], -// num_fixed: 1, -// lookup_bits: Some(14), -// num_instance_columns: 1, -// }; -// let voter_params = gen_srs(15); -// let voter_circuit = VoterCircuit::new(voter_config.clone(), voter_inputs[0].clone()); -// let voter_pk = gen_pk(&voter_params, &voter_circuit); -// let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); + // Generating voter proof + let voter_config = BaseCircuitParams { + k: 15, + num_advice_per_phase: vec![1], + num_lookup_advice_per_phase: vec![1, 0, 0], + num_fixed: 1, + lookup_bits: Some(14), + num_instance_columns: 1, + }; + let voter_params = gen_srs(15); + let voter_circuit = VoterCircuit::new(voter_config.clone(), voter_inputs[0].clone()); + let voter_pk = gen_pk(&voter_params, &voter_circuit); + let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit); -// let recursion_config = BaseCircuitParams { -// k: K as usize, -// num_advice_per_phase: vec![4], -// num_lookup_advice_per_phase: vec![1, 0, 0], -// num_fixed: 1, -// lookup_bits: Some((K - 1) as usize), -// num_instance_columns: 1, -// }; -// let recursion_params = gen_srs(K); + let recursion_config = BaseCircuitParams { + k: K as usize, + num_advice_per_phase: vec![4], + num_lookup_advice_per_phase: vec![1, 0, 0], + num_fixed: 1, + lookup_bits: Some((K - 1) as usize), + num_instance_columns: 1, + }; + let recursion_params = gen_srs(K); -// let mut init_vote = Vec::::new(); -// for x in state_transition_inputs[0].prev_vote.iter() { -// init_vote.extend(decompose_biguint::(x, 4, 88)); -// } + let mut init_vote = Vec::::new(); + for x in state_transition_inputs[0].prev_vote.iter() { + init_vote.extend(decompose_biguint::(x, 4, 88)); + } -// // Init Base Instances -// let mut base_instances = [ -// Fr::zero(), // preprocessed_digest -// voter_snark.instances[0][0], // pk_enc_n -// voter_snark.instances[0][1], -// voter_snark.instances[0][2], // pk_enc_g -// voter_snark.instances[0][3], -// ] -// .to_vec(); -// base_instances.extend(init_vote); // init_vote -// base_instances.extend([ -// state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_old_root -// state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_new_root -// voter_snark.instances[0][28], // membership_root -// voter_snark.instances[0][29], // proposal_id -// Fr::from(0), // round -// ]); + // Init Base Instances + let mut base_instances = [ + Fr::zero(), // preprocessed_digest + voter_snark.instances[0][0], // pk_enc_n + voter_snark.instances[0][1], + voter_snark.instances[0][2], // pk_enc_g + voter_snark.instances[0][3], + ] + .to_vec(); + base_instances.extend(init_vote); // init_vote + base_instances.extend([ + state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_old_root + state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_new_root + voter_snark.instances[0][28], // membership_root + voter_snark.instances[0][29], // proposal_id + Fr::from(0), // round + ]); -// let recursion_circuit = RecursionCircuit::new( -// CircuitBuilderStage::Keygen, -// &recursion_params, -// gen_dummy_snark::>(&voter_params, Some(voter_pk.get_vk()), voter_config), -// RecursionCircuit::initial_snark( -// &recursion_params, -// None, -// recursion_config.clone(), -// base_instances.clone(), -// ), -// 0, -// recursion_config, -// state_transition_inputs[0].clone(), -// ); -// let pk = gen_pk(&recursion_params, &recursion_circuit); -// let config_params = recursion_circuit.inner().params(); + let recursion_circuit = RecursionCircuit::new( + &recursion_params, + gen_dummy_snark::>(&voter_params, Some(voter_pk.get_vk()), voter_config), + RecursionCircuit::initial_snark( + &recursion_params, + None, + recursion_config.clone(), + base_instances.clone(), + ), + 0, + recursion_config, + state_transition_inputs[0].clone(), + ); + let pk = gen_pk(&recursion_params, &recursion_circuit); + let config_params = recursion_circuit.inner().params(); -// let mut group = c.benchmark_group("plonk-prover"); -// group.sample_size(10); -// group.bench_with_input( -// BenchmarkId::new("wrapper circuit", K), -// &(&recursion_params, &pk, &voter_snark), -// |bencher, &(params, pk, voter_snark)| { -// let cloned_voter_snark = voter_snark; -// bencher.iter(|| { -// let cloned_config_params = config_params.clone(); -// let circuit = RecursionCircuit::new( -// CircuitBuilderStage::Prover -// , -// ¶ms, -// cloned_voter_snark.clone(), -// RecursionCircuit::initial_snark( -// ¶ms, -// None, -// cloned_config_params.clone(), -// base_instances.clone(), -// ), -// 0, -// cloned_config_params, -// state_transition_inputs[0].clone(), -// ); -// println!("reached proof generation"); -// let instances = circuit.inner().instances().clone(); -// gen_proof(params, pk, circuit, instances); -// println!("completed proof generation") -// }) -// }, -// ); -// group.finish() -// } + let mut group = c.benchmark_group("plonk-prover"); + group.sample_size(10); -// criterion_group! { -// name = benches; -// config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); -// targets = bench -// } -// criterion_main!(benches); -fn main() {} + group.bench_with_input( + BenchmarkId::new("wrapper circuit", K), + &(&recursion_params, &pk, &voter_snark), + |bencher, &(params, pk, voter_snark)| { + let cloned_voter_snark = voter_snark; + bencher.iter(|| { + let cloned_config_params = config_params.clone(); + let circuit = RecursionCircuit::new( + ¶ms, + cloned_voter_snark.clone(), + RecursionCircuit::initial_snark( + ¶ms, + None, + cloned_config_params.clone(), + base_instances.clone(), + ), + 0, + cloned_config_params, + state_transition_inputs[0].clone(), + ); + println!("reached proof generation"); + let instances = circuit.inner().instances().clone(); + gen_proof(params, pk, circuit, instances); + println!("completed proof generation") + }) + }, + ); + group.finish() +} + +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); + targets = bench +} +criterion_main!(benches); diff --git a/aggregator/src/utils.rs b/aggregator/src/utils.rs index d633ee6..078c065 100644 --- a/aggregator/src/utils.rs +++ b/aggregator/src/utils.rs @@ -3,7 +3,8 @@ use halo2_base::halo2_proofs::halo2curves::bn256::Fr; use halo2_base::halo2_proofs::halo2curves::group::Curve; use halo2_base::halo2_proofs::halo2curves::secp256k1::{Fq, Secp256k1, Secp256k1Affine}; use halo2_base::halo2_proofs::halo2curves::secq256k1::Fp; -use halo2_base::utils::{fe_to_biguint, ScalarField}; +use halo2_base::utils::{fe_to_biguint, BigPrimeField, ScalarField}; +use halo2_base::{AssignedValue, Context}; use halo2_ecc::*; use num_bigint::{BigUint, RandBigInt}; use paillier_chip::paillier::{paillier_add_native, paillier_enc_native}; @@ -493,3 +494,14 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput input } + +pub fn assign_big_uint( + ctx: &mut Context, + value: &BigUint, +) -> Vec> { + let limbs = value.to_u64_digits(); + limbs + .into_iter() + .map(|limb| ctx.load_witness(F::from(limb))) + .collect() +} From 5a53f89f58d70fd53f9d5e66f3f6da8a1146b472 Mon Sep 17 00:00:00 2001 From: 0xvikasrushi <0xvikas@gmail.com> Date: Tue, 17 Sep 2024 17:55:03 +0530 Subject: [PATCH 5/7] fix: update new imt branch --- aggregator/Cargo.toml | 2 +- aggregator/src/state_transition.rs | 33 +++++++++++++----------------- aggregator/src/utils.rs | 31 +++++++++++----------------- 3 files changed, 27 insertions(+), 39 deletions(-) diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index c258fe6..1f2b21a 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -16,7 +16,7 @@ k256 = { version = "0.13.3", features = ["arithmetic", "hash2curve", "expose-fie pse-poseidon = { git = "https://github.com/aerius-labs/pse-poseidon.git", branch = "feat/stateless-hash" } rand = "0.8.5" sha2 = "0.10.8" -indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merkle-tree-halo2.git" , branch = "feat/aggr-indexed" } +indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merkle-tree-halo2.git" , branch = "fix/imt" } rand_chacha = "0.3.1" snark-verifier-sdk = { git = "https://github.com/aerius-labs/snark-verifier.git", branch = "feat/custom" } ark-std = "0.4.0" diff --git a/aggregator/src/state_transition.rs b/aggregator/src/state_transition.rs index 765cb3f..33e751d 100644 --- a/aggregator/src/state_transition.rs +++ b/aggregator/src/state_transition.rs @@ -31,7 +31,6 @@ pub struct IndexedMerkleTreeInput { new_leaf_index: F, new_leaf_proof: Vec, new_leaf_proof_helper: Vec, - is_new_leaf_largest: F, } impl IndexedMerkleTreeInput { @@ -45,7 +44,6 @@ impl IndexedMerkleTreeInput { new_leaf_index: F, new_leaf_proof: Vec, new_leaf_proof_helper: Vec, - is_new_leaf_largest: F, ) -> Self { Self { old_root, @@ -57,7 +55,6 @@ impl IndexedMerkleTreeInput { new_leaf_index, new_leaf_proof, new_leaf_proof_helper, - is_new_leaf_largest, } } pub fn get_old_root(&self) -> F { @@ -206,7 +203,6 @@ pub fn state_transition_circuit( let new_leaf = IndexedMerkleTreeLeaf::new(val, next_val, next_idx); let new_leaf_index = ctx.load_witness(nullifier_tree.new_leaf_index); - let is_new_leaf_largest = ctx.load_witness(nullifier_tree.is_new_leaf_largest); let low_leaf_proof = nullifier_tree .low_leaf_proof @@ -229,21 +225,20 @@ pub fn state_transition_circuit( .map(|x| ctx.load_witness(*x)) .collect::>(); - // insert_leaf::( - // ctx, - // range, - // &hasher, - // &old_root, - // &low_leaf, - // &low_leaf_proof, - // &low_leaf_proof_helper, - // &new_root, - // &new_leaf, - // &new_leaf_index, - // &new_leaf_proof, - // &new_leaf_proof_helper, - // &is_new_leaf_largest, - // ); + insert_leaf::( + ctx, + range, + &hasher, + &old_root, + &low_leaf, + &low_leaf_proof, + &low_leaf_proof_helper, + &new_root, + &new_leaf, + &new_leaf_index, + &new_leaf_proof, + &new_leaf_proof_helper, + ); for enc_vote in aggr_vote { public_inputs.extend(enc_vote.limbs()); diff --git a/aggregator/src/utils.rs b/aggregator/src/utils.rs index 078c065..9cc6630 100644 --- a/aggregator/src/utils.rs +++ b/aggregator/src/utils.rs @@ -116,7 +116,7 @@ fn generate_state_transition_circuit_inputs( let new_val = native_hasher.squeeze_and_reset(); let mut tree = - IndexedMerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); + IndexedMerkleTree::::new_default_leaf(nullifier_tree_preimages.len()); let old_root = tree.get_root(); @@ -124,9 +124,10 @@ fn generate_state_transition_circuit_inputs( update_idx_leaf(nullifier_tree_preimages.clone(), new_val, round); let low_leaf = nullifier_tree_preimages[low_leaf_idx].clone(); let (low_leaf_proof, low_leaf_proof_helper) = tree.get_proof(low_leaf_idx); + assert_eq!( tree.verify_proof( - &leaves[low_leaf_idx], + &mut native_hasher, low_leaf_idx, &tree.get_root(), &low_leaf_proof @@ -148,11 +149,12 @@ fn generate_state_transition_circuit_inputs( updated_idx_leaves[round as usize].next_idx, ]); leaves[round as usize] = new_native_hasher.squeeze_and_reset(); - tree = IndexedMerkleTree::::new(&mut new_native_hasher, leaves.clone()).unwrap(); + let (new_leaf_proof, new_leaf_proof_helper) = tree.get_proof(round as usize); assert_eq!( tree.verify_proof( - &leaves[round as usize], + &mut new_native_hasher, + // &leaves[round as usize], round as usize, &tree.get_root(), &new_leaf_proof @@ -167,11 +169,6 @@ fn generate_state_transition_circuit_inputs( next_idx: updated_idx_leaves[round as usize].next_idx, }; let new_leaf_index = Fr::from(round); - let is_new_leaf_largest = if new_leaf.next_val == Fr::zero() { - Fr::one() - } else { - Fr::zero() - }; let idx_input = IndexedMerkleTreeInput::new( old_root, @@ -183,7 +180,6 @@ fn generate_state_transition_circuit_inputs( new_leaf_index, new_leaf_proof, new_leaf_proof_helper, - is_new_leaf_largest, ); let input = StateTransitionInput::new( @@ -384,6 +380,7 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput let nullifier_affine = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let mut native_hasher = Poseidon::::new(R_F, R_P); + let mut tree = IndexedMerkleTree::::new_default_leaf(tree_size as usize); // Filling leaves with dfault values. for _ in 0..tree_size { @@ -393,8 +390,8 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput let nullifier_compress = compress_native_nullifier(&nullifier_affine); native_hasher.update(&nullifier_compress); let new_val = native_hasher.squeeze_and_reset(); - let mut tree = - IndexedMerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); + + tree.insert_leaf(&mut native_hasher, leaves[0], 0); let old_root = tree.get_root(); let low_leaf = IMTLeaf:: { @@ -403,8 +400,9 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput next_idx: Fr::from(0u64), }; let (low_leaf_proof, low_leaf_proof_helper) = tree.get_proof(0); + assert_eq!( - tree.verify_proof(&leaves[0], 0, &tree.get_root(), &low_leaf_proof), + tree.verify_proof(&mut native_hasher, 0, &tree.get_root(), &low_leaf_proof), true ); @@ -424,12 +422,9 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput native_hasher.update(&[new_val, Fr::from(0u64), Fr::from(0u64)]); leaves[1] = native_hasher.squeeze_and_reset(); - tree = IndexedMerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); - - let (new_low_leaf_proof, _) = tree.get_proof(0); let (new_leaf_proof, new_leaf_proof_helper) = tree.get_proof(1); assert_eq!( - tree.verify_proof(&leaves[1], 1, &tree.get_root(), &new_leaf_proof), + tree.verify_proof(&mut native_hasher, 1, &tree.get_root(), &new_leaf_proof), true ); @@ -440,7 +435,6 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput next_idx: Fr::from(0u64), }; let new_leaf_index = Fr::from(1u64); - let is_new_leaf_largest = Fr::from(true); let idx_input = IndexedMerkleTreeInput::new( old_root, @@ -452,7 +446,6 @@ pub fn generate_random_state_transition_circuit_inputs() -> StateTransitionInput new_leaf_index, new_leaf_proof, new_leaf_proof_helper, - is_new_leaf_largest, ); let mut rng = thread_rng(); From 329c98dc8d22294d362405762c1ce52cd81dcb06 Mon Sep 17 00:00:00 2001 From: 0xvikasrushi <0xvikas@gmail.com> Date: Wed, 18 Sep 2024 11:32:14 +0530 Subject: [PATCH 6/7] chore: imt branch in voter circuits --- voter/Cargo.toml | 1 + voter/src/lib.rs | 101 ++++++++++++++++++++++------------------- voter_tests/Cargo.toml | 1 + voter_tests/src/lib.rs | 21 +++++++-- 4 files changed, 72 insertions(+), 52 deletions(-) diff --git a/voter/Cargo.toml b/voter/Cargo.toml index b96f310..b4e9c19 100644 --- a/voter/Cargo.toml +++ b/voter/Cargo.toml @@ -13,6 +13,7 @@ pse-poseidon = { git = "https://github.com/aerius-labs/pse-poseidon.git", branch num-bigint = { version = "0.4.4", features = ["serde"] } itertools = "0.12.0" serde ="1.0.196" +indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merkle-tree-halo2.git" , branch = "fix/imt" } [dev-dependencies] criterion = "0.5.1" diff --git a/voter/src/lib.rs b/voter/src/lib.rs index 346b6a1..8d22a1d 100644 --- a/voter/src/lib.rs +++ b/voter/src/lib.rs @@ -23,8 +23,9 @@ use halo2_ecc::{ fields::FieldChip, secp256k1::{sha256::Sha256Chip, FpChip, FqChip}, }; +use indexed_merkle_tree_halo2::indexed_merkle_tree::verify_merkle_proof; use itertools::Itertools; -use merkletree::verify_membership_proof; + use num_bigint::BigUint; use biguint_halo2::big_uint::chip::BigUintChip; @@ -226,25 +227,31 @@ pub fn voter_circuit( g: g_assigned, }; + let zero = ctx.load_zero(); + let one = ctx.load_constant(F::from(1)); + // 1. Verify if the voter is in the membership tree - // verify_membership_proof( - // ctx, - // gate, - // &hasher, - // &membership_root, - // &leaf, - // &membership_proof, - // &membership_proof_helper, - // ); + verify_merkle_proof( + ctx, + range, + &hasher, + &membership_root, + &leaf, + &membership_proof, + &membership_proof_helper, + &zero, + &one, + false, + ); // Check to verify correct votes have been passed. - // let _ = vote_assigned_fe.iter().map(|x| gate.assert_bit(ctx, *x)); - // let zero = ctx.load_zero(); - // let one = ctx.load_constant(F::ONE); - // let vote_sum_assigned = vote_assigned_fe - // .iter() - // .fold(zero, |zero, x| gate.add(ctx, zero, *x)); - // ctx.constrain_equal(&vote_sum_assigned, &one); + let _ = vote_assigned_fe.iter().map(|x| gate.assert_bit(ctx, *x)); + let zero = ctx.load_zero(); + let one = ctx.load_constant(F::ONE); + let vote_sum_assigned = vote_assigned_fe + .iter() + .fold(zero, |zero, x| gate.add(ctx, zero, *x)); + ctx.constrain_equal(&vote_sum_assigned, &one); //PK_ENC_n println!("pk_enc n : {:?}", pk_enc.n.limbs().to_vec()); @@ -256,46 +263,46 @@ pub fn voter_circuit( // 2. Verify correct vote encryption for i in 0..input.vote.len() { - // let _vote_enc = paillier_chip - // .encrypt(ctx, &pk_enc, &vote_assigned_big[i], &r_assigned[i]) - // .unwrap(); + let _vote_enc = paillier_chip + .encrypt(ctx, &pk_enc, &vote_assigned_big[i], &r_assigned[i]) + .unwrap(); - // biguint_chip - // .assert_equal_fresh(ctx, &vote_enc_assigned_big[i], &_vote_enc) - // .unwrap(); + biguint_chip + .assert_equal_fresh(ctx, &vote_enc_assigned_big[i], &_vote_enc) + .unwrap(); //ENC_VOTE public_inputs.append(&mut vote_enc_assigned_big[i].limbs().to_vec()); } // 3. Verify nullifier - // let message = proposal_id.value().to_bytes_le()[..2] - // .iter() - // .map(|v| ctx.load_witness(F::from(*v as u64))) - // .collect::>(); - // { - // let mut _proposal_id = ctx.load_zero(); - // for i in 0..2 { - // _proposal_id = gate.mul_add( - // ctx, - // message[i], - // QuantumCell::Constant(F::from(1u64 << (8 * i))), - // _proposal_id, - // ); - // } - // ctx.constrain_equal(&_proposal_id, &proposal_id); - // } + let message = proposal_id.value().to_bytes_le()[..2] + .iter() + .map(|v| ctx.load_witness(F::from(*v as u64))) + .collect::>(); + { + let mut _proposal_id = ctx.load_zero(); + for i in 0..2 { + _proposal_id = gate.mul_add( + ctx, + message[i], + QuantumCell::Constant(F::from(1u64 << (8 * i))), + _proposal_id, + ); + } + ctx.constrain_equal(&_proposal_id, &proposal_id); + } let compressed_nullifier = compress_nullifier(ctx, range, &nullifier); - // let plume_input = PlumeInput::new( - // nullifier, - // s_nullifier.clone(), - // c_nullifier, - // pk_voter, - // message, - // ); - // verify_plume(ctx, &ecc_chip, &sha256_chip, 4, 4, plume_input); + let plume_input = PlumeInput::new( + nullifier, + s_nullifier.clone(), + c_nullifier, + pk_voter, + message, + ); + verify_plume(ctx, &ecc_chip, &sha256_chip, 4, 4, plume_input); //NULLIFIER public_inputs.extend(compressed_nullifier.to_vec()); diff --git a/voter_tests/Cargo.toml b/voter_tests/Cargo.toml index 6094f04..bf0add9 100644 --- a/voter_tests/Cargo.toml +++ b/voter_tests/Cargo.toml @@ -17,3 +17,4 @@ serde = "1.0.196" k256 = { version = "0.13.3", features = ["arithmetic", "hash2curve", "expose-field", "sha2"]} num-bigint = { version = "0.4.4", features = ["serde"] } voter = { path = "../voter" } +indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merkle-tree-halo2.git" , branch = "fix/imt" } diff --git a/voter_tests/src/lib.rs b/voter_tests/src/lib.rs index 80575ce..411edad 100644 --- a/voter_tests/src/lib.rs +++ b/voter_tests/src/lib.rs @@ -5,6 +5,7 @@ use halo2_base::halo2_proofs::halo2curves::group::Curve; use halo2_base::halo2_proofs::halo2curves::secp256k1::{Fp, Fq, Secp256k1, Secp256k1Affine}; use halo2_base::utils::{fe_to_biguint, ScalarField}; use halo2_ecc::*; +use indexed_merkle_tree_halo2::utils::IndexedMerkleTree; use k256::elliptic_curve::hash2curve::GroupDigest; use k256::elliptic_curve::sec1::ToEncodedPoint; use k256::{ @@ -18,7 +19,7 @@ use rand::rngs::OsRng; use rand::thread_rng; use sha2::{Digest, Sha256}; -use voter::merkletree::native::MerkleTree; +use voter::merkletree::native::{self, MerkleTree}; use voter::{voter_circuit, EncryptionPublicKey, VoterCircuitInput}; pub fn compress_point(point: &Secp256k1Affine) -> [u8; 33] { @@ -158,6 +159,7 @@ pub fn generate_random_voter_circuit_inputs() -> VoterCircuitInput { } let mut native_hasher = Poseidon::::new(R_F, R_P); + let mut membership_tree = IndexedMerkleTree::::new_default_leaf(treesize as usize); let mut leaves = Vec::::new(); @@ -189,18 +191,27 @@ pub fn generate_random_voter_circuit_inputs() -> VoterCircuitInput { native_hasher.update(&[Fr::ZERO]); } leaves.push(native_hasher.squeeze_and_reset()); + membership_tree.insert_leaf(&mut native_hasher, leaves[i as usize], i as usize); } - let mut membership_tree = - MerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); - let membership_root = membership_tree.get_root(); let (membership_proof, membership_proof_helper) = membership_tree.get_proof(0); + assert_eq!( - membership_tree.verify_proof(&leaves[0], 0, &membership_root, &membership_proof), + membership_tree.verify_proof(&mut native_hasher, 0, &membership_root, &membership_proof), true ); + // let mut membership_tree = + // MerkleTree::::new(&mut native_hasher, leaves.clone()).unwrap(); + + // let membership_root = membership_tree.get_root(); + // let (membership_proof, membership_proof_helper) = membership_tree.get_proof(0); + // assert_eq!( + // membership_tree.verify_proof(&leaves[0], 0, &membership_root, &membership_proof), + // true + // ); + let pk_enc = EncryptionPublicKey { n, g }; // Proposal id = 1 From d1bba77e1fb4d697ec3bc003ddefdef4b7cb9bfb Mon Sep 17 00:00:00 2001 From: 0xvikasrushi <0xvikas@gmail.com> Date: Wed, 18 Sep 2024 20:44:58 +0530 Subject: [PATCH 7/7] fix: test_recursion --- aggregator/src/utils.rs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/aggregator/src/utils.rs b/aggregator/src/utils.rs index 9cc6630..d93136c 100644 --- a/aggregator/src/utils.rs +++ b/aggregator/src/utils.rs @@ -33,7 +33,7 @@ fn generate_voter_circuit_inputs( pk_voter: Secp256k1Affine, vote: Vec, r_enc: Vec, - members_tree: &MerkleTree<'_, Fr, T, RATE>, + members_tree: &IndexedMerkleTree, round: usize, ) -> VoterCircuitInput { let membership_root = members_tree.get_root(); @@ -118,6 +118,10 @@ fn generate_state_transition_circuit_inputs( let mut tree = IndexedMerkleTree::::new_default_leaf(nullifier_tree_preimages.len()); + for i in 0..nullifier_tree_leaves.len() { + tree.insert_leaf(&mut native_hasher, nullifier_tree_leaves[i], i); + } + let old_root = tree.get_root(); let (updated_idx_leaves, low_leaf_idx) = @@ -250,12 +254,17 @@ pub fn generate_wrapper_circuit_input( members_tree_leaves.push(native_hasher.squeeze_and_reset()); } + let mut members_tree = IndexedMerkleTree::::new_default_leaf(8); + for _ in num_round..8 { native_hasher.update(&[Fr::ZERO]); - members_tree_leaves.push(native_hasher.squeeze_and_reset()); + let hash = native_hasher.squeeze_and_reset(); + members_tree_leaves.push(hash.clone()); } - let members_tree = MerkleTree::new(&mut native_hasher, members_tree_leaves.clone()).unwrap(); + for i in 0..members_tree_leaves.len() { + members_tree.insert_leaf(&mut native_hasher, members_tree_leaves[i], i); + } let mut rng = thread_rng(); let mut prev_vote = Vec::::new();