Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precomputed Lagrange #957

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion ipa-core/src/ff/prime_field.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt::Display;
use std::{fmt::Display, mem};

use generic_array::GenericArray;

Expand All @@ -14,6 +14,42 @@ pub trait PrimeField: Field {
type PrimeInteger: Into<u128>;

const PRIME: Self::PrimeInteger;

/// Invert function that returns the multiplicative inverse
/// the default implementation uses the extended Euclidean algorithm,
/// follows inversion algorithm in
/// (with the modification that it works for unsigned integers by keeping track of `sign`):
/// `https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm`
///
/// The function operates on `u128` rather than field elements since we need divisions
///
/// ## Panics
/// When `self` is `Zero`
#[must_use]
fn invert(&self) -> Self {
assert_ne!(*self, Self::ZERO);

let mut t = 0u128;
let mut newt = 1u128;
let mut r = Self::PRIME.into();
let mut newr = self.as_u128();
let mut sign = 1u128;

while newr != 0 {
let quotient = r / newr;
mem::swap(&mut t, &mut newt);
mem::swap(&mut r, &mut newr);
newt += quotient * t;
newr -= quotient * r;

// flip sign
sign = 1 - sign;
}

// when sign is negative, output `PRIME-t` otherwise `t`
// unwrap is safe
Self::try_from((1 - sign) * t + sign * (Self::PRIME.into() - t)).unwrap()
}
}

#[derive(thiserror::Error, Debug)]
Expand Down Expand Up @@ -294,6 +330,14 @@ macro_rules! field_impl {
let err = $field::deserialize(&buf).unwrap_err();
assert!(matches!(err, GreaterThanPrimeError(..)))
}

#[test]
fn invert(element: $field) {
if element != $field::ZERO
{
assert_eq!($field::ONE,element * element.invert() );
}
}
}
}

Expand Down
282 changes: 282 additions & 0 deletions ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
use std::fmt::Debug;

use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray};
use typenum::{Unsigned, U1};

use crate::ff::{Field, PrimeField, Serializable};

/// A degree `N-1` polynomial is stored as `N` points `(x,y)`
/// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE`
/// Therefore, we only need to store the `y` coordinates.
#[derive(Debug, PartialEq, Clone)]

Check warning on line 11 in ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs#L11

Added line #L11 was not covered by tests
pub struct Polynomial<F: Field, N: ArrayLength> {
y_coordinates: GenericArray<F, N>,
}

/// The Canonical Lagrange denominator is defined as the denominator of the Lagrange base polynomials
/// `https://en.wikipedia.org/wiki/Lagrange_polynomial`
/// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE`
/// the degree of the polynomials is `N-1`
pub struct CanonicalLagrangeDenominator<F: Field, N: ArrayLength> {
benjaminsavage marked this conversation as resolved.
Show resolved Hide resolved
denominator: GenericArray<F, N>,
}

impl<F, N> CanonicalLagrangeDenominator<F, N>
where
F: PrimeField + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
{
/// generates canonical Lagrange denominators
///
/// ## Panics
/// When the field size is too small for `N` evaluation points
pub fn new() -> Self {
// assertion that field is large enough
// when it is large enough, `F::try_from().unwrap()` below does not panic
assert!(
F::BITS > usize::BITS - N::USIZE.leading_zeros(),
"Field size {} is not large enough to hold {} points",

Check warning on line 39 in ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs#L39

Added line #L39 was not covered by tests
F::BITS,
N::USIZE
);

// assertion that table is not too large for the stack
assert!(<F as Serializable>::Size::USIZE * N::USIZE < 2024);

let mut denominator = GenericArray::generate(|_| F::ONE);
for (d, i) in denominator.iter_mut().zip(0u64..) {
for j in (0..N::U64).filter(|&j| i != j) {
*d *= F::try_from(u128::from(i)).unwrap() - F::try_from(u128::from(j)).unwrap();
}
*d = d.invert();
benjaminsavage marked this conversation as resolved.
Show resolved Hide resolved
}
Self { denominator }
}
}

/// `LagrangeTable` is a precomputed table for the Lagrange evaluation.
/// Allows to compute points on the polynomial, i.e. output points,
/// given enough points on the polynomial, i.e. input points,
/// by using the `eval` function.
/// The "x coordinates" are implicit.
/// The "y coordinates" of the input points are inputs to `eval`.
/// The output of `eval` are the "y coordinates" of the output points .
/// The "x coordinates" of the input points `x_0` to `x_(N-1)` are `F::ZERO` to `(N-1)*F::ONE`.
/// The `LagrangeTable` also specifies `M` "x coordinates" for the output points.
/// The "x coordinates" of the output points `x_N` to `x_(N+M-1)` are `N*F::ONE` to `(N+M-1)*F::ONE`
/// when generated using `from(denominator)`
/// unless generated using `new(denominator, x_output)` for a specific output "x coordinate" `x_output`.
pub struct LagrangeTable<F: Field, N: ArrayLength, M: ArrayLength> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised that it declares that it works over fields and not prime fields only.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its more general this way. If you want to load something precomputed for a field for which we dont have invert implemented, e.g. from a file, you could. Invert is only needed to generate the denominator, otherwise we don't need inversion.

