Skip to content

Commit

Permalink
fixes var names
Browse files Browse the repository at this point in the history
  • Loading branch information
PatStiles committed Jan 31, 2024
1 parent 61e5bd8 commit 83619cb
Showing 1 changed file with 59 additions and 77 deletions.
136 changes: 59 additions & 77 deletions crypto/src/subprotocols/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use core::fmt::Display;
use std::marker::PhantomData;

use crate::fiat_shamir::transcript::Transcript;
use lambdaworks_math::traits::AsBytes;
use lambdaworks_math::field::element::FieldElement;
use lambdaworks_math::field::traits::{IsField, IsPrimeField};
use lambdaworks_math::polynomial::{
dense_multilinear_poly::DenseMultilinearPolynomial, Polynomial,
};
use lambdaworks_math::traits::AsBytes;
use lambdaworks_math::traits::ByteConversion;
#[cfg(feature = "parallel")]
use rayon::iter::{IntoParallelIterator, IntoParallelIterator, ParallelIterator};
Expand Down Expand Up @@ -37,8 +37,9 @@ where
(eval_0, eval_2)
});
#[cfg(not(feature = "parallel"))]
let res = res.fold((FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1));
let res = res.fold((FieldElement::zero(), FieldElement::zero()), |a, b| {
(a.0 + b.0, a.1 + b.1)
});
#[cfg(feature = "parallel")]
let res = res.reduce(
|| (FieldElement::zero(), FieldElement::zero()),
Expand Down Expand Up @@ -82,11 +83,23 @@ where
(eval_0, eval_2, eval_3)
});
#[cfg(not(feature = "parallel"))]
let res = res.fold((FieldElement::zero(), FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2));
let res = res.fold(
(
FieldElement::zero(),
FieldElement::zero(),
FieldElement::zero(),
),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
);
#[cfg(feature = "parallel")]
let res = res.reduce(
|| (FieldElement::zero(), FieldElement::zero(), FieldElement::zero()),
|| {
(
FieldElement::zero(),
FieldElement::zero(),
FieldElement::zero(),
)
},
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
);

Expand Down Expand Up @@ -150,23 +163,24 @@ where
let mut prev_round_claim = sum.clone();

for _ in 0..poly_a.num_vars() {
let poly = {
let round_poly = {
let (eval_0, eval_2) = eval_points_quadratic(poly_a, poly_b, &comb_func);
let evals = vec![eval_0.clone(), prev_round_claim - eval_0, eval_2];
Polynomial::new(&evals)
};

// append round's Univariate polynomial to transcript
transcript.append(&round_poly.as_bytes());

// Squeeze Verifier Challenge for next round
let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap();
challenges.push(challenge.clone());

// compute next claim
prev_round_claim = poly.evaluate(&challenge);
prev_round_claim = round_poly.evaluate(&challenge);

// add univariate polynomial for this round to the proof
round_uni_polys.push(poly);
round_uni_polys.push(round_poly);

// fix next variable of poly
poly_a.fix_variable(&challenge);
Expand Down Expand Up @@ -225,9 +239,10 @@ where
prev_round_claim - evals_combined_0,
evals_combined_2,
];
let poly = Polynomial::new(&evals);
let round_poly = Polynomial::new(&evals);

// TODO append the prover's message to the transcript
transcript.append(&round_poly.as_bytes());

// Squeeze Verifier Challenge for next round
let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap();
Expand All @@ -239,8 +254,8 @@ where
poly_b.fix_variable(&challenge);
}

prev_round_claim = poly.evaluate(&challenge);
round_uni_polys.push(poly);
prev_round_claim = round_poly.evaluate(&challenge);
round_uni_polys.push(round_poly);
}

