Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Circuits and Benches #7

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aggregator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ indexed-merkle-tree-halo2 = { git = "https://github.com/aerius-labs/indexed-merk
rand_chacha = "0.3.1"
snark-verifier-sdk = { git = "https://github.com/aerius-labs/snark-verifier.git", branch = "feat/custom" }
ark-std = "0.4.0"
num-traits = "0.2.18"

voter = { path = "../voter" }
voter-tests = { path = "../voter_tests" }
Expand Down
36 changes: 34 additions & 2 deletions aggregator/benches/state_transition_circuit.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use aggregator::state_transition::{state_transition_circuit, StateTransitionInput};
use aggregator::utils::generate_random_state_transition_circuit_inputs;
use aggregator::utils::{
assign_big_uint, compress_native_nullifier, generate_random_state_transition_circuit_inputs,
};
use ark_std::{end_timer, start_timer};
use halo2_base::gates::circuit::BaseCircuitParams;
use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage};
Expand Down Expand Up @@ -42,7 +44,37 @@ fn state_transition_circuit_bench(
let range = builder.range_chip();

let mut public_inputs = Vec::<AssignedValue<Fr>>::new();
state_transition_circuit(builder.main(0), &range, input, &mut public_inputs);
let mut state_transition_vec = Vec::<AssignedValue<Fr>>::new();
let ctx = builder.main(0);

state_transition_vec.extend(
compress_native_nullifier(&input.nullifier)
.iter()
.map(|&x| ctx.load_witness(x)),
);

let enc_g = input.pk_enc.g;
let enc_h = input.pk_enc.n;

state_transition_vec.extend(assign_big_uint(ctx, &enc_g));
state_transition_vec.extend(assign_big_uint(ctx, &enc_h));

state_transition_vec.extend(
input
.incoming_vote
.iter()
.flat_map(|x| assign_big_uint(ctx, x)),
);

state_transition_vec.extend(input.prev_vote.iter().flat_map(|x| assign_big_uint(ctx, x)));

state_transition_circuit(
ctx,
&range,
input.nullifier_tree,
state_transition_vec,
&mut public_inputs,
);

end_timer!(start0);
if !stage.witness_gen_only() {
Expand Down
75 changes: 24 additions & 51 deletions aggregator/benches/wrapper_circuit.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
use aggregator::state_transition::StateTransitionCircuit;
use aggregator::utils::generate_wrapper_circuit_input;
use aggregator::wrapper::common::gen_dummy_snark;
use aggregator::wrapper::common::gen_pk;
use aggregator::wrapper::common::gen_proof;
use aggregator::wrapper::common::gen_snark;
use aggregator::wrapper::recursion::RecursionCircuit;
use halo2_base::gates::circuit::BaseCircuitParams;
use halo2_base::gates::circuit::CircuitBuilderStage;
use halo2_base::halo2_proofs::{halo2curves::bn256::Fr, plonk::*};
use halo2_base::utils::decompose_biguint;
use halo2_base::utils::fs::gen_srs;
use halo2_base::{
halo2_proofs::{halo2curves::bn256::Fr, plonk::*},
utils::testing::gen_proof,
};

use criterion::{criterion_group, criterion_main};
use criterion::{BenchmarkId, Criterion};

use pprof::criterion::{Output, PProfProfiler};
use snark_verifier_sdk::CircuitExt;
use voter::VoterCircuit;

const K: u32 = 22;
Expand All @@ -37,27 +35,6 @@ fn bench(c: &mut Criterion) {
let voter_pk = gen_pk(&voter_params, &voter_circuit);
let voter_snark = gen_snark(&voter_params, &voter_pk, voter_circuit);

// Generating state transition proof
let state_transition_config = BaseCircuitParams {
k: 15,
num_advice_per_phase: vec![3],
num_lookup_advice_per_phase: vec![1, 0, 0],
num_fixed: 1,
lookup_bits: Some(14),
num_instance_columns: 1,
};
let state_transition_params = gen_srs(15);
let state_transition_circuit = StateTransitionCircuit::new(
state_transition_config.clone(),
state_transition_inputs[0].clone(),
);
let state_transition_pk = gen_pk(&state_transition_params, &state_transition_circuit);
let state_transition_snark = gen_snark(
&state_transition_params,
&state_transition_pk,
state_transition_circuit,
);

let recursion_config = BaseCircuitParams {
k: K as usize,
num_advice_per_phase: vec![4],
Expand All @@ -68,6 +45,11 @@ fn bench(c: &mut Criterion) {
};
let recursion_params = gen_srs(K);

let mut init_vote = Vec::<Fr>::new();
for x in state_transition_inputs[0].prev_vote.iter() {
init_vote.extend(decompose_biguint::<Fr>(x, 4, 88));
}

// Init Base Instances
let mut base_instances = [
Fr::zero(), // preprocessed_digest
Expand All @@ -77,24 +59,18 @@ fn bench(c: &mut Criterion) {
voter_snark.instances[0][3],
]
.to_vec();
base_instances.extend(state_transition_snark.instances[0][4..24].iter()); // init_vote
base_instances.extend(init_vote); // init_vote
base_instances.extend([
state_transition_snark.instances[0][68], // nullifier_old_root
state_transition_snark.instances[0][68], // nullifier_new_root
voter_snark.instances[0][28], // membership_root
voter_snark.instances[0][29], // proposal_id
Fr::from(0), // round
state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_old_root
state_transition_inputs[0].nullifier_tree.get_old_root(), // nullifier_new_root
voter_snark.instances[0][28], // membership_root
voter_snark.instances[0][29], // proposal_id
Fr::from(0), // round
]);

let recursion_circuit = RecursionCircuit::new(
CircuitBuilderStage::Keygen,
&recursion_params,
gen_dummy_snark::<VoterCircuit<Fr>>(&voter_params, Some(voter_pk.get_vk()), voter_config),
gen_dummy_snark::<StateTransitionCircuit<Fr>>(
&state_transition_params,
Some(state_transition_pk.get_vk()),
state_transition_config,
),
RecursionCircuit::initial_snark(
&recursion_params,
None,
Expand All @@ -103,30 +79,24 @@ fn bench(c: &mut Criterion) {
),
0,
recursion_config,
state_transition_inputs[0].clone(),
);
let pk = gen_pk(&recursion_params, &recursion_circuit);
let config_params = recursion_circuit.inner().params();

let mut group = c.benchmark_group("plonk-prover");
group.sample_size(10);

group.bench_with_input(
BenchmarkId::new("wrapper circuit", K),
&(
&recursion_params,
&pk,
&voter_snark,
&state_transition_snark,
),
|bencher, &(params, pk, voter_snark, state_transition_snark)| {
&(&recursion_params, &pk, &voter_snark),
|bencher, &(params, pk, voter_snark)| {
let cloned_voter_snark = voter_snark;
let cloned_state_transition_snark = state_transition_snark;
bencher.iter(|| {
let cloned_config_params = config_params.clone();
let circuit = RecursionCircuit::new(
CircuitBuilderStage::Prover,
&params,
cloned_voter_snark.clone(),
cloned_state_transition_snark.clone(),
RecursionCircuit::initial_snark(
&params,
None,
Expand All @@ -135,9 +105,12 @@ fn bench(c: &mut Criterion) {
),
0,
cloned_config_params,
state_transition_inputs[0].clone(),
);

gen_proof(params, pk, circuit);
println!("reached proof generation");
let instances = circuit.inner().instances().clone();
gen_proof(params, pk, circuit, instances);
println!("completed proof generation")
})
},
);
Expand Down
Loading