table: GenericArray<GenericArray<F, N>, M>,
}

impl<F, N> LagrangeTable<F, N, U1>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this specialization?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make it more general, however we won't need it. M is the amount of output x-coordinates and this implementation allows to generate a Lagrange table for a single output point using new(x_output, denominator). In the proofs, we need to evaluate the polynomials on a single x-coordinate r which is random. Otherwise we always evaluate polynomials at fixed x-coordinates N...N+M in which case we can just use the implementation of From(denominator) below

where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
{
/// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point
/// The "x coordinate" of the output point is `x_output`.
pub fn new(denominator: CanonicalLagrangeDenominator<F, N>, x_output: &F) -> Self {
// assertion that table is not too large for the stack
assert!(<F as Serializable>::Size::USIZE * N::USIZE < 2024);

let mut table = denominator.denominator;
Self::compute_table_row(x_output, &mut table);
LagrangeTable::<F, N, U1> {
table: GenericArray::from_array([table; 1]),
}
}
}

impl<F, N, M> LagrangeTable<F, N, M>
where
F: Field,
N: ArrayLength,
M: ArrayLength,
{
/// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates"
/// outputs the "y coordinates" such that `(x,y)` lies on `polynomial`
pub fn eval(&self, polynomial: &Polynomial<F, N>) -> GenericArray<F, M> {
let mut result = GenericArray::generate(|_| F::ONE);
self.mult_result_by_evaluation(polynomial, &mut result);
result
}

/// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates"
/// the "y coordinates" of the evaluation are multiplied to `result`
pub fn mult_result_by_evaluation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be public? It looks like it assumes a particular form for result that a caller isn't obligated to guarantee.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some parts in the DZKP (within the provers computation) where I want to compute q(x')*p(x') for some x coordinate x' so it is convenient do use this function and pass q(x') as result which is mutated to q(x')*p(x'). There is the more general function eval for the general case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not to embed this logic into the Lagrange module. I understand it might even be slightly more efficient this way - but it feels like we are violating an important separation of concerns.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with removing it since you seem to prefer it that way.

&self,
polynomial: &Polynomial<F, N>,
result: &mut GenericArray<F, M>,
) {
for (y, base) in result.iter_mut().zip(self.table.iter()) {
*y *= base
.iter()
.zip(polynomial.y_coordinates.iter())
.fold(F::ZERO, |acc, (&base, &y)| acc + base * y);
}
}

/// helper function to compute a single row of `LagrangeTable`
///
/// ## Panics
/// When the field size is too small for `N` evaluation points
fn compute_table_row(x_output: &F, table_row: &mut GenericArray<F, N>)
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
{
for (entry, i) in table_row.iter_mut().zip(0u64..) {
for j in (0..N::U64).filter(|&j| j != i) {
*entry *= *x_output - F::try_from(u128::from(j)).unwrap();
}
}
}
}

impl<F, N, M> From<CanonicalLagrangeDenominator<F, N>> for LagrangeTable<F, N, M>
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
M: ArrayLength,
{
fn from(value: CanonicalLagrangeDenominator<F, N>) -> Self {
// assertion that field is large enough
// when it is large enough, `F::try_from().unwrap()` below does not panic
assert!(
F::BITS > usize::BITS - (N::USIZE + M::USIZE).leading_zeros(),
"Field size {} is not large enough to hold {} + {} points",

Check warning on line 153 in ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs#L153

Added line #L153 was not covered by tests
F::BITS,
N::USIZE,
M::USIZE
);

// assertion that table is not too large for the stack
assert!(<F as Serializable>::Size::USIZE * N::USIZE * M::USIZE < 2024);

let mut table = GenericArray::generate(|_| value.denominator.clone());
table.iter_mut().zip(0u64..).for_each(|(row, i)| {
Self::compute_table_row(&F::try_from(u128::from(i + N::U64)).unwrap(), row);
});
LagrangeTable { table }
}
}

#[cfg(all(test, unit_test))]
mod test {
use std::fmt::Debug;

use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray};
use proptest::{prelude::*, proptest};
use typenum::{U1, U32, U7, U8};

