diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 033a56931..6b064d09f 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -149,7 +149,7 @@ tower = { version = "0.4.13", optional = true } tower-http = { version = "0.4.0", optional = true, features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -typenum = "1.16" +typenum = { version = "1.17", features = ["i128"] } # hpke is pinned to it x25519-dalek = "2.0.0-rc.3" diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 2d26e2490..b2fb2e10b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; +use std::iter::repeat; -use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; +use generic_array::{ArrayLength, GenericArray}; use typenum::{Unsigned, U1}; use crate::ff::{Field, PrimeField, Serializable}; @@ -35,20 +36,20 @@ where // assertion that field is large enough // when it is large enough, `F::try_from().unwrap()` below does not panic assert!( - u128::from(N::U64) < F::PRIME.into(), + N::U128 < F::PRIME.into(), "Field size {} is not large enough to hold {} points", F::PRIME.into(), - N::U64 + N::U128 ); // assertion that table is not too large for the stack assert!(<F as Serializable>::Size::USIZE * N::USIZE < 2024); Self { - denominator: (0..u128::from(N::U64)) + denominator: (0..N::U128) .into_iter() .map(|i| { - (0..u128::from(N::U64)) + (0..N::U128) .into_iter() .filter(|&j| i != j) .map(|j| F::try_from(i).unwrap() - F::try_from(j).unwrap()) @@ -84,12 +85,11 @@ where { /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point /// The "x coordinate" of the output point is `x_output`. - pub fn new(denominator: CanonicalLagrangeDenominator<F, N>, x_output: &F) -> Self { + pub fn new(denominator: &CanonicalLagrangeDenominator<F, N>, x_output: &F) -> Self { // assertion that table is not too large for the stack assert!(<F as Serializable>::Size::USIZE * N::USIZE < 2024); - let mut table = denominator.denominator; - Self::compute_table_row(x_output, &mut table); + let table = Self::compute_table_row(x_output, denominator); LagrangeTable::<F, N, U1> { table: GenericArray::from_array([table; 1]), } @@ -120,17 +120,25 @@ where /// /// ## Panics /// When the field size is too small for `N` evaluation points - fn compute_table_row(x_output: &F, table_row: &mut GenericArray<F, N>) + fn compute_table_row( + x_output: &F, + denominator: &CanonicalLagrangeDenominator<F, N>, + ) -> GenericArray<F, N> where F: Field + TryFrom<u128>, <F as TryFrom<u128>>::Error: Debug, N: ArrayLength, { - for (entry, i) in table_row.iter_mut().zip(0u64..) { - for j in (0..N::U64).filter(|&j| j != i) { - *entry *= *x_output - F::try_from(u128::from(j)).unwrap(); - } - } + (0..N::U128) + .zip(repeat(0..N::U128)) + .map(|(i, range)| { + range + .filter(|&j| j != i) + .fold(F::ONE, |acc, j| acc * (*x_output - F::try_from(j).unwrap())) + }) + .zip(&denominator.denominator) + .map(|(numerator, denominator)| *denominator * numerator) + .collect() } } @@ -144,21 +152,21 @@ where // assertion that field is large enough // when it is large enough, `F::try_from().unwrap()` below does not panic assert!( - u128::from(N::U64 + M::U64) < F::PRIME.into(), + N::U128 + M::U128 < F::PRIME.into(), "Field size {} is not large enough to hold {} + {} points", F::PRIME.into(), - N::U64, - M::U64 + N::U128, + M::U128 ); // assertion that table is not too large for the stack assert!(<F as Serializable>::Size::USIZE * N::USIZE * M::USIZE < 2024); - let mut table = GenericArray::generate(|_| value.denominator.clone()); - table.iter_mut().zip(0u64..).for_each(|(row, i)| { - Self::compute_table_row(&F::try_from(u128::from(i + N::U64)).unwrap(), row); - }); - LagrangeTable { table } + LagrangeTable { + table: (N::U128..(N::U128 + M::U128)) + .map(|i| Self::compute_table_row(&F::try_from(i).unwrap(), &value)) + .collect(), + } } } @@ -240,7 +248,7 @@ mod test { let polynomial = Polynomial::from(polynomial_monomial_form.clone()); let denominator = CanonicalLagrangeDenominator::<TestField, U32>::new(); // generate table using new - let lagrange_table = LagrangeTable::<TestField, U32, U1>::new(denominator, &output_point); + let lagrange_table = LagrangeTable::<TestField, U32, U1>::new(&denominator, &output_point); let output = lagrange_table.eval(&polynomial); assert_eq!(output, output_expected); }