diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index c00182bfd..8c90644fb 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -5,14 +5,6 @@ use typenum::{Unsigned, U1}; use crate::ff::{Field, PrimeField, Serializable}; -/// A degree `N-1` polynomial is stored as `N` points `(x,y)` -/// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE` -/// Therefore, we only need to store the `y` coordinates. -#[derive(Debug, PartialEq, Clone)] -pub struct Polynomial { - y_coordinates: GenericArray, -} - /// The Canonical Lagrange denominator is defined as the denominator of the Lagrange base polynomials /// `https://en.wikipedia.org/wiki/Lagrange_polynomial` /// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE` @@ -101,13 +93,13 @@ where { /// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates" /// outputs the "y coordinates" such that `(x,y)` lies on `polynomial` - pub fn eval(&self, polynomial: &Polynomial) -> GenericArray { + pub fn eval(&self, y_coordinates: &GenericArray) -> GenericArray { self.table .iter() .map(|table_row| { table_row .iter() - .zip(polynomial.y_coordinates.iter()) + .zip(y_coordinates.iter()) .fold(F::ZERO, |acc, (&base, &y)| acc + base * y) }) .collect() @@ -175,24 +167,30 @@ mod test { use typenum::{U1, U32, U7, U8}; use crate::{ - ff::Field, + ff::PrimeField, protocol::ipa_prf::malicious_security::lagrange::{ - CanonicalLagrangeDenominator, LagrangeTable, Polynomial, + CanonicalLagrangeDenominator, LagrangeTable, }, }; type TestField = crate::ff::Fp32BitPrime; #[derive(Debug, PartialEq, Clone)] - struct MonomialFormPolynomial { + struct MonomialFormPolynomial { coefficients: GenericArray, } impl MonomialFormPolynomial where - F: Field, + F: PrimeField, N: ArrayLength, { + fn gen_y_values_of_canonical_points(self) -> GenericArray { + let canonical_points: GenericArray = + GenericArray::generate(|i| F::try_from(u128::try_from(i).unwrap()).unwrap()); + self.eval(&canonical_points) + } + /// test helper function that evaluates a polynomial in monomial form, i.e. `sum_i c_i x^i` on points `x_output` /// where `c_0` to `c_N` are stored in `polynomial` fn eval(&self, x_output: &GenericArray) -> GenericArray @@ -216,21 +214,6 @@ mod test { } } - impl From> for Polynomial - where - F: Field + TryFrom, - >::Error: Debug, - N: ArrayLength, - { - fn from(value: MonomialFormPolynomial) -> Self { - let canonical_points: GenericArray = - GenericArray::generate(|i| F::try_from(u128::try_from(i).unwrap()).unwrap()); - Polynomial { - y_coordinates: value.eval(&canonical_points), - } - } - } - fn lagrange_single_output_point_using_new( output_point: TestField, input_points: [TestField; 32], @@ -241,11 +224,11 @@ mod test { let output_expected = polynomial_monomial_form.eval( &GenericArray::::from_array([output_point; 1]), ); - let polynomial = Polynomial::from(polynomial_monomial_form.clone()); let denominator = CanonicalLagrangeDenominator::::new(); // generate table using new let lagrange_table = LagrangeTable::::new(&denominator, &output_point); - let output = lagrange_table.eval(&polynomial); + let output = + lagrange_table.eval(&polynomial_monomial_form.gen_y_values_of_canonical_points()); assert_eq!(output, output_expected); } @@ -265,11 +248,11 @@ mod test { TestField::try_from(u128::try_from(i).unwrap() + 8).unwrap() }); let output_expected = polynomial_monomial_form.eval(&x_coordinates_output); - let polynomial = Polynomial::from(polynomial_monomial_form.clone()); let denominator = CanonicalLagrangeDenominator::::new(); // generate table using from let lagrange_table = LagrangeTable::::from(denominator); - let output = lagrange_table.eval(&polynomial); + let output = + lagrange_table.eval(&polynomial_monomial_form.gen_y_values_of_canonical_points()); assert_eq!(output, output_expected); } diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs index ea0ac6eef..0e7f6bf3a 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs @@ -1 +1,2 @@ pub mod lagrange; +pub mod prover; diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs new file mode 100644 index 000000000..dd0263243 --- /dev/null +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -0,0 +1,89 @@ +use std::{ + iter::zip, + ops::{Add, Sub}, +}; + +use generic_array::{ArrayLength, GenericArray}; +use typenum::{Diff, Sum, U1}; + +use crate::{ + ff::PrimeField, + protocol::ipa_prf::malicious_security::lagrange::{ + CanonicalLagrangeDenominator, LagrangeTable, + }, +}; + +pub struct ProofGenerator { + u: Vec, + v: Vec, +} + +/// +/// Distributed Zero Knowledge Proofs algorithm drawn from +/// `https://eprint.iacr.org/2023/909.pdf` +/// +impl ProofGenerator +where + F: PrimeField, +{ + #![allow(non_camel_case_types)] + pub fn compute_proof<λ: ArrayLength>(self) -> GenericArray, U1>> + where + λ: ArrayLength + Add + Sub, + <λ as Add>::Output: Sub, + <<λ as Add>::Output as Sub>::Output: ArrayLength, + <λ as Sub>::Output: ArrayLength, + { + assert!(self.u.len() % λ::USIZE == 0); // We should pad with zeroes eventually + + let s = self.u.len() / λ::USIZE; + + let denominator = CanonicalLagrangeDenominator::::new(); + let lagrange_table = LagrangeTable::>::Output>::from(denominator); + let extrapolated_points = (0..s).map(|i| { + let p = (0..λ::USIZE).map(|j| self.u[i * λ::USIZE + j]).collect(); + let q = (0..λ::USIZE).map(|j| self.v[i * λ::USIZE + j]).collect(); + let p_extrapolated = lagrange_table.eval(&p); + let q_extrapolated = lagrange_table.eval(&q); + zip( + p.into_iter().chain(p_extrapolated), + q.into_iter().chain(q_extrapolated), + ) + .map(|(a, b)| a * b) + .collect::>() + }); + extrapolated_points + .reduce(|acc, pts| zip(acc, pts).map(|(a, b)| a + b).collect()) + .unwrap() + } +} + +#[cfg(all(test, unit_test))] +mod test { + use typenum::U4; + + use super::ProofGenerator; + use crate::ff::{Fp31, U128Conversions}; + + #[test] + fn sample_proof() { + const U: [u128; 32] = [ + 0, 0, 1, 15, 0, 0, 0, 15, 2, 30, 30, 16, 29, 1, 1, 15, 0, 0, 0, 15, 0, 0, 0, 15, 2, 30, + 30, 16, 0, 0, 1, 15, + ]; + const V: [u128; 32] = [ + 30, 30, 30, 30, 0, 1, 0, 1, 0, 0, 0, 30, 0, 30, 0, 30, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 30, 0, 0, 30, 30, + ]; + const EXPECTED: [u128; 7] = [0, 30, 29, 30, 3, 22, 6]; + let pg: ProofGenerator = ProofGenerator { + u: U.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(), + v: V.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(), + }; + let proof = pg.compute_proof::(); + assert_eq!( + proof.into_iter().map(|x| x.as_u128()).collect::>(), + EXPECTED + ); + } +}