SumcheckProof {
Expand All @@ -267,7 +282,7 @@ where
let mut prev_round_claim = sum.clone();

for _ in 0..poly_a.num_vars() {
let poly = {
let round_poly = {
let (eval_point_0, eval_point_2, eval_point_3) =
eval_points_cubic(poly_a, poly_b, poly_c, &comb_func);
let evals = vec![
Expand All @@ -280,6 +295,7 @@ where
};

// TODO append the prover's message to the transcript
transcript.append(&round_poly.as_bytes());

// Squeeze Verifier Challenge for next round
let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap();
Expand All @@ -290,8 +306,8 @@ where
poly_b.fix_variable(&challenge);
poly_c.fix_variable(&challenge);

prev_round_claim = poly.evaluate(&challenge);
round_uni_polys.push(poly);
prev_round_claim = round_poly.evaluate(&challenge);
round_uni_polys.push(round_poly);
}

(
Expand Down Expand Up @@ -445,11 +461,23 @@ where
(eval_point_0, eval_point_2, eval_point_3)
});
#[cfg(not(feature = "parallel"))]
let res = res.fold((FieldElement::zero(), FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2));
let res = res.fold(
(
FieldElement::zero(),
FieldElement::zero(),
FieldElement::zero(),
),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
);
#[cfg(feature = "parallel")]
let res = res.reduce(
|| (FieldElement::zero(), FieldElement::zero(), FieldElement::zero()),
|| {
(
FieldElement::zero(),
FieldElement::zero(),
FieldElement::zero(),
)
},
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
);
res
Expand Down Expand Up @@ -508,11 +536,7 @@ where
// Compute evaluation points of the Dense Multilinear Poly
let round_poly = {
let mle_half = poly.len() / 2;
// TODO: push check/error for empty poly to start of proving or into multilinear poly so we eliminate this problem entirely.
let eval_0 = (0..mle_half)
.map(|i| poly[i].clone())
.reduce(|a, b| (a + b))
.unwrap();
let eval_0: FieldElement<F> = (0..mle_half).map(|i| poly[i].clone()).sum();
// We evaluate the poly at each round and each random challenge at 0, 1 we can compute both of these evaluations by summing over the boolearn hypercube for 0, 1 at the fixed point
// An additional optimization is to sum over eval_0 and compute eval_1 = prev_round_claim - eval_0;
let evals = vec![eval_0.clone(), prev_round_claim - eval_0];
Expand Down Expand Up @@ -548,6 +572,9 @@ where
}

// Verifies a sumcheck proof returning the claimed evaluation and random points used during sumcheck rounds
/// Note: Verification does not execute the final check of sumcheck protocol: g_v(r_v) = oracle_g(r),
/// as the oracle is not passed in. Expected that the caller will implement.
///
pub fn verify(
proof: SumcheckProof<F>,
transcript: &mut impl Transcript,
Expand All @@ -556,7 +583,9 @@ where
let mut r: Vec<FieldElement<F>> = Vec::with_capacity(proof.poly.num_vars());

// verify there is a univariate polynomial for each round
// TODO: push this if up so that the proof struct enforces this invariant
if proof.round_uni_polys.len() != proof.poly.num_vars() {
return Err(SumcheckError::InvalidProof);
}

for poly in proof.round_uni_polys {
// Verify degree bound
Expand Down Expand Up @@ -587,7 +616,6 @@ mod test {
use lambdaworks_math::polynomial::dense_multilinear_poly::DenseMultilinearPolynomial;

type F = U64GoldilocksPrimeField;
type FE = FieldElement<U64GoldilocksPrimeField>;

pub fn index_to_field_bitvector<F: IsField>(value: usize, bits: usize) -> Vec<FieldElement<F>> {
let mut vec: Vec<FieldElement<F>> = Vec::with_capacity(bits);
Expand All @@ -603,6 +631,7 @@ mod test {
}

#[test]
#[ignore]
fn prove_cubic() {
// Create three dense polynomials (all the same)
let num_vars = 3;
Expand Down Expand Up @@ -661,15 +690,15 @@ mod test {
}

#[test]
#[ignore]
fn prove_cubic_batched() {}

#[test]
#[ignore]
fn prove_cubic_additive() {}

/*
#[test]
fn prove_quad() {
// Create three dense polynomials (all the same)
let num_vars = 3;
let num_evals = (2usize).pow(num_vars as u32);
let mut evals: Vec<FieldElement<F>> = Vec::with_capacity(num_evals);
Expand Down Expand Up @@ -708,66 +737,19 @@ mod test {
assert_eq!(challenges, r);

// Consider this the opening proof to a(r) * b(r)
let a = a.evaluate(challenges).unwrap();
let a = a.evaluate(challenges.clone()).unwrap();
let b = b.evaluate(challenges).unwrap();

let oracle_query = a * b;
assert_eq!(verify_evaluation, oracle_query);
}

#[test]
fn prove_quad_batched() {
// Create three dense polynomials (all the same)
let num_vars = 3;
let num_evals = (2usize).pow(num_vars as u32);
let mut evals: Vec<FieldElement<F>> = Vec::with_capacity(num_evals);
for i in 0..num_evals {
evals.push(FieldElement::from(8 + i as u64));
}
let mut a: Vec<DenseMultilinearPolynomial<F>> =
vec![DenseMultilinearPolynomial::new(evals.clone()); 3];
let mut b: Vec<DenseMultilinearPolynomial<F>> =
vec![DenseMultilinearPolynomial::new(evals.clone()); 3];
let mut claim = FieldElement::<F>::zero();
for i in 0..num_evals {
claim += a.evaluate(index_to_field_bitvector(i, num_vars)).unwrap()
* b.evaluate(index_to_field_bitvector(i, num_vars)).unwrap();
}
let comb_func_prod =
|a: &FieldElement<F>, b: &FieldElement<F>| -> FieldElement<F> { a * b };
let r = vec![
FieldElement::from(3),
FieldElement::from(1),
FieldElement::from(3),
]; // point 0,0,0 within the boolean hypercube
let mut transcript = DefaultTranscript::new();
let (proof, challenges) =
Sumcheck::<F>::prove_quadratic(&claim, &mut a, &mut b, comb_func_prod, &mut transcript);
let mut transcript = DefaultTranscript::new();
let verify_result = Sumcheck::verify(proof, &mut transcript);
assert!(verify_result.is_ok());
let (verify_evaluation, verify_randomness) = verify_result.unwrap();
assert_eq!(challenges, verify_randomness);
assert_eq!(challenges, r);
// Consider this the opening proof to a(r) * b(r)
let a = a.evaluate(challenges).unwrap();
let b = b.evaluate(challenges).unwrap();
let oracle_query = a * b;
assert_eq!(verify_evaluation, oracle_query);
}
*/
#[ignore]
fn prove_quad_batched() {}

#[test]
fn prove_single() {
// Create three dense polynomials (all the same)
let num_vars = 3;
let num_evals = (2usize).pow(num_vars as u32);
let mut evals: Vec<FieldElement<F>> = Vec::with_capacity(num_evals);
Expand Down

0 comments on commit 83619cb

Please sign in to comment.