Skip to content

Commit

Permalink
feat(example): cli (#275)
Browse files Browse the repository at this point in the history
**Motivation**
Now we have a default path for time-profling & mem-profiling, but it is
not very convenient to vary the runtime parameters by hand. That's why I
put all runtime parameters in a separate cli example. This allows us to
do time-profiling & mem-profiling on different parameters directly on
the command line
Part of #272 

**Overview**
Fairly simple code that parses parameters with clap and runs circuit.
Covers all our examples

Once mem-profiling is in main, I'll add its description to the README
along with this example

```console
Usage: cli [OPTIONS] [PRIMARY_CIRCUIT] [SECONDARY_CIRCUIT]

Arguments:
  [PRIMARY_CIRCUIT]    [default: poseidon] [possible values: poseidon, trivial]
  [SECONDARY_CIRCUIT]  [default: trivial] [possible values: poseidon, trivial]

Options:
      --primary-circuit-k-table-size <PRIMARY_CIRCUIT_K_TABLE_SIZE>      [default: 17]
      --primary-commitment-key-size <PRIMARY_COMMITMENT_KEY_SIZE>        [default: 21]
      --primary-repeat-count <PRIMARY_REPEAT_COUNT>                      [default: 1]
      --primary-r-f <PRIMARY_R_F>                                        [default: 10]
      --primary-r-p <PRIMARY_R_P>                                        [default: 10]
      --secondary-circuit-k-table-size <SECONDARY_CIRCUIT_K_TABLE_SIZE>  [default: 17]
      --secondary-commitment-key-size <SECONDARY_COMMITMENT_KEY_SIZE>    [default: 21]
      --secondary-repeat-count <SECONDARY_REPEAT_COUNT>                  [default: 1]
      --secondary-r-f <SECONDARY_R_F>                                    [default: 10]
      --secondary-r-p <SECONDARY_R_P>                                    [default: 10]
      --limb-width <LIMB_WIDTH>                                          [default: 32]
      --limbs-count <LIMBS_COUNT>                                        [default: 10]
      --debug-mode
      --fold-step-count <FOLD_STEP_COUNT>                                [default: 1]
      --json-logs
  -h, --help                                                             Print help
  -V, --version                                                          Print version
```
  • Loading branch information
cyphersnake authored May 28, 2024
1 parent 0144059 commit e9ec84b
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 17 deletions.
1 change: 1 addition & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[alias]
poseidon-example = ["run", "--example", "poseidon", "--release", "--", "--json"]
trivial-example = ["run", "--example", "trivial", "--release", "--", "--json"]
cli-example = ["run", "--example", "cli", "--release", "--"]
11 changes: 6 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ thiserror = "1.0.48"
tracing = { version = "0.1.40", features = ["attributes"] }

[dev-dependencies]
prettytable-rs = "0.10.0"
tracing-test = "0.2.4"
halo2_gadgets = { git = "https://github.com/snarkify/halo2", branch = "snarkify/dev", features = ["unstable"] }
bincode = "1.3"
clap = { version = "4.5.4", features = ["derive"] }
criterion = "0.5.1"
halo2_gadgets = { git = "https://github.com/snarkify/halo2", branch = "snarkify/dev", features = ["unstable"] }
maplit = "1.0.2"
prettytable-rs = "0.10.0"
tempfile = "3.9.0"
tracing-subscriber = { version = "0.3.18", features = ["json"] }
maplit = "1.0.2"
criterion = "0.5.1"
tracing-test = "0.2.4"

[dev-dependencies.cargo-husky]
version = "1"
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ This code will run `fold_step_count` of folding steps, and also check the proof
For runnable examples, please check [examples](examples) folder.

```bash
# Alias to run IVC with parameterization via cli arguments
cargo cli-example --help

# Alias for run the IVC for trivial `StepCircuit` (just returns its input unchanged)
cargo trivial-example

Expand Down
207 changes: 207 additions & 0 deletions examples/cli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
use std::{array, num::NonZeroUsize};

use clap::{Parser, ValueEnum};

#[allow(dead_code)]
mod poseidon;

use ff::Field;
use poseidon::poseidon_step_circuit::TestPoseidonCircuit;
use sirius::{
ivc::{step_circuit::trivial, CircuitPublicParamsInput, PublicParams, StepCircuit, IVC},
poseidon::ROPair,
};
use tracing::*;
use tracing_subscriber::{filter::LevelFilter, fmt::format::FmtSpan, EnvFilter};

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
#[arg(value_enum, default_value_t = Circuits::Poseidon) ]
primary_circuit: Circuits,
#[arg(long, default_value_t = 17)]
primary_circuit_k_table_size: u32,
#[arg(long, default_value_t = 21)]
primary_commitment_key_size: usize,
#[arg(long, default_value_t = 1)]
primary_repeat_count: usize,
#[arg(long, default_value_t = 10)]
primary_r_f: usize,
#[arg(long, default_value_t = 10)]
primary_r_p: usize,
#[arg(value_enum, default_value_t = Circuits::Trivial) ]
secondary_circuit: Circuits,
#[arg(long, default_value_t = 17)]
secondary_circuit_k_table_size: u32,
#[arg(long, default_value_t = 21)]
secondary_commitment_key_size: usize,
#[arg(long, default_value_t = 1)]
secondary_repeat_count: usize,
#[arg(long, default_value_t = 10)]
secondary_r_f: usize,
#[arg(long, default_value_t = 10)]
secondary_r_p: usize,
#[arg(long, default_value_t = NonZeroUsize::new(32).unwrap()) ]
limb_width: NonZeroUsize,
#[arg(long, default_value_t = NonZeroUsize::new(10).unwrap()) ]
limbs_count: NonZeroUsize,
#[arg(long, default_value_t = false)]
debug_mode: bool,
#[arg(long, default_value_t = NonZeroUsize::new(1).unwrap()) ]
fold_step_count: NonZeroUsize,
#[arg(long, default_value_t = false)]
json_logs: bool,
}

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)]
enum Circuits {
Poseidon,
Trivial,
}

use halo2curves::{bn256, grumpkin};

use bn256::G1 as C1;
use grumpkin::G1 as C2;

const MAIN_GATE_SIZE: usize = 5;
const RATE: usize = 4;

type RandomOracle = sirius::poseidon::PoseidonRO<MAIN_GATE_SIZE, RATE>;
type RandomOracleConstant<F> = <RandomOracle as ROPair<F>>::Args;

type C1Affine = <C1 as halo2curves::group::prime::PrimeCurve>::Affine;
type C1Scalar = <C1 as halo2curves::group::Group>::Scalar;

type C2Affine = <C2 as halo2curves::group::prime::PrimeCurve>::Affine;
type C2Scalar = <C2 as halo2curves::group::Group>::Scalar;

fn fold(
args: &Args,
primary: impl StepCircuit<1, C1Scalar>,
secondary: impl StepCircuit<1, C2Scalar>,
) {
let primary_commitment_key = poseidon::get_or_create_commitment_key::<C1Affine>(
args.primary_commitment_key_size,
"bn256",
)
.expect("Failed to get primary key");
let secondary_commitment_key = poseidon::get_or_create_commitment_key::<C2Affine>(
args.secondary_commitment_key_size,
"grumpkin",
)
.expect("Failed to get secondary key");

// Specifications for random oracle used as part of the IVC algorithm
let primary_spec = RandomOracleConstant::<C1Scalar>::new(args.primary_r_f, args.primary_r_p);
let secondary_spec =
RandomOracleConstant::<C2Scalar>::new(args.secondary_r_f, args.secondary_r_p);

let pp = PublicParams::<
'_,
1,
1,
MAIN_GATE_SIZE,
C1Affine,
C2Affine,
_,
_,
RandomOracle,
RandomOracle,
>::new(
CircuitPublicParamsInput::new(
args.primary_circuit_k_table_size,
&primary_commitment_key,
primary_spec,
&primary,
),
CircuitPublicParamsInput::new(
args.secondary_circuit_k_table_size,
&secondary_commitment_key,
secondary_spec,
&secondary,
),
args.limb_width,
args.limbs_count,
)
.unwrap();

let mut rnd = rand::thread_rng();
let primary_input = array::from_fn(|_| C1Scalar::random(&mut rnd));
let secondary_input = array::from_fn(|_| C2Scalar::random(&mut rnd));

if args.debug_mode {
IVC::fold_with_debug_mode(
&pp,
primary,
primary_input,
secondary,
secondary_input,
args.fold_step_count,
)
.unwrap();
} else {
IVC::fold(
&pp,
primary,
primary_input,
secondary,
secondary_input,
args.fold_step_count,
)
.unwrap();
}
}

fn main() {
let args = Args::parse();

let builder = tracing_subscriber::fmt()
// Adds events to track the entry and exit of the span, which are used to build
// time-profiling
.with_span_events(FmtSpan::ENTER | FmtSpan::CLOSE)
// Changes the default level to INFO
.with_env_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy(),
);

// Structured logs are needed for time-profiling, while for simple run regular logs are
// more convenient.
//
// So this expr keeps track of the --json argument for turn-on json-logs
if args.json_logs {
builder.json().init();
} else {
builder.init();
}

// To osterize the total execution time of the example
let _span = info_span!("poseidon_example").entered();

// Such a redundant call design due to the fact that they are different function types for the
// compiler due to generics
match (args.primary_circuit, args.secondary_circuit) {
(Circuits::Poseidon, Circuits::Trivial) => fold(
&args,
TestPoseidonCircuit::new(args.primary_repeat_count),
trivial::Circuit::default(),
),
(Circuits::Poseidon, Circuits::Poseidon) => fold(
&args,
TestPoseidonCircuit::new(args.primary_repeat_count),
TestPoseidonCircuit::new(args.secondary_repeat_count),
),
(Circuits::Trivial, Circuits::Poseidon) => fold(
&args,
trivial::Circuit::default(),
TestPoseidonCircuit::new(args.secondary_repeat_count),
),
(Circuits::Trivial, Circuits::Trivial) => fold(
&args,
trivial::Circuit::default(),
trivial::Circuit::default(),
),
}
}
75 changes: 63 additions & 12 deletions examples/poseidon/main.rs → examples/poseidon.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/// This module represents an implementation of `StepCircuit` based on the poseidon chip
mod poseidon_step_circuit {
pub mod poseidon_step_circuit {
use std::marker::PhantomData;

use ff::{FromUniformBytes, PrimeFieldBits};
Expand All @@ -13,6 +13,7 @@ mod poseidon_step_circuit {
main_gate::{MainGate, MainGateConfig, RegionCtx, WrapValue},
poseidon::{poseidon_circuit::PoseidonChip, Spec},
};
use tracing::*;

/// Input and output size for `StepCircuit` within each step
pub const ARITY: usize = 1;
Expand All @@ -21,7 +22,7 @@ mod poseidon_step_circuit {
const POSEIDON_PERMUTATION_WIDTH: usize = 3;
const POSEIDON_RATE: usize = POSEIDON_PERMUTATION_WIDTH - 1;

type CircuitPoseidonSpec<F> = Spec<F, POSEIDON_PERMUTATION_WIDTH, POSEIDON_RATE>;
pub type CircuitPoseidonSpec<F> = Spec<F, POSEIDON_PERMUTATION_WIDTH, POSEIDON_RATE>;

const R_F1: usize = 4;
const R_P1: usize = 3;
Expand All @@ -31,11 +32,30 @@ mod poseidon_step_circuit {
pconfig: MainGateConfig<POSEIDON_PERMUTATION_WIDTH>,
}

#[derive(Default, Debug)]
#[derive(Debug)]
pub struct TestPoseidonCircuit<F: PrimeFieldBits> {
repeat_count: usize,
_p: PhantomData<F>,
}

impl<F: PrimeFieldBits> Default for TestPoseidonCircuit<F> {
fn default() -> Self {
Self {
repeat_count: 1,
_p: Default::default(),
}
}
}

impl<F: PrimeFieldBits> TestPoseidonCircuit<F> {
pub fn new(repeat_count: usize) -> Self {
Self {
repeat_count,
_p: Default::default(),
}
}
}

impl<F: PrimeFieldBits + FromUniformBytes<64>> StepCircuit<ARITY, F> for TestPoseidonCircuit<F> {
type Config = TestPoseidonCircuitConfig;

Expand All @@ -51,19 +71,50 @@ mod poseidon_step_circuit {
z_in: &[AssignedCell<F, F>; ARITY],
) -> Result<[AssignedCell<F, F>; ARITY], SynthesisError> {
let spec = CircuitPoseidonSpec::<F>::new(R_F1, R_P1);
let mut pchip = PoseidonChip::new(config.pconfig, spec);
let input = z_in.iter().map(|x| x.into()).collect::<Vec<WrapValue<F>>>();
pchip.update(&input);
let output = layouter

layouter
.assign_region(
|| "poseidon hash",
|region| {
move |region| {
let mut z_i = z_in.clone();
let ctx = &mut RegionCtx::new(region, 0);
pchip.squeeze(ctx)

for step in 0..=self.repeat_count {
let mut pchip = PoseidonChip::new(config.pconfig.clone(), spec.clone());

pchip.update(
&z_i.iter()
.cloned()
.map(WrapValue::Assigned)
.collect::<Vec<WrapValue<F>>>(),
);

info!(
"offset for {} hash repeat count is {} (log2 = {})",
step,
ctx.offset(),
(ctx.offset() as f64).log2()
);

z_i = [pchip.squeeze(ctx).inspect_err(|err| {
error!("at step {step}: {err:?}");
})?];
}

info!(
"total offset for {} hash repeat count is {} (log2 = {})",
self.repeat_count,
ctx.offset(),
(ctx.offset() as f64).log2()
);

Ok(z_i)
},
)
.map_err(SynthesisError::Halo2)?;
Ok([output])
.map_err(|err| {
error!("while synth {err:?}");
SynthesisError::Halo2(err)
})
}
}
}
Expand Down Expand Up @@ -117,7 +168,7 @@ type C2Scalar = <C2 as halo2curves::group::Group>::Scalar;

/// Either takes the key from [`CACHE_FOLDER`] or generates a new one and puts it in it
#[instrument]
fn get_or_create_commitment_key<C: CurveAffine>(
pub fn get_or_create_commitment_key<C: CurveAffine>(
k: usize,
label: &'static str,
) -> io::Result<CommitmentKey<C>> {
Expand Down

0 comments on commit e9ec84b

Please sign in to comment.