diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 767758a8b..ec0de0b14 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -56,10 +56,12 @@ jobs: - name: Run example env: RAYON_NUM_THREADS: 2 - run: cargo run --release --package ceno_zkvm --example riscv_opcodes --target ${{ matrix.target }} -- --start 10 --end 11 + RUSTFLAGS: "-C opt-level=3" + run: cargo run --package ceno_zkvm --example riscv_opcodes --target ${{ matrix.target }} -- --start 10 --end 11 - name: Run fibonacci env: RAYON_NUM_THREADS: 8 RUST_LOG: debug - run: cargo run --release --package ceno_zkvm --bin e2e --target ${{ matrix.target }} -- --platform=sp1 ceno_zkvm/examples/fibonacci.elf \ No newline at end of file + RUSTFLAGS: "-C opt-level=3" + run: cargo run --package ceno_zkvm --bin e2e --target ${{ matrix.target }} -- --platform=sp1 ceno_zkvm/examples/fibonacci.elf diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index f87d32886..1014ba9a1 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -63,7 +63,7 @@ jobs: run: | cargo make --version || cargo install cargo-make - name: Check code format - run: cargo make fmt-all-check + run: cargo fmt --all --check - name: Run clippy env: diff --git a/Cargo.toml b/Cargo.toml index 1ee0b66ef..ead19b880 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,5 +52,8 @@ tracing = { version = "0.1", features = [ tracing-forest = { version = "0.1.6" } tracing-subscriber = { version = "0.3", features = ["env-filter"] } +[profile.dev] +lto = "thin" + [profile.release] lto = "thin" diff --git a/Makefile.toml b/Makefile.toml index b9f231db4..8d6ab5edf 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -33,26 +33,11 @@ args = [ command = "cargo" workspace = false -[tasks.fmt-all-check] -args = ["fmt", "--all", "--", "--check"] -command = "cargo" -workspace = false - -[tasks.fmt-all] -args = ["fmt", "--all"] -command = "cargo" -workspace = false - [tasks.clippy-all] args = ["clippy", "--all-features", "--all-targets", "--", "-D", "warnings"] command = "cargo" workspace = false -[tasks.fmt] -args = ["fmt", "-p", "ceno_zkvm", "--", "--check"] -command = "cargo" -workspace = false - [tasks.riscv_stats] args = ["run", "--bin", "riscv_stats"] command = "cargo" diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 8f1746d79..e071cb000 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -29,9 +29,8 @@ impl VMState { /// 32 architectural registers + 1 register RD_NULL for dark writes to x0. pub const REG_COUNT: usize = 32 + 1; - pub fn new(platform: Platform, program: Program) -> Self { + pub fn new(platform: Platform, program: Arc) -> Self { let pc = program.entry; - let program = Arc::new(program); let mut vm = Self { pc, @@ -52,7 +51,7 @@ impl VMState { } pub fn new_from_elf(platform: Platform, elf: &[u8]) -> Result { - let program = Program::load_elf(elf, u32::MAX)?; + let program = Arc::new(Program::load_elf(elf, u32::MAX)?); Ok(Self::new(platform, program)) } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index d931a6a9c..2aa5f0da2 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -1,6 +1,9 @@ #![allow(clippy::unusual_byte_groupings)] use anyhow::Result; -use std::collections::{BTreeMap, HashMap}; +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, +}; use ceno_emul::{ CENO_PLATFORM, Cycle, EmuContext, InsnKind, Platform, Program, StepRecord, Tracer, VMState, @@ -24,7 +27,7 @@ fn test_vm_trace() -> Result<()> { }) .collect(), ); - let mut ctx = VMState::new(CENO_PLATFORM, program); + let mut ctx = VMState::new(CENO_PLATFORM, Arc::new(program)); let steps = run(&mut ctx)?; @@ -52,7 +55,7 @@ fn test_empty_program() -> Result<()> { vec![], BTreeMap::new(), ); - let mut ctx = VMState::new(CENO_PLATFORM, empty_program); + let mut ctx = VMState::new(CENO_PLATFORM, Arc::new(empty_program)); let res = run(&mut ctx); assert!(matches!(res, Err(e) if e.to_string().contains("InstructionAccessFault")),); Ok(()) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index f5bd125e7..4bb427eb9 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,3 +67,7 @@ name = "riscv_add" [[bench]] harness = false name = "fibonacci" + +[[bench]] +harness = false +name = "fibonacci_witness" diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index a90521cfc..f75c4ddd1 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -7,7 +7,7 @@ use std::{ use ceno_emul::{CENO_PLATFORM, Platform, Program, WORD_SIZE}; use ceno_zkvm::{ self, - e2e::{run_e2e_gen_witness, run_e2e_proof}, + e2e::{Checkpoint, run_e2e_with_checkpoint}, }; use criterion::*; @@ -15,17 +15,20 @@ use goldilocks::GoldilocksExt2; use mpcs::BasefoldDefault; criterion_group! { - name = fibonacci; + name = fibonacci_prove_group; config = Criterion::default().warm_up_time(Duration::from_millis(20000)); - targets = bench_e2e + targets = fibonacci_prove, } -criterion_main!(fibonacci); +criterion_main!(fibonacci_prove_group); const NUM_SAMPLES: usize = 10; -fn bench_e2e(c: &mut Criterion) { - type Pcs = BasefoldDefault; +type Pcs = BasefoldDefault; +type E = GoldilocksExt2; + +// Relevant init data for fibonacci run +fn setup() -> (Program, Platform, u32, u32) { let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); file_path.push("examples/fibonacci.elf"); let stack_size = 32768; @@ -33,7 +36,6 @@ fn bench_e2e(c: &mut Criterion) { let elf_bytes = fs::read(&file_path).expect("read elf file"); let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); - // use sp1 platform let platform = Platform { // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. stack_top: 0x0020_0400, @@ -44,6 +46,11 @@ fn bench_e2e(c: &mut Criterion) { ..CENO_PLATFORM }; + (program, platform, stack_size, heap_size) +} + +fn fibonacci_prove(c: &mut Criterion) { + let (program, platform, stack_size, heap_size) = setup(); for max_steps in [1usize << 20, 1usize << 21, 1usize << 22] { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("fibonacci_max_steps_{}", max_steps)); @@ -58,18 +65,20 @@ fn bench_e2e(c: &mut Criterion) { |b| { b.iter_with_setup( || { - run_e2e_gen_witness::( + run_e2e_with_checkpoint::( program.clone(), platform.clone(), stack_size, heap_size, vec![], max_steps, + Checkpoint::PrepE2EProving, ) }, - |(prover, _, zkvm_witness, pi, _, _, _)| { + |(_, run_e2e_proof)| { let timer = Instant::now(); - let _ = run_e2e_proof(prover, zkvm_witness, pi); + + run_e2e_proof(); println!( "Fibonacci::create_proof, max_steps = {}, time = {}", max_steps, @@ -82,6 +91,4 @@ fn bench_e2e(c: &mut Criterion) { group.finish(); } - - type E = GoldilocksExt2; } diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs new file mode 100644 index 000000000..2f09adaee --- /dev/null +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -0,0 +1,83 @@ +use std::{fs, path::PathBuf, time::Duration}; + +use ceno_emul::{CENO_PLATFORM, Platform, Program, WORD_SIZE}; +use ceno_zkvm::{ + self, + e2e::{Checkpoint, run_e2e_with_checkpoint}, +}; +use criterion::*; + +use goldilocks::GoldilocksExt2; +use mpcs::BasefoldDefault; + +criterion_group! { + name = fibonacci; + config = Criterion::default().warm_up_time(Duration::from_millis(20000)); + targets = fibonacci_witness +} + +criterion_main!(fibonacci); + +const NUM_SAMPLES: usize = 10; +type Pcs = BasefoldDefault; +type E = GoldilocksExt2; + +// Relevant init data for fibonacci run +fn setup() -> (Program, Platform, u32, u32) { + let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + file_path.push("examples/fibonacci.elf"); + let stack_size = 32768; + let heap_size = 2097152; + let elf_bytes = fs::read(&file_path).expect("read elf file"); + let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); + + let platform = Platform { + // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. + stack_top: 0x0020_0400, + rom: program.base_address + ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, + ram: 0x0010_0000..0xFFFF_0000, + unsafe_ecall_nop: true, + ..CENO_PLATFORM + }; + + (program, platform, stack_size, heap_size) +} + +fn fibonacci_witness(c: &mut Criterion) { + let (program, platform, stack_size, heap_size) = setup(); + + let max_steps = usize::MAX; + let mut group = c.benchmark_group(format!("fib_wit_max_steps_{}", max_steps)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new( + "fibonacci_witness", + format!("fib_wit_max_steps_{}", max_steps), + ), + |b| { + b.iter_with_setup( + || { + run_e2e_with_checkpoint::( + program.clone(), + platform.clone(), + stack_size, + heap_size, + vec![], + max_steps, + Checkpoint::PrepWitnessGen, + ) + }, + |(_, generate_witness)| { + generate_witness(); + }, + ); + }, + ); + + group.finish(); + + type E = GoldilocksExt2; +} diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index d1e86ad98..bc3124440 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -15,7 +15,7 @@ use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme}; use multilinear_extensions::mle::IntoMLE; -use transcript::Transcript; +use transcript::{BasicTranscript, Transcript}; cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { @@ -87,7 +87,7 @@ fn bench_add(c: &mut Criterion) { |wits_in| { let timer = Instant::now(); let num_instances = 1 << instance_num_vars; - let mut transcript = Transcript::new(b"riscv"); + let mut transcript = BasicTranscript::new(b"riscv"); let commit = Pcs::batch_commit_and_write(&prover.pk.pp, &wits_in, &mut transcript) .unwrap(); diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 99b86173f..fe119bb25 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -1,4 +1,4 @@ -use std::{panic, time::Instant}; +use std::{panic, sync::Arc, time::Instant}; use ceno_zkvm::{ declare_program, @@ -27,7 +27,7 @@ use itertools::Itertools; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; use sumcheck::macros::{entered_span, exit_span}; use tracing_subscriber::{EnvFilter, Registry, fmt, fmt::format::FmtSpan, layer::SubscriberExt}; -use transcript::Transcript; +use transcript::BasicTranscript as Transcript; const PROGRAM_SIZE: usize = 16; // For now, we assume registers // - x0 is not touched, @@ -177,7 +177,7 @@ fn main() { // init vm.x1 = 1, vm.x2 = -1, vm.x3 = step_loop let public_io_init = init_public_io(&[1, u32::MAX, step_loop]); - let mut vm = VMState::new(CENO_PLATFORM, program.clone()); + let mut vm = VMState::new(CENO_PLATFORM, Arc::new(program.clone())); // init memory mapped IO for record in &public_io_init { @@ -290,12 +290,7 @@ fn main() { trace_report.save_json("report.json"); trace_report.save_table("report.txt"); - MockProver::assert_satisfied_full( - zkvm_cs.clone(), - zkvm_fixed_traces.clone(), - &zkvm_witness, - &pi, - ); + MockProver::assert_satisfied_full(&zkvm_cs, zkvm_fixed_traces.clone(), &zkvm_witness, &pi); let timer = Instant::now(); diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 18f373d8b..07baf8998 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,6 +1,6 @@ use ceno_emul::{CENO_PLATFORM, IterAddresses, Platform, Program, WORD_SIZE, Word}; use ceno_zkvm::{ - e2e::{run_e2e_gen_witness, run_e2e_proof, run_e2e_verify}, + e2e::{Checkpoint, run_e2e_with_checkpoint}, with_panic_hook, }; use clap::{Parser, ValueEnum}; @@ -8,13 +8,13 @@ use ff_ext::ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; use mpcs::{Basefold, BasefoldRSParams}; -use std::{fs, panic, time::Instant}; +use std::{fs, panic}; use tracing::level_filters::LevelFilter; use tracing_forest::ForestLayer; use tracing_subscriber::{ EnvFilter, Registry, filter::filter_fn, fmt, layer::SubscriberExt, util::SubscriberInitExt, }; -use transcript::Transcript; +use transcript::BasicTranscript as Transcript; /// Prove the execution of a fixed RISC-V program. #[derive(Parser, Debug)] @@ -143,37 +143,17 @@ fn main() { type B = Goldilocks; type Pcs = Basefold; - let (prover, verifier, zkvm_witness, pi, cycle_num, e2e_start, exit_code) = - run_e2e_gen_witness::( - program, - platform, - args.stack_size, - args.heap_size, - hints, - max_steps, - ); - - let timer = Instant::now(); - let mut zkvm_proof = run_e2e_proof(prover, zkvm_witness, pi); - let proving_time = timer.elapsed().as_secs_f64(); - let e2e_time = e2e_start.elapsed().as_secs_f64(); - let witgen_time = e2e_time - proving_time; - println!( - "Proving finished.\n\ -\tProving time = {:.3}s, freq = {:.3}khz\n\ -\tWitgen time = {:.3}s, freq = {:.3}khz\n\ -\tTotal time = {:.3}s, freq = {:.3}khz\n\ -\tthread num: {}", - proving_time, - cycle_num as f64 / proving_time / 1000.0, - witgen_time, - cycle_num as f64 / witgen_time / 1000.0, - e2e_time, - cycle_num as f64 / e2e_time / 1000.0, - rayon::current_num_threads() + let (state, _) = run_e2e_with_checkpoint::( + program, + platform, + args.stack_size, + args.heap_size, + hints, + max_steps, + Checkpoint::PrepSanityCheck, ); - run_e2e_verify(&verifier, zkvm_proof.clone(), exit_code, max_steps); + let (mut zkvm_proof, verifier) = state.expect("PrepSanityCheck should yield state."); // do sanity check let transcript = Transcript::new(b"riscv"); @@ -207,7 +187,6 @@ fn main() { } }; } - fn memory_from_file(path: &Option) -> Vec { path.as_ref() .map(|path| { diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index c8bfbe811..b1d5b4d77 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -5,136 +5,69 @@ use crate::{ prover::ZKVMProver, verifier::ZKVMVerifier, }, state::GlobalState, - structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, - tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit}, + structs::{ + ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey, ZKVMWitnesses, + }, + tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ ByteAddr, EmuContext, InsnKind::EANY, IterAddresses, Platform, Program, StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, }; use ff_ext::ExtensionField; -use itertools::{Itertools, MinMaxResult, chain, enumerate}; +use itertools::{Itertools, MinMaxResult, chain}; use mpcs::PolynomialCommitmentScheme; use std::{ collections::{HashMap, HashSet}, iter::zip, - time::Instant, + ops::Deref, + sync::Arc, }; -use transcript::Transcript; - -type E2EWitnessGen = ( - ZKVMProver, - ZKVMVerifier, - ZKVMWitnesses, - PublicValues, - usize, // number of cycles - Instant, // e2e start, excluding key gen time - Option, -); - -pub fn run_e2e_gen_witness>( - program: Program, - platform: Platform, - stack_size: u32, - heap_size: u32, - hints: Vec, - max_steps: usize, -) -> E2EWitnessGen { - let stack_addrs = platform.stack_top - stack_size..platform.stack_top; - - // Detect heap as starting after program data. - let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; - let heap_addrs = heap_start..heap_start + heap_size; +use transcript::BasicTranscript as Transcript; - let mut mem_padder = MemPadder::new(heap_addrs.end..platform.ram.end); - - let mem_init = { - let program_addrs = program.image.iter().map(|(addr, value)| MemInitRecord { - addr: *addr, - value: *value, - }); - - let stack = stack_addrs - .iter_addresses() - .map(|addr| MemInitRecord { addr, value: 0 }); +pub struct FullMemState { + mem: Vec, + io: Vec, + reg: Vec, + priv_io: Vec, +} - let heap = heap_addrs - .iter_addresses() - .map(|addr| MemInitRecord { addr, value: 0 }); +type InitMemState = FullMemState; +type FinalMemState = FullMemState; - let mem_init = chain!(program_addrs, stack, heap).collect_vec(); +pub struct EmulationResult { + exit_code: Option, + all_records: Vec, + final_mem_state: FinalMemState, + pi: PublicValues, +} - mem_padder.padded_sorted(mem_init.len().next_power_of_two(), mem_init) - }; +fn emulate_program( + program: Arc, + max_steps: usize, + init_mem_state: InitMemState, + platform: &Platform, + hints: Vec, +) -> EmulationResult { + let InitMemState { + mem: mem_init, + io: io_init, + reg: reg_init, + priv_io: _, + } = init_mem_state; - let mut vm = VMState::new(platform.clone(), program); + let mut vm: VMState = VMState::new(platform.clone(), program); for (addr, value) in zip(platform.hints.iter_addresses(), &hints) { vm.init_memory(addr.into(), *value); } - // keygen - let pcs_param = PCS::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); - let (pp, vp) = PCS::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); - let program_params = ProgramParams { - platform: platform.clone(), - program_size: vm.program().instructions.len(), - static_memory_len: mem_init.len(), - ..ProgramParams::default() - }; - let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); - - let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); - let mmu_config = MmuConfig::::construct_circuits(&mut zkvm_cs); - let dummy_config = DummyExtraConfig::::construct_circuits(&mut zkvm_cs); - let prog_config = zkvm_cs.register_table_circuit::>(); - zkvm_cs.register_global_state::(); - - let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); - - zkvm_fixed_traces.register_table_circuit::>( - &zkvm_cs, - &prog_config, - vm.program(), - ); - - // IO is not used in this program, but it must have a particular size at the moment. - let io_init = mem_padder.padded_sorted(mmu_config.public_io_len(), vec![]); - - let reg_init = mmu_config.initial_registers(); - config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces); - mmu_config.generate_fixed_traces( - &zkvm_cs, - &mut zkvm_fixed_traces, - ®_init, - &mem_init, - &io_init.iter().map(|rec| rec.addr).collect_vec(), - ); - dummy_config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces); - - let pk = zkvm_cs - .clone() - .key_gen::(pp.clone(), vp.clone(), zkvm_fixed_traces.clone()) - .expect("keygen failed"); - let vk = pk.get_vk(); - - // proving - let e2e_start = Instant::now(); - let prover = ZKVMProver::new(pk); - let verifier = ZKVMVerifier::new(vk); - let all_records = vm .iter_until_halt() .take(max_steps) .collect::, _>>() .expect("vm exec failed"); - let cycle_num = all_records.len(); - tracing::info!("Proving {} execution steps", cycle_num); - for (i, step) in enumerate(&all_records).rev().take(5).rev() { - tracing::trace!("Step {i}: {:?} - {:?}\n", step.insn().codes().kind, step); - } - // Find the exit code from the HALT step, if halting at all. let exit_code = all_records .iter() @@ -158,16 +91,6 @@ pub fn run_e2e_gen_witness io_init.iter().map(|rec| rec.value).collect_vec(), ); - let mut zkvm_witness = ZKVMWitnesses::default(); - // assign opcode circuits - let dummy_records = config - .assign_opcode_circuit(&zkvm_cs, &mut zkvm_witness, all_records) - .unwrap(); - dummy_config - .assign_opcode_circuit(&zkvm_cs, &mut zkvm_witness, dummy_records) - .unwrap(); - zkvm_witness.finalize_lk_multiplicities(); - // Find the final register values and cycles. let reg_final = reg_init .iter() @@ -208,7 +131,11 @@ pub fn run_e2e_gen_witness // Find the final public IO cycles. let io_final = io_init .iter() - .map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0)) + .map(|rec| MemFinalRecord { + addr: rec.addr, + value: rec.value, + cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), + }) .collect_vec(); let priv_io_final = zip(platform.hints.iter_addresses(), &hints) @@ -219,45 +146,331 @@ pub fn run_e2e_gen_witness }) .collect_vec(); + EmulationResult { + pi, + exit_code, + all_records, + final_mem_state: FinalMemState { + reg: reg_final, + io: io_final, + mem: mem_final, + priv_io: priv_io_final, + }, + } +} + +fn init_mem( + program: &Program, + platform: &Platform, + mem_padder: &mut MemPadder, + stack_size: u32, + heap_size: u32, +) -> Vec { + let stack_addrs = platform.stack_top - stack_size..platform.stack_top; + // Detect heap as starting after program data. + let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; + let heap_addrs = heap_start..heap_start + heap_size; + let program_addrs = program.image.iter().map(|(addr, value)| MemInitRecord { + addr: *addr, + value: *value, + }); + + let stack = stack_addrs + .iter_addresses() + .map(|addr| MemInitRecord { addr, value: 0 }); + + let heap = heap_addrs + .iter_addresses() + .map(|addr| MemInitRecord { addr, value: 0 }); + + let mem_init = chain!(program_addrs, stack, heap).collect_vec(); + + mem_padder.padded_sorted(mem_init.len().next_power_of_two(), mem_init) +} + +pub struct ConstraintSystemConfig { + zkvm_cs: ZKVMConstraintSystem, + config: Rv32imConfig, + mmu_config: MmuConfig, + dummy_config: DummyExtraConfig, + prog_config: ProgramTableConfig, +} + +fn construct_configs( + program_params: ProgramParams, +) -> ConstraintSystemConfig { + let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); + + let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let mmu_config = MmuConfig::::construct_circuits(&mut zkvm_cs); + let dummy_config = DummyExtraConfig::::construct_circuits(&mut zkvm_cs); + let prog_config = zkvm_cs.register_table_circuit::>(); + zkvm_cs.register_global_state::(); + ConstraintSystemConfig { + zkvm_cs, + config, + mmu_config, + dummy_config, + prog_config, + } +} + +fn generate_fixed_traces( + system_config: &ConstraintSystemConfig, + init_mem_state: &InitMemState, + program: &Program, +) -> ZKVMFixedTraces { + let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); + + zkvm_fixed_traces.register_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + program, + ); + + system_config + .config + .generate_fixed_traces(&system_config.zkvm_cs, &mut zkvm_fixed_traces); + system_config.mmu_config.generate_fixed_traces( + &system_config.zkvm_cs, + &mut zkvm_fixed_traces, + &init_mem_state.reg, + &init_mem_state.mem, + &init_mem_state.io.iter().map(|rec| rec.addr).collect_vec(), + ); + system_config + .dummy_config + .generate_fixed_traces(&system_config.zkvm_cs, &mut zkvm_fixed_traces); + + zkvm_fixed_traces +} + +pub fn generate_witness( + system_config: &ConstraintSystemConfig, + emul_result: EmulationResult, + program: &Program, +) -> ZKVMWitnesses { + let mut zkvm_witness = ZKVMWitnesses::default(); + // assign opcode circuits + let dummy_records = system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + emul_result.all_records, + ) + .unwrap(); + system_config + .dummy_config + .assign_opcode_circuit(&system_config.zkvm_cs, &mut zkvm_witness, dummy_records) + .unwrap(); + zkvm_witness.finalize_lk_multiplicities(); + // assign table circuits - config - .assign_table_circuit(&zkvm_cs, &mut zkvm_witness) + system_config + .config + .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) .unwrap(); - mmu_config + system_config + .mmu_config .assign_table_circuit( - &zkvm_cs, + &system_config.zkvm_cs, &mut zkvm_witness, - ®_final, - &mem_final, - &io_final, - &priv_io_final, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result + .final_mem_state + .io + .iter() + .map(|rec| rec.cycle) + .collect_vec(), + &emul_result.final_mem_state.priv_io, ) .unwrap(); // assign program circuit zkvm_witness - .assign_table_circuit::>(&zkvm_cs, &prog_config, vm.program()) + .assign_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + program, + ) .unwrap(); + zkvm_witness +} + +// Encodes useful early return points of the e2e pipeline +pub enum Checkpoint { + PrepE2EProving, + PrepWitnessGen, + PrepSanityCheck, + Complete, +} + +// Currently handles state required by the sanity check in `bin/e2e.rs` +// Future cases would require this to be an enum +pub type IntermediateState = (ZKVMProof, ZKVMVerifier); + +// Runs end-to-end pipeline, stopping at a certain checkpoint and yielding useful state. +// +// The return type is a pair of: +// 1. Explicit state +// 2. A no-input-no-ouptut closure +// +// (2.) is useful when you want to setup a certain action and run it +// elsewhere (i.e, in a benchmark) +// (1.) is useful for exposing state which must be further combined with +// state external to this pipeline (e.g, sanity check in bin/e2e.rs) + +#[allow(clippy::type_complexity)] +pub fn run_e2e_with_checkpoint + 'static>( + program: Program, + platform: Platform, + stack_size: u32, + heap_size: u32, + hints: Vec, + max_steps: usize, + checkpoint: Checkpoint, +) -> (Option>, Box) { + // Detect heap as starting after program data. + let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; + let heap_addrs = heap_start..heap_start + heap_size; + let mut mem_padder = MemPadder::new(heap_addrs.end..platform.ram.end); + let mem_init = init_mem(&program, &platform, &mut mem_padder, stack_size, heap_size); + + let program_params = ProgramParams { + platform: platform.clone(), + program_size: program.instructions.len(), + static_memory_len: mem_init.len(), + ..ProgramParams::default() + }; + + let program = Arc::new(program); + let system_config = construct_configs::(program_params); + + // IO is not used in this program, but it must have a particular size at the moment. + let io_init = mem_padder.padded_sorted(system_config.mmu_config.public_io_len(), vec![]); + let reg_init = system_config.mmu_config.initial_registers(); + + let init_full_mem = InitMemState { + mem: mem_init, + reg: reg_init, + io: io_init, + priv_io: vec![], + }; + + // Generate fixed traces + let zkvm_fixed_traces = generate_fixed_traces(&system_config, &init_full_mem, &program); + + // Keygen + let pcs_param = PCS::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); + let (pp, vp) = PCS::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); + let pk = system_config + .zkvm_cs + .clone() + .key_gen::(pp.clone(), vp.clone(), zkvm_fixed_traces.clone()) + .expect("keygen failed"); + let vk = pk.get_vk(); + + if let Checkpoint::PrepE2EProving = checkpoint { + return ( + None, + Box::new(move || { + _ = run_e2e_proof( + program, + max_steps, + init_full_mem, + platform, + hints, + &system_config, + pk, + zkvm_fixed_traces, + ) + }), + ); + } + + // Emulate program + let emul_result = emulate_program(program.clone(), max_steps, init_full_mem, &platform, hints); + + // Clone some emul_result fields before consuming + let pi = emul_result.pi.clone(); + let exit_code = emul_result.exit_code; + + if let Checkpoint::PrepWitnessGen = checkpoint { + return ( + None, + Box::new(move || _ = generate_witness(&system_config, emul_result, program.deref())), + ); + } + + // Generate witness + let zkvm_witness = generate_witness(&system_config, emul_result, &program); + + // proving + let prover = ZKVMProver::new(pk); + if std::env::var("MOCK_PROVING").is_ok() { - MockProver::assert_satisfied_full(zkvm_cs, zkvm_fixed_traces, &zkvm_witness, &pi); + MockProver::assert_satisfied_full( + &system_config.zkvm_cs, + zkvm_fixed_traces.clone(), + &zkvm_witness, + &pi, + ); tracing::info!("Mock proving passed"); } - ( - prover, - verifier, - zkvm_witness, - pi, - cycle_num, - e2e_start, - exit_code, - ) + + // Run proof phase + let transcript = Transcript::new(b"riscv"); + let zkvm_proof = prover + .create_proof(zkvm_witness, pi, transcript) + .expect("create_proof failed"); + + let verifier = ZKVMVerifier::new(vk); + + run_e2e_verify(&verifier, zkvm_proof.clone(), exit_code, max_steps); + + if let Checkpoint::PrepSanityCheck = checkpoint { + return (Some((zkvm_proof, verifier)), Box::new(|| ())); + } + + (None, Box::new(|| ())) } +// Runs program emulation + witness generation + proving +#[allow(clippy::too_many_arguments)] pub fn run_e2e_proof>( - prover: ZKVMProver, - zkvm_witness: ZKVMWitnesses, - pi: PublicValues, + program: Arc, + max_steps: usize, + init_full_mem: InitMemState, + platform: Platform, + hints: Vec, + system_config: &ConstraintSystemConfig, + pk: ZKVMProvingKey, + zkvm_fixed_traces: ZKVMFixedTraces, ) -> ZKVMProof { + // Emulate program + let emul_result = emulate_program(program.clone(), max_steps, init_full_mem, &platform, hints); + + // clone pi before consuming + let pi = emul_result.pi.clone(); + + // Generate witness + let zkvm_witness = generate_witness(system_config, emul_result, program.deref()); + + // proving + let prover = ZKVMProver::new(pk); + + if std::env::var("MOCK_PROVING").is_ok() { + MockProver::assert_satisfied_full( + &system_config.zkvm_cs, + zkvm_fixed_traces.clone(), + &zkvm_witness, + &pi, + ); + tracing::info!("Mock proving passed"); + } + let transcript = Transcript::new(b"riscv"); prover .create_proof(zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index d88f0336c..76894f7a0 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -150,29 +150,31 @@ mod test { #[test] fn test_sltiu_true() { - verify::("lt = true, 0 < 1", 0, 1, 1); - verify::("lt = true, 1 < 2", 1, 2, 1); - verify::("lt = true, 10 < 20", 10, 20, 1); - verify::("lt = true, 0 < imm upper boundary", 0, 2047, 1); + let verify = |name, a, imm| verify::(name, a, imm, true); + verify("lt = true, 0 < 1", 0, 1); + verify("lt = true, 1 < 2", 1, 2); + verify("lt = true, 10 < 20", 10, 20); + verify("lt = true, 0 < imm upper boundary", 0, 2047); // negative imm is treated as positive - verify::("lt = true, 0 < u32::MAX-1", 0, -1, 1); - verify::("lt = true, 1 < u32::MAX-1", 1, -1, 1); - verify::("lt = true, 0 < imm lower bondary", 0, -2048, 1); + verify("lt = true, 0 < u32::MAX-1", 0, -1); + verify("lt = true, 1 < u32::MAX-1", 1, -1); + verify("lt = true, 0 < imm lower bondary", 0, -2048); } #[test] fn test_sltiu_false() { - verify::("lt = false, 1 < 0", 1, 0, 0); - verify::("lt = false, 2 < 1", 2, 1, 0); - verify::("lt = false, 100 < 50", 100, 50, 0); - verify::("lt = false, 500 < 100", 500, 100, 0); - verify::("lt = false, 100000 < 2047", 100000, 2047, 0); - verify::("lt = false, 100000 < 0", 100000, 0, 0); - verify::("lt = false, 0 == 0", 0, 0, 0); - verify::("lt = false, 1 == 1", 1, 1, 0); - verify::("lt = false, imm upper bondary", u32::MAX, 2047, 0); + let verify = |name, a, imm| verify::(name, a, imm, false); + verify("lt = false, 1 < 0", 1, 0); + verify("lt = false, 2 < 1", 2, 1); + verify("lt = false, 100 < 50", 100, 50); + verify("lt = false, 500 < 100", 500, 100); + verify("lt = false, 100000 < 2047", 100000, 2047); + verify("lt = false, 100000 < 0", 100000, 0); + verify("lt = false, 0 == 0", 0, 0); + verify("lt = false, 1 == 1", 1, 1); + verify("lt = false, imm upper bondary", u32::MAX, 2047); // negative imm is treated as positive - verify::("lt = false, imm lower bondary", u32::MAX, -2048, 0); + verify("lt = false, imm lower bondary", u32::MAX, -2048); } #[test] @@ -181,34 +183,36 @@ mod test { let a: u32 = rng.gen::(); let b: i32 = rng.gen_range(-2048..2048); println!("random: {} ("random unsigned comparison", a, b, (a < (b as u32)) as u32); + verify::("random unsigned comparison", a, b, a < (b as u32)); } #[test] fn test_slti_true() { - verify::("lt = true, 0 < 1", 0, 1, 1); - verify::("lt = true, 1 < 2", 1, 2, 1); - verify::("lt = true, -1 < 0", -1i32 as u32, 0, 1); - verify::("lt = true, -1 < 1", -1i32 as u32, 1, 1); - verify::("lt = true, -2 < -1", -2i32 as u32, -1, 1); + let verify = |name, a: i32, imm| verify::(name, a as u32, imm, true); + verify("lt = true, 0 < 1", 0, 1); + verify("lt = true, 1 < 2", 1, 2); + verify("lt = true, -1 < 0", -1, 0); + verify("lt = true, -1 < 1", -1, 1); + verify("lt = true, -2 < -1", -2, -1); // -2048 <= imm <= 2047 - verify::("lt = true, imm upper bondary", i32::MIN as u32, 2047, 1); - verify::("lt = true, imm lower bondary", i32::MIN as u32, -2048, 1); + verify("lt = true, imm upper bondary", i32::MIN, 2047); + verify("lt = true, imm lower bondary", i32::MIN, -2048); } #[test] fn test_slti_false() { - verify::("lt = false, 1 < 0", 1, 0, 0); - verify::("lt = false, 2 < 1", 2, 1, 0); - verify::("lt = false, 0 < -1", 0, -1, 0); - verify::("lt = false, 1 < -1", 1, -1, 0); - verify::("lt = false, -1 < -2", -1i32 as u32, -2, 0); - verify::("lt = false, 0 == 0", 0, 0, 0); - verify::("lt = false, 1 == 1", 1, 1, 0); - verify::("lt = false, -1 == -1", -1i32 as u32, -1, 0); + let verify = |name, a: i32, imm| verify::(name, a as u32, imm, false); + verify("lt = false, 1 < 0", 1, 0); + verify("lt = false, 2 < 1", 2, 1); + verify("lt = false, 0 < -1", 0, -1); + verify("lt = false, 1 < -1", 1, -1); + verify("lt = false, -1 < -2", -1, -2); + verify("lt = false, 0 == 0", 0, 0); + verify("lt = false, 1 == 1", 1, 1); + verify("lt = false, -1 == -1", -1, -1); // -2048 <= imm <= 2047 - verify::("lt = false, imm upper bondary", i32::MAX as u32, 2047, 0); - verify::("lt = false, imm lower bondary", i32::MAX as u32, -2048, 0); + verify("lt = false, imm upper bondary", i32::MAX, 2047); + verify("lt = false, imm lower bondary", i32::MAX, -2048); } #[test] @@ -217,10 +221,11 @@ mod test { let a: i32 = rng.gen(); let b: i32 = rng.gen_range(-2048..2048); println!("random: {} ("random 1", a as u32, b, (a < b) as u32); + verify::("random 1", a as u32, b, a < b); } - fn verify(name: &'static str, rs1_read: u32, imm: i32, expected_rd: u32) { + fn verify(name: &'static str, rs1_read: u32, imm: i32, expected_rd: bool) { + let expected_rd = expected_rd as u32; let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 70b543090..ccd8e0a07 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -746,7 +746,7 @@ Hints: } pub fn assert_satisfied_full( - cs: ZKVMConstraintSystem, + cs: &ZKVMConstraintSystem, mut fixed_trace: ZKVMFixedTraces, witnesses: &ZKVMWitnesses, pi: &PublicValues, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2d31e656f..b0ff7b656 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -18,7 +18,7 @@ use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverStateV2}, }; -use transcript::Transcript; +use transcript::{ForkableTranscript, Transcript}; use crate::{ circuit_builder::SetTableAddrType, @@ -62,7 +62,7 @@ impl> ZKVMProver { &self, witnesses: ZKVMWitnesses, pi: PublicValues, - mut transcript: Transcript, + mut transcript: impl ForkableTranscript, ) -> Result, ZKVMError> { let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); let mut vm_proof = ZKVMProof::empty(pi); @@ -219,7 +219,7 @@ impl> ZKVMProver { wits_commit: PCS::CommitmentWithData, pi: &[ArcMultilinearExtension<'_, E>], num_instances: usize, - transcript: &mut Transcript, + transcript: &mut impl Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { let cs = circuit_pk.get_cs(); @@ -663,7 +663,7 @@ impl> ZKVMProver { witnesses: Vec>, wits_commit: PCS::CommitmentWithData, pi: &[ArcMultilinearExtension<'_, E>], - transcript: &mut Transcript, + transcript: &mut impl Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { let cs = circuit_pk.get_cs(); @@ -1152,7 +1152,7 @@ impl TowerProver { prod_specs: Vec>, logup_specs: Vec>, num_fanin: usize, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> (Point, TowerProofs) { // XXX to sumcheck batched product argument with logup, we limit num_product_fanin to 2 // TODO mayber give a better naming? diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 260c17fae..13ef29a66 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -14,7 +14,7 @@ use mpcs::{Basefold, BasefoldDefault, BasefoldRSParams, PolynomialCommitmentSche use multilinear_extensions::{ mle::IntoMLE, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, }; -use transcript::Transcript; +use transcript::{BasicTranscript, Transcript}; use crate::{ circuit_builder::CircuitBuilder, @@ -126,7 +126,7 @@ fn test_rw_lk_expression_combination() { // get proof let prover = ZKVMProver::new(pk); - let mut transcript = Transcript::new(b"test"); + let mut transcript = BasicTranscript::new(b"test"); let wits_in = zkvm_witness .into_iter_sorted() .next() @@ -157,7 +157,7 @@ fn test_rw_lk_expression_combination() { // verify proof let verifier = ZKVMVerifier::new(vk.clone()); - let mut v_transcript = Transcript::new(b"test"); + let mut v_transcript = BasicTranscript::new(b"test"); // write commitment into transcript and derive challenges from it Pcs::write_commitment(&proof.wits_commit, &mut v_transcript).unwrap(); let verifier_challenges = [ @@ -261,7 +261,7 @@ fn test_single_add_instance_e2e() { let vk = pk.get_vk(); // single instance - let mut vm = VMState::new(CENO_PLATFORM, program.clone()); + let mut vm = VMState::new(CENO_PLATFORM, program.clone().into()); let all_records = vm .iter_until_halt() .collect::, _>>() @@ -305,12 +305,12 @@ fn test_single_add_instance_e2e() { .unwrap(); let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0]); - let transcript = Transcript::new(b"riscv"); + let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); - let transcript = Transcript::new(b"riscv"); + let transcript = BasicTranscript::new(b"riscv"); assert!( verifier .verify_proof(zkvm_proof, transcript) @@ -325,7 +325,7 @@ fn test_tower_proof_various_prod_size() { let num_vars = ceil_log2(leaf_layer_size); let mut rng = test_rng(); type E = GoldilocksExt2; - let mut transcript = Transcript::new(b"test_tower_proof"); + let mut transcript = BasicTranscript::new(b"test_tower_proof"); let leaf_layer: ArcMultilinearExtension = (0..leaf_layer_size) .map(|_| E::random(&mut rng)) .collect_vec() @@ -348,7 +348,7 @@ fn test_tower_proof_various_prod_size() { &mut transcript, ); - let mut transcript = Transcript::new(b"test_tower_proof"); + let mut transcript = BasicTranscript::new(b"test_tower_proof"); let (rt_tower_v, prod_point_and_eval, _, _) = TowerVerify::verify( vec![ layers[0] diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index c9593f24f..e3cb38b2d 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -12,7 +12,7 @@ use multilinear_extensions::{ virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, }; use sumcheck::structs::{IOPProof, IOPVerifierState}; -use transcript::Transcript; +use transcript::{ForkableTranscript, Transcript}; use crate::{ circuit_builder::SetTableAddrType, @@ -48,7 +48,7 @@ impl> ZKVMVerifier pub fn verify_proof( &self, vm_proof: ZKVMProof, - transcript: Transcript, + transcript: impl ForkableTranscript, ) -> Result { self.verify_proof_halt(vm_proof, transcript, true) } @@ -57,7 +57,7 @@ impl> ZKVMVerifier pub fn verify_proof_halt( &self, vm_proof: ZKVMProof, - transcript: Transcript, + transcript: impl ForkableTranscript, does_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. @@ -79,7 +79,7 @@ impl> ZKVMVerifier fn verify_proof_validity( &self, vm_proof: ZKVMProof, - mut transcript: Transcript, + mut transcript: impl ForkableTranscript, ) -> Result { // main invariant between opcode circuits and table circuits let mut prod_r = E::ONE; @@ -255,7 +255,7 @@ impl> ZKVMVerifier circuit_vk: &VerifyingKey, proof: &ZKVMOpcodeProof, pi: &[E], - transcript: &mut Transcript, + transcript: &mut impl Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS @@ -504,7 +504,7 @@ impl> ZKVMVerifier proof: &ZKVMTableProof, raw_pi: &[Vec], pi: &[E], - transcript: &mut Transcript, + transcript: &mut impl Transcript, num_logup_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], @@ -812,7 +812,7 @@ impl TowerVerify { tower_proofs: &TowerProofs, num_variables: Vec, num_fanin: usize, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> TowerVerifyResult { // XXX to sumcheck batched product argument with logup, we limit num_product_fanin to 2 // TODO mayber give a better naming? diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index b27f04e54..892babb3f 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -14,7 +14,7 @@ mod ops; pub use ops::*; mod program; -pub use program::{InsnRecord, ProgramTableCircuit}; +pub use program::{InsnRecord, ProgramTableCircuit, ProgramTableConfig}; mod ram; pub use ram::*; diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index f1ca5856b..8b7d8cbde 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -61,7 +61,7 @@ pub(crate) fn add_one_to_big_num(limb_modulo: F, limbs: &[F]) -> Vec( size: usize, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Vec { // println!("alpha_pow"); let alpha = transcript diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index a3d8064c6..b3d32eab3 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -177,7 +177,7 @@ mod tests { virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, }; use sumcheck::structs::{IOPProverStateV2, IOPVerifierState}; - use transcript::Transcript; + use transcript::BasicTranscript as Transcript; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, diff --git a/mpcs/benches/basefold.rs b/mpcs/benches/basefold.rs index 965d55035..19fb7aa03 100644 --- a/mpcs/benches/basefold.rs +++ b/mpcs/benches/basefold.rs @@ -18,11 +18,11 @@ use multilinear_extensions::{ mle::{DenseMultilinearExtension, MultilinearExtension}, virtual_poly_v2::ArcMultilinearExtension, }; -use transcript::Transcript; +use transcript::{BasicTranscript, Transcript}; type PcsGoldilocksRSCode = Basefold; type PcsGoldilocksBasecode = Basefold; -type T = Transcript; +type T = BasicTranscript; type E = GoldilocksExt2; const NUM_SAMPLES: usize = 10; @@ -292,7 +292,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks, + transcript: &mut impl Transcript, ) -> Result<(), Error> { write_digest_to_transcript(&comm.root(), transcript); Ok(()) @@ -470,7 +470,7 @@ where comm: &Self::CommitmentWithData, point: &[E], _eval: &E, // Opening does not need eval, except for sanity check - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { let timer = start_timer!(|| "Basefold::open"); @@ -550,7 +550,7 @@ where comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { let timer = start_timer!(|| "Basefold::batch_open"); let num_vars = polys.iter().map(|poly| poly.num_vars).max().unwrap(); @@ -772,7 +772,7 @@ where comm: &Self::CommitmentWithData, point: &[E], evals: &[E], - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { let timer = start_timer!(|| "Basefold::batch_open"); let num_vars = polys[0].num_vars(); @@ -858,7 +858,7 @@ where point: &[E], eval: &E, proof: &Self::Proof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error> { let timer = start_timer!(|| "Basefold::verify"); @@ -944,7 +944,7 @@ where points: &[Vec], evals: &[Evaluation], proof: &Self::Proof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error> { let timer = start_timer!(|| "Basefold::batch_verify"); // let key = "RAYON_NUM_THREADS"; @@ -1071,7 +1071,7 @@ where point: &[E], evals: &[E], proof: &Self::Proof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error> { let timer = start_timer!(|| "Basefold::simple batch verify"); let batch_size = evals.len(); diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index fe6b2b1e9..7d6e0b949 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -34,7 +34,7 @@ pub fn commit_phase>( pp: &>::ProverParameters, point: &[E], comm: &BasefoldCommitmentWithData, - transcript: &mut Transcript, + transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, ) -> (Vec>, BasefoldCommitPhaseProof) @@ -180,7 +180,7 @@ pub fn batch_commit_phase>( pp: &>::ProverParameters, point: &[E], comms: &[BasefoldCommitmentWithData], - transcript: &mut Transcript, + transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, coeffs: &[E], @@ -351,7 +351,7 @@ pub fn simple_batch_commit_phase>( point: &[E], batch_coeffs: &[E], comm: &BasefoldCommitmentWithData, - transcript: &mut Transcript, + transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, ) -> (Vec>, BasefoldCommitPhaseProof) diff --git a/mpcs/src/basefold/query_phase.rs b/mpcs/src/basefold/query_phase.rs index 946214c68..f0db40890 100644 --- a/mpcs/src/basefold/query_phase.rs +++ b/mpcs/src/basefold/query_phase.rs @@ -29,7 +29,7 @@ use super::{ }; pub fn prover_query_phase( - transcript: &mut Transcript, + transcript: &mut impl Transcript, comm: &BasefoldCommitmentWithData, trees: &[MerkleTree], num_verifier_queries: usize, @@ -65,7 +65,7 @@ where } pub fn batch_prover_query_phase( - transcript: &mut Transcript, + transcript: &mut impl Transcript, codeword_size: usize, comms: &[BasefoldCommitmentWithData], trees: &[MerkleTree], @@ -102,7 +102,7 @@ where } pub fn simple_batch_prover_query_phase( - transcript: &mut Transcript, + transcript: &mut impl Transcript, comm: &BasefoldCommitmentWithData, trees: &[MerkleTree], num_verifier_queries: usize, diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index b6716c19e..e9c67b866 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -4,7 +4,7 @@ use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; use serde::{Serialize, de::DeserializeOwned}; use std::fmt::Debug; -use transcript::Transcript; +use transcript::{BasicTranscript, Transcript}; use util::hash::Digest; pub mod sum_check; @@ -41,7 +41,7 @@ pub fn pcs_commit>( pub fn pcs_commit_and_write>( pp: &Pcs::ProverParam, poly: &DenseMultilinearExtension, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { Pcs::commit_and_write(pp, poly, transcript) } @@ -56,7 +56,7 @@ pub fn pcs_batch_commit>( pub fn pcs_batch_commit_and_write>( pp: &Pcs::ProverParam, polys: &[DenseMultilinearExtension], - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { Pcs::batch_commit_and_write(pp, polys, transcript) } @@ -67,7 +67,7 @@ pub fn pcs_open>( comm: &Pcs::CommitmentWithData, point: &[E], eval: &E, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { Pcs::open(pp, poly, comm, point, eval, transcript) } @@ -78,7 +78,7 @@ pub fn pcs_batch_open>( comms: &[Pcs::CommitmentWithData], points: &[Vec], evals: &[Evaluation], - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { Pcs::batch_open(pp, polys, comms, points, evals, transcript) } @@ -89,7 +89,7 @@ pub fn pcs_verify>( point: &[E], eval: &E, proof: &Pcs::Proof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error> { Pcs::verify(vp, comm, point, eval, proof, transcript) } @@ -100,7 +100,7 @@ pub fn pcs_batch_verify<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme], evals: &[Evaluation], proof: &Pcs::Proof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error> where Pcs::Commitment: 'a, @@ -132,7 +132,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn commit_and_write( pp: &Self::ProverParam, poly: &DenseMultilinearExtension, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { let comm = Self::commit(pp, poly)?; Self::write_commitment(&Self::get_pure_commitment(&comm), transcript)?; @@ -141,7 +141,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn write_commitment( comm: &Self::Commitment, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error>; fn get_pure_commitment(comm: &Self::CommitmentWithData) -> Self::Commitment; @@ -154,7 +154,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_commit_and_write( pp: &Self::ProverParam, polys: &[DenseMultilinearExtension], - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result { let comm = Self::batch_commit(pp, polys)?; Self::write_commitment(&Self::get_pure_commitment(&comm), transcript)?; @@ -167,7 +167,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { comm: &Self::CommitmentWithData, point: &[E], eval: &E, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result; fn batch_open( @@ -176,7 +176,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { comms: &[Self::CommitmentWithData], points: &[Vec], evals: &[Evaluation], - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result; /// This is a simple version of batch open: @@ -189,7 +189,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { comm: &Self::CommitmentWithData, point: &[E], evals: &[E], - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result; fn verify( @@ -198,7 +198,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { point: &[E], eval: &E, proof: &Self::Proof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error>; fn batch_verify( @@ -207,7 +207,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { points: &[Vec], evals: &[Evaluation], proof: &Self::Proof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error>; fn simple_batch_verify( @@ -216,7 +216,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { point: &[E], evals: &[E], proof: &Self::Proof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(), Error>; } @@ -232,7 +232,7 @@ where point: &[E], eval: &E, ) -> Result { - let mut transcript = Transcript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::open(pp, poly, comm, point, eval, &mut transcript) } @@ -243,7 +243,7 @@ where points: &[Vec], evals: &[Evaluation], ) -> Result { - let mut transcript = Transcript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::batch_open(pp, polys, comms, points, evals, &mut transcript) } @@ -254,7 +254,7 @@ where eval: &E, proof: &Self::Proof, ) -> Result<(), Error> { - let mut transcript = Transcript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::verify(vp, comm, point, eval, proof, &mut transcript) } @@ -268,7 +268,7 @@ where where Self::Commitment: 'a, { - let mut transcript = Transcript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::batch_verify(vp, comms, points, evals, proof, &mut transcript) } } @@ -379,6 +379,8 @@ pub mod test_util { mle::MultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, }; use rand::rngs::OsRng; + #[cfg(test)] + use transcript::BasicTranscript; use transcript::Transcript; pub fn setup_pcs>( @@ -414,7 +416,7 @@ pub mod test_util { pub fn get_point_from_challenge( num_vars: usize, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Vec { (0..num_vars) .map(|_| transcript.get_and_append_challenge(b"Point").elements) @@ -423,7 +425,7 @@ pub mod test_util { pub fn get_points_from_challenge( num_vars: impl Fn(usize) -> usize, num_points: usize, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Vec> { (0..num_points) .map(|i| get_point_from_challenge(num_vars(i), transcript)) @@ -433,7 +435,7 @@ pub mod test_util { pub fn commit_polys_individually>( pp: &Pcs::ProverParam, polys: &[DenseMultilinearExtension], - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Vec { polys .iter() @@ -454,7 +456,7 @@ pub mod test_util { // Commit and open let (comm, eval, proof, challenge) = { - let mut transcript = Transcript::new(b"BaseFold"); + let mut transcript = BasicTranscript::new(b"BaseFold"); let poly = gen_rand_poly(num_vars); let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); @@ -470,7 +472,7 @@ pub mod test_util { }; // Verify { - let mut transcript = Transcript::new(b"BaseFold"); + let mut transcript = BasicTranscript::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_ext(&eval); @@ -505,7 +507,7 @@ pub mod test_util { .collect_vec(); let (comms, evals, proof, challenge) = { - let mut transcript = Transcript::new(b"BaseFold"); + let mut transcript = BasicTranscript::new(b"BaseFold"); let polys = gen_rand_polys(|i| num_vars - (i >> 1), batch_size, gen_rand_poly); let comms = @@ -536,7 +538,7 @@ pub mod test_util { }; // Batch verify { - let mut transcript = Transcript::new(b"BaseFold"); + let mut transcript = BasicTranscript::new(b"BaseFold"); let comms = comms .iter() .map(|comm| { @@ -577,7 +579,7 @@ pub mod test_util { let (pp, vp) = setup_pcs::(num_vars); let (comm, evals, proof, challenge) = { - let mut transcript = Transcript::new(b"BaseFold"); + let mut transcript = BasicTranscript::new(b"BaseFold"); let polys = gen_rand_polys(|_| num_vars, batch_size, gen_rand_poly); let comm = Pcs::batch_commit_and_write(&pp, polys.as_slice(), &mut transcript).unwrap(); @@ -601,7 +603,7 @@ pub mod test_util { }; // Batch verify { - let mut transcript = Transcript::new(b"BaseFold"); + let mut transcript = BasicTranscript::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); diff --git a/mpcs/src/sum_check.rs b/mpcs/src/sum_check.rs index cd42360df..f2fcf0e47 100644 --- a/mpcs/src/sum_check.rs +++ b/mpcs/src/sum_check.rs @@ -62,7 +62,7 @@ where num_vars: usize, virtual_poly: VirtualPolynomial, sum: E, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result, Error>; fn verify( @@ -71,7 +71,7 @@ where degree: usize, sum: E, proof: &SumcheckProof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(E, Vec), Error>; } diff --git a/mpcs/src/sum_check/classic.rs b/mpcs/src/sum_check/classic.rs index 6d7a3e7a8..f99d832df 100644 --- a/mpcs/src/sum_check/classic.rs +++ b/mpcs/src/sum_check/classic.rs @@ -186,7 +186,7 @@ pub trait ClassicSumCheckProver: Clone + Debug { pub trait ClassicSumCheckRoundMessage: Sized + Debug { type Auxiliary: Default; - fn write(&self, transcript: &mut Transcript) -> Result<(), Error>; + fn write(&self, transcript: &mut impl Transcript) -> Result<(), Error>; fn sum(&self) -> E; @@ -234,7 +234,7 @@ where num_vars: usize, virtual_poly: VirtualPolynomial, sum: E, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(Vec, Vec, SumcheckProof), Error> { let _timer = start_timer!(|| { let degree = virtual_poly.expression.degree(); @@ -290,7 +290,7 @@ where degree: usize, sum: E, proof: &SumcheckProof, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Result<(E, Vec), Error> { let (msgs, challenges) = { let mut msgs = Vec::with_capacity(num_vars); @@ -321,7 +321,7 @@ mod tests { sum_check::eq_xy_eval, util::{arithmetic::inner_product, expression::Query, poly_iter_ext}, }; - use transcript::Transcript; + use transcript::BasicTranscript; use super::*; use goldilocks::{Goldilocks as Fr, GoldilocksExt2 as E}; @@ -368,7 +368,7 @@ mod tests { &build_eq_x_r_vec(&points[1]), ) * Fr::from(4) * Fr::from(2); // The third polynomial is summed twice because the hypercube is larger - let mut transcript = Transcript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); let (challenges, evals, proof) = > as SumCheck>::prove( &(), @@ -383,7 +383,7 @@ mod tests { assert_eq!(polys[1].evaluate(&challenges), evals[1]); assert_eq!(polys[2].evaluate(&challenges[..1]), evals[2]); - let mut transcript = Transcript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); let (new_sum, verifier_challenges) = > as SumCheck< E, @@ -400,7 +400,7 @@ mod tests { + evals[2] * eq_xy_eval(&points[1], &challenges[..1]) * Fr::from(4) ); - let mut transcript = Transcript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); > as SumCheck>::verify( &(), diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index a6107bfe7..12f46880f 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -37,7 +37,7 @@ pub struct Coefficients(FieldType); impl ClassicSumCheckRoundMessage for Coefficients { type Auxiliary = (); - fn write(&self, transcript: &mut Transcript) -> Result<(), Error> { + fn write(&self, transcript: &mut impl Transcript) -> Result<(), Error> { match &self.0 { FieldType::Ext(coeffs) => transcript.append_field_element_exts(coeffs), FieldType::Base(coeffs) => coeffs diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 754768140..499a053b5 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -9,7 +9,7 @@ use poseidon::poseidon::Poseidon; pub fn write_digest_to_transcript( digest: &Digest, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) { digest .0 diff --git a/mpcs/src/util/merkle_tree.rs b/mpcs/src/util/merkle_tree.rs index 66f896710..d24840496 100644 --- a/mpcs/src/util/merkle_tree.rs +++ b/mpcs/src/util/merkle_tree.rs @@ -182,7 +182,7 @@ where self.inner.iter() } - pub fn write_transcript(&self, transcript: &mut Transcript) { + pub fn write_transcript(&self, transcript: &mut impl Transcript) { self.inner .iter() .for_each(|hash| write_digest_to_transcript(hash, transcript)); diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index 3a101fe53..757106529 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -16,7 +16,7 @@ use multilinear_extensions::{ util::max_usable_threads, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial}, }; -use transcript::Transcript; +use transcript::BasicTranscript as Transcript; criterion_group!(benches, sumcheck_fn, devirgo_sumcheck_fn,); criterion_main!(benches); diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index f27dd435c..7cee2169b 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -33,7 +33,7 @@ impl IOPProverState { pub fn prove_batch_polys( max_thread_id: usize, mut polys: Vec>, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> (IOPProof, IOPProverState) { assert!(!polys.is_empty()); assert_eq!(polys.len(), max_thread_id); @@ -476,7 +476,7 @@ impl IOPProverState { #[tracing::instrument(skip_all, name = "sumcheck::prove_parallel")] pub fn prove_parallel( poly: VirtualPolynomial, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> (IOPProof, IOPProverState) { let (num_variables, max_degree) = (poly.aux_info.num_variables, poly.aux_info.max_degree); diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index 4e38c0911..3dae3a101 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -39,7 +39,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { pub fn prove_batch_polys( max_thread_id: usize, mut polys: Vec>, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> (IOPProof, IOPProverStateV2<'a, E>) { assert!(!polys.is_empty()); assert_eq!(polys.len(), max_thread_id); @@ -604,7 +604,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { #[tracing::instrument(skip_all, name = "sumcheck::prove_parallel")] pub fn prove_parallel( poly: VirtualPolynomialV2<'a, E>, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> (IOPProof, IOPProverStateV2<'a, E>) { let (num_variables, max_degree) = (poly.aux_info.max_num_variables, poly.aux_info.max_degree); diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index d0ea28e4f..0c24e83cd 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -6,7 +6,7 @@ use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use multilinear_extensions::{mle::MultilinearExtension, virtual_poly::VirtualPolynomial}; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; -use transcript::Transcript; +use transcript::{BasicTranscript, Transcript}; use crate::{ structs::{IOPProverState, IOPVerifierState}, @@ -21,7 +21,7 @@ fn test_sumcheck( num_products: usize, ) { let mut rng = test_rng(); - let mut transcript = Transcript::new(b"test"); + let mut transcript = BasicTranscript::new(b"test"); let (poly, asserted_sum) = VirtualPolynomial::::random(nv, num_multiplicands_range, num_products, &mut rng); @@ -29,7 +29,7 @@ fn test_sumcheck( #[allow(deprecated)] let (proof, _) = IOPProverState::::prove_parallel(poly.clone(), &mut transcript); - let mut transcript = Transcript::new(b"test"); + let mut transcript = BasicTranscript::new(b"test"); let subclaim = IOPVerifierState::::verify(asserted_sum, &proof, &poly_info, &mut transcript); assert!( poly.evaluate( @@ -58,7 +58,7 @@ fn test_sumcheck_internal( let mut verifier_state = IOPVerifierState::verifier_init(&poly_info); let mut challenge = None; - let mut transcript = Transcript::new(b"test"); + let mut transcript = BasicTranscript::new(b"test"); transcript.append_message(b"initializing transcript for testing"); @@ -145,7 +145,7 @@ fn test_extract_sum() { fn test_extract_sum_helper() { let mut rng = test_rng(); - let mut transcript = Transcript::::new(b"test"); + let mut transcript = BasicTranscript::::new(b"test"); let (poly, asserted_sum) = VirtualPolynomial::::random(8, (2, 3), 3, &mut rng); #[allow(deprecated)] let (proof, _) = IOPProverState::::prove_parallel(poly, &mut transcript); diff --git a/sumcheck/src/verifier.rs b/sumcheck/src/verifier.rs index f67c09103..4dcd8767e 100644 --- a/sumcheck/src/verifier.rs +++ b/sumcheck/src/verifier.rs @@ -13,7 +13,7 @@ impl IOPVerifierState { claimed_sum: E, proof: &IOPProof, aux_info: &VPAuxInfo, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> SumCheckSubClaim { if aux_info.num_variables == 0 { return SumCheckSubClaim { @@ -66,7 +66,7 @@ impl IOPVerifierState { pub(crate) fn verify_round_and_update_state( &mut self, prover_msg: &IOPProverMessage, - transcript: &mut Transcript, + transcript: &mut impl Transcript, ) -> Challenge { let start = start_timer!(|| format!("sum check verify {}-th round and update state", self.round)); diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index b38dfc27a..47c32cdf4 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -3,14 +3,14 @@ use ff_ext::ExtensionField; use goldilocks::SmallField; use poseidon::poseidon_permutation::PoseidonPermutation; -use crate::Challenge; +use crate::{Challenge, ForkableTranscript, Transcript}; #[derive(Clone)] -pub struct Transcript { +pub struct BasicTranscript { permutation: PoseidonPermutation, } -impl Transcript { +impl BasicTranscript { /// Create a new IOP transcript. pub fn new(label: &'static [u8]) -> Self { let mut perm = PoseidonPermutation::new(core::iter::repeat(E::BaseField::ZERO)); @@ -21,96 +21,43 @@ impl Transcript { } } -impl Transcript { - /// Fork this transcript into n different threads. - pub fn fork(self, n: usize) -> Vec { - let mut forks = Vec::with_capacity(n); - for i in 0..n { - let mut fork = self.clone(); - fork.append_field_element(&(i as u64).into()); - forks.push(fork); - } - forks - } - - // Append the message to the transcript. - pub fn append_message(&mut self, msg: &[u8]) { - let msg_f = E::BaseField::bytes_to_field_elements(msg); - self.permutation.set_from_slice(&msg_f, 0); +impl Transcript for BasicTranscript { + fn append_field_elements(&mut self, elements: &[E::BaseField]) { + self.permutation.set_from_slice(elements, 0); self.permutation.permute(); } - // Append the field extension element to the transcript. - pub fn append_field_element_ext(&mut self, element: &E) { - self.permutation.set_from_slice(element.as_bases(), 0); - self.permutation.permute(); + fn append_field_element_ext(&mut self, element: &E) { + self.append_field_elements(element.as_bases()) } - pub fn append_field_element_exts(&mut self, element: &[E]) { - for e in element { - self.append_field_element_ext(e); - } - } - - // Append the field elemetn to the transcript. - pub fn append_field_element(&mut self, element: &E::BaseField) { - self.permutation.set_from_slice(&[*element], 0); - self.permutation.permute(); - } - - // Append the challenge to the transcript. - pub fn append_challenge(&mut self, challenge: Challenge) { - self.permutation - .set_from_slice(challenge.elements.as_bases(), 0); - self.permutation.permute(); - } - - // // Append the message to the transcript. - // pub fn append_serializable_element( - // &mut self, - // _label: &'static [u8], - // _element: &S, - // ) { - // unimplemented!() - // } - - // Generate the challenge from the current transcript - // and append it to the transcript. - // - // The output field element is statistical uniform as long - // as the field has a size less than 2^384. - pub fn get_and_append_challenge(&mut self, label: &'static [u8]) -> Challenge { - self.append_message(label); - - let challenge = Challenge { - elements: E::from_limbs(self.permutation.squeeze()), - }; - challenge - } + fn read_challenge(&mut self) -> Challenge { + // notice `from_bases` and `from_limbs` has the same behavior but + // `from_bases` has a sanity check for length of input slices + // while `from_limbs` use the first two fields silently + // we select `from_base` here to make it being more clear that + // we only use the first 2 fields here to construct the + // extension field (i.e. the challenge) + let r = E::from_bases(&self.permutation.squeeze()[..2]); - pub fn commit_rolling(&mut self) { - // do nothing + Challenge { elements: r } } - pub fn read_field_element_ext(&self) -> E { + fn read_field_element_exts(&self) -> Vec { unimplemented!() } - pub fn read_field_element_exts(&self) -> Vec { + fn read_field_element(&self) -> E::BaseField { unimplemented!() } - pub fn read_field_element(&self) -> E::BaseField { + fn send_challenge(&self, _challenge: E) { unimplemented!() } - pub fn read_challenge(&mut self) -> Challenge { - let r = E::from_bases(&self.permutation.squeeze()[..2]); - - Challenge { elements: r } - } - - pub fn send_challenge(&self, _challenge: E) { - unimplemented!() + fn commit_rolling(&mut self) { + // do nothing } } + +impl ForkableTranscript for BasicTranscript {} diff --git a/transcript/src/lib.rs b/transcript/src/lib.rs index 376bdb1ca..3c8c56038 100644 --- a/transcript/src/lib.rs +++ b/transcript/src/lib.rs @@ -5,7 +5,7 @@ pub mod basic; pub mod syncronized; -pub use basic::Transcript; +pub use basic::BasicTranscript; pub use syncronized::TranscriptSyncronized; mod hasher; @@ -14,3 +14,85 @@ mod hasher; pub struct Challenge { pub elements: F, } + +use ff_ext::ExtensionField; +use goldilocks::SmallField; +/// The Transcript trait +pub trait Transcript { + /// Append slice of base field elemets to the transcript. Implement + /// has to override at least one of append_field_elements / append_field_element + fn append_field_elements(&mut self, elements: &[E::BaseField]) { + for e in elements { + self.append_field_element(e); + } + } + + // Append a single field element to the transcript. Implement + /// has to override at least one of append_field_elements / append_field_element + fn append_field_element(&mut self, element: &E::BaseField) { + self.append_field_elements(&[*element]) + } + + /// Append the message to the transcript. + fn append_message(&mut self, msg: &[u8]) { + let msg_f = E::BaseField::bytes_to_field_elements(msg); + self.append_field_elements(&msg_f); + } + + /// Append the field extension element to the transcript.Implement + /// has to override at least one of append_field_element_ext / append_field_element_exts + fn append_field_element_ext(&mut self, element: &E) { + self.append_field_element_exts(&[*element]) + } + + /// Append slice of field extension elements to the transcript. Implement + /// has to override at least one of append_field_element_ext / append_field_element_exts + fn append_field_element_exts(&mut self, element: &[E]) { + for e in element { + self.append_field_element_ext(e); + } + } + + /// Append the challenge to the transcript. + fn append_challenge(&mut self, challenge: Challenge) { + self.append_field_element_ext(&challenge.elements) + } + + /// Generate the challenge from the current transcript + /// and append it to the transcript. + /// + /// The output field element is statistical uniform as long + /// as the field has a size less than 2^384. + fn get_and_append_challenge(&mut self, label: &'static [u8]) -> Challenge { + self.append_message(label); + self.read_challenge() + } + + fn read_field_element_ext(&self) -> E { + self.read_field_element_exts()[0] + } + + fn read_field_element_exts(&self) -> Vec; + + fn read_field_element(&self) -> E::BaseField; + + fn read_challenge(&mut self) -> Challenge; + + fn send_challenge(&self, challenge: E); + + fn commit_rolling(&mut self); +} + +/// Forkable Transcript trait, enable fork method +pub trait ForkableTranscript: Transcript + Sized + Clone { + /// Fork this transcript into n different threads. + fn fork(self, n: usize) -> Vec { + let mut forks = Vec::with_capacity(n); + for i in 0..n { + let mut fork = self.clone(); + fork.append_field_element(&(i as u64).into()); + forks.push(fork); + } + forks + } +} diff --git a/transcript/src/syncronized.rs b/transcript/src/syncronized.rs index c7e1fdb83..51a32e2b6 100644 --- a/transcript/src/syncronized.rs +++ b/transcript/src/syncronized.rs @@ -2,9 +2,9 @@ use std::array; use crossbeam_channel::{Receiver, Sender, bounded}; use ff_ext::ExtensionField; -use goldilocks::SmallField; +// use goldilocks::SmallField; -use crate::Challenge; +use crate::{Challenge, Transcript}; #[derive(Clone)] pub struct TranscriptSyncronized { @@ -42,36 +42,24 @@ impl TranscriptSyncronized { } } -impl TranscriptSyncronized { - // Append the message to the transcript. - pub fn append_message(&mut self, msg: &[u8]) { - let msg_f = E::BaseField::bytes_to_field_elements(msg); - self.bf_append_tx[self.rolling_index].send(msg_f).unwrap(); - } - - pub fn append_field_element_ext(&mut self, element: &E) { - self.ef_append_tx[self.rolling_index] +impl Transcript for TranscriptSyncronized { + fn append_field_element(&mut self, element: &E::BaseField) { + self.bf_append_tx[self.rolling_index] .send(vec![*element]) .unwrap(); } - pub fn append_field_element_exts(&mut self, element: &[E]) { + fn append_field_element_exts(&mut self, element: &[E]) { self.ef_append_tx[self.rolling_index] .send(element.to_vec()) .unwrap(); } - pub fn append_field_element(&mut self, element: &E::BaseField) { - self.bf_append_tx[self.rolling_index] - .send(vec![*element]) - .unwrap(); - } - - pub fn append_challenge(&mut self, _challenge: Challenge) { + fn append_challenge(&mut self, _challenge: Challenge) { unimplemented!() } - pub fn get_and_append_challenge(&mut self, _label: &'static [u8]) -> Challenge { + fn get_and_append_challenge(&mut self, _label: &'static [u8]) -> Challenge { Challenge { elements: self.challenge_rx[self.rolling_index].recv().unwrap(), } @@ -91,29 +79,25 @@ impl TranscriptSyncronized { // } } - pub fn read_field_element_ext(&self) -> E { - self.ef_append_rx[self.rolling_index].recv().unwrap()[0] - } - - pub fn read_field_element_exts(&self) -> Vec { + fn read_field_element_exts(&self) -> Vec { self.ef_append_rx[self.rolling_index].recv().unwrap() } - pub fn read_field_element(&self) -> E::BaseField { + fn read_field_element(&self) -> E::BaseField { self.bf_append_rx[self.rolling_index].recv().unwrap()[0] } - pub fn read_challenge(&self, _challenge: Challenge) { + fn read_challenge(&mut self) -> Challenge { unimplemented!() } - pub fn send_challenge(&self, challenge: E) { + fn send_challenge(&self, challenge: E) { self.challenge_tx[self.rolling_index] .send(challenge) .unwrap(); } - pub fn commit_rolling(&mut self) { + fn commit_rolling(&mut self) { self.rolling_index = (self.rolling_index + 1) % 2 } }