diff --git a/Cargo.lock b/Cargo.lock index d400b0e9d..923484c8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -241,6 +241,7 @@ dependencies = [ "anyhow", "ceno-examples", "elf", + "itertools 0.13.0", "num-derive", "num-traits", "strum", @@ -1914,9 +1915,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -1925,9 +1926,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -1936,9 +1937,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", "valuable", @@ -1968,9 +1969,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ "matchers", "nu-ansi-term", diff --git a/Cargo.toml b/Cargo.toml index 4a638dad0..daa44f156 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,12 @@ members = [ resolver = "2" [workspace.package] +categories = ["cryptography", "zk", "blockchain", "ceno"] edition = "2021" +keywords = ["cryptography", "zk", "blockchain", "ceno"] license = "MIT OR Apache-2.0" +readme = "README.md" +repository = "https://github.com/scroll-tech/ceno" version = "0.1.0" [workspace.dependencies] diff --git a/README.md b/README.md index 4c6fd6114..82f2e8bfe 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,11 @@ Please see [the slightly outdated paper](https://eprint.iacr.org/2024/387) for a 🚧 This project is currently under construction and not suitable for use in production. 🚧 -If you are unfamiliar with the RISC-V instruction set, please have a look at the [RISC-V instruction set reference](https://github.com/jameslzhu/riscv-card/blob/master/riscv-card.pdf). +If you are unfamiliar with the RISC-V instruction set, please have a look at the [RISC-V instruction set reference](https://github.com/jameslzhu/riscv-card/releases/download/latest/riscv-card.pdf). ## Local build requirements -Ceno is built in Rust, so [installing the Rust toolchain](https://www.rust-lang.org/tools/install) is a pre-requisite, if you want to develop on your local machine. We also use [cargo-make](https://sagiegurari.github.io/cargo-make/) to build Ceno. You can install cargo-make with the following command: +Ceno is built in Rust, so [installing the Rust toolchain](https://www.rust-lang.org/tools/install) is a pre-requisite if you want to develop on your local machine. We also use [cargo-make](https://sagiegurari.github.io/cargo-make/) to build Ceno. You can install cargo-make with the following command: ```sh cargo install cargo-make diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index 6b93a3f0a..38f0a8bfd 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -1,12 +1,18 @@ [package] +categories.workspace = true +description = "A Risc-V emulator for Ceno" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ceno_emul" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] anyhow = { version = "1.0", default-features = false } elf = "0.7" +itertools.workspace = true num-derive.workspace = true num-traits.workspace = true strum.workspace = true diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 200739ce7..78b01563e 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -14,7 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{fmt, ops}; +use std::{ + fmt, + ops::{self, Range}, +}; pub const WORD_SIZE: usize = 4; pub const PC_WORD_SIZE: usize = 4; @@ -192,3 +195,29 @@ impl ops::AddAssign for ByteAddr { self.0 += rhs; } } + +pub trait IterAddresses { + fn iter_addresses(&self) -> impl ExactSizeIterator; +} + +impl IterAddresses for Range { + fn iter_addresses(&self) -> impl ExactSizeIterator { + self.clone().step_by(WORD_SIZE) + } +} + +impl<'a, T: GetAddr> IterAddresses for &'a [T] { + fn iter_addresses(&self) -> impl ExactSizeIterator { + self.iter().map(T::get_addr) + } +} + +pub trait GetAddr { + fn get_addr(&self) -> Addr; +} + +impl GetAddr for Addr { + fn get_addr(&self) -> Addr { + *self + } +} diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index d8ec28ab0..c734b1794 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] mod addr; pub use addr::*; diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 8836bdfbb..c28bcda40 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -1,3 +1,5 @@ +use std::ops::Range; + use crate::addr::{Addr, RegIdx}; /// The Platform struct holds the parameters of the VM. @@ -7,20 +9,20 @@ use crate::addr::{Addr, RegIdx}; /// - codes of environment calls. #[derive(Clone, Debug)] pub struct Platform { - pub rom_start: Addr, - pub rom_end: Addr, - pub ram_start: Addr, - pub ram_end: Addr, + pub rom: Range, + pub ram: Range, + pub public_io: Range, + pub hints: Range, pub stack_top: Addr, /// If true, ecall instructions are no-op instead of trap. Testing only. pub unsafe_ecall_nop: bool, } pub const CENO_PLATFORM: Platform = Platform { - rom_start: 0x2000_0000, - rom_end: 0x3000_0000 - 1, - ram_start: 0x8000_0000, - ram_end: 0xFFFF_0000 - 1, + rom: 0x2000_0000..0x3000_0000, + ram: 0x8000_0000..0xFFFF_0000, + public_io: 0x3000_1000..0x3000_2000, + hints: 0x4000_0000..0x5000_0000, stack_top: 0xC0000000, unsafe_ecall_nop: false, }; @@ -28,53 +30,20 @@ pub const CENO_PLATFORM: Platform = Platform { impl Platform { // Virtual memory layout. - pub const fn rom_start(&self) -> Addr { - self.rom_start - } - - pub const fn rom_end(&self) -> Addr { - self.rom_end - } - pub fn is_rom(&self, addr: Addr) -> bool { - (self.rom_start()..=self.rom_end()).contains(&addr) - } - - // TODO figure out a proper region for public io - pub const fn public_io_start(&self) -> Addr { - 0x3000_1000 - } - - pub const fn public_io_end(&self) -> Addr { - 0x3000_2000 - 1 - } - - pub const fn ram_start(&self) -> Addr { - if cfg!(feature = "forbid_overflow") { - // -1<<11 == 0x800 is the smallest negative 'immediate' - // offset we can have in memory instructions. - // So if we stay away from it, we are safe. - assert!(self.ram_start >= 0x800); - } - self.ram_start - } - - pub const fn ram_end(&self) -> Addr { - if cfg!(feature = "forbid_overflow") { - // (1<<11) - 1 == 0x7ff is the largest positive 'immediate' - // offset we can have in memory instructions. - // So if we stay away from it, we are safe. - assert!(self.ram_end < -(1_i32 << 11) as u32) - } - self.ram_end + self.rom.contains(&addr) } pub fn is_ram(&self, addr: Addr) -> bool { - (self.ram_start()..=self.ram_end()).contains(&addr) + self.ram.contains(&addr) } pub fn is_pub_io(&self, addr: Addr) -> bool { - (self.public_io_start()..=self.public_io_end()).contains(&addr) + self.public_io.contains(&addr) + } + + pub fn is_hints(&self, addr: Addr) -> bool { + self.hints.contains(&addr) } /// Virtual address of a register. @@ -91,13 +60,13 @@ impl Platform { // Startup. pub const fn pc_base(&self) -> Addr { - self.rom_start() + self.rom.start } // Permissions. pub fn can_read(&self, addr: Addr) -> bool { - self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) + self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) || self.is_hints(addr) } pub fn can_write(&self, addr: Addr) -> bool { @@ -139,17 +108,17 @@ impl Platform { #[cfg(test)] mod tests { use super::*; - use crate::VMState; + use crate::{VMState, WORD_SIZE}; #[test] fn test_no_overlap() { let p = CENO_PLATFORM; assert!(p.can_execute(p.pc_base())); // ROM and RAM do not overlap. - assert!(!p.is_rom(p.ram_start())); - assert!(!p.is_rom(p.ram_end())); - assert!(!p.is_ram(p.rom_start())); - assert!(!p.is_ram(p.rom_end())); + assert!(!p.is_rom(p.ram.start)); + assert!(!p.is_rom(p.ram.end - WORD_SIZE as Addr)); + assert!(!p.is_ram(p.rom.start)); + assert!(!p.is_ram(p.rom.end - WORD_SIZE as Addr)); // Registers do not overlap with ROM or RAM. for reg in [ Platform::register_vma(0), diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 19dcd2ecf..4554bdb08 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -15,6 +15,7 @@ // limitations under the License. use anyhow::{Result, anyhow}; +use itertools::enumerate; use num_derive::ToPrimitive; use std::sync::OnceLock; use strum_macros::EnumIter; @@ -403,7 +404,7 @@ struct FastDecodeTable { impl FastDecodeTable { fn new() -> Self { let mut table: FastInstructionTable = [0; 1 << 10]; - for (isa_idx, insn) in RV32IM_ISA.iter().enumerate() { + for (isa_idx, insn) in enumerate(&RV32IM_ISA) { Self::add_insn(&mut table, insn, isa_idx); } Self { table } diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index d9f87b695..bbb1e6a51 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -199,7 +199,7 @@ impl StepRecord { Some(value), Some(Change::new(value, value)), Some(WriteOp { - addr: CENO_PLATFORM.ram_start().into(), + addr: CENO_PLATFORM.ram.start.into(), value: Change { before: value, after: value, diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 68dd4adaf..d8eff9701 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::rv32im::EmuContext; use crate::{ - PC_STEP_SIZE, Program, + PC_STEP_SIZE, Program, WORD_SIZE, addr::{ByteAddr, RegIdx, Word, WordAddr}, platform::Platform, rv32im::{DecodedInstruction, Emulator, TrapCause}, @@ -44,7 +44,7 @@ impl VMState { }; // init memory from program.image - for (&addr, &value) in program.image.iter() { + for (&addr, &value) in &program.image { vm.init_memory(ByteAddr(addr).waddr(), value); } @@ -123,7 +123,8 @@ impl EmuContext for VMState { // Read two registers, write one register, write one memory word, and branch. tracing::warn!("ecall ignored: syscall_id={}", function); self.store_register(DecodedInstruction::RD_NULL as RegIdx, 0)?; - let addr = self.platform.ram_start().into(); + // Example ecall effect - any writable address will do. + let addr = (self.platform.stack_top - WORD_SIZE as u32).into(); self.store_memory(addr, self.peek_memory(addr))?; self.set_pc(ByteAddr(self.pc) + PC_STEP_SIZE); Ok(true) diff --git a/ceno_emul/tests/test_elf.rs b/ceno_emul/tests/test_elf.rs index 3168d265a..ca7a14c1c 100644 --- a/ceno_emul/tests/test_elf.rs +++ b/ceno_emul/tests/test_elf.rs @@ -27,7 +27,7 @@ fn test_ceno_rt_mem() -> Result<()> { let mut state = VMState::new_from_elf(CENO_PLATFORM, program_elf)?; let _steps = run(&mut state)?; - let value = state.peek_memory(CENO_PLATFORM.ram_start().into()); + let value = state.peek_memory(CENO_PLATFORM.ram.start.into()); assert_eq!(value, 6765, "Expected Fibonacci 20, got {}", value); Ok(()) } diff --git a/ceno_rt/Cargo.toml b/ceno_rt/Cargo.toml index 505f67dd8..dfdc87ad2 100644 --- a/ceno_rt/Cargo.toml +++ b/ceno_rt/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Ceno runtime library" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ceno_rt" +readme = "README.md" +repository.workspace = true version.workspace = true [dependencies] diff --git a/ceno_rt/src/lib.rs b/ceno_rt/src/lib.rs index 07c53a2bc..8de456c41 100644 --- a/ceno_rt/src/lib.rs +++ b/ceno_rt/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] #![feature(strict_overflow_ops)] #![no_std] diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index a6cf65cd4..13f4f2810 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Ceno ZKVM" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ceno_zkvm" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index d0ad76c1c..67a314683 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -42,10 +42,10 @@ const PROGRAM_CODE: [u32; PROGRAM_SIZE] = { let mut program: [u32; PROGRAM_SIZE] = [ECALL_HALT; PROGRAM_SIZE]; declare_program!( program, - encode_rv32(LUI, 0, 0, 10, CENO_PLATFORM.public_io_start()), // lui x10, public_io - encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10) - encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10) - encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10) + encode_rv32(LUI, 0, 0, 10, CENO_PLATFORM.public_io.start), // lui x10, public_io + encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10) + encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10) + encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10) // Main loop. encode_rv32(ADD, 1, 4, 4, 0), // add x4, x1, x4 encode_rv32(ADD, 2, 3, 3, 0), // add x3, x2, x3 @@ -90,8 +90,8 @@ fn main() { }) .collect(), ); - let mem_addresses = CENO_PLATFORM.ram_start()..=CENO_PLATFORM.ram_end(); - let io_addresses = CENO_PLATFORM.public_io_start()..=CENO_PLATFORM.public_io_end(); + let mem_addresses = CENO_PLATFORM.ram.clone(); + let io_addresses = CENO_PLATFORM.public_io.clone(); let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let mut fmt_layer = fmt::layer() @@ -275,6 +275,7 @@ fn main() { ®_final, &mem_final, &public_io_final, + &[], ) .unwrap(); diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 5c79c27dc..bb4f3c0ed 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,6 +1,6 @@ use ceno_emul::{ - ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, Platform, StepRecord, Tracer, VMState, - WORD_SIZE, WordAddr, + ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, IterAddresses, Platform, Program, + StepRecord, Tracer, VMState, WORD_SIZE, Word, WordAddr, }; use ceno_zkvm::{ instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, @@ -19,7 +19,9 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate}; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; use std::{ collections::{HashMap, HashSet}, - fs, panic, + fs, + iter::zip, + panic, time::Instant, }; use tracing::level_filters::LevelFilter; @@ -41,6 +43,20 @@ struct Args { /// The preset configuration to use. #[arg(short, long, value_enum, default_value_t = Preset::Ceno)] platform: Preset, + + /// Hints: prover-private unconstrained input. + /// This is a raw file mapped as a memory segment. + /// Zero-padded to the right to the next power-of-two size. + #[arg(long)] + hints: Option, + + /// Stack size in bytes. + #[arg(long, default_value = "32768")] + stack_size: u32, + + /// Heap size in bytes. + #[arg(long, default_value = "2097152")] + heap_size: u32, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -50,11 +66,15 @@ enum Preset { } fn main() { - let args = Args::parse(); + let args = { + let mut args = Args::parse(); + args.stack_size = args.stack_size.next_multiple_of(WORD_SIZE as u32); + args.heap_size = args.heap_size.next_multiple_of(WORD_SIZE as u32); + args + }; type E = GoldilocksExt2; type Pcs = Basefold; - const PROGRAM_SIZE: usize = 1 << 14; type ExampleProgramTableCircuit = ProgramTableCircuit; // set up logger @@ -74,33 +94,76 @@ fn main() { .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); + let elf_bytes = fs::read(&args.elf).expect("read elf file"); + let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); + let platform = match args.platform { Preset::Ceno => CENO_PLATFORM, Preset::Sp1 => Platform { // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. stack_top: 0x0020_0400, - rom_start: 0x0020_0800, - rom_end: 0x003f_ffff, - ram_start: 0x0020_0000, - ram_end: 0xFFFF_0000 - 1, + rom: program.base_address + ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, + ram: 0x0010_0000..0xFFFF_0000, unsafe_ecall_nop: true, + ..CENO_PLATFORM }, }; - tracing::info!("Running on platform {:?}", args.platform); + tracing::info!("Running on platform {:?} {:?}", args.platform, platform); + tracing::info!( + "Stack: {} bytes. Heap: {} bytes.", + args.stack_size, + args.heap_size + ); + + let stack_addrs = platform.stack_top - args.stack_size..platform.stack_top; + + // Detect heap as starting after program data. + let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; + let heap_addrs = heap_start..heap_start + args.heap_size; - const STACK_SIZE: u32 = 256; - let mut mem_padder = MemPadder::new(platform.ram_start()..=platform.ram_end()); + let mut mem_padder = MemPadder::new(heap_addrs.end..platform.ram.end); + + let mem_init = { + let program_addrs = program.image.iter().map(|(addr, value)| MemInitRecord { + addr: *addr, + value: *value, + }); + + let stack = stack_addrs + .iter_addresses() + .map(|addr| MemInitRecord { addr, value: 0 }); + + let heap = heap_addrs + .iter_addresses() + .map(|addr| MemInitRecord { addr, value: 0 }); + + let mem_init = chain!(program_addrs, stack, heap).collect_vec(); + + mem_padder.padded_sorted(mem_init.len().next_power_of_two(), mem_init) + }; tracing::info!("Loading ELF file: {}", args.elf); - let elf_bytes = fs::read(&args.elf).expect("read elf file"); - let mut vm = VMState::new_from_elf(platform.clone(), &elf_bytes).unwrap(); + let mut vm = VMState::new(platform.clone(), program); + + tracing::info!("Loading hints file: {:?}", args.hints); + let hints = memory_from_file(&args.hints); + assert!( + hints.len() <= platform.hints.iter_addresses().len(), + "hints must fit in {} bytes", + platform.hints.len() + ); + for (addr, value) in zip(platform.hints.iter_addresses(), &hints) { + vm.init_memory(addr.into(), *value); + } // keygen let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let program_params = ProgramParams { platform: platform.clone(), - program_size: PROGRAM_SIZE, + program_size: vm.program().instructions.len(), + static_memory_len: mem_init.len(), ..ProgramParams::default() }; let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); @@ -119,25 +182,6 @@ fn main() { vm.program(), ); - let mem_init = { - let program_addrs = vm - .program() - .image - .iter() - .map(|(addr, value)| MemInitRecord { - addr: *addr, - value: *value, - }); - - let stack_addrs = (1..=STACK_SIZE) - .map(|i| platform.stack_top - i * WORD_SIZE as u32) - .map(|addr| MemInitRecord { addr, value: 0 }); - - let mem_init = chain!(program_addrs, stack_addrs).collect_vec(); - - mem_padder.padded_sorted(mmu_config.static_mem_len(), mem_init) - }; - // IO is not used in this program, but it must have a particular size at the moment. let io_init = mem_padder.padded_sorted(mmu_config.public_io_len(), vec![]); @@ -169,7 +213,8 @@ fn main() { .collect::, _>>() .expect("vm exec failed"); - tracing::info!("Proving {} execution steps", all_records.len()); + let cycle_num = all_records.len(); + tracing::info!("Proving {} execution steps", cycle_num); for (i, step) in enumerate(&all_records).rev().take(5).rev() { tracing::trace!("Step {i}: {:?} - {:?}\n", step.insn().codes().kind, step); } @@ -250,6 +295,14 @@ fn main() { .map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0)) .collect_vec(); + let priv_io_final = zip(platform.hints.iter_addresses(), &hints) + .map(|(addr, &value)| MemFinalRecord { + addr, + value, + cycle: *final_access.get(&addr.into()).unwrap_or(&0), + }) + .collect_vec(); + // assign table circuits config .assign_table_circuit(&zkvm_cs, &mut zkvm_witness) @@ -261,6 +314,7 @@ fn main() { ®_final, &mem_final, &io_final, + &priv_io_final, ) .unwrap(); // assign program circuit @@ -279,10 +333,22 @@ fn main() { .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); + let proving_time = timer.elapsed().as_secs_f64(); + let e2e_time = e2e_start.elapsed().as_secs_f64(); + let witgen_time = e2e_time - proving_time; println!( - "fibonacci create_proof, time = {}, e2e = {:?}", - timer.elapsed().as_secs_f64(), - e2e_start.elapsed(), + "Proving finished.\n\ +\tProving time = {:.3}s, freq = {:.3}khz\n\ +\tWitgen time = {:.3}s, freq = {:.3}khz\n\ +\tTotal time = {:.3}s, freq = {:.3}khz\n\ +\tthread num: {}", + proving_time, + cycle_num as f64 / proving_time / 1000.0, + witgen_time, + cycle_num as f64 / witgen_time / 1000.0, + e2e_time, + cycle_num as f64 / e2e_time / 1000.0, + rayon::current_num_threads() ); let transcript = Transcript::new(b"riscv"); @@ -333,6 +399,18 @@ fn main() { }; } +fn memory_from_file(path: &Option) -> Vec { + path.as_ref() + .map(|path| { + let mut buf = fs::read(path).expect("could not read file"); + buf.resize(buf.len().next_multiple_of(WORD_SIZE), 0); + buf.chunks_exact(WORD_SIZE) + .map(|word| Word::from_le_bytes(word.try_into().unwrap())) + .collect_vec() + }) + .unwrap_or_default() +} + fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) { let accessed_addrs = vm .tracer() diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index a663c6f77..d76adb948 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -1,5 +1,6 @@ use ceno_emul::StepRecord; use ff_ext::ExtensionField; +use multilinear_extensions::util::max_usable_threads; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, @@ -44,8 +45,7 @@ pub trait Instruction { num_witin: usize, steps: Vec, ) -> Result<(RowMajorMatrix, LkMultiplicity), ZKVMError> { - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let nthreads = max_usable_threads(); let num_instance_per_batch = if steps.len() > 256 { steps.len().div_ceil(nthreads) } else { diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 32bb5ce41..a271c21ab 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -1,6 +1,6 @@ -use std::{collections::HashSet, iter::zip, ops::RangeInclusive}; +use std::{collections::HashSet, iter::zip, ops::Range}; -use ceno_emul::{Addr, Cycle, WORD_SIZE, Word}; +use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; use ff_ext::ExtensionField; use itertools::{Itertools, chain}; @@ -8,8 +8,8 @@ use crate::{ error::ZKVMError, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ - MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable, - RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit, + HintsCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, + RegTable, RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit, }, }; @@ -20,6 +20,8 @@ pub struct MmuConfig { pub static_mem_config: as TableCircuit>::TableConfig, /// Initialization of public IO. pub public_io_config: as TableCircuit>::TableConfig, + /// Initialization of hints. + pub hints_config: as TableCircuit>::TableConfig, pub params: ProgramParams, } @@ -30,11 +32,13 @@ impl MmuConfig { let static_mem_config = cs.register_table_circuit::>(); let public_io_config = cs.register_table_circuit::>(); + let hints_config = cs.register_table_circuit::>(); Self { reg_config, static_mem_config, public_io_config, + hints_config, params: cs.params.clone(), } } @@ -48,9 +52,11 @@ impl MmuConfig { io_addrs: &[Addr], ) { assert!( - chain( - static_mem_init.iter().map(|record| record.addr), - io_addrs.iter().copied(), + chain!( + static_mem_init.iter_addresses(), + io_addrs.iter_addresses(), + // TODO: optimize with min_max and Range. + self.params.platform.hints.iter_addresses(), ) .all_unique(), "memory addresses must be unique" @@ -65,6 +71,7 @@ impl MmuConfig { ); fixed.register_table_circuit::>(cs, &self.public_io_config, io_addrs); + fixed.register_table_circuit::>(cs, &self.hints_config, &()); } pub fn assign_table_circuit( @@ -74,6 +81,7 @@ impl MmuConfig { reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], io_cycles: &[Cycle], + hints_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { witness.assign_table_circuit::>(cs, &self.reg_config, reg_final)?; @@ -85,6 +93,8 @@ impl MmuConfig { witness.assign_table_circuit::>(cs, &self.public_io_config, io_cycles)?; + witness.assign_table_circuit::>(cs, &self.hints_config, hints_final)?; + Ok(()) } @@ -107,7 +117,7 @@ impl MmuConfig { } pub struct MemPadder { - valid_addresses: RangeInclusive, + valid_addresses: Range, used_addresses: HashSet, } @@ -118,7 +128,7 @@ impl MemPadder { /// /// Require: `values.len() <= padded_len <= address_range.len()` pub fn init_mem( - address_range: RangeInclusive, + address_range: Range, padded_len: usize, values: &[Word], ) -> Vec { @@ -129,7 +139,7 @@ impl MemPadder { records } - pub fn new(valid_addresses: RangeInclusive) -> Self { + pub fn new(valid_addresses: Range) -> Self { Self { valid_addresses, used_addresses: HashSet::new(), diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index a3c2ff02f..35013d448 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] #![feature(box_patterns)] #![feature(stmt_expr_attributes)] #![feature(variant_count)] diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index a7a081f29..5889616fe 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -769,7 +769,7 @@ Hints: let mut rom_inputs = HashMap::, String, String, Vec>)>>::new(); let mut rom_tables = HashMap::>::new(); - for (circuit_name, cs) in cs.circuit_css.iter() { + for (circuit_name, cs) in &cs.circuit_css { let is_opcode = cs.lk_table_expressions.is_empty() && cs.r_table_expressions.is_empty() && cs.w_table_expressions.is_empty(); @@ -954,7 +954,7 @@ Hints: let mut writes_grp_by_annotations = HashMap::new(); // store (pc, timestamp) for $ram_type == RAMType::GlobalState let mut gs = HashMap::new(); - for (circuit_name, cs) in cs.circuit_css.iter() { + for (circuit_name, cs) in &cs.circuit_css { let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let num_rows = num_instances.get(circuit_name).unwrap(); @@ -1016,7 +1016,7 @@ Hints: let mut reads = HashSet::new(); let mut reads_grp_by_annotations = HashMap::new(); - for (circuit_name, cs) in cs.circuit_css.iter() { + for (circuit_name, cs) in &cs.circuit_css { let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let num_rows = num_instances.get(circuit_name).unwrap(); @@ -1061,7 +1061,7 @@ Hints: } macro_rules! find_rw_mismatch { ($reads:ident,$reads_grp_by_annotations:ident,$writes:ident,$writes_grp_by_annotations:ident,$ram_type:expr,$gs:expr) => { - for (annotation, (reads, circuit_name)) in $reads_grp_by_annotations.iter() { + for (annotation, (reads, circuit_name)) in &$reads_grp_by_annotations { // (pc, timestamp) let gs_of_circuit = $gs.get(circuit_name); let num_missing = reads @@ -1098,7 +1098,7 @@ Hints: } num_rw_mismatch_errors += num_missing; } - for (annotation, (writes, circuit_name)) in $writes_grp_by_annotations.iter() { + for (annotation, (writes, circuit_name)) in &$writes_grp_by_annotations { let gs_of_circuit = $gs.get(circuit_name); let num_missing = writes .iter() diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 30d8d9a6d..85e071ea5 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -5,7 +5,7 @@ use std::{ }; use ff::Field; -use itertools::{Itertools, izip}; +use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, @@ -62,10 +62,9 @@ impl> ZKVMProver { let mut vm_proof = ZKVMProof::empty(pi); // including raw public input to transcript - vm_proof - .raw_pi - .iter() - .for_each(|v| v.iter().for_each(|v| transcript.append_field_element(v))); + for v in vm_proof.raw_pi.iter().flatten() { + transcript.append_field_element(v); + } let pi: Vec> = vm_proof .raw_pi @@ -77,7 +76,7 @@ impl> ZKVMProver { .collect(); // commit to fixed commitment - for (_, pk) in self.pk.circuit_pks.iter() { + for pk in self.pk.circuit_pks.values() { if let Some(fixed_commit) = &pk.vk.fixed_commit { PCS::write_commitment(fixed_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; @@ -147,7 +146,7 @@ impl> ZKVMProver { cs.w_expressions.len(), cs.lk_expressions.len(), ); - for lk_s in cs.lk_expressions_namespace_map.iter() { + for lk_s in &cs.lk_expressions_namespace_map { tracing::debug!("opcode circuit {}: {}", circuit_name, lk_s); } let opcode_proof = self.create_opcode_proof( @@ -1188,7 +1187,7 @@ impl TowerProver { let eq: ArcMultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle().into(); let mut virtual_polys = VirtualPolynomials::::new(num_threads, out_rt.len()); - for (s, alpha) in prod_specs.iter().zip(alpha_pows.iter()) { + for (s, alpha) in izip!(&prod_specs, &alpha_pows) { if round < s.witness.len() { let layer_polys = &s.witness[round]; @@ -1210,9 +1209,7 @@ impl TowerProver { } } - for (s, alpha) in logup_specs - .iter() - .zip(alpha_pows[prod_specs.len()..].chunks(2)) + for (s, alpha) in izip!(&logup_specs, alpha_pows[prod_specs.len()..].chunks(2)) { if round < s.witness.len() { let layer_polys = &s.witness[round]; @@ -1270,7 +1267,7 @@ impl TowerProver { let evals = state.get_mle_final_evaluations(); let mut evals_iter = evals.iter(); evals_iter.next(); // skip first eq - for (i, s) in prod_specs.iter().enumerate() { + for (i, s) in enumerate(&prod_specs) { if round < s.witness.len() { // collect evals belong to current spec proofs.push_prod_evals_and_point( @@ -1282,7 +1279,7 @@ impl TowerProver { ); } } - for (i, s) in logup_specs.iter().enumerate() { + for (i, s) in enumerate(&logup_specs) { if round < s.witness.len() { // collect evals belong to current spec // p1, q2, p2, q1 diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 056c9ea54..fa662f338 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; +use ark_std::test_rng; use ceno_emul::{ CENO_PLATFORM, InsnKind::{ADD, EANY}, @@ -10,6 +11,9 @@ use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::Itertools; use mpcs::{Basefold, BasefoldDefault, BasefoldRSParams, PolynomialCommitmentScheme}; +use multilinear_extensions::{ + mle::IntoMLE, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, +}; use transcript::Transcript; use crate::{ @@ -23,7 +27,8 @@ use crate::{ }, set_val, structs::{ - PointAndEval, RAMType::Register, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses, + PointAndEval, RAMType::Register, TowerProver, TowerProverSpec, ZKVMConstraintSystem, + ZKVMFixedTraces, ZKVMWitnesses, }, tables::{ProgramTableCircuit, U16TableCircuit}, witness::LkMultiplicity, @@ -33,7 +38,8 @@ use super::{ PublicValues, constants::{MAX_NUM_VARIABLES, NUM_FANIN}, prover::ZKVMProver, - verifier::ZKVMVerifier, + utils::infer_tower_product_witness, + verifier::{TowerVerify, ZKVMVerifier}, }; struct TestConfig { @@ -311,3 +317,63 @@ fn test_single_add_instance_e2e() { .expect("verify proof return with error"), ); } + +/// test various product argument size, starting from minimal leaf size 2 +#[test] +fn test_tower_proof_various_prod_size() { + fn _test_tower_proof_prod_size_2(leaf_layer_size: usize) { + let num_vars = ceil_log2(leaf_layer_size); + let mut rng = test_rng(); + type E = GoldilocksExt2; + let mut transcript = Transcript::new(b"test_tower_proof"); + let leaf_layer: ArcMultilinearExtension = (0..leaf_layer_size) + .map(|_| E::random(&mut rng)) + .collect_vec() + .into_mle() + .into(); + let (first, second): (&[E], &[E]) = leaf_layer + .get_ext_field_vec() + .split_at(leaf_layer.evaluations().len() / 2); + let last_layer_splitted_fanin: Vec> = vec![ + first.to_vec().into_mle().into(), + second.to_vec().into_mle().into(), + ]; + let layers = infer_tower_product_witness(num_vars, last_layer_splitted_fanin, 2); + let (rt_tower_p, tower_proof) = TowerProver::create_proof( + vec![TowerProverSpec { + witness: layers.clone(), + }], + vec![], + 2, + &mut transcript, + ); + + let mut transcript = Transcript::new(b"test_tower_proof"); + let (rt_tower_v, prod_point_and_eval, _, _) = TowerVerify::verify( + vec![ + layers[0] + .iter() + .flat_map(|mle| mle.get_ext_field_vec().to_vec()) + .collect_vec(), + ], + vec![], + &tower_proof, + vec![num_vars], + 2, + &mut transcript, + ) + .unwrap(); + + assert_eq!(rt_tower_p, rt_tower_v); + assert_eq!(rt_tower_v.len(), num_vars); + assert_eq!(prod_point_and_eval.len(), 1); + assert_eq!( + leaf_layer.evaluate(&rt_tower_v), + prod_point_and_eval[0].eval + ); + } + + for leaf_layer_size in 1..10 { + _test_tower_proof_prod_size_2(1 << leaf_layer_size); + } +} diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 6b2ab6482..c8ec6453a 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -77,9 +77,7 @@ pub(crate) fn interleaving_mles_to_mles<'a, E: ExtensionField>( .with_min_len(MIN_PAR_SIZE) .for_each(|(value, instance)| { assert_eq!(instance.len(), per_instance_size); - instance[i] = <::BaseField as Into< - E, - >>::into(*value); + instance[i] = E::from(*value); }), _ => unreachable!(), }); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index bdce2bd09..78f70793f 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -848,22 +848,41 @@ impl TowerVerify { // out_j[rt] := (record_{j}[rt]) // out_j[rt] := (logup_p{j}[rt]) // out_j[rt] := (logup_q{j}[rt]) - let initial_claim = izip!(prod_out_evals, alpha_pows.iter()) - .map(|(evals, alpha)| evals.into_mle().evaluate(&initial_rt) * alpha) + + // bookkeeping records of latest (point, evaluation) of each layer + // prod argument + let mut prod_spec_point_n_eval = prod_out_evals + .into_iter() + .map(|evals| { + PointAndEval::new(initial_rt.clone(), evals.into_mle().evaluate(&initial_rt)) + }) + .collect::>(); + // logup argument for p, q + let (mut logup_spec_p_point_n_eval, mut logup_spec_q_point_n_eval) = logup_out_evals + .into_iter() + .map(|evals| { + let (p1, p2, q1, q2) = (evals[0], evals[1], evals[2], evals[3]); + ( + PointAndEval::new( + initial_rt.clone(), + vec![p1, p2].into_mle().evaluate(&initial_rt), + ), + PointAndEval::new( + initial_rt.clone(), + vec![q1, q2].into_mle().evaluate(&initial_rt), + ), + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) + .map(|(point_n_eval, alpha)| point_n_eval.eval * alpha) .sum::() - + izip!(logup_out_evals, alpha_pows[num_prod_spec..].chunks(2)) - .map(|(evals, alpha)| { - let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); - let (p1, p2, q1, q2) = (evals[0], evals[1], evals[2], evals[3]); - vec![p1, p2].into_mle().evaluate(&initial_rt) * alpha_numerator - + vec![q1, q2].into_mle().evaluate(&initial_rt) * alpha_denominator - }) - .sum::(); - - // evaluation in the tower input layer - let mut prod_spec_input_layer_eval = vec![PointAndEval::default(); num_prod_spec]; - let mut logup_spec_p_input_layer_eval = vec![PointAndEval::default(); num_logup_spec]; - let mut logup_spec_q_input_layer_eval = vec![PointAndEval::default(); num_logup_spec]; + + izip!( + interleave(&logup_spec_p_point_n_eval, &logup_spec_q_point_n_eval), + &alpha_pows[num_prod_spec..] + ) + .map(|(point_n_eval, alpha)| point_n_eval.eval * alpha) + .sum::(); let max_num_variables = num_variables.iter().max().unwrap(); @@ -954,7 +973,7 @@ impl TowerVerify { .map(|(a, b)| *a * b) .sum::(); // this will keep update until round > evaluation - prod_spec_input_layer_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals); + prod_spec_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals); if next_round < max_round -1 { *alpha * evals } else { @@ -987,8 +1006,8 @@ impl TowerVerify { .sum::(); // this will keep update until round > evaluation - logup_spec_p_input_layer_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals); - logup_spec_q_input_layer_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals); + logup_spec_p_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals); + logup_spec_q_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals); if next_round < max_round -1 { *alpha_numerator * p_evals + *alpha_denominator * q_evals @@ -1011,9 +1030,9 @@ impl TowerVerify { Ok(( next_rt.point, - prod_spec_input_layer_eval, - logup_spec_p_input_layer_eval, - logup_spec_q_input_layer_eval, + prod_spec_point_n_eval, + logup_spec_p_point_n_eval, + logup_spec_q_point_n_eval, )) } } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 88a5834de..cf78a9b69 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1,9 +1,9 @@ -use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, - witness::RowMajorMatrix, -}; +use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix}; +use ff::Field; use ff_ext::ExtensionField; -use std::collections::HashMap; +use multilinear_extensions::util::max_usable_threads; +use rayon::iter::{IndexedParallelIterator, ParallelIterator}; +use std::{collections::HashMap, mem::MaybeUninit}; mod range; pub use range::*; diff --git a/ceno_zkvm/src/tables/ops.rs b/ceno_zkvm/src/tables/ops.rs index f1b2e0612..789f7f744 100644 --- a/ceno_zkvm/src/tables/ops.rs +++ b/ceno_zkvm/src/tables/ops.rs @@ -87,7 +87,49 @@ impl OpsTable for PowTable { } fn content() -> Vec<[u64; 3]> { - (0..Self::len() as u64).map(|b| [2, b, 1 << b]).collect() + (0..Self::len() as u64) + .map(|exponent| [2, exponent, 1 << exponent]) + .collect() + } + + fn pack(base: u64, exponent: u64) -> u64 { + assert_eq!(base, 2); + exponent + } + + fn unpack(exponent: u64) -> (u64, u64) { + (2, exponent) } } pub type PowTableCircuit = OpsTableCircuit; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + tables::TableCircuit, + }; + use goldilocks::{GoldilocksExt2 as E, SmallField}; + + #[test] + fn test_ops_pow_table_assign() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = PowTableCircuit::::construct_circuit(&mut cb).unwrap(); + + let fixed = PowTableCircuit::::generate_fixed_traces(&config, cb.cs.num_fixed, &()); + + for (i, row) in fixed.iter_rows().enumerate() { + let (base, exp) = PowTable::unpack(i as u64); + assert_eq!(PowTable::pack(base, exp), i as u64); + assert_eq!(base, unsafe { row[0].assume_init() }.to_canonical_u64()); + assert_eq!(exp, unsafe { row[1].assume_init() }.to_canonical_u64()); + assert_eq!( + base.pow(exp.try_into().unwrap()), + unsafe { row[2].assume_init() }.to_canonical_u64() + ); + } + } +} diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 0f1be6558..d86fd58bc 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -137,7 +137,7 @@ impl TableCircuit for ProgramTableCircuit { cb.lk_table_record( || "prog table", - cb.params.program_size, + cb.params.program_size.next_power_of_two(), ROMType::Instruction, record_exprs, mlt.expr(), @@ -219,25 +219,72 @@ impl TableCircuit for ProgramTableCircuit { } #[cfg(test)] -#[test] -#[allow(clippy::identity_op)] -fn test_decode_imm() { - for (i, expected) in [ - // Example of I-type: ADDI. - // imm | rs1 | funct3 | rd | opcode - (89 << 20 | 1 << 15 | 0b000 << 12 | 1 << 7 | 0x13, 89), - // Shifts get a precomputed power of 2: SLLI, SRLI, SRAI. - (31 << 20 | 1 << 15 | 0b001 << 12 | 1 << 7 | 0x13, 1 << 31), - (31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, 1 << 31), - ( - 1 << 30 | 31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, - 1 << 31, - ), - // Example of R-type with funct7: SUB. - // funct7 | rs2 | rs1 | funct3 | rd | opcode - (0x20 << 25 | 1 << 20 | 1 << 15 | 0 << 12 | 1 << 7 | 0x33, 0), - ] { - let imm = InsnRecord::imm_internal(&DecodedInstruction::new(i)); - assert_eq!(imm, expected); +mod tests { + use super::*; + use crate::{circuit_builder::ConstraintSystem, witness::LkMultiplicity}; + use ceno_emul::encode_rv32; + use ff::Field; + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + + #[test] + #[allow(clippy::identity_op)] + fn test_decode_imm() { + for (i, expected) in [ + // Example of I-type: ADDI. + // imm | rs1 | funct3 | rd | opcode + (89 << 20 | 1 << 15 | 0b000 << 12 | 1 << 7 | 0x13, 89), + // Shifts get a precomputed power of 2: SLLI, SRLI, SRAI. + (31 << 20 | 1 << 15 | 0b001 << 12 | 1 << 7 | 0x13, 1 << 31), + (31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, 1 << 31), + ( + 1 << 30 | 31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, + 1 << 31, + ), + // Example of R-type with funct7: SUB. + // funct7 | rs2 | rs1 | funct3 | rd | opcode + (0x20 << 25 | 1 << 20 | 1 << 15 | 0 << 12 | 1 << 7 | 0x33, 0), + ] { + let imm = InsnRecord::imm_internal(&DecodedInstruction::new(i)); + assert_eq!(imm, expected); + } + } + + #[test] + fn test_program_padding() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + + let actual_len = 3; + let instructions = vec![encode_rv32(ADD, 1, 2, 3, 0); actual_len]; + let program = Program::new(0x2000_0000, 0x2000_0000, instructions, Default::default()); + + let config = ProgramTableCircuit::construct_circuit(&mut cb).unwrap(); + + let check = |matrix: &RowMajorMatrix| { + assert_eq!( + matrix.num_instances() + matrix.num_padding_instances(), + cb.params.program_size + ); + for row in matrix.iter_rows().skip(actual_len) { + for col in row.iter() { + assert_eq!(unsafe { col.assume_init() }, F::ZERO); + } + } + }; + + let fixed = + ProgramTableCircuit::::generate_fixed_traces(&config, cb.cs.num_fixed, &program); + check(&fixed); + + let lkm = LkMultiplicity::default().into_finalize_result(); + + let witness = ProgramTableCircuit::::assign_instances( + &config, + cb.cs.num_witin as usize, + &lkm, + &program, + ) + .unwrap(); + check(&witness); } } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 182892603..2c45294e4 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -19,11 +19,11 @@ impl DynVolatileRamTable for DynMemTable { const ZERO_INIT: bool = true; fn offset_addr(params: &ProgramParams) -> Addr { - params.platform.ram_start() + params.platform.ram.start } fn end_addr(params: &ProgramParams) -> Addr { - params.platform.ram_end() + params.platform.ram.end } fn name() -> &'static str { @@ -34,25 +34,25 @@ impl DynVolatileRamTable for DynMemTable { pub type DynMemCircuit = DynVolatileRamCircuit; #[derive(Clone)] -pub struct PrivateMemTable; -impl DynVolatileRamTable for PrivateMemTable { +pub struct HintsTable; +impl DynVolatileRamTable for HintsTable { const RAM_TYPE: RAMType = RAMType::Memory; const V_LIMBS: usize = 1; // See `MemoryExpr`. const ZERO_INIT: bool = false; fn offset_addr(params: &ProgramParams) -> Addr { - params.platform.ram_start() + params.platform.hints.start } fn end_addr(params: &ProgramParams) -> Addr { - params.platform.ram_end() + params.platform.hints.end } fn name() -> &'static str { - "PrivateMemTable" + "HintsTable" } } -pub type PrivateMemCircuit = DynVolatileRamCircuit; +pub type HintsCircuit = DynVolatileRamCircuit; /// RegTable, fix size without offset #[derive(Clone)] diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 9ed7f9067..26927c54d 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, marker::PhantomData}; -use ceno_emul::{Addr, Cycle, WORD_SIZE, Word}; +use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; use ff_ext::ExtensionField; use crate::{ @@ -27,6 +27,18 @@ pub struct MemFinalRecord { pub value: Word, } +impl GetAddr for MemInitRecord { + fn get_addr(&self) -> Addr { + self.addr + } +} + +impl GetAddr for MemFinalRecord { + fn get_addr(&self) -> Addr { + self.addr + } +} + /// - **Non-Volatile**: The initial values can be set to any arbitrary value. /// /// **Special Note**: @@ -178,7 +190,7 @@ impl TableC type WitnessInput = [MemFinalRecord]; fn name() -> String { - format!("RAM_{:?}", DVRAM::RAM_TYPE) + format!("RAM_{:?}_{}", DVRAM::RAM_TYPE, DVRAM::name()) } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 98aa77ad7..35f3b0720 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -389,16 +389,23 @@ impl DynVolatileRamTableConfig ) -> Result, ZKVMError> { assert!(final_mem.len() <= DVRAM::max_len(&self.params)); assert!(DVRAM::max_len(&self.params).is_power_of_two()); +<<<<<<< HEAD dbg!(final_mem.len(), num_witness); let mut final_table = RowMajorMatrix::::new(final_mem.len(), num_witness, InstancePaddingStrategy::Zero); +======= + let mut final_table = RowMajorMatrix::::new(final_mem.len(), num_witness); +>>>>>>> master final_table .par_iter_mut() .with_min_len(MIN_PAR_SIZE) .zip(final_mem.into_par_iter()) - .for_each(|(row, rec)| { + .enumerate() + .for_each(|(i, (row, rec))| { + assert_eq!(rec.addr, DVRAM::addr(&self.params, i)); set_val!(row, self.addr, rec.addr as u64); + if self.final_v.len() == 1 { // Assign value directly. set_val!(row, self.final_v[0], rec.value as u64); @@ -412,6 +419,24 @@ impl DynVolatileRamTableConfig set_val!(row, self.final_cycle, rec.cycle); }); +<<<<<<< HEAD +======= + // set padding with well-form address + final_table + .par_iter_mut() + .enumerate() + .skip(final_mem.len()) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(i, row)| { + // Assign value limbs. + self.final_v.iter().for_each(|limb| { + set_val!(row, limb, 0u64); + }); + set_val!(row, self.addr, DVRAM::addr(&self.params, i) as u64); + set_val!(row, self.final_cycle, 0_u64); + }); + +>>>>>>> master Ok(final_table) } } diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 1f540b472..0af75208d 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -17,7 +17,7 @@ use ark_std::iterable::Iterable; use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; -use itertools::Itertools; +use itertools::{Itertools, enumerate}; use std::{ borrow::Cow, mem::{self}, @@ -34,15 +34,36 @@ pub enum UintLimb { Expression(Vec>), } -impl UintLimb { - pub fn iter(&self) -> impl Iterator { +impl IntoIterator for UintLimb { + type Item = WitIn; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { match self { - UintLimb::WitIn(vec) => vec.iter(), + UintLimb::WitIn(wits) => wits.into_iter(), _ => unimplemented!(), } } } +impl<'a, E: ExtensionField> IntoIterator for &'a UintLimb { + type Item = &'a WitIn; + type IntoIter = std::slice::Iter<'a, WitIn>; + + fn into_iter(self) -> Self::IntoIter { + match self { + UintLimb::WitIn(wits) => wits.iter(), + _ => unimplemented!(), + } + } +} + +impl UintLimb { + pub fn iter(&self) -> impl Iterator { + self.into_iter() + } +} + impl Index for UintLimb { type Output = WitIn; @@ -735,8 +756,8 @@ impl<'a, T: Into + From + Copy + Default> Value<'a, T> { let mut c_limbs = vec![0u16; num_limbs]; let mut carries = vec![0u64; num_limbs]; let mut tmp = vec![0u64; num_limbs]; - a_limbs.iter().enumerate().for_each(|(i, &a_limb)| { - b_limbs.iter().enumerate().for_each(|(j, &b_limb)| { + enumerate(a_limbs).for_each(|(i, &a_limb)| { + enumerate(b_limbs).for_each(|(j, &b_limb)| { let idx = i + j; if idx < num_limbs { tmp[idx] += a_limb as u64 * b_limb as u64; diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 85eb703c8..dfe33b076 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -267,14 +267,12 @@ impl UIntLimbs { rhs: &UIntLimbs, ) -> Result { let n_limbs = Self::NUM_LIMBS; - let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = self - .limbs - .iter() - .zip_eq(rhs.limbs.iter()) - .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) - .collect::, ZKVMError>>()? - .into_iter() - .unzip(); + let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = + izip!(&self.limbs, &rhs.limbs) + .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) + .collect::, ZKVMError>>()? + .into_iter() + .unzip(); let sum_expr = is_equal_per_limb .iter() diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs index b340df982..024d09d73 100644 --- a/ceno_zkvm/src/uint/logic.rs +++ b/ceno_zkvm/src/uint/logic.rs @@ -18,7 +18,7 @@ impl UIntLimbs { b: &Self, c: &Self, ) -> Result<(), ZKVMError> { - for (a_byte, b_byte, c_byte) in izip!(a.limbs.iter(), b.limbs.iter(), c.limbs.iter()) { + for (a_byte, b_byte, c_byte) in izip!(&a.limbs, &b.limbs, &c.limbs) { cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; } Ok(()) diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 75c930191..6302888d1 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -246,8 +246,8 @@ mod tests { }, ); - <::BaseField as std::convert::Into>::into( - evals.iter().sum::<::BaseField>() + GoldilocksExt2::from( + evals.iter().sum::() * base_2.pow([(max_num_vars - fs[0].num_vars()) as u64]), ) }; diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 000000000..6e64e6f38 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,10 @@ +# TODO(Matthias): review and see which exception we can remove over time. +# Eg removing syn is blocked by ark-ff-asm cutting a new release +# (https://github.com/arkworks-rs/algebra/issues/813) amongst other things. +allowed-duplicate-crates = [ + "syn", + "windows-sys", + "regex-automata", + "regex-syntax", + "itertools", +] diff --git a/examples-builder/Cargo.toml b/examples-builder/Cargo.toml index 6862e806b..00104ec3c 100644 --- a/examples-builder/Cargo.toml +++ b/examples-builder/Cargo.toml @@ -1,5 +1,9 @@ [package] +categories.workspace = true +description = "Build scripts for ceno examples" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ceno-examples" +repository.workspace = true version.workspace = true diff --git a/examples-builder/src/lib.rs b/examples-builder/src/lib.rs index fdb344ff9..430c4d1de 100644 --- a/examples-builder/src/lib.rs +++ b/examples-builder/src/lib.rs @@ -1 +1,2 @@ +#![deny(clippy::cargo)] include!(concat!(env!("OUT_DIR"), "/vars.rs")); diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 08082d0bd..85c975ac1 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -1,6 +1,7 @@ [package] edition = "2021" name = "examples" +readme = "README.md" resolver = "2" version = "0.1.0" diff --git a/ff_ext/Cargo.toml b/ff_ext/Cargo.toml index 5c6a0d937..3b55f3581 100644 --- a/ff_ext/Cargo.toml +++ b/ff_ext/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Extra functionality for the ff ((finite fields) crate" edition.workspace = true +keywords.workspace = true license.workspace = true name = "ff_ext" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index ba34bfa4b..32d77a565 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] pub use ff; use ff::FromUniformBytes; use goldilocks::SmallField; diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 26ee123ba..f977328cc 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Multilinear Polynomial Commitment Scheme" edition.workspace = true +keywords.workspace = true license.workspace = true name = "mpcs" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] @@ -38,11 +43,7 @@ sanity-check = [] [[bench]] harness = false -name = "commit_open_verify_rs" - -[[bench]] -harness = false -name = "commit_open_verify_basecode" +name = "basefold" [[bench]] harness = false diff --git a/mpcs/benches/commit_open_verify_rs.rs b/mpcs/benches/basefold.rs similarity index 61% rename from mpcs/benches/commit_open_verify_rs.rs rename to mpcs/benches/basefold.rs index 1401f5127..965d55035 100644 --- a/mpcs/benches/commit_open_verify_rs.rs +++ b/mpcs/benches/basefold.rs @@ -1,12 +1,16 @@ use std::time::Duration; use criterion::*; -use ff::Field; +use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::{Itertools, chain}; use mpcs::{ - Basefold, BasefoldRSParams, Evaluation, PolynomialCommitmentScheme, + Basefold, BasefoldBasecodeParams, BasefoldRSParams, Evaluation, PolynomialCommitmentScheme, + test_util::{ + commit_polys_individually, gen_rand_poly_base, gen_rand_poly_ext, gen_rand_polys, + get_point_from_challenge, get_points_from_challenge, setup_pcs, + }, util::plonky2_util::log2_ceil, }; @@ -14,12 +18,10 @@ use multilinear_extensions::{ mle::{DenseMultilinearExtension, MultilinearExtension}, virtual_poly_v2::ArcMultilinearExtension, }; -use rand::{SeedableRng, rngs::OsRng}; -use rand_chacha::ChaCha8Rng; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; use transcript::Transcript; -type Pcs = Basefold; +type PcsGoldilocksRSCode = Basefold; +type PcsGoldilocksBasecode = Basefold; type T = Transcript; type E = GoldilocksExt2; @@ -29,10 +31,19 @@ const NUM_VARS_END: usize = 20; const BATCH_SIZE_LOG_START: usize = 6; const BATCH_SIZE_LOG_END: usize = 6; -fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { +struct Switch<'a, E: ExtensionField> { + name: &'a str, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, +} + +fn bench_commit_open_verify_goldilocks>( + c: &mut Criterion, + switch: Switch, + id: &str, +) { let mut group = c.benchmark_group(format!( - "commit_open_verify_goldilocks_rs_{}", - if is_base { "base" } else { "ext2" } + "commit_open_verify_goldilocks_{}_{}", + id, switch.name, )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field @@ -50,18 +61,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { }; let mut transcript = T::new(b"BaseFold"); - let poly = if is_base { - DenseMultilinearExtension::random(num_vars, &mut OsRng) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars) - .into_par_iter() - .map(|_| E::random(&mut OsRng)) - .collect(), - ) - }; - + let poly = (switch.gen_rand_poly)(num_vars); let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { @@ -70,9 +70,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { }) }); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); let eval = poly.evaluate(point.as_slice()); transcript.append_field_element_ext(&eval); let transcript_for_bench = transcript.clone(); @@ -91,9 +89,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let comm = Pcs::get_pure_commitment(&comm); let mut transcript = T::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_ext(&eval); let transcript_for_bench = transcript.clone(); Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); @@ -109,10 +105,24 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { } } -fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { +const BASE: Switch = Switch { + name: "base", + gen_rand_poly: gen_rand_poly_base, +}; + +const EXT: Switch = Switch { + name: "ext", + gen_rand_poly: gen_rand_poly_ext, +}; + +fn bench_batch_commit_open_verify_goldilocks>( + c: &mut Criterion, + switch: Switch, + id: &str, +) { let mut group = c.benchmark_group(format!( - "batch_commit_open_verify_goldilocks_rs_{}", - if is_base { "base" } else { "ext2" } + "batch_commit_open_verify_goldilocks_{}_{}", + id, switch.name, )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field @@ -120,13 +130,8 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { let batch_size = 1 << batch_size_log; let num_points = batch_size >> 1; - let rng = ChaCha8Rng::from_seed([0u8; 32]); // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); // Batch commit and open let evals = chain![ (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys @@ -136,37 +141,18 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { .collect_vec(); let mut transcript = T::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|i| { - if is_base { - DenseMultilinearExtension::random( - num_vars - log2_ceil((i >> 1) + 1), - &mut rng.clone(), - ) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars - log2_ceil((i >> 1) + 1), - (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) - .into_par_iter() - .map(|_| E::random(&mut OsRng)) - .collect(), - ) - } - }) - .collect_vec(); - let comms = polys - .iter() - .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) - .collect_vec(); + let polys = gen_rand_polys( + |i| num_vars - log2_ceil((i >> 1) + 1), + batch_size, + switch.gen_rand_poly, + ); + let comms = commit_polys_individually::(&pp, &polys, &mut transcript); - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); + let points = get_points_from_challenge( + |i| num_vars - log2_ceil(i + 1), + num_points, + &mut transcript, + ); let evals = evals .iter() @@ -175,11 +161,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) }) .collect_vec(); - let values: Vec = evals - .iter() - .map(Evaluation::value) - .copied() - .collect::>(); + let values: Vec = evals.iter().map(Evaluation::value).copied().collect(); transcript.append_field_element_exts(values.as_slice()); let transcript_for_bench = transcript.clone(); let proof = @@ -208,14 +190,11 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { comm }) .collect_vec(); - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); + let points = get_points_from_challenge( + |i| num_vars - log2_ceil(i + 1), + num_points, + &mut transcript, + ); let values: Vec = evals .iter() @@ -252,38 +231,23 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { } } -fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { +fn bench_simple_batch_commit_open_verify_goldilocks>( + c: &mut Criterion, + switch: Switch, + id: &str, +) { let mut group = c.benchmark_group(format!( - "simple_batch_commit_open_verify_goldilocks_rs_{}", - if is_base { "base" } else { "extension" } + "simple_batch_commit_open_verify_goldilocks_{}_{}", + id, switch.name, )); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field for num_vars in NUM_VARS_START..=NUM_VARS_END { for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { let batch_size = 1 << batch_size_log; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); let mut transcript = T::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|_| { - if is_base { - DenseMultilinearExtension::random(num_vars, &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars) - .into_par_iter() - .map(|_| E::random(&mut OsRng)) - .collect(), - ) - } - }) - .collect_vec(); + let polys = gen_rand_polys(|_| num_vars, batch_size, switch.gen_rand_poly); let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); group.bench_function( @@ -294,20 +258,14 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: }) }, ); - - let polys: Vec> = - polys.into_iter().map(|poly| poly.into()).collect_vec(); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - - let evals = (0..batch_size) - .map(|i| polys[i].evaluate(&point)) - .collect_vec(); - + let point = get_point_from_challenge(num_vars, &mut transcript); + let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); transcript.append_field_element_exts(&evals); let transcript_for_bench = transcript.clone(); + let polys = polys + .iter() + .map(|poly| ArcMultilinearExtension::from(poly.clone())) + .collect::>(); let proof = Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) .unwrap(); @@ -337,9 +295,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: let mut transcript = Transcript::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_exts(&evals); let backup_transcript = transcript.clone(); @@ -369,34 +325,61 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: } } -fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, false); +fn bench_commit_open_verify_goldilocks_ext_rs(c: &mut Criterion) { + bench_commit_open_verify_goldilocks::(c, EXT, "rs"); +} + +fn bench_commit_open_verify_goldilocks_ext_basecode(c: &mut Criterion) { + bench_commit_open_verify_goldilocks::(c, EXT, "basecode"); +} + +fn bench_commit_open_verify_goldilocks_base_rs(c: &mut Criterion) { + bench_commit_open_verify_goldilocks::(c, BASE, "rs"); +} + +fn bench_commit_open_verify_goldilocks_base_basecode(c: &mut Criterion) { + bench_commit_open_verify_goldilocks::(c, BASE, "basecode"); +} + +fn bench_batch_commit_open_verify_goldilocks_ext_rs(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, EXT, "rs"); +} + +fn bench_batch_commit_open_verify_goldilocks_ext_basecode(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, EXT, "basecode"); +} + +fn bench_batch_commit_open_verify_goldilocks_base_rs(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, BASE, "rs"); } -fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, true); +fn bench_batch_commit_open_verify_goldilocks_base_basecode(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, BASE, "basecode"); } -fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, false); +fn bench_simple_batch_commit_open_verify_goldilocks_ext_rs(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks::(c, EXT, "rs"); } -fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, true); +fn bench_simple_batch_commit_open_verify_goldilocks_ext_basecode(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks::(c, EXT, "basecode"); } -fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks(c, false); +fn bench_simple_batch_commit_open_verify_goldilocks_base_rs(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks::(c, BASE, "rs"); } -fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks(c, true); +fn bench_simple_batch_commit_open_verify_goldilocks_base_basecode(c: &mut Criterion) { + bench_simple_batch_commit_open_verify_goldilocks::(c, BASE, "basecode"); } criterion_group! { name = bench_basefold; config = Criterion::default().warm_up_time(Duration::from_millis(3000)); - targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2, bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, + targets = + bench_simple_batch_commit_open_verify_goldilocks_base_rs, bench_simple_batch_commit_open_verify_goldilocks_ext_rs, + bench_batch_commit_open_verify_goldilocks_base_rs, bench_batch_commit_open_verify_goldilocks_ext_rs, bench_commit_open_verify_goldilocks_base_rs, bench_commit_open_verify_goldilocks_ext_rs, + bench_simple_batch_commit_open_verify_goldilocks_base_basecode, bench_simple_batch_commit_open_verify_goldilocks_ext_basecode, bench_batch_commit_open_verify_goldilocks_base_basecode, bench_batch_commit_open_verify_goldilocks_ext_basecode, bench_commit_open_verify_goldilocks_base_basecode, bench_commit_open_verify_goldilocks_ext_basecode, } criterion_main!(bench_basefold); diff --git a/mpcs/benches/commit_open_verify_basecode.rs b/mpcs/benches/commit_open_verify_basecode.rs deleted file mode 100644 index 91baa5f73..000000000 --- a/mpcs/benches/commit_open_verify_basecode.rs +++ /dev/null @@ -1,400 +0,0 @@ -use std::time::Duration; - -use criterion::*; -use ff::Field; -use goldilocks::GoldilocksExt2; - -use itertools::{Itertools, chain}; -use mpcs::{ - Basefold, BasefoldBasecodeParams, Evaluation, PolynomialCommitmentScheme, - util::plonky2_util::log2_ceil, -}; - -use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; -use rand::{SeedableRng, rngs::OsRng}; -use rand_chacha::ChaCha8Rng; -use transcript::Transcript; - -type Pcs = Basefold; -type T = Transcript; -type E = GoldilocksExt2; - -const NUM_SAMPLES: usize = 10; -const NUM_VARS_START: usize = 20; -const NUM_VARS_END: usize = 20; -const BATCH_SIZE_LOG_START: usize = 6; -const BATCH_SIZE_LOG_END: usize = 6; - -fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { - let mut group = c.benchmark_group(format!( - "commit_open_verify_goldilocks_basecode_{}", - if is_base { "base" } else { "ext2" } - )); - group.sample_size(NUM_SAMPLES); - // Challenge is over extension field, poly over the base field - for num_vars in NUM_VARS_START..=NUM_VARS_END { - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - - group.bench_function(BenchmarkId::new("setup", format!("{}", num_vars)), |b| { - b.iter(|| { - Pcs::setup(poly_size).unwrap(); - }) - }); - Pcs::trim(param, poly_size).unwrap() - }; - - let mut transcript = T::new(b"BaseFold"); - let poly = if is_base { - DenseMultilinearExtension::random(num_vars, &mut OsRng) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - }; - - let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); - - group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { - b.iter(|| { - Pcs::commit(&pp, &poly).unwrap(); - }) - }); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - let eval = poly.evaluate(point.as_slice()); - transcript.append_field_element_ext(&eval); - let transcript_for_bench = transcript.clone(); - let proof = Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - - group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| { - b.iter_batched( - || transcript_for_bench.clone(), - |mut transcript| { - Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - }, - BatchSize::SmallInput, - ); - }); - // Verify - let comm = Pcs::get_pure_commitment(&comm); - let mut transcript = T::new(b"BaseFold"); - Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - transcript.append_field_element_ext(&eval); - let transcript_for_bench = transcript.clone(); - Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); - group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| { - b.iter_batched( - || transcript_for_bench.clone(), - |mut transcript| { - Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); - }, - BatchSize::SmallInput, - ); - }); - } -} - -fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { - let mut group = c.benchmark_group(format!( - "batch_commit_open_verify_goldilocks_basecode_{}", - if is_base { "base" } else { "ext2" } - )); - group.sample_size(NUM_SAMPLES); - // Challenge is over extension field, poly over the base field - for num_vars in NUM_VARS_START..=NUM_VARS_END { - for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { - let batch_size = 1 << batch_size_log; - let num_points = batch_size >> 1; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; - // Batch commit and open - let evals = chain![ - (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys - (0..num_points).map(|point| (point * 2 + 1, point)), - ] - .unique() - .collect_vec(); - - let mut transcript = T::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|i| { - if is_base { - DenseMultilinearExtension::random( - num_vars - log2_ceil((i >> 1) + 1), - &mut rng.clone(), - ) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars - log2_ceil((i >> 1) + 1), - (0..1 << (num_vars - log2_ceil((i >> 1) + 1))) - .map(|_| E::random(&mut OsRng)) - .collect(), - ) - } - }) - .collect_vec(); - let comms = polys - .iter() - .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) - .collect_vec(); - - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); - - let evals = evals - .iter() - .copied() - .map(|(poly, point)| { - Evaluation::new(poly, point, polys[poly].evaluate(&points[point])) - }) - .collect_vec(); - let values: Vec = evals - .iter() - .map(Evaluation::value) - .copied() - .collect::>(); - transcript.append_field_element_exts(values.as_slice()); - let transcript_for_bench = transcript.clone(); - let proof = - Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter_batched( - || transcript_for_bench.clone(), - |mut transcript| { - Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - // Batch verify - let mut transcript = T::new(b"BaseFold"); - let comms = comms - .iter() - .map(|comm| { - let comm = Pcs::get_pure_commitment(comm); - Pcs::write_commitment(&comm, &mut transcript).unwrap(); - comm - }) - .collect_vec(); - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); - - let values: Vec = evals - .iter() - .map(Evaluation::value) - .copied() - .collect::>(); - transcript.append_field_element_exts(values.as_slice()); - - let backup_transcript = transcript.clone(); - - Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter_batched( - || backup_transcript.clone(), - |mut transcript| { - Pcs::batch_verify( - &vp, - &comms, - &points, - &evals, - &proof, - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - } - } -} - -fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { - let mut group = c.benchmark_group(format!( - "simple_batch_commit_open_verify_goldilocks_basecode_{}", - if is_base { "base" } else { "extension" } - )); - group.sample_size(NUM_SAMPLES); - // Challenge is over extension field, poly over the base field - for num_vars in NUM_VARS_START..=NUM_VARS_END { - for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { - let batch_size = 1 << batch_size_log; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; - let mut transcript = T::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|_| { - if is_base { - DenseMultilinearExtension::random(num_vars, &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - } - }) - .collect_vec(); - let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter(|| { - Pcs::batch_commit(&pp, &polys).unwrap(); - }) - }, - ); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - - let evals = (0..batch_size) - .map(|i| polys[i].evaluate(&point)) - .collect_vec(); - - transcript.append_field_element_exts(&evals); - let transcript_for_bench = transcript.clone(); - let polys = polys - .clone() - .into_iter() - .map(|x| x.into()) - .collect::>(); - let proof = Pcs::simple_batch_open( - &pp, - polys.as_slice(), - &comm, - &point, - &evals, - &mut transcript, - ) - .unwrap(); - - group.bench_function( - BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter_batched( - || transcript_for_bench.clone(), - |mut transcript| { - Pcs::simple_batch_open( - &pp, - polys.as_slice(), - &comm, - &point, - &evals, - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - let comm = Pcs::get_pure_commitment(&comm); - - // Batch verify - let mut transcript = Transcript::new(b"BaseFold"); - Pcs::write_commitment(&comm, &mut transcript).unwrap(); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - transcript.append_field_element_exts(&evals); - let backup_transcript = transcript.clone(); - - Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript).unwrap(); - - group.bench_function( - BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), - |b| { - b.iter_batched( - || backup_transcript.clone(), - |mut transcript| { - Pcs::simple_batch_verify( - &vp, - &comm, - &point, - &evals, - &proof, - &mut transcript, - ) - .unwrap(); - }, - BatchSize::SmallInput, - ); - }, - ); - } - } -} - -fn bench_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, false); -} - -fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_commit_open_verify_goldilocks(c, true); -} - -fn bench_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, false); -} - -fn bench_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_batch_commit_open_verify_goldilocks(c, true); -} - -fn bench_simple_batch_commit_open_verify_goldilocks_2(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks(c, false); -} - -fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks(c, true); -} - -criterion_group! { - name = bench_basefold; - config = Criterion::default().warm_up_time(Duration::from_millis(3000)); - targets = bench_simple_batch_commit_open_verify_goldilocks_base, bench_simple_batch_commit_open_verify_goldilocks_2,bench_batch_commit_open_verify_goldilocks_base, bench_batch_commit_open_verify_goldilocks_2, bench_commit_open_verify_goldilocks_base, bench_commit_open_verify_goldilocks_2, -} - -criterion_main!(bench_basefold); diff --git a/mpcs/benches/hashing.rs b/mpcs/benches/hashing.rs index 4dc107ed7..818cd7b6c 100644 --- a/mpcs/benches/hashing.rs +++ b/mpcs/benches/hashing.rs @@ -5,28 +5,19 @@ use goldilocks::Goldilocks; use mpcs::util::hash::{Digest, hash_two_digests}; use poseidon::poseidon_hash::PoseidonHash; +fn random_ceno_goldy() -> Goldilocks { + Goldilocks::random(&mut test_rng()) +} pub fn criterion_benchmark(c: &mut Criterion) { - let left = Digest( - vec![Goldilocks::random(&mut test_rng()); 4] - .try_into() - .unwrap(), - ); - let right = Digest( - vec![Goldilocks::random(&mut test_rng()); 4] - .try_into() - .unwrap(), - ); + let left = Digest(vec![random_ceno_goldy(); 4].try_into().unwrap()); + let right = Digest(vec![random_ceno_goldy(); 4].try_into().unwrap()); c.bench_function("ceno hash 2 to 1", |bencher| { bencher.iter(|| hash_two_digests(&left, &right)) }); - let values = (0..60) - .map(|_| Goldilocks::random(&mut test_rng())) - .collect::>(); + let values = (0..60).map(|_| random_ceno_goldy()).collect::>(); c.bench_function("ceno hash 60 to 1", |bencher| { - bencher.iter(|| { - PoseidonHash::hash_or_noop(&values); - }) + bencher.iter(|| PoseidonHash::hash_or_noop(&values)) }); } diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 5c225c75a..713bd2ec3 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -1179,8 +1179,8 @@ mod test { use crate::{ basefold::Basefold, test_util::{ - run_batch_commit_open_verify, run_commit_open_verify, - run_simple_batch_commit_open_verify, + gen_rand_poly_base, gen_rand_poly_ext, run_batch_commit_open_verify, + run_commit_open_verify, run_simple_batch_commit_open_verify, }, }; use goldilocks::GoldilocksExt2; @@ -1191,108 +1191,79 @@ mod test { type PcsGoldilocksBaseCode = Basefold; #[test] - fn commit_open_verify_goldilocks_basecode_base() { - // Challenge is over extension field, poly over the base field - run_commit_open_verify::(true, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(true, 4, 6); - } - - #[test] - fn commit_open_verify_goldilocks_rscode_base() { - // Challenge is over extension field, poly over the base field - run_commit_open_verify::(true, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(true, 4, 6); - } - - #[test] - fn commit_open_verify_goldilocks_basecode_2() { - // Both challenge and poly are over extension field - run_commit_open_verify::(false, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(false, 4, 6); - } - - #[test] - fn commit_open_verify_goldilocks_rscode_2() { - // Both challenge and poly are over extension field - run_commit_open_verify::(false, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(false, 4, 6); - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks_basecode_base() { - // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::( - true, 10, 11, 1, - ); - run_simple_batch_commit_open_verify::( - true, 10, 11, 4, - ); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::(true, 4, 6, 4); - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks_rscode_base() { - // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::(true, 10, 11, 1); - run_simple_batch_commit_open_verify::(true, 10, 11, 4); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::(true, 4, 6, 4); - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks_basecode_2() { - // Both challenge and poly are over extension field - run_simple_batch_commit_open_verify::( - false, 10, 11, 1, - ); - run_simple_batch_commit_open_verify::( - false, 10, 11, 4, - ); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::( - false, 4, 6, 4, - ); - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks_rscode_2() { - // Both challenge and poly are over extension field - run_simple_batch_commit_open_verify::( - false, 10, 11, 1, - ); - run_simple_batch_commit_open_verify::( - false, 10, 11, 4, - ); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::(false, 4, 6, 4); - } - - #[test] - fn batch_commit_open_verify_goldilocks_basecode_base() { - // Both challenge and poly are over base field - run_batch_commit_open_verify::(true, 10, 11); - } - - #[test] - fn batch_commit_open_verify_goldilocks_rscode_base() { - // Both challenge and poly are over base field - run_batch_commit_open_verify::(true, 10, 11); + fn commit_open_verify_goldilocks() { + for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { + // Challenge is over extension field, poly over the base field + run_commit_open_verify::(gen_rand_poly, 10, 11); + // Test trivial proof with small num vars + run_commit_open_verify::(gen_rand_poly, 4, 6); + // Challenge is over extension field, poly over the base field + run_commit_open_verify::(gen_rand_poly, 10, 11); + // Test trivial proof with small num vars + run_commit_open_verify::(gen_rand_poly, 4, 6); + } } #[test] - fn batch_commit_open_verify_goldilocks_basecode_2() { - // Both challenge and poly are over extension field - run_batch_commit_open_verify::(false, 10, 11); + fn simple_batch_commit_open_verify_goldilocks() { + for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::( + gen_rand_poly, + 10, + 11, + 1, + ); + run_simple_batch_commit_open_verify::( + gen_rand_poly, + 10, + 11, + 4, + ); + // Test trivial proof with small num vars + run_simple_batch_commit_open_verify::( + gen_rand_poly, + 4, + 6, + 4, + ); + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::( + gen_rand_poly, + 10, + 11, + 1, + ); + run_simple_batch_commit_open_verify::( + gen_rand_poly, + 10, + 11, + 4, + ); + // Test trivial proof with small num vars + run_simple_batch_commit_open_verify::( + gen_rand_poly, + 4, + 6, + 4, + ); + } } #[test] - fn batch_commit_open_verify_goldilocks_rscode_2() { - // Both challenge and poly are over extension field - run_batch_commit_open_verify::(false, 10, 11); + fn batch_commit_open_verify() { + for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { + // Both challenge and poly are over base field + run_batch_commit_open_verify::( + gen_rand_poly, + 10, + 11, + ); + run_batch_commit_open_verify::( + gen_rand_poly, + 10, + 11, + ); + } } } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 19b3d16b6..b6716c19e 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; @@ -357,49 +358,109 @@ fn err_too_many_variates(function: &str, upto: usize, got: usize) -> Error { }) } -#[cfg(test)] +// TODO: Need to use some functions here in the integration benchmarks. But +// unfortunately integration benchmarks do not compile the #[cfg(test)] +// code. So remove the gate for the entire module, only gate the test +// functions. +// This is not the best way: the test utility functions should not be +// compiled in the release build. Need a better solution. +#[doc(hidden)] pub mod test_util { - - use crate::{Evaluation, PolynomialCommitmentScheme}; + #[cfg(test)] + use crate::Evaluation; + use crate::PolynomialCommitmentScheme; use ff_ext::ExtensionField; - use itertools::{Itertools, chain}; - use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; - use rand::{prelude::*, rngs::OsRng}; - use rand_chacha::ChaCha8Rng; + use itertools::Itertools; + #[cfg(test)] + use itertools::chain; + use multilinear_extensions::mle::DenseMultilinearExtension; + #[cfg(test)] + use multilinear_extensions::{ + mle::MultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, + }; + use rand::rngs::OsRng; use transcript::Transcript; + pub fn setup_pcs>( + num_vars: usize, + ) -> (Pcs::ProverParam, Pcs::VerifierParam) { + let poly_size = 1 << num_vars; + let param = Pcs::setup(poly_size).unwrap(); + Pcs::trim(param, poly_size).unwrap() + } + + pub fn gen_rand_poly_base(num_vars: usize) -> DenseMultilinearExtension { + DenseMultilinearExtension::random(num_vars, &mut OsRng) + } + + pub fn gen_rand_poly_ext(num_vars: usize) -> DenseMultilinearExtension { + DenseMultilinearExtension::from_evaluations_ext_vec( + num_vars, + (0..(1 << num_vars)) + .map(|_| E::random(&mut OsRng)) + .collect_vec(), + ) + } + + pub fn gen_rand_polys( + num_vars: impl Fn(usize) -> usize, + batch_size: usize, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, + ) -> Vec> { + (0..batch_size) + .map(|i| gen_rand_poly(num_vars(i))) + .collect_vec() + } + + pub fn get_point_from_challenge( + num_vars: usize, + transcript: &mut Transcript, + ) -> Vec { + (0..num_vars) + .map(|_| transcript.get_and_append_challenge(b"Point").elements) + .collect() + } + pub fn get_points_from_challenge( + num_vars: impl Fn(usize) -> usize, + num_points: usize, + transcript: &mut Transcript, + ) -> Vec> { + (0..num_points) + .map(|i| get_point_from_challenge(num_vars(i), transcript)) + .collect() + } + + pub fn commit_polys_individually>( + pp: &Pcs::ProverParam, + polys: &[DenseMultilinearExtension], + transcript: &mut Transcript, + ) -> Vec { + polys + .iter() + .map(|poly| Pcs::commit_and_write(pp, poly, transcript).unwrap()) + .collect_vec() + } + + #[cfg(test)] pub fn run_commit_open_verify( - base: bool, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, ) where Pcs: PolynomialCommitmentScheme, { for num_vars in num_vars_start..num_vars_end { - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); + // Commit and open let (comm, eval, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let poly = if base { - DenseMultilinearExtension::random(num_vars, &mut OsRng) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - }; - + let poly = gen_rand_poly(num_vars); let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); let eval = poly.evaluate(point.as_slice()); transcript.append_field_element_ext(&eval); + ( Pcs::get_pure_commitment(&comm), eval, @@ -408,26 +469,22 @@ pub mod test_util { ) }; // Verify - let result = { + { let mut transcript = Transcript::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_ext(&eval); - let result = Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript); + Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); let v_challenge = transcript.read_challenge(); assert_eq!(challenge, v_challenge); - - result - }; - result.unwrap(); + } } } + #[cfg(test)] pub fn run_batch_commit_open_verify( - base: bool, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, ) where @@ -437,13 +494,8 @@ pub mod test_util { for num_vars in num_vars_start..num_vars_end { let batch_size = 2; let num_points = batch_size >> 1; - let rng = ChaCha8Rng::from_seed([0u8; 32]); - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); + // Batch commit and open let evals = chain![ (0..num_points).map(|point| (point * 2, point)), // Every point matches two polys @@ -452,34 +504,15 @@ pub mod test_util { .unique() .collect_vec(); - let (comms, points, evals, proof, challenge) = { + let (comms, evals, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|i| { - if base { - DenseMultilinearExtension::random(num_vars - (i >> 1), &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - } - }) - .collect_vec(); + let polys = gen_rand_polys(|i| num_vars - (i >> 1), batch_size, gen_rand_poly); - let comms = polys - .iter() - .map(|poly| Pcs::commit_and_write(&pp, poly, &mut transcript).unwrap()) - .collect_vec(); + let comms = + commit_polys_individually::(&pp, polys.as_slice(), &mut transcript); - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); + let points = + get_points_from_challenge(|i| num_vars - i, num_points, &mut transcript); let evals = evals .iter() @@ -499,10 +532,10 @@ pub mod test_util { let proof = Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); - (comms, points, evals, proof, transcript.read_challenge()) + (comms, evals, proof, transcript.read_challenge()) }; // Batch verify - let result = { + { let mut transcript = Transcript::new(b"BaseFold"); let comms = comms .iter() @@ -513,16 +546,9 @@ pub mod test_util { }) .collect_vec(); - let old_points = points; - let points = (0..num_points) - .map(|i| { - (0..num_vars - i) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>() - }) - .take(num_points) - .collect_vec(); - assert_eq!(points, old_points); + let points = + get_points_from_challenge(|i| num_vars - i, num_points, &mut transcript); + let values: Vec = evals .iter() .map(Evaluation::value) @@ -530,19 +556,16 @@ pub mod test_util { .collect::>(); transcript.append_field_element_exts(values.as_slice()); - let result = - Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript); + Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript).unwrap(); let v_challenge = transcript.read_challenge(); assert_eq!(challenge, v_challenge); - result - }; - - result.unwrap(); + } } } + #[cfg(test)] pub(super) fn run_simple_batch_commit_open_verify( - base: bool, + gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, batch_size: usize, @@ -551,52 +574,24 @@ pub mod test_util { Pcs: PolynomialCommitmentScheme, { for num_vars in num_vars_start..num_vars_end { - let rng = ChaCha8Rng::from_seed([0u8; 32]); - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(param, poly_size).unwrap() - }; + let (pp, vp) = setup_pcs::(num_vars); let (comm, evals, proof, challenge) = { let mut transcript = Transcript::new(b"BaseFold"); - let polys = (0..batch_size) - .map(|_| { - if base { - DenseMultilinearExtension::random(num_vars, &mut rng.clone()) - } else { - DenseMultilinearExtension::from_evaluations_ext_vec( - num_vars, - (0..1 << num_vars).map(|_| E::random(&mut OsRng)).collect(), - ) - } - }) - .collect_vec(); - let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); + let polys = gen_rand_polys(|_| num_vars, batch_size, gen_rand_poly); + let comm = + Pcs::batch_commit_and_write(&pp, polys.as_slice(), &mut transcript).unwrap(); + let point = get_point_from_challenge(num_vars, &mut transcript); + let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); + transcript.append_field_element_exts(&evals); - let evals = (0..batch_size) - .map(|i| polys[i].evaluate(&point)) + let polys = polys + .iter() + .map(|poly| ArcMultilinearExtension::from(poly.clone())) .collect_vec(); - - transcript.append_field_element_exts(&evals); - let proof = Pcs::simple_batch_open( - &pp, - polys - .into_iter() - .map(|x| x.into()) - .collect::>() - .as_slice(), - &comm, - &point, - &evals, - &mut transcript, - ) - .unwrap(); + let proof = + Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) + .unwrap(); ( Pcs::get_pure_commitment(&comm), evals, @@ -605,25 +600,19 @@ pub mod test_util { ) }; // Batch verify - let result = { + { let mut transcript = Transcript::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); - let point = (0..num_vars) - .map(|_| transcript.get_and_append_challenge(b"Point").elements) - .collect::>(); - + let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_exts(&evals); - let result = - Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript); + Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript) + .unwrap(); let v_challenge = transcript.read_challenge(); assert_eq!(challenge, v_challenge); - result - }; - - result.unwrap(); + } } } } diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index 80ecf2535..7688b53ec 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -6,7 +6,7 @@ pub mod plonky2_util; use ff::{Field, PrimeField}; use ff_ext::ExtensionField; use goldilocks::SmallField; -use itertools::{Itertools, izip}; +use itertools::{Either, Itertools, izip}; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; pub mod merkle_tree; @@ -148,38 +148,45 @@ pub fn field_type_index_set_ext( } } -pub struct FieldTypeIterExt<'a, E: ExtensionField> { - inner: &'a FieldType, - index: usize, +pub fn poly_iter_ext( + poly: &DenseMultilinearExtension, +) -> impl Iterator + '_ { + field_type_iter_ext(&poly.evaluations) } -impl<'a, E: ExtensionField> Iterator for FieldTypeIterExt<'a, E> { - type Item = E; +pub fn field_type_iter_ext( + evaluations: &FieldType, +) -> impl Iterator + '_ { + match evaluations { + FieldType::Ext(coeffs) => Either::Left(coeffs.iter().copied()), + FieldType::Base(coeffs) => Either::Right(coeffs.iter().map(|x| (*x).into())), + _ => unreachable!(), + } +} - fn next(&mut self) -> Option { - if self.index >= self.inner.len() { - None - } else { - let res = field_type_index_ext(self.inner, self.index); - self.index += 1; - Some(res) - } +pub fn field_type_to_ext_vec(evaluations: &FieldType) -> Vec { + match evaluations { + FieldType::Ext(coeffs) => coeffs.to_vec(), + FieldType::Base(coeffs) => coeffs.iter().map(|&x| x.into()).collect(), + _ => unreachable!(), } } -pub fn poly_iter_ext( - poly: &DenseMultilinearExtension, -) -> FieldTypeIterExt { - FieldTypeIterExt { - inner: &poly.evaluations, - index: 0, +pub fn field_type_as_ext(values: &FieldType) -> &Vec { + match values { + FieldType::Ext(coeffs) => coeffs, + FieldType::Base(_) => panic!("Expected ext field"), + _ => unreachable!(), } } -pub fn field_type_iter_ext(evaluations: &FieldType) -> FieldTypeIterExt { - FieldTypeIterExt { - inner: evaluations, - index: 0, +pub fn field_type_iter_base( + values: &FieldType, +) -> impl Iterator + '_ { + match values { + FieldType::Ext(coeffs) => Either::Left(coeffs.iter().flat_map(|x| x.as_bases())), + FieldType::Base(coeffs) => Either::Right(coeffs.iter()), + _ => unreachable!(), } } diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index ad364def1..1a8777641 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Multilinear extensions for the Ceno project" edition.workspace = true +keywords.workspace = true license.workspace = true name = "multilinear_extensions" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/multilinear_extensions/src/lib.rs b/multilinear_extensions/src/lib.rs index 7f0a0c089..9f669e348 100644 --- a/multilinear_extensions/src/lib.rs +++ b/multilinear_extensions/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] pub mod mle; pub mod util; pub mod virtual_poly; diff --git a/poseidon/Cargo.toml b/poseidon/Cargo.toml index 489f4efc1..eff0f50b7 100644 --- a/poseidon/Cargo.toml +++ b/poseidon/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Poseidon hash function" edition.workspace = true +keywords.workspace = true license.workspace = true name = "poseidon" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/poseidon/src/lib.rs b/poseidon/src/lib.rs index 31ed313ad..17db28f72 100644 --- a/poseidon/src/lib.rs +++ b/poseidon/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] extern crate core; pub(crate) mod constants; diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 540495c0a..092187cba 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Sumcheck protocol implementation" edition.workspace = true +keywords.workspace = true license.workspace = true name = "sumcheck" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 14ed79aed..0d0b95adf 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] #[cfg(feature = "non_pow2_rayon_thread")] pub mod local_thread_pool; mod macros; diff --git a/transcript/Cargo.toml b/transcript/Cargo.toml index 4769afe00..f784b689f 100644 --- a/transcript/Cargo.toml +++ b/transcript/Cargo.toml @@ -1,7 +1,12 @@ [package] +categories.workspace = true +description = "Transcript generation for Ceno" edition.workspace = true +keywords.workspace = true license.workspace = true name = "transcript" +readme.workspace = true +repository.workspace = true version.workspace = true [dependencies] diff --git a/transcript/src/lib.rs b/transcript/src/lib.rs index b291fc58b..376bdb1ca 100644 --- a/transcript/src/lib.rs +++ b/transcript/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(clippy::cargo)] //! This repo is not properly implemented //! Transcript APIs are placeholders; the actual logic is to be implemented later. #![feature(generic_arg_infer)]