Skip to content

Commit

Permalink
Next part of the ZKPs
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage committed Mar 17, 2024
1 parent 661259c commit 92aaa8d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 41 deletions.
67 changes: 48 additions & 19 deletions ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::fmt::Debug;
use std::{
borrow::Borrow,
fmt::Debug,
iter::{repeat, zip},
};

use generic_array::{ArrayLength, GenericArray};
use typenum::{Unsigned, U1};
Expand Down Expand Up @@ -91,16 +95,40 @@ where
N: ArrayLength,
M: ArrayLength,
{
pub fn print(&self) {
for table_row in &self.table {
println!("{:?}", table_row);
}
}

/// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates"
/// outputs the "y coordinates" such that `(x,y)` lies on `polynomial`
pub fn eval(&self, y_coordinates: &GenericArray<F, N>) -> GenericArray<F, M> {
pub fn eval<I, J>(&self, y_coordinates: I) -> GenericArray<F, M>
where
I: IntoIterator<Item = J> + Copy,
I::IntoIter: ExactSizeIterator,
J: Borrow<F>,
{
// let y_coordinates = y_coordinates.into_iter();
// debug_assert_eq!(y_coordinates.len(), N::USIZE);
// y_coordinates
// .enumerate()
// .map(|(i, y_coord)| {
// self.table
// .iter()
// .map(|table_row| table_row[i] * (*y_coord.borrow()))
// .collect::<GenericArray<F, _>>()
// })
// .reduce(|vec_a, vec_b| zip(vec_a, vec_b).map(|(a, b)| a + b).collect())
// .unwrap()

self.table
.iter()
.map(|table_row| {
table_row
.iter()
.zip(y_coordinates.iter())
.fold(F::ZERO, |acc, (&base, &y)| acc + base * y)
.zip(y_coordinates.into_iter())
.fold(F::ZERO, |acc, (&base, y)| acc + base * (*y.borrow()))
})
.collect()
}
Expand Down Expand Up @@ -160,7 +188,7 @@ where

#[cfg(all(test, unit_test))]
mod test {
use std::fmt::Debug;
use std::{borrow::Borrow, fmt::Debug};

use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray};
use proptest::{prelude::*, proptest};
Expand All @@ -186,27 +214,31 @@ mod test {
N: ArrayLength,
{
fn gen_y_values_of_canonical_points(self) -> GenericArray<F, N> {
let canonical_points: GenericArray<F, N> =
GenericArray::generate(|i| F::try_from(u128::try_from(i).unwrap()).unwrap());
self.eval(&canonical_points)
// Sadly, we cannot just use the range (0..N::U128) because it does not implement ExactSizeIterator
let canonical_points =
(0..N::USIZE).map(|i| F::try_from(u128::try_from(i).unwrap()).unwrap());
self.eval(canonical_points)
}

/// test helper function that evaluates a polynomial in monomial form, i.e. `sum_i c_i x^i` on points `x_output`
/// where `c_0` to `c_N` are stored in `polynomial`
fn eval<M>(&self, x_output: &GenericArray<F, M>) -> GenericArray<F, M>
fn eval<M, I, J>(&self, x_output: I) -> GenericArray<F, M>
where
I: IntoIterator<Item = J>,
I::IntoIter: ExactSizeIterator,
J: Borrow<F>,
M: ArrayLength,
{
x_output
.iter()
.map(|&x| {
.into_iter()
.map(|x| {
// monomial base, i.e. `x^k`
// evaluate p via `sum_k coefficient_k * x^k`
let (_, y) = self
.coefficients
.iter()
.fold((F::ONE, F::ZERO), |(base, y), &coef| {
(base * x, y + coef * base)
(base * (*x.borrow()), y + coef * base)
});
y
})
Expand All @@ -221,9 +253,7 @@ mod test {
let polynomial_monomial_form = MonomialFormPolynomial {
coefficients: GenericArray::<TestField, U32>::from_array(input_points),
};
let output_expected = polynomial_monomial_form.eval(
&GenericArray::<TestField, U1>::from_array([output_point; 1]),
);
let output_expected = polynomial_monomial_form.eval(&[output_point]);
let denominator = CanonicalLagrangeDenominator::<TestField, U32>::new();
// generate table using new
let lagrange_table = LagrangeTable::<TestField, U32, U1>::new(&denominator, &output_point);
Expand All @@ -244,10 +274,9 @@ mod test {
coefficients: GenericArray::<TestField, U8>::from_array(input_points),
};
// the canonical x coordinates are 0..7, the outputs use coordinates 8..15:
let x_coordinates_output = GenericArray::<_, U7>::generate(|i| {
TestField::try_from(u128::try_from(i).unwrap() + 8).unwrap()
});
let output_expected = polynomial_monomial_form.eval(&x_coordinates_output);
let x_coordinates_output =
(0..7).map(|i| TestField::try_from(u128::try_from(i).unwrap() + 8).unwrap());
let output_expected = polynomial_monomial_form.eval(x_coordinates_output);
let denominator = CanonicalLagrangeDenominator::<TestField, U8>::new();
// generate table using from
let lagrange_table = LagrangeTable::<TestField, U8, U7>::from(denominator);
Expand Down
93 changes: 71 additions & 22 deletions ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ use crate::{
},
};

pub struct ZeroKnowledgeProof<F: PrimeField, N: ArrayLength> {
g: GenericArray<F, N>,
r: F,
}

pub struct ProofGenerator<F: PrimeField> {
u: Vec<F>,
v: Vec<F>,
Expand All @@ -27,7 +32,13 @@ where
F: PrimeField,
{
#![allow(non_camel_case_types)]
pub fn compute_proof<λ: ArrayLength>(self) -> GenericArray<F, Diff<Sum<λ, λ>, U1>>
pub fn compute_proof<λ: ArrayLength>(
self,
r: F,
) -> (
ZeroKnowledgeProof<F, Diff<Sum<λ, λ>, U1>>,
ProofGenerator<F>,
)
where
λ: ArrayLength + Add + Sub<U1>,
<λ as Add>::Output: Sub<U1>,
Expand All @@ -38,23 +49,42 @@ where

let s = self.u.len() / λ::USIZE;

if s <= 1 {
panic!("When the output is this small, you should call compute_final_proof");
}

let mut next_proof_generator = ProofGenerator {
u: Vec::<F>::with_capacity(s),
v: Vec::<F>::with_capacity(s),
};

let denominator = CanonicalLagrangeDenominator::<F, λ>::new();
let lagrange_table_r = LagrangeTable::<F, λ, U1>::new(&denominator, &r);
lagrange_table_r.print();
let lagrange_table = LagrangeTable::<F, λ, <λ as Sub<U1>>::Output>::from(denominator);
let extrapolated_points = (0..s).map(|i| {
let p = (0..λ::USIZE).map(|j| self.u[i * λ::USIZE + j]).collect();
let q = (0..λ::USIZE).map(|j| self.v[i * λ::USIZE + j]).collect();
let p_extrapolated = lagrange_table.eval(&p);
let q_extrapolated = lagrange_table.eval(&q);
zip(
p.into_iter().chain(p_extrapolated),
q.into_iter().chain(q_extrapolated),
)
.map(|(a, b)| a * b)
.collect::<GenericArray<F, _>>()
let start = i * λ::USIZE;
let end = start + λ::USIZE;
let p = &self.u[start..end];
let q = &self.v[start..end];
let p_extrapolated = lagrange_table.eval(p);
let q_extrapolated = lagrange_table.eval(q);
let p_r = lagrange_table_r.eval(p)[0];
let q_r = lagrange_table_r.eval(q)[0];
next_proof_generator.u.push(p_r);
next_proof_generator.v.push(q_r);
zip(p.into_iter(), q.into_iter())
.map(|(a, b)| *a * *b)
.chain(zip(p_extrapolated, q_extrapolated).map(|(a, b)| a * b))
.collect::<GenericArray<F, _>>()
});
extrapolated_points
.reduce(|acc, pts| zip(acc, pts).map(|(a, b)| a + b).collect())
.unwrap()
let proof = ZeroKnowledgeProof {
g: extrapolated_points
.reduce(|acc, pts| zip(acc, pts).map(|(a, b)| a + b).collect())
.unwrap(),
r,
};
(proof, next_proof_generator)
}
}

Expand All @@ -68,22 +98,41 @@ mod test {
#[test]
fn sample_proof() {
const U: [u128; 32] = [
0, 0, 1, 15, 0, 0, 0, 15, 2, 30, 30, 16, 29, 1, 1, 15, 0, 0, 0, 15, 0, 0, 0, 15, 2, 30,
30, 16, 0, 0, 1, 15,
0, 30, 0, 16, 0, 1, 0, 15, 0, 0, 0, 16, 0, 30, 0, 16, 29, 1, 1, 15, 0, 0, 1, 15, 2, 30,
30, 16, 0, 0, 30, 16,
];
const V: [u128; 32] = [
30, 30, 30, 30, 0, 1, 0, 1, 0, 0, 0, 30, 0, 30, 0, 30, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
30, 0, 0, 30, 30,
0, 0, 0, 30, 0, 0, 0, 1, 30, 30, 30, 30, 0, 0, 30, 30, 0, 30, 0, 30, 0, 0, 0, 1, 0, 0,
1, 1, 0, 0, 1, 1,
];
const EXPECTED: [u128; 7] = [0, 30, 29, 30, 3, 22, 6];
const EXPECTED: [u128; 7] = [0, 30, 29, 30, 5, 28, 13];
const R1: u128 = 22;
const EXPECTED_NEXT_U: [u128; 8] = [0, 0, 26, 0, 7, 18, 24, 13];
const EXPECTED_NEXT_V: [u128; 8] = [10, 21, 30, 28, 15, 21, 3, 3];
let pg: ProofGenerator<Fp31> = ProofGenerator {
u: U.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(),
v: V.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(),
};
let proof = pg.compute_proof::<U4>();
let (proof, next_proof_generator) = pg.compute_proof::<U4>(Fp31::try_from(R1).unwrap());
assert_eq!(
proof.g.into_iter().map(|x| x.as_u128()).collect::<Vec<_>>(),
EXPECTED,
);
assert_eq!(
next_proof_generator
.u
.into_iter()
.map(|x| x.as_u128())
.collect::<Vec<_>>(),
EXPECTED_NEXT_U,
);
assert_eq!(
proof.into_iter().map(|x| x.as_u128()).collect::<Vec<_>>(),
EXPECTED
next_proof_generator
.v
.into_iter()
.map(|x| x.as_u128())
.collect::<Vec<_>>(),
EXPECTED_NEXT_V,
);
}
}

0 comments on commit 92aaa8d

Please sign in to comment.