Skip to content

Commit

Permalink
Finishing the prover
Browse files Browse the repository at this point in the history
  • Loading branch information
dlubarov committed Jan 23, 2024
1 parent 9bfe764 commit 213198c
Show file tree
Hide file tree
Showing 28 changed files with 1,227 additions and 107 deletions.
2 changes: 1 addition & 1 deletion alu_u32/src/add/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use valida_range::MachineWithRangeChip;
use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use p3_maybe_rayon::prelude::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

Expand Down
2 changes: 1 addition & 1 deletion alu_u32/src/bitwise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use valida_opcodes::{AND32, OR32, XOR32};
use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use p3_maybe_rayon::prelude::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

Expand Down
2 changes: 1 addition & 1 deletion alu_u32/src/div/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use valida_range::MachineWithRangeChip;
use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use p3_maybe_rayon::prelude::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

Expand Down
2 changes: 1 addition & 1 deletion alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use valida_opcodes::LT32;
use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use p3_maybe_rayon::prelude::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

Expand Down
2 changes: 1 addition & 1 deletion alu_u32/src/shift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use valida_opcodes::{DIV32, MUL32, SDIV32, SHL32, SHR32, SRA32};
use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use p3_maybe_rayon::prelude::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

Expand Down
2 changes: 1 addition & 1 deletion alu_u32/src/sub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use valida_range::MachineWithRangeChip;
use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use p3_maybe_rayon::prelude::*;
use valida_machine::config::StarkConfig;
use valida_util::pad_to_power_of_two;

Expand Down
2 changes: 1 addition & 1 deletion basic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use valida_output::{MachineWithOutputChip, OutputChip, WriteInstruction};
use valida_program::{MachineWithProgramChip, ProgramChip};
use valida_range::{MachineWithRangeChip, RangeCheckerChip};

use p3_maybe_rayon::*;
use p3_maybe_rayon::prelude::*;
use valida_machine::config::StarkConfig;

