diff --git a/crypto/src/subprotocols/sumcheck.rs b/crypto/src/subprotocols/sumcheck.rs index 8ac3c50e85..3439f27743 100644 --- a/crypto/src/subprotocols/sumcheck.rs +++ b/crypto/src/subprotocols/sumcheck.rs @@ -89,7 +89,7 @@ pub enum SumcheckError { impl Display for SumcheckError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { - SumcheckError::InvalidProof => write!(f, "Sumcheck Proof Invalid") + SumcheckError::InvalidProof => write!(f, "Sumcheck Proof Invalid"), } } } @@ -161,12 +161,14 @@ where poly_b.fix_variable(&challenge); } - (SumcheckProof { - poly: poly_a.clone(), - sum: sum.clone(), - round_uni_polys, - }, - challenges) + ( + SumcheckProof { + poly: poly_a.clone(), + sum: sum.clone(), + round_uni_polys, + }, + challenges, + ) } pub fn prove_quadratic_batched( @@ -245,16 +247,17 @@ where transcript: &mut impl Transcript, ) -> (SumcheckProof, Vec>) where - E: Fn(&FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, { + E: Fn(&FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, + { let mut round_uni_polys: Vec>> = Vec::with_capacity(poly_a.num_vars()); let mut challenges = Vec::with_capacity(poly_a.num_vars()); let mut prev_round_claim = sum.clone(); for _ in 0..poly_a.num_vars() { - let poly = { - let (eval_point_0, eval_point_2, eval_point_3) = eval_points_cubic(poly_a, poly_b, poly_c, &comb_func); + let (eval_point_0, eval_point_2, eval_point_3) = + eval_points_cubic(poly_a, poly_b, poly_c, &comb_func); let evals = vec![ eval_point_0.clone(), prev_round_claim - eval_point_0, @@ -279,12 +282,14 @@ where round_uni_polys.push(poly); } - (SumcheckProof { - poly: poly_a.clone(), - sum: sum.clone(), - round_uni_polys, - }, - challenges) + ( + SumcheckProof { + poly: poly_a.clone(), + sum: sum.clone(), + round_uni_polys, + }, + challenges, + ) } pub fn prove_cubic_batched( @@ -297,7 +302,8 @@ where transcript: &mut impl Transcript, ) -> (SumcheckProof, Vec>) where - E: Fn(&FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, { + E: Fn(&FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, + { let mut round_uni_polys: Vec>> = Vec::with_capacity(poly_a[0].num_vars()); let mut challenges = Vec::with_capacity(poly_a[0].num_vars()); @@ -330,7 +336,7 @@ where evals_combined_0.clone(), prev_round_claim - evals_combined_0, evals_combined_2, - evals_combined_3 + evals_combined_3, ]; let poly = Polynomial::new(&evals); @@ -350,12 +356,14 @@ where round_uni_polys.push(poly); } - (SumcheckProof { - poly: poly_a[0].clone(), - sum: sum.clone(), - round_uni_polys, - }, - challenges) + ( + SumcheckProof { + poly: poly_a[0].clone(), + sum: sum.clone(), + round_uni_polys, + }, + challenges, + ) } // Special instance of sumcheck for a cubic polynomial with an additional additive term: @@ -370,8 +378,14 @@ where transcript: &mut impl Transcript, ) -> (SumcheckProof, Vec>) where - E: Fn(&FieldElement, &FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, { - + E: Fn( + &FieldElement, + &FieldElement, + &FieldElement, + &FieldElement, + ) -> FieldElement + + Sync, + { let mut round_uni_polys: Vec>> = Vec::with_capacity(poly_a.num_vars()); let mut challenges = Vec::with_capacity(poly_a.num_vars()); @@ -382,41 +396,47 @@ where let (eval_point_0, eval_point_2, eval_point_3) = { let len = poly_a.len() / 2; (0..len) - .into_par_iter() - .map(|i| { - // eval 0: bound_func is A(low) - let eval_point_0 = comb_func(&poly_a[i], &poly_b[i], &poly_c[i], &poly_d[i]); - - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_a_point_2 = &poly_a[len + i] + &poly_a[len + i] - &poly_a[i]; - let poly_b_point_2 = &poly_b[len + i] + &poly_b[len + i] - &poly_b[i]; - let poly_c_point_2 = &poly_c[len + i] + &poly_c[len + i] - &poly_c[i]; - let poly_d_point_2 = &poly_d[len + i] + &poly_d[len + i] - &poly_c[i]; - let eval_point_2 = comb_func( - &poly_a_point_2, - &poly_b_point_2, - &poly_c_point_2, - &poly_d_point_2, - ); - - // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_a_point_3 = poly_a_point_2 + &poly_a[len + i] - &poly_a[i]; - let poly_b_point_3 = poly_b_point_2 + &poly_b[len + i] - &poly_b[i]; - let poly_c_point_3 = poly_c_point_2 + &poly_c[len + i] - &poly_c[i]; - let poly_d_point_3 = poly_d_point_2 + &poly_d[len + i] - &poly_d[i]; - let eval_point_3 = comb_func( - &poly_a_point_3, - &poly_b_point_3, - &poly_c_point_3, - &poly_d_point_3, - ); - (eval_point_0, eval_point_2, eval_point_3) - }) - .reduce( - || (FieldElement::zero(), FieldElement - ::zero(), FieldElement::zero()), - |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2), - ) + .into_par_iter() + .map(|i| { + // eval 0: bound_func is A(low) + let eval_point_0 = + comb_func(&poly_a[i], &poly_b[i], &poly_c[i], &poly_d[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_a_point_2 = &poly_a[len + i] + &poly_a[len + i] - &poly_a[i]; + let poly_b_point_2 = &poly_b[len + i] + &poly_b[len + i] - &poly_b[i]; + let poly_c_point_2 = &poly_c[len + i] + &poly_c[len + i] - &poly_c[i]; + let poly_d_point_2 = &poly_d[len + i] + &poly_d[len + i] - &poly_c[i]; + let eval_point_2 = comb_func( + &poly_a_point_2, + &poly_b_point_2, + &poly_c_point_2, + &poly_d_point_2, + ); + + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_a_point_3 = poly_a_point_2 + &poly_a[len + i] - &poly_a[i]; + let poly_b_point_3 = poly_b_point_2 + &poly_b[len + i] - &poly_b[i]; + let poly_c_point_3 = poly_c_point_2 + &poly_c[len + i] - &poly_c[i]; + let poly_d_point_3 = poly_d_point_2 + &poly_d[len + i] - &poly_d[i]; + let eval_point_3 = comb_func( + &poly_a_point_3, + &poly_b_point_3, + &poly_c_point_3, + &poly_d_point_3, + ); + (eval_point_0, eval_point_2, eval_point_3) + }) + .reduce( + || { + ( + FieldElement::zero(), + FieldElement::zero(), + FieldElement::zero(), + ) + }, + |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2), + ) }; let evals = vec![ eval_point_0.clone(), @@ -443,12 +463,14 @@ where poly_d.fix_variable(&challenge); } - (SumcheckProof { - poly: poly_a.clone(), - sum: sum.clone(), - round_uni_polys, - }, - challenges) + ( + SumcheckProof { + poly: poly_a.clone(), + sum: sum.clone(), + round_uni_polys, + }, + challenges, + ) } // Create a test for this @@ -496,11 +518,14 @@ where poly.fix_variable(&challenge); } - (SumcheckProof { - poly: poly.clone(), - sum: sum.clone(), - round_uni_polys, - }, challenges) + ( + SumcheckProof { + poly: poly.clone(), + sum: sum.clone(), + round_uni_polys, + }, + challenges, + ) } // Verifies a sumcheck proof returning the claimed evaluation and random points used during sumcheck rounds @@ -518,7 +543,8 @@ where // Verify degree bound // check if G_k(0) + G_k(1) = e - if poly.evaluate(&FieldElement::::zero()) + poly.evaluate(&FieldElement::one()) != e { + if poly.evaluate(&FieldElement::::zero()) + poly.evaluate(&FieldElement::one()) != e + { return Err(SumcheckError::InvalidProof); } //transcript.append(poly); @@ -544,7 +570,7 @@ mod test { type F = U64GoldilocksPrimeField; type FE = FieldElement; - pub fn index_to_field_bitvector( value: usize, bits: usize) -> Vec> { + pub fn index_to_field_bitvector(value: usize, bits: usize) -> Vec> { let mut vec: Vec> = Vec::with_capacity(bits); for i in (0..bits).rev() { @@ -564,7 +590,7 @@ mod test { let num_evals = (2usize).pow(num_vars as u32); let mut evals: Vec> = Vec::with_capacity(num_evals); for i in 0..num_evals { - evals.push(FieldElement::from(8 + i as u64)); + evals.push(FieldElement::from(8 + i as u64)); } let mut a: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); @@ -573,20 +599,24 @@ mod test { let mut claim = FieldElement::::zero(); for i in 0..num_evals { - - claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() - * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() - * c.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); + claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + * c.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); } - let comb_func_prod = - |a: &FieldElement, b: &FieldElement, c: &FieldElement| -> FieldElement { a * b * c }; + let comb_func_prod = |a: &FieldElement, + b: &FieldElement, + c: &FieldElement| + -> FieldElement { a * b * c }; - let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube + let r = vec![ + FieldElement::from(3), + FieldElement::from(1), + FieldElement::from(3), + ]; // point 0,0,0 within the boolean hypercube let mut transcript = DefaultTranscript::new(); - let (proof, challenges) = - Sumcheck::::prove_cubic( + let (proof, challenges) = Sumcheck::::prove_cubic( &claim, &mut a, &mut b, @@ -613,14 +643,10 @@ mod test { } #[test] - fn prove_cubic_batched() { - - } + fn prove_cubic_batched() {} #[test] - fn prove_cubic_additive() { - - } + fn prove_cubic_additive() {} #[test] fn prove_quad() { @@ -629,7 +655,7 @@ mod test { let num_evals = (2usize).pow(num_vars as u32); let mut evals: Vec> = Vec::with_capacity(num_evals); for i in 0..num_evals { - evals.push(FieldElement::from(8 + i as u64)); + evals.push(FieldElement::from(8 + i as u64)); } let mut a: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); @@ -637,25 +663,22 @@ mod test { let mut claim = FieldElement::::zero(); for i in 0..num_evals { - - claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() - * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); + claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); } let comb_func_prod = - |a: &FieldElement, b: &FieldElement| -> FieldElement { a * b }; + |a: &FieldElement, b: &FieldElement| -> FieldElement { a * b }; - let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube + let r = vec![ + FieldElement::from(3), + FieldElement::from(1), + FieldElement::from(3), + ]; // point 0,0,0 within the boolean hypercube let mut transcript = DefaultTranscript::new(); let (proof, challenges) = - Sumcheck::::prove_quadratic( - &claim, - &mut a, - &mut b, - comb_func_prod, - &mut transcript, - ); + Sumcheck::::prove_quadratic(&claim, &mut a, &mut b, comb_func_prod, &mut transcript); let mut transcript = DefaultTranscript::new(); let verify_result = Sumcheck::verify(proof, &mut transcript); @@ -670,7 +693,7 @@ mod test { let b = b.evaluate(&challenges.as_slice()).unwrap(); let oracle_query = a * b; - assert_eq!(verify_evaluation, oracle_query); + assert_eq!(verify_evaluation, oracle_query); } #[test] @@ -680,33 +703,32 @@ mod test { let num_evals = (2usize).pow(num_vars as u32); let mut evals: Vec> = Vec::with_capacity(num_evals); for i in 0..num_evals { - evals.push(FieldElement::from(8 + i as u64)); + evals.push(FieldElement::from(8 + i as u64)); } - let mut a: Vec> = vec![DenseMultilinearPolynomial::new(evals.clone()); 3]; - let mut b: Vec> = vec![DenseMultilinearPolynomial::new(evals.clone()); 3]; + let mut a: Vec> = + vec![DenseMultilinearPolynomial::new(evals.clone()); 3]; + let mut b: Vec> = + vec![DenseMultilinearPolynomial::new(evals.clone()); 3]; let mut claim = FieldElement::::zero(); for i in 0..num_evals { - - claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() - * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); + claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() + * b.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap(); } let comb_func_prod = - |a: &FieldElement, b: &FieldElement| -> FieldElement { a * b }; + |a: &FieldElement, b: &FieldElement| -> FieldElement { a * b }; - let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube + let r = vec![ + FieldElement::from(3), + FieldElement::from(1), + FieldElement::from(3), + ]; // point 0,0,0 within the boolean hypercube let mut transcript = DefaultTranscript::new(); let (proof, challenges) = - Sumcheck::::prove_quadratic( - &claim, - &mut a, - &mut b, - comb_func_prod, - &mut transcript, - ); + Sumcheck::::prove_quadratic(&claim, &mut a, &mut b, comb_func_prod, &mut transcript); let mut transcript = DefaultTranscript::new(); let verify_result = Sumcheck::verify(proof, &mut transcript); @@ -721,7 +743,7 @@ mod test { let b = b.evaluate(&challenges.as_slice()).unwrap(); let oracle_query = a * b; - assert_eq!(verify_evaluation, oracle_query); + assert_eq!(verify_evaluation, oracle_query); } #[test] @@ -731,7 +753,7 @@ mod test { let num_evals = (2usize).pow(num_vars as u32); let mut evals: Vec> = Vec::with_capacity(num_evals); for i in 0..num_evals { - evals.push(FieldElement::from(8 + i as u64)); + evals.push(FieldElement::from(8 + i as u64)); } let mut a: DenseMultilinearPolynomial = DenseMultilinearPolynomial::new(evals.clone()); @@ -741,15 +763,14 @@ mod test { claim += a.evaluate(&index_to_field_bitvector(i, num_vars)).unwrap() } - let r = vec![FieldElement::from(3), FieldElement::from(1), FieldElement::from(3)]; // point 0,0,0 within the boolean hypercube + let r = vec![ + FieldElement::from(3), + FieldElement::from(1), + FieldElement::from(3), + ]; // point 0,0,0 within the boolean hypercube let mut transcript = DefaultTranscript::new(); - let (proof, challenges) = - Sumcheck::::prove_single( - &mut a, - &claim, - &mut transcript, - ); + let (proof, challenges) = Sumcheck::::prove_single(&mut a, &claim, &mut transcript); let mut transcript = DefaultTranscript::new(); let verify_result = Sumcheck::verify(proof, &mut transcript); @@ -759,6 +780,9 @@ mod test { assert_eq!(challenges, verify_randomness); assert_eq!(challenges, r); - assert_eq!(verify_evaluation, a.evaluate(&challenges.as_slice()).unwrap()); + assert_eq!( + verify_evaluation, + a.evaluate(&challenges.as_slice()).unwrap() + ); } } diff --git a/math/src/fft/test_helpers.rs b/math/src/fft/test_helpers.rs index b6e483c08a..c12fa2cda1 100644 --- a/math/src/fft/test_helpers.rs +++ b/math/src/fft/test_helpers.rs @@ -34,7 +34,13 @@ pub fn naive_matrix_dft_test(input: &[FieldElement]) -> Vec]) -> Result, MultilinearError> { // r must have a value for each variable if r.len() != self.num_vars() { - return Err(MultilinearError::IncorrectNumberofEvaluationPoints(r.len(), self.num_vars())); + return Err(MultilinearError::IncorrectNumberofEvaluationPoints( + r.len(), + self.num_vars(), + )); } let mut chis: Vec> = @@ -77,7 +80,10 @@ where } } if chis.len() != self.evals.len() { - return Err(MultilinearError::ChisAndEvalsMismatch(chis.len(), self.evals.len())); + return Err(MultilinearError::ChisAndEvalsMismatch( + chis.len(), + self.evals.len(), + )); } Ok((0..chis.len()) .into_par_iter() @@ -101,7 +107,10 @@ where } } if chis.len() != evals.len() { - return Err(MultilinearError::ChisAndEvalsMismatch(chis.len(), evals.len())); + return Err(MultilinearError::ChisAndEvalsMismatch( + chis.len(), + evals.len(), + )); } Ok((0..evals.len()).map(|i| &evals[i] * &chis[i]).sum()) } diff --git a/math/src/polynomial/error.rs b/math/src/polynomial/error.rs index 2264c90df1..e2f7805ccd 100644 --- a/math/src/polynomial/error.rs +++ b/math/src/polynomial/error.rs @@ -18,4 +18,4 @@ impl Display for MultilinearError { } #[cfg(feature = "std")] -impl std::error::Error for MultilinearError {} \ No newline at end of file +impl std::error::Error for MultilinearError {}