Skip to content

Commit

Permalink
append polynomial
Browse files Browse the repository at this point in the history
  • Loading branch information
PatStiles committed Jan 30, 2024
1 parent 69083d1 commit 61e5bd8
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 43 deletions.
93 changes: 52 additions & 41 deletions crypto/src/subprotocols/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use core::fmt::Display;
use std::marker::PhantomData;

use crate::fiat_shamir::transcript::Transcript;
use lambdaworks_math::traits::AsBytes;
use lambdaworks_math::field::element::FieldElement;
use lambdaworks_math::field::traits::{IsField, IsPrimeField};
use lambdaworks_math::polynomial::{
Expand All @@ -25,7 +26,7 @@ where
let iter = 0..len;
#[cfg(feature = "parallel")]
let iter = (0..len).into_par_iter();
iter.map(|i| {
let res = iter.map(|i| {
// eval_0: A(low)
let eval_0 = comb_func(&poly_a[i], &poly_b[i]);

Expand All @@ -34,11 +35,17 @@ where
let poly_b_eval_2 = &poly_b[len + i] + &poly_b[len + i] - &poly_b[i];
let eval_2 = comb_func(&poly_a_eval_2, &poly_b_eval_2);
(eval_0, eval_2)
})
.reduce(
|| (FieldElement::<F>::zero(), FieldElement::<F>::zero()),
});
#[cfg(not(feature = "parallel"))]
let res = res.fold((FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1));
#[cfg(feature = "parallel")]
let res = res.reduce(
|| (FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1),
)
);

res
}

