diff --git a/Cargo.lock b/Cargo.lock index 76fac0dfe..13cca77b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -845,6 +845,23 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "gkr_iop" +version = "0.1.0" +dependencies = [ + "ark-std", + "ff", + "ff_ext", + "goldilocks", + "itertools 0.13.0", + "multilinear_extensions", + "rand", + "rayon", + "subprotocols", + "thiserror 2.0.8", + "transcript", +] + [[package]] name = "glob" version = "0.3.2" @@ -2065,6 +2082,23 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "subprotocols" +version = "0.1.0" +dependencies = [ + "ark-std", + "criterion", + "ff", + "ff_ext", + "goldilocks", + "itertools 0.13.0", + "multilinear_extensions", + "rand", + "rayon", + "thiserror 2.0.8", + "transcript", +] + [[package]] name = "subtle" version = "2.6.1" diff --git a/Cargo.toml b/Cargo.toml index 2d0b62b5d..d9dd1f2a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,8 @@ members = [ "multilinear_extensions", "sumcheck_macro", "poseidon", + "gkr_iop", + "subprotocols", "sumcheck", "transcript", ] @@ -31,6 +33,7 @@ cfg-if = "1.0" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" ff = "0.13" +ff_ext = { path = "./ff_ext" } goldilocks = { git = "https://github.com/scroll-tech/ceno-Goldilocks" } itertools = "0.13" num-derive = "0.4" @@ -51,12 +54,15 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" strum = "0.26" strum_macros = "0.26" +subprotocols = { path = "./subprotocols" } +thiserror = "2.0.3" tiny-keccak = { version = "2.0.2", features = ["keccak"] } tracing = { version = "0.1", features = [ "attributes", ] } tracing-forest = { version = "0.1.6" } tracing-subscriber = { version = "0.3", features = ["env-filter"] } +transcript = { path = "./transcript" } [profile.dev] lto = "thin" diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 97a7c9364..bb444cb36 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -20,11 +20,11 @@ serde_json.workspace = true base64 = "0.22" ceno_emul = { path = "../ceno_emul" } -ff_ext = { path = "../ff_ext" } +ff_ext.workspace = true mpcs = { path = "../mpcs" } multilinear_extensions = { version = "0", path = "../multilinear_extensions" } sumcheck = { version = "0", path = "../sumcheck" } -transcript = { path = "../transcript" } +transcript.workspace = true itertools.workspace = true num-traits.workspace = true diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml new file mode 100644 index 000000000..ab4e19ae5 --- /dev/null +++ b/gkr_iop/Cargo.toml @@ -0,0 +1,23 @@ +[package] +categories.workspace = true +description = "GKR IOP protocol implementation" +edition.workspace = true +keywords.workspace = true +license.workspace = true +name = "gkr_iop" +readme.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +ark-std.workspace = true +ff.workspace = true +ff_ext.workspace = true +goldilocks.workspace = true +itertools.workspace = true +multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } +rand.workspace = true +rayon.workspace = true +subprotocols.workspace = true +thiserror.workspace = true +transcript.workspace = true diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs new file mode 100644 index 000000000..857c78432 --- /dev/null +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -0,0 +1,294 @@ +use std::{marker::PhantomData, mem, sync::Arc}; + +use ff_ext::ExtensionField; +use gkr_iop::{ + ProtocolBuilder, ProtocolWitnessGenerator, + chip::Chip, + evaluation::{EvalExpression, PointAndEval}, + gkr::{ + GKRCircuitWitness, GKRProverOutput, + layer::{Layer, LayerType, LayerWitness}, + }, +}; +use goldilocks::GoldilocksExt2; +use itertools::{Itertools, izip}; +use rand::{Rng, rngs::OsRng}; +use subprotocols::expression::{Constant, Expression}; +use transcript::{BasicTranscript, Transcript}; + +#[cfg(debug_assertions)] +use gkr_iop::gkr::mock::MockProver; +#[cfg(debug_assertions)] +use subprotocols::expression::VectorType; + +type E = GoldilocksExt2; + +#[derive(Clone, Debug, Default)] +struct TowerParams { + height: usize, +} + +#[derive(Clone, Debug, Default)] +struct TowerChipLayout { + params: TowerParams, + + // Committed poly indices. + committed_table_id: usize, + committed_count_id: usize, + + lookup_challenge: Constant, + + output_cumulative_sum: [EvalExpression; 2], + + _field: PhantomData, +} + +impl ProtocolBuilder for TowerChipLayout { + type Params = TowerParams; + + fn init(params: Self::Params) -> Self { + Self { + params, + ..Default::default() + } + } + + fn build_commit_phase(&mut self, chip: &mut Chip) { + [self.committed_table_id, self.committed_count_id] = chip.allocate_committed_base(); + [self.lookup_challenge] = chip.allocate_challenges(); + } + + fn build_gkr_phase(&mut self, chip: &mut Chip) { + let height = self.params.height; + let lookup_challenge = Expression::Const(self.lookup_challenge.clone()); + + self.output_cumulative_sum = chip.allocate_output_evals(); + + // Tower layers + let ([updated_table, count], challenges) = (0..height).fold( + (self.output_cumulative_sum.clone(), vec![]), + |([den, num], challenges), i| { + let [den_0, den_1, num_0, num_1] = if i == height - 1 { + // Allocate witnesses in the extension field, except numerator inputs in the base field. + let ([num_0, num_1], [den_0, den_1]) = chip.allocate_wits_in_layer(); + [den_0, den_1, num_0, num_1] + } else { + let ([], [den_0, den_1, num_0, num_1]) = chip.allocate_wits_in_layer(); + [den_0, den_1, num_0, num_1] + }; + + let [den_expr_0, den_expr_1, num_expr_0, num_expr_1]: [Expression; 4] = [ + den_0.0.into(), + den_1.0.into(), + num_0.0.into(), + num_1.0.into(), + ]; + let (in_bases, in_exts) = if i == height - 1 { + (vec![num_0.1.clone(), num_1.1.clone()], vec![ + den_0.1.clone(), + den_1.1.clone(), + ]) + } else { + (vec![], vec![ + den_0.1.clone(), + den_1.1.clone(), + num_0.1.clone(), + num_1.1.clone(), + ]) + }; + chip.add_layer(Layer::new( + format!("Tower_layer_{}", i), + LayerType::Zerocheck, + vec![ + den_expr_0.clone() * den_expr_1.clone(), + den_expr_0 * num_expr_1 + den_expr_1 * num_expr_0, + ], + challenges, + in_bases, + in_exts, + vec![den, num], + )); + let [challenge] = chip.allocate_challenges(); + ( + [ + EvalExpression::Partition( + vec![Box::new(den_0.1), Box::new(den_1.1)], + vec![(0, challenge.clone())], + ), + EvalExpression::Partition( + vec![Box::new(num_0.1), Box::new(num_1.1)], + vec![(0, challenge.clone())], + ), + ], + vec![challenge], + ) + }, + ); + + // Preprocessing layer, compute table + challenge + let ([table], []) = chip.allocate_wits_in_layer(); + + chip.add_layer(Layer::new( + "Update_table".to_string(), + LayerType::Linear, + vec![lookup_challenge + table.0.into()], + challenges, + vec![table.1.clone()], + vec![], + vec![updated_table], + )); + + chip.allocate_base_opening(self.committed_table_id, table.1); + chip.allocate_base_opening(self.committed_count_id, count); + } +} + +pub struct TowerChipTrace { + pub table: Vec, + pub multiplicity: Vec, +} + +impl ProtocolWitnessGenerator for TowerChipLayout +where + E: ExtensionField, +{ + type Trace = TowerChipTrace; + + fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + let mut res = vec![vec![]; 2]; + res[self.committed_table_id] = phase1.table.into_iter().map(E::BaseField::from).collect(); + res[self.committed_count_id] = phase1 + .multiplicity + .into_iter() + .map(E::BaseField::from) + .collect(); + res + } + + fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness { + // Generate witnesses. + let table = &phase1[self.committed_table_id]; + let count = &phase1[self.committed_count_id]; + let beta = self.lookup_challenge.entry(challenges); + + // Compute table + beta. + let n_layers = self.params.height + 1; + let mut layer_wits = Vec::>::with_capacity(n_layers); + layer_wits.push(LayerWitness::new(vec![table.clone()], vec![])); + + // Compute den_0, den_1, num_0, num_1 for each layer. + let updated_table = table.iter().map(|x| beta + x).collect_vec(); + + let (num_0, num_1): (Vec, Vec) = count.iter().tuples().unzip(); + let (den_0, den_1): (Vec, Vec) = updated_table.into_iter().tuples().unzip(); + let (mut last_den, mut last_num): (Vec<_>, Vec<_>) = izip!(&den_0, &den_1, &num_0, &num_1) + .map(|(&den_0, &den_1, &num_0, &num_1)| (den_0 * den_1, den_0 * num_1 + den_1 * num_0)) + .unzip(); + + layer_wits.push(LayerWitness::new(vec![num_0, num_1], vec![den_0, den_1])); + + layer_wits.extend((1..self.params.height).map(|_i| { + let (den_0, den_1): (Vec, Vec) = + mem::take(&mut last_den).into_iter().tuples().unzip(); + let (num_0, num_1): (Vec, Vec) = + mem::take(&mut last_num).into_iter().tuples().unzip(); + + (last_den, last_num) = izip!(&den_0, &den_1, &num_0, &num_1) + .map(|(&den_0, &den_1, &num_0, &num_1)| { + (den_0 * den_1, den_0 * num_1 + den_1 * num_0) + }) + .unzip(); + + LayerWitness::new(vec![], vec![den_0, den_1, num_0, num_1]) + })); + layer_wits.reverse(); + + GKRCircuitWitness { layers: layer_wits } + } +} + +fn main() { + let log_size = 3; + let params = TowerParams { height: log_size }; + let (layout, chip) = TowerChipLayout::build(params); + let gkr_circuit = chip.gkr_circuit(); + + let (out_evals, gkr_proof) = { + let table = (0..1 << log_size) + .map(|_| OsRng.gen_range(0..1 << log_size as u64)) + .collect_vec(); + let count = (0..1 << log_size) + .map(|_| OsRng.gen_range(0..1 << log_size as u64)) + .collect_vec(); + let phase1_witness = layout.phase1_witness(TowerChipTrace { + table, + multiplicity: count, + }); + + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + + let challenges = vec![ + prover_transcript + .get_and_append_challenge(b"lookup challenge") + .elements, + ]; + let gkr_witness = layout.gkr_witness(&phase1_witness, &challenges); + + #[cfg(debug_assertions)] + { + let last = gkr_witness.layers[0].exts.clone(); + MockProver::check( + gkr_circuit.clone(), + &gkr_witness, + vec![ + VectorType::Ext(vec![last[0][0] * last[1][0]]), + VectorType::Ext(vec![last[0][0] * last[3][0] + last[1][0] * last[2][0]]), + ], + challenges.clone(), + ) + .expect("Mock prover failed"); + } + + let out_evals = { + let last = gkr_witness.layers[0].exts.clone(); + let point = Arc::new(vec![]); + assert_eq!(last[0].len(), 1); + vec![ + PointAndEval { + point: point.clone(), + eval: last[0][0] * last[1][0], + }, + PointAndEval { + point, + eval: last[0][0] * last[3][0] + last[1][0] * last[2][0], + }, + ] + }; + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_witness, &out_evals, &challenges, &mut prover_transcript) + .expect("Failed to prove phase"); + + // Omit the PCS opening phase. + + (out_evals, gkr_proof) + }; + + { + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + let challenges = vec![ + verifier_transcript + .get_and_append_challenge(b"lookup challenge") + .elements, + ]; + + gkr_circuit + .verify(gkr_proof, &out_evals, &challenges, &mut verifier_transcript) + .expect("GKR verify failed"); + + // Omit the PCS opening phase. + } +} diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs new file mode 100644 index 000000000..b40b66f9c --- /dev/null +++ b/gkr_iop/src/chip.rs @@ -0,0 +1,28 @@ +use crate::{evaluation::EvalExpression, gkr::layer::Layer}; + +pub mod builder; +pub mod protocol; + +/// Chip stores all information required in the GKR protocol, including the +/// commit phases, the GKR phase and the opening phase. +#[derive(Clone, Debug, Default)] +pub struct Chip { + /// The number of base inputs committed in the whole protocol. + pub n_committed_bases: usize, + /// The number of ext inputs committed in the whole protocol. + pub n_committed_exts: usize, + + /// The number of challenges generated through the whole protocols + /// (except the ones inside sumcheck protocols). + pub n_challenges: usize, + /// All input evaluations generated at the end of layer protocols will be stored + /// in a vector and this is the length. + pub n_evaluations: usize, + /// The layers of the GKR circuit, in the order outputs-to-inputs. + pub layers: Vec, + + /// The polynomial index and evaluation expressions of the base inputs. + pub base_openings: Vec<(usize, EvalExpression)>, + /// The polynomial index and evaluation expressions of the ext inputs. + pub ext_openings: Vec<(usize, EvalExpression)>, +} diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs new file mode 100644 index 000000000..ac742fe04 --- /dev/null +++ b/gkr_iop/src/chip/builder.rs @@ -0,0 +1,91 @@ +use std::array; + +use subprotocols::expression::{Constant, Witness}; + +use crate::{ + evaluation::EvalExpression, + gkr::layer::{Layer, LayerType}, +}; + +use super::Chip; + +impl Chip { + /// Allocate indices for committing base field polynomials. + pub fn allocate_committed_base(&mut self) -> [usize; N] { + self.n_committed_bases += N; + array::from_fn(|i| i + self.n_committed_bases - N) + } + + /// Allocate indices for committing extension field polynomials. + pub fn allocate_committed_ext(&mut self) -> [usize; N] { + self.n_committed_exts += N; + array::from_fn(|i| i + self.n_committed_exts - N) + } + + /// Allocate `Witness` and `EvalExpression` for the input polynomials in a layer. + /// Where `Witness` denotes the index and `EvalExpression` denotes the position + /// to place the evaluation of the polynomial after processing the layer prover + /// for each polynomial. This should be called at most once for each layer! + #[allow(clippy::type_complexity)] + pub fn allocate_wits_in_layer( + &mut self, + ) -> ( + [(Witness, EvalExpression); M], + [(Witness, EvalExpression); N], + ) { + let bases = array::from_fn(|i| { + ( + Witness::BasePoly(i), + EvalExpression::Single(i + self.n_evaluations), + ) + }); + self.n_evaluations += M; + let exts = array::from_fn(|i| { + ( + Witness::ExtPoly(i), + EvalExpression::Single(i + self.n_evaluations), + ) + }); + self.n_evaluations += N; + (bases, exts) + } + + /// Generate the evaluation expression for each output. + pub fn allocate_output_evals(&mut self) -> [EvalExpression; N] { + self.n_evaluations += N; + array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) + } + + /// Allocate challenges. + pub fn allocate_challenges(&mut self) -> [Constant; N] { + self.n_challenges += N; + array::from_fn(|i| Constant::Challenge(i + self.n_challenges - N)) + } + + /// Allocate a PCS opening action to a base polynomial with index `wit_index`. + /// The `EvalExpression` represents the expression to compute the evaluation. + pub fn allocate_base_opening(&mut self, wit_index: usize, eval: EvalExpression) { + self.base_openings.push((wit_index, eval)); + } + + /// Allocate a PCS opening action to an ext polynomial with index `wit_index`. + /// The `EvalExpression` represents the expression to compute the evaluation. + pub fn allocate_ext_opening(&mut self, wit_index: usize, eval: EvalExpression) { + self.ext_openings.push((wit_index, eval)); + } + + /// Add a layer to the circuit. + pub fn add_layer(&mut self, layer: Layer) { + assert_eq!(layer.outs.len(), layer.exprs.len()); + match layer.ty { + LayerType::Linear => { + assert!(layer.exprs.iter().all(|expr| expr.degree() == 1)); + } + LayerType::Sumcheck => { + assert_eq!(layer.exprs.len(), 1); + } + _ => {} + } + self.layers.push(layer); + } +} diff --git a/gkr_iop/src/chip/protocol.rs b/gkr_iop/src/chip/protocol.rs new file mode 100644 index 000000000..a633791eb --- /dev/null +++ b/gkr_iop/src/chip/protocol.rs @@ -0,0 +1,16 @@ +use crate::gkr::GKRCircuit; + +use super::Chip; + +impl Chip { + /// Extract information from Chip that required in the GKR phase. + pub fn gkr_circuit(&'_ self) -> GKRCircuit<'_> { + GKRCircuit { + layers: &self.layers, + n_challenges: self.n_challenges, + n_evaluations: self.n_evaluations, + base_openings: &self.base_openings, + ext_openings: &self.ext_openings, + } + } +} diff --git a/gkr_iop/src/error.rs b/gkr_iop/src/error.rs new file mode 100644 index 000000000..02c64f96e --- /dev/null +++ b/gkr_iop/src/error.rs @@ -0,0 +1,8 @@ +use subprotocols::error::VerifierError; +use thiserror::Error; + +#[derive(Clone, Debug, Error)] +pub enum BackendError { + #[error("layer verification failed: {0:?}, {1:?}")] + LayerVerificationFailed(String, VerifierError), +} diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs new file mode 100644 index 000000000..f9d60914c --- /dev/null +++ b/gkr_iop/src/evaluation.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; + +use ff_ext::ExtensionField; +use itertools::{Itertools, izip}; +use multilinear_extensions::virtual_poly::build_eq_x_r_vec_sequential; +use subprotocols::expression::{Constant, Point}; + +/// Evaluation expression for the gkr layer reduction and PCS opening preparation. +#[derive(Clone, Debug)] +pub enum EvalExpression { + /// Single entry in the evaluation vector. + Single(usize), + /// Linear expression of an entry with the scalar and offset. + Linear(usize, Constant, Constant), + /// Merging multiple evaluations which denotes a partition of the original + /// polynomial. `(usize, Constant)` denote the modification of the point. + /// For example, when it receive a point `(p0, p1, p2, p3)` from a succeeding + /// layer, `vec![(2, c0), (4, c1)]` will modify the point to `(p0, p1, c0, p2, c1, p3)`. + /// where the indices specify how the partition applied to the original polynomial. + Partition(Vec>, Vec<(usize, Constant)>), +} + +#[derive(Clone, Debug, Default)] +pub struct PointAndEval { + pub point: Point, + pub eval: E, +} + +impl Default for EvalExpression { + fn default() -> Self { + EvalExpression::Single(0) + } +} + +impl EvalExpression { + pub fn evaluate( + &self, + evals: &[PointAndEval], + challenges: &[E], + ) -> PointAndEval { + match self { + EvalExpression::Single(i) => evals[*i].clone(), + EvalExpression::Linear(i, c0, c1) => PointAndEval { + point: evals[*i].point.clone(), + eval: evals[*i].eval * c0.evaluate(challenges) + c1.evaluate(challenges), + }, + EvalExpression::Partition(parts, indices) => { + assert!(izip!(indices.iter(), indices.iter().skip(1)).all(|(a, b)| a.0 < b.0)); + let vars = indices + .iter() + .map(|(_, c)| c.evaluate(challenges)) + .collect_vec(); + + let parts = parts + .iter() + .map(|part| part.evaluate(evals, &vars)) + .collect_vec(); + assert_eq!(parts.len(), 1 << indices.len()); + assert!(parts.iter().all(|part| part.point == parts[0].point)); + + let mut new_point = parts[0].point.to_vec(); + for (index_in_point, c) in indices { + new_point.insert(*index_in_point, c.evaluate(challenges)); + } + + let eq = build_eq_x_r_vec_sequential(&vars); + let eval = izip!(parts, &eq).fold(E::ZERO, |acc, (part, eq)| acc + part.eval * eq); + + PointAndEval { + point: Arc::new(new_point), + eval, + } + } + } + } + + pub fn entry<'a, T>(&self, evals: &'a [T]) -> &'a T { + match self { + EvalExpression::Single(i) => &evals[*i], + _ => unreachable!(), + } + } + + pub fn entry_mut<'a, T>(&self, evals: &'a mut [T]) -> &'a mut T { + match self { + EvalExpression::Single(i) => &mut evals[*i], + _ => unreachable!(), + } + } +} diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs new file mode 100644 index 000000000..1e632a18f --- /dev/null +++ b/gkr_iop/src/gkr.rs @@ -0,0 +1,110 @@ +use ff_ext::ExtensionField; +use itertools::{Itertools, chain, izip}; +use layer::{Layer, LayerWitness}; +use subprotocols::{expression::Point, sumcheck::SumcheckProof}; +use transcript::Transcript; + +use crate::{ + error::BackendError, + evaluation::{EvalExpression, PointAndEval}, +}; + +pub mod layer; +pub mod mock; + +#[derive(Clone, Debug)] +pub struct GKRCircuit<'a> { + pub layers: &'a [Layer], + + pub n_challenges: usize, + pub n_evaluations: usize, + pub base_openings: &'a [(usize, EvalExpression)], + pub ext_openings: &'a [(usize, EvalExpression)], +} + +#[derive(Clone, Debug)] +pub struct GKRCircuitWitness { + pub layers: Vec>, +} + +pub struct GKRProverOutput { + pub gkr_proof: GKRProof, + pub opening_evaluations: Vec, +} + +pub struct GKRProof(pub Vec>); + +pub struct Evaluation { + pub value: E, + pub point: Point, + pub poly: usize, +} + +pub struct GKRClaims(pub Vec); + +impl GKRCircuit<'_> { + pub fn prove( + &self, + circuit_wit: GKRCircuitWitness, + out_evals: &[PointAndEval], + challenges: &[E], + transcript: &mut impl Transcript, + ) -> Result>, BackendError> + where + E: ExtensionField, + { + let mut evaluations = out_evals.to_vec(); + evaluations.resize(self.n_evaluations, PointAndEval::default()); + let mut challenges = challenges.to_vec(); + let sumcheck_proofs = izip!(self.layers, circuit_wit.layers) + .map(|(layer, layer_wit)| { + layer.prove(layer_wit, &mut evaluations, &mut challenges, transcript) + }) + .collect_vec(); + + let opening_evaluations = self.opening_evaluations(&evaluations, &challenges); + + Ok(GKRProverOutput { + gkr_proof: GKRProof(sumcheck_proofs), + opening_evaluations, + }) + } + + pub fn verify( + &self, + gkr_proof: GKRProof, + out_evals: &[PointAndEval], + challenges: &[E], + transcript: &mut impl Transcript, + ) -> Result>, BackendError> + where + E: ExtensionField, + { + let GKRProof(sumcheck_proofs) = gkr_proof; + + let mut challenges = challenges.to_vec(); + let mut evaluations = out_evals.to_vec(); + evaluations.resize(self.n_evaluations, PointAndEval::default()); + for (layer, layer_proof) in izip!(self.layers, sumcheck_proofs) { + layer.verify(layer_proof, &mut evaluations, &mut challenges, transcript)?; + } + + Ok(GKRClaims( + self.opening_evaluations(&evaluations, &challenges), + )) + } + + fn opening_evaluations( + &self, + evaluations: &[PointAndEval], + challenges: &[E], + ) -> Vec> { + chain!(self.base_openings, self.ext_openings) + .map(|(poly, eval)| { + let poly = *poly; + let PointAndEval { point, eval: value } = eval.evaluate(evaluations, challenges); + Evaluation { value, point, poly } + }) + .collect_vec() + } +} diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs new file mode 100644 index 000000000..d77fa9d3c --- /dev/null +++ b/gkr_iop/src/gkr/layer.rs @@ -0,0 +1,239 @@ +use ark_std::log2; +use ff_ext::ExtensionField; +use itertools::{chain, izip}; +use linear_layer::LinearLayer; +use subprotocols::{ + expression::{Constant, Expression, Point}, + sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, +}; +use sumcheck_layer::SumcheckLayer; +use transcript::Transcript; +use zerocheck_layer::ZerocheckLayer; + +use crate::{ + error::BackendError, + evaluation::{EvalExpression, PointAndEval}, + utils::SliceVector, +}; + +pub mod linear_layer; +pub mod sumcheck_layer; +pub mod zerocheck_layer; + +#[derive(Clone, Debug)] +pub enum LayerType { + Sumcheck, + Zerocheck, + Linear, +} + +#[derive(Clone, Debug)] +pub struct Layer { + pub name: String, + pub ty: LayerType, + /// Challenges generated at the beginning of the layer protocol. + pub challenges: Vec, + /// Expressions to prove in this layer. For zerocheck and linear layers, each + /// expression corresponds to an output. While in sumcheck, there is only 1 + /// expression, which corresponds to the sum of all outputs. This design is + /// for the convenience when building the following expression: + /// `e_0 + beta * e_1 = sum_x (eq(p_0, x) + beta * eq(p_1, x)) expr(x)`. + /// where `vec![e_0, beta * e_1]` will be the output evaluation expressions. + pub exprs: Vec, + /// Positions to place the evaluations of the base inputs of this layer. + pub in_bases: Vec, + /// Positions to place the evaluations of the ext inputs of this layer. + pub in_exts: Vec, + /// The expressions of the evaluations from the succeeding layers, which are + /// connected to the outputs of this layer. + pub outs: Vec, +} + +#[derive(Clone, Debug)] +pub struct LayerWitness { + pub bases: Vec>, + pub exts: Vec>, + pub num_vars: usize, +} + +impl Layer { + #[allow(clippy::too_many_arguments)] + pub fn new( + name: String, + ty: LayerType, + exprs: Vec, + challenges: Vec, + in_bases: Vec, + in_exts: Vec, + outs: Vec, + ) -> Self { + Self { + name, + ty, + challenges, + exprs, + in_bases, + in_exts, + outs, + } + } + + #[allow(clippy::too_many_arguments)] + pub fn prove>( + &self, + wit: LayerWitness, + claims: &mut [PointAndEval], + challenges: &mut Vec, + transcript: &mut Trans, + ) -> SumcheckProof { + self.update_challenges(challenges, transcript); + #[allow(unused)] + let (sigmas, out_points) = self.sigmas_and_points(claims, challenges); + + let SumcheckProverOutput { + point: in_point, + proof, + } = match self.ty { + LayerType::Sumcheck => >::prove( + self, + wit, + &out_points.slice_vector(), + challenges, + transcript, + ), + LayerType::Zerocheck => >::prove( + self, + wit, + &out_points.slice_vector(), + challenges, + transcript, + ), + LayerType::Linear => { + assert!(out_points.iter().all(|point| point == &out_points[0])); + >::prove(self, wit, &out_points[0], transcript) + } + }; + + self.update_claims( + claims, + &proof.base_mle_evals, + &proof.ext_mle_evals, + &in_point, + ); + + proof + } + + pub fn verify>( + &self, + proof: SumcheckProof, + claims: &mut [PointAndEval], + challenges: &mut Vec, + transcript: &mut Trans, + ) -> Result<(), BackendError> { + self.update_challenges(challenges, transcript); + let (sigmas, points) = self.sigmas_and_points(claims, challenges); + + let SumcheckClaims { + in_point, + base_mle_evals, + ext_mle_evals, + } = match self.ty { + LayerType::Sumcheck => >::verify( + self, + proof, + &sigmas.iter().sum(), + points.slice_vector(), + challenges, + transcript, + )?, + LayerType::Zerocheck => >::verify( + self, + proof, + sigmas, + points.slice_vector(), + challenges, + transcript, + )?, + LayerType::Linear => { + assert!(points.iter().all(|point| point == &points[0])); + >::verify( + self, proof, &sigmas, &points[0], challenges, transcript, + )? + } + }; + + self.update_claims(claims, &base_mle_evals, &ext_mle_evals, &in_point); + + Ok(()) + } + + fn sigmas_and_points( + &self, + claims: &[PointAndEval], + challenges: &[E], + ) -> (Vec, Vec>) { + self.outs + .iter() + .map(|out| { + let tmp = out.evaluate(claims, challenges); + (tmp.eval, tmp.point) + }) + .unzip() + } + + fn update_challenges( + &self, + challenges: &mut Vec, + transcript: &mut impl Transcript, + ) { + for challenge in &self.challenges { + let value = transcript.get_and_append_challenge(b"linear layer challenge"); + match challenge { + Constant::Challenge(i) => { + if challenges.len() <= *i { + challenges.resize(*i + 1, E::ZERO); + } + challenges[*i] = value.elements; + } + _ => unreachable!(), + } + } + } + + fn update_claims( + &self, + claims: &mut [PointAndEval], + base_mle_evals: &[E], + ext_mle_evals: &[E], + point: &Point, + ) { + for (value, pos) in izip!(chain![base_mle_evals, ext_mle_evals], chain![ + &self.in_bases, + &self.in_exts + ]) { + *(pos.entry_mut(claims)) = PointAndEval { + point: point.clone(), + eval: *value, + }; + } + } +} + +impl LayerWitness { + pub fn new(bases: Vec>, exts: Vec>) -> Self { + assert!(!bases.is_empty() || !exts.is_empty()); + let num_vars = if bases.is_empty() { + log2(exts[0].len()) + } else { + log2(bases[0].len()) + } as usize; + assert!(bases.iter().all(|b| b.len() == 1 << num_vars)); + assert!(exts.iter().all(|e| e.len() == 1 << num_vars)); + Self { + bases, + exts, + num_vars, + } + } +} diff --git a/gkr_iop/src/gkr/layer/linear_layer.rs b/gkr_iop/src/gkr/layer/linear_layer.rs new file mode 100644 index 000000000..40d8a6ec8 --- /dev/null +++ b/gkr_iop/src/gkr/layer/linear_layer.rs @@ -0,0 +1,105 @@ +use ff_ext::ExtensionField; +use itertools::{Itertools, izip}; +use subprotocols::{ + error::VerifierError, + expression::Point, + sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, + utils::{evaluate_mle_ext, evaluate_mle_inplace}, +}; +use transcript::Transcript; + +use crate::error::BackendError; + +use super::{Layer, LayerWitness}; + +pub trait LinearLayer { + fn prove( + &self, + wit: LayerWitness, + out_point: &Point, + transcript: &mut impl Transcript, + ) -> SumcheckProverOutput; + + fn verify( + &self, + proof: SumcheckProof, + sigmas: &[E], + out_point: &Point, + challenges: &[E], + transcript: &mut impl Transcript, + ) -> Result, BackendError>; +} + +impl LinearLayer for Layer { + fn prove( + &self, + wit: LayerWitness, + out_point: &Point, + transcript: &mut impl Transcript, + ) -> SumcheckProverOutput { + let base_mle_evals = wit + .bases + .iter() + .map(|base| evaluate_mle_ext(base, out_point)) + .collect_vec(); + + transcript.append_field_element_exts(&base_mle_evals); + + let ext_mle_evals = wit + .exts + .into_iter() + .map(|mut ext| evaluate_mle_inplace(&mut ext, out_point)) + .collect_vec(); + + transcript.append_field_element_exts(&ext_mle_evals); + + SumcheckProverOutput { + proof: SumcheckProof { + univariate_polys: vec![], + ext_mle_evals, + base_mle_evals, + }, + point: out_point.clone(), + } + } + + fn verify( + &self, + proof: SumcheckProof, + sigmas: &[E], + out_point: &Point, + challenges: &[E], + transcript: &mut impl Transcript, + ) -> Result, BackendError> { + let SumcheckProof { + univariate_polys: _, + ext_mle_evals, + base_mle_evals, + } = proof; + + transcript.append_field_element_exts(&ext_mle_evals); + transcript.append_field_element_exts(&base_mle_evals); + + for (sigma, expr) in izip!(sigmas, &self.exprs) { + let got = expr.evaluate( + &ext_mle_evals, + &base_mle_evals, + &[out_point], + &[], + challenges, + ); + if *sigma != got { + return Err(BackendError::LayerVerificationFailed( + self.name.clone(), + VerifierError::::ClaimNotMatch(expr.clone(), *sigma, got), + )); + } + } + + Ok(SumcheckClaims { + base_mle_evals, + ext_mle_evals, + in_point: out_point.clone(), + }) + } +} diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs new file mode 100644 index 000000000..afb65d9a2 --- /dev/null +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -0,0 +1,75 @@ +use ff_ext::ExtensionField; +use subprotocols::sumcheck::{ + SumcheckClaims, SumcheckProof, SumcheckProverOutput, SumcheckProverState, SumcheckVerifierState, +}; +use transcript::Transcript; + +use crate::{ + error::BackendError, + utils::{SliceVector, SliceVectorMut}, +}; + +use super::{Layer, LayerWitness}; + +pub trait SumcheckLayer { + #[allow(clippy::too_many_arguments)] + fn prove( + &self, + wit: LayerWitness, + out_points: &[&[E]], + challenges: &[E], + transcript: &mut impl Transcript, + ) -> SumcheckProverOutput; + + fn verify( + &self, + proof: SumcheckProof, + sigma: &E, + out_points: Vec<&[E]>, + challenges: &[E], + transcript: &mut impl Transcript, + ) -> Result, BackendError>; +} + +impl SumcheckLayer for Layer { + fn prove( + &self, + mut wit: LayerWitness, + out_points: &[&[E]], + challenges: &[E], + transcript: &mut impl Transcript, + ) -> SumcheckProverOutput { + let prover_state = SumcheckProverState::new( + self.exprs[0].clone(), + out_points, + wit.exts.slice_vector_mut(), + wit.bases.slice_vector(), + challenges, + transcript, + ); + + prover_state.prove() + } + + fn verify( + &self, + proof: SumcheckProof, + sigma: &E, + out_points: Vec<&[E]>, + challenges: &[E], + transcript: &mut impl Transcript, + ) -> Result, BackendError> { + let verifier_state = SumcheckVerifierState::new( + *sigma, + self.exprs[0].clone(), + out_points, + proof, + challenges, + transcript, + ); + + verifier_state + .verify() + .map_err(|e| BackendError::LayerVerificationFailed(self.name.clone(), e)) + } +} diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs new file mode 100644 index 000000000..e9e9bc577 --- /dev/null +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -0,0 +1,76 @@ +use ff_ext::ExtensionField; +use subprotocols::{ + sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, + zerocheck::{ZerocheckProverState, ZerocheckVerifierState}, +}; +use transcript::Transcript; + +use crate::{ + error::BackendError, + utils::{SliceVector, SliceVectorMut}, +}; + +use super::{Layer, LayerWitness}; + +pub trait ZerocheckLayer { + #[allow(clippy::too_many_arguments)] + fn prove( + &self, + wit: LayerWitness, + out_points: &[&[E]], + challenges: &[E], + transcript: &mut impl Transcript, + ) -> SumcheckProverOutput; + + fn verify( + &self, + proof: SumcheckProof, + sigmas: Vec, + out_points: Vec<&[E]>, + challenges: &[E], + transcript: &mut impl Transcript, + ) -> Result, BackendError>; +} + +impl ZerocheckLayer for Layer { + fn prove( + &self, + mut wit: LayerWitness, + out_points: &[&[E]], + challenges: &[E], + transcript: &mut impl Transcript, + ) -> SumcheckProverOutput { + let prover_state = ZerocheckProverState::new( + self.exprs.clone(), + out_points, + wit.exts.slice_vector_mut(), + wit.bases.slice_vector(), + challenges, + transcript, + ); + + prover_state.prove() + } + + fn verify( + &self, + proof: SumcheckProof, + sigmas: Vec, + out_points: Vec<&[E]>, + challenges: &[E], + transcript: &mut impl Transcript, + ) -> Result, BackendError> { + let verifier_state = ZerocheckVerifierState::new( + sigmas, + self.exprs.clone(), + out_points, + proof, + challenges, + transcript, + ); + + verifier_state + .verify() + .map_err(|e| BackendError::LayerVerificationFailed(self.name.clone(), e)) + } +} diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs new file mode 100644 index 000000000..660f5156e --- /dev/null +++ b/gkr_iop/src/gkr/mock.rs @@ -0,0 +1,173 @@ +use std::marker::PhantomData; + +use ff_ext::ExtensionField; +use itertools::{Itertools, izip}; +use rand::rngs::OsRng; +use subprotocols::{ + expression::{Expression, VectorType}, + test_utils::random_point, + utils::eq_vecs, +}; +use thiserror::Error; + +use crate::{evaluation::EvalExpression, utils::SliceIterator}; + +use super::{GKRCircuit, GKRCircuitWitness, layer::LayerType}; + +pub struct MockProver(PhantomData); + +#[derive(Clone, Debug, Error)] +pub enum MockProverError { + #[error("sumcheck layer should have only one expression, got {0}")] + SumcheckExprLenError(usize), + #[error("sumcheck expression not match, out: {0:?}, expr: {1:?}, expect: {2:?}. got: {3:?}")] + SumcheckExpressionNotMatch( + Vec, + Expression, + VectorType, + VectorType, + ), + #[error("zerocheck expression not match, out: {0:?}, expr: {1:?}, expect: {2:?}. got: {3:?}")] + ZerocheckExpressionNotMatch(EvalExpression, Expression, VectorType, VectorType), + #[error("linear expression not match, out: {0:?}, expr: {1:?}, expect: {2:?}. got: {3:?}")] + LinearExpressionNotMatch(EvalExpression, Expression, VectorType, VectorType), +} + +impl MockProver { + pub fn check( + circuit: GKRCircuit<'_>, + circuit_wit: &GKRCircuitWitness, + mut evaluations: Vec>, + mut challenges: Vec, + ) -> Result<(), MockProverError> { + evaluations.resize(circuit.n_evaluations, VectorType::Base(vec![])); + challenges.resize_with(circuit.n_challenges, || E::random(OsRng)); + for (layer, layer_wit) in izip!(circuit.layers, &circuit_wit.layers) { + let num_vars = layer_wit.num_vars; + let points = (0..layer.outs.len()) + .map(|_| random_point::(OsRng, num_vars)) + .collect_vec(); + let eqs = eq_vecs(points.slice_iter(), &vec![E::ONE; points.len()]); + let gots = layer + .exprs + .iter() + .map(|expr| expr.calc(&layer_wit.exts, &layer_wit.bases, &eqs, &challenges)) + .collect_vec(); + let expects = layer + .outs + .iter() + .map(|out| out.mock_evaluate(&evaluations, &challenges, 1 << num_vars)) + .collect_vec(); + match layer.ty { + LayerType::Sumcheck => { + if gots.len() != 1 { + return Err(MockProverError::SumcheckExprLenError(gots.len())); + } + let got = gots.into_iter().next().unwrap(); + let expect = expects.into_iter().reduce(|a, b| a + b).unwrap(); + if expect != got { + return Err(MockProverError::SumcheckExpressionNotMatch( + layer.outs.clone(), + layer.exprs[0].clone(), + expect, + got, + )); + } + } + LayerType::Zerocheck => { + for (got, expect, expr, out) in izip!(gots, expects, &layer.exprs, &layer.outs) + { + if expect != got { + return Err(MockProverError::ZerocheckExpressionNotMatch( + out.clone(), + expr.clone(), + expect, + got, + )); + } + } + } + LayerType::Linear => { + for (got, expect, expr, out) in izip!(gots, expects, &layer.exprs, &layer.outs) + { + if expect != got { + return Err(MockProverError::LinearExpressionNotMatch( + out.clone(), + expr.clone(), + expect, + got, + )); + } + } + } + } + for (in_pos, base) in izip!(&layer.in_bases, &layer_wit.bases) { + *(in_pos.entry_mut(&mut evaluations)) = VectorType::Base(base.clone()); + } + for (in_pos, ext) in izip!(&layer.in_exts, &layer_wit.exts) { + *(in_pos.entry_mut(&mut evaluations)) = VectorType::Ext(ext.clone()); + } + } + Ok(()) + } +} + +impl EvalExpression { + pub fn mock_evaluate( + &self, + evals: &[VectorType], + challenges: &[E], + len: usize, + ) -> VectorType { + match self { + EvalExpression::Single(i) => evals[*i].clone(), + EvalExpression::Linear(i, c0, c1) => { + evals[*i].clone() * VectorType::Ext(vec![c0.evaluate(challenges); len]) + + VectorType::Ext(vec![c1.evaluate(challenges); len]) + } + EvalExpression::Partition(parts, indices) => { + assert_eq!(parts.len(), 1 << indices.len()); + let parts = parts + .iter() + .map(|part| part.mock_evaluate(evals, challenges, len)) + .collect_vec(); + indices + .iter() + .fold(parts, |acc, (i, _c)| { + let step_size = 1 << i; + acc.chunks_exact(2) + .map(|chunk| match (&chunk[0], &chunk[1]) { + (VectorType::Base(v0), VectorType::Base(v1)) => { + let res = (0..v0.len()) + .step_by(step_size) + .flat_map(|j| { + v0[j..j + step_size] + .iter() + .chain(v1[j..j + step_size].iter()) + .cloned() + }) + .collect_vec(); + VectorType::Base(res) + } + (VectorType::Ext(v0), VectorType::Ext(v1)) => { + let res = (0..v0.len()) + .step_by(step_size) + .flat_map(|j| { + v0[j..j + step_size] + .iter() + .chain(v1[j..j + step_size].iter()) + .cloned() + }) + .collect_vec(); + VectorType::Ext(res) + } + _ => unreachable!(), + }) + .collect_vec() + }) + .pop() + .unwrap() + } + } + } +} diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs new file mode 100644 index 000000000..dff90eb6d --- /dev/null +++ b/gkr_iop/src/lib.rs @@ -0,0 +1,59 @@ +use std::marker::PhantomData; + +use chip::Chip; +use ff_ext::ExtensionField; +use gkr::GKRCircuitWitness; +use transcript::Transcript; + +pub mod chip; +pub mod error; +pub mod evaluation; +pub mod gkr; +pub mod utils; + +pub trait ProtocolBuilder: Sized { + type Params; + + fn init(params: Self::Params) -> Self; + + /// Build the protocol for GKR IOP. + fn build(params: Self::Params) -> (Self, Chip) { + let mut chip_spec = Self::init(params); + let mut chip = Chip::default(); + chip_spec.build_commit_phase(&mut chip); + chip_spec.build_gkr_phase(&mut chip); + + (chip_spec, chip) + } + + /// Specify the polynomials and challenges to be committed and generated in + /// Phase 1. + fn build_commit_phase(&mut self, spec: &mut Chip); + /// Create the GKR layers in the reverse order. For each layer, specify the + /// polynomial expressions, evaluation expressions of outputs and evaluation + /// positions of the inputs. + fn build_gkr_phase(&mut self, spec: &mut Chip); +} + +pub trait ProtocolWitnessGenerator +where + E: ExtensionField, +{ + type Trace; + + /// The vectors to be committed in the phase1. + fn phase1_witness(&self, phase1: Self::Trace) -> Vec>; + + /// GKR witness. + fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness; +} + +// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, `gkr_phase` and `opening_phase`. +pub struct ProtocolProver, PCS>( + PhantomData<(E, Trans, PCS)>, +); + +// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, `gkr_phase` and `opening_phase`. +pub struct ProtocolVerifier, PCS>( + PhantomData<(E, Trans, PCS)>, +); diff --git a/gkr_iop/src/utils.rs b/gkr_iop/src/utils.rs new file mode 100644 index 000000000..f0d4a06b8 --- /dev/null +++ b/gkr_iop/src/utils.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; + +pub trait SliceVector { + fn slice_vector(&self) -> Vec<&[T]>; +} + +pub trait SliceVectorMut { + fn slice_vector_mut(&mut self) -> Vec<&mut [T]>; +} + +pub trait SliceIterator<'a, T: 'a> { + fn slice_iter(&'a self) -> impl Iterator + Clone; +} + +impl SliceVector for Vec> { + fn slice_vector(&self) -> Vec<&[T]> { + self.iter().map(|v| v.as_slice()).collect() + } +} + +impl SliceVector for Vec>> { + fn slice_vector(&self) -> Vec<&[T]> { + self.iter().map(|v| v.as_slice()).collect() + } +} + +impl<'a, T: 'a> SliceIterator<'a, T> for Vec> { + fn slice_iter(&'a self) -> impl Iterator + Clone { + self.iter().map(|v| v.as_slice()) + } +} + +impl<'a, T: 'a> SliceIterator<'a, T> for Vec>> { + fn slice_iter(&'a self) -> impl Iterator + Clone { + self.iter().map(|v| v.as_slice()) + } +} + +impl SliceVectorMut for Vec> { + fn slice_vector_mut(&mut self) -> Vec<&mut [T]> { + self.iter_mut().map(|v| v.as_mut_slice()).collect() + } +} diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index f977328cc..0732e66e0 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -15,7 +15,7 @@ ark-std.workspace = true bitvec = "1.0" ctr = "0.9" ff.workspace = true -ff_ext = { path = "../ff_ext" } +ff_ext.workspace = true # TODO: move to version 1, once our dependencies are updated generic-array = { version = "0.14", features = ["serde"] } goldilocks.workspace = true @@ -29,7 +29,7 @@ rand.workspace = true rand_chacha.workspace = true rayon = { workspace = true, optional = true } serde.workspace = true -transcript = { path = "../transcript" } +transcript.workspace = true [dev-dependencies] criterion.workspace = true diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index 1a8777641..96d9a7df8 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -12,7 +12,7 @@ version.workspace = true [dependencies] ark-std.workspace = true ff.workspace = true -ff_ext = { path = "../ff_ext" } +ff_ext.workspace = true goldilocks.workspace = true itertools.workspace = true rayon.workspace = true diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index bd50d659a..a2e65178a 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -1,4 +1,6 @@ -use std::{cmp::max, collections::HashMap, marker::PhantomData, mem::MaybeUninit, sync::Arc}; +use std::{ + cmp::max, collections::HashMap, marker::PhantomData, mem::MaybeUninit, ops::Mul, sync::Arc, +}; use crate::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, @@ -261,11 +263,14 @@ pub fn build_eq_x_r_sequential(r: &[E]) -> ArcDenseMultilinea /// over r, which is /// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) -#[tracing::instrument(skip_all, name = "multilinear_extensions::build_eq_x_r_vec_sequential")] -pub fn build_eq_x_r_vec_sequential(r: &[E]) -> Vec { +#[tracing::instrument( + skip_all, + name = "multilinear_extensions::build_eq_x_r_vec_sequential_with_scalar" +)] +pub fn build_eq_x_r_vec_sequential_with_scalar(r: &[E], scalar: E) -> Vec { // avoid unnecessary allocation if r.is_empty() { - return vec![E::ONE]; + return vec![scalar]; } // we build eq(x,r) from its evaluations // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars @@ -279,11 +284,17 @@ pub fn build_eq_x_r_vec_sequential(r: &[E]) -> Vec { // we will need 2^num_var evaluations let mut evals = create_uninit_vec(1 << r.len()); - build_eq_x_r_helper_sequential(r, &mut evals, E::ONE); + build_eq_x_r_helper_sequential(r, &mut evals, scalar); unsafe { std::mem::transmute(evals) } } +#[inline] +#[tracing::instrument(skip_all, name = "multilinear_extensions::build_eq_x_r_vec_sequential")] +pub fn build_eq_x_r_vec_sequential(r: &[E]) -> Vec { + build_eq_x_r_vec_sequential_with_scalar(r, E::ONE) +} + /// A helper function to build eq(x, r)*init via dynamic programing tricks. /// This function takes 2^num_var iterations, and per iteration with 1 multiplication. fn build_eq_x_r_helper_sequential(r: &[E], buf: &mut [MaybeUninit], init: E) { @@ -367,6 +378,56 @@ pub fn build_eq_x_r_vec(r: &[E]) -> Vec { } } +#[tracing::instrument( + skip_all, + name = "multilinear_extensions::build_eq_x_r_vec_with_scalar" +)] +pub fn build_eq_x_r_vec_with_scalar + From, F>( + r: &[E], + scalar: F, +) -> Vec { + // avoid unnecessary allocation + if r.is_empty() { + return vec![E::from(scalar)]; + } + // we build eq(x,r) from its evaluations + // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars + // for example, with num_vars = 4, x is a binary vector of 4, then + // 0 0 0 0 -> (1-r0) * (1-r1) * (1-r2) * (1-r3) + // 1 0 0 0 -> r0 * (1-r1) * (1-r2) * (1-r3) + // 0 1 0 0 -> (1-r0) * r1 * (1-r2) * (1-r3) + // 1 1 0 0 -> r0 * r1 * (1-r2) * (1-r3) + // .... + // 1 1 1 1 -> r0 * r1 * r2 * r3 + // we will need 2^num_var evaluations + let nthreads = max_usable_threads(); + let nbits = nthreads.trailing_zeros() as usize; + assert_eq!(1 << nbits, nthreads); + + let mut evals = create_uninit_vec(1 << r.len()); + if r.len() < nbits { + build_eq_x_r_helper_sequential(r, &mut evals, E::from(scalar)); + } else { + let eq_ts = + build_eq_x_r_vec_sequential_with_scalar(&r[(r.len() - nbits)..], E::from(scalar)); + + // eq(x, r) = eq(x_lo, r_lo) * eq(x_hi, r_hi) + // where rlen = r.len(), x_lo = x[0..rlen-nbits], x_hi = x[rlen-nbits..] + // r_lo = r[0..rlen-nbits] and r_hi = r[rlen-nbits..] + // each thread is associated with x_hi, and it will computes the subset + // { eq(x_lo, r_lo) * eq(x_hi, r_hi) } whose cardinality equals to 2^{rlen-nbits} + evals + .par_chunks_mut(1 << (r.len() - nbits)) + .zip((0..nthreads).into_par_iter()) + .for_each(|(chunks, tid)| { + let eq_t = eq_ts[tid]; + + build_eq_x_r_helper_sequential(&r[..(r.len() - nbits)], chunks, eq_t); + }); + } + unsafe { std::mem::transmute::>, Vec>(evals) } +} + #[cfg(test)] mod tests { use crate::virtual_poly::{build_eq_x_r_vec, build_eq_x_r_vec_sequential}; diff --git a/subprotocols/Cargo.toml b/subprotocols/Cargo.toml new file mode 100644 index 000000000..1f0f195da --- /dev/null +++ b/subprotocols/Cargo.toml @@ -0,0 +1,29 @@ +[package] +categories.workspace = true +description = "Subprotocols" +edition.workspace = true +keywords.workspace = true +license.workspace = true +name = "subprotocols" +readme.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +ark-std.workspace = true +ff.workspace = true +ff_ext.workspace = true +itertools.workspace = true +multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } +rand.workspace = true +rayon.workspace = true +thiserror.workspace = true +transcript.workspace = true + +[dev-dependencies] +criterion.workspace = true +goldilocks.workspace = true + +[[bench]] +harness = false +name = "expr_based_logup" diff --git a/subprotocols/benches/expr_based_logup.rs b/subprotocols/benches/expr_based_logup.rs new file mode 100644 index 000000000..d4a1f4a0d --- /dev/null +++ b/subprotocols/benches/expr_based_logup.rs @@ -0,0 +1,143 @@ +use std::{array, time::Duration}; + +use ark_std::test_rng; +use criterion::*; +use ff::Field; +use goldilocks::GoldilocksExt2; +use itertools::Itertools; +use subprotocols::{ + expression::{Constant, Expression, Witness}, + sumcheck::SumcheckProverState, + test_utils::{random_point, random_poly}, + zerocheck::ZerocheckProverState, +}; +use transcript::BasicTranscript as Transcript; + +criterion_group!(benches, zerocheck_fn, sumcheck_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; +const NV: [usize; 2] = [25, 26]; + +fn sumcheck_fn(c: &mut Criterion) { + type E = GoldilocksExt2; + + for nv in NV { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("logup_sumcheck_nv_{}", nv)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), + |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + let mut rng = test_rng(); + // Initialize logup expression. + let eq = Expression::Wit(Witness::EqPoly(0)); + let beta = Expression::Const(Constant::Challenge(0)); + let [d0, d1, n0, n1] = + array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let expr = eq * (d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0)); + + // Randomly generate point and witness. + let point = random_point(&mut rng, nv); + + let d0 = random_poly(&mut rng, nv); + let d1 = random_poly(&mut rng, nv); + let n0 = random_poly(&mut rng, nv); + let n1 = random_poly(&mut rng, nv); + let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; + + let challenges = vec![E::random(&mut rng)]; + + let ext_mle_refs = + ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + + let mut prover_transcript = Transcript::new(b"test"); + let prover = SumcheckProverState::new( + expr, + &[&point], + ext_mle_refs, + vec![], + &challenges, + &mut prover_transcript, + ); + + let instant = std::time::Instant::now(); + let _ = black_box(prover.prove()); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }, + ); + + group.finish(); + } +} + +fn zerocheck_fn(c: &mut Criterion) { + type E = GoldilocksExt2; + + for nv in NV { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("logup_sumcheck_nv_{}", nv)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), + |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Initialize logup expression. + let mut rng = test_rng(); + let beta = Expression::Const(Constant::Challenge(0)); + let [d0, d1, n0, n1] = + array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); + + // Randomly generate point and witness. + let point = random_point(&mut rng, nv); + + let d0 = random_poly(&mut rng, nv); + let d1 = random_poly(&mut rng, nv); + let n0 = random_poly(&mut rng, nv); + let n1 = random_poly(&mut rng, nv); + let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; + + let challenges = vec![E::random(&mut rng)]; + + let ext_mle_refs = + ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + + let mut prover_transcript = Transcript::new(b"test"); + let prover = ZerocheckProverState::new( + vec![expr], + &[&point], + ext_mle_refs, + vec![], + &challenges, + &mut prover_transcript, + ); + + let instant = std::time::Instant::now(); + let _ = black_box(prover.prove()); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }, + ); + + group.finish(); + } +} diff --git a/subprotocols/examples/zerocheck_logup.rs b/subprotocols/examples/zerocheck_logup.rs new file mode 100644 index 000000000..50a6305b0 --- /dev/null +++ b/subprotocols/examples/zerocheck_logup.rs @@ -0,0 +1,90 @@ +use std::array; + +use ff::Field; +use ff_ext::ExtensionField; +use goldilocks::GoldilocksExt2 as E; +use itertools::{Itertools, izip}; +use rand::thread_rng; +use subprotocols::{ + expression::{Constant, Expression, Witness}, + sumcheck::{SumcheckProof, SumcheckProverOutput}, + test_utils::{random_point, random_poly}, + utils::eq_vecs, + zerocheck::{ZerocheckProverState, ZerocheckVerifierState}, +}; +use transcript::BasicTranscript; + +fn run_prover( + point: &[E], + ext_mles: &mut [Vec], + expr: Expression, + challenges: Vec, +) -> SumcheckProof { + let timer = std::time::Instant::now(); + let ext_mle_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + + let mut prover_transcript = BasicTranscript::new(b"test"); + let prover = ZerocheckProverState::new( + vec![expr], + &[point], + ext_mle_refs, + vec![], + &challenges, + &mut prover_transcript, + ); + + let SumcheckProverOutput { proof, .. } = prover.prove(); + println!("Proving time: {:?}", timer.elapsed()); + proof +} + +fn run_verifier( + proof: SumcheckProof, + ans: &E, + point: &[E], + expr: Expression, + challenges: Vec, +) { + let mut verifier_transcript = BasicTranscript::new(b"test"); + let verifier = ZerocheckVerifierState::new( + vec![*ans], + vec![expr], + vec![point], + proof, + &challenges, + &mut verifier_transcript, + ); + + verifier.verify().expect("verification failed"); +} + +fn main() { + let num_vars = 24; + let mut rng = thread_rng(); + + // Initialize logup expression. + let beta = Expression::Const(Constant::Challenge(0)); + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); + + // Randomly generate point and witness. + let point = random_point(&mut rng, num_vars); + + let d0 = random_poly(&mut rng, num_vars); + let d1 = random_poly(&mut rng, num_vars); + let n0 = random_poly(&mut rng, num_vars); + let n1 = random_poly(&mut rng, num_vars); + let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; + + let challenges = vec![E::random(&mut rng)]; + + let proof = run_prover(&point, &mut ext_mles, expr.clone(), challenges.clone()); + + let eqs = eq_vecs([point.as_slice()].into_iter(), &[E::ONE]); + + let ans: E = izip!(&eqs[0], &d0, &d1, &n0, &n1) + .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) + .sum(); + + run_verifier(proof, &ans, &point, expr, challenges); +} diff --git a/subprotocols/src/error.rs b/subprotocols/src/error.rs new file mode 100644 index 000000000..5dad7bc9d --- /dev/null +++ b/subprotocols/src/error.rs @@ -0,0 +1,9 @@ +use thiserror::Error; + +use crate::expression::Expression; + +#[derive(Clone, Debug, Error)] +pub enum VerifierError { + #[error("Claim not match: expr: {0:?}, expect: {1:?}, got: {2:?}")] + ClaimNotMatch(Expression, E, E), +} diff --git a/subprotocols/src/expression.rs b/subprotocols/src/expression.rs new file mode 100644 index 000000000..7750280a6 --- /dev/null +++ b/subprotocols/src/expression.rs @@ -0,0 +1,166 @@ +use std::sync::Arc; + +use ff_ext::ExtensionField; + +mod evaluate; +mod op; + +mod macros; + +pub type Point = Arc>; + +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum Constant { + /// Base field + Base(i64), + /// Challenge + Challenge(usize), + /// Sum + Sum(Box, Box), + /// Product + Product(Box, Box), + /// Neg + Neg(Box), + /// Pow + Pow(Box, usize), +} + +impl Default for Constant { + fn default() -> Self { + Constant::Base(0) + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum Witness { + /// Base field polynomial (index). + BasePoly(usize), + /// Extension field polynomial (index). + ExtPoly(usize), + /// Eq polynomial + EqPoly(usize), +} + +impl Default for Witness { + fn default() -> Self { + Witness::BasePoly(0) + } +} + +#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum Expression { + /// Constant + Const(Constant), + /// Witness. + Wit(Witness), + /// This is the sum of two expressions, with `degree`. + Sum(Box, Box, usize), + /// This is the product of two expressions, with `degree`. + Product(Box, Box, usize), + /// Neg, with `degree`. + Neg(Box, usize), + /// Pow, with `D` and `degree`. + Pow(Box, usize, usize), +} + +impl std::fmt::Debug for Expression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Expression::Const(c) => write!(f, "{:?}", c), + Expression::Wit(w) => write!(f, "{:?}", w), + Expression::Sum(a, b, _) => write!(f, "({:?} + {:?})", a, b), + Expression::Product(a, b, _) => write!(f, "({:?} * {:?})", a, b), + Expression::Neg(a, _) => write!(f, "(-{:?})", a), + Expression::Pow(a, n, _) => write!(f, "({:?})^({})", a, n), + } + } +} + +impl std::fmt::Debug for Witness { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Witness::BasePoly(i) => write!(f, "BP[{}]", i), + Witness::ExtPoly(i) => write!(f, "EP[{}]", i), + Witness::EqPoly(i) => write!(f, "EQ[{}]", i), + } + } +} + +/// Vector of univariate polys. +#[derive(Clone, Debug)] +enum UniPolyVectorType { + Base(Vec>), + Ext(Vec>), +} + +/// Vector of field type. +#[derive(Clone, PartialEq, Eq)] +pub enum VectorType { + Base(Vec), + Ext(Vec), +} + +impl std::fmt::Debug for VectorType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VectorType::Base(v) => { + let mut v = v.iter(); + write!(f, "[")?; + if let Some(e) = v.next() { + write!(f, "{:?}", e)?; + } + for _ in 0..2 { + if let Some(e) = v.next() { + write!(f, ", {:?}", e)?; + } else { + break; + } + } + if v.next().is_some() { + write!(f, ", ...]")?; + } else { + write!(f, "]")?; + }; + Ok(()) + } + VectorType::Ext(v) => { + let mut v = v.iter(); + write!(f, "[")?; + if let Some(e) = v.next() { + write!(f, "{:?}", e)?; + } + for _ in 0..2 { + if let Some(e) = v.next() { + write!(f, ", {:?}", e)?; + } else { + break; + } + } + if v.next().is_some() { + write!(f, ", ...]")?; + } else { + write!(f, "]")?; + }; + Ok(()) + } + } + } +} + +#[derive(Clone, Debug)] +enum ScalarType { + Base(E::BaseField), + Ext(E), +} + +impl From for Expression { + fn from(w: Witness) -> Self { + Expression::Wit(w) + } +} + +impl From for Expression { + fn from(c: Constant) -> Self { + Expression::Const(c) + } +} diff --git a/subprotocols/src/expression/evaluate.rs b/subprotocols/src/expression/evaluate.rs new file mode 100644 index 000000000..b37f63cd2 --- /dev/null +++ b/subprotocols/src/expression/evaluate.rs @@ -0,0 +1,458 @@ +use ff::Field; +use ff_ext::ExtensionField; +use itertools::{Itertools, zip_eq}; +use multilinear_extensions::virtual_poly::eq_eval; + +use crate::{op_by_type, utils::i64_to_field}; + +use super::{Constant, Expression, ScalarType, UniPolyVectorType, VectorType, Witness}; + +impl Expression { + pub fn degree(&self) -> usize { + match self { + Expression::Const(_) => 0, + Expression::Wit(_) => 1, + Expression::Sum(_, _, degree) => *degree, + Expression::Product(_, _, degree) => *degree, + Expression::Neg(_, degree) => *degree, + Expression::Pow(_, _, degree) => *degree, + } + } + + pub fn is_ext(&self) -> bool { + match self { + Expression::Const(c) => c.is_ext(), + Expression::Wit(w) => w.is_ext(), + Expression::Sum(e0, e1, _) | Expression::Product(e0, e1, _) => { + e0.is_ext() || e1.is_ext() + } + Expression::Neg(e, _) => e.is_ext(), + Expression::Pow(e, d, _) => { + if *d > 0 { + e.is_ext() + } else { + false + } + } + } + } + + pub fn evaluate( + &self, + ext_mle_evals: &[E], + base_mle_evals: &[E], + out_points: &[&[E]], + in_point: &[E], + challenges: &[E], + ) -> E { + match self { + Expression::Const(constant) => constant.evaluate(challenges), + Expression::Wit(w) => w.evaluate(base_mle_evals, ext_mle_evals, out_points, in_point), + Expression::Sum(e0, e1, _) => { + e0.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) + e1.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) + } + Expression::Product(e0, e1, _) => { + e0.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) * e1.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) + } + Expression::Neg(e, _) => -e.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ), + Expression::Pow(e, d, _) => e + .evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) + .pow([*d as u64]), + } + } + + pub fn calc( + &self, + ext: &[Vec], + base: &[Vec], + eqs: &[Vec], + challenges: &[E], + ) -> VectorType { + assert!(!(ext.is_empty() && base.is_empty())); + let size = if !ext.is_empty() { + ext[0].len() + } else { + base[0].len() + }; + match self { + Expression::Const(constant) => { + VectorType::Ext(vec![constant.evaluate(challenges); size]) + } + Expression::Wit(w) => match w { + Witness::BasePoly(index) => VectorType::Base(base[*index].clone()), + Witness::ExtPoly(index) => VectorType::Ext(ext[*index].clone()), + Witness::EqPoly(index) => VectorType::Ext(eqs[*index].clone()), + }, + Expression::Sum(e0, e1, _) => { + e0.calc(ext, base, eqs, challenges) + e1.calc(ext, base, eqs, challenges) + } + Expression::Product(e0, e1, _) => { + e0.calc(ext, base, eqs, challenges) * e1.calc(ext, base, eqs, challenges) + } + Expression::Neg(e, _) => -e.calc(ext, base, eqs, challenges), + Expression::Pow(e, d, _) => { + let poly = e.calc(ext, base, eqs, challenges); + op_by_type!( + VectorType, + poly, + |poly| { poly.into_iter().map(|x| x.pow([*d as u64])).collect_vec() }, + |ext| VectorType::Ext(ext), + |base| VectorType::Base(base) + ) + } + } + } + + #[allow(clippy::too_many_arguments)] + pub fn sumcheck_uni_poly( + &self, + ext_mles: &[&mut [E]], + base_after_mles: &[Vec], + base_mles: &[&[E::BaseField]], + eqs: &[Vec], + challenges: &[E], + size: usize, + degree: usize, + ) -> Vec { + let poly = self.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + op_by_type!(UniPolyVectorType, poly, |poly| { + poly.into_iter().fold(vec![E::ZERO; degree], |acc, x| { + zip_eq(acc, x).map(|(a, b)| a + b).collect_vec() + }) + }) + } + + /// Compute \sum_x (eq(0, x) + eq(1, x)) * expr_0(X, x) + #[allow(clippy::too_many_arguments)] + pub fn zerocheck_uni_poly<'a, E: ExtensionField>( + &self, + ext_mles: &[&mut [E]], + base_after_mles: &[Vec], + base_mles: &[&[E::BaseField]], + challenges: &[E], + coeffs: impl Iterator, + size: usize, + ) -> Vec { + let degree = self.degree(); + let poly = self.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + &[], + challenges, + size, + degree, + ); + + op_by_type!(UniPolyVectorType, poly, |poly| { + zip_eq(coeffs, poly).fold(vec![E::ZERO; degree], |mut acc, (c, poly)| { + zip_eq(&mut acc, poly).for_each(|(a, x)| *a += *c * x); + acc + }) + }) + } + + /// Compute the extension field univariate polynomial evaluated on 1..degree + 1. + #[allow(clippy::too_many_arguments)] + fn uni_poly_inner( + &self, + ext_mles: &[&mut [E]], + base_after_mles: &[Vec], + base_mles: &[&[E::BaseField]], + eqs: &[Vec], + challenges: &[E], + size: usize, + degree: usize, + ) -> UniPolyVectorType { + match self { + Expression::Const(constant) => { + let value = constant.evaluate(challenges); + UniPolyVectorType::Ext(vec![vec![value; degree]; size >> 1]) + } + Expression::Wit(w) => match w { + Witness::BasePoly(index) => { + if !base_mles.is_empty() { + UniPolyVectorType::Base(uni_poly_helper(base_mles[*index], size, degree)) + } else { + UniPolyVectorType::Ext(uni_poly_helper( + &base_after_mles[*index], + size, + degree, + )) + } + } + Witness::ExtPoly(index) => { + UniPolyVectorType::Ext(uni_poly_helper(ext_mles[*index], size, degree)) + } + Witness::EqPoly(index) => { + UniPolyVectorType::Ext(uni_poly_helper(&eqs[*index], size, degree)) + } + }, + Expression::Sum(expr0, expr1, _) => { + let poly0 = expr0.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + let poly1 = expr1.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + poly0 + poly1 + } + Expression::Product(expr0, expr1, _) => { + let poly0 = expr0.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + let poly1 = expr1.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + poly0 * poly1 + } + Expression::Neg(expr, _) => { + let poly = expr.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + -poly + } + Expression::Pow(expr, d, _) => { + let poly = expr.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + op_by_type!( + UniPolyVectorType, + poly, + |poly| { + poly.into_iter() + .map(|x| x.iter().map(|x| x.pow([*d as u64])).collect_vec()) + .collect_vec() + }, + |ext| UniPolyVectorType::Ext(ext), + |base| UniPolyVectorType::Base(base) + ) + } + } + } +} + +impl Constant { + pub fn is_ext(&self) -> bool { + match self { + Constant::Base(_) => false, + Constant::Challenge(_) => true, + Constant::Sum(c0, c1) | Constant::Product(c0, c1) => c0.is_ext() || c1.is_ext(), + Constant::Neg(c) => c.is_ext(), + Constant::Pow(c, _) => c.is_ext(), + } + } + + pub fn evaluate(&self, challenges: &[E]) -> E { + let res = self.evaluate_inner(challenges); + op_by_type!(ScalarType, res, |b| b, |e| e, |bf| E::from(bf)) + } + + fn evaluate_inner(&self, challenges: &[E]) -> ScalarType { + match self { + Constant::Base(value) => ScalarType::Base(i64_to_field(*value)), + Constant::Challenge(index) => ScalarType::Ext(challenges[*index]), + Constant::Sum(c0, c1) => c0.evaluate_inner(challenges) + c1.evaluate_inner(challenges), + Constant::Product(c0, c1) => { + c0.evaluate_inner(challenges) * c1.evaluate_inner(challenges) + } + Constant::Neg(c) => -c.evaluate_inner(challenges), + Constant::Pow(c, degree) => { + let value = c.evaluate_inner(challenges); + op_by_type!( + ScalarType, + value, + |value| { value.pow([*degree as u64]) }, + |ext| ScalarType::Ext(ext), + |base| ScalarType::Base(base) + ) + } + } + } + + pub fn entry(&self, challenges: &[E]) -> E { + match self { + Constant::Challenge(index) => challenges[*index], + _ => unreachable!(), + } + } + + pub fn entry_mut<'a, E: ExtensionField>(&self, challenges: &'a mut [E]) -> &'a mut E { + match self { + Constant::Challenge(index) => &mut challenges[*index], + _ => unreachable!(), + } + } +} + +impl Witness { + pub fn is_ext(&self) -> bool { + match self { + Witness::BasePoly(_) => false, + Witness::ExtPoly(_) => true, + Witness::EqPoly(_) => true, + } + } + + pub fn evaluate( + &self, + base_mle_evals: &[E], + ext_mle_evals: &[E], + out_point: &[&[E]], + in_point: &[E], + ) -> E { + match self { + Witness::BasePoly(index) => base_mle_evals[*index], + Witness::ExtPoly(index) => ext_mle_evals[*index], + Witness::EqPoly(index) => eq_eval(out_point[*index], in_point), + } + } + + pub fn base<'a, T>(&self, base_mle_evals: &'a [T]) -> &'a T { + match self { + Witness::BasePoly(index) => &base_mle_evals[*index], + _ => unreachable!(), + } + } + + pub fn base_mut<'a, T>(&self, base_mle_evals: &'a mut [T]) -> &'a mut T { + match self { + Witness::BasePoly(index) => &mut base_mle_evals[*index], + _ => unreachable!(), + } + } + + pub fn ext<'a, T>(&self, ext_mle_evals: &'a [T]) -> &'a T { + match self { + Witness::ExtPoly(index) => &ext_mle_evals[*index], + _ => unreachable!(), + } + } + + pub fn ext_mut<'a, T>(&self, ext_mle_evals: &'a mut [T]) -> &'a mut T { + match self { + Witness::ExtPoly(index) => &mut ext_mle_evals[*index], + _ => unreachable!(), + } + } +} + +/// Compute the univariate polynomial evaluated on 1..degree. +#[inline] +fn uni_poly_helper(mle: &[F], size: usize, degree: usize) -> Vec> { + mle.chunks(2) + .take(size >> 1) + .map(|p| { + let start = p[0]; + let step = p[1] - start; + (0..degree) + .scan(start, |state, _| { + *state += step; + Some(*state) + }) + .collect_vec() + }) + .collect_vec() +} + +#[cfg(test)] +mod test { + use crate::field_vec; + use goldilocks::Goldilocks as F; + + #[test] + fn test_uni_poly_helper() { + // (x + 2), (3x + 4), (5x + 6), (7x + 8) + let mle = field_vec![F, 2, 3, 4, 7, 6, 11, 8, 15, 11, 13, 17, 19, 23, 29, 31, 37]; + let size = 8; + let degree = 3; + let expected = vec![ + field_vec![F, 3, 4, 5], + field_vec![F, 7, 10, 13], + field_vec![F, 11, 16, 21], + field_vec![F, 15, 22, 29], + ]; + let result = super::uni_poly_helper(&mle, size, degree); + assert_eq!(result, expected); + } +} diff --git a/subprotocols/src/expression/macros.rs b/subprotocols/src/expression/macros.rs new file mode 100644 index 000000000..6e4e22784 --- /dev/null +++ b/subprotocols/src/expression/macros.rs @@ -0,0 +1,100 @@ +#[macro_export] +macro_rules! op_by_type { + ($ele_type:ident, $ele:ident, |$x:ident| $op:expr, |$y_ext:ident| $convert_ext:expr, |$y_base:ident| $convert_base:expr) => { + match $ele { + $ele_type::Base($x) => { + let $y_base = $op; + $convert_base + } + $ele_type::Ext($x) => { + let $y_ext = $op; + $convert_ext + } + } + }; + + ($ele_type:ident, $ele:ident, |$x:ident| $op:expr, |$y_base:ident| $convert_base:expr) => { + match $ele { + $ele_type::Base($x) => { + let $y_base = $op; + $convert_base + } + $ele_type::Ext($x) => $op, + } + }; + + ($ele_type:ident, $ele:ident, |$x:ident| $op:expr) => { + match $ele { + $ele_type::Base($x) => $op, + $ele_type::Ext($x) => $op, + } + }; +} + +#[macro_export] +macro_rules! define_commutative_op_mle2 { + ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident, $y:ident| $op:expr) => { + impl $trait_type for $ele_type { + type Output = Self; + + fn $func_type(self, other: Self) -> Self::Output { + #[allow(unused)] + match (self, other) { + ($ele_type::Base(mut $x), $ele_type::Base($y)) => $ele_type::Base($op), + ($ele_type::Ext(mut $x), $ele_type::Base($y)) + | ($ele_type::Base($y), $ele_type::Ext(mut $x)) => $ele_type::Ext($op), + ($ele_type::Ext(mut $x), $ele_type::Ext($y)) => $ele_type::Ext($op), + } + } + } + + impl<'a, E: ExtensionField> $trait_type<&'a Self> for $ele_type { + type Output = Self; + + fn $func_type(self, other: &'a Self) -> Self::Output { + #[allow(unused)] + match (self, other) { + ($ele_type::Base(mut $x), $ele_type::Base($y)) => $ele_type::Base($op), + ($ele_type::Ext(mut $x), $ele_type::Base($y)) => $ele_type::Ext($op), + ($ele_type::Base($y), $ele_type::Ext($x)) => { + let mut $x = $x.clone(); + $ele_type::Ext($op) + } + ($ele_type::Ext(mut $x), $ele_type::Ext($y)) => $ele_type::Ext($op), + } + } + } + }; +} + +#[macro_export] +macro_rules! define_op_mle2 { + ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident, $y:ident| $op:expr) => { + impl $trait_type for $ele_type { + type Output = Self; + + fn $func_type(self, other: Self) -> Self::Output { + let $x = self; + let $y = other; + $op + } + } + }; +} + +#[macro_export] +macro_rules! define_op_mle { + ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident| $op:expr) => { + impl $trait_type for $ele_type { + type Output = Self; + + fn $func_type(self) -> Self::Output { + #[allow(unused)] + match (self) { + $ele_type::Base(mut $x) => $ele_type::Base($op), + $ele_type::Ext(mut $x) => $ele_type::Ext($op), + } + } + } + }; +} diff --git a/subprotocols/src/expression/op.rs b/subprotocols/src/expression/op.rs new file mode 100644 index 000000000..1690c24e4 --- /dev/null +++ b/subprotocols/src/expression/op.rs @@ -0,0 +1,81 @@ +use std::{ + cmp::max, + ops::{Add, Mul, Neg, Sub}, +}; + +use ff_ext::ExtensionField; +use itertools::zip_eq; + +use crate::{define_commutative_op_mle2, define_op_mle, define_op_mle2}; + +use super::{Expression, ScalarType, UniPolyVectorType, VectorType}; + +impl Add for Expression { + type Output = Self; + + fn add(self, other: Self) -> Self { + let degree = max(self.degree(), other.degree()); + Expression::Sum(Box::new(self), Box::new(other), degree) + } +} + +impl Mul for Expression { + type Output = Self; + + fn mul(self, other: Self) -> Self { + #[allow(clippy::suspicious_arithmetic_impl)] + let degree = self.degree() + other.degree(); + Expression::Product(Box::new(self), Box::new(other), degree) + } +} + +impl Neg for Expression { + type Output = Self; + + fn neg(self) -> Self { + let deg = self.degree(); + Expression::Neg(Box::new(self), deg) + } +} + +impl Sub for Expression { + type Output = Self; + + fn sub(self, other: Self) -> Self { + self + (-other) + } +} + +define_commutative_op_mle2!(UniPolyVectorType, Add, add, |x, y| { + zip_eq(&mut x, y).for_each(|(x, y)| zip_eq(x, y).for_each(|(x, y)| *x += y)); + x +}); +define_commutative_op_mle2!(UniPolyVectorType, Mul, mul, |x, y| { + zip_eq(&mut x, y).for_each(|(x, y)| zip_eq(x, y).for_each(|(x, y)| *x *= y)); + x +}); +define_op_mle2!(UniPolyVectorType, Sub, sub, |x, y| x + (-y)); +define_op_mle!(UniPolyVectorType, Neg, neg, |x| { + x.iter_mut() + .for_each(|x| x.iter_mut().for_each(|x| *x = -(*x))); + x +}); + +define_commutative_op_mle2!(VectorType, Add, add, |x, y| { + zip_eq(&mut x, y).for_each(|(x, y)| *x += y); + x +}); +define_commutative_op_mle2!(VectorType, Mul, mul, |x, y| { + zip_eq(&mut x, y).for_each(|(x, y)| *x *= y); + x +}); +define_op_mle2!(VectorType, Sub, sub, |x, y| x + (-y)); +define_op_mle!(VectorType, Neg, neg, |x| { + x.iter_mut().for_each(|x| *x = -(*x)); + x +}); + +define_commutative_op_mle2!(ScalarType, Add, add, |x, y| x + y); +define_commutative_op_mle2!(ScalarType, Mul, mul, |x, y| x * y); +define_op_mle2!(ScalarType, Sub, sub, |x, y| x + (-y)); +define_op_mle!(ScalarType, Neg, neg, |x| -x); diff --git a/subprotocols/src/lib.rs b/subprotocols/src/lib.rs new file mode 100644 index 000000000..a86f12c8f --- /dev/null +++ b/subprotocols/src/lib.rs @@ -0,0 +1,9 @@ +pub mod error; +pub mod expression; +pub mod points; +pub mod sumcheck; +pub mod utils; +pub mod zerocheck; + +#[macro_use] +pub mod test_utils; diff --git a/subprotocols/src/points.rs b/subprotocols/src/points.rs new file mode 100644 index 000000000..07e63f035 --- /dev/null +++ b/subprotocols/src/points.rs @@ -0,0 +1,75 @@ +use std::sync::Arc; + +use ff_ext::ExtensionField; +use itertools::izip; + +type Point = Arc>; + +pub trait PointBeforeMerge { + fn point_before_merge(&self, pos: &[usize]) -> Point; +} + +pub trait PointBeforePartition { + fn point_before_partition( + &self, + pos_and_var_ids: &[(usize, usize)], + challenges: &[E], + ) -> Point; +} + +/// Suppose we have several vectors v_0, ..., v_{N-1}, and want to merge it through n = log(N) variables, +/// x_0, ..., x_{n-1}, at the positions i_0, ..., i_{n - 1}. Suppose the output point is P, then the point +/// before it is P_0, ..., P_{i_0 - 1}, P_{i_0 + 1}, ..., P_{i_1 - 1}, ..., P_{i_{n - 1} + 1}, ..., P_{N - 1}. +impl PointBeforeMerge for Point { + fn point_before_merge(&self, pos: &[usize]) -> Point { + if pos.is_empty() { + return self.clone(); + } + + assert!(izip!(pos.iter(), pos.iter().skip(1)).all(|(i, j)| i < j)); + + let mut new_point = Vec::with_capacity(self.len() - pos.len()); + let mut i = 0usize; + for (j, p) in self.iter().enumerate() { + if j != pos[i] { + new_point.push(*p); + } else { + i += 1; + } + } + + Arc::new(new_point) + } +} + +/// Suppose we have a vector v, and want to partition it through n = log(N) variables +/// x_0, ..., x_{n-1}, at the positions i_0, ..., i_{n - 1}. Suppose the output point +/// is P, then the point before it is P after calling P.insert(i_0, x_0), ... +impl PointBeforePartition for Point { + fn point_before_partition( + &self, + pos_and_var_ids: &[(usize, usize)], + challenges: &[E], + ) -> Point { + if pos_and_var_ids.is_empty() { + return self.clone(); + } + + assert!( + izip!(pos_and_var_ids.iter(), pos_and_var_ids.iter().skip(1)).all(|(i, j)| i.0 < j.0) + ); + + let mut new_point = Vec::with_capacity(self.len() + pos_and_var_ids.len()); + let mut i = 0usize; + for (j, p) in self.iter().enumerate() { + if i + j != pos_and_var_ids[i].0 { + new_point.push(*p); + } else { + new_point.push(challenges[pos_and_var_ids[i].1]); + i += 1; + } + } + + Arc::new(new_point) + } +} diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs new file mode 100644 index 000000000..4869f3acb --- /dev/null +++ b/subprotocols/src/sumcheck.rs @@ -0,0 +1,434 @@ +use std::{iter, mem, sync::Arc, vec}; + +use ark_std::log2; +use ff_ext::ExtensionField; +use itertools::chain; +use transcript::Transcript; + +use crate::{ + error::VerifierError, + expression::{Expression, Point}, + utils::eq_vecs, +}; + +use super::utils::{fix_variables_ext, fix_variables_inplace, interpolate_uni_poly}; + +/// This is an randomly combined sumcheck protocol for the following equation: +/// \sigma = \sum_x expr(x) +pub struct SumcheckProverState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + /// Expression. + expr: Expression, + + /// Extension field mles. + ext_mles: Vec<&'a mut [E]>, + /// Base field mles after the first round. + base_mles_after: Vec>, + /// Base field mles. + base_mles: Vec<&'a [E::BaseField]>, + /// Eq polys + eqs: Vec>, + /// Challenges occurred in expressions + challenges: &'a [E], + + transcript: &'a mut Trans, + + degree: usize, + num_vars: usize, +} + +pub struct SumcheckProof { + /// Messages for each round. + pub univariate_polys: Vec>>, + pub ext_mle_evals: Vec, + pub base_mle_evals: Vec, +} + +pub struct SumcheckProverOutput { + pub proof: SumcheckProof, + pub point: Point, +} + +impl<'a, E, Trans> SumcheckProverState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + expr: Expression, + points: &[&[E]], + ext_mles: Vec<&'a mut [E]>, + base_mles: Vec<&'a [E::BaseField]>, + challenges: &'a [E], + transcript: &'a mut Trans, + ) -> Self { + assert!(!(ext_mles.is_empty() && base_mles.is_empty())); + + let num_vars = if !ext_mles.is_empty() { + log2(ext_mles[0].len()) as usize + } else { + log2(base_mles[0].len()) as usize + }; + + // The length of all mles should be 2^{num_vars}. + assert!(ext_mles.iter().all(|mle| mle.len() == 1 << num_vars)); + assert!(base_mles.iter().all(|mle| mle.len() == 1 << num_vars)); + + let degree = expr.degree(); + + let eqs = eq_vecs(points.iter().copied(), &vec![E::ONE; points.len()]); + + Self { + expr, + ext_mles, + base_mles_after: vec![], + base_mles, + eqs, + challenges, + transcript, + num_vars, + degree, + } + } + + pub fn prove(mut self) -> SumcheckProverOutput { + let (univariate_polys, point) = (0..self.num_vars) + .map(|round| { + let round_msg = self.compute_univariate_poly(round); + self.transcript.append_field_element_exts(&round_msg); + + let r = self + .transcript + .get_and_append_challenge(b"sumcheck round") + .elements; + self.update_mles(&r, round); + (vec![round_msg], r) + }) + .unzip(); + let point = Arc::new(point); + + // Send the final evaluations + let SumcheckProverState { + ext_mles, + base_mles_after, + base_mles, + .. + } = self; + let ext_mle_evaluations = ext_mles.into_iter().map(|mle| mle[0]).collect(); + let base_mle_evaluations = if !base_mles.is_empty() { + base_mles.into_iter().map(|mle| E::from(mle[0])).collect() + } else { + base_mles_after.into_iter().map(|mle| mle[0]).collect() + }; + + SumcheckProverOutput { + proof: SumcheckProof { + univariate_polys, + ext_mle_evals: ext_mle_evaluations, + base_mle_evals: base_mle_evaluations, + }, + point, + } + } + + /// Compute f(X) = r^0 \sum_x expr_0(X || x) + r^1 \sum_x expr_1(X || x) + ... + fn compute_univariate_poly(&self, round: usize) -> Vec { + self.expr.sumcheck_uni_poly( + &self.ext_mles, + &self.base_mles_after, + &self.base_mles, + &self.eqs, + self.challenges, + 1 << (self.num_vars - round), + self.degree, + ) + } + + fn update_mles(&mut self, r: &E, round: usize) { + // fix variables of eq polynomials + self.eqs.iter_mut().for_each(|eq| { + fix_variables_inplace(eq, r); + }); + // fix variables of ext field polynomials. + self.ext_mles.iter_mut().for_each(|mle| { + fix_variables_inplace(mle, r); + }); + // fix variables of base field polynomials. + if round == 0 { + self.base_mles_after = mem::take(&mut self.base_mles) + .into_iter() + .map(|mle| fix_variables_ext(mle, r)) + .collect(); + } else { + self.base_mles_after + .iter_mut() + .for_each(|mle| fix_variables_inplace(mle, r)); + } + } +} + +pub struct SumcheckVerifierState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + sigma: E, + expr: Expression, + proof: SumcheckProof, + challenges: &'a [E], + transcript: &'a mut Trans, + out_points: Vec<&'a [E]>, +} + +pub struct SumcheckClaims { + pub in_point: Point, + pub base_mle_evals: Vec, + pub ext_mle_evals: Vec, +} + +impl<'a, E, Trans> SumcheckVerifierState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + pub fn new( + sigma: E, + expr: Expression, + out_points: Vec<&'a [E]>, + proof: SumcheckProof, + challenges: &'a [E], + transcript: &'a mut Trans, + ) -> Self { + Self { + sigma, + expr, + proof, + challenges, + transcript, + out_points, + } + } + + pub fn verify(self) -> Result, VerifierError> { + let SumcheckVerifierState { + sigma, + expr, + proof, + challenges, + transcript, + out_points, + } = self; + let SumcheckProof { + univariate_polys, + ext_mle_evals, + base_mle_evals, + } = proof; + + let (in_point, expected_claim) = univariate_polys.into_iter().fold( + (vec![], sigma), + |(mut last_point, last_sigma), msg| { + let msg = msg.into_iter().next().unwrap(); + transcript.append_field_element_exts(&msg); + + let len = msg.len() + 1; + let eval_at_0 = last_sigma - msg[0]; + + // Evaluations on degree, degree - 1, ..., 1, 0. + let evals_iter_rev = chain![msg.into_iter().rev(), iter::once(eval_at_0)]; + + let r = transcript + .get_and_append_challenge(b"sumcheck round") + .elements; + let sigma = interpolate_uni_poly(evals_iter_rev, len, r); + last_point.push(r); + (last_point, sigma) + }, + ); + + // Check the final evaluations. + let got_claim = expr.evaluate( + &ext_mle_evals, + &base_mle_evals, + &out_points, + &in_point, + challenges, + ); + if expected_claim != got_claim { + return Err(VerifierError::ClaimNotMatch( + expr, + expected_claim, + got_claim, + )); + } + + let in_point = Arc::new(in_point); + Ok(SumcheckClaims { + in_point, + base_mle_evals, + ext_mle_evals, + }) + } +} + +#[cfg(test)] +mod test { + use std::array; + + use ff::Field; + use ff_ext::ExtensionField; + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + use itertools::{Itertools, izip}; + use transcript::BasicTranscript; + + use crate::{ + expression::{Constant, Expression, Witness}, + field_vec, + utils::eq_vecs, + }; + + use super::{SumcheckProverOutput, SumcheckProverState, SumcheckVerifierState}; + + #[allow(clippy::too_many_arguments)] + fn run( + points: Vec<&[E]>, + expr: Expression, + ext_mle_refs: Vec<&mut [E]>, + base_mle_refs: Vec<&[E::BaseField]>, + challenges: Vec, + + sigma: E, + ) { + let mut prover_transcript = BasicTranscript::new(b"test"); + let prover = SumcheckProverState::new( + expr.clone(), + &points, + ext_mle_refs, + base_mle_refs, + &challenges, + &mut prover_transcript, + ); + + let SumcheckProverOutput { proof, .. } = prover.prove(); + + let mut verifier_transcript = BasicTranscript::new(b"test"); + let verifier = SumcheckVerifierState::new( + sigma, + expr, + points, + proof, + &challenges, + &mut verifier_transcript, + ); + + verifier.verify().expect("verification failed"); + } + + #[test] + fn test_sumcheck_trivial() { + let f = field_vec![F, 2]; + let g = field_vec![F, 3]; + let out_point = vec![]; + + let base_mle_refs = vec![f.as_slice(), g.as_slice()]; + let f = Expression::Wit(Witness::BasePoly(0)); + let g = Expression::Wit(Witness::BasePoly(1)); + let expr = f * g; + + run( + vec![out_point.as_slice()], + expr, + vec![], + base_mle_refs, + vec![], + E::from(6), + ); + } + + #[test] + fn test_sumcheck_simple() { + let f = field_vec![F, 1, 2, 3, 4]; + let ans = E::from(f.iter().fold(F::ZERO, |acc, x| acc + x)); + let base_mle_refs = vec![f.as_slice()]; + let expr = Expression::Wit(Witness::BasePoly(0)); + + run(vec![], expr, vec![], base_mle_refs, vec![], ans); + } + + #[test] + fn test_sumcheck_logup() { + let point = field_vec![E, 2, 3]; + + let eqs = eq_vecs([point.as_slice()].into_iter(), &[E::ONE]); + + let d0 = field_vec![E, 1, 2, 3, 4]; + let d1 = field_vec![E, 5, 6, 7, 8]; + let n0 = field_vec![E, 9, 10, 11, 12]; + let n1 = field_vec![E, 13, 14, 15, 16]; + + let challenges = vec![E::from(7)]; + let ans = izip!(&eqs[0], &d0, &d1, &n0, &n1) + .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) + .sum(); + + let mut ext_mles = [d0, d1, n0, n1]; + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let eq = Expression::Wit(Witness::EqPoly(0)); + let beta = Expression::Const(Constant::Challenge(0)); + + let expr = eq * (d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0)); + + let ext_mle_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + run( + vec![point.as_slice()], + expr, + ext_mle_refs, + vec![], + challenges, + ans, + ); + } + + #[test] + fn test_sumcheck_multi_points() { + let challenges = vec![E::from(2)]; + + let points = [field_vec![E, 2, 3], field_vec![E, 5, 7], field_vec![ + E, 2, 5 + ]]; + let point_refs = points.iter().map(|v| v.as_slice()).collect_vec(); + + let eqs = eq_vecs(point_refs.clone().into_iter(), &vec![E::ONE; points.len()]); + + let d0 = field_vec![F, 1, 2, 3, 4]; + let d1 = field_vec![F, 5, 6, 7, 8]; + let n0 = field_vec![F, 9, 10, 11, 12]; + let n1 = field_vec![F, 13, 14, 15, 16]; + + let ans_0 = izip!(&eqs[0], &d0, &d1) + .map(|(eq0, d0, d1)| eq0 * d0 * d1) + .sum::(); + let ans_1 = izip!(&eqs[1], &d0, &n1) + .map(|(eq1, d0, n1)| eq1 * d0 * n1) + .sum::(); + let ans_2 = izip!(&eqs[2], &d1, &n0) + .map(|(eq2, d1, n0)| eq2 * d1 * n0) + .sum::(); + let ans = (ans_0 * challenges[0] + ans_1) * challenges[0] + ans_2; + + let base_mles = [d0, d1, n0, n1]; + let [eq0, eq1, eq2] = array::from_fn(|i| Expression::Wit(Witness::EqPoly(i))); + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::BasePoly(i))); + let rlc_challenge = Expression::Const(Constant::Challenge(0)); + + let expr = (eq0 * d0.clone() * d1.clone() * rlc_challenge.clone() + eq1 * d0 * n1) + * rlc_challenge + + eq2 * d1 * n0; + + let base_mle_refs = base_mles.iter().map(|v| v.as_slice()).collect_vec(); + run(point_refs, expr, vec![], base_mle_refs, challenges, ans); + } +} diff --git a/subprotocols/src/test_utils.rs b/subprotocols/src/test_utils.rs new file mode 100644 index 000000000..a7c8ad29f --- /dev/null +++ b/subprotocols/src/test_utils.rs @@ -0,0 +1,46 @@ +use ff::PrimeField; +use ff_ext::ExtensionField; +use itertools::Itertools; +use rand::RngCore; + +pub fn random_point(mut rng: impl RngCore, num_vars: usize) -> Vec { + (0..num_vars).map(|_| E::random(&mut rng)).collect_vec() +} + +pub fn random_vec(mut rng: impl RngCore, len: usize) -> Vec { + (0..len).map(|_| E::random(&mut rng)).collect_vec() +} + +pub fn random_poly(mut rng: impl RngCore, num_vars: usize) -> Vec { + (0..1 << num_vars) + .map(|_| E::random(&mut rng)) + .collect_vec() +} + +#[macro_export] +macro_rules! field_vec { + () => ( + $crate::vec::Vec::new() + ); + ($field_type:ident; $elem:expr; $n:expr) => ( + $crate::vec::from_elem({ + if $x < 0 { + -$field_type::from((-$x) as u64) + } else { + $field_type::from($x as u64) + } + }, $n) + ); + ($field_type:ident, $($x:expr),+ $(,)?) => ( + <[_]>::into_vec( + std::boxed::Box::new([$({ + let x = $x as i64; + if $x < 0 { + -$field_type::from((-x) as u64) + } else { + $field_type::from(x as u64) + } + }),+]) + ) + ); +} diff --git a/subprotocols/src/utils.rs b/subprotocols/src/utils.rs new file mode 100644 index 000000000..2ec1854f1 --- /dev/null +++ b/subprotocols/src/utils.rs @@ -0,0 +1,232 @@ +use std::{iter, ops::Mul}; + +use ff::{Field, PrimeField}; +use ff_ext::ExtensionField; +use itertools::{Itertools, chain, izip}; +use multilinear_extensions::virtual_poly::build_eq_x_r_vec_with_scalar; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; + +pub fn i64_to_field(i: i64) -> F { + if i < 0 { + -F::from(i.unsigned_abs()) + } else { + F::from(i as u64) + } +} + +pub fn power_list(ele: &F, size: usize) -> Vec { + (0..size) + .scan(F::ONE, |state, _| { + let last = *state; + *state *= *ele; + Some(last) + }) + .collect() +} + +/// Grand product of ele, start from 1, with length ele.len() + 1. +pub fn grand_product(ele: &[F]) -> Vec { + let one = F::ONE; + chain![iter::once(&one), ele.iter()] + .scan(F::ONE, |state, e| { + *state *= *e; + Some(*state) + }) + .collect() +} + +pub fn eq_vecs<'a, E: ExtensionField>( + points: impl Iterator, + scalars: &[E], +) -> Vec> { + izip!(points, scalars) + .map(|(point, scalar)| build_eq_x_r_vec_with_scalar(point, *scalar)) + .collect_vec() +} + +#[inline(always)] +pub fn eq(x: &F, y: &F) -> F { + // x * y + (1 - x) * (1 - y) + let xy = *x * y; + xy + xy - x - y + F::ONE +} + +pub fn fix_variables_ext(base_mle: &[E::BaseField], r: &E) -> Vec { + base_mle + .par_iter() + .chunks(2) + .with_min_len(64) + .map(|buf| *r * (*buf[1] - *buf[0]) + *buf[0]) + .collect() +} + +pub fn fix_variables_inplace(ext_mle: &mut [E], r: &E) { + ext_mle + .par_iter_mut() + .chunks(2) + .with_min_len(64) + .for_each(|mut buf| *buf[0] = *buf[0] + (*buf[1] - *buf[0]) * r); + // sequentially update buf[b1, b2,..bt] = buf[b1, b2,..bt, 0] + let half_len = ext_mle.len() >> 1; + for index in 0..half_len { + ext_mle[index] = ext_mle[index << 1]; + } +} + +pub fn evaluate_mle_inplace(mle: &mut [E], point: &[E]) -> E { + for r in point { + fix_variables_inplace(mle, r); + } + mle[0] +} + +pub fn evaluate_mle_ext(mle: &[E::BaseField], point: &[E]) -> E { + let mut ext_mle = fix_variables_ext(mle, &point[0]); + evaluate_mle_inplace(&mut ext_mle, &point[1..]) +} + +/// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this +/// polynomial at `eval_at`: +/// +/// \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j) ) +/// +/// This implementation is linear in number of inputs in terms of field +/// operations. It also has a quadratic term in primitive operations which is +/// negligible compared to field operations. +/// TODO: The quadratic term can be removed by precomputing the lagrange +/// coefficients. +pub(crate) fn interpolate_uni_poly>( + p_iter_rev: impl Iterator, + len: usize, + eval_at: E, +) -> E { + let mut evals = vec![eval_at]; + let mut prod = eval_at; + + // `prod = \prod_{j} (eval_at - j)` + for j in 1..len { + let tmp = eval_at - E::from(j as u64); + evals.push(tmp); + prod *= tmp; + } + let mut res = E::ZERO; + // we want to compute \prod (j!=i) (i-j) for a given i + // + // we start from the last step, which is + // denom[len-1] = (len-1) * (len-2) *... * 2 * 1 + // the step before that is + // denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1 + // and the step before that is + // denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2 + // + // i.e., for any i, the one before this will be derived from + // denom[i-1] = denom[i] * (len-i) / i + // + // that is, we only need to store + // - the last denom for i = len-1, and + // - the ratio between current step and fhe last step, which is the product of (len-i) / i from + // all previous steps and we store this product as a fraction number to reduce field + // divisions. + + let mut denom_up = field_factorial::(len - 1); + let mut denom_down = F::ONE; + + for (j, p_i) in p_iter_rev.enumerate() { + let i = len - j - 1; + res += prod * p_i * denom_down * (evals[i] * denom_up).invert().unwrap(); + + // compute denom for the next step is current_denom * (len-i)/i + if i != 0 { + denom_up *= -F::from((j + 1) as u64); + denom_down *= F::from(i as u64); + } + } + res +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn field_factorial(a: usize) -> F { + let mut res = F::ONE; + for i in 2..=a { + res *= F::from(i as u64); + } + res +} + +#[cfg(test)] +mod test { + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + use itertools::Itertools; + use multilinear_extensions::virtual_poly::eq_eval; + + use crate::field_vec; + + use super::*; + + #[test] + fn test_power_list() { + let ele = F::from(3u64); + let list = power_list(&ele, 4); + assert_eq!(list, field_vec![F, 1, 3, 9, 27]); + } + + #[test] + fn test_grand_product() { + let ele = field_vec![F, 2, 3, 4, 5]; + let expected = field_vec![F, 1, 2, 6, 24, 120]; + assert_eq!(grand_product(&ele), expected); + } + + #[test] + fn test_eq_vecs() { + let points = [field_vec![E, 2, 3, 5], field_vec![E, 7, 11, 13]]; + let point_refs = points.iter().map(|p| p.as_slice()).collect_vec(); + + let scalars = field_vec![E, 3, 5]; + + let eq_evals = eq_vecs(point_refs.into_iter(), &scalars); + + let expected = vec![ + field_vec![E, -24, 48, 36, -72, 30, -60, -45, 90], + field_vec![E, -3600, 4200, 3960, -4620, 3900, -4550, -4290, 5005], + ]; + assert_eq!(eq_evals, expected); + } + + #[test] + fn test_eq_eval() { + let xs = field_vec![E, 2, 3, 5]; + let ys = field_vec![E, 7, 11, 13]; + let expected = E::from(119780); + assert_eq!(eq_eval(&xs, &ys), expected); + } + + #[test] + fn test_fix_variables_ext() { + let base_mle = field_vec![F, 1, 2, 3, 4, 5, 6]; + let r = E::from(3u64); + let expected = field_vec![E, 4, 6, 8]; + assert_eq!(fix_variables_ext(&base_mle, &r), expected); + } + + #[test] + fn test_fix_variables_inplace() { + let mut ext_mle = field_vec![E, 1, 2, 3, 4, 5, 6]; + let r = E::from(3u64); + fix_variables_inplace(&mut ext_mle, &r); + let expected = field_vec![E, 4, 6, 8]; + assert_eq!(ext_mle[..3], expected); + } + + #[test] + fn test_interpolate_uni_poly() { + // p(x) = x^3 + 2x^2 + 3x + 4 + let p_iter = field_vec![F, 4, 10, 26, 58].into_iter().rev(); + let eval_at = E::from(11); + let expected = E::from(1610); + assert_eq!(interpolate_uni_poly(p_iter, 4, eval_at), expected); + } +} diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs new file mode 100644 index 000000000..6756e0336 --- /dev/null +++ b/subprotocols/src/zerocheck.rs @@ -0,0 +1,472 @@ +use std::{iter, mem, sync::Arc, vec}; + +use ark_std::log2; +use ff::BatchInvert; +use ff_ext::ExtensionField; +use itertools::{Itertools, chain, izip, zip_eq}; +use transcript::Transcript; + +use crate::{ + error::VerifierError, + expression::Expression, + sumcheck::{SumcheckProof, SumcheckProverOutput}, +}; + +use super::{ + sumcheck::SumcheckClaims, + utils::{ + eq_vecs, fix_variables_ext, fix_variables_inplace, grand_product, interpolate_uni_poly, + }, +}; + +/// This is an randomly combined zerocheck protocol for the following equation: +/// \sigma = \sum_x (r^0 eq_0(X) \cdot expr_0(x) + r^1 eq_1(X) \cdot expr_1(x) + ...) +pub struct ZerocheckProverState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + /// Expressions and corresponding half eq reference. + exprs: Vec<(Expression, Vec)>, + + /// Extension field mles. + ext_mles: Vec<&'a mut [E]>, + /// Base field mles after the first round. + base_mles_after: Vec>, + /// Base field mles. + base_mles: Vec<&'a [E::BaseField]>, + /// Challenges occurred in expressions + challenges: &'a [E], + /// For each point in points, the inverse of prod_{j < i}(1 - point[i]) for 0 <= i < point.len(). + grand_prod_of_not_inv: Vec>, + + transcript: &'a mut Trans, + + num_vars: usize, +} + +impl<'a, E, Trans> ZerocheckProverState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + exprs: Vec, + points: &[&[E]], + ext_mles: Vec<&'a mut [E]>, + base_mles: Vec<&'a [E::BaseField]>, + challenges: &'a [E], + transcript: &'a mut Trans, + ) -> Self { + assert!(!(ext_mles.is_empty() && base_mles.is_empty())); + + let num_vars = if !ext_mles.is_empty() { + log2(ext_mles[0].len()) as usize + } else { + log2(base_mles[0].len()) as usize + }; + + // For each point, compute eq(point[1..], b) for b in [0, 2^{num_vars - 1}). + let (exprs, grand_prod_of_not_inv) = if num_vars > 0 { + let half_eq_evals = eq_vecs(points.iter().map(|point| &point[1..]), &vec![ + E::ONE; + exprs.len() + ]); + let exprs = zip_eq(exprs, half_eq_evals).collect_vec(); + let mut grand_prod_of_not_inv = points + .iter() + .flat_map(|point| point[1..].iter().map(|p| E::ONE - p).collect_vec()) + .collect_vec(); + BatchInvert::batch_invert(&mut grand_prod_of_not_inv); + let (_, grand_prod_of_not_inv) = + points + .iter() + .fold((0usize, vec![]), |(start, mut last_vec), point| { + let end = start + point.len() - 1; + last_vec.push(grand_product(&grand_prod_of_not_inv[start..end])); + (end, last_vec) + }); + (exprs, grand_prod_of_not_inv) + } else { + let expr = exprs.into_iter().map(|expr| (expr, vec![])).collect_vec(); + (expr, vec![]) + }; + + // The length of all mles should be 2^{num_vars}. + assert!(ext_mles.iter().all(|mle| mle.len() == 1 << num_vars)); + assert!(base_mles.iter().all(|mle| mle.len() == 1 << num_vars)); + + Self { + exprs, + ext_mles, + base_mles_after: vec![], + base_mles, + challenges, + grand_prod_of_not_inv, + transcript, + num_vars, + } + } + + pub fn prove(mut self) -> SumcheckProverOutput { + let (univariate_polys, point) = (0..self.num_vars) + .map(|round| { + let round_msg = self.compute_univariate_poly(round); + round_msg + .iter() + .for_each(|poly| self.transcript.append_field_element_exts(poly)); + + let r = self + .transcript + .get_and_append_challenge(b"sumcheck round") + .elements; + self.update_mles(&r, round); + (round_msg, r) + }) + .unzip(); + let point = Arc::new(point); + + // Send the final evaluations + let ZerocheckProverState { + ext_mles, + base_mles_after, + base_mles, + .. + } = self; + let ext_mle_evaluations = ext_mles.into_iter().map(|mle| mle[0]).collect(); + let base_mle_evaluations = if !base_mles.is_empty() { + base_mles.into_iter().map(|mle| E::from(mle[0])).collect() + } else { + base_mles_after.into_iter().map(|mle| mle[0]).collect() + }; + + SumcheckProverOutput { + proof: SumcheckProof { + univariate_polys, + ext_mle_evals: ext_mle_evaluations, + base_mle_evals: base_mle_evaluations, + }, + point, + } + } + + /// Compute f_i(X) = \sum_x eq_i(x) expr_i(X || x) + fn compute_univariate_poly(&self, round: usize) -> Vec> { + izip!(&self.exprs, &self.grand_prod_of_not_inv) + .map(|((expr, half_eq_mle), coeff)| { + let mut uni_poly = expr.zerocheck_uni_poly( + &self.ext_mles, + &self.base_mles_after, + &self.base_mles, + self.challenges, + half_eq_mle.iter().step_by(1 << round), + 1 << (self.num_vars - round), + ); + uni_poly.iter_mut().for_each(|x| *x *= coeff[round]); + uni_poly + }) + .collect_vec() + } + + fn update_mles(&mut self, r: &E, round: usize) { + // fix variables of base field polynomials. + self.ext_mles.iter_mut().for_each(|mle| { + fix_variables_inplace(mle, r); + }); + if round == 0 { + self.base_mles_after = mem::take(&mut self.base_mles) + .into_iter() + .map(|mle| fix_variables_ext(mle, r)) + .collect(); + } else { + self.base_mles_after + .iter_mut() + .for_each(|mle| fix_variables_inplace(mle, r)); + } + } +} + +pub struct ZerocheckVerifierState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + sigmas: Vec, + inv_of_one_minus_points: Vec>, + exprs: Vec<(Expression, &'a [E])>, + proof: SumcheckProof, + challenges: &'a [E], + transcript: &'a mut Trans, +} + +impl<'a, E, Trans> ZerocheckVerifierState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + pub fn new( + sigmas: Vec, + exprs: Vec, + points: Vec<&'a [E]>, + proof: SumcheckProof, + challenges: &'a [E], + transcript: &'a mut Trans, + ) -> Self { + let mut inv_of_one_minus_points = points + .iter() + .flat_map(|point| point.iter().map(|p| E::ONE - p).collect_vec()) + .collect_vec(); + BatchInvert::batch_invert(&mut inv_of_one_minus_points); + let (_, inv_of_one_minus_points) = + points + .iter() + .fold((0usize, vec![]), |(start, mut last_vec), point| { + let end = start + point.len(); + last_vec.push(inv_of_one_minus_points[start..start + point.len()].to_vec()); + (end, last_vec) + }); + + let exprs = zip_eq(exprs, points).collect_vec(); + Self { + sigmas, + inv_of_one_minus_points, + exprs, + proof, + challenges, + transcript, + } + } + + pub fn verify(self) -> Result, VerifierError> { + let ZerocheckVerifierState { + sigmas, + inv_of_one_minus_points, + exprs, + proof, + challenges, + transcript, + .. + } = self; + let SumcheckProof { + univariate_polys, + ext_mle_evals, + base_mle_evals, + } = proof; + + let (in_point, expected_claims) = univariate_polys.into_iter().enumerate().fold( + (vec![], sigmas), + |(mut last_point, last_sigmas), (round, round_msg)| { + round_msg + .iter() + .for_each(|poly| transcript.append_field_element_exts(poly)); + let r = transcript + .get_and_append_challenge(b"sumcheck round") + .elements; + last_point.push(r); + + let sigmas = izip!(&exprs, &inv_of_one_minus_points, round_msg, last_sigmas) + .map(|((_, point), inv_of_one_minus_point, poly, last_sigma)| { + let len = poly.len() + 1; + // last_sigma = (1 - point[round]) * eval_at_0 + point[round] * eval_at_1 + // eval_at_0 = (last_sigma - point[round] * eval_at_1) * inv(1 - point[round]) + let eval_at_0 = + (last_sigma - point[round] * poly[0]) * inv_of_one_minus_point[round]; + + // Evaluations on degree, degree - 1, ..., 1, 0. + let evals_iter_rev = chain![poly.into_iter().rev(), iter::once(eval_at_0)]; + + interpolate_uni_poly(evals_iter_rev, len, r) + }) + .collect_vec(); + + (last_point, sigmas) + }, + ); + + // Check the final evaluations. + for (expected_claim, (expr, _)) in izip!(expected_claims, exprs) { + let got_claim = expr.evaluate(&ext_mle_evals, &base_mle_evals, &[], &[], challenges); + + if expected_claim != got_claim { + return Err(VerifierError::ClaimNotMatch( + expr, + expected_claim, + got_claim, + )); + } + } + + let in_point = Arc::new(in_point); + Ok(SumcheckClaims { + in_point, + ext_mle_evals, + base_mle_evals, + }) + } +} + +#[cfg(test)] +mod test { + use std::array; + + use ff::Field; + use ff_ext::ExtensionField; + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + use itertools::{Itertools, izip}; + use transcript::BasicTranscript; + + use crate::{ + expression::{Constant, Expression, Witness}, + field_vec, + sumcheck::SumcheckProverOutput, + }; + + use super::{ZerocheckProverState, ZerocheckVerifierState}; + + #[allow(clippy::too_many_arguments)] + fn run<'a, E: ExtensionField>( + points: Vec<&[E]>, + exprs: Vec, + ext_mle_refs: Vec<&'a mut [E]>, + base_mle_refs: Vec<&'a [E::BaseField]>, + challenges: Vec, + + sigmas: Vec, + ) { + let mut prover_transcript = BasicTranscript::new(b"test"); + let prover = ZerocheckProverState::new( + exprs.clone(), + &points, + ext_mle_refs, + base_mle_refs, + &challenges, + &mut prover_transcript, + ); + + let SumcheckProverOutput { proof, .. } = prover.prove(); + + let mut verifier_transcript = BasicTranscript::new(b"test"); + let verifier = ZerocheckVerifierState::new( + sigmas, + exprs, + points, + proof, + &challenges, + &mut verifier_transcript, + ); + + verifier.verify().expect("verification failed"); + } + + #[test] + fn test_zerocheck_trivial() { + let f = field_vec![F, 2]; + let g = field_vec![F, 3]; + let out_point = vec![]; + + let base_mle_refs = vec![f.as_slice(), g.as_slice()]; + let f = Expression::Wit(Witness::BasePoly(0)); + let g = Expression::Wit(Witness::BasePoly(1)); + let expr = f * g; + + run( + vec![out_point.as_slice()], + vec![expr], + vec![], + base_mle_refs, + vec![], + vec![E::from(6)], + ); + } + + #[test] + fn test_zerocheck_simple() { + let f = field_vec![F, 1, 2, 3, 4, 5, 6, 7, 8]; + let out_point = field_vec![E, 2, 3, 5]; + let out_eq = field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30]; + let ans = izip!(&out_eq, &f).fold(E::ZERO, |acc, (c, x)| acc + *c * x); + + let base_mle_refs = vec![f.as_slice()]; + let expr = Expression::Wit(Witness::BasePoly(0)); + run( + vec![out_point.as_slice()], + vec![expr.clone()], + vec![], + base_mle_refs, + vec![], + vec![ans], + ); + } + + #[test] + fn test_zerocheck_logup() { + let out_point = field_vec![E, 2, 3, 5]; + let out_eq = field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30]; + + let d0 = field_vec![E, 1, 2, 3, 4, 5, 6, 7, 8]; + let d1 = field_vec![E, 9, 10, 11, 12, 13, 14, 15, 16]; + let n0 = field_vec![E, 17, 18, 19, 20, 21, 22, 23, 24]; + let n1 = field_vec![E, 25, 26, 27, 28, 29, 30, 31, 32]; + + let challenges = vec![E::from(7)]; + let ans = izip!(&out_eq, &d0, &d1, &n0, &n1) + .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) + .sum(); + + let mut ext_mles = [d0, d1, n0, n1]; + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let beta = Expression::Const(Constant::Challenge(0)); + let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); + + let ext_mles_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + run( + vec![out_point.as_slice()], + vec![expr.clone()], + ext_mles_refs, + vec![], + challenges, + vec![ans], + ); + } + + #[test] + fn test_zerocheck_multi_points() { + let points = [ + field_vec![E, 2, 3, 5], + field_vec![E, 7, 11, 13], + field_vec![E, 17, 19, 23], + ]; + let out_eqs = [ + field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30], + field_vec![E, -720, 840, 792, -924, 780, -910, -858, 1001], + field_vec![E, -6336, 6732, 6688, -7106, 6624, -7038, -6992, 7429], + ]; + let point_refs = points.iter().map(|v| v.as_slice()).collect_vec(); + + let d0 = field_vec![F, 1, 2, 3, 4, 5, 6, 7, 8]; + let d1 = field_vec![F, 9, 10, 11, 12, 13, 14, 15, 16]; + let n0 = field_vec![F, 17, 18, 19, 20, 21, 22, 23, 24]; + let n1 = field_vec![F, 25, 26, 27, 28, 29, 30, 31, 32]; + + let ans_0 = izip!(&out_eqs[0], &d0, &d1) + .map(|(eq0, d0, d1)| eq0 * d0 * d1) + .sum(); + let ans_1 = izip!(&out_eqs[1], &d0, &n1) + .map(|(eq1, d0, n1)| eq1 * d0 * n1) + .sum(); + let ans_2 = izip!(&out_eqs[2], &d1, &n0) + .map(|(eq2, d1, n0)| eq2 * d1 * n0) + .sum(); + + let base_mles = [d0, d1, n0, n1]; + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::BasePoly(i))); + + let exprs = vec![d0.clone() * d1.clone(), d0 * n1, d1 * n0]; + + let base_mle_refs = base_mles.iter().map(|v| v.as_slice()).collect_vec(); + run(point_refs, exprs, vec![], base_mle_refs, vec![], vec![ + ans_0, ans_1, ans_2, + ]); + } +} diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 4d9fe27fc..ccea110c7 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -12,7 +12,7 @@ version.workspace = true [dependencies] ark-std.workspace = true ff.workspace = true -ff_ext = { path = "../ff_ext" } +ff_ext.workspace = true goldilocks.workspace = true itertools.workspace = true rayon.workspace = true @@ -22,7 +22,7 @@ tracing.workspace = true crossbeam-channel.workspace = true multilinear_extensions = { path = "../multilinear_extensions", features = ["parallel"] } sumcheck_macro = { path = "../sumcheck_macro" } -transcript = { path = "../transcript" } +transcript.workspace = true [dev-dependencies] criterion.workspace = true @@ -30,3 +30,8 @@ criterion.workspace = true [[bench]] harness = false name = "devirgo_sumcheck" + + +[[bench]] +harness = false +name = "devirgo_sumcheck_logup" diff --git a/sumcheck/benches/devirgo_sumcheck_logup.rs b/sumcheck/benches/devirgo_sumcheck_logup.rs new file mode 100644 index 000000000..87c6e9002 --- /dev/null +++ b/sumcheck/benches/devirgo_sumcheck_logup.rs @@ -0,0 +1,215 @@ +#![allow(clippy::manual_memcpy)] +#![allow(clippy::needless_range_loop)] + +use std::{array, time::Duration}; + +use ark_std::test_rng; +use criterion::*; +use ff::Field; +use ff_ext::ExtensionField; +use itertools::Itertools; +use sumcheck::{structs::IOPProverState, util::ceil_log2}; + +use goldilocks::GoldilocksExt2; +use multilinear_extensions::{ + mle::DenseMultilinearExtension, + op_mle, + util::max_usable_threads, + virtual_poly::{ArcMultilinearExtension, VirtualPolynomial, build_eq_x_r_vec}, +}; +use transcript::BasicTranscript as Transcript; + +criterion_group!(benches, sumcheck_fn, devirgo_sumcheck_fn,); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; +const NV: [usize; 2] = [25, 26]; + +/// transpose 2d vector without clone +pub fn transpose(v: Vec>) -> Vec> { + assert!(!v.is_empty()); + let len = v[0].len(); + let mut iters: Vec<_> = v.into_iter().map(|n| n.into_iter()).collect(); + (0..len) + .map(|_| { + iters + .iter_mut() + .map(|n| n.next().unwrap()) + .collect::>() + }) + .collect() +} + +fn prepare_input<'a, E: ExtensionField + Field>( + nv: usize, +) -> (E, VirtualPolynomial<'a, E>, Vec>) { + let mut rng = test_rng(); + let max_thread_id = max_usable_threads(); + let size_log2 = ceil_log2(max_thread_id); + let point = (0..nv).map(|_| E::random(&mut rng)).collect::>(); + // generate logup constraint sigma = f0 * f1 + beta * (f0 * f3 + f1 * f2) + let fs: [ArcMultilinearExtension<'a, E>; 4] = array::from_fn(|_| { + let eval = (0..1 << nv).map(|_| E::random(&mut rng)).collect_vec(); + DenseMultilinearExtension::from_evaluations_ext_vec(nv, eval).into() + }); + let eq = build_eq_x_r_vec(&point); + let eq = DenseMultilinearExtension::from_evaluations_ext_vec(nv, eq).into(); + let polys = [vec![eq], fs.to_vec()].concat(); + let beta = E::random(&mut rng); + + let mut virtual_poly_v1 = VirtualPolynomial::new(nv); + virtual_poly_v1.add_mle_list( + vec![polys[0].clone(), polys[1].clone(), polys[2].clone()], + E::ONE, + ); + virtual_poly_v1.add_mle_list( + vec![polys[0].clone(), polys[1].clone(), polys[4].clone()], + beta, + ); + virtual_poly_v1.add_mle_list( + vec![polys[0].clone(), polys[2].clone(), polys[3].clone()], + beta, + ); + + // devirgo version + let virtual_poly_v2: Vec>> = transpose( + polys + .iter() + .map(|f| match &f.evaluations() { + multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations + .chunks((1 << nv) >> size_log2) + .map(|chunk| { + let mle: ArcMultilinearExtension<'a, E> = + DenseMultilinearExtension::::from_evaluations_vec( + nv - size_log2, + chunk.to_vec(), + ) + .into(); + mle + }) + .collect_vec(), + multilinear_extensions::mle::FieldType::Ext(evaluations) => evaluations + .chunks((1 << nv) >> size_log2) + .map(|chunk| { + let mle: ArcMultilinearExtension<'a, E> = + DenseMultilinearExtension::::from_evaluations_ext_vec( + nv - size_log2, + chunk.to_vec(), + ) + .into(); + mle + }) + .collect_vec(), + _ => unreachable!(), + }) + .collect(), + ); + let virtual_poly_v2: Vec> = virtual_poly_v2 + .into_iter() + .map(|polys| { + let mut virtual_poly = VirtualPolynomial::new(nv); + virtual_poly.add_mle_list( + vec![polys[0].clone(), polys[1].clone(), polys[2].clone()], + E::ONE, + ); + virtual_poly.add_mle_list( + vec![polys[0].clone(), polys[1].clone(), polys[4].clone()], + beta, + ); + virtual_poly.add_mle_list( + vec![polys[0].clone(), polys[2].clone(), polys[3].clone()], + beta, + ); + virtual_poly + }) + .collect(); + + let asserted_sum = fs + .iter() + .fold(vec![E::ONE; 1 << nv], |mut acc, f| { + op_mle!(f, |f| { + (0..f.len()).zip(acc.iter_mut()).for_each(|(i, acc)| { + *acc *= f[i]; + }); + acc + }) + }) + .iter() + .sum::(); + + (asserted_sum, virtual_poly_v1, virtual_poly_v2) +} + +fn sumcheck_fn(c: &mut Criterion) { + type E = GoldilocksExt2; + + for nv in NV { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("logup_sumcheck_nv_{}", nv)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), + |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + let mut prover_transcript = Transcript::::new(b"test"); + let (_, virtual_poly, _) = { prepare_input(nv) }; + + let instant = std::time::Instant::now(); + #[allow(deprecated)] + let (_sumcheck_proof_v1, _) = IOPProverState::::prove_parallel( + virtual_poly.clone(), + &mut prover_transcript, + ); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }, + ); + + group.finish(); + } +} + +fn devirgo_sumcheck_fn(c: &mut Criterion) { + type E = GoldilocksExt2; + + let threads = max_usable_threads(); + for nv in NV { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("logup_devirgo_nv_{}", nv)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_sumcheck", format!("devirgo_nv_{}", nv)), + |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + let mut prover_transcript = Transcript::::new(b"test"); + let (_, _, virtual_poly_splitted) = { prepare_input(nv) }; + + let instant = std::time::Instant::now(); + let (_sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( + threads, + virtual_poly_splitted, + &mut prover_transcript, + ); + let elapsed = instant.elapsed(); + time += elapsed; + } + time + }); + }, + ); + + group.finish(); + } +} diff --git a/transcript/Cargo.toml b/transcript/Cargo.toml index f784b689f..62f385399 100644 --- a/transcript/Cargo.toml +++ b/transcript/Cargo.toml @@ -12,7 +12,7 @@ version.workspace = true [dependencies] crossbeam-channel.workspace = true ff.workspace = true -ff_ext = { path = "../ff_ext" } +ff_ext.workspace = true goldilocks.workspace = true poseidon.workspace = true serde.workspace = true