From 92aaa8d1813536ea6b48116cecb166e8f128c9b0 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Sun, 17 Mar 2024 13:43:54 +1000 Subject: [PATCH] Next part of the ZKPs --- .../ipa_prf/malicious_security/lagrange.rs | 67 +++++++++---- .../ipa_prf/malicious_security/prover.rs | 93 ++++++++++++++----- 2 files changed, 119 insertions(+), 41 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 8c90644fb..b1a4396fe 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,4 +1,8 @@ -use std::fmt::Debug; +use std::{ + borrow::Borrow, + fmt::Debug, + iter::{repeat, zip}, +}; use generic_array::{ArrayLength, GenericArray}; use typenum::{Unsigned, U1}; @@ -91,16 +95,40 @@ where N: ArrayLength, M: ArrayLength, { + pub fn print(&self) { + for table_row in &self.table { + println!("{:?}", table_row); + } + } + /// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates" /// outputs the "y coordinates" such that `(x,y)` lies on `polynomial` - pub fn eval(&self, y_coordinates: &GenericArray) -> GenericArray { + pub fn eval(&self, y_coordinates: I) -> GenericArray + where + I: IntoIterator + Copy, + I::IntoIter: ExactSizeIterator, + J: Borrow, + { + // let y_coordinates = y_coordinates.into_iter(); + // debug_assert_eq!(y_coordinates.len(), N::USIZE); + // y_coordinates + // .enumerate() + // .map(|(i, y_coord)| { + // self.table + // .iter() + // .map(|table_row| table_row[i] * (*y_coord.borrow())) + // .collect::>() + // }) + // .reduce(|vec_a, vec_b| zip(vec_a, vec_b).map(|(a, b)| a + b).collect()) + // .unwrap() + self.table .iter() .map(|table_row| { table_row .iter() - .zip(y_coordinates.iter()) - .fold(F::ZERO, |acc, (&base, &y)| acc + base * y) + .zip(y_coordinates.into_iter()) + .fold(F::ZERO, |acc, (&base, y)| acc + base * (*y.borrow())) }) .collect() } @@ -160,7 +188,7 @@ where #[cfg(all(test, unit_test))] mod test { - use std::fmt::Debug; + use std::{borrow::Borrow, fmt::Debug}; use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; use proptest::{prelude::*, proptest}; @@ -186,27 +214,31 @@ mod test { N: ArrayLength, { fn gen_y_values_of_canonical_points(self) -> GenericArray { - let canonical_points: GenericArray = - GenericArray::generate(|i| F::try_from(u128::try_from(i).unwrap()).unwrap()); - self.eval(&canonical_points) + // Sadly, we cannot just use the range (0..N::U128) because it does not implement ExactSizeIterator + let canonical_points = + (0..N::USIZE).map(|i| F::try_from(u128::try_from(i).unwrap()).unwrap()); + self.eval(canonical_points) } /// test helper function that evaluates a polynomial in monomial form, i.e. `sum_i c_i x^i` on points `x_output` /// where `c_0` to `c_N` are stored in `polynomial` - fn eval(&self, x_output: &GenericArray) -> GenericArray + fn eval(&self, x_output: I) -> GenericArray where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + J: Borrow, M: ArrayLength, { x_output - .iter() - .map(|&x| { + .into_iter() + .map(|x| { // monomial base, i.e. `x^k` // evaluate p via `sum_k coefficient_k * x^k` let (_, y) = self .coefficients .iter() .fold((F::ONE, F::ZERO), |(base, y), &coef| { - (base * x, y + coef * base) + (base * (*x.borrow()), y + coef * base) }); y }) @@ -221,9 +253,7 @@ mod test { let polynomial_monomial_form = MonomialFormPolynomial { coefficients: GenericArray::::from_array(input_points), }; - let output_expected = polynomial_monomial_form.eval( - &GenericArray::::from_array([output_point; 1]), - ); + let output_expected = polynomial_monomial_form.eval(&[output_point]); let denominator = CanonicalLagrangeDenominator::::new(); // generate table using new let lagrange_table = LagrangeTable::::new(&denominator, &output_point); @@ -244,10 +274,9 @@ mod test { coefficients: GenericArray::::from_array(input_points), }; // the canonical x coordinates are 0..7, the outputs use coordinates 8..15: - let x_coordinates_output = GenericArray::<_, U7>::generate(|i| { - TestField::try_from(u128::try_from(i).unwrap() + 8).unwrap() - }); - let output_expected = polynomial_monomial_form.eval(&x_coordinates_output); + let x_coordinates_output = + (0..7).map(|i| TestField::try_from(u128::try_from(i).unwrap() + 8).unwrap()); + let output_expected = polynomial_monomial_form.eval(x_coordinates_output); let denominator = CanonicalLagrangeDenominator::::new(); // generate table using from let lagrange_table = LagrangeTable::::from(denominator); diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index dd0263243..cef5665d6 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -13,6 +13,11 @@ use crate::{ }, }; +pub struct ZeroKnowledgeProof { + g: GenericArray, + r: F, +} + pub struct ProofGenerator { u: Vec, v: Vec, @@ -27,7 +32,13 @@ where F: PrimeField, { #![allow(non_camel_case_types)] - pub fn compute_proof<λ: ArrayLength>(self) -> GenericArray, U1>> + pub fn compute_proof<λ: ArrayLength>( + self, + r: F, + ) -> ( + ZeroKnowledgeProof, U1>>, + ProofGenerator, + ) where λ: ArrayLength + Add + Sub, <λ as Add>::Output: Sub, @@ -38,23 +49,42 @@ where let s = self.u.len() / λ::USIZE; + if s <= 1 { + panic!("When the output is this small, you should call compute_final_proof"); + } + + let mut next_proof_generator = ProofGenerator { + u: Vec::::with_capacity(s), + v: Vec::::with_capacity(s), + }; + let denominator = CanonicalLagrangeDenominator::::new(); + let lagrange_table_r = LagrangeTable::::new(&denominator, &r); + lagrange_table_r.print(); let lagrange_table = LagrangeTable::>::Output>::from(denominator); let extrapolated_points = (0..s).map(|i| { - let p = (0..λ::USIZE).map(|j| self.u[i * λ::USIZE + j]).collect(); - let q = (0..λ::USIZE).map(|j| self.v[i * λ::USIZE + j]).collect(); - let p_extrapolated = lagrange_table.eval(&p); - let q_extrapolated = lagrange_table.eval(&q); - zip( - p.into_iter().chain(p_extrapolated), - q.into_iter().chain(q_extrapolated), - ) - .map(|(a, b)| a * b) - .collect::>() + let start = i * λ::USIZE; + let end = start + λ::USIZE; + let p = &self.u[start..end]; + let q = &self.v[start..end]; + let p_extrapolated = lagrange_table.eval(p); + let q_extrapolated = lagrange_table.eval(q); + let p_r = lagrange_table_r.eval(p)[0]; + let q_r = lagrange_table_r.eval(q)[0]; + next_proof_generator.u.push(p_r); + next_proof_generator.v.push(q_r); + zip(p.into_iter(), q.into_iter()) + .map(|(a, b)| *a * *b) + .chain(zip(p_extrapolated, q_extrapolated).map(|(a, b)| a * b)) + .collect::>() }); - extrapolated_points - .reduce(|acc, pts| zip(acc, pts).map(|(a, b)| a + b).collect()) - .unwrap() + let proof = ZeroKnowledgeProof { + g: extrapolated_points + .reduce(|acc, pts| zip(acc, pts).map(|(a, b)| a + b).collect()) + .unwrap(), + r, + }; + (proof, next_proof_generator) } } @@ -68,22 +98,41 @@ mod test { #[test] fn sample_proof() { const U: [u128; 32] = [ - 0, 0, 1, 15, 0, 0, 0, 15, 2, 30, 30, 16, 29, 1, 1, 15, 0, 0, 0, 15, 0, 0, 0, 15, 2, 30, - 30, 16, 0, 0, 1, 15, + 0, 30, 0, 16, 0, 1, 0, 15, 0, 0, 0, 16, 0, 30, 0, 16, 29, 1, 1, 15, 0, 0, 1, 15, 2, 30, + 30, 16, 0, 0, 30, 16, ]; const V: [u128; 32] = [ - 30, 30, 30, 30, 0, 1, 0, 1, 0, 0, 0, 30, 0, 30, 0, 30, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 30, 0, 0, 30, 30, + 0, 0, 0, 30, 0, 0, 0, 1, 30, 30, 30, 30, 0, 0, 30, 30, 0, 30, 0, 30, 0, 0, 0, 1, 0, 0, + 1, 1, 0, 0, 1, 1, ]; - const EXPECTED: [u128; 7] = [0, 30, 29, 30, 3, 22, 6]; + const EXPECTED: [u128; 7] = [0, 30, 29, 30, 5, 28, 13]; + const R1: u128 = 22; + const EXPECTED_NEXT_U: [u128; 8] = [0, 0, 26, 0, 7, 18, 24, 13]; + const EXPECTED_NEXT_V: [u128; 8] = [10, 21, 30, 28, 15, 21, 3, 3]; let pg: ProofGenerator = ProofGenerator { u: U.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(), v: V.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(), }; - let proof = pg.compute_proof::(); + let (proof, next_proof_generator) = pg.compute_proof::(Fp31::try_from(R1).unwrap()); + assert_eq!( + proof.g.into_iter().map(|x| x.as_u128()).collect::>(), + EXPECTED, + ); + assert_eq!( + next_proof_generator + .u + .into_iter() + .map(|x| x.as_u128()) + .collect::>(), + EXPECTED_NEXT_U, + ); assert_eq!( - proof.into_iter().map(|x| x.as_u128()).collect::>(), - EXPECTED + next_proof_generator + .v + .into_iter() + .map(|x| x.as_u128()) + .collect::>(), + EXPECTED_NEXT_V, ); } }