diff --git a/Cargo.lock b/Cargo.lock index 1fdc11b9..92af63fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -760,6 +760,24 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rand_seeder" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf2890aaef0aa82719a50e808de264f9484b74b442e1a3a0e5ee38243ac40bdb" +dependencies = [ + "rand_core", +] + [[package]] name = "rayon" version = "1.9.0" @@ -1047,6 +1065,8 @@ dependencies = [ "p3-uni-stark", "p3-util", "rand", + "rand_pcg", + "rand_seeder", "tracing", "valida-alu-u32", "valida-assembler", diff --git a/alu_u32/src/com/mod.rs b/alu_u32/src/com/mod.rs index 5084062a..e05e98da 100644 --- a/alu_u32/src/com/mod.rs +++ b/alu_u32/src/com/mod.rs @@ -16,8 +16,7 @@ use valida_opcodes::{EQ32, NE32}; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -// use p3_maybe_rayon::*; -use p3_maybe_rayon::prelude::IntoParallelRefIterator; +use p3_maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; use valida_util::pad_to_power_of_two; pub mod columns; diff --git a/basic/Cargo.toml b/basic/Cargo.toml index 933c0cec..9b899126 100644 --- a/basic/Cargo.toml +++ b/basic/Cargo.toml @@ -8,9 +8,17 @@ license = "MIT OR Apache-2.0" name = "valida" path = "src/bin/valida.rs" +[[bin]] +name = "test_prover" +path = "src/bin/test_prover.rs" + [dependencies] byteorder = "1.4.3" +ciborium = "0.2.2" clap = { version = "4.3.19", features = ["derive"] } +rand = "0.8.5" +rand_pcg = "0.3.1" +rand_seeder = "0.2.3" tracing = "0.1.37" valida-alu-u32 = { path = "../alu_u32" } valida-assembler = { path = "../assembler" } @@ -32,10 +40,17 @@ p3-uni-stark = { workspace = true } p3-commit = { workspace = true } p3-air = { workspace = true } p3-matrix = { workspace = true } +p3-challenger = { workspace = true } +p3-dft = { workspace = true } +p3-fri = { workspace = true } +p3-keccak = { workspace = true } +p3-mds = { workspace = true } +p3-merkle-tree = { workspace = true } +p3-poseidon = { workspace = true } +p3-symmetric = { workspace = true } [dev-dependencies] ciborium = "0.2.2" -rand = "0.8.5" p3-challenger = { workspace = true } p3-dft = { workspace = true } p3-field = { workspace = true } diff --git a/basic/src/bin/test_prover.rs b/basic/src/bin/test_prover.rs new file mode 100644 index 00000000..f754bce7 --- /dev/null +++ b/basic/src/bin/test_prover.rs @@ -0,0 +1,265 @@ +extern crate core; + +use p3_baby_bear::BabyBear; +use p3_fri::{TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use valida_alu_u32::add::{Add32Instruction, MachineWithAdd32Chip}; +use valida_basic::BasicMachine; +use valida_cpu::{ + BeqInstruction, BneInstruction, Imm32Instruction, JalInstruction, JalvInstruction, + MachineWithCpuChip, StopInstruction, +}; +use valida_machine::{ + FixedAdviceProvider, Instruction, InstructionWord, Machine, MachineProof, Operands, ProgramROM, + Word, +}; + +use valida_memory::MachineWithMemoryChip; +use valida_opcodes::BYTES_PER_INSTR; +use valida_program::MachineWithProgramChip; + +use p3_challenger::DuplexChallenger; +use p3_dft::Radix2Bowers; +use p3_field::extension::BinomialExtensionField; +use p3_field::Field; +use p3_fri::FriConfig; +use p3_keccak::Keccak256Hash; +use p3_mds::coset_mds::CosetMds; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon::Poseidon; +use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; +use rand::thread_rng; +use valida_machine::StarkConfigImpl; +use valida_machine::__internal::p3_commit::ExtensionMmcs; + +fn main() { + prove_fibonacci() +} + +fn prove_fibonacci() { + let mut program = vec![]; + + // Label locations + let bytes_per_instr = BYTES_PER_INSTR as i32; + let fib_bb0 = 8 * bytes_per_instr; + let fib_bb0_1 = 13 * bytes_per_instr; + let fib_bb0_2 = 15 * bytes_per_instr; + let fib_bb0_3 = 19 * bytes_per_instr; + let fib_bb0_4 = 21 * bytes_per_instr; + + //main: ; @main + //; %bb.0: + // imm32 -4(fp), 0, 0, 0, 0 + // imm32 -8(fp), 0, 0, 0, 100 + // addi -16(fp), -8(fp), 0 + // imm32 -20(fp), 0, 0, 0, 28 + // jal -28(fp), fib, -28 + // addi -12(fp), -24(fp), 0 + // addi 4(fp), -12(fp), 0 + // exit + //... + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-4, 0, 0, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-8, 0, 0, 0, 100]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-16, -8, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-20, 0, 0, 0, 28]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-28, fib_bb0, -28, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-12, -24, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([4, -12, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands::default(), + }, + ]); + + //fib: ; @fib + //; %bb.0: + // addi -4(fp), 12(fp), 0 + // imm32 -8(fp), 0, 0, 0, 0 + // imm32 -12(fp), 0, 0, 0, 1 + // imm32 -16(fp), 0, 0, 0, 0 + // beq .LBB0_1, 0(fp), 0(fp) + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-4, 12, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-8, 0, 0, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-12, 0, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-16, 0, 0, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([fib_bb0_1, 0, 0, 0, 0]), + }, + ]); + + //.LBB0_1: + // bne .LBB0_2, -16(fp), -4(fp) + // beq .LBB0_4, 0(fp), 0(fp) + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([fib_bb0_2, -16, -4, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([fib_bb0_4, 0, 0, 0, 0]), + }, + ]); + + //; %bb.2: + // add -20(fp), -8(fp), -12(fp) + // addi -8(fp), -12(fp), 0 + // addi -12(fp), -20(fp), 0 + // beq .LBB0_3, 0(fp), 0(fp) + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-20, -8, -12, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-8, -12, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-12, -20, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([fib_bb0_3, 0, 0, 0, 0]), + }, + ]); + + //; %bb.3: + // addi -16(fp), -16(fp), 1 + // beq .LBB0_1, 0(fp), 0(fp) + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-16, -16, 1, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([fib_bb0_1, 0, 0, 0, 0]), + }, + ]); + + //.LBB0_4: + // addi 4(fp), -8(fp), 0 + // jalv -4(fp), 0(fp), 8(fp) + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([4, -8, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-4, 0, 8, 0, 0]), + }, + ]); + + let mut machine = BasicMachine::::default(); + let rom = ProgramROM::new(program); + machine.program_mut().set_program_rom(&rom); + machine.cpu_mut().fp = 0x1000; + machine.cpu_mut().save_register_state(); // TODO: Initial register state should be saved + // automatically by the machine, not manually here + machine.run(&rom, &mut FixedAdviceProvider::empty()); + + type Val = BabyBear; + type Challenge = BinomialExtensionField; + type PackedChallenge = BinomialExtensionField<::Packing, 5>; + + type Mds16 = CosetMds; + let mds16 = Mds16::default(); + + type Perm16 = Poseidon; + let perm16 = Perm16::new_from_rng(4, 22, mds16, &mut thread_rng()); // TODO: Use deterministic RNG + + type MyHash = SerializingHasher32; + let hash = MyHash::new(Keccak256Hash {}); + + type MyCompress = CompressionFunctionFromHasher; + let compress = MyCompress::new(hash); + + type ValMmcs = FieldMerkleTreeMmcs; + let val_mmcs = ValMmcs::new(hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Dft = Radix2Bowers; + let dft = Dft::default(); + + type Challenger = DuplexChallenger; + + type MyFriConfig = TwoAdicFriPcsConfig; + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 40, + proof_of_work_bits: 8, + mmcs: challenge_mmcs, + }; + + type Pcs = TwoAdicFriPcs; + type MyConfig = StarkConfigImpl; + + let pcs = Pcs::new(fri_config, dft, val_mmcs); + + let challenger = Challenger::new(perm16); + let config = MyConfig::new(pcs, challenger); + let proof = machine.prove(&config); + + let mut bytes = vec![]; + ciborium::into_writer(&proof, &mut bytes).expect("serialization failed"); + println!("Proof size: {} bytes", bytes.len()); + let deserialized_proof: MachineProof = + ciborium::from_reader(bytes.as_slice()).expect("deserialization failed"); + + machine + .verify(&config, &proof) + .expect("verification failed"); + machine + .verify(&config, &deserialized_proof) + .expect("verification failed"); + + // assert_eq!(machine.cpu().clock, 192); + // assert_eq!(machine.cpu().operations.len(), 192); + // assert_eq!(machine.mem().operations.values().flatten().count(), 401); + // assert_eq!(machine.add_u32().operations.len(), 105); + + // assert_eq!( + // *machine.mem().cells.get(&(0x1000 + 4)).unwrap(), // Return value + // Word([0, 1, 37, 17,]) // 25th fibonacci number (75025) + // ); +} diff --git a/basic/src/bin/valida.rs b/basic/src/bin/valida.rs index 6220a318..ebe12173 100644 --- a/basic/src/bin/valida.rs +++ b/basic/src/bin/valida.rs @@ -1,20 +1,46 @@ use clap::Parser; +use std::fs::File; use std::io::{stdout, Write}; use valida_basic::BasicMachine; + +use p3_baby_bear::BabyBear; + +use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; use valida_cpu::MachineWithCpuChip; -use valida_machine::{Machine, ProgramROM, StdinAdviceProvider}; -use valida_output::MachineWithOutputChip; +use valida_machine::{Machine, MachineProof, ProgramROM, StdinAdviceProvider}; + use valida_program::MachineWithProgramChip; -use p3_baby_bear::BabyBear; +use p3_challenger::DuplexChallenger; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_field::Field; +use p3_keccak::Keccak256Hash; +use p3_mds::coset_mds::CosetMds; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon::Poseidon; +use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; +use rand_pcg::Pcg64; +use rand_seeder::Seeder; +use valida_machine::StarkConfigImpl; +use valida_machine::__internal::p3_commit::ExtensionMmcs; +use valida_output::MachineWithOutputChip; #[derive(Parser)] struct Args { + /// Command option either "run" or "prove" or "verify" + #[arg(name = "Action Option")] + action: String, + /// Program binary file - #[arg(name = "FILE")] + #[arg(name = "PROGRAM FILE")] program: String, + /// The output file for run or prove, or the input file for verify + #[arg(name = "ACTION FILE")] + action_file: String, + /// Stack height (which is also the initial frame pointer value) #[arg(long, default_value = "16777216")] stack_height: u32, @@ -35,6 +61,97 @@ fn main() { // Run the program machine.run(&rom, &mut StdinAdviceProvider); - // Write output chip values to standard output - stdout().write_all(&machine.output().bytes()).unwrap(); + type Val = BabyBear; + type Challenge = BinomialExtensionField; + type PackedChallenge = BinomialExtensionField<::Packing, 5>; + + type Mds16 = CosetMds; + let mds16 = Mds16::default(); + + type Perm16 = Poseidon; + let mut rng: Pcg64 = Seeder::from("validia seed").make_rng(); + let perm16 = Perm16::new_from_rng(4, 22, mds16, &mut rng); + + type MyHash = SerializingHasher32; + let hash = MyHash::new(Keccak256Hash {}); + + type MyCompress = CompressionFunctionFromHasher; + let compress = MyCompress::new(hash); + + type ValMmcs = FieldMerkleTreeMmcs; + let val_mmcs = ValMmcs::new(hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Dft = Radix2DitParallel; + let dft = Dft::default(); + + type Challenger = DuplexChallenger; + + type MyFriConfig = TwoAdicFriPcsConfig; + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 40, + proof_of_work_bits: 8, + mmcs: challenge_mmcs, + }; + + type Pcs = TwoAdicFriPcs; + type MyConfig = StarkConfigImpl; + + let pcs = Pcs::new(fri_config, dft, val_mmcs); + + let challenger = Challenger::new(perm16); + let config = MyConfig::new(pcs, challenger); + + if args.action == "run" { + let mut action_file; + match File::create(args.action_file) { + Ok(file) => { + action_file = file; + } + Err(e) => { + stdout().write(e.to_string().as_bytes()).unwrap(); + return (); + } + } + action_file.write_all(&machine.output().bytes()).unwrap(); + } else if args.action == "prove" { + let mut action_file; + match File::create(args.action_file) { + Ok(file) => { + action_file = file; + } + Err(e) => { + stdout().write(e.to_string().as_bytes()).unwrap(); + return (); + } + } + let proof = machine.prove(&config); + debug_assert!(machine.verify(&config, &proof).is_ok()); + let mut bytes = vec![]; + ciborium::into_writer(&proof, &mut bytes).expect("Proof serialization failed"); + action_file.write(&bytes).expect("Writing proof failed"); + stdout().write("Proof successful\n".as_bytes()).unwrap(); + } else if args.action == "verify" { + let bytes = std::fs::read(args.action_file).expect("File reading failed"); + let proof: MachineProof = + ciborium::from_reader(bytes.as_slice()).expect("Proof deserialization failed"); + let verification_result = machine.verify(&config, &proof); + match verification_result { + Ok(_) => { + stdout().write("Proof verified\n".as_bytes()).unwrap(); + } + Err(_) => { + stdout() + .write("Proof verification failed\n".as_bytes()) + .unwrap(); + } + } + } else { + stdout() + .write("Action name unrecognized".as_bytes()) + .unwrap(); + } } diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index b276e280..cd84403c 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -330,10 +330,12 @@ impl CpuChip { }); } - fn set_imm_value(&self, cols: &mut CpuCols, imm: Option>) { + fn set_imm_value(&self, cols: &mut CpuCols, imm: Option>) { if let Some(imm) = imm { cols.opcode_flags.is_imm_op = F::one(); - cols.mem_channels[1].value = imm.transform(F::from_canonical_u8); + let imm = imm.transform(F::from_canonical_u8); + cols.mem_channels[1].value = imm; + cols.instruction.operands.0[2] = imm.reduce(); } } } diff --git a/derive/src/lib.rs b/derive/src/lib.rs index f9793981..b59bdc53 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -350,7 +350,6 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { let (openings, opening_proof) = pcs.open_multi_batches( &prover_data_and_points, &mut challenger); - // TODO: add preprocessed openings let [main_openings, perm_openings, quotient_openings] = openings.try_into().expect("Should have 3 rounds of openings");