From fc2611686def7ef27ac581ad7e3952d3cc7c2b83 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Thu, 15 Feb 2024 14:53:08 -0800 Subject: [PATCH 1/9] inversion for prime fields (smaller than 128 bit) --- ipa-core/src/ff/prime_field.rs | 47 +++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/ff/prime_field.rs b/ipa-core/src/ff/prime_field.rs index 27e22a74e..5f32eb337 100644 --- a/ipa-core/src/ff/prime_field.rs +++ b/ipa-core/src/ff/prime_field.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::{fmt::Display, mem}; use generic_array::GenericArray; @@ -14,6 +14,43 @@ pub trait PrimeField: Field { type PrimeInteger: Into; const PRIME: Self::PrimeInteger; + + /// Invert function that returns the multiplicative inverse + /// the default implementation uses the extended Euclidean algorithm, + /// follows inversion algorithm in + /// (with the modification that it works for unsigned integers by keeping track of `sign`): + /// `https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm` + /// + /// The function operates on `u128` rather than field elements since we need divisions + /// + /// ## Panics + /// When `self` is `Zero` + + #[must_use] + fn invert(&self) -> Self { + assert_ne!(*self, Self::ZERO); + + let mut t = 0u128; + let mut newt = 1u128; + let mut r = Self::PRIME.into(); + let mut newr = self.as_u128(); + let mut sign = 1u128; + + while newr != 0u128 { + let quotient = r / newr; + mem::swap(&mut t, &mut newt); + mem::swap(&mut r, &mut newr); + newt += quotient * t; + newr -= quotient * r; + + // flip sign + sign = 1 - sign; + } + + // when sign is negative, output `PRIME-t` otherwise `t` + // unwrap is safe + Self::try_from((1 - sign) * t + sign * (Self::PRIME.into() - t)).unwrap() + } } #[derive(thiserror::Error, Debug)] @@ -294,6 +331,14 @@ macro_rules! field_impl { let err = $field::deserialize(&buf).unwrap_err(); assert!(matches!(err, GreaterThanPrimeError(..))) } + + #[test] + fn test_invert(element: $field) { + if element != $field::ZERO + { + assert_eq!($field::ONE,element * element.invert() ); + } + } } } From ee6d2f50e7b241dd03aa577bbed8fff9dc2a103c Mon Sep 17 00:00:00 2001 From: danielmasny Date: Fri, 16 Feb 2024 20:47:41 -0800 Subject: [PATCH 2/9] lagrange interpolation --- .../ipa_prf/malicious_security/lagrange.rs | 240 ++++++++++++++++++ .../ipa_prf/malicious_security/mod.rs | 1 + ipa-core/src/protocol/ipa_prf/mod.rs | 2 + 3 files changed, 243 insertions(+) create mode 100644 ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs create mode 100644 ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs new file mode 100644 index 000000000..10d79dd25 --- /dev/null +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -0,0 +1,240 @@ +use std::iter; + +use generic_array::{ArrayLength, GenericArray}; +use typenum::U1; + +use crate::ff::{Field, PrimeField}; + +/// A degree `N-1` polynomial is stored as `N` points `(x,y)` +/// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE` +/// Therefore, we only need to store the `y` coordinates. +#[derive(Debug, PartialEq, Clone)] +pub struct Polynomial { + y_coordinates: GenericArray, +} + +/// The Canonical Lagrange denominator is defined as the denominator of the Lagrange base polynomials +/// `https://en.wikipedia.org/wiki/Lagrange_polynomial` +/// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE` +/// the degree of the polynomials is `N-1` +pub struct CanonicalLagrangeDenominator { + denominator: GenericArray, +} + +impl CanonicalLagrangeDenominator +where + F: PrimeField, + N: ArrayLength, +{ + /// generates canonical Lagrange denominators + /// + /// ## Panics + /// When the field size is too small for `N` evaluation points + pub fn new() -> Self { + // assertion that field is large enough + // also checks that `try_from` for conversions from sufficiently small `u128` to `F` do not panic + debug_assert!(F::BITS > usize::BITS - N::USIZE.leading_zeros()); + + let mut denominator = iter::repeat(F::ONE) + .take(N::USIZE) + .collect::>(); + for (i, d) in denominator.iter_mut().enumerate() { + for j in (0usize..N::USIZE).filter(|&j| i != j) { + *d *= F::try_from(i as u128).unwrap() - F::try_from(j as u128).unwrap(); + } + *d = d.invert(); + } + CanonicalLagrangeDenominator { denominator } + } +} + +/// `LagrangeTable` is a precomputation table for the Lagrange evaluation. +/// The "x coordinates" of the input points are `x_0` to `x_(N-1)` are `F::ZERO` to `(N-1)*F::ONE`. +/// The `LagrangeTable` also specifies `M` "x coordinates" for the output points +/// The "x coordinates" of the output points are `x_N` to `x_(N+M-1)` are `N*F::ONE` to `(N+M-1)*F::ONE`. +pub struct LagrangeTable { + table: GenericArray, M>, +} + +impl LagrangeTable +where + F: Field, + N: ArrayLength, +{ + /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point + /// The "x coordinate" of the output point is `x_output`. + pub fn new(denominator: &CanonicalLagrangeDenominator, x_output: &F) -> Self { + let mut table = denominator.denominator.clone(); + compute_table_row(x_output, &mut table); + LagrangeTable:: { + table: GenericArray::from_array([table; 1]), + } + } +} + +impl LagrangeTable +where + F: Field, + N: ArrayLength, + M: ArrayLength, +{ + /// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates" + /// outputs the "y coordinates" such that `(x,y)` lies on `polynomial` + pub fn eval(&self, polynomial: &Polynomial) -> GenericArray { + let mut result = iter::repeat(F::ONE) + .take(M::USIZE) + .collect::>(); + self.mult_result_by_evaluation(polynomial, &mut result); + result + } + + /// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates" + /// the "y coordinates" of the evaluation are multiplied to `result` + pub fn mult_result_by_evaluation( + &self, + polynomial: &Polynomial, + result: &mut GenericArray, + ) { + for (y, base) in result.iter_mut().zip(self.table.iter()) { + *y *= base + .iter() + .zip(polynomial.y_coordinates.iter()) + .fold(F::ZERO, |acc, (&base, &y)| acc + base * y); + } + } +} + +impl From> for LagrangeTable +where + F: Field, + N: ArrayLength, + M: ArrayLength, +{ + fn from(value: CanonicalLagrangeDenominator) -> Self { + // assertion that field is large enough + // also checks that `try_from` for conversions from sufficiently small `u128` to `F` do not panic + debug_assert!(F::BITS > usize::BITS - N::USIZE.leading_zeros() - M::USIZE.leading_zeros()); + + let mut table = iter::repeat(value.denominator.clone()) + .take(M::USIZE) + .collect::>(); + table.iter_mut().enumerate().for_each(|(i, row)| { + compute_table_row(&F::try_from((i + N::USIZE) as u128).unwrap(), row); + }); + LagrangeTable { table } + } +} + +/// helper function to compute a single row of `CanonicalLagrangeTable` +/// +/// ## Panics +/// When the field size is too small for `N` evaluation points +fn compute_table_row(x_output: &F, table_row: &mut GenericArray) +where + F: Field, + N: ArrayLength, +{ + for (i, entry) in table_row.iter_mut().enumerate() { + for j in (0usize..N::USIZE).filter(|&j| j != i) { + *entry *= *x_output - F::try_from(j as u128).unwrap(); + } + } +} + +#[cfg(all(test, unit_test))] +mod test { + use std::iter; + + use generic_array::{ArrayLength, GenericArray}; + use proptest::{prelude::*, proptest}; + use typenum::{U1, U32, U7, U8}; + + use crate::{ + ff::Field, + protocol::ipa_prf::malicious_security::lagrange::{ + CanonicalLagrangeDenominator, LagrangeTable, Polynomial, + }, + }; + + type TestField = crate::ff::Fp32BitPrime; + + #[derive(Debug, PartialEq, Clone)] + struct MonomialFormPolynomial { + coefficients: GenericArray, + } + + impl MonomialFormPolynomial + where + F: Field, + N: ArrayLength, + { + /// test helper function that evaluates a polynomial in monomial form, i.e. `sum_i c_i x^i` on points `x_output` + /// where `c_0` to `c_N` are stored in `polynomial` + fn eval(&self, x_output: &GenericArray) -> GenericArray + where + M: ArrayLength, + { + // evaluate polynomial p at evaluation_points and random point using monomial base + let mut y_values = iter::repeat(F::ZERO) + .take(M::USIZE) + .collect::>(); + for (x, y) in x_output.iter().zip(y_values.iter_mut()) { + // monomial base, i.e. `x^k` + let mut base = F::ONE; + // evaluate p via `sum_k coefficient_k * x^k` + for coefficient in &self.coefficients { + *y += *coefficient * base; + base *= *x; + } + } + y_values + } + } + + impl From> for Polynomial + where + F: Field, + N: ArrayLength, + { + fn from(value: MonomialFormPolynomial) -> Self { + let canonical_points: GenericArray = (0..N::USIZE) + .map(|i| F::try_from(i as u128).unwrap()) + .collect::>(); + Polynomial { + y_coordinates: value.eval(&canonical_points), + } + } + } + + proptest! { + #[test] + fn test_lagrange_single_output_point_using_new(output_point: TestField, input_points in prop::array::uniform32(any::())){ + let polynomial_monomial_form = MonomialFormPolynomial{ + coefficients: GenericArray::::from_array(input_points)}; + let output_expected = polynomial_monomial_form.eval( + &GenericArray::::from_array([output_point;1])); + let polynomial = Polynomial::from(polynomial_monomial_form.clone()); + let denominator = CanonicalLagrangeDenominator::::new(); + // generate table using new + let lagrange_table = LagrangeTable::::new(&denominator,&output_point); + let output = lagrange_table.eval(&polynomial); + assert_eq!(output,output_expected); + } + + #[test] + fn test_lagrange_cannonical_using_from(input_points in prop::array::uniform8(any::())) + { + let polynomial_monomial_form = MonomialFormPolynomial{ + coefficients: GenericArray::::from_array(input_points)}; + // the canonical x coordinates are 0..15, the outputs use coordinates 8..15: + let x_coordinates_output = (8..15).map(|i|TestField::try_from(i).unwrap()).collect::>(); + let output_expected = polynomial_monomial_form.eval(&x_coordinates_output); + let polynomial = Polynomial::from(polynomial_monomial_form.clone()); + let denominator = CanonicalLagrangeDenominator::::new(); + // generate table using from + let lagrange_table = LagrangeTable::::from(denominator); + let output = lagrange_table.eval(&polynomial); + assert_eq!(output,output_expected); + } + } +} diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs new file mode 100644 index 000000000..ea0ac6eef --- /dev/null +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs @@ -0,0 +1 @@ +pub mod lagrange; diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 47616d929..54530456b 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -28,6 +28,8 @@ mod boolean_ops; pub mod prf_eval; pub mod prf_sharding; +#[cfg(all(test, unit_test))] +mod malicious_security; mod quicksort; mod shuffle; From 9d51abd843da44a3391721a158638ecf970865d1 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Tue, 20 Feb 2024 12:02:55 -0800 Subject: [PATCH 3/9] address comments --- .../ipa_prf/malicious_security/lagrange.rs | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 10d79dd25..17254ba96 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -44,7 +44,7 @@ where } *d = d.invert(); } - CanonicalLagrangeDenominator { denominator } + Self { denominator } } } @@ -63,9 +63,9 @@ where { /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point /// The "x coordinate" of the output point is `x_output`. - pub fn new(denominator: &CanonicalLagrangeDenominator, x_output: &F) -> Self { - let mut table = denominator.denominator.clone(); - compute_table_row(x_output, &mut table); + pub fn new(denominator: CanonicalLagrangeDenominator, x_output: &F) -> Self { + let mut table = denominator.denominator; + Self::compute_table_row(x_output, &mut table); LagrangeTable:: { table: GenericArray::from_array([table; 1]), } @@ -102,6 +102,22 @@ where .fold(F::ZERO, |acc, (&base, &y)| acc + base * y); } } + + /// helper function to compute a single row of `CanonicalLagrangeTable` + /// + /// ## Panics + /// When the field size is too small for `N` evaluation points + fn compute_table_row(x_output: &F, table_row: &mut GenericArray) + where + F: Field, + N: ArrayLength, + { + for (i, entry) in table_row.iter_mut().enumerate() { + for j in (0usize..N::USIZE).filter(|&j| j != i) { + *entry *= *x_output - F::try_from(j as u128).unwrap(); + } + } + } } impl From> for LagrangeTable @@ -113,34 +129,18 @@ where fn from(value: CanonicalLagrangeDenominator) -> Self { // assertion that field is large enough // also checks that `try_from` for conversions from sufficiently small `u128` to `F` do not panic - debug_assert!(F::BITS > usize::BITS - N::USIZE.leading_zeros() - M::USIZE.leading_zeros()); + debug_assert!(F::BITS > usize::BITS - (N::USIZE + M::USIZE).leading_zeros()); let mut table = iter::repeat(value.denominator.clone()) .take(M::USIZE) .collect::>(); table.iter_mut().enumerate().for_each(|(i, row)| { - compute_table_row(&F::try_from((i + N::USIZE) as u128).unwrap(), row); + Self::compute_table_row(&F::try_from((i + N::USIZE) as u128).unwrap(), row); }); LagrangeTable { table } } } -/// helper function to compute a single row of `CanonicalLagrangeTable` -/// -/// ## Panics -/// When the field size is too small for `N` evaluation points -fn compute_table_row(x_output: &F, table_row: &mut GenericArray) -where - F: Field, - N: ArrayLength, -{ - for (i, entry) in table_row.iter_mut().enumerate() { - for j in (0usize..N::USIZE).filter(|&j| j != i) { - *entry *= *x_output - F::try_from(j as u128).unwrap(); - } - } -} - #[cfg(all(test, unit_test))] mod test { use std::iter; @@ -216,7 +216,7 @@ mod test { let polynomial = Polynomial::from(polynomial_monomial_form.clone()); let denominator = CanonicalLagrangeDenominator::::new(); // generate table using new - let lagrange_table = LagrangeTable::::new(&denominator,&output_point); + let lagrange_table = LagrangeTable::::new(denominator,&output_point); let output = lagrange_table.eval(&polynomial); assert_eq!(output,output_expected); } From b95d9f84c64bc0371576176a3503cd73f4b128b6 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Tue, 20 Feb 2024 12:07:03 -0800 Subject: [PATCH 4/9] address comments --- ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 17254ba96..2ed9eee0f 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -208,7 +208,7 @@ mod test { proptest! { #[test] - fn test_lagrange_single_output_point_using_new(output_point: TestField, input_points in prop::array::uniform32(any::())){ + fn lagrange_single_output_point_using_new(output_point: TestField, input_points in prop::array::uniform32(any::())){ let polynomial_monomial_form = MonomialFormPolynomial{ coefficients: GenericArray::::from_array(input_points)}; let output_expected = polynomial_monomial_form.eval( @@ -222,7 +222,7 @@ mod test { } #[test] - fn test_lagrange_cannonical_using_from(input_points in prop::array::uniform8(any::())) + fn lagrange_cannonical_using_from(input_points in prop::array::uniform8(any::())) { let polynomial_monomial_form = MonomialFormPolynomial{ coefficients: GenericArray::::from_array(input_points)}; From 2de66ed48a932e77b67007f14c7c6359865b10fd Mon Sep 17 00:00:00 2001 From: danielmasny Date: Wed, 21 Feb 2024 15:32:35 -0800 Subject: [PATCH 5/9] address Alex's comments --- ipa-core/src/ff/prime_field.rs | 7 +- .../ipa_prf/malicious_security/lagrange.rs | 118 ++++++++++++------ 2 files changed, 85 insertions(+), 40 deletions(-) diff --git a/ipa-core/src/ff/prime_field.rs b/ipa-core/src/ff/prime_field.rs index 5f32eb337..ef000cc0b 100644 --- a/ipa-core/src/ff/prime_field.rs +++ b/ipa-core/src/ff/prime_field.rs @@ -25,7 +25,6 @@ pub trait PrimeField: Field { /// /// ## Panics /// When `self` is `Zero` - #[must_use] fn invert(&self) -> Self { assert_ne!(*self, Self::ZERO); @@ -36,7 +35,7 @@ pub trait PrimeField: Field { let mut newr = self.as_u128(); let mut sign = 1u128; - while newr != 0u128 { + while newr != 0 { let quotient = r / newr; mem::swap(&mut t, &mut newt); mem::swap(&mut r, &mut newr); @@ -332,8 +331,8 @@ macro_rules! field_impl { assert!(matches!(err, GreaterThanPrimeError(..))) } - #[test] - fn test_invert(element: $field) { + #[test] + fn invert(element: $field) { if element != $field::ZERO { assert_eq!($field::ONE,element * element.invert() ); diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 2ed9eee0f..0a3ab3291 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,9 +1,9 @@ use std::iter; use generic_array::{ArrayLength, GenericArray}; -use typenum::U1; +use typenum::{Unsigned, U1}; -use crate::ff::{Field, PrimeField}; +use crate::ff::{Field, PrimeField, Serializable}; /// A degree `N-1` polynomial is stored as `N` points `(x,y)` /// where the "x coordinates" of the input points are `x_0` to `x_N` are `F::ZERO` to `(N-1)*F::ONE` @@ -32,8 +32,16 @@ where /// When the field size is too small for `N` evaluation points pub fn new() -> Self { // assertion that field is large enough - // also checks that `try_from` for conversions from sufficiently small `u128` to `F` do not panic - debug_assert!(F::BITS > usize::BITS - N::USIZE.leading_zeros()); + // when it is large enough, `F::try_from().unwrap()` below does not panic + assert!( + F::BITS > usize::BITS - N::USIZE.leading_zeros(), + "Field size {} is not large enough to hold {} points", + F::BITS, + N::USIZE + ); + + // assertion that table is not too large for the stack + assert!(::Size::USIZE * N::USIZE < 2024); let mut denominator = iter::repeat(F::ONE) .take(N::USIZE) @@ -48,10 +56,18 @@ where } } -/// `LagrangeTable` is a precomputation table for the Lagrange evaluation. -/// The "x coordinates" of the input points are `x_0` to `x_(N-1)` are `F::ZERO` to `(N-1)*F::ONE`. -/// The `LagrangeTable` also specifies `M` "x coordinates" for the output points -/// The "x coordinates" of the output points are `x_N` to `x_(N+M-1)` are `N*F::ONE` to `(N+M-1)*F::ONE`. +/// `LagrangeTable` is a precomputed table for the Lagrange evaluation. +/// Allows to compute points on the polynomial, i.e. output points, +/// given enough points on the polynomial, i.e. input points, +/// by using the `eval` function. +/// The "x coordinates" are implicit. +/// The "y coordinates" of the input points are inputs to `eval`. +/// The output of `eval` are the "y coordinates" of the output points . +/// The "x coordinates" of the input points `x_0` to `x_(N-1)` are `F::ZERO` to `(N-1)*F::ONE`. +/// The `LagrangeTable` also specifies `M` "x coordinates" for the output points. +/// The "x coordinates" of the output points `x_N` to `x_(N+M-1)` are `N*F::ONE` to `(N+M-1)*F::ONE` +/// when generated using `from(denominator)` +/// unless generated using `new(denominator, x_output)` for a specific output "x coordinate" `x_output`. pub struct LagrangeTable { table: GenericArray, M>, } @@ -64,6 +80,9 @@ where /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point /// The "x coordinate" of the output point is `x_output`. pub fn new(denominator: CanonicalLagrangeDenominator, x_output: &F) -> Self { + // assertion that table is not too large for the stack + assert!(::Size::USIZE * N::USIZE < 2024); + let mut table = denominator.denominator; Self::compute_table_row(x_output, &mut table); LagrangeTable:: { @@ -103,7 +122,7 @@ where } } - /// helper function to compute a single row of `CanonicalLagrangeTable` + /// helper function to compute a single row of `LagrangeTable` /// /// ## Panics /// When the field size is too small for `N` evaluation points @@ -128,8 +147,17 @@ where { fn from(value: CanonicalLagrangeDenominator) -> Self { // assertion that field is large enough - // also checks that `try_from` for conversions from sufficiently small `u128` to `F` do not panic - debug_assert!(F::BITS > usize::BITS - (N::USIZE + M::USIZE).leading_zeros()); + // when it is large enough, `F::try_from().unwrap()` below does not panic + assert!( + F::BITS > usize::BITS - (N::USIZE + M::USIZE).leading_zeros(), + "Field size {} is not large enough to hold {} + {} points", + F::BITS, + N::USIZE, + M::USIZE + ); + + // assertion that table is not too large for the stack + assert!(::Size::USIZE * N::USIZE * M::USIZE < 2024); let mut table = iter::repeat(value.denominator.clone()) .take(M::USIZE) @@ -206,35 +234,53 @@ mod test { } } + fn lagrange_single_output_point_using_new( + output_point: TestField, + input_points: [TestField; 32], + ) { + let polynomial_monomial_form = MonomialFormPolynomial { + coefficients: GenericArray::::from_array(input_points), + }; + let output_expected = polynomial_monomial_form.eval( + &GenericArray::::from_array([output_point; 1]), + ); + let polynomial = Polynomial::from(polynomial_monomial_form.clone()); + let denominator = CanonicalLagrangeDenominator::::new(); + // generate table using new + let lagrange_table = LagrangeTable::::new(denominator, &output_point); + let output = lagrange_table.eval(&polynomial); + assert_eq!(output, output_expected); + } + proptest! { - #[test] - fn lagrange_single_output_point_using_new(output_point: TestField, input_points in prop::array::uniform32(any::())){ - let polynomial_monomial_form = MonomialFormPolynomial{ - coefficients: GenericArray::::from_array(input_points)}; - let output_expected = polynomial_monomial_form.eval( - &GenericArray::::from_array([output_point;1])); - let polynomial = Polynomial::from(polynomial_monomial_form.clone()); - let denominator = CanonicalLagrangeDenominator::::new(); - // generate table using new - let lagrange_table = LagrangeTable::::new(denominator,&output_point); - let output = lagrange_table.eval(&polynomial); - assert_eq!(output,output_expected); - } + #[test] + fn proptest_lagrange_single_output_point_using_new(output_point: TestField, input_points in prop::array::uniform32(any::())){ + lagrange_single_output_point_using_new(output_point,input_points); + } + } + fn lagrange_canonical_using_from(input_points: [TestField; 8]) { + let polynomial_monomial_form = MonomialFormPolynomial { + coefficients: GenericArray::::from_array(input_points), + }; + // the canonical x coordinates are 0..15, the outputs use coordinates 8..15: + let x_coordinates_output = (8..15) + .map(|i| TestField::try_from(i).unwrap()) + .collect::>(); + let output_expected = polynomial_monomial_form.eval(&x_coordinates_output); + let polynomial = Polynomial::from(polynomial_monomial_form.clone()); + let denominator = CanonicalLagrangeDenominator::::new(); + // generate table using from + let lagrange_table = LagrangeTable::::from(denominator); + let output = lagrange_table.eval(&polynomial); + assert_eq!(output, output_expected); + } + + proptest! { #[test] - fn lagrange_cannonical_using_from(input_points in prop::array::uniform8(any::())) + fn proptest_lagrange_canonical_using_from(input_points in prop::array::uniform8(any::())) { - let polynomial_monomial_form = MonomialFormPolynomial{ - coefficients: GenericArray::::from_array(input_points)}; - // the canonical x coordinates are 0..15, the outputs use coordinates 8..15: - let x_coordinates_output = (8..15).map(|i|TestField::try_from(i).unwrap()).collect::>(); - let output_expected = polynomial_monomial_form.eval(&x_coordinates_output); - let polynomial = Polynomial::from(polynomial_monomial_form.clone()); - let denominator = CanonicalLagrangeDenominator::::new(); - // generate table using from - let lagrange_table = LagrangeTable::::from(denominator); - let output = lagrange_table.eval(&polynomial); - assert_eq!(output,output_expected); + lagrange_canonical_using_from(input_points); } } } From f738a81c55b26b02a55c062098a4b31e58e40adf Mon Sep 17 00:00:00 2001 From: danielmasny Date: Fri, 23 Feb 2024 10:25:53 -0800 Subject: [PATCH 6/9] use generate --- .../ipa_prf/malicious_security/lagrange.rs | 34 ++++++------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 0a3ab3291..5356fb96f 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,6 +1,4 @@ -use std::iter; - -use generic_array::{ArrayLength, GenericArray}; +use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; use typenum::{Unsigned, U1}; use crate::ff::{Field, PrimeField, Serializable}; @@ -43,9 +41,7 @@ where // assertion that table is not too large for the stack assert!(::Size::USIZE * N::USIZE < 2024); - let mut denominator = iter::repeat(F::ONE) - .take(N::USIZE) - .collect::>(); + let mut denominator = GenericArray::generate(|_| F::ONE); for (i, d) in denominator.iter_mut().enumerate() { for j in (0usize..N::USIZE).filter(|&j| i != j) { *d *= F::try_from(i as u128).unwrap() - F::try_from(j as u128).unwrap(); @@ -100,9 +96,7 @@ where /// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates" /// outputs the "y coordinates" such that `(x,y)` lies on `polynomial` pub fn eval(&self, polynomial: &Polynomial) -> GenericArray { - let mut result = iter::repeat(F::ONE) - .take(M::USIZE) - .collect::>(); + let mut result = GenericArray::generate(|_| F::ONE); self.mult_result_by_evaluation(polynomial, &mut result); result } @@ -159,9 +153,7 @@ where // assertion that table is not too large for the stack assert!(::Size::USIZE * N::USIZE * M::USIZE < 2024); - let mut table = iter::repeat(value.denominator.clone()) - .take(M::USIZE) - .collect::>(); + let mut table = GenericArray::generate(|_| value.denominator.clone()); table.iter_mut().enumerate().for_each(|(i, row)| { Self::compute_table_row(&F::try_from((i + N::USIZE) as u128).unwrap(), row); }); @@ -171,9 +163,7 @@ where #[cfg(all(test, unit_test))] mod test { - use std::iter; - - use generic_array::{ArrayLength, GenericArray}; + use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; use proptest::{prelude::*, proptest}; use typenum::{U1, U32, U7, U8}; @@ -203,9 +193,7 @@ mod test { M: ArrayLength, { // evaluate polynomial p at evaluation_points and random point using monomial base - let mut y_values = iter::repeat(F::ZERO) - .take(M::USIZE) - .collect::>(); + let mut y_values = GenericArray::generate(|_| F::ZERO); for (x, y) in x_output.iter().zip(y_values.iter_mut()) { // monomial base, i.e. `x^k` let mut base = F::ONE; @@ -225,9 +213,8 @@ mod test { N: ArrayLength, { fn from(value: MonomialFormPolynomial) -> Self { - let canonical_points: GenericArray = (0..N::USIZE) - .map(|i| F::try_from(i as u128).unwrap()) - .collect::>(); + let canonical_points: GenericArray = + GenericArray::generate(|i| F::try_from(i as u128).unwrap()); Polynomial { y_coordinates: value.eval(&canonical_points), } @@ -264,9 +251,8 @@ mod test { coefficients: GenericArray::::from_array(input_points), }; // the canonical x coordinates are 0..15, the outputs use coordinates 8..15: - let x_coordinates_output = (8..15) - .map(|i| TestField::try_from(i).unwrap()) - .collect::>(); + let x_coordinates_output = + GenericArray::<_, U7>::generate(|i| TestField::try_from(i as u128 + 8).unwrap()); let output_expected = polynomial_monomial_form.eval(&x_coordinates_output); let polynomial = Polynomial::from(polynomial_monomial_form.clone()); let denominator = CanonicalLagrangeDenominator::::new(); From 0c88a3947db1e6c022fb97b76f11baf197cf5300 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Fri, 23 Feb 2024 17:27:22 -0800 Subject: [PATCH 7/9] remove primitive cast --- .../ipa_prf/malicious_security/lagrange.rs | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 5356fb96f..51fb33f97 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -42,9 +42,9 @@ where assert!(::Size::USIZE * N::USIZE < 2024); let mut denominator = GenericArray::generate(|_| F::ONE); - for (i, d) in denominator.iter_mut().enumerate() { - for j in (0usize..N::USIZE).filter(|&j| i != j) { - *d *= F::try_from(i as u128).unwrap() - F::try_from(j as u128).unwrap(); + for (d, i) in denominator.iter_mut().zip(0u64..) { + for j in (0..N::U64).filter(|&j| i != j) { + *d *= F::try_from(u128::from(i)).unwrap() - F::try_from(u128::from(j)).unwrap(); } *d = d.invert(); } @@ -125,9 +125,9 @@ where F: Field, N: ArrayLength, { - for (i, entry) in table_row.iter_mut().enumerate() { - for j in (0usize..N::USIZE).filter(|&j| j != i) { - *entry *= *x_output - F::try_from(j as u128).unwrap(); + for (entry, i) in table_row.iter_mut().zip(0u64..) { + for j in (0..N::U64).filter(|&j| j != i) { + *entry *= *x_output - F::try_from(u128::from(j)).unwrap(); } } } @@ -154,8 +154,8 @@ where assert!(::Size::USIZE * N::USIZE * M::USIZE < 2024); let mut table = GenericArray::generate(|_| value.denominator.clone()); - table.iter_mut().enumerate().for_each(|(i, row)| { - Self::compute_table_row(&F::try_from((i + N::USIZE) as u128).unwrap(), row); + table.iter_mut().zip(0u64..).for_each(|(row, i)| { + Self::compute_table_row(&F::try_from(u128::from(i + N::U64)).unwrap(), row); }); LagrangeTable { table } } @@ -214,7 +214,7 @@ mod test { { fn from(value: MonomialFormPolynomial) -> Self { let canonical_points: GenericArray = - GenericArray::generate(|i| F::try_from(i as u128).unwrap()); + GenericArray::generate(|i| F::try_from(u128::try_from(i).unwrap()).unwrap()); Polynomial { y_coordinates: value.eval(&canonical_points), } @@ -251,8 +251,9 @@ mod test { coefficients: GenericArray::::from_array(input_points), }; // the canonical x coordinates are 0..15, the outputs use coordinates 8..15: - let x_coordinates_output = - GenericArray::<_, U7>::generate(|i| TestField::try_from(i as u128 + 8).unwrap()); + let x_coordinates_output = GenericArray::<_, U7>::generate(|i| { + TestField::try_from(u128::try_from(i).unwrap() + 8).unwrap() + }); let output_expected = polynomial_monomial_form.eval(&x_coordinates_output); let polynomial = Polynomial::from(polynomial_monomial_form.clone()); let denominator = CanonicalLagrangeDenominator::::new(); From 5ea2171ea8bc17bc89aeeff7477242bfb7f23932 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Mon, 26 Feb 2024 14:44:59 -0800 Subject: [PATCH 8/9] add TryFrom trait bound --- .../protocol/ipa_prf/malicious_security/lagrange.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 51fb33f97..66e92e782 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -21,7 +21,7 @@ pub struct CanonicalLagrangeDenominator { impl CanonicalLagrangeDenominator where - F: PrimeField, + F: PrimeField + TryFrom, N: ArrayLength, { /// generates canonical Lagrange denominators @@ -70,7 +70,7 @@ pub struct LagrangeTable { impl LagrangeTable where - F: Field, + F: Field + TryFrom, N: ArrayLength, { /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point @@ -122,7 +122,7 @@ where /// When the field size is too small for `N` evaluation points fn compute_table_row(x_output: &F, table_row: &mut GenericArray) where - F: Field, + F: Field + TryFrom, N: ArrayLength, { for (entry, i) in table_row.iter_mut().zip(0u64..) { @@ -135,7 +135,7 @@ where impl From> for LagrangeTable where - F: Field, + F: Field + TryFrom, N: ArrayLength, M: ArrayLength, { @@ -209,7 +209,7 @@ mod test { impl From> for Polynomial where - F: Field, + F: Field + TryFrom, N: ArrayLength, { fn from(value: MonomialFormPolynomial) -> Self { From 0a81a8e5e6c32880c1d932d2647d972adb005de6 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Mon, 26 Feb 2024 15:00:30 -0800 Subject: [PATCH 9/9] add debug trait bound --- .../src/protocol/ipa_prf/malicious_security/lagrange.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 66e92e782..04ddeae4b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; use typenum::{Unsigned, U1}; @@ -22,6 +24,7 @@ pub struct CanonicalLagrangeDenominator { impl CanonicalLagrangeDenominator where F: PrimeField + TryFrom, + >::Error: Debug, N: ArrayLength, { /// generates canonical Lagrange denominators @@ -71,6 +74,7 @@ pub struct LagrangeTable { impl LagrangeTable where F: Field + TryFrom, + >::Error: Debug, N: ArrayLength, { /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point @@ -123,6 +127,7 @@ where fn compute_table_row(x_output: &F, table_row: &mut GenericArray) where F: Field + TryFrom, + >::Error: Debug, N: ArrayLength, { for (entry, i) in table_row.iter_mut().zip(0u64..) { @@ -136,6 +141,7 @@ where impl From> for LagrangeTable where F: Field + TryFrom, + >::Error: Debug, N: ArrayLength, M: ArrayLength, { @@ -163,6 +169,8 @@ where #[cfg(all(test, unit_test))] mod test { + use std::fmt::Debug; + use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; use proptest::{prelude::*, proptest}; use typenum::{U1, U32, U7, U8}; @@ -210,6 +218,7 @@ mod test { impl From> for Polynomial where F: Field + TryFrom, + >::Error: Debug, N: ArrayLength, { fn from(value: MonomialFormPolynomial) -> Self {