Skip to content

Commit

Permalink
Merge pull request #957 from danielmasny/precomputed_lagrange
Browse files Browse the repository at this point in the history
Precomputed Lagrange
  • Loading branch information
benjaminsavage authored Mar 16, 2024
2 parents d574a7c + 0a81a8e commit c50e800
Show file tree
Hide file tree
Showing 4 changed files with 330 additions and 1 deletion.
46 changes: 45 additions & 1 deletion ipa-core/src/ff/prime_field.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt::Display;
use std::{fmt::Display, mem};

use generic_array::GenericArray;

Expand All @@ -14,6 +14,42 @@ pub trait PrimeField: Field + U128Conversions {
type PrimeInteger: Into<u128>;

const PRIME: Self::PrimeInteger;

/// Invert function that returns the multiplicative inverse
/// the default implementation uses the extended Euclidean algorithm,
/// follows inversion algorithm in
/// (with the modification that it works for unsigned integers by keeping track of `sign`):
/// `https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm`
///
/// The function operates on `u128` rather than field elements since we need divisions
///
/// ## Panics
/// When `self` is `Zero`
#[must_use]
fn invert(&self) -> Self {
assert_ne!(*self, Self::ZERO);

let mut t = 0u128;
let mut newt = 1u128;
let mut r = Self::PRIME.into();
let mut newr = self.as_u128();
let mut sign = 1u128;

while newr != 0 {
let quotient = r / newr;
mem::swap(&mut t, &mut newt);
mem::swap(&mut r, &mut newr);
newt += quotient * t;
newr -= quotient * r;

// flip sign
sign = 1 - sign;
}

// when sign is negative, output `PRIME-t` otherwise `t`
// unwrap is safe
Self::try_from((1 - sign) * t + sign * (Self::PRIME.into() - t)).unwrap()
}
}

#[derive(thiserror::Error, Debug)]
Expand Down Expand Up @@ -295,6 +331,14 @@ macro_rules! field_impl {
let err = $field::deserialize(&buf).unwrap_err();
assert!(matches!(err, GreaterThanPrimeError(..)))
}

#[test]
fn invert(element: $field) {
if element != $field::ZERO
{
assert_eq!($field::ONE,element * element.invert() );
}
}
}
}

Expand Down
282 changes: 282 additions & 0 deletions ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
use std::fmt::Debug;

use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray};
use typenum::{Unsigned, U1};

use crate::ff::{Field, PrimeField, Serializable};

/// A degree `N-1` polynomial is stored as `N` points `(x,y)`
/// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE`
/// Therefore, we only need to store the `y` coordinates.
#[derive(Debug, PartialEq, Clone)]
pub struct Polynomial<F: Field, N: ArrayLength> {
y_coordinates: GenericArray<F, N>,
}

/// The Canonical Lagrange denominator is defined as the denominator of the Lagrange base polynomials
/// `https://en.wikipedia.org/wiki/Lagrange_polynomial`
/// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE`
/// the degree of the polynomials is `N-1`
pub struct CanonicalLagrangeDenominator<F: Field, N: ArrayLength> {
denominator: GenericArray<F, N>,
}

impl<F, N> CanonicalLagrangeDenominator<F, N>
where
F: PrimeField + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
{
/// generates canonical Lagrange denominators
///
/// ## Panics
/// When the field size is too small for `N` evaluation points
pub fn new() -> Self {
// assertion that field is large enough
// when it is large enough, `F::try_from().unwrap()` below does not panic
assert!(
F::BITS > usize::BITS - N::USIZE.leading_zeros(),
"Field size {} is not large enough to hold {} points",
F::BITS,
N::USIZE
);

// assertion that table is not too large for the stack
assert!(<F as Serializable>::Size::USIZE * N::USIZE < 2024);

let mut denominator = GenericArray::generate(|_| F::ONE);
for (d, i) in denominator.iter_mut().zip(0u64..) {
for j in (0..N::U64).filter(|&j| i != j) {
*d *= F::try_from(u128::from(i)).unwrap() - F::try_from(u128::from(j)).unwrap();
}
*d = d.invert();
}
Self { denominator }
}
}

