From 0efb9d1b380e8eaa49add1fa3a5ca9cecc132c96 Mon Sep 17 00:00:00 2001 From: naure Date: Fri, 22 Nov 2024 14:50:07 +0100 Subject: [PATCH 01/21] feat/private-input-2: fix padding assignment (#623) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Did not work when `len=1` because it’s padded to length 2. Fixed and simplified. Co-authored-by: Aurélien Nicolas --- ceno_zkvm/src/tables/ram/ram_impl.rs | 31 +++++++++++----------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 07f36cdab..d293e5d83 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -392,26 +392,19 @@ impl DynVolatileRamTableConfig }); // set padding with well-form address - if final_mem.len().next_power_of_two() - final_mem.len() > 0 { - let paddin_entry_start = final_mem.len(); - final_table - .par_iter_mut() - .skip(final_mem.len()) - .enumerate() - .with_min_len(MIN_PAR_SIZE) - .for_each(|(i, row)| { - // Assign value limbs. - self.final_v.iter().for_each(|limb| { - set_val!(row, limb, 0u64); - }); - set_val!( - row, - self.addr, - DVRAM::addr(&self.params, paddin_entry_start + i) as u64 - ); - set_val!(row, self.final_cycle, 0_u64); + final_table + .par_iter_mut() + .enumerate() + .skip(final_mem.len()) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(i, row)| { + // Assign value limbs. + self.final_v.iter().for_each(|limb| { + set_val!(row, limb, 0u64); }); - } + set_val!(row, self.addr, DVRAM::addr(&self.params, i) as u64); + set_val!(row, self.final_cycle, 0_u64); + }); Ok(final_table) } From ea9fa2d6f445a7a62196156c8b26ec8d722af80b Mon Sep 17 00:00:00 2001 From: naure Date: Fri, 22 Nov 2024 15:07:20 +0100 Subject: [PATCH 02/21] cleanup/exclusive-ranges (#620) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Switch to half-open / usual range for platform addresses. - Expose the `Range` type (#621). - Introduce methods `iter_addresses` which is not the same as `Range::iter`. (This became easy to fix after recent changes and useful for upcoming changes.) --------- Co-authored-by: Aurélien Nicolas --- ceno_emul/src/addr.rs | 31 +++++++- ceno_emul/src/platform.rs | 71 +++++-------------- ceno_emul/src/tracer.rs | 2 +- ceno_emul/src/vm_state.rs | 2 +- ceno_emul/tests/test_elf.rs | 2 +- ceno_zkvm/examples/riscv_opcodes.rs | 12 ++-- ceno_zkvm/src/bin/e2e.rs | 9 ++- .../src/instructions/riscv/rv32im/mmu.rs | 16 ++--- ceno_zkvm/src/tables/ram.rs | 8 +-- ceno_zkvm/src/tables/ram/ram_circuit.rs | 14 +++- 10 files changed, 83 insertions(+), 84 deletions(-) diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 200739ce7..dfd9056b7 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -14,7 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{fmt, ops}; +use std::{ + fmt, + ops::{self, Range}, +}; pub const WORD_SIZE: usize = 4; pub const PC_WORD_SIZE: usize = 4; @@ -192,3 +195,29 @@ impl ops::AddAssign for ByteAddr { self.0 += rhs; } } + +pub trait IterAddresses { + fn iter_addresses(&self) -> impl Iterator; +} + +impl IterAddresses for Range { + fn iter_addresses(&self) -> impl Iterator { + self.clone().step_by(WORD_SIZE) + } +} + +impl<'a, T: GetAddr> IterAddresses for &'a [T] { + fn iter_addresses(&self) -> impl Iterator { + self.iter().map(T::get_addr) + } +} + +pub trait GetAddr { + fn get_addr(&self) -> Addr; +} + +impl GetAddr for Addr { + fn get_addr(&self) -> Addr { + *self + } +} diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 8836bdfbb..fb11d166d 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -1,3 +1,5 @@ +use std::ops::Range; + use crate::addr::{Addr, RegIdx}; /// The Platform struct holds the parameters of the VM. @@ -7,20 +9,18 @@ use crate::addr::{Addr, RegIdx}; /// - codes of environment calls. #[derive(Clone, Debug)] pub struct Platform { - pub rom_start: Addr, - pub rom_end: Addr, - pub ram_start: Addr, - pub ram_end: Addr, + pub rom: Range, + pub ram: Range, + pub public_io: Range, pub stack_top: Addr, /// If true, ecall instructions are no-op instead of trap. Testing only. pub unsafe_ecall_nop: bool, } pub const CENO_PLATFORM: Platform = Platform { - rom_start: 0x2000_0000, - rom_end: 0x3000_0000 - 1, - ram_start: 0x8000_0000, - ram_end: 0xFFFF_0000 - 1, + rom: 0x2000_0000..0x3000_0000, + ram: 0x8000_0000..0xFFFF_0000, + public_io: 0x3000_1000..0x3000_2000, stack_top: 0xC0000000, unsafe_ecall_nop: false, }; @@ -28,53 +28,16 @@ pub const CENO_PLATFORM: Platform = Platform { impl Platform { // Virtual memory layout. - pub const fn rom_start(&self) -> Addr { - self.rom_start - } - - pub const fn rom_end(&self) -> Addr { - self.rom_end - } - pub fn is_rom(&self, addr: Addr) -> bool { - (self.rom_start()..=self.rom_end()).contains(&addr) - } - - // TODO figure out a proper region for public io - pub const fn public_io_start(&self) -> Addr { - 0x3000_1000 - } - - pub const fn public_io_end(&self) -> Addr { - 0x3000_2000 - 1 - } - - pub const fn ram_start(&self) -> Addr { - if cfg!(feature = "forbid_overflow") { - // -1<<11 == 0x800 is the smallest negative 'immediate' - // offset we can have in memory instructions. - // So if we stay away from it, we are safe. - assert!(self.ram_start >= 0x800); - } - self.ram_start - } - - pub const fn ram_end(&self) -> Addr { - if cfg!(feature = "forbid_overflow") { - // (1<<11) - 1 == 0x7ff is the largest positive 'immediate' - // offset we can have in memory instructions. - // So if we stay away from it, we are safe. - assert!(self.ram_end < -(1_i32 << 11) as u32) - } - self.ram_end + self.rom.contains(&addr) } pub fn is_ram(&self, addr: Addr) -> bool { - (self.ram_start()..=self.ram_end()).contains(&addr) + self.ram.contains(&addr) } pub fn is_pub_io(&self, addr: Addr) -> bool { - (self.public_io_start()..=self.public_io_end()).contains(&addr) + self.public_io.contains(&addr) } /// Virtual address of a register. @@ -91,7 +54,7 @@ impl Platform { // Startup. pub const fn pc_base(&self) -> Addr { - self.rom_start() + self.rom.start } // Permissions. @@ -139,17 +102,17 @@ impl Platform { #[cfg(test)] mod tests { use super::*; - use crate::VMState; + use crate::{VMState, WORD_SIZE}; #[test] fn test_no_overlap() { let p = CENO_PLATFORM; assert!(p.can_execute(p.pc_base())); // ROM and RAM do not overlap. - assert!(!p.is_rom(p.ram_start())); - assert!(!p.is_rom(p.ram_end())); - assert!(!p.is_ram(p.rom_start())); - assert!(!p.is_ram(p.rom_end())); + assert!(!p.is_rom(p.ram.start)); + assert!(!p.is_rom(p.ram.end - WORD_SIZE as Addr)); + assert!(!p.is_ram(p.rom.start)); + assert!(!p.is_ram(p.rom.end - WORD_SIZE as Addr)); // Registers do not overlap with ROM or RAM. for reg in [ Platform::register_vma(0), diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index d9f87b695..bbb1e6a51 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -199,7 +199,7 @@ impl StepRecord { Some(value), Some(Change::new(value, value)), Some(WriteOp { - addr: CENO_PLATFORM.ram_start().into(), + addr: CENO_PLATFORM.ram.start.into(), value: Change { before: value, after: value, diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 68dd4adaf..fe8013448 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -123,7 +123,7 @@ impl EmuContext for VMState { // Read two registers, write one register, write one memory word, and branch. tracing::warn!("ecall ignored: syscall_id={}", function); self.store_register(DecodedInstruction::RD_NULL as RegIdx, 0)?; - let addr = self.platform.ram_start().into(); + let addr = self.platform.ram.start.into(); self.store_memory(addr, self.peek_memory(addr))?; self.set_pc(ByteAddr(self.pc) + PC_STEP_SIZE); Ok(true) diff --git a/ceno_emul/tests/test_elf.rs b/ceno_emul/tests/test_elf.rs index 3168d265a..ca7a14c1c 100644 --- a/ceno_emul/tests/test_elf.rs +++ b/ceno_emul/tests/test_elf.rs @@ -27,7 +27,7 @@ fn test_ceno_rt_mem() -> Result<()> { let mut state = VMState::new_from_elf(CENO_PLATFORM, program_elf)?; let _steps = run(&mut state)?; - let value = state.peek_memory(CENO_PLATFORM.ram_start().into()); + let value = state.peek_memory(CENO_PLATFORM.ram.start.into()); assert_eq!(value, 6765, "Expected Fibonacci 20, got {}", value); Ok(()) } diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index d0ad76c1c..0711bd58f 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -42,10 +42,10 @@ const PROGRAM_CODE: [u32; PROGRAM_SIZE] = { let mut program: [u32; PROGRAM_SIZE] = [ECALL_HALT; PROGRAM_SIZE]; declare_program!( program, - encode_rv32(LUI, 0, 0, 10, CENO_PLATFORM.public_io_start()), // lui x10, public_io - encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10) - encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10) - encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10) + encode_rv32(LUI, 0, 0, 10, CENO_PLATFORM.public_io.start), // lui x10, public_io + encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10) + encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10) + encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10) // Main loop. encode_rv32(ADD, 1, 4, 4, 0), // add x4, x1, x4 encode_rv32(ADD, 2, 3, 3, 0), // add x3, x2, x3 @@ -90,8 +90,8 @@ fn main() { }) .collect(), ); - let mem_addresses = CENO_PLATFORM.ram_start()..=CENO_PLATFORM.ram_end(); - let io_addresses = CENO_PLATFORM.public_io_start()..=CENO_PLATFORM.public_io_end(); + let mem_addresses = CENO_PLATFORM.ram.clone(); + let io_addresses = CENO_PLATFORM.public_io.clone(); let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let mut fmt_layer = fmt::layer() diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 5c79c27dc..a4a2b6087 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -79,17 +79,16 @@ fn main() { Preset::Sp1 => Platform { // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. stack_top: 0x0020_0400, - rom_start: 0x0020_0800, - rom_end: 0x003f_ffff, - ram_start: 0x0020_0000, - ram_end: 0xFFFF_0000 - 1, + rom: 0x0020_0800..0x0040_0000, + ram: 0x0020_0000..0xFFFF_0000, unsafe_ecall_nop: true, + ..CENO_PLATFORM }, }; tracing::info!("Running on platform {:?}", args.platform); const STACK_SIZE: u32 = 256; - let mut mem_padder = MemPadder::new(platform.ram_start()..=platform.ram_end()); + let mut mem_padder = MemPadder::new(platform.ram.clone()); tracing::info!("Loading ELF file: {}", args.elf); let elf_bytes = fs::read(&args.elf).expect("read elf file"); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 32bb5ce41..66eafd891 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -1,6 +1,6 @@ -use std::{collections::HashSet, iter::zip, ops::RangeInclusive}; +use std::{collections::HashSet, iter::zip, ops::Range}; -use ceno_emul::{Addr, Cycle, WORD_SIZE, Word}; +use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; use ff_ext::ExtensionField; use itertools::{Itertools, chain}; @@ -48,11 +48,7 @@ impl MmuConfig { io_addrs: &[Addr], ) { assert!( - chain( - static_mem_init.iter().map(|record| record.addr), - io_addrs.iter().copied(), - ) - .all_unique(), + chain!(static_mem_init.iter_addresses(), io_addrs.iter_addresses()).all_unique(), "memory addresses must be unique" ); @@ -107,7 +103,7 @@ impl MmuConfig { } pub struct MemPadder { - valid_addresses: RangeInclusive, + valid_addresses: Range, used_addresses: HashSet, } @@ -118,7 +114,7 @@ impl MemPadder { /// /// Require: `values.len() <= padded_len <= address_range.len()` pub fn init_mem( - address_range: RangeInclusive, + address_range: Range, padded_len: usize, values: &[Word], ) -> Vec { @@ -129,7 +125,7 @@ impl MemPadder { records } - pub fn new(valid_addresses: RangeInclusive) -> Self { + pub fn new(valid_addresses: Range) -> Self { Self { valid_addresses, used_addresses: HashSet::new(), diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 182892603..88e968e69 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -19,11 +19,11 @@ impl DynVolatileRamTable for DynMemTable { const ZERO_INIT: bool = true; fn offset_addr(params: &ProgramParams) -> Addr { - params.platform.ram_start() + params.platform.ram.start } fn end_addr(params: &ProgramParams) -> Addr { - params.platform.ram_end() + params.platform.ram.end } fn name() -> &'static str { @@ -41,11 +41,11 @@ impl DynVolatileRamTable for PrivateMemTable { const ZERO_INIT: bool = false; fn offset_addr(params: &ProgramParams) -> Addr { - params.platform.ram_start() + params.platform.ram.start } fn end_addr(params: &ProgramParams) -> Addr { - params.platform.ram_end() + params.platform.ram.end } fn name() -> &'static str { diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 81f0fcfaa..584489781 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, marker::PhantomData}; -use ceno_emul::{Addr, Cycle, WORD_SIZE, Word}; +use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; use ff_ext::ExtensionField; use crate::{ @@ -26,6 +26,18 @@ pub struct MemFinalRecord { pub value: Word, } +impl GetAddr for MemInitRecord { + fn get_addr(&self) -> Addr { + self.addr + } +} + +impl GetAddr for MemFinalRecord { + fn get_addr(&self) -> Addr { + self.addr + } +} + /// - **Non-Volatile**: The initial values can be set to any arbitrary value. /// /// **Special Note**: From 3a72862d39a2a989d028d3bbeea98161b3c1263a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Mon, 25 Nov 2024 14:49:57 +0800 Subject: [PATCH 03/21] chore: remove `.iter()` (#582) --- Cargo.lock | 1 + ceno_emul/Cargo.toml | 1 + ceno_emul/src/rv32im.rs | 3 ++- ceno_emul/src/vm_state.rs | 2 +- ceno_zkvm/src/scheme/mock_prover.rs | 10 +++++----- ceno_zkvm/src/scheme/prover.rs | 23 ++++++++++------------- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d400b0e9d..7a9651cc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -241,6 +241,7 @@ dependencies = [ "anyhow", "ceno-examples", "elf", + "itertools 0.13.0", "num-derive", "num-traits", "strum", diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index 6b93a3f0a..6c172d97f 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -7,6 +7,7 @@ version.workspace = true [dependencies] anyhow = { version = "1.0", default-features = false } elf = "0.7" +itertools.workspace = true num-derive.workspace = true num-traits.workspace = true strum.workspace = true diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 19dcd2ecf..4554bdb08 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -15,6 +15,7 @@ // limitations under the License. use anyhow::{Result, anyhow}; +use itertools::enumerate; use num_derive::ToPrimitive; use std::sync::OnceLock; use strum_macros::EnumIter; @@ -403,7 +404,7 @@ struct FastDecodeTable { impl FastDecodeTable { fn new() -> Self { let mut table: FastInstructionTable = [0; 1 << 10]; - for (isa_idx, insn) in RV32IM_ISA.iter().enumerate() { + for (isa_idx, insn) in enumerate(&RV32IM_ISA) { Self::add_insn(&mut table, insn, isa_idx); } Self { table } diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index fe8013448..296d25be8 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -44,7 +44,7 @@ impl VMState { }; // init memory from program.image - for (&addr, &value) in program.image.iter() { + for (&addr, &value) in &program.image { vm.init_memory(ByteAddr(addr).waddr(), value); } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 2ee5d1a66..e78474d3e 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -773,7 +773,7 @@ Hints: let mut rom_inputs = HashMap::, String, String, Vec>)>>::new(); let mut rom_tables = HashMap::>::new(); - for (circuit_name, cs) in cs.circuit_css.iter() { + for (circuit_name, cs) in &cs.circuit_css { let is_opcode = cs.lk_table_expressions.is_empty() && cs.r_table_expressions.is_empty() && cs.w_table_expressions.is_empty(); @@ -958,7 +958,7 @@ Hints: let mut writes_grp_by_annotations = HashMap::new(); // store (pc, timestamp) for $ram_type == RAMType::GlobalState let mut gs = HashMap::new(); - for (circuit_name, cs) in cs.circuit_css.iter() { + for (circuit_name, cs) in &cs.circuit_css { let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let num_rows = num_instances.get(circuit_name).unwrap(); @@ -1020,7 +1020,7 @@ Hints: let mut reads = HashSet::new(); let mut reads_grp_by_annotations = HashMap::new(); - for (circuit_name, cs) in cs.circuit_css.iter() { + for (circuit_name, cs) in &cs.circuit_css { let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let num_rows = num_instances.get(circuit_name).unwrap(); @@ -1065,7 +1065,7 @@ Hints: } macro_rules! find_rw_mismatch { ($reads:ident,$reads_grp_by_annotations:ident,$writes:ident,$writes_grp_by_annotations:ident,$ram_type:expr,$gs:expr) => { - for (annotation, (reads, circuit_name)) in $reads_grp_by_annotations.iter() { + for (annotation, (reads, circuit_name)) in &$reads_grp_by_annotations { // (pc, timestamp) let gs_of_circuit = $gs.get(circuit_name); let num_missing = reads @@ -1102,7 +1102,7 @@ Hints: } num_rw_mismatch_errors += num_missing; } - for (annotation, (writes, circuit_name)) in $writes_grp_by_annotations.iter() { + for (annotation, (writes, circuit_name)) in &$writes_grp_by_annotations { let gs_of_circuit = $gs.get(circuit_name); let num_missing = writes .iter() diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 30d8d9a6d..85e071ea5 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -5,7 +5,7 @@ use std::{ }; use ff::Field; -use itertools::{Itertools, izip}; +use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, @@ -62,10 +62,9 @@ impl> ZKVMProver { let mut vm_proof = ZKVMProof::empty(pi); // including raw public input to transcript - vm_proof - .raw_pi - .iter() - .for_each(|v| v.iter().for_each(|v| transcript.append_field_element(v))); + for v in vm_proof.raw_pi.iter().flatten() { + transcript.append_field_element(v); + } let pi: Vec> = vm_proof .raw_pi @@ -77,7 +76,7 @@ impl> ZKVMProver { .collect(); // commit to fixed commitment - for (_, pk) in self.pk.circuit_pks.iter() { + for pk in self.pk.circuit_pks.values() { if let Some(fixed_commit) = &pk.vk.fixed_commit { PCS::write_commitment(fixed_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; @@ -147,7 +146,7 @@ impl> ZKVMProver { cs.w_expressions.len(), cs.lk_expressions.len(), ); - for lk_s in cs.lk_expressions_namespace_map.iter() { + for lk_s in &cs.lk_expressions_namespace_map { tracing::debug!("opcode circuit {}: {}", circuit_name, lk_s); } let opcode_proof = self.create_opcode_proof( @@ -1188,7 +1187,7 @@ impl TowerProver { let eq: ArcMultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle().into(); let mut virtual_polys = VirtualPolynomials::::new(num_threads, out_rt.len()); - for (s, alpha) in prod_specs.iter().zip(alpha_pows.iter()) { + for (s, alpha) in izip!(&prod_specs, &alpha_pows) { if round < s.witness.len() { let layer_polys = &s.witness[round]; @@ -1210,9 +1209,7 @@ impl TowerProver { } } - for (s, alpha) in logup_specs - .iter() - .zip(alpha_pows[prod_specs.len()..].chunks(2)) + for (s, alpha) in izip!(&logup_specs, alpha_pows[prod_specs.len()..].chunks(2)) { if round < s.witness.len() { let layer_polys = &s.witness[round]; @@ -1270,7 +1267,7 @@ impl TowerProver { let evals = state.get_mle_final_evaluations(); let mut evals_iter = evals.iter(); evals_iter.next(); // skip first eq - for (i, s) in prod_specs.iter().enumerate() { + for (i, s) in enumerate(&prod_specs) { if round < s.witness.len() { // collect evals belong to current spec proofs.push_prod_evals_and_point( @@ -1282,7 +1279,7 @@ impl TowerProver { ); } } - for (i, s) in logup_specs.iter().enumerate() { + for (i, s) in enumerate(&logup_specs) { if round < s.witness.len() { // collect evals belong to current spec // p1, q2, p2, q1 From a76d586859481af09a69f98bdeeaad4565ef1603 Mon Sep 17 00:00:00 2001 From: naure Date: Mon, 25 Nov 2024 18:53:15 +0100 Subject: [PATCH 04/21] Feat: Private input integration (#622) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _Issue #614_. _Pending a solution in #625_ - Load the private input from a file. - Configuration of the address range. - Integrate the circuit (#617) in MMU. --------- Co-authored-by: Aurélien Nicolas --- ceno_emul/src/addr.rs | 6 +-- ceno_emul/src/platform.rs | 8 +++- ceno_zkvm/examples/riscv_opcodes.rs | 1 + ceno_zkvm/src/bin/e2e.rs | 45 +++++++++++++++++-- .../src/instructions/riscv/rv32im/mmu.rs | 24 ++++++++-- ceno_zkvm/src/tables/ram.rs | 12 ++--- ceno_zkvm/src/tables/ram/ram_circuit.rs | 2 +- ceno_zkvm/src/tables/ram/ram_impl.rs | 8 ++-- 8 files changed, 86 insertions(+), 20 deletions(-) diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index dfd9056b7..78b01563e 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -197,17 +197,17 @@ impl ops::AddAssign for ByteAddr { } pub trait IterAddresses { - fn iter_addresses(&self) -> impl Iterator; + fn iter_addresses(&self) -> impl ExactSizeIterator; } impl IterAddresses for Range { - fn iter_addresses(&self) -> impl Iterator { + fn iter_addresses(&self) -> impl ExactSizeIterator { self.clone().step_by(WORD_SIZE) } } impl<'a, T: GetAddr> IterAddresses for &'a [T] { - fn iter_addresses(&self) -> impl Iterator { + fn iter_addresses(&self) -> impl ExactSizeIterator { self.iter().map(T::get_addr) } } diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index fb11d166d..b0fa0afe7 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -12,6 +12,7 @@ pub struct Platform { pub rom: Range, pub ram: Range, pub public_io: Range, + pub private_io: Range, pub stack_top: Addr, /// If true, ecall instructions are no-op instead of trap. Testing only. pub unsafe_ecall_nop: bool, @@ -21,6 +22,7 @@ pub const CENO_PLATFORM: Platform = Platform { rom: 0x2000_0000..0x3000_0000, ram: 0x8000_0000..0xFFFF_0000, public_io: 0x3000_1000..0x3000_2000, + private_io: 0x4000_0000..0x5000_0000, stack_top: 0xC0000000, unsafe_ecall_nop: false, }; @@ -40,6 +42,10 @@ impl Platform { self.public_io.contains(&addr) } + pub fn is_priv_io(&self, addr: Addr) -> bool { + self.private_io.contains(&addr) + } + /// Virtual address of a register. pub const fn register_vma(index: RegIdx) -> Addr { // Register VMAs are aligned, cannot be confused with indices, and readable in hex. @@ -60,7 +66,7 @@ impl Platform { // Permissions. pub fn can_read(&self, addr: Addr) -> bool { - self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) + self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) || self.is_priv_io(addr) } pub fn can_write(&self, addr: Addr) -> bool { diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 0711bd58f..67a314683 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -275,6 +275,7 @@ fn main() { ®_final, &mem_final, &public_io_final, + &[], ) .unwrap(); diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index a4a2b6087..9da848b7e 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,6 +1,6 @@ use ceno_emul::{ - ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, Platform, StepRecord, Tracer, VMState, - WORD_SIZE, WordAddr, + ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, IterAddresses, Platform, StepRecord, + Tracer, VMState, WORD_SIZE, Word, WordAddr, }; use ceno_zkvm::{ instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, @@ -19,7 +19,9 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate}; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; use std::{ collections::{HashMap, HashSet}, - fs, panic, + fs, + iter::zip, + panic, time::Instant, }; use tracing::level_filters::LevelFilter; @@ -41,6 +43,11 @@ struct Args { /// The preset configuration to use. #[arg(short, long, value_enum, default_value_t = Preset::Ceno)] platform: Preset, + + /// The private input or hints. This is a raw file mounted as a memory segment. + /// Zero-padded to the next power-of-two size. + #[arg(long)] + private_input: Option, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -94,6 +101,17 @@ fn main() { let elf_bytes = fs::read(&args.elf).expect("read elf file"); let mut vm = VMState::new_from_elf(platform.clone(), &elf_bytes).unwrap(); + tracing::info!("Loading private input file: {:?}", args.private_input); + let priv_io = memory_from_file(&args.private_input); + assert!( + priv_io.len() <= platform.private_io.iter_addresses().len(), + "private input must fit in {} bytes", + platform.private_io.len() + ); + for (addr, value) in zip(platform.private_io.iter_addresses(), &priv_io) { + vm.init_memory(addr.into(), *value); + } + // keygen let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); @@ -249,6 +267,14 @@ fn main() { .map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0)) .collect_vec(); + let priv_io_final = zip(platform.private_io.iter_addresses(), &priv_io) + .map(|(addr, &value)| MemFinalRecord { + addr, + value, + cycle: *final_access.get(&addr.into()).unwrap_or(&0), + }) + .collect_vec(); + // assign table circuits config .assign_table_circuit(&zkvm_cs, &mut zkvm_witness) @@ -260,6 +286,7 @@ fn main() { ®_final, &mem_final, &io_final, + &priv_io_final, ) .unwrap(); // assign program circuit @@ -332,6 +359,18 @@ fn main() { }; } +fn memory_from_file(path: &Option) -> Vec { + path.as_ref() + .map(|path| { + let mut buf = fs::read(path).expect("could not read file"); + buf.resize(buf.len().next_multiple_of(WORD_SIZE), 0); + buf.chunks_exact(WORD_SIZE) + .map(|word| Word::from_le_bytes(word.try_into().unwrap())) + .collect_vec() + }) + .unwrap_or_default() +} + fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) { let accessed_addrs = vm .tracer() diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 66eafd891..fe1722728 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -8,8 +8,8 @@ use crate::{ error::ZKVMError, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ - MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable, - RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit, + MemFinalRecord, MemInitRecord, NonVolatileTable, PrivateIOCircuit, PubIOCircuit, + PubIOTable, RegTable, RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit, }, }; @@ -20,6 +20,8 @@ pub struct MmuConfig { pub static_mem_config: as TableCircuit>::TableConfig, /// Initialization of public IO. pub public_io_config: as TableCircuit>::TableConfig, + /// Initialization of private IO. + pub private_io_config: as TableCircuit>::TableConfig, pub params: ProgramParams, } @@ -30,11 +32,13 @@ impl MmuConfig { let static_mem_config = cs.register_table_circuit::>(); let public_io_config = cs.register_table_circuit::>(); + let private_io_config = cs.register_table_circuit::>(); Self { reg_config, static_mem_config, public_io_config, + private_io_config, params: cs.params.clone(), } } @@ -48,7 +52,13 @@ impl MmuConfig { io_addrs: &[Addr], ) { assert!( - chain!(static_mem_init.iter_addresses(), io_addrs.iter_addresses()).all_unique(), + chain!( + static_mem_init.iter_addresses(), + io_addrs.iter_addresses(), + // TODO: optimize with min_max and Range. + self.params.platform.private_io.iter_addresses(), + ) + .all_unique(), "memory addresses must be unique" ); @@ -61,6 +71,7 @@ impl MmuConfig { ); fixed.register_table_circuit::>(cs, &self.public_io_config, io_addrs); + fixed.register_table_circuit::>(cs, &self.private_io_config, &()); } pub fn assign_table_circuit( @@ -70,6 +81,7 @@ impl MmuConfig { reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], io_cycles: &[Cycle], + private_io_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { witness.assign_table_circuit::>(cs, &self.reg_config, reg_final)?; @@ -81,6 +93,12 @@ impl MmuConfig { witness.assign_table_circuit::>(cs, &self.public_io_config, io_cycles)?; + witness.assign_table_circuit::>( + cs, + &self.private_io_config, + private_io_final, + )?; + Ok(()) } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 88e968e69..f6fcb0282 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -34,25 +34,25 @@ impl DynVolatileRamTable for DynMemTable { pub type DynMemCircuit = DynVolatileRamCircuit; #[derive(Clone)] -pub struct PrivateMemTable; -impl DynVolatileRamTable for PrivateMemTable { +pub struct PrivateIOTable; +impl DynVolatileRamTable for PrivateIOTable { const RAM_TYPE: RAMType = RAMType::Memory; const V_LIMBS: usize = 1; // See `MemoryExpr`. const ZERO_INIT: bool = false; fn offset_addr(params: &ProgramParams) -> Addr { - params.platform.ram.start + params.platform.private_io.start } fn end_addr(params: &ProgramParams) -> Addr { - params.platform.ram.end + params.platform.private_io.end } fn name() -> &'static str { - "PrivateMemTable" + "PrivateIOTable" } } -pub type PrivateMemCircuit = DynVolatileRamCircuit; +pub type PrivateIOCircuit = DynVolatileRamCircuit; /// RegTable, fix size without offset #[derive(Clone)] diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 584489781..234ff7799 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -189,7 +189,7 @@ impl TableC type WitnessInput = [MemFinalRecord]; fn name() -> String { - format!("RAM_{:?}", DVRAM::RAM_TYPE) + format!("RAM_{:?}_{}", DVRAM::RAM_TYPE, DVRAM::name()) } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index d293e5d83..2f16c6bff 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -369,15 +369,17 @@ impl DynVolatileRamTableConfig ) -> Result, ZKVMError> { assert!(final_mem.len() <= DVRAM::max_len(&self.params)); assert!(DVRAM::max_len(&self.params).is_power_of_two()); - let mut final_table = - RowMajorMatrix::::new(final_mem.len().next_power_of_two(), num_witness); + let mut final_table = RowMajorMatrix::::new(final_mem.len(), num_witness); final_table .par_iter_mut() .with_min_len(MIN_PAR_SIZE) .zip(final_mem.into_par_iter()) - .for_each(|(row, rec)| { + .enumerate() + .for_each(|(i, (row, rec))| { + assert_eq!(rec.addr, DVRAM::addr(&self.params, i)); set_val!(row, self.addr, rec.addr as u64); + if self.final_v.len() == 1 { // Assign value directly. set_val!(row, self.final_v[0], rec.value as u64); From baab45fd75037fc2e11ba6f476ba052a62967aa2 Mon Sep 17 00:00:00 2001 From: t Date: Tue, 26 Nov 2024 09:35:47 +0600 Subject: [PATCH 05/21] RiscV card book link fix (#629) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit link fix Co-authored-by: Matthias Görgens --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4c6fd6114..82f2e8bfe 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,11 @@ Please see [the slightly outdated paper](https://eprint.iacr.org/2024/387) for a 🚧 This project is currently under construction and not suitable for use in production. 🚧 -If you are unfamiliar with the RISC-V instruction set, please have a look at the [RISC-V instruction set reference](https://github.com/jameslzhu/riscv-card/blob/master/riscv-card.pdf). +If you are unfamiliar with the RISC-V instruction set, please have a look at the [RISC-V instruction set reference](https://github.com/jameslzhu/riscv-card/releases/download/latest/riscv-card.pdf). ## Local build requirements -Ceno is built in Rust, so [installing the Rust toolchain](https://www.rust-lang.org/tools/install) is a pre-requisite, if you want to develop on your local machine. We also use [cargo-make](https://sagiegurari.github.io/cargo-make/) to build Ceno. You can install cargo-make with the following command: +Ceno is built in Rust, so [installing the Rust toolchain](https://www.rust-lang.org/tools/install) is a pre-requisite if you want to develop on your local machine. We also use [cargo-make](https://sagiegurari.github.io/cargo-make/) to build Ceno. You can install cargo-make with the following command: ```sh cargo install cargo-make From 33edc81e1e62370a309e4e5f4b82f3f1ef1c7bb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Tue, 26 Nov 2024 16:49:05 +0800 Subject: [PATCH 06/21] Fix docs (#633) We accidentally wrote about 'mounting' (like a file system), when we meant memory mapping. --- ceno_zkvm/src/bin/e2e.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 9da848b7e..0a9d64a11 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -44,8 +44,8 @@ struct Args { #[arg(short, long, value_enum, default_value_t = Preset::Ceno)] platform: Preset, - /// The private input or hints. This is a raw file mounted as a memory segment. - /// Zero-padded to the next power-of-two size. + /// The private input or hints. This is a raw file mapped as a memory segment. + /// Zero-padded to the right to the next power-of-two size. #[arg(long)] private_input: Option, } From a062e13ce53ef4fb2ac419541b43f87bf7a6e663 Mon Sep 17 00:00:00 2001 From: naure Date: Tue, 26 Nov 2024 10:41:31 +0100 Subject: [PATCH 07/21] doc: Rename private input to hints (#636) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Aurélien Nicolas --- ceno_emul/src/platform.rs | 10 ++++---- ceno_zkvm/src/bin/e2e.rs | 19 ++++++++------- .../src/instructions/riscv/rv32im/mmu.rs | 24 ++++++++----------- ceno_zkvm/src/tables/ram.rs | 12 +++++----- 4 files changed, 31 insertions(+), 34 deletions(-) diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index b0fa0afe7..c28bcda40 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -12,7 +12,7 @@ pub struct Platform { pub rom: Range, pub ram: Range, pub public_io: Range, - pub private_io: Range, + pub hints: Range, pub stack_top: Addr, /// If true, ecall instructions are no-op instead of trap. Testing only. pub unsafe_ecall_nop: bool, @@ -22,7 +22,7 @@ pub const CENO_PLATFORM: Platform = Platform { rom: 0x2000_0000..0x3000_0000, ram: 0x8000_0000..0xFFFF_0000, public_io: 0x3000_1000..0x3000_2000, - private_io: 0x4000_0000..0x5000_0000, + hints: 0x4000_0000..0x5000_0000, stack_top: 0xC0000000, unsafe_ecall_nop: false, }; @@ -42,8 +42,8 @@ impl Platform { self.public_io.contains(&addr) } - pub fn is_priv_io(&self, addr: Addr) -> bool { - self.private_io.contains(&addr) + pub fn is_hints(&self, addr: Addr) -> bool { + self.hints.contains(&addr) } /// Virtual address of a register. @@ -66,7 +66,7 @@ impl Platform { // Permissions. pub fn can_read(&self, addr: Addr) -> bool { - self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) || self.is_priv_io(addr) + self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) || self.is_hints(addr) } pub fn can_write(&self, addr: Addr) -> bool { diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 0a9d64a11..af0ae40de 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -44,10 +44,11 @@ struct Args { #[arg(short, long, value_enum, default_value_t = Preset::Ceno)] platform: Preset, - /// The private input or hints. This is a raw file mapped as a memory segment. + /// Hints: prover-private unconstrained input. + /// This is a raw file mapped as a memory segment. /// Zero-padded to the right to the next power-of-two size. #[arg(long)] - private_input: Option, + hints: Option, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -101,14 +102,14 @@ fn main() { let elf_bytes = fs::read(&args.elf).expect("read elf file"); let mut vm = VMState::new_from_elf(platform.clone(), &elf_bytes).unwrap(); - tracing::info!("Loading private input file: {:?}", args.private_input); - let priv_io = memory_from_file(&args.private_input); + tracing::info!("Loading hints file: {:?}", args.hints); + let hints = memory_from_file(&args.hints); assert!( - priv_io.len() <= platform.private_io.iter_addresses().len(), - "private input must fit in {} bytes", - platform.private_io.len() + hints.len() <= platform.hints.iter_addresses().len(), + "hints must fit in {} bytes", + platform.hints.len() ); - for (addr, value) in zip(platform.private_io.iter_addresses(), &priv_io) { + for (addr, value) in zip(platform.hints.iter_addresses(), &hints) { vm.init_memory(addr.into(), *value); } @@ -267,7 +268,7 @@ fn main() { .map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0)) .collect_vec(); - let priv_io_final = zip(platform.private_io.iter_addresses(), &priv_io) + let priv_io_final = zip(platform.hints.iter_addresses(), &hints) .map(|(addr, &value)| MemFinalRecord { addr, value, diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index fe1722728..a271c21ab 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -8,8 +8,8 @@ use crate::{ error::ZKVMError, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ - MemFinalRecord, MemInitRecord, NonVolatileTable, PrivateIOCircuit, PubIOCircuit, - PubIOTable, RegTable, RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit, + HintsCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, + RegTable, RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit, }, }; @@ -20,8 +20,8 @@ pub struct MmuConfig { pub static_mem_config: as TableCircuit>::TableConfig, /// Initialization of public IO. pub public_io_config: as TableCircuit>::TableConfig, - /// Initialization of private IO. - pub private_io_config: as TableCircuit>::TableConfig, + /// Initialization of hints. + pub hints_config: as TableCircuit>::TableConfig, pub params: ProgramParams, } @@ -32,13 +32,13 @@ impl MmuConfig { let static_mem_config = cs.register_table_circuit::>(); let public_io_config = cs.register_table_circuit::>(); - let private_io_config = cs.register_table_circuit::>(); + let hints_config = cs.register_table_circuit::>(); Self { reg_config, static_mem_config, public_io_config, - private_io_config, + hints_config, params: cs.params.clone(), } } @@ -56,7 +56,7 @@ impl MmuConfig { static_mem_init.iter_addresses(), io_addrs.iter_addresses(), // TODO: optimize with min_max and Range. - self.params.platform.private_io.iter_addresses(), + self.params.platform.hints.iter_addresses(), ) .all_unique(), "memory addresses must be unique" @@ -71,7 +71,7 @@ impl MmuConfig { ); fixed.register_table_circuit::>(cs, &self.public_io_config, io_addrs); - fixed.register_table_circuit::>(cs, &self.private_io_config, &()); + fixed.register_table_circuit::>(cs, &self.hints_config, &()); } pub fn assign_table_circuit( @@ -81,7 +81,7 @@ impl MmuConfig { reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], io_cycles: &[Cycle], - private_io_final: &[MemFinalRecord], + hints_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { witness.assign_table_circuit::>(cs, &self.reg_config, reg_final)?; @@ -93,11 +93,7 @@ impl MmuConfig { witness.assign_table_circuit::>(cs, &self.public_io_config, io_cycles)?; - witness.assign_table_circuit::>( - cs, - &self.private_io_config, - private_io_final, - )?; + witness.assign_table_circuit::>(cs, &self.hints_config, hints_final)?; Ok(()) } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index f6fcb0282..2c45294e4 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -34,25 +34,25 @@ impl DynVolatileRamTable for DynMemTable { pub type DynMemCircuit = DynVolatileRamCircuit; #[derive(Clone)] -pub struct PrivateIOTable; -impl DynVolatileRamTable for PrivateIOTable { +pub struct HintsTable; +impl DynVolatileRamTable for HintsTable { const RAM_TYPE: RAMType = RAMType::Memory; const V_LIMBS: usize = 1; // See `MemoryExpr`. const ZERO_INIT: bool = false; fn offset_addr(params: &ProgramParams) -> Addr { - params.platform.private_io.start + params.platform.hints.start } fn end_addr(params: &ProgramParams) -> Addr { - params.platform.private_io.end + params.platform.hints.end } fn name() -> &'static str { - "PrivateIOTable" + "HintsTable" } } -pub type PrivateIOCircuit = DynVolatileRamCircuit; +pub type HintsCircuit = DynVolatileRamCircuit; /// RegTable, fix size without offset #[derive(Clone)] From 0ea9248dc6bffbb3e7d6c8aaad7039567e8860e1 Mon Sep 17 00:00:00 2001 From: naure Date: Tue, 26 Nov 2024 12:59:48 +0100 Subject: [PATCH 08/21] Revert "fix/program-size2: refactor padding_zero (#615)" (#638) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit c548dc88fccac94fcb529105169ac44b32c66c0c. --------- Co-authored-by: Aurélien Nicolas --- ceno_zkvm/src/instructions.rs | 2 +- ceno_zkvm/src/tables/mod.rs | 49 ++++++-------- ceno_zkvm/src/tables/program.rs | 109 +++++++++++++++++++++++++------- ceno_zkvm/src/witness.rs | 9 ++- 4 files changed, 110 insertions(+), 59 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index e87675dd6..63314cbee 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -94,7 +94,7 @@ pub trait Instruction { num_padding_instances }; raw_witin - .par_batch_iter_padding_mut(None, num_padding_instance_per_batch) + .par_batch_iter_padding_mut(num_padding_instance_per_batch) .with_min_len(MIN_PAR_SIZE) .for_each(|row| { row.chunks_mut(num_witin) diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index f498b868e..2ef7e293a 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -2,8 +2,8 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, scheme::constants::MIN_PAR_SIZE, witness::RowMajorMatrix, }; +use ff::Field; use ff_ext::ExtensionField; -use goldilocks::SmallField; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::{collections::HashMap, mem::MaybeUninit}; mod range; @@ -46,34 +46,25 @@ pub trait TableCircuit { table: &mut RowMajorMatrix, num_witin: usize, ) -> Result<(), ZKVMError> { - padding_zero(table, num_witin, None); + // Fill the padding with zeros, if any. + let num_padding_instances = table.num_padding_instances(); + if num_padding_instances > 0 { + let nthreads = + std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let padding_instance = vec![MaybeUninit::new(E::BaseField::ZERO); num_witin]; + let num_padding_instance_per_batch = if num_padding_instances > 256 { + num_padding_instances.div_ceil(nthreads) + } else { + num_padding_instances + }; + table + .par_batch_iter_padding_mut(num_padding_instance_per_batch) + .with_min_len(MIN_PAR_SIZE) + .for_each(|row| { + row.chunks_mut(num_witin) + .for_each(|instance| instance.copy_from_slice(padding_instance.as_slice())); + }); + } Ok(()) } } - -/// Fill the padding with zeros. Start after the given `num_instances`, or detect it from the table. -pub fn padding_zero( - table: &mut RowMajorMatrix, - num_cols: usize, - num_instances: Option, -) { - // Fill the padding with zeros, if any. - let num_padding_instances = table.num_padding_instances(); - if num_padding_instances > 0 { - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); - let padding_instance = vec![MaybeUninit::new(F::ZERO); num_cols]; - let num_padding_instance_per_batch = if num_padding_instances > 256 { - num_padding_instances.div_ceil(nthreads) - } else { - num_padding_instances - }; - table - .par_batch_iter_padding_mut(num_instances, num_padding_instance_per_batch) - .with_min_len(MIN_PAR_SIZE) - .for_each(|row| { - row.chunks_mut(num_cols) - .for_each(|instance| instance.copy_from_slice(padding_instance.as_slice())); - }); - } -} diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index f1d38ada8..da063545e 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -7,7 +7,7 @@ use crate::{ scheme::constants::MIN_PAR_SIZE, set_fixed_val, set_val, structs::ROMType, - tables::{TableCircuit, padding_zero}, + tables::TableCircuit, utils::i64_to_base, witness::RowMajorMatrix, }; @@ -136,7 +136,7 @@ impl TableCircuit for ProgramTableCircuit { cb.lk_table_record( || "prog table", - cb.params.program_size, + cb.params.program_size.next_power_of_two(), ROMType::Instruction, record_exprs, mlt.expr(), @@ -176,7 +176,15 @@ impl TableCircuit for ProgramTableCircuit { }); assert_eq!(INVALID as u64, 0, "0 padding must be invalid instructions"); - padding_zero(&mut fixed, num_fixed, Some(num_instructions)); + fixed + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .skip(num_instructions) + .for_each(|row| { + for col in config.record.as_slice() { + set_fixed_val!(row, *col, 0_u64.into()); + } + }); fixed } @@ -204,32 +212,85 @@ impl TableCircuit for ProgramTableCircuit { set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); }); - padding_zero(&mut witness, num_witin, Some(program.instructions.len())); + witness + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .skip(program.instructions.len()) + .for_each(|row| { + set_val!(row, config.mlt, 0_u64); + }); Ok(witness) } } #[cfg(test)] -#[test] -#[allow(clippy::identity_op)] -fn test_decode_imm() { - for (i, expected) in [ - // Example of I-type: ADDI. - // imm | rs1 | funct3 | rd | opcode - (89 << 20 | 1 << 15 | 0b000 << 12 | 1 << 7 | 0x13, 89), - // Shifts get a precomputed power of 2: SLLI, SRLI, SRAI. - (31 << 20 | 1 << 15 | 0b001 << 12 | 1 << 7 | 0x13, 1 << 31), - (31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, 1 << 31), - ( - 1 << 30 | 31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, - 1 << 31, - ), - // Example of R-type with funct7: SUB. - // funct7 | rs2 | rs1 | funct3 | rd | opcode - (0x20 << 25 | 1 << 20 | 1 << 15 | 0 << 12 | 1 << 7 | 0x33, 0), - ] { - let imm = InsnRecord::imm_internal(&DecodedInstruction::new(i)); - assert_eq!(imm, expected); +mod tests { + use super::*; + use crate::{circuit_builder::ConstraintSystem, witness::LkMultiplicity}; + use ceno_emul::encode_rv32; + use ff::Field; + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + + #[test] + #[allow(clippy::identity_op)] + fn test_decode_imm() { + for (i, expected) in [ + // Example of I-type: ADDI. + // imm | rs1 | funct3 | rd | opcode + (89 << 20 | 1 << 15 | 0b000 << 12 | 1 << 7 | 0x13, 89), + // Shifts get a precomputed power of 2: SLLI, SRLI, SRAI. + (31 << 20 | 1 << 15 | 0b001 << 12 | 1 << 7 | 0x13, 1 << 31), + (31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, 1 << 31), + ( + 1 << 30 | 31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, + 1 << 31, + ), + // Example of R-type with funct7: SUB. + // funct7 | rs2 | rs1 | funct3 | rd | opcode + (0x20 << 25 | 1 << 20 | 1 << 15 | 0 << 12 | 1 << 7 | 0x33, 0), + ] { + let imm = InsnRecord::imm_internal(&DecodedInstruction::new(i)); + assert_eq!(imm, expected); + } + } + + #[test] + fn test_program_padding() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + + let actual_len = 3; + let instructions = vec![encode_rv32(ADD, 1, 2, 3, 0); actual_len]; + let program = Program::new(0x2000_0000, 0x2000_0000, instructions, Default::default()); + + let config = ProgramTableCircuit::construct_circuit(&mut cb).unwrap(); + + let check = |matrix: &RowMajorMatrix| { + assert_eq!( + matrix.num_instances() + matrix.num_padding_instances(), + cb.params.program_size + ); + for row in matrix.iter_rows().skip(actual_len) { + for col in row.iter() { + assert_eq!(unsafe { col.assume_init() }, F::ZERO); + } + } + }; + + let fixed = + ProgramTableCircuit::::generate_fixed_traces(&config, cb.cs.num_fixed, &program); + check(&fixed); + + let lkm = LkMultiplicity::default().into_finalize_result(); + + let witness = ProgramTableCircuit::::assign_instances( + &config, + cb.cs.num_witin as usize, + &lkm, + &program, + ) + .unwrap(); + check(&witness); } } diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index e85360aee..7acc9ad50 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -87,13 +87,12 @@ impl RowMajorMatrix { pub fn par_batch_iter_padding_mut( &mut self, - num_instances: Option, - batch_size: usize, + num_rows: usize, ) -> rayon::slice::ChunksMut<'_, MaybeUninit> { - let num_instances = num_instances.unwrap_or(self.num_instances()); - self.values[num_instances * self.num_col..] + let valid_instance = self.num_instances(); + self.values[valid_instance * self.num_col..] .as_mut() - .par_chunks_mut(batch_size * self.num_col) + .par_chunks_mut(num_rows * self.num_col) } pub fn de_interleaving(mut self) -> Vec> { From 1edbf65010274acee38d445beba236e7293c66b4 Mon Sep 17 00:00:00 2001 From: Cyte Zhang Date: Tue, 26 Nov 2024 23:19:46 +0800 Subject: [PATCH 09/21] Refactor BaseFold hashing benchmark. (#556) Extract a small independent change from #294 Co-authored-by: Ming --- mpcs/benches/hashing.rs | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/mpcs/benches/hashing.rs b/mpcs/benches/hashing.rs index 4dc107ed7..818cd7b6c 100644 --- a/mpcs/benches/hashing.rs +++ b/mpcs/benches/hashing.rs @@ -5,28 +5,19 @@ use goldilocks::Goldilocks; use mpcs::util::hash::{Digest, hash_two_digests}; use poseidon::poseidon_hash::PoseidonHash; +fn random_ceno_goldy() -> Goldilocks { + Goldilocks::random(&mut test_rng()) +} pub fn criterion_benchmark(c: &mut Criterion) { - let left = Digest( - vec![Goldilocks::random(&mut test_rng()); 4] - .try_into() - .unwrap(), - ); - let right = Digest( - vec![Goldilocks::random(&mut test_rng()); 4] - .try_into() - .unwrap(), - ); + let left = Digest(vec![random_ceno_goldy(); 4].try_into().unwrap()); + let right = Digest(vec![random_ceno_goldy(); 4].try_into().unwrap()); c.bench_function("ceno hash 2 to 1", |bencher| { bencher.iter(|| hash_two_digests(&left, &right)) }); - let values = (0..60) - .map(|_| Goldilocks::random(&mut test_rng())) - .collect::>(); + let values = (0..60).map(|_| random_ceno_goldy()).collect::>(); c.bench_function("ceno hash 60 to 1", |bencher| { - bencher.iter(|| { - PoseidonHash::hash_or_noop(&values); - }) + bencher.iter(|| PoseidonHash::hash_or_noop(&values)) }); } From 846033ad805ae115507ed6923acf37e3ff478279 Mon Sep 17 00:00:00 2001 From: Ming Date: Wed, 27 Nov 2024 10:10:28 +0700 Subject: [PATCH 10/21] chores: simplify to max_usable_threads (#640) --- ceno_zkvm/src/instructions.rs | 4 ++-- ceno_zkvm/src/tables/mod.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 63314cbee..eb24b7008 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -2,6 +2,7 @@ use std::mem::MaybeUninit; use ceno_emul::StepRecord; use ff_ext::ExtensionField; +use multilinear_extensions::util::max_usable_threads; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, @@ -47,8 +48,7 @@ pub trait Instruction { num_witin: usize, steps: Vec, ) -> Result<(RowMajorMatrix, LkMultiplicity), ZKVMError> { - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let nthreads = max_usable_threads(); let num_instance_per_batch = if steps.len() > 256 { steps.len().div_ceil(nthreads) } else { diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 2ef7e293a..b27f04e54 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -4,6 +4,7 @@ use crate::{ }; use ff::Field; use ff_ext::ExtensionField; +use multilinear_extensions::util::max_usable_threads; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::{collections::HashMap, mem::MaybeUninit}; mod range; @@ -49,8 +50,7 @@ pub trait TableCircuit { // Fill the padding with zeros, if any. let num_padding_instances = table.num_padding_instances(); if num_padding_instances > 0 { - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let nthreads = max_usable_threads(); let padding_instance = vec![MaybeUninit::new(E::BaseField::ZERO); num_witin]; let num_padding_instance_per_batch = if num_padding_instances > 256 { num_padding_instances.div_ceil(nthreads) From 23fbe0a95ffaa1adc472df3b63d9d315f3f005e9 Mon Sep 17 00:00:00 2001 From: naure Date: Wed, 27 Nov 2024 14:49:28 +0100 Subject: [PATCH 11/21] fix/pow-table-index: Correct table index calculation (#644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _Blocking sproll-evm._ PowTable content is organized differently than binary ops. This breaks the assignment of multiplicities. Fixed by overriding the mapping between table indexes and entries. This detail was not covered by the table checks of MockProver, hence why discovered only now. --------- Co-authored-by: Aurélien Nicolas --- ceno_zkvm/src/tables/ops.rs | 44 ++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/ceno_zkvm/src/tables/ops.rs b/ceno_zkvm/src/tables/ops.rs index f1b2e0612..789f7f744 100644 --- a/ceno_zkvm/src/tables/ops.rs +++ b/ceno_zkvm/src/tables/ops.rs @@ -87,7 +87,49 @@ impl OpsTable for PowTable { } fn content() -> Vec<[u64; 3]> { - (0..Self::len() as u64).map(|b| [2, b, 1 << b]).collect() + (0..Self::len() as u64) + .map(|exponent| [2, exponent, 1 << exponent]) + .collect() + } + + fn pack(base: u64, exponent: u64) -> u64 { + assert_eq!(base, 2); + exponent + } + + fn unpack(exponent: u64) -> (u64, u64) { + (2, exponent) } } pub type PowTableCircuit = OpsTableCircuit; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + tables::TableCircuit, + }; + use goldilocks::{GoldilocksExt2 as E, SmallField}; + + #[test] + fn test_ops_pow_table_assign() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = PowTableCircuit::::construct_circuit(&mut cb).unwrap(); + + let fixed = PowTableCircuit::::generate_fixed_traces(&config, cb.cs.num_fixed, &()); + + for (i, row) in fixed.iter_rows().enumerate() { + let (base, exp) = PowTable::unpack(i as u64); + assert_eq!(PowTable::pack(base, exp), i as u64); + assert_eq!(base, unsafe { row[0].assume_init() }.to_canonical_u64()); + assert_eq!(exp, unsafe { row[1].assume_init() }.to_canonical_u64()); + assert_eq!( + base.pow(exp.try_into().unwrap()), + unsafe { row[2].assume_init() }.to_canonical_u64() + ); + } + } +} From 304b2ae83913af410da97b893c0b4e05fa91a971 Mon Sep 17 00:00:00 2001 From: Cyte Zhang Date: Thu, 28 Nov 2024 13:44:01 +0800 Subject: [PATCH 12/21] Refactor BaseFold test code (#552) Extracting small PR from #294 Refactor the test code in `mpcs/src/lib.rs`. Define some utility functions to simplify the tests. Also merged the test functions in `mpcs/src/basefold.rs`. --- mpcs/src/basefold.rs | 139 ++++++++----------------- mpcs/src/lib.rs | 243 ++++++++++++++++++++----------------------- 2 files changed, 152 insertions(+), 230 deletions(-) diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 5c225c75a..be59f42c7 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -1191,108 +1191,53 @@ mod test { type PcsGoldilocksBaseCode = Basefold; #[test] - fn commit_open_verify_goldilocks_basecode_base() { - // Challenge is over extension field, poly over the base field - run_commit_open_verify::(true, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(true, 4, 6); - } - - #[test] - fn commit_open_verify_goldilocks_rscode_base() { - // Challenge is over extension field, poly over the base field - run_commit_open_verify::(true, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(true, 4, 6); - } - - #[test] - fn commit_open_verify_goldilocks_basecode_2() { - // Both challenge and poly are over extension field - run_commit_open_verify::(false, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(false, 4, 6); - } - - #[test] - fn commit_open_verify_goldilocks_rscode_2() { - // Both challenge and poly are over extension field - run_commit_open_verify::(false, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(false, 4, 6); - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks_basecode_base() { - // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::( - true, 10, 11, 1, - ); - run_simple_batch_commit_open_verify::( - true, 10, 11, 4, - ); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::(true, 4, 6, 4); - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks_rscode_base() { - // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::(true, 10, 11, 1); - run_simple_batch_commit_open_verify::(true, 10, 11, 4); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::(true, 4, 6, 4); - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks_basecode_2() { - // Both challenge and poly are over extension field - run_simple_batch_commit_open_verify::( - false, 10, 11, 1, - ); - run_simple_batch_commit_open_verify::( - false, 10, 11, 4, - ); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::( - false, 4, 6, 4, - ); - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks_rscode_2() { - // Both challenge and poly are over extension field - run_simple_batch_commit_open_verify::( - false, 10, 11, 1, - ); - run_simple_batch_commit_open_verify::( - false, 10, 11, 4, - ); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::(false, 4, 6, 4); - } - - #[test] - fn batch_commit_open_verify_goldilocks_basecode_base() { - // Both challenge and poly are over base field - run_batch_commit_open_verify::(true, 10, 11); - } - - #[test] - fn batch_commit_open_verify_goldilocks_rscode_base() { - // Both challenge and poly are over base field - run_batch_commit_open_verify::(true, 10, 11); + fn commit_open_verify_goldilocks() { + for base in [true, false] { + // Challenge is over extension field, poly over the base field + run_commit_open_verify::(base, 10, 11); + // Test trivial proof with small num vars + run_commit_open_verify::(base, 4, 6); + // Challenge is over extension field, poly over the base field + run_commit_open_verify::(base, 10, 11); + // Test trivial proof with small num vars + run_commit_open_verify::(base, 4, 6); + } } #[test] - fn batch_commit_open_verify_goldilocks_basecode_2() { - // Both challenge and poly are over extension field - run_batch_commit_open_verify::(false, 10, 11); + fn simple_batch_commit_open_verify_goldilocks() { + for base in [true, false] { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::( + base, 10, 11, 1, + ); + run_simple_batch_commit_open_verify::( + base, 10, 11, 4, + ); + // Test trivial proof with small num vars + run_simple_batch_commit_open_verify::( + base, 4, 6, 4, + ); + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::( + base, 10, 11, 1, + ); + run_simple_batch_commit_open_verify::( + base, 10, 11, 4, + ); + // Test trivial proof with small num vars + run_simple_batch_commit_open_verify::( + base, 4, 6, 4, + ); + } } #[test] - fn batch_commit_open_verify_goldilocks_rscode_2() { - // Both challenge and poly are over extension field - run_batch_commit_open_verify::(false, 10, 11); + fn batch_commit_open_verify() { + for base in [true, false] { + // Both challenge and poly are over base field + run_batch_commit_open_verify::(base, 10, 11); + run_batch_commit_open_verify::(base, 10, 11); + } } } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 19b3d16b6..3bf1310f6 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -359,15 +359,79 @@ fn err_too_many_variates(function: &str, upto: usize, got: usize) -> Error { #[cfg(test)] pub mod test_util { - use crate::{Evaluation, PolynomialCommitmentScheme}; use ff_ext::ExtensionField; use itertools::{Itertools, chain}; - use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; - use rand::{prelude::*, rngs::OsRng}; - use rand_chacha::ChaCha8Rng; + use multilinear_extensions::{ + mle::{DenseMultilinearExtension, MultilinearExtension}, + virtual_poly_v2::ArcMultilinearExtension, + }; + use rand::rngs::OsRng; use transcript::Transcript; + pub fn setup_pcs>( + num_vars: usize, + ) -> (Pcs::ProverParam, Pcs::VerifierParam) { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size).unwrap(); + Pcs::trim(param, poly_size).unwrap() + } + + pub fn gen_rand_poly( + num_vars: usize, + base: bool, + ) -> DenseMultilinearExtension { + if base { + DenseMultilinearExtension::random(num_vars, &mut OsRng) + } else { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..(1 << num_vars)) + .map(|_| E::random(&mut OsRng)) + .collect_vec(), + ) + } + } + + pub fn gen_rand_polys( + num_vars: impl Fn(usize) -> usize, + batch_size: usize, + base: bool, + ) -> Vec> { + (0..batch_size) + .map(|i| gen_rand_poly(num_vars(i), base)) + .collect_vec() + } + + pub fn get_point_from_challenge( + num_vars: usize, + transcript: &mut Transcript, + ) -> Vec { + (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect() + } + pub fn get_points_from_challenge( + num_vars: impl Fn(usize) -> usize, + num_points: usize, + transcript: &mut Transcript, + ) -> Vec> { + (0..num_points) + .map(|i| get_point_from_challenge(num_vars(i), transcript)) + .collect() + } + + pub fn commit_polys_individually>( + pp: &Pcs::ProverParam, + polys: &[DenseMultilinearExtension], + transcript: &mut Transcript, + ) -> Vec { + polys + .iter() + .map(|poly| Pcs::commit_and_write(pp, poly, transcript).unwrap()) + .collect_vec() + } + pub fn run_commit_open_verify( base: bool, num_vars_start: usize, @@ -376,30 +440,17 @@ pub mod test_util { Pcs: PolynomialCommitmentScheme, { for num_vars in num_vars_start..num_vars_end { - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); + // Commit and open let (comm, eval, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let poly = if base { - DenseMultilinearExtension::random(num_vars, &mut OsRng) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - }; - + let poly = gen_rand_poly(num_vars, base); let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); let eval = poly.evaluate(point.as_slice()); transcript.append_field_element_ext(&eval); + ( Pcs::get_pure_commitment(&comm), eval, @@ -408,21 +459,16 @@ pub mod test_util { ) }; // Verify - let result = { + { let mut transcript = Transcript::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_ext(&eval); - let result = Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript); + Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); let v_challenge = transcript.read_challenge(); assert_eq!(challenge, v_challenge); - - result - }; - result.unwrap(); + } } } @@ -437,13 +483,8 @@ pub mod test_util { for num_vars in num_vars_start..num_vars_end { let batch_size = 2; let num_points = batch_size >> 1; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); + // Batch commit and open let evals = chain![ (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys @@ -452,34 +493,15 @@ pub mod test_util { .unique() .collect_vec(); - let (comms, points, evals, proof, challenge) = { + let (comms, evals, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|i| { - if base { - DenseMultilinearExtension::random(num_vars - (i >> 1), &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - } - }) - .collect_vec(); + let polys = gen_rand_polys(|i| num_vars - (i >> 1), batch_size, base); - let comms = polys - .iter() - .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) - .collect_vec(); + let comms = + commit_polys_individually::(&pp, polys.as_slice(), &mut transcript); - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); + let points = + get_points_from_challenge(|i| num_vars - i, num_points, &mut transcript); let evals = evals .iter() @@ -499,10 +521,10 @@ pub mod test_util { let proof = Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); - (comms, points, evals, proof, transcript.read_challenge()) + (comms, evals, proof, transcript.read_challenge()) }; // Batch verify - let result = { + { let mut transcript = Transcript::new(b"BaseFold"); let comms = comms .iter() @@ -513,16 +535,9 @@ pub mod test_util { }) .collect_vec(); - let old_points = points; - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); - assert_eq!(points, old_points); + let points = + get_points_from_challenge(|i| num_vars - i, num_points, &mut transcript); + let values: Vec = evals .iter() .map(Evaluation::value) @@ -530,14 +545,10 @@ pub mod test_util { .collect::>(); transcript.append_field_element_exts(values.as_slice()); - let result = - Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript); + Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript).unwrap(); let v_challenge = transcript.read_challenge(); assert_eq!(challenge, v_challenge); - result - }; - - result.unwrap(); + } } } @@ -551,52 +562,24 @@ pub mod test_util { Pcs: PolynomialCommitmentScheme, { for num_vars in num_vars_start..num_vars_end { - let rng = ChaCha8Rng::from_seed([0u8; 32]); - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); let (comm, evals, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|_| { - if base { - DenseMultilinearExtension::random(num_vars, &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - } - }) - .collect_vec(); - let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let polys = gen_rand_polys(|_| num_vars, batch_size, base); + let comm = + Pcs::batch_commit_and_write(&pp, polys.as_slice(), &mut transcript).unwrap(); + let point = get_point_from_challenge(num_vars, &mut transcript); + let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); + transcript.append_field_element_exts(&evals); - let evals = (0..batch_size) - .map(|i| polys[i].evaluate(&point)) + let polys = polys + .iter() + .map(|poly| ArcMultilinearExtension::from(poly.clone())) .collect_vec(); - - transcript.append_field_element_exts(&evals); - let proof = Pcs::simple_batch_open( - &pp, - polys - .into_iter() - .map(|x| x.into()) - .collect::>() - .as_slice(), - &comm, - &point, - &evals, - &mut transcript, - ) - .unwrap(); + let proof = + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); ( Pcs::get_pure_commitment(&comm), evals, @@ -605,25 +588,19 @@ pub mod test_util { ) }; // Batch verify - let result = { + { let mut transcript = Transcript::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - + let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_exts(&evals); - let result = - Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript); + Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript) + .unwrap(); let v_challenge = transcript.read_challenge(); assert_eq!(challenge, v_challenge); - result - }; - - result.unwrap(); + } } } } From 3e5262c39b004b2b743e059a288b284c981037ef Mon Sep 17 00:00:00 2001 From: Cyte Zhang Date: Thu, 28 Nov 2024 14:06:34 +0800 Subject: [PATCH 13/21] Merge two similar benchmarks in BaseFold into one (#554) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extracting small PR from #294. The two benchmarks `benches_commit_verify_rs` and `benches_commit_verify_basecode` are similar, differing only in the choice of parameter. Merge them into one benchmark file. Also simplify the codes using the test utility functions introduced in #552 Waiting for #552 to merge. --------- Co-authored-by: Matthias Görgens --- mpcs/Cargo.toml | 6 +- .../{commit_open_verify_rs.rs => basefold.rs} | 245 +++++------ mpcs/benches/commit_open_verify_basecode.rs | 400 ------------------ mpcs/src/basefold.rs | 60 ++- mpcs/src/lib.rs | 65 +-- 5 files changed, 196 insertions(+), 580 deletions(-) rename mpcs/benches/{commit_open_verify_rs.rs => basefold.rs} (61%) delete mode 100644 mpcs/benches/commit_open_verify_basecode.rs diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 26ee123ba..1100cb133 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -38,11 +38,7 @@ sanity-check = [] [[bench]] harness = false -name = "commit_open_verify_rs" - -[[bench]] -harness = false -name = "commit_open_verify_basecode" +name = "basefold" [[bench]] harness = false diff --git a/mpcs/benches/commit_open_verify_rs.rs b/mpcs/benches/basefold.rs similarity index 61% rename from mpcs/benches/commit_open_verify_rs.rs rename to mpcs/benches/basefold.rs index 1401f5127..965d55035 100644 --- a/mpcs/benches/commit_open_verify_rs.rs +++ b/mpcs/benches/basefold.rs @@ -1,12 +1,16 @@ use std::time::Duration; use criterion::*; -use ff::Field; +use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::{Itertools, chain}; use mpcs::{ - Basefold, BasefoldRSParams, Evaluation, PolynomialCommitmentScheme, + Basefold, BasefoldBasecodeParams, BasefoldRSParams, Evaluation, PolynomialCommitmentScheme, + test_util::{ + commit_polys_individually, gen_rand_poly_base, gen_rand_poly_ext, gen_rand_polys, + get_point_from_challenge, get_points_from_challenge, setup_pcs, + }, util::plonky2_util::log2_ceil, }; @@ -14,12 +18,10 @@ use multilinear_extensions::{ mle::{DenseMultilinearExtension, MultilinearExtension}, virtual_poly_v2::ArcMultilinearExtension, }; -use rand::{SeedableRng, rngs::OsRng}; -use rand_chacha::ChaCha8Rng; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; use transcript::Transcript; -type Pcs = Basefold; +type PcsGoldilocksRSCode = Basefold; +type PcsGoldilocksBasecode = Basefold; type T = Transcript; type E = GoldilocksExt2; @@ -29,10 +31,19 @@ const NUM_VARS_END: usize = 20; const BATCH_SIZE_LOG_START: usize = 6; const BATCH_SIZE_LOG_END: usize = 6; -fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { +struct Switch<'a, E: ExtensionField> { + name: &'a str, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, +} + +fn bench_commit_open_verify_goldilocks>( + c: &mut Criterion, + switch: Switch, + id: &str, +) { let mut group = c.benchmark_group(format!( - "commit_open_verify_goldilocks_rs_{}", - if is_base { "base" } else { "ext2" } + "commit_open_verify_goldilocks_{}_{}", + id, switch.name, )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field @@ -50,18 +61,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { }; let mut transcript = T::new(b"BaseFold"); - let poly = if is_base { - DenseMultilinearExtension::random(num_vars, &mut OsRng) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars) - .into_par_iter() - .map(|_| E::random(&mut OsRng)) - .collect(), - ) - }; - + let poly = (switch.gen_rand_poly)(num_vars); let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { @@ -70,9 +70,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { }) }); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); let eval = poly.evaluate(point.as_slice()); transcript.append_field_element_ext(&eval); let transcript_for_bench = transcript.clone(); @@ -91,9 +89,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let comm = Pcs::get_pure_commitment(&comm); let mut transcript = T::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_ext(&eval); let transcript_for_bench = transcript.clone(); Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); @@ -109,10 +105,24 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { } } -fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { +const BASE: Switch = Switch { + name: "base", + gen_rand_poly: gen_rand_poly_base, +}; + +const EXT: Switch = Switch { + name: "ext", + gen_rand_poly: gen_rand_poly_ext, +}; + +fn bench_batch_commit_open_verify_goldilocks>( + c: &mut Criterion, + switch: Switch, + id: &str, +) { let mut group = c.benchmark_group(format!( - "batch_commit_open_verify_goldilocks_rs_{}", - if is_base { "base" } else { "ext2" } + "batch_commit_open_verify_goldilocks_{}_{}", + id, switch.name, )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field @@ -120,13 +130,8 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { let batch_size = 1 << batch_size_log; let num_points = batch_size >> 1; - let rng = ChaCha8Rng::from_seed([0u8; 32]); // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); // Batch commit and open let evals = chain![ (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys @@ -136,37 +141,18 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { .collect_vec(); let mut transcript = T::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|i| { - if is_base { - DenseMultilinearExtension::random( - num_vars - log2_ceil((i >> 1) + 1), - &mut rng.clone(), - ) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars - log2_ceil((i >> 1) + 1), - (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) - .into_par_iter() - .map(|_| E::random(&mut OsRng)) - .collect(), - ) - } - }) - .collect_vec(); - let comms = polys - .iter() - .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) - .collect_vec(); + let polys = gen_rand_polys( + |i| num_vars - log2_ceil((i >> 1) + 1), + batch_size, + switch.gen_rand_poly, + ); + let comms = commit_polys_individually::(&pp, &polys, &mut transcript); - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); + let points = get_points_from_challenge( + |i| num_vars - log2_ceil(i + 1), + num_points, + &mut transcript, + ); let evals = evals .iter() @@ -175,11 +161,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) }) .collect_vec(); - let values: Vec = evals - .iter() - .map(Evaluation::value) - .copied() - .collect::>(); + let values: Vec = evals.iter().map(Evaluation::value).copied().collect(); transcript.append_field_element_exts(values.as_slice()); let transcript_for_bench = transcript.clone(); let proof = @@ -208,14 +190,11 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { comm }) .collect_vec(); - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); + let points = get_points_from_challenge( + |i| num_vars - log2_ceil(i + 1), + num_points, + &mut transcript, + ); let values: Vec = evals .iter() @@ -252,38 +231,23 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { } } -fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { +fn bench_simple_batch_commit_open_verify_goldilocks>( + c: &mut Criterion, + switch: Switch, + id: &str, +) { let mut group = c.benchmark_group(format!( - "simple_batch_commit_open_verify_goldilocks_rs_{}", - if is_base { "base" } else { "extension" } + "simple_batch_commit_open_verify_goldilocks_{}_{}", + id, switch.name, )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field for num_vars in NUM_VARS_START..=NUM_VARS_END { for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { let batch_size = 1 << batch_size_log; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); let mut transcript = T::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|_| { - if is_base { - DenseMultilinearExtension::random(num_vars, &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars) - .into_par_iter() - .map(|_| E::random(&mut OsRng)) - .collect(), - ) - } - }) - .collect_vec(); + let polys = gen_rand_polys(|_| num_vars, batch_size, switch.gen_rand_poly); let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); group.bench_function( @@ -294,20 +258,14 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: }) }, ); - - let polys: Vec> = - polys.into_iter().map(|poly| poly.into()).collect_vec(); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - - let evals = (0..batch_size) - .map(|i| polys[i].evaluate(&point)) - .collect_vec(); - + let point = get_point_from_challenge(num_vars, &mut transcript); + let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); transcript.append_field_element_exts(&evals); let transcript_for_bench = transcript.clone(); + let polys = polys + .iter() + .map(|poly| ArcMultilinearExtension::from(poly.clone())) + .collect::>(); let proof = Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) .unwrap(); @@ -337,9 +295,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: let mut transcript = Transcript::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_exts(&evals); let backup_transcript = transcript.clone(); @@ -369,34 +325,61 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: } } -fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, false); +fn bench_commit_open_verify_goldilocks_ext_rs(c: &mut Criterion) { + bench_commit_open_verify_goldilocks::(c, EXT, "rs"); +} + +fn bench_commit_open_verify_goldilocks_ext_basecode(c: &mut Criterion) { + bench_commit_open_verify_goldilocks::(c, EXT, "basecode"); +} + +fn bench_commit_open_verify_goldilocks_base_rs(c: &mut Criterion) { + bench_commit_open_verify_goldilocks::(c, BASE, "rs"); +} + +fn bench_commit_open_verify_goldilocks_base_basecode(c: &mut Criterion) { + bench_commit_open_verify_goldilocks::(c, BASE, "basecode"); +} + +fn bench_batch_commit_open_verify_goldilocks_ext_rs(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, EXT, "rs"); +} + +fn bench_batch_commit_open_verify_goldilocks_ext_basecode(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, EXT, "basecode"); +} + +fn bench_batch_commit_open_verify_goldilocks_base_rs(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, BASE, "rs"); } -fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, true); +fn bench_batch_commit_open_verify_goldilocks_base_basecode(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, BASE, "basecode"); } -fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, false); +fn bench_simple_batch_commit_open_verify_goldilocks_ext_rs(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks::(c, EXT, "rs"); } -fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, true); +fn bench_simple_batch_commit_open_verify_goldilocks_ext_basecode(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks::(c, EXT, "basecode"); } -fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks(c, false); +fn bench_simple_batch_commit_open_verify_goldilocks_base_rs(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks::(c, BASE, "rs"); } -fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks(c, true); +fn bench_simple_batch_commit_open_verify_goldilocks_base_basecode(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks::(c, BASE, "basecode"); } criterion_group! { name = bench_basefold; config = Criterion::default().warm_up_time(Duration::from_millis(3000)); - targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2, bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, + targets = + bench_simple_batch_commit_open_verify_goldilocks_base_rs, bench_simple_batch_commit_open_verify_goldilocks_ext_rs, + bench_batch_commit_open_verify_goldilocks_base_rs, bench_batch_commit_open_verify_goldilocks_ext_rs, bench_commit_open_verify_goldilocks_base_rs, bench_commit_open_verify_goldilocks_ext_rs, + bench_simple_batch_commit_open_verify_goldilocks_base_basecode, bench_simple_batch_commit_open_verify_goldilocks_ext_basecode, bench_batch_commit_open_verify_goldilocks_base_basecode, bench_batch_commit_open_verify_goldilocks_ext_basecode, bench_commit_open_verify_goldilocks_base_basecode, bench_commit_open_verify_goldilocks_ext_basecode, } criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify_basecode.rs b/mpcs/benches/commit_open_verify_basecode.rs deleted file mode 100644 index 91baa5f73..000000000 --- a/mpcs/benches/commit_open_verify_basecode.rs +++ /dev/null @@ -1,400 +0,0 @@ -use std::time::Duration; - -use criterion::*; -use ff::Field; -use goldilocks::GoldilocksExt2; - -use itertools::{Itertools, chain}; -use mpcs::{ - Basefold, BasefoldBasecodeParams, Evaluation, PolynomialCommitmentScheme, - util::plonky2_util::log2_ceil, -}; - -use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; -use rand::{SeedableRng, rngs::OsRng}; -use rand_chacha::ChaCha8Rng; -use transcript::Transcript; - -type Pcs = Basefold; -type T = Transcript; -type E = GoldilocksExt2; - -const NUM_SAMPLES: usize = 10; -const NUM_VARS_START: usize = 20; -const NUM_VARS_END: usize = 20; -const BATCH_SIZE_LOG_START: usize = 6; -const BATCH_SIZE_LOG_END: usize = 6; - -fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { - let mut group = c.benchmark_group(format!( - "commit_open_verify_goldilocks_basecode_{}", - if is_base { "base" } else { "ext2" } - )); - group.sample_size(NUM_SAMPLES); - // Challenge is over extension field, poly over the base field - for num_vars in NUM_VARS_START..=NUM_VARS_END { - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - - group.bench_function(BenchmarkId::new("setup", format!("{}", num_vars)), |b| { - b.iter(|| { - Pcs::setup(poly_size).unwrap(); - }) - }); - Pcs::trim(param, poly_size).unwrap() - }; - - let mut transcript = T::new(b"BaseFold"); - let poly = if is_base { - DenseMultilinearExtension::random(num_vars, &mut OsRng) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - }; - - let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); - - group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { - b.iter(|| { - Pcs::commit(&pp, &poly).unwrap(); - }) - }); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - let eval = poly.evaluate(point.as_slice()); - transcript.append_field_element_ext(&eval); - let transcript_for_bench = transcript.clone(); - let proof = Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - - group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| { - b.iter_batched( - || transcript_for_bench.clone(), - |mut transcript| { - Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - }, - BatchSize::SmallInput, - ); - }); - // Verify - let comm = Pcs::get_pure_commitment(&comm); - let mut transcript = T::new(b"BaseFold"); - Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - transcript.append_field_element_ext(&eval); - let transcript_for_bench = transcript.clone(); - Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); - group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| { - b.iter_batched( - || transcript_for_bench.clone(), - |mut transcript| { - Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); - }, - BatchSize::SmallInput, - ); - }); - } -} - -fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { - let mut group = c.benchmark_group(format!( - "batch_commit_open_verify_goldilocks_basecode_{}", - if is_base { "base" } else { "ext2" } - )); - group.sample_size(NUM_SAMPLES); - // Challenge is over extension field, poly over the base field - for num_vars in NUM_VARS_START..=NUM_VARS_END { - for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { - let batch_size = 1 << batch_size_log; - let num_points = batch_size >> 1; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; - // Batch commit and open - let evals = chain![ - (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys - (0..num_points).map(|point| (point * 2 + 1, point)), - ] - .unique() - .collect_vec(); - - let mut transcript = T::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|i| { - if is_base { - DenseMultilinearExtension::random( - num_vars - log2_ceil((i >> 1) + 1), - &mut rng.clone(), - ) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars - log2_ceil((i >> 1) + 1), - (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) - .map(|_| E::random(&mut OsRng)) - .collect(), - ) - } - }) - .collect_vec(); - let comms = polys - .iter() - .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) - .collect_vec(); - - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); - - let evals = evals - .iter() - .copied() - .map(|(poly, point)| { - Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) - }) - .collect_vec(); - let values: Vec = evals - .iter() - .map(Evaluation::value) - .copied() - .collect::>(); - transcript.append_field_element_exts(values.as_slice()); - let transcript_for_bench = transcript.clone(); - let proof = - Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter_batched( - || transcript_for_bench.clone(), - |mut transcript| { - Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - // Batch verify - let mut transcript = T::new(b"BaseFold"); - let comms = comms - .iter() - .map(|comm| { - let comm = Pcs::get_pure_commitment(comm); - Pcs::write_commitment(&comm, &mut transcript).unwrap(); - comm - }) - .collect_vec(); - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); - - let values: Vec = evals - .iter() - .map(Evaluation::value) - .copied() - .collect::>(); - transcript.append_field_element_exts(values.as_slice()); - - let backup_transcript = transcript.clone(); - - Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter_batched( - || backup_transcript.clone(), - |mut transcript| { - Pcs::batch_verify( - &vp, - &comms, - &points, - &evals, - &proof, - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - } - } -} - -fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { - let mut group = c.benchmark_group(format!( - "simple_batch_commit_open_verify_goldilocks_basecode_{}", - if is_base { "base" } else { "extension" } - )); - group.sample_size(NUM_SAMPLES); - // Challenge is over extension field, poly over the base field - for num_vars in NUM_VARS_START..=NUM_VARS_END { - for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { - let batch_size = 1 << batch_size_log; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; - let mut transcript = T::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|_| { - if is_base { - DenseMultilinearExtension::random(num_vars, &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - } - }) - .collect_vec(); - let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter(|| { - Pcs::batch_commit(&pp, &polys).unwrap(); - }) - }, - ); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - - let evals = (0..batch_size) - .map(|i| polys[i].evaluate(&point)) - .collect_vec(); - - transcript.append_field_element_exts(&evals); - let transcript_for_bench = transcript.clone(); - let polys = polys - .clone() - .into_iter() - .map(|x| x.into()) - .collect::>(); - let proof = Pcs::simple_batch_open( - &pp, - polys.as_slice(), - &comm, - &point, - &evals, - &mut transcript, - ) - .unwrap(); - - group.bench_function( - BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter_batched( - || transcript_for_bench.clone(), - |mut transcript| { - Pcs::simple_batch_open( - &pp, - polys.as_slice(), - &comm, - &point, - &evals, - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - let comm = Pcs::get_pure_commitment(&comm); - - // Batch verify - let mut transcript = Transcript::new(b"BaseFold"); - Pcs::write_commitment(&comm, &mut transcript).unwrap(); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - transcript.append_field_element_exts(&evals); - let backup_transcript = transcript.clone(); - - Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter_batched( - || backup_transcript.clone(), - |mut transcript| { - Pcs::simple_batch_verify( - &vp, - &comm, - &point, - &evals, - &proof, - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - } - } -} - -fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, false); -} - -fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, true); -} - -fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, false); -} - -fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, true); -} - -fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks(c, false); -} - -fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks(c, true); -} - -criterion_group! { - name = bench_basefold; - config = Criterion::default().warm_up_time(Duration::from_millis(3000)); - targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2,bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, -} - -criterion_main!(bench_basefold); diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index be59f42c7..713bd2ec3 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -1179,8 +1179,8 @@ mod test { use crate::{ basefold::Basefold, test_util::{ - run_batch_commit_open_verify, run_commit_open_verify, - run_simple_batch_commit_open_verify, + gen_rand_poly_base, gen_rand_poly_ext, run_batch_commit_open_verify, + run_commit_open_verify, run_simple_batch_commit_open_verify, }, }; use goldilocks::GoldilocksExt2; @@ -1192,52 +1192,78 @@ mod test { #[test] fn commit_open_verify_goldilocks() { - for base in [true, false] { + for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { // Challenge is over extension field, poly over the base field - run_commit_open_verify::(base, 10, 11); + run_commit_open_verify::(gen_rand_poly, 10, 11); // Test trivial proof with small num vars - run_commit_open_verify::(base, 4, 6); + run_commit_open_verify::(gen_rand_poly, 4, 6); // Challenge is over extension field, poly over the base field - run_commit_open_verify::(base, 10, 11); + run_commit_open_verify::(gen_rand_poly, 10, 11); // Test trivial proof with small num vars - run_commit_open_verify::(base, 4, 6); + run_commit_open_verify::(gen_rand_poly, 4, 6); } } #[test] fn simple_batch_commit_open_verify_goldilocks() { - for base in [true, false] { + for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { // Both challenge and poly are over base field run_simple_batch_commit_open_verify::( - base, 10, 11, 1, + gen_rand_poly, + 10, + 11, + 1, ); run_simple_batch_commit_open_verify::( - base, 10, 11, 4, + gen_rand_poly, + 10, + 11, + 4, ); // Test trivial proof with small num vars run_simple_batch_commit_open_verify::( - base, 4, 6, 4, + gen_rand_poly, + 4, + 6, + 4, ); // Both challenge and poly are over base field run_simple_batch_commit_open_verify::( - base, 10, 11, 1, + gen_rand_poly, + 10, + 11, + 1, ); run_simple_batch_commit_open_verify::( - base, 10, 11, 4, + gen_rand_poly, + 10, + 11, + 4, ); // Test trivial proof with small num vars run_simple_batch_commit_open_verify::( - base, 4, 6, 4, + gen_rand_poly, + 4, + 6, + 4, ); } } #[test] fn batch_commit_open_verify() { - for base in [true, false] { + for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { // Both challenge and poly are over base field - run_batch_commit_open_verify::(base, 10, 11); - run_batch_commit_open_verify::(base, 10, 11); + run_batch_commit_open_verify::( + gen_rand_poly, + 10, + 11, + ); + run_batch_commit_open_verify::( + gen_rand_poly, + 10, + 11, + ); } } } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 3bf1310f6..a97dd0942 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -357,14 +357,25 @@ fn err_too_many_variates(function: &str, upto: usize, got: usize) -> Error { }) } -#[cfg(test)] +// TODO: Need to use some functions here in the integration benchmarks. But +// unfortunately integration benchmarks do not compile the #[cfg(test)] +// code. So remove the gate for the entire module, only gate the test +// functions. +// This is not the best way: the test utility functions should not be +// compiled in the release build. Need a better solution. +#[doc(hidden)] pub mod test_util { - use crate::{Evaluation, PolynomialCommitmentScheme}; + #[cfg(test)] + use crate::Evaluation; + use crate::PolynomialCommitmentScheme; use ff_ext::ExtensionField; - use itertools::{Itertools, chain}; + use itertools::Itertools; + #[cfg(test)] + use itertools::chain; + use multilinear_extensions::mle::DenseMultilinearExtension; + #[cfg(test)] use multilinear_extensions::{ - mle::{DenseMultilinearExtension, MultilinearExtension}, - virtual_poly_v2::ArcMultilinearExtension, + mle::MultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, }; use rand::rngs::OsRng; use transcript::Transcript; @@ -377,29 +388,26 @@ pub mod test_util { Pcs::trim(param, poly_size).unwrap() } - pub fn gen_rand_poly( - num_vars: usize, - base: bool, - ) -> DenseMultilinearExtension { - if base { - DenseMultilinearExtension::random(num_vars, &mut OsRng) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..(1 << num_vars)) - .map(|_| E::random(&mut OsRng)) - .collect_vec(), - ) - } + pub fn gen_rand_poly_base(num_vars: usize) -> DenseMultilinearExtension { + DenseMultilinearExtension::random(num_vars, &mut OsRng) + } + + pub fn gen_rand_poly_ext(num_vars: usize) -> DenseMultilinearExtension { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..(1 << num_vars)) + .map(|_| E::random(&mut OsRng)) + .collect_vec(), + ) } pub fn gen_rand_polys( num_vars: impl Fn(usize) -> usize, batch_size: usize, - base: bool, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, ) -> Vec> { (0..batch_size) - .map(|i| gen_rand_poly(num_vars(i), base)) + .map(|i| gen_rand_poly(num_vars(i))) .collect_vec() } @@ -432,8 +440,9 @@ pub mod test_util { .collect_vec() } + #[cfg(test)] pub fn run_commit_open_verify( - base: bool, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, ) where @@ -445,7 +454,7 @@ pub mod test_util { // Commit and open let (comm, eval, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let poly = gen_rand_poly(num_vars, base); + let poly = gen_rand_poly(num_vars); let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); let eval = poly.evaluate(point.as_slice()); @@ -472,8 +481,9 @@ pub mod test_util { } } + #[cfg(test)] pub fn run_batch_commit_open_verify( - base: bool, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, ) where @@ -495,7 +505,7 @@ pub mod test_util { let (comms, evals, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let polys = gen_rand_polys(|i| num_vars - (i >> 1), batch_size, base); + let polys = gen_rand_polys(|i| num_vars - (i >> 1), batch_size, gen_rand_poly); let comms = commit_polys_individually::(&pp, polys.as_slice(), &mut transcript); @@ -552,8 +562,9 @@ pub mod test_util { } } + #[cfg(test)] pub(super) fn run_simple_batch_commit_open_verify( - base: bool, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, batch_size: usize, @@ -566,7 +577,7 @@ pub mod test_util { let (comm, evals, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let polys = gen_rand_polys(|_| num_vars, batch_size, base); + let polys = gen_rand_polys(|_| num_vars, batch_size, gen_rand_poly); let comm = Pcs::batch_commit_and_write(&pp, polys.as_slice(), &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); From 1f5b990d934daca4cb1912681419df32fdaadbf2 Mon Sep 17 00:00:00 2001 From: Cyte Zhang Date: Thu, 28 Nov 2024 15:00:06 +0800 Subject: [PATCH 14/21] BaseFold: Add and reimplement some utility functions. (#559) Extracting small PRs from #294 There was an implementation of Iterator for iterating through a `FieldType`. Turns out that it can be replaced with existing tools in `iter_tools`. This proposes the new implementation. Also added some other utility functions for type conversion between `FieldType` and `Vec`, to avoid `match field_type` everywhere. --- mpcs/src/util.rs | 55 +++++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index 80ecf2535..7688b53ec 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -6,7 +6,7 @@ pub mod plonky2_util; use ff::{Field, PrimeField}; use ff_ext::ExtensionField; use goldilocks::SmallField; -use itertools::{Itertools, izip}; +use itertools::{Either, Itertools, izip}; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; pub mod merkle_tree; @@ -148,38 +148,45 @@ pub fn field_type_index_set_ext( } } -pub struct FieldTypeIterExt<'a, E: ExtensionField> { - inner: &'a FieldType, - index: usize, +pub fn poly_iter_ext( + poly: &DenseMultilinearExtension, +) -> impl Iterator + '_ { + field_type_iter_ext(&poly.evaluations) } -impl<'a, E: ExtensionField> Iterator for FieldTypeIterExt<'a, E> { - type Item = E; +pub fn field_type_iter_ext( + evaluations: &FieldType, +) -> impl Iterator + '_ { + match evaluations { + FieldType::Ext(coeffs) => Either::Left(coeffs.iter().copied()), + FieldType::Base(coeffs) => Either::Right(coeffs.iter().map(|x| (*x).into())), + _ => unreachable!(), + } +} - fn next(&mut self) -> Option { - if self.index >= self.inner.len() { - None - } else { - let res = field_type_index_ext(self.inner, self.index); - self.index += 1; - Some(res) - } +pub fn field_type_to_ext_vec(evaluations: &FieldType) -> Vec { + match evaluations { + FieldType::Ext(coeffs) => coeffs.to_vec(), + FieldType::Base(coeffs) => coeffs.iter().map(|&x| x.into()).collect(), + _ => unreachable!(), } } -pub fn poly_iter_ext( - poly: &DenseMultilinearExtension, -) -> FieldTypeIterExt { - FieldTypeIterExt { - inner: &poly.evaluations, - index: 0, +pub fn field_type_as_ext(values: &FieldType) -> &Vec { + match values { + FieldType::Ext(coeffs) => coeffs, + FieldType::Base(_) => panic!("Expected ext field"), + _ => unreachable!(), } } -pub fn field_type_iter_ext(evaluations: &FieldType) -> FieldTypeIterExt { - FieldTypeIterExt { - inner: evaluations, - index: 0, +pub fn field_type_iter_base( + values: &FieldType, +) -> impl Iterator + '_ { + match values { + FieldType::Ext(coeffs) => Either::Left(coeffs.iter().flat_map(|x| x.as_bases())), + FieldType::Base(coeffs) => Either::Right(coeffs.iter()), + _ => unreachable!(), } } From b729caf160d3e6fe9b0f52ca2f550cf742508828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Thu, 28 Nov 2024 15:27:35 +0800 Subject: [PATCH 15/21] Enable clippy check for `Cargo.toml` files (#401) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit One of the benefits is that we get warned, when we depend on multiple version of the same crate. That's inefficient and can lead to subtle errors. As an example, you can run this to see the different versions of `syn`: ```console $ cargo tree --invert syn error: There are multiple `syn` packages in your project, and the specification `syn` is ambiguous. Please re-run this command with one of the following specifications: syn@1.0.109 syn@2.0.79 ``` And then be more specific to see the offenders that still depend on the old version of `syn`: ```console $ cargo tree --invert syn@1 syn v1.0.109 └── unroll v0.1.5 (proc-macro) ├── plonky2 v0.2.2 │ └── mpcs v0.1.0 (/home/matthias/scroll/prog/ceno6/mpcs) │ └── ceno_zkvm v0.1.0 (/home/matthias/scroll/prog/ceno6/ceno_zkvm) │ [dev-dependencies] │ └── poseidon v0.1.0 (/home/matthias/scroll/prog/ceno6/poseidon) │ ├── ff_ext v0.1.0 (/home/matthias/scroll/prog/ceno6/ff_ext) │ │ ├── ceno_zkvm v0.1.0 (/home/matthias/scroll/prog/ceno6/ceno_zkvm) │ │ ├── mpcs v0.1.0 (/home/matthias/scroll/prog/ceno6/mpcs) (*) │ │ ├── multilinear_extensions v0.1.0 (/home/matthias/scroll/prog/ceno6/multilinear_extensions) │ │ │ ├── ceno_zkvm v0.1.0 (/home/matthias/scroll/prog/ceno6/ceno_zkvm) │ │ │ ├── mpcs v0.1.0 (/home/matthias/scroll/prog/ceno6/mpcs) (*) │ │ │ └── sumcheck v0.1.0 (/home/matthias/scroll/prog/ceno6/sumcheck) │ │ │ └── ceno_zkvm v0.1.0 (/home/matthias/scroll/prog/ceno6/ceno_zkvm) │ │ ├── sumcheck v0.1.0 (/home/matthias/scroll/prog/ceno6/sumcheck) (*) │ │ └── transcript v0.1.0 (/home/matthias/scroll/prog/ceno6/transcript) │ │ ├── ceno_zkvm v0.1.0 (/home/matthias/scroll/prog/ceno6/ceno_zkvm) │ │ ├── mpcs v0.1.0 (/home/matthias/scroll/prog/ceno6/mpcs) (*) │ │ └── sumcheck v0.1.0 (/home/matthias/scroll/prog/ceno6/sumcheck) (*) │ ├── mpcs v0.1.0 (/home/matthias/scroll/prog/ceno6/mpcs) (*) │ └── transcript v0.1.0 (/home/matthias/scroll/prog/ceno6/transcript) (*) ├── plonky2_field v0.2.2 │ └── plonky2 v0.2.2 (*) └── poseidon v0.1.0 (/home/matthias/scroll/prog/ceno6/poseidon) (*) ``` For now I've just whitelisted the existing offenders, but in the future a new PR can work on shrinking that list. --------- Co-authored-by: Zhang Zhuo --- Cargo.toml | 4 ++++ ceno_emul/Cargo.toml | 5 +++++ ceno_emul/src/lib.rs | 1 + ceno_rt/Cargo.toml | 5 +++++ ceno_rt/src/lib.rs | 1 + ceno_zkvm/Cargo.toml | 5 +++++ ceno_zkvm/src/lib.rs | 1 + clippy.toml | 10 ++++++++++ examples-builder/Cargo.toml | 4 ++++ examples-builder/src/lib.rs | 1 + examples/Cargo.toml | 1 + ff_ext/Cargo.toml | 5 +++++ ff_ext/src/lib.rs | 1 + mpcs/Cargo.toml | 5 +++++ mpcs/src/lib.rs | 1 + multilinear_extensions/Cargo.toml | 5 +++++ multilinear_extensions/src/lib.rs | 1 + poseidon/Cargo.toml | 5 +++++ poseidon/src/lib.rs | 1 + sumcheck/Cargo.toml | 5 +++++ sumcheck/src/lib.rs | 1 + transcript/Cargo.toml | 5 +++++ transcript/src/lib.rs | 1 + 23 files changed, 74 insertions(+) create mode 100644 clippy.toml diff --git a/Cargo.toml b/Cargo.toml index 4a638dad0..daa44f156 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,12 @@ members = [ resolver = "2" [workspace.package] +categories = ["cryptography", "zk", "blockchain", "ceno"] edition = "2021" +keywords = ["cryptography", "zk", "blockchain", "ceno"] license = "MIT OR Apache-2.0" +readme = "README.md" +repository = "https://github.com/scroll-tech/ceno" version = "0.1.0" [workspace.dependencies] diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index 6c172d97f..38f0a8bfd 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "A Risc-V emulator for Ceno" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ceno_emul" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index d8ec28ab0..c734b1794 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] mod addr; pub use addr::*; diff --git a/ceno_rt/Cargo.toml b/ceno_rt/Cargo.toml index 505f67dd8..dfdc87ad2 100644 --- a/ceno_rt/Cargo.toml +++ b/ceno_rt/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Ceno runtime library" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ceno_rt" +readme = "README.md" +repository.workspace = true version.workspace = true [dependencies] diff --git a/ceno_rt/src/lib.rs b/ceno_rt/src/lib.rs index 07c53a2bc..8de456c41 100644 --- a/ceno_rt/src/lib.rs +++ b/ceno_rt/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] #![feature(strict_overflow_ops)] #![no_std] diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index a6cf65cd4..13f4f2810 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Ceno ZKVM" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ceno_zkvm" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index a3c2ff02f..35013d448 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] #![feature(box_patterns)] #![feature(stmt_expr_attributes)] #![feature(variant_count)] diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 000000000..6e64e6f38 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,10 @@ +# TODO(Matthias): review and see which exception we can remove over time. +# Eg removing syn is blocked by ark-ff-asm cutting a new release +# (https://github.com/arkworks-rs/algebra/issues/813) amongst other things. +allowed-duplicate-crates = [ + "syn", + "windows-sys", + "regex-automata", + "regex-syntax", + "itertools", +] diff --git a/examples-builder/Cargo.toml b/examples-builder/Cargo.toml index 6862e806b..00104ec3c 100644 --- a/examples-builder/Cargo.toml +++ b/examples-builder/Cargo.toml @@ -1,5 +1,9 @@ [package] +categories.workspace = true +description = "Build scripts for ceno examples" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ceno-examples" +repository.workspace = true version.workspace = true diff --git a/examples-builder/src/lib.rs b/examples-builder/src/lib.rs index fdb344ff9..430c4d1de 100644 --- a/examples-builder/src/lib.rs +++ b/examples-builder/src/lib.rs @@ -1 +1,2 @@ +#![deny(clippy::cargo)] include!(concat!(env!("OUT_DIR"), "/vars.rs")); diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 08082d0bd..85c975ac1 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -1,6 +1,7 @@ [package] edition = "2021" name = "examples" +readme = "README.md" resolver = "2" version = "0.1.0" diff --git a/ff_ext/Cargo.toml b/ff_ext/Cargo.toml index 5c6a0d937..3b55f3581 100644 --- a/ff_ext/Cargo.toml +++ b/ff_ext/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Extra functionality for the ff ((finite fields) crate" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ff_ext" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index ba34bfa4b..32d77a565 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] pub use ff; use ff::FromUniformBytes; use goldilocks::SmallField; diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 1100cb133..f977328cc 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Multilinear Polynomial Commitment Scheme" edition.workspace = true +keywords.workspace = true license.workspace = true name = "mpcs" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index a97dd0942..b6716c19e 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index ad364def1..1a8777641 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Multilinear extensions for the Ceno project" edition.workspace = true +keywords.workspace = true license.workspace = true name = "multilinear_extensions" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/multilinear_extensions/src/lib.rs b/multilinear_extensions/src/lib.rs index 7f0a0c089..9f669e348 100644 --- a/multilinear_extensions/src/lib.rs +++ b/multilinear_extensions/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] pub mod mle; pub mod util; pub mod virtual_poly; diff --git a/poseidon/Cargo.toml b/poseidon/Cargo.toml index 489f4efc1..eff0f50b7 100644 --- a/poseidon/Cargo.toml +++ b/poseidon/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Poseidon hash function" edition.workspace = true +keywords.workspace = true license.workspace = true name = "poseidon" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/poseidon/src/lib.rs b/poseidon/src/lib.rs index 31ed313ad..17db28f72 100644 --- a/poseidon/src/lib.rs +++ b/poseidon/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] extern crate core; pub(crate) mod constants; diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 540495c0a..092187cba 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Sumcheck protocol implementation" edition.workspace = true +keywords.workspace = true license.workspace = true name = "sumcheck" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 14ed79aed..0d0b95adf 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] #[cfg(feature = "non_pow2_rayon_thread")] pub mod local_thread_pool; mod macros; diff --git a/transcript/Cargo.toml b/transcript/Cargo.toml index 4769afe00..f784b689f 100644 --- a/transcript/Cargo.toml +++ b/transcript/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Transcript generation for Ceno" edition.workspace = true +keywords.workspace = true license.workspace = true name = "transcript" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/transcript/src/lib.rs b/transcript/src/lib.rs index b291fc58b..376bdb1ca 100644 --- a/transcript/src/lib.rs +++ b/transcript/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] //! This repo is not properly implemented //! Transcript APIs are placeholders; the actual logic is to be implemented later. #![feature(generic_arg_infer)] From cec7b82c29f0d06041bf20aec3e3c0932e6643fa Mon Sep 17 00:00:00 2001 From: Zhang Zhuo Date: Thu, 28 Nov 2024 21:28:32 +0800 Subject: [PATCH 16/21] feat: log proving khz in e2e.rs (#652) ``` Proving finished. Proving time = 76.888s, freq = 150.074khz Witgen time = 1.633s, freq = 7064.266khz Total time = 78.522s, freq = 146.952khz thread num: 8 ``` --- ceno_zkvm/src/bin/e2e.rs | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index af0ae40de..a4f3a8c30 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -187,7 +187,8 @@ fn main() { .collect::, _>>() .expect("vm exec failed"); - tracing::info!("Proving {} execution steps", all_records.len()); + let cycle_num = all_records.len(); + tracing::info!("Proving {} execution steps", cycle_num); for (i, step) in enumerate(&all_records).rev().take(5).rev() { tracing::trace!("Step {i}: {:?} - {:?}\n", step.insn().codes().kind, step); } @@ -306,10 +307,22 @@ fn main() { .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); + let proving_time = timer.elapsed().as_secs_f64(); + let e2e_time = e2e_start.elapsed().as_secs_f64(); + let witgen_time = e2e_time - proving_time; println!( - "fibonacci create_proof, time = {}, e2e = {:?}", - timer.elapsed().as_secs_f64(), - e2e_start.elapsed(), + "Proving finished.\n\ +\tProving time = {:.3}s, freq = {:.3}khz\n\ +\tWitgen time = {:.3}s, freq = {:.3}khz\n\ +\tTotal time = {:.3}s, freq = {:.3}khz\n\ +\tthread num: {}", + proving_time, + cycle_num as f64 / proving_time / 1000.0, + witgen_time, + cycle_num as f64 / witgen_time / 1000.0, + e2e_time, + cycle_num as f64 / e2e_time / 1000.0, + rayon::current_num_threads() ); let transcript = Transcript::new(b"riscv"); From 998e93340f2b65b5c2d06dde2da610931eada3ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Fri, 29 Nov 2024 16:58:37 +0800 Subject: [PATCH 17/21] Implement `IntoIterator` and use it for an example (#468) There's lots more in the code where we can use this to simplify things later. But I want to keep the PR simple. --- ceno_zkvm/src/uint.rs | 33 ++++++++++++++++++++++++++------ ceno_zkvm/src/uint/arithmetic.rs | 14 ++++++-------- ceno_zkvm/src/uint/logic.rs | 2 +- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index afb231d8a..193d34f13 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -17,7 +17,7 @@ use ark_std::iterable::Iterable; use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; -use itertools::Itertools; +use itertools::{Itertools, enumerate}; use std::{ borrow::Cow, mem::{self, MaybeUninit}, @@ -34,15 +34,36 @@ pub enum UintLimb { Expression(Vec>), } -impl UintLimb { - pub fn iter(&self) -> impl Iterator { +impl IntoIterator for UintLimb { + type Item = WitIn; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { match self { - UintLimb::WitIn(vec) => vec.iter(), + UintLimb::WitIn(wits) => wits.into_iter(), _ => unimplemented!(), } } } +impl<'a, E: ExtensionField> IntoIterator for &'a UintLimb { + type Item = &'a WitIn; + type IntoIter = std::slice::Iter<'a, WitIn>; + + fn into_iter(self) -> Self::IntoIter { + match self { + UintLimb::WitIn(wits) => wits.iter(), + _ => unimplemented!(), + } + } +} + +impl UintLimb { + pub fn iter(&self) -> impl Iterator { + self.into_iter() + } +} + impl Index for UintLimb { type Output = WitIn; @@ -735,8 +756,8 @@ impl<'a, T: Into + From + Copy + Default> Value<'a, T> { let mut c_limbs = vec![0u16; num_limbs]; let mut carries = vec![0u64; num_limbs]; let mut tmp = vec![0u64; num_limbs]; - a_limbs.iter().enumerate().for_each(|(i, &a_limb)| { - b_limbs.iter().enumerate().for_each(|(j, &b_limb)| { + enumerate(a_limbs).for_each(|(i, &a_limb)| { + enumerate(b_limbs).for_each(|(j, &b_limb)| { let idx = i + j; if idx < num_limbs { tmp[idx] += a_limb as u64 * b_limb as u64; diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 85eb703c8..dfe33b076 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -267,14 +267,12 @@ impl UIntLimbs { rhs: &UIntLimbs, ) -> Result { let n_limbs = Self::NUM_LIMBS; - let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = self - .limbs - .iter() - .zip_eq(rhs.limbs.iter()) - .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) - .collect::, ZKVMError>>()? - .into_iter() - .unzip(); + let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = + izip!(&self.limbs, &rhs.limbs) + .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) + .collect::, ZKVMError>>()? + .into_iter() + .unzip(); let sum_expr = is_equal_per_limb .iter() diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs index b340df982..024d09d73 100644 --- a/ceno_zkvm/src/uint/logic.rs +++ b/ceno_zkvm/src/uint/logic.rs @@ -18,7 +18,7 @@ impl UIntLimbs { b: &Self, c: &Self, ) -> Result<(), ZKVMError> { - for (a_byte, b_byte, c_byte) in izip!(a.limbs.iter(), b.limbs.iter(), c.limbs.iter()) { + for (a_byte, b_byte, c_byte) in izip!(&a.limbs, &b.limbs, &c.limbs) { cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; } Ok(()) From 7cfa6932963ede139264f43af1cf877611f8f154 Mon Sep 17 00:00:00 2001 From: Ming Date: Fri, 29 Nov 2024 19:07:38 +0700 Subject: [PATCH 18/21] fix tower sumcheck default logic with minimal table size 2 (#646) A follow-up PR after https://github.com/scroll-tech/ceno/pull/622 In #622 integration test was failed in first place. A proposed fix in comment https://github.com/scroll-tech/ceno/pull/622#discussion_r1855953467 is to skip the entire proof when table size should be 0. This fix is reasonable and fit the expectation, however, the integration test should NOT failed even BEFORE this fix. It should be garbage-in, garbage out and shouldn't break the proof. ### Root cause The root cause is when table size is 1 (we will padding next size which is 2) or 2, in tower sumcheck there will be empty rounds. In tower sumcheck verifier, there are bookkeeping variables to tracking each layer (evaluation, point), and bookkeeping variables will keep updated when goes through layer by layer verification. However when table size is 2, there is only one layer, so bookkeeping variables skip updated. However in this case, default value was wrong. The correct "default" value of bookkeeping variable should be the output layer evaluations. This PR correcting the default value design. Also add simple test to verify the logic. Co-authored-by: Wu Sung-Ming Co-authored-by: naure --- ceno_zkvm/src/scheme/tests.rs | 70 +++++++++++++++++++++++++++++++- ceno_zkvm/src/scheme/verifier.rs | 61 ++++++++++++++++++---------- 2 files changed, 108 insertions(+), 23 deletions(-) diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index c66884ee7..260c17fae 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,5 +1,6 @@ use std::{marker::PhantomData, mem::MaybeUninit}; +use ark_std::test_rng; use ceno_emul::{ CENO_PLATFORM, InsnKind::{ADD, EANY}, @@ -10,6 +11,9 @@ use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::Itertools; use mpcs::{Basefold, BasefoldDefault, BasefoldRSParams, PolynomialCommitmentScheme}; +use multilinear_extensions::{ + mle::IntoMLE, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, +}; use transcript::Transcript; use crate::{ @@ -23,7 +27,8 @@ use crate::{ }, set_val, structs::{ - PointAndEval, RAMType::Register, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses, + PointAndEval, RAMType::Register, TowerProver, TowerProverSpec, ZKVMConstraintSystem, + ZKVMFixedTraces, ZKVMWitnesses, }, tables::{ProgramTableCircuit, U16TableCircuit}, witness::LkMultiplicity, @@ -33,7 +38,8 @@ use super::{ PublicValues, constants::{MAX_NUM_VARIABLES, NUM_FANIN}, prover::ZKVMProver, - verifier::ZKVMVerifier, + utils::infer_tower_product_witness, + verifier::{TowerVerify, ZKVMVerifier}, }; struct TestConfig { @@ -311,3 +317,63 @@ fn test_single_add_instance_e2e() { .expect("verify proof return with error"), ); } + +/// test various product argument size, starting from minimal leaf size 2 +#[test] +fn test_tower_proof_various_prod_size() { + fn _test_tower_proof_prod_size_2(leaf_layer_size: usize) { + let num_vars = ceil_log2(leaf_layer_size); + let mut rng = test_rng(); + type E = GoldilocksExt2; + let mut transcript = Transcript::new(b"test_tower_proof"); + let leaf_layer: ArcMultilinearExtension = (0..leaf_layer_size) + .map(|_| E::random(&mut rng)) + .collect_vec() + .into_mle() + .into(); + let (first, second): (&[E], &[E]) = leaf_layer + .get_ext_field_vec() + .split_at(leaf_layer.evaluations().len() / 2); + let last_layer_splitted_fanin: Vec> = vec![ + first.to_vec().into_mle().into(), + second.to_vec().into_mle().into(), + ]; + let layers = infer_tower_product_witness(num_vars, last_layer_splitted_fanin, 2); + let (rt_tower_p, tower_proof) = TowerProver::create_proof( + vec![TowerProverSpec { + witness: layers.clone(), + }], + vec![], + 2, + &mut transcript, + ); + + let mut transcript = Transcript::new(b"test_tower_proof"); + let (rt_tower_v, prod_point_and_eval, _, _) = TowerVerify::verify( + vec![ + layers[0] + .iter() + .flat_map(|mle| mle.get_ext_field_vec().to_vec()) + .collect_vec(), + ], + vec![], + &tower_proof, + vec![num_vars], + 2, + &mut transcript, + ) + .unwrap(); + + assert_eq!(rt_tower_p, rt_tower_v); + assert_eq!(rt_tower_v.len(), num_vars); + assert_eq!(prod_point_and_eval.len(), 1); + assert_eq!( + leaf_layer.evaluate(&rt_tower_v), + prod_point_and_eval[0].eval + ); + } + + for leaf_layer_size in 1..10 { + _test_tower_proof_prod_size_2(1 << leaf_layer_size); + } +} diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index bdce2bd09..78f70793f 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -848,22 +848,41 @@ impl TowerVerify { // out_j[rt] := (record_{j}[rt]) // out_j[rt] := (logup_p{j}[rt]) // out_j[rt] := (logup_q{j}[rt]) - let initial_claim = izip!(prod_out_evals, alpha_pows.iter()) - .map(|(evals, alpha)| evals.into_mle().evaluate(&initial_rt) * alpha) + + // bookkeeping records of latest (point, evaluation) of each layer + // prod argument + let mut prod_spec_point_n_eval = prod_out_evals + .into_iter() + .map(|evals| { + PointAndEval::new(initial_rt.clone(), evals.into_mle().evaluate(&initial_rt)) + }) + .collect::>(); + // logup argument for p, q + let (mut logup_spec_p_point_n_eval, mut logup_spec_q_point_n_eval) = logup_out_evals + .into_iter() + .map(|evals| { + let (p1, p2, q1, q2) = (evals[0], evals[1], evals[2], evals[3]); + ( + PointAndEval::new( + initial_rt.clone(), + vec![p1, p2].into_mle().evaluate(&initial_rt), + ), + PointAndEval::new( + initial_rt.clone(), + vec![q1, q2].into_mle().evaluate(&initial_rt), + ), + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) + .map(|(point_n_eval, alpha)| point_n_eval.eval * alpha) .sum::() - + izip!(logup_out_evals, alpha_pows[num_prod_spec..].chunks(2)) - .map(|(evals, alpha)| { - let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); - let (p1, p2, q1, q2) = (evals[0], evals[1], evals[2], evals[3]); - vec![p1, p2].into_mle().evaluate(&initial_rt) * alpha_numerator - + vec![q1, q2].into_mle().evaluate(&initial_rt) * alpha_denominator - }) - .sum::(); - - // evaluation in the tower input layer - let mut prod_spec_input_layer_eval = vec![PointAndEval::default(); num_prod_spec]; - let mut logup_spec_p_input_layer_eval = vec![PointAndEval::default(); num_logup_spec]; - let mut logup_spec_q_input_layer_eval = vec![PointAndEval::default(); num_logup_spec]; + + izip!( + interleave(&logup_spec_p_point_n_eval, &logup_spec_q_point_n_eval), + &alpha_pows[num_prod_spec..] + ) + .map(|(point_n_eval, alpha)| point_n_eval.eval * alpha) + .sum::(); let max_num_variables = num_variables.iter().max().unwrap(); @@ -954,7 +973,7 @@ impl TowerVerify { .map(|(a, b)| *a * b) .sum::(); // this will keep update until round > evaluation - prod_spec_input_layer_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals); + prod_spec_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals); if next_round < max_round -1 { *alpha * evals } else { @@ -987,8 +1006,8 @@ impl TowerVerify { .sum::(); // this will keep update until round > evaluation - logup_spec_p_input_layer_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals); - logup_spec_q_input_layer_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals); + logup_spec_p_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals); + logup_spec_q_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals); if next_round < max_round -1 { *alpha_numerator * p_evals + *alpha_denominator * q_evals @@ -1011,9 +1030,9 @@ impl TowerVerify { Ok(( next_rt.point, - prod_spec_input_layer_eval, - logup_spec_p_input_layer_eval, - logup_spec_q_input_layer_eval, + prod_spec_point_n_eval, + logup_spec_p_point_n_eval, + logup_spec_q_point_n_eval, )) } } From b9b50da1095ac993d7579fa9dcc4b4ae16531274 Mon Sep 17 00:00:00 2001 From: naure Date: Fri, 29 Nov 2024 14:09:39 +0100 Subject: [PATCH 19/21] Test/ Sproll EVM prover (#645) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR contains configuration that can prove the [sproll-evm executable](https://github.com/scroll-tech/sproll-evm/pull/60/files#diff-0e528863de23811e1d04049eccabd887e785625f5f0fe7381122cfeeba24960d). ``` RUST_LOG=info,e2e=debug,ceno_emul=debug,ceno_zkvm::scheme=warn MOCK_PROVING=1 cargo run --package ceno_zkvm --bin e2e --release -- --platform=sp1 --max-steps=10000 riscv32im-succinct-zkvm-elf ``` (remove `--max-steps` for the full run) Warning: ecalls WRITE, COMMIT, and COMMIT_DEFERRED_PROOFS are replaced with NOP. --------- Co-authored-by: Aurélien Nicolas --- ceno_emul/src/vm_state.rs | 5 ++- ceno_zkvm/src/bin/e2e.rs | 88 +++++++++++++++++++++++++-------------- 2 files changed, 60 insertions(+), 33 deletions(-) diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 296d25be8..d8eff9701 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::rv32im::EmuContext; use crate::{ - PC_STEP_SIZE, Program, + PC_STEP_SIZE, Program, WORD_SIZE, addr::{ByteAddr, RegIdx, Word, WordAddr}, platform::Platform, rv32im::{DecodedInstruction, Emulator, TrapCause}, @@ -123,7 +123,8 @@ impl EmuContext for VMState { // Read two registers, write one register, write one memory word, and branch. tracing::warn!("ecall ignored: syscall_id={}", function); self.store_register(DecodedInstruction::RD_NULL as RegIdx, 0)?; - let addr = self.platform.ram.start.into(); + // Example ecall effect - any writable address will do. + let addr = (self.platform.stack_top - WORD_SIZE as u32).into(); self.store_memory(addr, self.peek_memory(addr))?; self.set_pc(ByteAddr(self.pc) + PC_STEP_SIZE); Ok(true) diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index a4f3a8c30..bb4f3c0ed 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,6 +1,6 @@ use ceno_emul::{ - ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, IterAddresses, Platform, StepRecord, - Tracer, VMState, WORD_SIZE, Word, WordAddr, + ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, IterAddresses, Platform, Program, + StepRecord, Tracer, VMState, WORD_SIZE, Word, WordAddr, }; use ceno_zkvm::{ instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, @@ -49,6 +49,14 @@ struct Args { /// Zero-padded to the right to the next power-of-two size. #[arg(long)] hints: Option, + + /// Stack size in bytes. + #[arg(long, default_value = "32768")] + stack_size: u32, + + /// Heap size in bytes. + #[arg(long, default_value = "2097152")] + heap_size: u32, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -58,11 +66,15 @@ enum Preset { } fn main() { - let args = Args::parse(); + let args = { + let mut args = Args::parse(); + args.stack_size = args.stack_size.next_multiple_of(WORD_SIZE as u32); + args.heap_size = args.heap_size.next_multiple_of(WORD_SIZE as u32); + args + }; type E = GoldilocksExt2; type Pcs = Basefold; - const PROGRAM_SIZE: usize = 1 << 14; type ExampleProgramTableCircuit = ProgramTableCircuit; // set up logger @@ -82,25 +94,57 @@ fn main() { .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); + let elf_bytes = fs::read(&args.elf).expect("read elf file"); + let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); + let platform = match args.platform { Preset::Ceno => CENO_PLATFORM, Preset::Sp1 => Platform { // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. stack_top: 0x0020_0400, - rom: 0x0020_0800..0x0040_0000, - ram: 0x0020_0000..0xFFFF_0000, + rom: program.base_address + ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, + ram: 0x0010_0000..0xFFFF_0000, unsafe_ecall_nop: true, ..CENO_PLATFORM }, }; - tracing::info!("Running on platform {:?}", args.platform); + tracing::info!("Running on platform {:?} {:?}", args.platform, platform); + tracing::info!( + "Stack: {} bytes. Heap: {} bytes.", + args.stack_size, + args.heap_size + ); + + let stack_addrs = platform.stack_top - args.stack_size..platform.stack_top; + + // Detect heap as starting after program data. + let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; + let heap_addrs = heap_start..heap_start + args.heap_size; + + let mut mem_padder = MemPadder::new(heap_addrs.end..platform.ram.end); + + let mem_init = { + let program_addrs = program.image.iter().map(|(addr, value)| MemInitRecord { + addr: *addr, + value: *value, + }); + + let stack = stack_addrs + .iter_addresses() + .map(|addr| MemInitRecord { addr, value: 0 }); + + let heap = heap_addrs + .iter_addresses() + .map(|addr| MemInitRecord { addr, value: 0 }); - const STACK_SIZE: u32 = 256; - let mut mem_padder = MemPadder::new(platform.ram.clone()); + let mem_init = chain!(program_addrs, stack, heap).collect_vec(); + + mem_padder.padded_sorted(mem_init.len().next_power_of_two(), mem_init) + }; tracing::info!("Loading ELF file: {}", args.elf); - let elf_bytes = fs::read(&args.elf).expect("read elf file"); - let mut vm = VMState::new_from_elf(platform.clone(), &elf_bytes).unwrap(); + let mut vm = VMState::new(platform.clone(), program); tracing::info!("Loading hints file: {:?}", args.hints); let hints = memory_from_file(&args.hints); @@ -118,7 +162,8 @@ fn main() { let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let program_params = ProgramParams { platform: platform.clone(), - program_size: PROGRAM_SIZE, + program_size: vm.program().instructions.len(), + static_memory_len: mem_init.len(), ..ProgramParams::default() }; let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); @@ -137,25 +182,6 @@ fn main() { vm.program(), ); - let mem_init = { - let program_addrs = vm - .program() - .image - .iter() - .map(|(addr, value)| MemInitRecord { - addr: *addr, - value: *value, - }); - - let stack_addrs = (1..=STACK_SIZE) - .map(|i| platform.stack_top - i * WORD_SIZE as u32) - .map(|addr| MemInitRecord { addr, value: 0 }); - - let mem_init = chain!(program_addrs, stack_addrs).collect_vec(); - - mem_padder.padded_sorted(mmu_config.static_mem_len(), mem_init) - }; - // IO is not used in this program, but it must have a particular size at the moment. let io_init = mem_padder.padded_sorted(mmu_config.public_io_len(), vec![]); From fc34d2b0cca3597bb429532da4fbab58cc654178 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 07:31:43 +0000 Subject: [PATCH 20/21] Bump tracing-subscriber from 0.3.18 to 0.3.19 (#664) --- Cargo.lock | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7a9651cc3..923484c8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1915,9 +1915,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -1926,9 +1926,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -1937,9 +1937,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", "valuable", @@ -1969,9 +1969,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ "matchers", "nu-ansi-term", From a88b95b95d892eefc2f44a75221bc395307e186b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Mon, 2 Dec 2024 15:37:45 +0800 Subject: [PATCH 21/21] refactor: Simplify type conversions (#662) Just to improve readability --- ceno_zkvm/src/scheme/utils.rs | 4 +--- ceno_zkvm/src/virtual_polys.rs | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 6b2ab6482..c8ec6453a 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -77,9 +77,7 @@ pub(crate) fn interleaving_mles_to_mles<'a, E: ExtensionField>( .with_min_len(MIN_PAR_SIZE) .for_each(|(value, instance)| { assert_eq!(instance.len(), per_instance_size); - instance[i] = <::BaseField as Into< - E, - >>::into(*value); + instance[i] = E::from(*value); }), _ => unreachable!(), }); diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 75c930191..6302888d1 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -246,8 +246,8 @@ mod tests { }, ); - <::BaseField as std::convert::Into>::into( - evals.iter().sum::<::BaseField>() + GoldilocksExt2::from( + evals.iter().sum::() * base_2.pow([(max_num_vars - fs[0].num_vars()) as u64]), ) };