use crate::{
ff::Field,
protocol::ipa_prf::malicious_security::lagrange::{
CanonicalLagrangeDenominator, LagrangeTable, Polynomial,
},
};

type TestField = crate::ff::Fp32BitPrime;

#[derive(Debug, PartialEq, Clone)]
struct MonomialFormPolynomial<F: Field, N: ArrayLength> {
coefficients: GenericArray<F, N>,
}

impl<F, N> MonomialFormPolynomial<F, N>
where
F: Field,
N: ArrayLength,
{
/// test helper function that evaluates a polynomial in monomial form, i.e. `sum_i c_i x^i` on points `x_output`
/// where `c_0` to `c_N` are stored in `polynomial`
fn eval<M>(&self, x_output: &GenericArray<F, M>) -> GenericArray<F, M>
where
M: ArrayLength,
{
// evaluate polynomial p at evaluation_points and random point using monomial base
let mut y_values = GenericArray::generate(|_| F::ZERO);
for (x, y) in x_output.iter().zip(y_values.iter_mut()) {
// monomial base, i.e. `x^k`
let mut base = F::ONE;
// evaluate p via `sum_k coefficient_k * x^k`
for coefficient in &self.coefficients {
*y += *coefficient * base;
base *= *x;
}
}
y_values
}
}

impl<F, N> From<MonomialFormPolynomial<F, N>> for Polynomial<F, N>
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
{
fn from(value: MonomialFormPolynomial<F, N>) -> Self {
let canonical_points: GenericArray<F, N> =
GenericArray::generate(|i| F::try_from(u128::try_from(i).unwrap()).unwrap());
Polynomial {
y_coordinates: value.eval(&canonical_points),
}
}
}

fn lagrange_single_output_point_using_new(
output_point: TestField,
input_points: [TestField; 32],
) {
let polynomial_monomial_form = MonomialFormPolynomial {
coefficients: GenericArray::<TestField, U32>::from_array(input_points),
};
let output_expected = polynomial_monomial_form.eval(
&GenericArray::<TestField, U1>::from_array([output_point; 1]),
);
let polynomial = Polynomial::from(polynomial_monomial_form.clone());
let denominator = CanonicalLagrangeDenominator::<TestField, U32>::new();
// generate table using new
let lagrange_table = LagrangeTable::<TestField, U32, U1>::new(denominator, &output_point);
let output = lagrange_table.eval(&polynomial);
assert_eq!(output, output_expected);
}

proptest! {
danielmasny marked this conversation as resolved.
Show resolved Hide resolved
#[test]
fn proptest_lagrange_single_output_point_using_new(output_point: TestField, input_points in prop::array::uniform32(any::<TestField>())){
lagrange_single_output_point_using_new(output_point,input_points);
}
}
Comment on lines +255 to +256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what is going on here... but it looks wrong....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just use the proptest macro here. I added a function since the formatting is not checked within macro calls.


fn lagrange_canonical_using_from(input_points: [TestField; 8]) {
let polynomial_monomial_form = MonomialFormPolynomial {
coefficients: GenericArray::<TestField, U8>::from_array(input_points),
};
// the canonical x coordinates are 0..15, the outputs use coordinates 8..15:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an error in this comment. I think @danielmasny means:
the canonical x coordinates are 0..7, the outputs use coordinates 8..15

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

let x_coordinates_output = GenericArray::<_, U7>::generate(|i| {
TestField::try_from(u128::try_from(i).unwrap() + 8).unwrap()
});
Comment on lines +263 to +265
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I find the generics here(e.g. U7, U8) really confusing. I wish there was some better way.

let output_expected = polynomial_monomial_form.eval(&x_coordinates_output);
let polynomial = Polynomial::from(polynomial_monomial_form.clone());
let denominator = CanonicalLagrangeDenominator::<TestField, U8>::new();
// generate table using from
let lagrange_table = LagrangeTable::<TestField, U8, U7>::from(denominator);
let output = lagrange_table.eval(&polynomial);
assert_eq!(output, output_expected);
}

proptest! {
#[test]
fn proptest_lagrange_canonical_using_from(input_points in prop::array::uniform8(any::<TestField>()))
{
lagrange_canonical_using_from(input_points);
}
}
}
1 change: 1 addition & 0 deletions ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod lagrange;
2 changes: 2 additions & 0 deletions ipa-core/src/protocol/ipa_prf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ mod boolean_ops;
pub mod prf_eval;
pub mod prf_sharding;

#[cfg(all(test, unit_test))]
mod malicious_security;
mod quicksort;
mod shuffle;

Expand Down
Loading