diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e767fff..2da7ec8f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,9 +18,8 @@ jobs: - name: Checkout Plonky3 uses: actions/checkout@v4 with: - repository: lita-xyz/Plonky3 + repository: Plonky3/Plonky3 path: Plonky3 - ref: batch_prover - name: Checkout Valida uses: actions/checkout@v4 @@ -49,9 +48,8 @@ jobs: - name: Checkout Plonky3 uses: actions/checkout@v4 with: - repository: lita-xyz/Plonky3 + repository: Plonky3/Plonky3 path: Plonky3 - ref: batch_prover - name: Checkout Valida uses: actions/checkout@v4 diff --git a/alu_u32/src/add/mod.rs b/alu_u32/src/add/mod.rs index 8a382fb0..6286f3cc 100644 --- a/alu_u32/src/add/mod.rs +++ b/alu_u32/src/add/mod.rs @@ -13,8 +13,8 @@ use valida_range::MachineWithRangeChip; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; -use p3_uni_stark::StarkConfig; +use p3_maybe_rayon::prelude::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; diff --git a/alu_u32/src/bitwise/mod.rs b/alu_u32/src/bitwise/mod.rs index ffa4bc4c..fe590c41 100644 --- a/alu_u32/src/bitwise/mod.rs +++ b/alu_u32/src/bitwise/mod.rs @@ -12,8 +12,8 @@ use valida_opcodes::{AND32, OR32, XOR32}; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; -use p3_uni_stark::StarkConfig; +use p3_maybe_rayon::prelude::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; diff --git a/alu_u32/src/div/mod.rs b/alu_u32/src/div/mod.rs index 2f65c3ee..f8b623b5 100644 --- a/alu_u32/src/div/mod.rs +++ b/alu_u32/src/div/mod.rs @@ -6,18 +6,18 @@ use columns::{Div32Cols, DIV_COL_MAP, NUM_DIV_COLS}; use core::mem::transmute; use valida_bus::MachineWithGeneralBus; use valida_cpu::MachineWithCpuChip; +use valida_machine::config::StarkConfig; use valida_machine::core::SDiv; use valida_machine::{instructions, Chip, Instruction, Interaction, Operands, Word}; use valida_opcodes::{DIV32, SDIV32}; use valida_range::MachineWithRangeChip; +use valida_util::pad_to_power_of_two; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; +use p3_maybe_rayon::prelude::*; -use p3_uni_stark::StarkConfig; -use valida_util::pad_to_power_of_two; pub mod columns; pub mod stark; diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index 1a53c5bd..c501dfa6 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -15,9 +15,9 @@ use valida_opcodes::LT32; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; +use p3_maybe_rayon::prelude::*; -use p3_uni_stark::StarkConfig; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; pub mod stark; diff --git a/alu_u32/src/mul/mod.rs b/alu_u32/src/mul/mod.rs index c43155ee..45d4c800 100644 --- a/alu_u32/src/mul/mod.rs +++ b/alu_u32/src/mul/mod.rs @@ -13,7 +13,7 @@ use core::borrow::BorrowMut; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_uni_stark::StarkConfig; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; diff --git a/alu_u32/src/shift/mod.rs b/alu_u32/src/shift/mod.rs index 6f5b2df3..332fbd24 100644 --- a/alu_u32/src/shift/mod.rs +++ b/alu_u32/src/shift/mod.rs @@ -14,8 +14,8 @@ use valida_opcodes::{DIV32, MUL32, SDIV32, SHL32, SHR32, SRA32}; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; -use p3_uni_stark::StarkConfig; +use p3_maybe_rayon::prelude::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; diff --git a/alu_u32/src/sub/mod.rs b/alu_u32/src/sub/mod.rs index d935acc2..b79bf520 100644 --- a/alu_u32/src/sub/mod.rs +++ b/alu_u32/src/sub/mod.rs @@ -13,9 +13,8 @@ use valida_range::MachineWithRangeChip; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; -use p3_uni_stark::StarkConfig; - +use p3_maybe_rayon::prelude::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; diff --git a/basic/src/lib.rs b/basic/src/lib.rs index b6c9bb1b..29029d56 100644 --- a/basic/src/lib.rs +++ b/basic/src/lib.rs @@ -1,4 +1,4 @@ -//#![no_std] +#![no_std] #![allow(unused)] extern crate alloc; @@ -9,13 +9,8 @@ use p3_air::Air; use p3_commit::{Pcs, UnivariatePcs, UnivariatePcsWithLde}; use p3_field::PrimeField32; use p3_field::{extension::BinomialExtensionField, TwoAdicField}; -use p3_goldilocks::Goldilocks; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; -use p3_uni_stark::{ - get_max_constraint_degree, get_trace_and_quotient_ldes, open, Commitments, Proof, - ProverConstraintFolder, ProverData, StarkConfig, SymbolicAirBuilder, -}; use p3_util::log2_ceil_usize; use valida_alu_u32::{ add::{Add32Chip, Add32Instruction, MachineWithAdd32Chip}, @@ -52,81 +47,121 @@ use valida_output::{MachineWithOutputChip, OutputChip, WriteInstruction}; use valida_program::{MachineWithProgramChip, ProgramChip}; use valida_range::{MachineWithRangeChip, RangeCheckerChip}; -#[derive(Default)] -pub struct BasicMachine { +use p3_maybe_rayon::prelude::*; +use valida_machine::config::StarkConfig; + +#[derive(Machine, Default)] +#[machine_fields(F)] +pub struct BasicMachine { // Core instructions + #[instruction] load32: Load32Instruction, + #[instruction] store32: Store32Instruction, + #[instruction] jal: JalInstruction, + #[instruction] jalv: JalvInstruction, + #[instruction] beq: BeqInstruction, + #[instruction] bne: BneInstruction, + #[instruction] imm32: Imm32Instruction, + #[instruction] stop: StopInstruction, // ALU instructions + #[instruction(add_u32)] add32: Add32Instruction, + #[instruction(sub_u32)] sub32: Sub32Instruction, + #[instruction(mul_u32)] mul32: Mul32Instruction, + #[instruction(mul_u32)] mulhs32: Mulhs32Instruction, + #[instruction(mul_u32)] mulhu32: Mulhu32Instruction, + #[instruction(div_u32)] div32: Div32Instruction, + #[instruction(div_u32)] sdiv32: SDiv32Instruction, + #[instruction(shift_u32)] shl32: Shl32Instruction, + #[instruction(shift_u32)] shr32: Shr32Instruction, + #[instruction(shift_u32)] sra32: Sra32Instruction, + #[instruction(lt_u32)] lt32: Lt32Instruction, + #[instruction(bitwise_u32)] and32: And32Instruction, + #[instruction(bitwise_u32)] or32: Or32Instruction, + #[instruction(bitwise_u32)] xor32: Xor32Instruction, // Input/output instructions + #[instruction] read: ReadAdviceInstruction, + #[instruction(output)] write: WriteInstruction, + #[chip] cpu: CpuChip, + #[chip] program: ProgramChip, + #[chip] mem: MemoryChip, + #[chip] add_u32: Add32Chip, + #[chip] sub_u32: Sub32Chip, + #[chip] mul_u32: Mul32Chip, + #[chip] div_u32: Div32Chip, + #[chip] shift_u32: Shift32Chip, + #[chip] lt_u32: Lt32Chip, + #[chip] bitwise_u32: Bitwise32Chip, + #[chip] output: OutputChip, + #[chip] range: RangeCheckerChip<256>, _phantom_sc: PhantomData F>, @@ -275,288 +310,3 @@ impl MachineWithRangeChip for BasicMachi &mut self.range } } - -impl Machine for BasicMachine { - fn run(&mut self, program: &ProgramROM, advice: &mut Adv) { - loop { - let pc = self.cpu.pc; - let instruction = program.get_instruction(pc); - let opcode = instruction.opcode; - let ops = instruction.operands; - match opcode { - >::OPCODE => { - Load32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Store32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - JalInstruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - JalvInstruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - BeqInstruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - BneInstruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Imm32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Add32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Sub32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Mul32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Mulhs32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Mulhu32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Div32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - SDiv32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Shl32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Shr32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Sra32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Lt32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - And32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Or32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - Xor32Instruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - ReadAdviceInstruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - WriteInstruction::execute_with_advice(self, ops, advice); - } - >::OPCODE => { - StopInstruction::execute_with_advice(self, ops, advice); - } - _ => {} - } - - self.read_word(pc as usize); - - if opcode == >::OPCODE { - break; - } - } - let n = self.cpu.clock.next_power_of_two() - self.cpu.clock; - for _ in 0..n { - self.read_word(self.cpu.pc as usize); - } - } - fn add_chip_trace( - &self, - config: &SC, - challenger: &mut SC::Challenger, - trace_commitments: &mut Vec>, - quotient_commitments: &mut Vec>, - log_degrees: &mut Vec, - log_quotient_degrees: &mut Vec, - chip: &A, - trace: RowMajorMatrix<::Val>, - ) where - SC: StarkConfig, - A: Air> + for<'a> Air>, - { - let (trace_lde, quotient_lde, log_degree, log_quotient_degree) = - get_trace_and_quotient_ldes(config, trace, chip, challenger); - trace_commitments.push(trace_lde); - quotient_commitments.push(quotient_lde); - log_degrees.push(log_degree); - log_quotient_degrees.push(log_quotient_degree); - } - fn prove(&self, config: &SC, challenger: &mut SC::Challenger) -> MachineProof - where - SC: StarkConfig, - { - let mut trace_commitments = Vec::new(); - let mut quotient_commitments = Vec::new(); - let mut log_degrees = Vec::new(); - let mut log_quotient_degrees = Vec::new(); - /* - let air = &self.cpu(); - assert_eq!(air.operations.len() > 0, true); - let trace = air.generate_trace(air, self); - self.add_chip_trace( - config, - challenger, - &mut trace_commitments, - &mut quotient_commitments, - &mut log_degrees, - &mut log_quotient_degrees, - air, - trace, - ); - */ - if self.add_u32.operations.len() > 0 { - let air = &self.add_u32; - let trace = , SC>>::generate_trace(air, self); - - self.add_chip_trace( - config, - challenger, - &mut trace_commitments, - &mut quotient_commitments, - &mut log_degrees, - &mut log_quotient_degrees, - air, - trace, - ); - } - if self.sub_u32.operations.len() > 0 { - let air = &self.sub_u32; - let trace = , SC>>::generate_trace(air, self); - - self.add_chip_trace( - config, - challenger, - &mut trace_commitments, - &mut quotient_commitments, - &mut log_degrees, - &mut log_quotient_degrees, - air, - trace, - ); - } - if self.mul_u32.operations.len() > 0 { - let air = &self.mul_u32; - let trace = , SC>>::generate_trace(air, self); - - self.add_chip_trace( - config, - challenger, - &mut trace_commitments, - &mut quotient_commitments, - &mut log_degrees, - &mut log_quotient_degrees, - air, - trace, - ); - } - if self.div_u32.operations.len() > 0 { - let air = &self.div_u32; - - let trace = , SC>>::generate_trace(air, self); - self.add_chip_trace( - config, - challenger, - &mut trace_commitments, - &mut quotient_commitments, - &mut log_degrees, - &mut log_quotient_degrees, - air, - trace, - ); - } - if self.shift_u32.operations.len() > 0 { - let air = &self.shift_u32; - let trace = , SC>>::generate_trace(air, self); - - self.add_chip_trace( - config, - challenger, - &mut trace_commitments, - &mut quotient_commitments, - &mut log_degrees, - &mut log_quotient_degrees, - air, - trace, - ); - } - if self.lt_u32.operations.len() > 0 { - let air = &self.lt_u32; - let trace = , SC>>::generate_trace(air, self); - - self.add_chip_trace( - config, - challenger, - &mut trace_commitments, - &mut quotient_commitments, - &mut log_degrees, - &mut log_quotient_degrees, - air, - trace, - ); - } - - if self.bitwise_u32.operations.len() > 0 { - let air = &self.bitwise_u32; - let trace = , SC>>::generate_trace(air, self); - - self.add_chip_trace( - config, - challenger, - &mut trace_commitments, - &mut quotient_commitments, - &mut log_degrees, - &mut log_quotient_degrees, - air, - trace, - ); - } - - let pcs = config.pcs(); - let (aggregated_commitment, aggregated_trace) = pcs.combine(&trace_commitments); - let (aggregated_quotient_commitment, aggregated_quotient_trace) = - pcs.combine("ient_commitments); - let max_log_degree = log_degrees.iter().max().unwrap(); - let max_quotient_degree = log_quotient_degrees.iter().max().unwrap(); - let (opening_proof, opened_values) = open( - config, - &aggregated_trace, - &aggregated_quotient_trace, - *max_log_degree, - *max_quotient_degree, - challenger, - ); - - let commitments = Commitments { - trace: aggregated_commitment, - quotient_chunks: aggregated_quotient_commitment, - }; - MachineProof { - chip_proof: ChipProof { - proof: Proof { - commitments, - opened_values, - opening_proof, - degree_bits: *max_log_degree, - }, - }, - phantom: PhantomData::default(), - } - } - - fn verify(proof: &MachineProof) -> Result<(), ()> - where - SC: StarkConfig, - { - Ok(()) - } -} diff --git a/basic/tests/test_prover.rs b/basic/tests/test_prover.rs index 9b6818f1..881d5859 100644 --- a/basic/tests/test_prover.rs +++ b/basic/tests/test_prover.rs @@ -8,8 +8,6 @@ use valida_cpu::{ MachineWithCpuChip, StopInstruction, }; -use p3_uni_stark::StarkConfigImpl; - use valida_machine::{ FixedAdviceProvider, Instruction, InstructionWord, Machine, Operands, ProgramROM, Word, }; @@ -28,9 +26,10 @@ use p3_ldt::QuotientMmcs; use p3_mds::coset_mds::CosetMds; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_poseidon::Poseidon; -use p3_symmetric::{CompressionFunctionFromHasher, CryptographicPermutation, SerializingHasher32}; +use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; use rand::thread_rng; use valida_machine::__internal::p3_commit::ExtensionMmcs; +use valida_machine::config::StarkConfigImpl; #[test] fn prove_fibonacci() { @@ -223,7 +222,8 @@ fn prove_fibonacci() { type Quotient = QuotientMmcs; type MyFriConfig = FriConfigImpl; - let fri_config = MyFriConfig::new(40, challenge_mmcs); + // TODO: Change log_blowup from 2 to 1 once degree >3 constraints are eliminated. + let fri_config = MyFriConfig::new(2, 40, challenge_mmcs); let ldt = FriLdt { config: fri_config }; type Pcs = FriBasedPcs; @@ -231,14 +231,10 @@ fn prove_fibonacci() { let pcs = Pcs::new(dft, val_mmcs, ldt); - let config = MyConfig::new(pcs); - - let mut challenger = Challenger::new(perm16); - let out = machine.prove(&config, &mut challenger); - assert_eq!( - out.chip_proof.proof.opened_values.trace_local.len() > 0, - true - ); + let challenger = Challenger::new(perm16); + let config = MyConfig::new(pcs, challenger); + let proof = machine.prove(&config); + BasicMachine::verify(&config, &proof).expect("verification failed"); assert_eq!(machine.cpu().clock, 192); assert_eq!(machine.cpu().operations.len(), 192); assert_eq!(machine.mem().operations.values().flatten().count(), 401); diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index f2ff0c1c..3d8d7a93 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -17,13 +17,13 @@ use valida_memory::{MachineWithMemoryChip, Operation as MemoryOperation}; use valida_opcodes::{ BEQ, BNE, BYTES_PER_INSTR, IMM32, JAL, JALV, LOAD32, READ_ADVICE, STOP, STORE32, }; -use valida_util::batch_multiplicative_inverse; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; -use p3_uni_stark::StarkConfig; +use p3_maybe_rayon::prelude::*; +use valida_machine::config::StarkConfig; +use valida_util::batch_multiplicative_inverse_allowing_zero; pub mod columns; pub mod stark; @@ -71,7 +71,8 @@ where fn generate_trace(&self, machine: &M) -> RowMajorMatrix { let mut rows = self .operations - .par_iter() + .as_slice() + .into_par_iter() .enumerate() .map(|(n, op)| self.op_to_row::(n, op, machine)) .collect::>(); @@ -270,12 +271,12 @@ impl CpuChip { .map(|i| rows[n][i]) .collect::>(); for (a, b) in word_1.into_iter().zip(word_2) { - diff[n] += (a - b) * (a - b); + diff[n] += (a - b).square(); } } // Compute `diff_inv` - let diff_inv = batch_multiplicative_inverse(diff.clone()); + let diff_inv = batch_multiplicative_inverse_allowing_zero(diff.clone()); // Set trace values for n in 0..rows.len() { diff --git a/derive/src/lib.rs b/derive/src/lib.rs index f184fef0..76b87037 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -186,22 +186,46 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { }) .collect::(); - let prove_starks = chips + let quotient_degree_calls = chips + .iter() + .map(|chip| { + let chip_name = chip.ident.as_ref().unwrap(); + quote! { + get_log_quotient_degree::(self, self.#chip_name()), + } + }) + .collect::(); + + let compute_quotients = chips .iter() .enumerate() - .map(|(n, chip)| { + .map(|(i, chip)| { let chip_name = chip.ident.as_ref().unwrap(); quote! { #[cfg(debug_assertions)] check_constraints::( self, self.#chip_name(), - &main_traces[#n], - &perm_traces[#n], + &main_traces[#i], + &perm_traces[#i], &perm_challenges, ); - chip_proofs.push(prove(self, config, self.#chip_name(), &mut challenger)); + // TODO: Needlessly regenerating preprocessed_trace() + let ppt: Option> = self.#chip_name().preprocessed_trace(); + let preprocessed_trace_lde = ppt.map(|trace| preprocessed_trace_ldes.remove(0)); + + quotients.push(quotient( + self, + config, + self.#chip_name(), + log_degrees[#i], + preprocessed_trace_lde, + main_trace_ldes.remove(0), + perm_trace_ldes.remove(0), + &perm_challenges, + alpha, + )); } }) .collect::(); @@ -211,26 +235,46 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { fn prove>(&self, config: &SC) -> ::valida_machine::proof::MachineProof { use ::valida_machine::__internal::*; + use ::valida_machine::__internal::p3_air::{BaseAir}; use ::valida_machine::__internal::p3_challenger::{CanObserve, FieldChallenger}; - use ::valida_machine::__internal::p3_commit::{Pcs, UnivariatePcs}; + use ::valida_machine::__internal::p3_commit::{Pcs, UnivariatePcs, UnivariatePcsWithLde}; use ::valida_machine::__internal::p3_matrix::{Matrix, dense::RowMajorMatrix}; use ::valida_machine::__internal::p3_util::log2_strict_usize; use ::valida_machine::chip::generate_permutation_trace; - use ::valida_machine::proof::MachineProof; + use ::valida_machine::proof::{MachineProof, ChipProof, Commitments}; use alloc::vec; use alloc::vec::Vec; use alloc::boxed::Box; let mut chips: [Box<&dyn Chip>; #num_chips] = [ #chip_list ]; + let log_quotient_degrees: [usize; #num_chips] = [ #quotient_degree_calls ]; let mut challenger = config.challenger(); - - let main_traces: [RowMajorMatrix; #num_chips] = tracing::info_span!("generate main traces") - .in_scope(|| - chips.par_iter().map(|chip| { - chip.generate_trace(self) - }).collect::>().try_into().unwrap() - ); + // TODO: Seed challenger with digest of all constraints & trace lengths. + let pcs = config.pcs(); + + let preprocessed_traces: Vec> = + tracing::info_span!("generate preprocessed traces") + .in_scope(|| + chips.par_iter() + .flat_map(|chip| chip.preprocessed_trace()) + .collect::>() + ); + + let (preprocessed_commit, preprocessed_data) = + tracing::info_span!("commit to preprocessed traces") + .in_scope(|| pcs.commit_batches(preprocessed_traces.to_vec())); + challenger.observe(preprocessed_commit.clone()); + let mut preprocessed_trace_ldes = pcs.get_ldes(&preprocessed_data); + + let main_traces: [RowMajorMatrix; #num_chips] = + tracing::info_span!("generate main traces") + .in_scope(|| + chips.par_iter() + .map(|chip| chip.generate_trace(self)) + .collect::>() + .try_into().unwrap() + ); let degrees: [usize; #num_chips] = main_traces.iter() .map(|trace| trace.height()) @@ -240,10 +284,9 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { let g_subgroups = log_degrees.map(|log_deg| SC::Val::two_adic_generator(log_deg)); let (main_commit, main_data) = tracing::info_span!("commit to main traces") - .in_scope(|| - config.pcs().commit_batches(main_traces.to_vec()) - ); + .in_scope(|| pcs.commit_batches(main_traces.to_vec())); challenger.observe(main_commit.clone()); + let mut main_trace_ldes = pcs.get_ldes(&main_data); let mut perm_challenges = Vec::new(); for _ in 0..3 { @@ -259,35 +302,52 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { let (perm_commit, perm_data) = tracing::info_span!("commit to permutation traces") .in_scope(|| { - let flattened_perm_traces = perm_traces.iter().map(|trace| { - trace.flatten_to_base() - }).collect::>(); - config.pcs().commit_batches(flattened_perm_traces) + let flattened_perm_traces = perm_traces.iter() + .map(|trace| trace.flatten_to_base()) + .collect::>(); + pcs.commit_batches(flattened_perm_traces) }); challenger.observe(perm_commit.clone()); + let mut perm_trace_ldes = pcs.get_ldes(&perm_data); + + let alpha: SC::Challenge = challenger.sample_ext_element(); + + let mut quotients: Vec> = vec![]; + #compute_quotients + let (quotient_commit, quotient_data) = tracing::info_span!("commit to quotient chunks") + .in_scope(|| pcs.commit_batches(quotients.to_vec())); + + #[cfg(debug_assertions)] + check_cumulative_sums(&perm_traces[..]); let zeta: SC::Challenge = challenger.sample_ext_element(); let zeta_and_next: [Vec; #num_chips] = - core::array::from_fn(|i| vec![zeta, zeta * g_subgroups[i]]); + g_subgroups.map(|g| vec![zeta, zeta * g]); + let zeta_exp_quotient_degree: [Vec; #num_chips] = + log_quotient_degrees.map(|log_deg| vec![zeta.exp_power_of_2(log_deg)]); let prover_data_and_points = [ + // TODO: Causes some errors, probably related to the fact that not all chips have preprocessed traces? + // (&preprocessed_data, zeta_and_next.as_slice()), (&main_data, zeta_and_next.as_slice()), (&perm_data, zeta_and_next.as_slice()), - // TODO: Enable when we have quotient computation - // ("ient_data, &[zeta.exp_power_of_2(log_quotient_degree)]), + ("ient_data, zeta_exp_quotient_degree.as_slice()), ]; - let (openings, opening_proof) = config.pcs().open_multi_batches( + let (openings, opening_proof) = pcs.open_multi_batches( &prover_data_and_points, &mut challenger); - let mut chip_proofs = vec![]; - #prove_starks - - #[cfg(debug_assertions)] - check_cumulative_sums(&perm_traces[..]); + // let [preprocessed_openings, main_openings, perm_openings, quotient_openings] = + // openings.try_into().expect("Should have 4 rounds of openings"); + let commitments = Commitments { + main_trace: main_commit, + perm_trace: perm_commit, + quotient_chunks: quotient_commit, + }; + let chip_proofs = vec![]; // TODO MachineProof { - // opening_proof, + commitments, + opening_proof, chip_proofs, - phantom: core::marker::PhantomData, } } } @@ -296,6 +356,7 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { fn verify_method(_chips: &[&Field]) -> TokenStream2 { quote! { fn verify>( + config: &SC, proof: &::valida_machine::proof::MachineProof, ) -> core::result::Result<(), ()> { diff --git a/machine/Cargo.toml b/machine/Cargo.toml index 6cff4e23..6c6b37da 100644 --- a/machine/Cargo.toml +++ b/machine/Cargo.toml @@ -11,6 +11,8 @@ std = [] [dependencies] byteorder = "1.4.3" itertools = "0.10.3" +serde = { version = "1.0", default-features = false, features = ["derive"] } +tracing = "0.1.37" p3-air = { path = "../../Plonky3/air" } p3-baby-bear = { path = "../../Plonky3/baby-bear" } @@ -20,7 +22,6 @@ p3-dft = { path = "../../Plonky3/dft" } p3-field = { path = "../../Plonky3/field" } p3-matrix = { path = "../../Plonky3/matrix" } p3-maybe-rayon = { path = "../../Plonky3/maybe-rayon" } -p3-util = { path = "../../Plonky3/util" } p3-uni-stark = { path = "../../Plonky3/uni-stark" } +p3-util = { path = "../../Plonky3/util" } valida-util = { path = "../util" } - diff --git a/machine/src/__internal/check_constraints.rs b/machine/src/__internal/check_constraints.rs index 4fcaf766..a434a725 100644 --- a/machine/src/__internal/check_constraints.rs +++ b/machine/src/__internal/check_constraints.rs @@ -1,14 +1,14 @@ use crate::__internal::DebugConstraintBuilder; use crate::chip::eval_permutation_constraints; -use p3_uni_stark::StarkConfig; +use valida_machine::config::StarkConfig; use crate::{Chip, Machine}; -use p3_air::{Air, TwoRowMatrixView}; +use p3_air::TwoRowMatrixView; use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use p3_matrix::MatrixRowSlices; -use p3_maybe_rayon::{MaybeIntoParIter, ParallelIterator}; +use p3_maybe_rayon::prelude::*; /// Check that all constraints vanish on the subgroup. pub fn check_constraints( @@ -18,8 +18,8 @@ pub fn check_constraints( perm: &RowMajorMatrix, perm_challenges: &[SC::Challenge], ) where - M: Machine + Sync, - A: Chip + for<'a> Air>, + M: Machine, + A: Chip, SC: StarkConfig, { assert_eq!(main.height(), perm.height()); diff --git a/machine/src/__internal/debug_builder.rs b/machine/src/__internal/debug_builder.rs index 0a9708b5..b738f754 100644 --- a/machine/src/__internal/debug_builder.rs +++ b/machine/src/__internal/debug_builder.rs @@ -1,7 +1,7 @@ use crate::{Machine, ValidaAirBuilder}; use p3_air::{AirBuilder, PairBuilder, PermutationAirBuilder, TwoRowMatrixView}; use p3_field::AbstractField; -use p3_uni_stark::StarkConfig; +use valida_machine::config::StarkConfig; /// An `AirBuilder` which asserts that each constraint is zero, allowing any failed constraints to /// be detected early. pub struct DebugConstraintBuilder<'a, M: Machine, SC: StarkConfig> { diff --git a/machine/src/__internal/folding_builder.rs b/machine/src/__internal/folding_builder.rs index 871ecaad..8ae1e4b5 100644 --- a/machine/src/__internal/folding_builder.rs +++ b/machine/src/__internal/folding_builder.rs @@ -1,16 +1,19 @@ use crate::{Machine, ValidaAirBuilder}; use p3_air::{AirBuilder, PairBuilder, PermutationAirBuilder, TwoRowMatrixView}; -use p3_uni_stark::StarkConfig; +use p3_field::AbstractField; +use valida_machine::config::StarkConfig; pub struct ProverConstraintFolder<'a, M: Machine, SC: StarkConfig> { pub(crate) machine: &'a M, - pub(crate) main: TwoRowMatrixView<'a, SC::Val>, - pub(crate) preprocessed: TwoRowMatrixView<'a, SC::Val>, - pub(crate) perm: TwoRowMatrixView<'a, SC::Challenge>, + pub(crate) preprocessed: TwoRowMatrixView<'a, SC::PackedVal>, + pub(crate) main: TwoRowMatrixView<'a, SC::PackedVal>, + pub(crate) perm: TwoRowMatrixView<'a, SC::PackedChallenge>, pub(crate) perm_challenges: &'a [SC::Challenge], - pub(crate) is_first_row: SC::Val, - pub(crate) is_last_row: SC::Val, - pub(crate) is_transition: SC::Val, + pub(crate) is_first_row: SC::PackedVal, + pub(crate) is_last_row: SC::PackedVal, + pub(crate) is_transition: SC::PackedVal, + pub(crate) alpha: SC::Challenge, + pub(crate) accumulator: SC::PackedChallenge, } impl<'a, M, SC> AirBuilder for ProverConstraintFolder<'a, M, SC> @@ -19,9 +22,9 @@ where SC: StarkConfig, { type F = SC::Val; - type Expr = SC::Val; // TODO: PackedVal - type Var = SC::Val; // TODO: PackedVal - type M = TwoRowMatrixView<'a, SC::Val>; // TODO: PackedVal + type Expr = SC::PackedVal; + type Var = SC::PackedVal; + type M = TwoRowMatrixView<'a, SC::PackedVal>; fn main(&self) -> Self::M { self.main @@ -43,8 +46,10 @@ where } } - fn assert_zero>(&mut self, _x: I) { - // TODO + fn assert_zero>(&mut self, x: I) { + let x: SC::PackedVal = x.into(); + self.accumulator *= SC::PackedChallenge::from_f(self.alpha); + self.accumulator += x; } } @@ -64,9 +69,9 @@ where SC: StarkConfig, { type EF = SC::Challenge; - type ExprEF = SC::Challenge; - type VarEF = SC::Challenge; - type MP = TwoRowMatrixView<'a, SC::Challenge>; // TODO: packed challenge? + type ExprEF = SC::PackedChallenge; + type VarEF = SC::PackedChallenge; + type MP = TwoRowMatrixView<'a, SC::PackedChallenge>; fn permutation(&self) -> Self::MP { self.perm diff --git a/machine/src/__internal/mod.rs b/machine/src/__internal/mod.rs index dcc893d4..a38fb54c 100644 --- a/machine/src/__internal/mod.rs +++ b/machine/src/__internal/mod.rs @@ -1,20 +1,21 @@ //! Items intended to be used only by `valida-derive`. -pub type DefaultField = BabyBear; -pub type DefaultExtensionField = BabyBear; // FIXME: Replace +// TODO: Move actual logic elsewhere, convert this whole module into a list of re-exports mod check_constraints; mod debug_builder; mod folding_builder; -mod prove; +mod quotient; pub use check_constraints::*; pub use debug_builder::*; pub use folding_builder::*; -use p3_baby_bear::BabyBear; -pub use prove::*; +pub use quotient::*; + +pub use crate::symbolic::symbolic_builder::*; // Re-export some Plonky3 crates so that derives can use them. +pub use p3_air; pub use p3_challenger; pub use p3_commit; pub use p3_matrix; diff --git a/machine/src/__internal/prove.rs b/machine/src/__internal/prove.rs deleted file mode 100644 index 7d201952..00000000 --- a/machine/src/__internal/prove.rs +++ /dev/null @@ -1,28 +0,0 @@ -//use crate::__internal::ConstraintFolder; -use crate::proof::ChipProof; -use crate::{Chip, Machine}; -use p3_air::Air; -use p3_uni_stark::{prove as stark_prove, ProverConstraintFolder, StarkConfig, SymbolicAirBuilder}; -/* - -pub fn prove( - machine: &M, - config: &SC, - air: &A, - challenger: &mut SC::Challenger, -) -> ChipProof -where - - M: Machine, - A: for<'a> Air> + Chip + Air>, - SC: StarkConfig, - -{ - let trace = air.generate_trace(&machine); - let proof = stark_prove(config,air,challenger,trace); - - ChipProof{ - proof - } -} -*/ diff --git a/machine/src/__internal/quotient.rs b/machine/src/__internal/quotient.rs new file mode 100644 index 00000000..9bd64239 --- /dev/null +++ b/machine/src/__internal/quotient.rs @@ -0,0 +1,234 @@ +use crate::__internal::ProverConstraintFolder; +use crate::config::StarkConfig; +use crate::symbolic::symbolic_builder::get_log_quotient_degree; +use crate::{Chip, Machine}; +use itertools::Itertools; +use p3_air::{Air, TwoRowMatrixView}; +use p3_commit::UnivariatePcsWithLde; +use p3_field::{ + cyclic_subgroup_coset_known_order, AbstractExtensionField, AbstractField, Field, PackedField, + TwoAdicField, +}; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{MatrixGet, MatrixRows}; +use p3_maybe_rayon::prelude::*; +use p3_uni_stark::{decompose_and_flatten, ZerofierOnCoset}; +use tracing::instrument; + +pub fn quotient( + machine: &M, + config: &SC, + air: &A, + log_degree: usize, + preprocessed_trace_lde: Option, + main_trace_lde: MainTraceLde, + perm_trace_lde: PermTraceLde, + perm_challenges: &[SC::Challenge], + alpha: SC::Challenge, +) -> RowMajorMatrix +where + M: Machine, + A: Chip, + SC: StarkConfig, + PreprocessedTraceLde: MatrixRows + MatrixGet + Sync, + MainTraceLde: MatrixRows + MatrixGet + Sync, + PermTraceLde: MatrixRows + MatrixGet + Sync, +{ + let pcs = config.pcs(); + let log_quotient_degree = get_log_quotient_degree::(machine, air); + + let log_stride_for_quotient = pcs.log_blowup() - log_quotient_degree; + let preprocessed_trace_lde_for_quotient = + preprocessed_trace_lde.map(|lde| lde.vertically_strided(1 << log_stride_for_quotient, 0)); + let main_trace_lde_for_quotient = + main_trace_lde.vertically_strided(1 << log_stride_for_quotient, 0); + let perm_trace_lde_for_quotient = + perm_trace_lde.vertically_strided(1 << log_stride_for_quotient, 0); + + let quotient_values = quotient_values::( + machine, + config, + air, + log_degree, + log_quotient_degree, + preprocessed_trace_lde_for_quotient, + main_trace_lde_for_quotient, + perm_trace_lde_for_quotient, + perm_challenges, + alpha, + ); + + decompose_and_flatten::( + quotient_values, + SC::Challenge::from_base(pcs.coset_shift()), + log_quotient_degree, + ) +} + +#[instrument(name = "compute quotient polynomial", skip_all)] +fn quotient_values( + machine: &M, + config: &SC, + air: &A, + log_degree: usize, + log_quotient_degree: usize, + preprocessed_trace_lde: Option, + main_trace_lde: MainTraceLde, + perm_trace_lde: PermTraceLde, + perm_challenges: &[SC::Challenge], + alpha: SC::Challenge, +) -> Vec +where + M: Machine, + SC: StarkConfig, + A: for<'a> Air>, + PreprocessedTraceLde: MatrixRows + MatrixGet + Sync, + MainTraceLde: MatrixRows + MatrixGet + Sync, + PermTraceLde: MatrixRows + MatrixGet + Sync, +{ + let degree = 1 << log_degree; + let log_quotient_size = log_degree + log_quotient_degree; + let quotient_size = 1 << log_quotient_size; + let g_subgroup = SC::Val::two_adic_generator(log_degree); + let g_extended = SC::Val::two_adic_generator(log_quotient_size); + let subgroup_last = g_subgroup.inverse(); + let coset_shift = config.pcs().coset_shift(); + let next_step = 1 << log_quotient_degree; + + let mut coset: Vec<_> = + cyclic_subgroup_coset_known_order(g_extended, coset_shift, quotient_size).collect(); + + let zerofier_on_coset = ZerofierOnCoset::new(log_degree, log_quotient_degree, coset_shift); + + // Evaluations of L_first(x) = Z_H(x) / (x - 1) on our coset s H. + let mut lagrange_first_evals = zerofier_on_coset.lagrange_basis_unnormalized(0); + let mut lagrange_last_evals = zerofier_on_coset.lagrange_basis_unnormalized(degree - 1); + + // We have a few vectors of length `quotient_size`, and we're going to take slices therein of + // length `WIDTH`. In the edge case where `quotient_size < WIDTH`, we need to pad those vectors + // in order for the slices to exist. The entries beyond quotient_size will be ignored, so we can + // just use default values. + for _ in quotient_size..SC::PackedVal::WIDTH { + coset.push(SC::Val::default()); + lagrange_first_evals.push(SC::Val::default()); + lagrange_last_evals.push(SC::Val::default()); + } + + (0..quotient_size) + .into_par_iter() + .step_by(SC::PackedVal::WIDTH) + .flat_map_iter(|i_local_start| { + let wrap = |i| i % quotient_size; + let i_next_start = wrap(i_local_start + next_step); + let i_range = i_local_start..i_local_start + SC::PackedVal::WIDTH; + + let x = *SC::PackedVal::from_slice(&coset[i_range.clone()]); + let is_transition = x - subgroup_last; + let is_first_row = *SC::PackedVal::from_slice(&lagrange_first_evals[i_range.clone()]); + let is_last_row = *SC::PackedVal::from_slice(&lagrange_last_evals[i_range]); + + let (preprocessed_local, preprocessed_next): (Vec<_>, Vec<_>) = + match &preprocessed_trace_lde { + Some(lde) => { + let local = (0..lde.width()) + .map(|col| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_local_start + offset); + lde.get(row, col) + }) + }) + .collect(); + let next = (0..lde.width()) + .map(|col| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_next_start + offset); + lde.get(row, col) + }) + }) + .collect(); + (local, next) + } + None => (vec![], vec![]), + }; + + let main_local: Vec<_> = (0..main_trace_lde.width()) + .map(|col| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_local_start + offset); + main_trace_lde.get(row, col) + }) + }) + .collect(); + let main_next: Vec<_> = (0..main_trace_lde.width()) + .map(|col| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_next_start + offset); + main_trace_lde.get(row, col) + }) + }) + .collect(); + + let ext_degree = >::D; + debug_assert_eq!(perm_trace_lde.width() % ext_degree, 0); + let perm_width_ext = perm_trace_lde.width() / ext_degree; + + let perm_local: Vec<_> = (0..perm_width_ext) + .map(|ext_col| { + SC::PackedChallenge::from_base_fn(|coeff_idx| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_local_start + offset); + perm_trace_lde.get(row, ext_col * ext_degree + coeff_idx) + }) + }) + }) + .collect(); + let perm_next: Vec<_> = (0..perm_width_ext) + .map(|ext_col| { + SC::PackedChallenge::from_base_fn(|coeff_idx| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_next_start + offset); + perm_trace_lde.get(row, ext_col * ext_degree + coeff_idx) + }) + }) + }) + .collect(); + + let accumulator = SC::PackedChallenge::zero(); + let mut folder = ProverConstraintFolder { + machine, + preprocessed: TwoRowMatrixView { + local: &preprocessed_local, + next: &preprocessed_next, + }, + main: TwoRowMatrixView { + local: &main_local, + next: &main_next, + }, + perm: TwoRowMatrixView { + local: &perm_local, + next: &perm_next, + }, + perm_challenges, + is_first_row, + is_last_row, + is_transition, + alpha, + accumulator, + }; + air.eval(&mut folder); + + // quotient(x) = constraints(x) / Z_H(x) + let zerofier_inv: SC::PackedVal = zerofier_on_coset.eval_inverse_packed(i_local_start); + let quotient = folder.accumulator * zerofier_inv; + + // "Transpose" D packed base coefficients into WIDTH scalar extension coefficients. + let limit = SC::PackedVal::WIDTH.min(quotient_size); + (0..limit).map(move |idx_in_packing| { + let quotient_value = (0..>::D) + .map(|coeff_idx| quotient.as_base_slice()[coeff_idx].as_slice()[idx_in_packing]) + .collect_vec(); + SC::Challenge::from_base_slice("ient_value) + }) + }) + .collect() +} diff --git a/machine/src/chip.rs b/machine/src/chip.rs index 4496fe52..ae3ba9c3 100644 --- a/machine/src/chip.rs +++ b/machine/src/chip.rs @@ -2,16 +2,18 @@ use crate::Machine; use crate::__internal::{DebugConstraintBuilder, ProverConstraintFolder}; use alloc::vec; use alloc::vec::Vec; -use valida_util::batch_multiplicative_inverse; -//use crate::config::StarkConfig; +use crate::config::StarkConfig; +use crate::symbolic::symbolic_builder::SymbolicAirBuilder; use p3_air::{Air, AirBuilder, PairBuilder, PermutationAirBuilder, VirtualPairCol}; -use p3_field::{AbstractExtensionField, AbstractField, ExtensionField, Field, Powers}; +use p3_field::{AbstractField, ExtensionField, Field, Powers}; use p3_matrix::{dense::RowMajorMatrix, Matrix, MatrixRowSlices}; -use p3_uni_stark::StarkConfig; +use valida_util::batch_multiplicative_inverse_allowing_zero; pub trait Chip, SC: StarkConfig>: - for<'a> Air> + for<'a> Air> + for<'a> Air> + + for<'a> Air> + + for<'a> Air> { /// Generate the main trace for the chip given the provided machine. fn generate_trace(&self, machine: &M) -> RowMajorMatrix; @@ -160,7 +162,9 @@ where } perm_values.extend(row); } - let perm_values = batch_multiplicative_inverse(perm_values); + // TODO: Switch to batch_multiplicative_inverse (not allowing zero)? + // Zero should be vanishingly unlikely if properly randomized? + let perm_values = batch_multiplicative_inverse_allowing_zero(perm_values); let mut perm = RowMajorMatrix::new(perm_values, perm_width); // Compute the running sum column @@ -180,10 +184,10 @@ where .apply::(preprocessed_row, main_row); match interaction_type { InteractionType::LocalSend | InteractionType::GlobalSend => { - phi[n] += SC::Challenge::from_base(mult) * perm_row[m]; + phi[n] += perm_row[m] * mult; } InteractionType::LocalReceive | InteractionType::GlobalReceive => { - phi[n] -= SC::Challenge::from_base(mult) * perm_row[m]; + phi[n] -= perm_row[m] * mult; } } } @@ -229,22 +233,22 @@ pub fn eval_permutation_constraints( let (alphas_local, alphas_global) = generate_rlc_elements(builder.machine(), chip, &rand_elems); let betas = rand_elems[2].powers(); - let lhs = phi_next - phi_local.clone(); - let mut rhs = AB::ExprEF::from_base(AB::Expr::zero()); - let mut phi_0 = AB::ExprEF::from_base(AB::Expr::zero()); + let lhs = phi_next.into() - phi_local.into(); + let mut rhs = AB::ExprEF::zero(); + let mut phi_0 = AB::ExprEF::zero(); for (m, (interaction, interaction_type)) in all_interactions.iter().enumerate() { // Reciprocal constraints - let mut rlc = AB::ExprEF::from_base(AB::Expr::zero()); + let mut rlc = AB::ExprEF::zero(); for (field, beta) in interaction.fields.iter().zip(betas.clone()) { let elem = field.apply::(preprocessed_local, main_local); - rlc += AB::ExprEF::from(beta) * elem; + rlc += AB::ExprEF::from_f(beta) * elem; } if interaction.is_local() { - rlc = rlc + alphas_local[interaction.argument_index()]; + rlc = rlc + AB::ExprEF::from_f(alphas_local[interaction.argument_index()]); } else { - rlc = rlc + alphas_global[interaction.argument_index()]; + rlc = rlc + AB::ExprEF::from_f(alphas_global[interaction.argument_index()]); } - builder.assert_one_ext::(rlc * perm_local[m]); + builder.assert_one_ext::(rlc * perm_local[m].into()); let mult_local = interaction .count @@ -256,12 +260,12 @@ pub fn eval_permutation_constraints( // Build the RHS of the permutation constraint match interaction_type { InteractionType::LocalSend | InteractionType::GlobalSend => { - phi_0 += AB::ExprEF::from_base(mult_local) * perm_local[m]; - rhs += AB::ExprEF::from_base(mult_next) * perm_next[m]; + phi_0 += perm_local[m].into() * mult_local; + rhs += perm_next[m].into() * mult_next; } InteractionType::LocalReceive | InteractionType::GlobalReceive => { - phi_0 -= AB::ExprEF::from_base(mult_local) * perm_local[m]; - rhs -= AB::ExprEF::from_base(mult_next) * perm_next[m]; + phi_0 -= perm_local[m].into() * mult_local; + rhs -= perm_next[m].into() * mult_next; } } } @@ -275,7 +279,7 @@ pub fn eval_permutation_constraints( .assert_eq_ext(perm_local.last().unwrap().clone(), phi_0); builder.when_last_row().assert_eq_ext( perm_local.last().unwrap().clone(), - AB::ExprEF::from(cumulative_sum), + AB::ExprEF::from_f(cumulative_sum), ); } diff --git a/machine/src/config.rs b/machine/src/config.rs index c6a7aea4..fbc5676a 100644 --- a/machine/src/config.rs +++ b/machine/src/config.rs @@ -11,7 +11,7 @@ pub trait StarkConfig { /// The field from which most random challenges are drawn. type Challenge: ExtensionField + TwoAdicField; - type PackedChallenge: AbstractExtensionField; + type PackedChallenge: AbstractExtensionField + Copy; /// The PCS used to commit to trace polynomials. type Pcs: UnivariatePcsWithLde< @@ -53,7 +53,7 @@ impl StarkConfig where Val: PrimeField32 + TwoAdicField, // TODO: Relax to Field? Challenge: ExtensionField + TwoAdicField, - PackedChallenge: AbstractExtensionField, + PackedChallenge: AbstractExtensionField + Copy, Pcs: UnivariatePcsWithLde, Challenger>, Challenger: FieldChallenger + Clone diff --git a/machine/src/lib.rs b/machine/src/lib.rs index 4faf352a..974eba10 100644 --- a/machine/src/lib.rs +++ b/machine/src/lib.rs @@ -8,17 +8,11 @@ use alloc::vec::Vec; use byteorder::{ByteOrder, LittleEndian}; pub use crate::core::Word; -pub use chip::{BusArgument, Chip, Interaction, InteractionType, ValidaAirBuilder}; -use p3_matrix::dense::RowMajorMatrix; - use crate::proof::MachineProof; -use p3_air::Air; +pub use chip::{BusArgument, Chip, Interaction, InteractionType, ValidaAirBuilder}; pub use p3_field::{ AbstractExtensionField, AbstractField, ExtensionField, Field, PrimeField, PrimeField64, }; -use p3_uni_stark::{ - Commitments, Proof, ProverConstraintFolder, ProverData, StarkConfig, SymbolicAirBuilder, -}; // TODO: some are also re-exported, so they shouldn't be pub? pub mod __internal; mod advice; @@ -26,7 +20,9 @@ pub mod chip; pub mod config; pub mod core; pub mod proof; +mod symbolic; +use crate::config::StarkConfig; pub use advice::*; pub use chip::*; pub use core::*; @@ -158,26 +154,16 @@ impl ProgramROM { } } -pub trait Machine { - fn run(&mut self, program: &ProgramROM, advice: &mut Adv); - - fn add_chip_trace( - &self, - config: &SC, - challenger: &mut SC::Challenger, - trace_commitments: &mut Vec>, - quotient_commitments: &mut Vec>, - log_degree: &mut Vec, - log_quotient_degrees: &mut Vec, - chip: &A, - trace: RowMajorMatrix<::Val>, - ) where - SC: StarkConfig, - A: Air> + for<'a> Air>; - - fn prove(&self, config: &SC, challenger: &mut SC::Challenger) -> MachineProof +pub trait Machine: Sync { + fn run(&mut self, program: &ProgramROM, advice: &mut Adv) + where + Adv: AdviceProvider; + + fn prove(&self, config: &SC) -> MachineProof where SC: StarkConfig; - fn verify>(proof: &MachineProof) -> Result<(), ()>; + fn verify(config: &SC, proof: &MachineProof) -> Result<(), ()> + where + SC: StarkConfig; } diff --git a/machine/src/proof.rs b/machine/src/proof.rs index 0975020b..c0e8f0c5 100644 --- a/machine/src/proof.rs +++ b/machine/src/proof.rs @@ -1,12 +1,41 @@ -use crate::{Chip, Machine}; +use crate::config::StarkConfig; use alloc::vec::Vec; -use p3_uni_stark::{Proof, StarkConfig}; -pub struct MachineProof { - //pub opening_proof: >>::Proof, - pub chip_proof: ChipProof, - pub phantom: core::marker::PhantomData, +use p3_commit::Pcs; +use p3_matrix::dense::RowMajorMatrix; +use serde::{Deserialize, Serialize}; + +type Val = ::Val; +type ValMat = RowMajorMatrix>; +type Com = <::Pcs as Pcs, ValMat>>::Commitment; + +#[derive(Serialize, Deserialize)] +#[serde(bound = "")] +pub struct MachineProof { + pub commitments: Commitments>, + pub opening_proof: >>::Proof, + pub chip_proofs: Vec>, +} + +#[derive(Serialize, Deserialize)] +pub struct Commitments { + pub main_trace: Com, + pub perm_trace: Com, + pub quotient_chunks: Com, +} + +#[derive(Serialize, Deserialize)] +pub struct ChipProof { + pub(crate) log_degree: usize, + pub(crate) opened_values: OpenedValues, } -pub struct ChipProof { - pub proof: Proof, +#[derive(Serialize, Deserialize)] +pub struct OpenedValues { + pub(crate) preprocessed_local: Vec, + pub(crate) preprocessed_next: Vec, + pub(crate) trace_local: Vec, + pub(crate) trace_next: Vec, + pub(crate) permutation_local: Vec, + pub(crate) permutation_next: Vec, + pub(crate) quotient_chunks: Vec, } diff --git a/machine/src/symbolic/mod.rs b/machine/src/symbolic/mod.rs new file mode 100644 index 00000000..2be2fc20 --- /dev/null +++ b/machine/src/symbolic/mod.rs @@ -0,0 +1,4 @@ +pub(crate) mod symbolic_builder; +mod symbolic_expression; +mod symbolic_expression_ext; +mod symbolic_variable; diff --git a/machine/src/symbolic/symbolic_builder.rs b/machine/src/symbolic/symbolic_builder.rs new file mode 100644 index 00000000..a0a0de23 --- /dev/null +++ b/machine/src/symbolic/symbolic_builder.rs @@ -0,0 +1,140 @@ +use alloc::vec; +use alloc::vec::Vec; + +use crate::config::StarkConfig; +use crate::{Machine, ValidaAirBuilder}; +use p3_air::{Air, AirBuilder, PairBuilder, PermutationAirBuilder}; +use p3_matrix::dense::RowMajorMatrix; +use p3_util::log2_ceil_usize; +use valida_machine::symbolic::symbolic_expression_ext::SymbolicExpressionExt; +use valida_machine::symbolic::symbolic_variable::Trace; + +use crate::symbolic::symbolic_expression::SymbolicExpression; +use crate::symbolic::symbolic_variable::SymbolicVariable; + +pub fn get_log_quotient_degree(machine: &M, air: &A) -> usize +where + M: Machine, + SC: StarkConfig, + A: for<'a> Air>, +{ + // We pad to at least degree 2, since a quotient argument doesn't make sense with smaller degrees. + let constraint_degree = get_max_constraint_degree(machine, air).max(2); + + // The quotient's actual degree is approximately (max_constraint_degree - 1) n, + // where subtracting 1 comes from division by the zerofier. + // But we pad it to a power of two so that we can efficiently decompose the quotient. + log2_ceil_usize(constraint_degree - 1) +} + +pub fn get_max_constraint_degree(machine: &M, air: &A) -> usize +where + M: Machine, + SC: StarkConfig, + A: for<'a> Air>, +{ + get_symbolic_constraints(machine, air) + .iter() + .map(|c| c.degree_multiple()) + .max() + .unwrap_or(0) +} + +pub fn get_symbolic_constraints(machine: &M, air: &A) -> Vec> +where + M: Machine, + SC: StarkConfig, + A: for<'a> Air>, +{ + let mut builder = SymbolicAirBuilder::new(machine, air.width()); + air.eval(&mut builder); + builder.constraints() +} + +/// An `AirBuilder` for evaluating constraints symbolically, and recording them for later use. +pub struct SymbolicAirBuilder<'a, M: Machine, SC: StarkConfig> { + machine: &'a M, + preprocessed: RowMajorMatrix>, + main: RowMajorMatrix>, + permutation: RowMajorMatrix>, + constraints: Vec>, +} + +impl<'a, M: Machine, SC: StarkConfig> SymbolicAirBuilder<'a, M, SC> { + pub(crate) fn new(machine: &'a M, width: usize) -> Self { + // TODO: `width` is for the main trace, what about others? + Self { + machine, + preprocessed: SymbolicVariable::window(Trace::Preprocessed, width), + main: SymbolicVariable::window(Trace::Main, width), + permutation: SymbolicVariable::window(Trace::Permutation, width), + constraints: vec![], + } + } + + pub(crate) fn constraints(self) -> Vec> { + self.constraints + } +} + +impl<'a, M: Machine, SC: StarkConfig> AirBuilder for SymbolicAirBuilder<'a, M, SC> { + type F = SC::Val; + type Expr = SymbolicExpression; + type Var = SymbolicVariable; + type M = RowMajorMatrix; + + fn main(&self) -> Self::M { + self.main.clone() + } + + fn is_first_row(&self) -> Self::Expr { + SymbolicExpression::IsFirstRow + } + + fn is_last_row(&self) -> Self::Expr { + SymbolicExpression::IsLastRow + } + + fn is_transition_window(&self, size: usize) -> Self::Expr { + if size == 2 { + SymbolicExpression::IsTransition + } else { + panic!("uni-stark only supports a window size of 2") + } + } + + fn assert_zero>(&mut self, x: I) { + self.constraints.push(x.into()); + } +} + +impl<'a, M: Machine, SC: StarkConfig> PairBuilder for SymbolicAirBuilder<'a, M, SC> { + fn preprocessed(&self) -> Self::M { + self.preprocessed.clone() + } +} + +impl<'a, M: Machine, SC: StarkConfig> PermutationAirBuilder + for SymbolicAirBuilder<'a, M, SC> +{ + type EF = SC::Challenge; + type ExprEF = SymbolicExpressionExt; + type VarEF = SymbolicVariable; + type MP = RowMajorMatrix; + + fn permutation(&self) -> Self::MP { + self.permutation.clone() + } + + fn permutation_randomness(&self) -> &[Self::EF] { + &[] // TODO + } +} + +impl<'a, M: Machine, SC: StarkConfig> ValidaAirBuilder for SymbolicAirBuilder<'a, M, SC> { + type Machine = M; + + fn machine(&self) -> &Self::Machine { + self.machine + } +} diff --git a/machine/src/symbolic/symbolic_expression.rs b/machine/src/symbolic/symbolic_expression.rs new file mode 100644 index 00000000..dfe96fe8 --- /dev/null +++ b/machine/src/symbolic/symbolic_expression.rs @@ -0,0 +1,225 @@ +use alloc::rc::Rc; +use core::fmt::Debug; +use core::iter::{Product, Sum}; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use p3_field::{AbstractField, Field}; + +use crate::symbolic::symbolic_variable::SymbolicVariable; + +/// An expression over `SymbolicVariable`s. +#[derive(Clone, Debug)] +pub enum SymbolicExpression { + Variable(SymbolicVariable), + IsFirstRow, + IsLastRow, + IsTransition, + Constant(F), + Add(Rc, Rc), + Sub(Rc, Rc), + Neg(Rc), + Mul(Rc, Rc), +} + +impl SymbolicExpression { + /// Returns the multiple of `n` (the trace length) in this expression's degree. + pub(crate) fn degree_multiple(&self) -> usize { + match self { + SymbolicExpression::Variable(_) => 1, + SymbolicExpression::IsFirstRow => 1, + SymbolicExpression::IsLastRow => 1, + SymbolicExpression::IsTransition => 0, + SymbolicExpression::Constant(_) => 0, + SymbolicExpression::Add(x, y) => x.degree_multiple().max(y.degree_multiple()), + SymbolicExpression::Sub(x, y) => x.degree_multiple().max(y.degree_multiple()), + SymbolicExpression::Neg(x) => x.degree_multiple(), + SymbolicExpression::Mul(x, y) => x.degree_multiple() + y.degree_multiple(), + } + } +} + +impl Default for SymbolicExpression { + fn default() -> Self { + Self::Constant(F::zero()) + } +} + +impl From for SymbolicExpression { + fn from(value: F) -> Self { + Self::Constant(value) + } +} + +impl AbstractField for SymbolicExpression { + type F = F; + + fn zero() -> Self { + Self::Constant(F::zero()) + } + fn one() -> Self { + Self::Constant(F::one()) + } + fn two() -> Self { + Self::Constant(F::two()) + } + fn neg_one() -> Self { + Self::Constant(F::neg_one()) + } + + #[inline] + fn from_f(f: Self::F) -> Self { + f.into() + } + + fn from_bool(b: bool) -> Self { + Self::Constant(F::from_bool(b)) + } + + fn from_canonical_u8(n: u8) -> Self { + Self::Constant(F::from_canonical_u8(n)) + } + + fn from_canonical_u16(n: u16) -> Self { + Self::Constant(F::from_canonical_u16(n)) + } + + fn from_canonical_u32(n: u32) -> Self { + Self::Constant(F::from_canonical_u32(n)) + } + + fn from_canonical_u64(n: u64) -> Self { + Self::Constant(F::from_canonical_u64(n)) + } + + fn from_canonical_usize(n: usize) -> Self { + Self::Constant(F::from_canonical_usize(n)) + } + + fn from_wrapped_u32(n: u32) -> Self { + Self::Constant(F::from_wrapped_u32(n)) + } + + fn from_wrapped_u64(n: u64) -> Self { + Self::Constant(F::from_wrapped_u64(n)) + } + + fn generator() -> Self { + Self::Constant(F::generator()) + } +} + +impl Add for SymbolicExpression { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self::Add(Rc::new(self), Rc::new(rhs)) + } +} + +impl Add for SymbolicExpression { + type Output = Self; + + fn add(self, rhs: F) -> Self { + self + Self::from(rhs) + } +} + +impl AddAssign for SymbolicExpression { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs; + } +} + +impl AddAssign for SymbolicExpression { + fn add_assign(&mut self, rhs: F) { + *self += Self::from(rhs); + } +} + +impl Sum for SymbolicExpression { + fn sum>(iter: I) -> Self { + iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + } +} + +impl Sum for SymbolicExpression { + fn sum>(iter: I) -> Self { + iter.map(|x| Self::from(x)).sum() + } +} + +impl Sub for SymbolicExpression { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Self::Sub(Rc::new(self), Rc::new(rhs)) + } +} + +impl Sub for SymbolicExpression { + type Output = Self; + + fn sub(self, rhs: F) -> Self { + self - Self::from(rhs) + } +} + +impl SubAssign for SymbolicExpression { + fn sub_assign(&mut self, rhs: Self) { + *self = self.clone() - rhs; + } +} + +impl SubAssign for SymbolicExpression { + fn sub_assign(&mut self, rhs: F) { + *self -= Self::from(rhs); + } +} + +impl Neg for SymbolicExpression { + type Output = Self; + + fn neg(self) -> Self { + Self::Neg(Rc::new(self)) + } +} + +impl Mul for SymbolicExpression { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + Self::Mul(Rc::new(self), Rc::new(rhs)) + } +} + +impl Mul for SymbolicExpression { + type Output = Self; + + fn mul(self, rhs: F) -> Self { + self * Self::from(rhs) + } +} + +impl MulAssign for SymbolicExpression { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs; + } +} + +impl MulAssign for SymbolicExpression { + fn mul_assign(&mut self, rhs: F) { + *self *= Self::from(rhs); + } +} + +impl Product for SymbolicExpression { + fn product>(iter: I) -> Self { + iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + } +} + +impl Product for SymbolicExpression { + fn product>(iter: I) -> Self { + iter.map(|x| Self::from(x)).product() + } +} diff --git a/machine/src/symbolic/symbolic_expression_ext.rs b/machine/src/symbolic/symbolic_expression_ext.rs new file mode 100644 index 00000000..653aa481 --- /dev/null +++ b/machine/src/symbolic/symbolic_expression_ext.rs @@ -0,0 +1,283 @@ +use crate::symbolic::symbolic_expression::SymbolicExpression; +use crate::symbolic::symbolic_variable::SymbolicVariable; +use alloc::rc::Rc; +use p3_field::{AbstractExtensionField, AbstractField, ExtensionField, Field}; +use std::iter::{Product, Sum}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +#[derive(Clone, Debug)] +pub struct SymbolicExpressionExt(pub SymbolicExpression) +where + EF: Field; + +impl Default for SymbolicExpressionExt { + fn default() -> Self { + Self(SymbolicExpression::zero()) + } +} + +impl From> for SymbolicExpressionExt { + fn from(value: SymbolicVariable) -> Self { + Self(value.into()) + } +} + +impl AbstractField for SymbolicExpressionExt +where + EF: Field, +{ + type F = EF; + + fn zero() -> Self { + Self(SymbolicExpression::zero()) + } + + fn one() -> Self { + Self(SymbolicExpression::one()) + } + + fn two() -> Self { + Self(SymbolicExpression::two()) + } + + fn neg_one() -> Self { + Self(SymbolicExpression::neg_one()) + } + + fn from_f(f: Self::F) -> Self { + Self(SymbolicExpression::from_f(f)) + } + + fn from_bool(b: bool) -> Self { + Self(SymbolicExpression::from_bool(b)) + } + + fn from_canonical_u8(n: u8) -> Self { + Self(SymbolicExpression::from_canonical_u8(n)) + } + + fn from_canonical_u16(n: u16) -> Self { + Self(SymbolicExpression::from_canonical_u16(n)) + } + + fn from_canonical_u32(n: u32) -> Self { + Self(SymbolicExpression::from_canonical_u32(n)) + } + + fn from_canonical_u64(n: u64) -> Self { + Self(SymbolicExpression::from_canonical_u64(n)) + } + + fn from_canonical_usize(n: usize) -> Self { + Self(SymbolicExpression::from_canonical_usize(n)) + } + + fn from_wrapped_u32(n: u32) -> Self { + Self(SymbolicExpression::from_wrapped_u32(n)) + } + + fn from_wrapped_u64(n: u64) -> Self { + Self(SymbolicExpression::from_wrapped_u64(n)) + } + + fn generator() -> Self { + Self(SymbolicExpression::generator()) + } +} + +fn map_rc(rc: Rc>) -> Rc> +where + F: Field, + EF: ExtensionField, +{ + Rc::new(SymbolicExpressionExt::from_base((*rc).clone()).0) +} + +impl AbstractExtensionField> for SymbolicExpressionExt +where + F: Field, + EF: ExtensionField, +{ + const D: usize = EF::D; + + fn from_base(b: SymbolicExpression) -> Self { + match b { + SymbolicExpression::Variable(v) => Self(SymbolicExpression::Variable(v.to_ext())), + SymbolicExpression::IsFirstRow => Self(SymbolicExpression::IsFirstRow), + SymbolicExpression::IsLastRow => Self(SymbolicExpression::IsLastRow), + SymbolicExpression::IsTransition => Self(SymbolicExpression::IsTransition), + SymbolicExpression::Constant(c) => Self(SymbolicExpression::Constant(EF::from_base(c))), + SymbolicExpression::Add(x, y) => Self(SymbolicExpression::Add(map_rc(x), map_rc(y))), + SymbolicExpression::Sub(x, y) => Self(SymbolicExpression::Sub(map_rc(x), map_rc(y))), + SymbolicExpression::Neg(x) => Self(SymbolicExpression::Neg(map_rc(x))), + SymbolicExpression::Mul(x, y) => Self(SymbolicExpression::Mul(map_rc(x), map_rc(y))), + } + } + + fn from_base_slice(_bs: &[SymbolicExpression]) -> Self { + todo!("from_base_slice") + } + + fn from_base_fn SymbolicExpression>(_f: FN) -> Self { + todo!("from_base_fn") + } + + fn as_base_slice(&self) -> &[SymbolicExpression] { + todo!("as_base_slice") + } +} + +impl Add for SymbolicExpressionExt +where + EF: Field, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0) + } +} + +impl Add> for SymbolicExpressionExt +where + F: Field, + EF: ExtensionField, +{ + type Output = Self; + + fn add(self, rhs: SymbolicExpression) -> Self::Output { + self + Self::from_base(rhs) + } +} + +impl AddAssign for SymbolicExpressionExt +where + EF: Field, +{ + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs; + } +} + +impl AddAssign> for SymbolicExpressionExt +where + F: Field, + EF: ExtensionField, +{ + fn add_assign(&mut self, rhs: SymbolicExpression) { + *self = self.clone() + rhs; + } +} + +impl Sum for SymbolicExpressionExt +where + EF: Field, +{ + fn sum>(iter: I) -> Self { + iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + } +} + +impl Sub for SymbolicExpressionExt +where + EF: Field, +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 - rhs.0) + } +} + +impl Sub> for SymbolicExpressionExt +where + F: Field, + EF: ExtensionField, +{ + type Output = Self; + + fn sub(self, rhs: SymbolicExpression) -> Self::Output { + self - Self::from_base(rhs) + } +} + +impl SubAssign for SymbolicExpressionExt +where + EF: Field, +{ + fn sub_assign(&mut self, rhs: Self) { + *self = self.clone() - rhs; + } +} + +impl SubAssign> for SymbolicExpressionExt +where + F: Field, + EF: ExtensionField, +{ + fn sub_assign(&mut self, rhs: SymbolicExpression) { + *self = self.clone() - rhs; + } +} + +impl Neg for SymbolicExpressionExt +where + EF: Field, +{ + type Output = Self; + + fn neg(self) -> Self { + Self(-self.0) + } +} + +impl Mul for SymbolicExpressionExt +where + EF: Field, +{ + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self(self.0 * rhs.0) + } +} + +impl Mul> for SymbolicExpressionExt +where + F: Field, + EF: ExtensionField, +{ + type Output = Self; + + fn mul(self, rhs: SymbolicExpression) -> Self::Output { + self * Self::from_base(rhs) + } +} + +impl MulAssign for SymbolicExpressionExt +where + EF: Field, +{ + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs; + } +} + +impl MulAssign> for SymbolicExpressionExt +where + F: Field, + EF: ExtensionField, +{ + fn mul_assign(&mut self, rhs: SymbolicExpression) { + *self = self.clone() * rhs; + } +} + +impl Product for SymbolicExpressionExt +where + EF: Field, +{ + fn product>(iter: I) -> Self { + iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + } +} diff --git a/machine/src/symbolic/symbolic_variable.rs b/machine/src/symbolic/symbolic_variable.rs new file mode 100644 index 00000000..38190454 --- /dev/null +++ b/machine/src/symbolic/symbolic_variable.rs @@ -0,0 +1,151 @@ +use core::marker::PhantomData; +use core::ops::{Add, Mul, Sub}; + +use p3_field::{ExtensionField, Field}; +use p3_matrix::dense::RowMajorMatrix; + +use crate::symbolic::symbolic_expression::SymbolicExpression; + +#[derive(Copy, Clone, Debug)] +pub enum Trace { + Preprocessed, + Main, + Permutation, +} + +/// A variable within the evaluation window, i.e. a column in either the local or next row. +#[derive(Copy, Clone, Debug)] +pub struct SymbolicVariable { + pub trace: Trace, + pub is_next: bool, + pub column: usize, + pub(crate) _phantom: PhantomData, +} + +impl SymbolicVariable { + pub(crate) fn window(trace: Trace, width: usize) -> RowMajorMatrix { + let values = [false, true] + .into_iter() + .flat_map(|is_next| { + (0..width).map(move |column| SymbolicVariable { + trace, + is_next, + column, + _phantom: PhantomData, + }) + }) + .collect(); + RowMajorMatrix::new(values, width) + } + + pub(crate) fn to_ext>(self) -> SymbolicVariable { + SymbolicVariable { + trace: self.trace, + is_next: self.is_next, + column: self.column, + _phantom: PhantomData, + } + } +} + +impl From> for SymbolicExpression { + fn from(value: SymbolicVariable) -> Self { + SymbolicExpression::Variable(value) + } +} + +impl Add for SymbolicVariable { + type Output = SymbolicExpression; + + fn add(self, rhs: Self) -> Self::Output { + SymbolicExpression::from(self) + SymbolicExpression::from(rhs) + } +} + +impl Add for SymbolicVariable { + type Output = SymbolicExpression; + + fn add(self, rhs: F) -> Self::Output { + SymbolicExpression::from(self) + SymbolicExpression::from(rhs) + } +} + +impl Add> for SymbolicVariable { + type Output = SymbolicExpression; + + fn add(self, rhs: SymbolicExpression) -> Self::Output { + SymbolicExpression::from(self) + rhs + } +} + +impl Add> for SymbolicExpression { + type Output = Self; + + fn add(self, rhs: SymbolicVariable) -> Self::Output { + self + Self::from(rhs) + } +} + +impl Sub for SymbolicVariable { + type Output = SymbolicExpression; + + fn sub(self, rhs: Self) -> Self::Output { + SymbolicExpression::from(self) - SymbolicExpression::from(rhs) + } +} + +impl Sub for SymbolicVariable { + type Output = SymbolicExpression; + + fn sub(self, rhs: F) -> Self::Output { + SymbolicExpression::from(self) - SymbolicExpression::from(rhs) + } +} + +impl Sub> for SymbolicVariable { + type Output = SymbolicExpression; + + fn sub(self, rhs: SymbolicExpression) -> Self::Output { + SymbolicExpression::from(self) - rhs + } +} + +impl Sub> for SymbolicExpression { + type Output = Self; + + fn sub(self, rhs: SymbolicVariable) -> Self::Output { + self - Self::from(rhs) + } +} + +impl Mul for SymbolicVariable { + type Output = SymbolicExpression; + + fn mul(self, rhs: Self) -> Self::Output { + SymbolicExpression::from(self) * SymbolicExpression::from(rhs) + } +} + +impl Mul for SymbolicVariable { + type Output = SymbolicExpression; + + fn mul(self, rhs: F) -> Self::Output { + SymbolicExpression::from(self) * SymbolicExpression::from(rhs) + } +} + +impl Mul> for SymbolicVariable { + type Output = SymbolicExpression; + + fn mul(self, rhs: SymbolicExpression) -> Self::Output { + SymbolicExpression::from(self) * rhs + } +} + +impl Mul> for SymbolicExpression { + type Output = Self; + + fn mul(self, rhs: SymbolicVariable) -> Self::Output { + self * Self::from(rhs) + } +} diff --git a/memory/src/lib.rs b/memory/src/lib.rs index c82c2535..65973100 100644 --- a/memory/src/lib.rs +++ b/memory/src/lib.rs @@ -10,11 +10,11 @@ use core::mem::transmute; use p3_air::VirtualPairCol; use p3_field::{Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; -use p3_uni_stark::StarkConfig; +use p3_maybe_rayon::prelude::*; use valida_bus::MachineWithMemBus; +use valida_machine::config::StarkConfig; use valida_machine::{BusArgument, Chip, Interaction, Machine, Word}; -use valida_util::batch_multiplicative_inverse; +use valida_util::batch_multiplicative_inverse_allowing_zero; pub mod columns; pub mod stark; @@ -293,7 +293,7 @@ impl MemoryChip { for n in 0..(rows.len() - 1) { let addr = ops[n].1.get_address(); let addr_next = ops[n + 1].1.get_address(); - let value = if (addr_next - addr) != 0 { + let value = if addr_next != addr { addr_next - addr } else { let clk = ops[n].0; @@ -305,7 +305,7 @@ impl MemoryChip { } // Compute `diff_inv` - let diff_inv = batch_multiplicative_inverse(diff.clone()); + let diff_inv = batch_multiplicative_inverse_allowing_zero(diff.clone()); // Set trace values for n in 0..(rows.len() - 1) { diff --git a/native_field/src/lib.rs b/native_field/src/lib.rs index d7534865..aac71118 100644 --- a/native_field/src/lib.rs +++ b/native_field/src/lib.rs @@ -16,8 +16,8 @@ use valida_util::pad_to_power_of_two; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; -use p3_uni_stark::StarkConfig; +use p3_maybe_rayon::prelude::*; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; diff --git a/output/src/lib.rs b/output/src/lib.rs index 158e24cb..87b38860 100644 --- a/output/src/lib.rs +++ b/output/src/lib.rs @@ -11,8 +11,8 @@ use valida_opcodes::WRITE; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::*; -use p3_uni_stark::StarkConfig; +use p3_maybe_rayon::prelude::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; diff --git a/program/src/lib.rs b/program/src/lib.rs index bc49796b..151300cf 100644 --- a/program/src/lib.rs +++ b/program/src/lib.rs @@ -12,7 +12,7 @@ use valida_util::pad_to_power_of_two; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; -use p3_uni_stark::StarkConfig; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; diff --git a/range/src/lib.rs b/range/src/lib.rs index 3d8ad826..aa988357 100644 --- a/range/src/lib.rs +++ b/range/src/lib.rs @@ -14,7 +14,7 @@ use valida_machine::{Chip, Machine, Word}; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; -use p3_uni_stark::StarkConfig; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; diff --git a/util/src/lib.rs b/util/src/lib.rs index af7a0f0f..a7c92194 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -3,7 +3,6 @@ extern crate alloc; use alloc::vec::Vec; - use p3_field::Field; /// Returns `[0, ..., N - 1]`. @@ -17,9 +16,9 @@ pub const fn indices_arr() -> [usize; N] { indices_arr } -/// Calculates and returns the multiplicative inverses of a vector of field elements, with zero +/// Calculates and returns the multiplicative inverses of each field element, with zero /// values remaining unchanged. -pub fn batch_multiplicative_inverse(values: Vec) -> Vec { +pub fn batch_multiplicative_inverse_allowing_zero(values: Vec) -> Vec { // Check if values are zero, and construct a new vector with only nonzero values let mut nonzero_values = Vec::with_capacity(values.len()); let mut indices = Vec::with_capacity(values.len());