#[derive(Machine, Default)]
Expand Down
3 changes: 2 additions & 1 deletion basic/tests/test_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ fn prove_fibonacci() {

type Quotient = QuotientMmcs<Val, Challenge, ValMmcs>;
type MyFriConfig = FriConfigImpl<Val, Challenge, Quotient, ChallengeMmcs, Challenger>;
let fri_config = MyFriConfig::new(40, challenge_mmcs);
// TODO: Change log_blowup from 2 to 1 once degree >3 constraints are eliminated.
let fri_config = MyFriConfig::new(2, 40, challenge_mmcs);
let ldt = FriLdt { config: fri_config };

type Pcs = FriBasedPcs<MyFriConfig, ValMmcs, Dft, Challenger>;
Expand Down
5 changes: 3 additions & 2 deletions cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use valida_util::batch_multiplicative_inverse;
use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use p3_maybe_rayon::prelude::*;
use valida_machine::config::StarkConfig;

pub mod columns;
Expand Down Expand Up @@ -71,7 +71,8 @@ where
fn generate_trace(&self, machine: &M) -> RowMajorMatrix<SC::Val> {
let mut rows = self
.operations
.par_iter()
.as_slice()
.into_par_iter()
.enumerate()
.map(|(n, op)| self.op_to_row::<M, SC>(n, op, machine))
.collect::<Vec<_>>();
Expand Down
113 changes: 83 additions & 30 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,46 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 {
})
.collect::<TokenStream2>();

let prove_starks = chips
let quotient_degree_calls = chips
.iter()
.map(|chip| {
let chip_name = chip.ident.as_ref().unwrap();
quote! {
get_log_quotient_degree::<Self, SC, _>(self, self.#chip_name()),
}
})
.collect::<TokenStream2>();

let compute_quotients = chips
.iter()
.enumerate()
.map(|(n, chip)| {
.map(|(i, chip)| {
let chip_name = chip.ident.as_ref().unwrap();
quote! {
#[cfg(debug_assertions)]
check_constraints::<Self, _, SC>(
self,
self.#chip_name(),
&main_traces[#n],
&perm_traces[#n],
&main_traces[#i],
&perm_traces[#i],
&perm_challenges,
);

chip_proofs.push(prove(self, config, self.#chip_name(), &mut challenger));
// TODO: Needlessly regenerating preprocessed_trace()
let ppt: Option<RowMajorMatrix<SC::Val>> = self.#chip_name().preprocessed_trace();
let preprocessed_trace_lde = ppt.map(|trace| preprocessed_trace_ldes.remove(0));

quotients.push(quotient(
self,
config,
self.#chip_name(),
log_degrees[#i],
preprocessed_trace_lde,
main_trace_ldes.remove(0),
perm_trace_ldes.remove(0),
&perm_challenges,
alpha,
));
}
})
.collect::<TokenStream2>();
Expand All @@ -211,26 +235,45 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 {
fn prove<SC: StarkConfig<Val = F>>(&self, config: &SC) -> ::valida_machine::proof::MachineProof<SC>
{
use ::valida_machine::__internal::*;
use ::valida_machine::__internal::p3_air::{BaseAir};
use ::valida_machine::__internal::p3_challenger::{CanObserve, FieldChallenger};
use ::valida_machine::__internal::p3_commit::{Pcs, UnivariatePcs};
use ::valida_machine::__internal::p3_commit::{Pcs, UnivariatePcs, UnivariatePcsWithLde};
use ::valida_machine::__internal::p3_matrix::{Matrix, dense::RowMajorMatrix};
use ::valida_machine::__internal::p3_util::log2_strict_usize;
use ::valida_machine::chip::generate_permutation_trace;
use ::valida_machine::proof::MachineProof;
use ::valida_machine::proof::{MachineProof, ChipProof, Commitments};
use alloc::vec;
use alloc::vec::Vec;
use alloc::boxed::Box;

let mut chips: [Box<&dyn Chip<Self, SC>>; #num_chips] = [ #chip_list ];
let log_quotient_degrees: [usize; #num_chips] = [ #quotient_degree_calls ];

let mut challenger = config.challenger();

let main_traces: [RowMajorMatrix<SC::Val>; #num_chips] = tracing::info_span!("generate main traces")
.in_scope(||
chips.par_iter().map(|chip| {
chip.generate_trace(self)
}).collect::<Vec<_>>().try_into().unwrap()
);
let pcs = config.pcs();

let preprocessed_traces: Vec<RowMajorMatrix<SC::Val>> =
tracing::info_span!("generate preprocessed traces")
.in_scope(||
chips.par_iter()
.flat_map(|chip| chip.preprocessed_trace())
.collect::<Vec<_>>()
);

let (preprocessed_commit, preprocessed_data) =
tracing::info_span!("commit to preprocessed traces")
.in_scope(|| pcs.commit_batches(preprocessed_traces.to_vec()));
challenger.observe(preprocessed_commit.clone());
let mut preprocessed_trace_ldes = pcs.get_ldes(&preprocessed_data);

let main_traces: [RowMajorMatrix<SC::Val>; #num_chips] =
tracing::info_span!("generate main traces")
.in_scope(||
chips.par_iter()
.map(|chip| chip.generate_trace(self))
.collect::<Vec<_>>()
.try_into().unwrap()
);

let degrees: [usize; #num_chips] = main_traces.iter()
.map(|trace| trace.height())
Expand All @@ -240,10 +283,9 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 {
let g_subgroups = log_degrees.map(|log_deg| SC::Val::two_adic_generator(log_deg));

let (main_commit, main_data) = tracing::info_span!("commit to main traces")
.in_scope(||
config.pcs().commit_batches(main_traces.to_vec())
);
.in_scope(|| pcs.commit_batches(main_traces.to_vec()));
challenger.observe(main_commit.clone());
let mut main_trace_ldes = pcs.get_ldes(&main_data);

let mut perm_challenges = Vec::new();
for _ in 0..3 {
Expand All @@ -259,12 +301,23 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 {

let (perm_commit, perm_data) = tracing::info_span!("commit to permutation traces")
.in_scope(|| {
let flattened_perm_traces = perm_traces.iter().map(|trace| {
trace.flatten_to_base()
}).collect::<Vec<_>>();
config.pcs().commit_batches(flattened_perm_traces)
let flattened_perm_traces = perm_traces.iter()
.map(|trace| trace.flatten_to_base())
.collect::<Vec<_>>();
pcs.commit_batches(flattened_perm_traces)
});
challenger.observe(perm_commit.clone());
let mut perm_trace_ldes = pcs.get_ldes(&perm_data);

let alpha: SC::Challenge = challenger.sample_ext_element();

let mut quotients: Vec<RowMajorMatrix<SC::Val>> = vec![];
#compute_quotients
let (quotient_commit, quotient_data) = tracing::info_span!("commit to quotient chunks")
.in_scope(|| pcs.commit_batches(quotients.to_vec()));

#[cfg(debug_assertions)]
check_cumulative_sums(&perm_traces[..]);

let zeta: SC::Challenge = challenger.sample_ext_element();
let zeta_and_next: [Vec<SC::Challenge>; #num_chips] =
Expand All @@ -275,19 +328,19 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 {
// TODO: Enable when we have quotient computation
// (&quotient_data, &[zeta.exp_power_of_2(log_quotient_degree)]),
];
let (openings, opening_proof) = config.pcs().open_multi_batches(
let (openings, opening_proof) = pcs.open_multi_batches(
&prover_data_and_points, &mut challenger);

let mut chip_proofs = vec![];
#prove_starks

#[cfg(debug_assertions)]
check_cumulative_sums(&perm_traces[..]);

let commitments = Commitments {
main_trace: main_commit,
perm_trace: perm_commit,
quotient_chunks: quotient_commit,
};
let chip_proofs = vec![]; // TODO
MachineProof {
// opening_proof,
commitments,
opening_proof,
chip_proofs,
phantom: core::marker::PhantomData,
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions machine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ std = []
[dependencies]
byteorder = "1.4.3"
itertools = "0.10.3"
serde = { version = "1.0", default-features = false, features = ["derive"] }
tracing = "0.1.37"

p3-air = { path = "../../Plonky3/air" }
p3-baby-bear = { path = "../../Plonky3/baby-bear" }
Expand All @@ -20,6 +22,7 @@ p3-dft = { path = "../../Plonky3/dft" }
p3-field = { path = "../../Plonky3/field" }
p3-matrix = { path = "../../Plonky3/matrix" }
p3-maybe-rayon = { path = "../../Plonky3/maybe-rayon" }
p3-uni-stark = { path = "../../Plonky3/uni-stark" }
p3-util = { path = "../../Plonky3/util" }

valida-util = { path = "../util" }
8 changes: 4 additions & 4 deletions machine/src/__internal/check_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use crate::__internal::DebugConstraintBuilder;
use crate::chip::eval_permutation_constraints;
use crate::config::StarkConfig;
use crate::{Chip, Machine};
use p3_air::{Air, TwoRowMatrixView};
use p3_air::TwoRowMatrixView;
use p3_field::{AbstractField, Field};
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use p3_matrix::MatrixRowSlices;
use p3_maybe_rayon::{MaybeIntoParIter, ParallelIterator};
use p3_maybe_rayon::prelude::*;

/// Check that all constraints vanish on the subgroup.
pub fn check_constraints<M, A, SC>(
Expand All @@ -17,8 +17,8 @@ pub fn check_constraints<M, A, SC>(
perm: &RowMajorMatrix<SC::Challenge>,
perm_challenges: &[SC::Challenge],
) where
M: Machine<SC::Val> + Sync,
A: Chip<M, SC> + for<'a> Air<DebugConstraintBuilder<'a, M, SC>>,
M: Machine<SC::Val>,
A: Chip<M, SC>,
SC: StarkConfig,
{
assert_eq!(main.height(), perm.height());
Expand Down
33 changes: 19 additions & 14 deletions machine/src/__internal/folding_builder.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
use crate::{Machine, ValidaAirBuilder};
use p3_air::{AirBuilder, PairBuilder, PermutationAirBuilder, TwoRowMatrixView};
use p3_field::AbstractField;
use valida_machine::config::StarkConfig;

pub struct ProverConstraintFolder<'a, M: Machine<SC::Val>, SC: StarkConfig> {
pub(crate) machine: &'a M,
pub(crate) main: TwoRowMatrixView<'a, SC::Val>,
pub(crate) preprocessed: TwoRowMatrixView<'a, SC::Val>,
pub(crate) perm: TwoRowMatrixView<'a, SC::Challenge>,
pub(crate) preprocessed: TwoRowMatrixView<'a, SC::PackedVal>,
pub(crate) main: TwoRowMatrixView<'a, SC::PackedVal>,
pub(crate) perm: TwoRowMatrixView<'a, SC::PackedChallenge>,
pub(crate) perm_challenges: &'a [SC::Challenge],
pub(crate) is_first_row: SC::Val,
pub(crate) is_last_row: SC::Val,
pub(crate) is_transition: SC::Val,
pub(crate) is_first_row: SC::PackedVal,
pub(crate) is_last_row: SC::PackedVal,
pub(crate) is_transition: SC::PackedVal,
pub(crate) alpha: SC::Challenge,
pub(crate) accumulator: SC::PackedChallenge,
}

impl<'a, M, SC> AirBuilder for ProverConstraintFolder<'a, M, SC>
Expand All @@ -19,9 +22,9 @@ where
SC: StarkConfig,
{
type F = SC::Val;
type Expr = SC::Val; // TODO: PackedVal
type Var = SC::Val; // TODO: PackedVal
type M = TwoRowMatrixView<'a, SC::Val>; // TODO: PackedVal
type Expr = SC::PackedVal;
type Var = SC::PackedVal;
type M = TwoRowMatrixView<'a, SC::PackedVal>;

fn main(&self) -> Self::M {
self.main
Expand All @@ -43,8 +46,10 @@ where
}
}

fn assert_zero<I: Into<Self::Expr>>(&mut self, _x: I) {
// TODO
fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
let x: SC::PackedVal = x.into();
self.accumulator *= SC::PackedChallenge::from_f(self.alpha);
self.accumulator += x;
}
}

Expand All @@ -64,9 +69,9 @@ where
SC: StarkConfig,
{
type EF = SC::Challenge;
type ExprEF = SC::Challenge;
type VarEF = SC::Challenge;
type MP = TwoRowMatrixView<'a, SC::Challenge>; // TODO: packed challenge?
type ExprEF = SC::PackedChallenge;
type VarEF = SC::PackedChallenge;
type MP = TwoRowMatrixView<'a, SC::PackedChallenge>;

fn permutation(&self) -> Self::MP {
self.perm
Expand Down
Loading

0 comments on commit 213198c

Please sign in to comment.