From 59a6fb3958019b3cbe34145720d58864ae708cda Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 6 Nov 2024 16:50:57 -0800 Subject: [PATCH 1/3] Change the API of `eval` to take an array instead There is really no reason for taking an iterator, just makes things worse for the compiler --- .../src/protocol/ipa_prf/malicious_security/lagrange.rs | 8 +------- .../src/protocol/ipa_prf/malicious_security/verifier.rs | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index e0fe6e199..c31649a41 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -101,14 +101,8 @@ where /// that were used to generate this table. /// It is assumed that the `y_coordinates` provided to this function correspond the values of the _input_ "x coordinates" /// that were used to generate this table. - pub fn eval(&self, y_coordinates: I) -> [F; M] - where - I: IntoIterator + Copy, - I::IntoIter: ExactSizeIterator, - I::Item: Borrow, + pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M] { - debug_assert_eq!(y_coordinates.into_iter().len(), N); - self.table .iter() .map(|table_row| { diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs index 9bedcc9a2..e62465b73 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs @@ -159,7 +159,7 @@ where last_array[1..last_u_or_v_values.len()].copy_from_slice(&last_u_or_v_values[1..]); // compute and output p_or_q - tables.last().unwrap().eval(last_array)[0] + tables.last().unwrap().eval(&last_array)[0] } #[cfg(all(test, unit_test))] From de591944e91e6ad92e9a08b498eb0ec1d363a10c Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 6 Nov 2024 17:51:48 -0800 Subject: [PATCH 2/3] Lagrage evaluation improvements rustc didn't optimize the `eval` function and ended up doing 131072*32*992 loop iteration per Lagrange compute. For some unknown reason to me, optimizer works only if we operate on integers, not on fields. Maybe modulo reduction is to blame --- .../ipa_prf/malicious_security/lagrange.rs | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index c31649a41..e578f0a51 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, fmt::Debug}; +use std::fmt::Debug; use typenum::Unsigned; @@ -79,8 +79,7 @@ pub struct LagrangeTable { impl LagrangeTable where - F: Field + TryFrom, - >::Error: Debug, + F: PrimeField, { /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point /// The "x coordinate" of the output point is `x_output`. @@ -95,25 +94,16 @@ where impl LagrangeTable where - F: Field, + F: PrimeField, { /// This function uses the `LagrangeTable` to evaluate `polynomial` on the _output_ "x coordinates" /// that were used to generate this table. /// It is assumed that the `y_coordinates` provided to this function correspond the values of the _input_ "x coordinates" /// that were used to generate this table. - pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M] - { + pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M] { self.table - .iter() - .map(|table_row| { - table_row - .iter() - .zip(y_coordinates) - .fold(F::ZERO, |acc, (&base, y)| acc + base * (*y.borrow())) - }) - .collect::>() - .try_into() - .unwrap() + .each_ref() + .map(|row| dot_product(row, y_coordinates)) } /// helper function to compute a single row of `LagrangeTable` @@ -170,6 +160,29 @@ where } } +/// Computes the dot product of two arrays of the same size. +/// It is isolated from Lagrange because there could be potential SIMD optimizations used +fn dot_product(a: &[F; N], b: &[F; N]) -> F { + // Staying in integers allows rustc to optimize this code properly + // with any reasonable N, we won't run into overflow with dot product. + // (N can be as large as 2^32 and still no chance of overflow for 61 bit prime fields) + debug_assert!( + F::PRIME.into() < (1 << 64), + "The prime {} is too large for this dot product implementation", + F::PRIME.into() + ); + + let mut sum = 0; + + // I am cautious about using zip in hot code + // https://github.com/rust-lang/rust/issues/103555 + for i in 0..N { + sum += a[i].as_u128() * b[i].as_u128(); + } + + F::truncate_from(sum) +} + #[cfg(all(test, unit_test))] mod test { use std::{borrow::Borrow, fmt::Debug}; From a0c6a78fea099ce1322c2025ac923029a7a7f085 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 8 Nov 2024 15:10:47 -0800 Subject: [PATCH 3/3] Tighten up overflow check in dot product It turns out the original check was wrong estimating overflow conditions --- .../src/protocol/ipa_prf/malicious_security/lagrange.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index e578f0a51..91deec7f2 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -163,12 +163,11 @@ where /// Computes the dot product of two arrays of the same size. /// It is isolated from Lagrange because there could be potential SIMD optimizations used fn dot_product(a: &[F; N], b: &[F; N]) -> F { - // Staying in integers allows rustc to optimize this code properly - // with any reasonable N, we won't run into overflow with dot product. - // (N can be as large as 2^32 and still no chance of overflow for 61 bit prime fields) + // Staying in integers allows rustc to optimize this code properly, but puts a restriction + // on how large the prime field can be debug_assert!( - F::PRIME.into() < (1 << 64), - "The prime {} is too large for this dot product implementation", + 2 * F::BITS + N.next_power_of_two().ilog2() <= 128, + "The prime field {} is too large for this dot product implementation", F::PRIME.into() );