fn eval_points_cubic<F: IsField, E>(
Expand All @@ -56,7 +63,7 @@ where
let iter = 0..len;
#[cfg(feature = "parallel")]
let iter = (0..len).into_par_iter();
iter.map(|i| {
let res = iter.map(|i| {
// eval_0: A(low)
let eval_0 = comb_func(&poly_a[i], &poly_b[i], &poly_c[i]);

Expand All @@ -73,17 +80,17 @@ where
let eval_3 = comb_func(&poly_a_eval_3, &poly_b_eval_3, &poly_c_eval_3);

(eval_0, eval_2, eval_3)
})
.reduce(
|| {
(
FieldElement::<F>::zero(),
FieldElement::<F>::zero(),
FieldElement::<F>::zero(),
)
},
});
#[cfg(not(feature = "parallel"))]
let res = res.fold((FieldElement::zero(), FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2));
#[cfg(feature = "parallel")]
let res = res.reduce(
|| (FieldElement::zero(), FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
)
);

res
}

#[derive(Debug)]
Expand Down Expand Up @@ -124,7 +131,7 @@ where
impl<F: IsField + IsPrimeField> Sumcheck<F>
where
<F as IsField>::BaseType: Send + Sync,
FieldElement<F>: ByteConversion,
FieldElement<F>: ByteConversion + AsBytes,
{
//Used for sum_{(a * b)}
pub fn prove_quadratic<E>(
Expand Down Expand Up @@ -343,22 +350,24 @@ where
evals_combined_2,
evals_combined_3,
];
let poly = Polynomial::new(&evals);
let round_poly = Polynomial::new(&evals);

// TODO append the prover's message to the transcript
// TODO: Check if order matters
transcript.append(&round_poly.as_bytes());

// Squeeze Verifier Challenge for next round
let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap();
challenges.push(challenge.clone());

// TODO: rayon::join and gate
// bound all tables to the verifier's challenege
for (poly_a, poly_b) in poly_a.iter_mut().zip(poly_b.iter_mut()) {
poly_a.fix_variable(&challenge);
poly_b.fix_variable(&challenge);
}

prev_round_claim = poly.evaluate(&challenge);
round_uni_polys.push(poly);
prev_round_claim = round_poly.evaluate(&challenge);
round_uni_polys.push(round_poly);
}

(
Expand Down Expand Up @@ -397,14 +406,15 @@ where
let mut prev_round_claim = sum.clone();

for _ in 0..poly_a.num_vars() {
let poly = {
let round_poly = {
let (eval_point_0, eval_point_2, eval_point_3) = {
//TODO: remove this dedup if possible
let len = poly_a.len() / 2;
#[cfg(not(feature = "parallel"))]
let iter = 0..len;
#[cfg(feature = "parallel")]
let iter = (0..len).into_par_iter();
iter.map(|i| {
let res = iter.map(|i| {
// eval 0: bound_func is A(low)
let eval_point_0 =
comb_func(&poly_a[i], &poly_b[i], &poly_c[i], &poly_d[i]);
Expand Down Expand Up @@ -433,17 +443,16 @@ where
&poly_d_point_3,
);
(eval_point_0, eval_point_2, eval_point_3)
})
.reduce(
|| {
(
FieldElement::zero(),
FieldElement::zero(),
FieldElement::zero(),
)
},
});
#[cfg(not(feature = "parallel"))]
let res = res.fold((FieldElement::zero(), FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2));
#[cfg(feature = "parallel")]
let res = res.reduce(
|| (FieldElement::zero(), FieldElement::zero(), FieldElement::zero()),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
)
);
res
};
let evals = vec![
eval_point_0.clone(),
Expand All @@ -454,15 +463,17 @@ where
Polynomial::new(&evals)
};

// TODO append the prover's message to the transcript
// TODO: Does it matter that its before the challenge???? -> Should be I believe
transcript.append(&round_poly.as_bytes());

// Squeeze Verifier Challenge for next round
let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap();
challenges.push(challenge.clone());

prev_round_claim = poly.evaluate(&challenge);
round_uni_polys.push(poly);
prev_round_claim = round_poly.evaluate(&challenge);
round_uni_polys.push(round_poly);

// TODO: rayon::join and gate
// bound all tables to the verifier's challenege
poly_a.fix_variable(&challenge);
poly_b.fix_variable(&challenge);
Expand Down Expand Up @@ -509,6 +520,7 @@ where
};

// TODO: Append poly to transcript -> Modify Transcript
transcript.append(&round_poly.as_bytes());

let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap();
challenges.push(challenge.clone());
Expand Down Expand Up @@ -554,7 +566,7 @@ where
{
return Err(SumcheckError::InvalidProof);
}
//transcript.append(poly);
transcript.append(&poly.as_bytes());

let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap();
r.push(challenge.clone());
Expand Down Expand Up @@ -640,10 +652,9 @@ mod test {
assert_eq!(challenges, verify_randomness);
assert_eq!(challenges, r);

// Consider this the opening proof to a(r) * b(r) * c(r)
let a = a.evaluate(challenges).unwrap();
let b = b.evaluate(challenges).unwrap();
let c = c.evaluate(challenges).unwrap();
let a = a.evaluate(challenges.clone()).unwrap();
let b = b.evaluate(challenges.clone()).unwrap();
let c = c.evaluate(challenges.clone()).unwrap();

let oracle_query = a * b * c;
assert_eq!(verify_evaluation, oracle_query);
Expand Down
21 changes: 19 additions & 2 deletions math/src/polynomial/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::field::element::FieldElement;
use crate::field::traits::{IsField, IsSubFieldOf};
use crate::{
traits::AsBytes,
field::element::FieldElement,
field::traits::{IsField, IsSubFieldOf}
};
use alloc::{borrow::ToOwned, vec, vec::Vec};
use core::{fmt::Display, ops};

Expand Down Expand Up @@ -804,6 +807,20 @@ impl Display for InterpolateError {
#[cfg(feature = "std")]
impl std::error::Error for InterpolateError {}

impl<F: IsField> AsBytes for Polynomial<FieldElement<F>>
where
FieldElement<F>: AsBytes
{
fn as_bytes(&self) -> Vec<u8> {
self.coefficients().into_iter().fold(Vec::new(), |mut acc, coeff| {
acc.extend_from_slice(&coeff.as_bytes());
acc
}
)
}

}

#[cfg(test)]
mod tests {
use crate::field::fields::u64_prime_field::U64PrimeField;
Expand Down

0 comments on commit 61e5bd8

Please sign in to comment.