diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index e0fe6e199..91deec7f2 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, fmt::Debug}; +use std::fmt::Debug; use typenum::Unsigned; @@ -79,8 +79,7 @@ pub struct LagrangeTable { impl LagrangeTable where - F: Field + TryFrom, - >::Error: Debug, + F: PrimeField, { /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point /// The "x coordinate" of the output point is `x_output`. @@ -95,31 +94,16 @@ where impl LagrangeTable where - F: Field, + F: PrimeField, { /// This function uses the `LagrangeTable` to evaluate `polynomial` on the _output_ "x coordinates" /// that were used to generate this table. /// It is assumed that the `y_coordinates` provided to this function correspond the values of the _input_ "x coordinates" /// that were used to generate this table. - pub fn eval(&self, y_coordinates: I) -> [F; M] - where - I: IntoIterator + Copy, - I::IntoIter: ExactSizeIterator, - I::Item: Borrow, - { - debug_assert_eq!(y_coordinates.into_iter().len(), N); - + pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M] { self.table - .iter() - .map(|table_row| { - table_row - .iter() - .zip(y_coordinates) - .fold(F::ZERO, |acc, (&base, y)| acc + base * (*y.borrow())) - }) - .collect::>() - .try_into() - .unwrap() + .each_ref() + .map(|row| dot_product(row, y_coordinates)) } /// helper function to compute a single row of `LagrangeTable` @@ -176,6 +160,28 @@ where } } +/// Computes the dot product of two arrays of the same size. +/// It is isolated from Lagrange because there could be potential SIMD optimizations used +fn dot_product(a: &[F; N], b: &[F; N]) -> F { + // Staying in integers allows rustc to optimize this code properly, but puts a restriction + // on how large the prime field can be + debug_assert!( + 2 * F::BITS + N.next_power_of_two().ilog2() <= 128, + "The prime field {} is too large for this dot product implementation", + F::PRIME.into() + ); + + let mut sum = 0; + + // I am cautious about using zip in hot code + // https://github.com/rust-lang/rust/issues/103555 + for i in 0..N { + sum += a[i].as_u128() * b[i].as_u128(); + } + + F::truncate_from(sum) +} + #[cfg(all(test, unit_test))] mod test { use std::{borrow::Borrow, fmt::Debug}; diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs index 9bedcc9a2..e62465b73 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs @@ -159,7 +159,7 @@ where last_array[1..last_u_or_v_values.len()].copy_from_slice(&last_u_or_v_values[1..]); // compute and output p_or_q - tables.last().unwrap().eval(last_array)[0] + tables.last().unwrap().eval(&last_array)[0] } #[cfg(all(test, unit_test))]