/// `LagrangeTable` is a precomputed table for the Lagrange evaluation.
/// Allows to compute points on the polynomial, i.e. output points,
/// given enough points on the polynomial, i.e. input points,
/// by using the `eval` function.
/// The "x coordinates" are implicit.
/// The "y coordinates" of the input points are inputs to `eval`.
/// The output of `eval` are the "y coordinates" of the output points .
/// The "x coordinates" of the input points `x_0` to `x_(N-1)` are `F::ZERO` to `(N-1)*F::ONE`.
/// The `LagrangeTable` also specifies `M` "x coordinates" for the output points.
/// The "x coordinates" of the output points `x_N` to `x_(N+M-1)` are `N*F::ONE` to `(N+M-1)*F::ONE`
/// when generated using `from(denominator)`
/// unless generated using `new(denominator, x_output)` for a specific output "x coordinate" `x_output`.
pub struct LagrangeTable<F: Field, N: ArrayLength, M: ArrayLength> {
table: GenericArray<GenericArray<F, N>, M>,
}

impl<F, N> LagrangeTable<F, N, U1>
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
{
/// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point
/// The "x coordinate" of the output point is `x_output`.
pub fn new(denominator: CanonicalLagrangeDenominator<F, N>, x_output: &F) -> Self {
// assertion that table is not too large for the stack
assert!(<F as Serializable>::Size::USIZE * N::USIZE < 2024);

let mut table = denominator.denominator;
Self::compute_table_row(x_output, &mut table);
LagrangeTable::<F, N, U1> {
table: GenericArray::from_array([table; 1]),
}
}
}

impl<F, N, M> LagrangeTable<F, N, M>
where
F: Field,
N: ArrayLength,
M: ArrayLength,
{
/// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates"
/// outputs the "y coordinates" such that `(x,y)` lies on `polynomial`
pub fn eval(&self, polynomial: &Polynomial<F, N>) -> GenericArray<F, M> {
let mut result = GenericArray::generate(|_| F::ONE);
self.mult_result_by_evaluation(polynomial, &mut result);
result
}

/// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates"
/// the "y coordinates" of the evaluation are multiplied to `result`
pub fn mult_result_by_evaluation(
&self,
polynomial: &Polynomial<F, N>,
result: &mut GenericArray<F, M>,
) {
for (y, base) in result.iter_mut().zip(self.table.iter()) {
*y *= base
.iter()
.zip(polynomial.y_coordinates.iter())
.fold(F::ZERO, |acc, (&base, &y)| acc + base * y);
}
}

/// helper function to compute a single row of `LagrangeTable`
///
/// ## Panics
/// When the field size is too small for `N` evaluation points
fn compute_table_row(x_output: &F, table_row: &mut GenericArray<F, N>)
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
{
for (entry, i) in table_row.iter_mut().zip(0u64..) {
for j in (0..N::U64).filter(|&j| j != i) {
*entry *= *x_output - F::try_from(u128::from(j)).unwrap();
}
}
}
}

impl<F, N, M> From<CanonicalLagrangeDenominator<F, N>> for LagrangeTable<F, N, M>
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
M: ArrayLength,
{
fn from(value: CanonicalLagrangeDenominator<F, N>) -> Self {
// assertion that field is large enough
// when it is large enough, `F::try_from().unwrap()` below does not panic
assert!(
F::BITS > usize::BITS - (N::USIZE + M::USIZE).leading_zeros(),
"Field size {} is not large enough to hold {} + {} points",
F::BITS,
N::USIZE,
M::USIZE
);

// assertion that table is not too large for the stack
assert!(<F as Serializable>::Size::USIZE * N::USIZE * M::USIZE < 2024);

let mut table = GenericArray::generate(|_| value.denominator.clone());
table.iter_mut().zip(0u64..).for_each(|(row, i)| {
Self::compute_table_row(&F::try_from(u128::from(i + N::U64)).unwrap(), row);
});
LagrangeTable { table }
}
}

#[cfg(all(test, unit_test))]
mod test {
use std::fmt::Debug;

use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray};
use proptest::{prelude::*, proptest};
use typenum::{U1, U32, U7, U8};

