From c8be62366b789729125b141e1ded0080b9e9c638 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Sat, 16 Mar 2024 14:18:29 +1000 Subject: [PATCH 1/5] removing one more iter_mut in the denominators --- .../ipa_prf/malicious_security/lagrange.rs | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 7ef171fe2..cbe10a465 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -35,23 +35,28 @@ where // assertion that field is large enough // when it is large enough, `F::try_from().unwrap()` below does not panic assert!( - F::BITS > usize::BITS - N::USIZE.leading_zeros(), + u128::from(N::U64) < F::PRIME.into(), "Field size {} is not large enough to hold {} points", - F::BITS, - N::USIZE + F::PRIME.into(), + N::U64 ); // assertion that table is not too large for the stack assert!(::Size::USIZE * N::USIZE < 2024); - let mut denominator = GenericArray::generate(|_| F::ONE); - for (d, i) in denominator.iter_mut().zip(0u64..) { - for j in (0..N::U64).filter(|&j| i != j) { - *d *= F::try_from(u128::from(i)).unwrap() - F::try_from(u128::from(j)).unwrap(); - } - *d = d.invert(); + Self { + denominator: (0..u128::from(N::U64)) + .into_iter() + .map(|i| { + (0..u128::from(N::U64)) + .into_iter() + .filter(|&j| i != j) + .map(|j| F::try_from(i).unwrap() - F::try_from(j).unwrap()) + .fold(F::ONE, |acc, a| acc * a) + .invert() + }) + .collect(), } - Self { denominator } } } From 98c0543fc79329cb29853eca1eaa0d5ca16d363e Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Sat, 16 Mar 2024 14:34:21 +1000 Subject: [PATCH 2/5] Improving error message --- .../protocol/ipa_prf/malicious_security/lagrange.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index cbe10a465..2d26e2490 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -136,8 +136,7 @@ where impl From> for LagrangeTable where - F: Field + TryFrom, - >::Error: Debug, + F: PrimeField, N: ArrayLength, M: ArrayLength, { @@ -145,11 +144,11 @@ where // assertion that field is large enough // when it is large enough, `F::try_from().unwrap()` below does not panic assert!( - F::BITS > usize::BITS - (N::USIZE + M::USIZE).leading_zeros(), + u128::from(N::U64 + M::U64) < F::PRIME.into(), "Field size {} is not large enough to hold {} + {} points", - F::BITS, - N::USIZE, - M::USIZE + F::PRIME.into(), + N::U64, + M::U64 ); // assertion that table is not too large for the stack From 65066c40d655b8dcc2159f5ad05a168730e3d821 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Sat, 16 Mar 2024 16:12:52 +1000 Subject: [PATCH 3/5] Removing the last iter_mut --- ipa-core/Cargo.toml | 2 +- .../ipa_prf/malicious_security/lagrange.rs | 54 +++++++++++-------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 033a56931..6b064d09f 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -149,7 +149,7 @@ tower = { version = "0.4.13", optional = true } tower-http = { version = "0.4.0", optional = true, features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -typenum = "1.16" +typenum = { version = "1.17", features = ["i128"] } # hpke is pinned to it x25519-dalek = "2.0.0-rc.3" diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 2d26e2490..b2fb2e10b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; +use std::iter::repeat; -use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; +use generic_array::{ArrayLength, GenericArray}; use typenum::{Unsigned, U1}; use crate::ff::{Field, PrimeField, Serializable}; @@ -35,20 +36,20 @@ where // assertion that field is large enough // when it is large enough, `F::try_from().unwrap()` below does not panic assert!( - u128::from(N::U64) < F::PRIME.into(), + N::U128 < F::PRIME.into(), "Field size {} is not large enough to hold {} points", F::PRIME.into(), - N::U64 + N::U128 ); // assertion that table is not too large for the stack assert!(::Size::USIZE * N::USIZE < 2024); Self { - denominator: (0..u128::from(N::U64)) + denominator: (0..N::U128) .into_iter() .map(|i| { - (0..u128::from(N::U64)) + (0..N::U128) .into_iter() .filter(|&j| i != j) .map(|j| F::try_from(i).unwrap() - F::try_from(j).unwrap()) @@ -84,12 +85,11 @@ where { /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point /// The "x coordinate" of the output point is `x_output`. - pub fn new(denominator: CanonicalLagrangeDenominator, x_output: &F) -> Self { + pub fn new(denominator: &CanonicalLagrangeDenominator, x_output: &F) -> Self { // assertion that table is not too large for the stack assert!(::Size::USIZE * N::USIZE < 2024); - let mut table = denominator.denominator; - Self::compute_table_row(x_output, &mut table); + let table = Self::compute_table_row(x_output, denominator); LagrangeTable:: { table: GenericArray::from_array([table; 1]), } @@ -120,17 +120,25 @@ where /// /// ## Panics /// When the field size is too small for `N` evaluation points - fn compute_table_row(x_output: &F, table_row: &mut GenericArray) + fn compute_table_row( + x_output: &F, + denominator: &CanonicalLagrangeDenominator, + ) -> GenericArray where F: Field + TryFrom, >::Error: Debug, N: ArrayLength, { - for (entry, i) in table_row.iter_mut().zip(0u64..) { - for j in (0..N::U64).filter(|&j| j != i) { - *entry *= *x_output - F::try_from(u128::from(j)).unwrap(); - } - } + (0..N::U128) + .zip(repeat(0..N::U128)) + .map(|(i, range)| { + range + .filter(|&j| j != i) + .fold(F::ONE, |acc, j| acc * (*x_output - F::try_from(j).unwrap())) + }) + .zip(&denominator.denominator) + .map(|(numerator, denominator)| *denominator * numerator) + .collect() } } @@ -144,21 +152,21 @@ where // assertion that field is large enough // when it is large enough, `F::try_from().unwrap()` below does not panic assert!( - u128::from(N::U64 + M::U64) < F::PRIME.into(), + N::U128 + M::U128 < F::PRIME.into(), "Field size {} is not large enough to hold {} + {} points", F::PRIME.into(), - N::U64, - M::U64 + N::U128, + M::U128 ); // assertion that table is not too large for the stack assert!(::Size::USIZE * N::USIZE * M::USIZE < 2024); - let mut table = GenericArray::generate(|_| value.denominator.clone()); - table.iter_mut().zip(0u64..).for_each(|(row, i)| { - Self::compute_table_row(&F::try_from(u128::from(i + N::U64)).unwrap(), row); - }); - LagrangeTable { table } + LagrangeTable { + table: (N::U128..(N::U128 + M::U128)) + .map(|i| Self::compute_table_row(&F::try_from(i).unwrap(), &value)) + .collect(), + } } } @@ -240,7 +248,7 @@ mod test { let polynomial = Polynomial::from(polynomial_monomial_form.clone()); let denominator = CanonicalLagrangeDenominator::::new(); // generate table using new - let lagrange_table = LagrangeTable::::new(denominator, &output_point); + let lagrange_table = LagrangeTable::::new(&denominator, &output_point); let output = lagrange_table.eval(&polynomial); assert_eq!(output, output_expected); } From 1f9eb85a40cfc73146bbb9721d25a68184efa1c7 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Sat, 16 Mar 2024 16:27:50 +1000 Subject: [PATCH 4/5] formatting --- ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index b2fb2e10b..f4a3276be 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,5 +1,4 @@ -use std::fmt::Debug; -use std::iter::repeat; +use std::{fmt::Debug, iter::repeat}; use generic_array::{ArrayLength, GenericArray}; use typenum::{Unsigned, U1}; From 8df8f2afb9f9eff9901609825f009dae4b9e1fbe Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Sat, 16 Mar 2024 16:55:33 +1000 Subject: [PATCH 5/5] a bit shorter --- .../src/protocol/ipa_prf/malicious_security/lagrange.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index f4a3276be..c00182bfd 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, iter::repeat}; +use std::fmt::Debug; use generic_array::{ArrayLength, GenericArray}; use typenum::{Unsigned, U1}; @@ -46,10 +46,8 @@ where Self { denominator: (0..N::U128) - .into_iter() .map(|i| { (0..N::U128) - .into_iter() .filter(|&j| i != j) .map(|j| F::try_from(i).unwrap() - F::try_from(j).unwrap()) .fold(F::ONE, |acc, a| acc * a) @@ -129,9 +127,8 @@ where N: ArrayLength, { (0..N::U128) - .zip(repeat(0..N::U128)) - .map(|(i, range)| { - range + .map(|i| { + (0..N::U128) .filter(|&j| j != i) .fold(F::ONE, |acc, j| acc * (*x_output - F::try_from(j).unwrap())) })