From 79cb855fe94318d07f0e31cb6ea06c8032ee0146 Mon Sep 17 00:00:00 2001 From: PatStiles Date: Wed, 10 Jan 2024 00:17:04 -0600 Subject: [PATCH] finish impl --- crypto/src/subprotocols/sumcheck.rs | 392 ++++++++++++++++++++++------ 1 file changed, 311 insertions(+), 81 deletions(-) diff --git a/crypto/src/subprotocols/sumcheck.rs b/crypto/src/subprotocols/sumcheck.rs index fce4e842a4..8ac3c50e85 100644 --- a/crypto/src/subprotocols/sumcheck.rs +++ b/crypto/src/subprotocols/sumcheck.rs @@ -1,3 +1,4 @@ +use core::fmt::Display; use std::marker::PhantomData; use crate::fiat_shamir::transcript::Transcript; @@ -80,6 +81,19 @@ where ) } +#[derive(Debug)] +pub enum SumcheckError { + InvalidProof, +} + +impl Display for SumcheckError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + SumcheckError::InvalidProof => write!(f, "Sumcheck Proof Invalid") + } + } +} + // Proof attesting to sum over the boolean hypercube #[derive(Debug)] pub struct SumcheckProof @@ -114,7 +128,7 @@ where poly_b: &mut DenseMultilinearPolynomial, comb_func: E, transcript: &mut impl Transcript, - ) -> SumcheckProof + ) -> (SumcheckProof, Vec>) where E: Fn(&FieldElement, &FieldElement) -> FieldElement + Sync, { @@ -147,11 +161,12 @@ where poly_b.fix_variable(&challenge); } - SumcheckProof { + (SumcheckProof { poly: poly_a.clone(), sum: sum.clone(), round_uni_polys, - } + }, + challenges) } pub fn prove_quadratic_batched( @@ -228,7 +243,7 @@ where poly_c: &mut DenseMultilinearPolynomial, comb_func: E, transcript: &mut impl Transcript, - ) -> SumcheckProof + ) -> (SumcheckProof, Vec>) where E: Fn(&FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, { let mut round_uni_polys: Vec>> = @@ -264,11 +279,12 @@ where round_uni_polys.push(poly); } - SumcheckProof { + (SumcheckProof { poly: poly_a.clone(), sum: sum.clone(), round_uni_polys, - } + }, + challenges) } pub fn prove_cubic_batched( @@ -279,7 +295,7 @@ where powers: Option<&[FieldElement]>, comb_func: E, transcript: &mut impl Transcript, - ) -> SumcheckProof + ) -> (SumcheckProof, Vec>) where E: Fn(&FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, { let mut round_uni_polys: Vec>> = @@ -334,27 +350,105 @@ where round_uni_polys.push(poly); } - SumcheckProof { + (SumcheckProof { poly: poly_a[0].clone(), sum: sum.clone(), round_uni_polys, - } + }, + challenges) } // Special instance of sumcheck for a cubic polynomial with an additional additive term: // this is used in Spartan: (a * ((b * c) - d)) pub fn prove_cubic_additive_term( sum: &FieldElement, - poly_a: &DenseMultilinearPolynomial, - poly_b: &DenseMultilinearPolynomial, - poly_c: &DenseMultilinearPolynomial, - poly_d: &DenseMultilinearPolynomial, - comb_func: F, + poly_a: &mut DenseMultilinearPolynomial, + poly_b: &mut DenseMultilinearPolynomial, + poly_c: &mut DenseMultilinearPolynomial, + poly_d: &mut DenseMultilinearPolynomial, + comb_func: E, transcript: &mut impl Transcript, - ) -> SumcheckProof + ) -> (SumcheckProof, Vec>) where - E: Fn(&FieldElement, &FieldElement) -> FieldElement + Sync, { - todo!() + E: Fn(&FieldElement, &FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, { + + let mut round_uni_polys: Vec>> = + Vec::with_capacity(poly_a.num_vars()); + let mut challenges = Vec::with_capacity(poly_a.num_vars()); + let mut prev_round_claim = sum.clone(); + + for _ in 0..poly_a.num_vars() { + let poly = { + let (eval_point_0, eval_point_2, eval_point_3) = { + let len = poly_a.len() / 2; + (0..len) + .into_par_iter() + .map(|i| { + // eval 0: bound_func is A(low) + let eval_point_0 = comb_func(&poly_a[i], &poly_b[i], &poly_c[i], &poly_d[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_a_point_2 = &poly_a[len + i] + &poly_a[len + i] - &poly_a[i]; + let poly_b_point_2 = &poly_b[len + i] + &poly_b[len + i] - &poly_b[i]; + let poly_c_point_2 = &poly_c[len + i] + &poly_c[len + i] - &poly_c[i]; + let poly_d_point_2 = &poly_d[len + i] + &poly_d[len + i] - &poly_c[i]; + let eval_point_2 = comb_func( + &poly_a_point_2, + &poly_b_point_2, + &poly_c_point_2, + &poly_d_point_2, + ); + + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_a_point_3 = poly_a_point_2 + &poly_a[len + i] - &poly_a[i]; + let poly_b_point_3 = poly_b_point_2 + &poly_b[len + i] - &poly_b[i]; + let poly_c_point_3 = poly_c_point_2 + &poly_c[len + i] - &poly_c[i]; + let poly_d_point_3 = poly_d_point_2 + &poly_d[len + i] - &poly_d[i]; + let eval_point_3 = comb_func( + &poly_a_point_3, + &poly_b_point_3, + &poly_c_point_3, + &poly_d_point_3, + ); + (eval_point_0, eval_point_2, eval_point_3) + }) + .reduce( + || (FieldElement::zero(), FieldElement + ::zero(), FieldElement::zero()), + |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2), + ) + }; + let evals = vec![ + eval_point_0.clone(), + prev_round_claim - eval_point_0, + eval_point_2, + eval_point_3, + ]; + Polynomial::new(&evals) + }; + + // TODO append the prover's message to the transcript + + // Squeeze Verifier Challenge for next round + let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap(); + challenges.push(challenge.clone()); + + prev_round_claim = poly.evaluate(&challenge); + round_uni_polys.push(poly); + + // bound all tables to the verifier's challenege + poly_a.fix_variable(&challenge); + poly_b.fix_variable(&challenge); + poly_c.fix_variable(&challenge); + poly_d.fix_variable(&challenge); + } + + (SumcheckProof { + poly: poly_a.clone(), + sum: sum.clone(), + round_uni_polys, + }, + challenges) } // Create a test for this @@ -362,7 +456,7 @@ where poly: &mut DenseMultilinearPolynomial, sum: &FieldElement, transcript: &mut impl Transcript, - ) -> SumcheckProof { + ) -> (SumcheckProof, Vec>) { let mut round_uni_polys: Vec>> = Vec::with_capacity(poly.num_vars()); let mut challenges = Vec::with_capacity(poly.num_vars()); @@ -402,18 +496,18 @@ where poly.fix_variable(&challenge); } - SumcheckProof { + (SumcheckProof { poly: poly.clone(), sum: sum.clone(), round_uni_polys, - } + }, challenges) } // Verifies a sumcheck proof returning the claimed evaluation and random points used during sumcheck rounds pub fn verify( proof: SumcheckProof, transcript: &mut impl Transcript, - ) -> (FieldElement, Vec>) { + ) -> Result<(FieldElement, Vec>), SumcheckError> { let mut e = proof.sum.clone(); let mut r: Vec> = Vec::with_capacity(proof.poly.num_vars()); @@ -424,10 +518,9 @@ where // Verify degree bound // check if G_k(0) + G_k(1) = e - assert_eq!( - poly.evaluate(&FieldElement::::zero()) + poly.evaluate(&FieldElement::one()), - e - ); + if poly.evaluate(&FieldElement::::zero()) + poly.evaluate(&FieldElement::one()) != e { + return Err(SumcheckError::InvalidProof); + } //transcript.append(poly); let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap(); @@ -435,23 +528,21 @@ where e = poly.evaluate(&challenge); } - (proof.sum, r) + Ok((proof.sum, r)) } } #[cfg(test)] mod test { - use core::num; - use crate::fiat_shamir::default_transcript::DefaultTranscript; - use crate::subprotocols::sumcheck::{Sumcheck, SumcheckProof}; + use crate::subprotocols::sumcheck::Sumcheck; use lambdaworks_math::field::element::FieldElement; - use lambdaworks_math::field::fields::fft_friendly::babybear::Babybear31PrimeField; + use lambdaworks_math::field::fields::fft_friendly::u64_goldilocks::U64GoldilocksPrimeField; use lambdaworks_math::field::traits::IsField; use lambdaworks_math::polynomial::dense_multilinear_poly::DenseMultilinearPolynomial; - type F = Babybear31PrimeField; - type FE = FieldElement; + type F = U64GoldilocksPrimeField; + type FE = FieldElement; pub fn index_to_field_bitvector( value: usize, bits: usize) -> Vec> { let mut vec: Vec> = Vec::with_capacity(bits); @@ -468,67 +559,206 @@ mod test { #[test] fn prove_cubic() { - // Create three dense polynomials (all the same) - let num_vars = 3; - let num_evals = (2usize).pow(num_vars as u32); - let mut evals: Vec> = Vec::with_capacity(num_evals); - for i in 0..num_evals { - evals.push(FieldElement::from(8 + i as u64)); - } + // Create three dense polynomials (all the same) + let num_vars = 3; + let num_evals = (2usize).pow(num_vars as u32); + let mut evals: Vec> = Vec::with_capacity(num_evals); + for i in 0..num_evals { + evals.push(FieldElement::from(8 + i as u64)); + } - let a: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); - let b: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); - let c: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); + let mut a: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); + let mut b: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); + let mut c: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); - let mut claim = FieldElement::::zero(); - for i in 0..num_evals { + let mut claim = FieldElement::::zero(); + for i in 0..num_evals { + + claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + * c.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); + } - claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() - * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() - * c.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); + let comb_func_prod = + |a: &FieldElement, b: &FieldElement, c: &FieldElement| -> FieldElement { a * b * c }; + + let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube + + let mut transcript = DefaultTranscript::new(); + let (proof, challenges) = + Sumcheck::::prove_cubic( + &claim, + &mut a, + &mut b, + &mut c, + comb_func_prod, + &mut transcript, + ); + + let mut transcript = DefaultTranscript::new(); + let verify_result = Sumcheck::verify(proof, &mut transcript); + assert!(verify_result.is_ok()); + + let (verify_evaluation, verify_randomness) = verify_result.unwrap(); + assert_eq!(challenges, verify_randomness); + assert_eq!(challenges, r); + + // Consider this the opening proof to a(r) * b(r) * c(r) + let a = a.evaluate(&challenges.as_slice()).unwrap(); + let b = b.evaluate(&challenges.as_slice()).unwrap(); + let c = c.evaluate(&challenges.as_slice()).unwrap(); + + let oracle_query = a * b * c; + assert_eq!(verify_evaluation, oracle_query); } - let mut polys = [a.clone(), b.clone(), c.clone()]; - - let comb_func_prod = - |polys: &[FieldElement; 3]| -> FieldElement { polys.iter().fold(FieldElement::one(), |acc, poly| acc * *poly) }; - - let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube - - let mut transcript = DefaultTranscript::new(); - let proof = - Sumcheck::::prove_cubic( - &claim, - &poly_a, - &poly_b, - &poly_c, - comb_func_prod, - &mut transcript, - ); - - let mut transcript = DefaultTranscript::new(); - let verify_result = Sumcheck::verify(proof, &mut transcript); - assert!(verify_result.is_ok()); - - let (verify_evaluation, verify_randomness) = verify_result.unwrap(); - assert_eq!(prove_randomness, verify_randomness); - assert_eq!(prove_randomness, r); - - // Consider this the opening proof to a(r) * b(r) * c(r) - let a = a.evaluate(prove_randomness.as_slice()).unwrap(); - let b = b.evaluate(prove_randomness.as_slice()).unwrap(); - let c = c.evaluate(prove_randomness.as_slice()).unwrap(); - - let oracle_query = a * b * c; - assert_eq!(verify_evaluation, oracle_query); + + #[test] + fn prove_cubic_batched() { + + } + + #[test] + fn prove_cubic_additive() { + } #[test] fn prove_quad() { - + // Create three dense polynomials (all the same) + let num_vars = 3; + let num_evals = (2usize).pow(num_vars as u32); + let mut evals: Vec> = Vec::with_capacity(num_evals); + for i in 0..num_evals { + evals.push(FieldElement::from(8 + i as u64)); + } + + let mut a: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); + let mut b: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); + + let mut claim = FieldElement::::zero(); + for i in 0..num_evals { + + claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); + } + + let comb_func_prod = + |a: &FieldElement, b: &FieldElement| -> FieldElement { a * b }; + + let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube + + let mut transcript = DefaultTranscript::new(); + let (proof, challenges) = + Sumcheck::::prove_quadratic( + &claim, + &mut a, + &mut b, + comb_func_prod, + &mut transcript, + ); + + let mut transcript = DefaultTranscript::new(); + let verify_result = Sumcheck::verify(proof, &mut transcript); + assert!(verify_result.is_ok()); + + let (verify_evaluation, verify_randomness) = verify_result.unwrap(); + assert_eq!(challenges, verify_randomness); + assert_eq!(challenges, r); + + // Consider this the opening proof to a(r) * b(r) + let a = a.evaluate(&challenges.as_slice()).unwrap(); + let b = b.evaluate(&challenges.as_slice()).unwrap(); + + let oracle_query = a * b; + assert_eq!(verify_evaluation, oracle_query); + } + + #[test] + fn prove_quad_batched() { + // Create three dense polynomials (all the same) + let num_vars = 3; + let num_evals = (2usize).pow(num_vars as u32); + let mut evals: Vec> = Vec::with_capacity(num_evals); + for i in 0..num_evals { + evals.push(FieldElement::from(8 + i as u64)); + } + + let mut a: Vec> = vec![DenseMultilinearPolynomial::new(evals.clone()); 3]; + let mut b: Vec> = vec![DenseMultilinearPolynomial::new(evals.clone()); 3]; + + let mut claim = FieldElement::::zero(); + for i in 0..num_evals { + + claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); + } + + let comb_func_prod = + |a: &FieldElement, b: &FieldElement| -> FieldElement { a * b }; + + let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube + + let mut transcript = DefaultTranscript::new(); + let (proof, challenges) = + Sumcheck::::prove_quadratic( + &claim, + &mut a, + &mut b, + comb_func_prod, + &mut transcript, + ); + + let mut transcript = DefaultTranscript::new(); + let verify_result = Sumcheck::verify(proof, &mut transcript); + assert!(verify_result.is_ok()); + + let (verify_evaluation, verify_randomness) = verify_result.unwrap(); + assert_eq!(challenges, verify_randomness); + assert_eq!(challenges, r); + + // Consider this the opening proof to a(r) * b(r) + let a = a.evaluate(&challenges.as_slice()).unwrap(); + let b = b.evaluate(&challenges.as_slice()).unwrap(); + + let oracle_query = a * b; + assert_eq!(verify_evaluation, oracle_query); } #[test] fn prove_single() { + // Create three dense polynomials (all the same) + let num_vars = 3; + let num_evals = (2usize).pow(num_vars as u32); + let mut evals: Vec> = Vec::with_capacity(num_evals); + for i in 0..num_evals { + evals.push(FieldElement::from(8 + i as u64)); + } + + let mut a: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); + + let mut claim = FieldElement::::zero(); + for i in 0..num_evals { + claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + } + + let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube + + let mut transcript = DefaultTranscript::new(); + let (proof, challenges) = + Sumcheck::::prove_single( + &mut a, + &claim, + &mut transcript, + ); + + let mut transcript = DefaultTranscript::new(); + let verify_result = Sumcheck::verify(proof, &mut transcript); + assert!(verify_result.is_ok()); + + let (verify_evaluation, verify_randomness) = verify_result.unwrap(); + assert_eq!(challenges, verify_randomness); + assert_eq!(challenges, r); + assert_eq!(verify_evaluation, a.evaluate(&challenges.as_slice()).unwrap()); } }