use crate::{
ff::Field,
protocol::ipa_prf::malicious_security::lagrange::{
CanonicalLagrangeDenominator, LagrangeTable, Polynomial,
},
};

type TestField = crate::ff::Fp32BitPrime;

#[derive(Debug, PartialEq, Clone)]
struct MonomialFormPolynomial<F: Field, N: ArrayLength> {
coefficients: GenericArray<F, N>,
}

impl<F, N> MonomialFormPolynomial<F, N>
where
F: Field,
N: ArrayLength,
{
/// test helper function that evaluates a polynomial in monomial form, i.e. `sum_i c_i x^i` on points `x_output`
/// where `c_0` to `c_N` are stored in `polynomial`
fn eval<M>(&self, x_output: &GenericArray<F, M>) -> GenericArray<F, M>
where
M: ArrayLength,
{
// evaluate polynomial p at evaluation_points and random point using monomial base
let mut y_values = GenericArray::generate(|_| F::ZERO);
for (x, y) in x_output.iter().zip(y_values.iter_mut()) {
// monomial base, i.e. `x^k`
let mut base = F::ONE;
// evaluate p via `sum_k coefficient_k * x^k`
for coefficient in &self.coefficients {
*y += *coefficient * base;
base *= *x;
}
}
y_values
}
}

impl<F, N> From<MonomialFormPolynomial<F, N>> for Polynomial<F, N>
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
N: ArrayLength,
{
fn from(value: MonomialFormPolynomial<F, N>) -> Self {
let canonical_points: GenericArray<F, N> =
GenericArray::generate(|i| F::try_from(u128::try_from(i).unwrap()).unwrap());
Polynomial {
y_coordinates: value.eval(&canonical_points),
}
}
}

fn lagrange_single_output_point_using_new(
output_point: TestField,
input_points: [TestField; 32],
) {
let polynomial_monomial_form = MonomialFormPolynomial {
coefficients: GenericArray::<TestField, U32>::from_array(input_points),
};
let output_expected = polynomial_monomial_form.eval(
&GenericArray::<TestField, U1>::from_array([output_point; 1]),
);
let polynomial = Polynomial::from(polynomial_monomial_form.clone());
let denominator = CanonicalLagrangeDenominator::<TestField, U32>::new();
// generate table using new
let lagrange_table = LagrangeTable::<TestField, U32, U1>::new(denominator, &output_point);
let output = lagrange_table.eval(&polynomial);
assert_eq!(output, output_expected);
}

proptest! {
#[test]
fn proptest_lagrange_single_output_point_using_new(output_point: TestField, input_points in prop::array::uniform32(any::<TestField>())){
lagrange_single_output_point_using_new(output_point,input_points);
}
}

fn lagrange_canonical_using_from(input_points: [TestField; 8]) {
let polynomial_monomial_form = MonomialFormPolynomial {
coefficients: GenericArray::<TestField, U8>::from_array(input_points),
};
// the canonical x coordinates are 0..15, the outputs use coordinates 8..15:
let x_coordinates_output = GenericArray::<_, U7>::generate(|i| {
TestField::try_from(u128::try_from(i).unwrap() + 8).unwrap()
});
let output_expected = polynomial_monomial_form.eval(&x_coordinates_output);
let polynomial = Polynomial::from(polynomial_monomial_form.clone());
let denominator = CanonicalLagrangeDenominator::<TestField, U8>::new();
// generate table using from
let lagrange_table = LagrangeTable::<TestField, U8, U7>::from(denominator);
let output = lagrange_table.eval(&polynomial);
assert_eq!(output, output_expected);
}

proptest! {
#[test]
fn proptest_lagrange_canonical_using_from(input_points in prop::array::uniform8(any::<TestField>()))
{
lagrange_canonical_using_from(input_points);
}
}
}
1 change: 1 addition & 0 deletions ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod lagrange;
2 changes: 2 additions & 0 deletions ipa-core/src/protocol/ipa_prf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ mod boolean_ops;
pub mod prf_eval;
pub mod prf_sharding;

#[cfg(all(test, unit_test))]
mod malicious_security;
mod quicksort;
mod shuffle;

Expand Down

0 comments on commit c50e800

Please sign in to comment.