From f738a81c55b26b02a55c062098a4b31e58e40adf Mon Sep 17 00:00:00 2001 From: danielmasny Date: Fri, 23 Feb 2024 10:25:53 -0800 Subject: [PATCH] use generate --- .../ipa_prf/malicious_security/lagrange.rs | 34 ++++++------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 0a3ab3291..5356fb96f 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,6 +1,4 @@ -use std::iter; - -use generic_array::{ArrayLength, GenericArray}; +use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; use typenum::{Unsigned, U1}; use crate::ff::{Field, PrimeField, Serializable}; @@ -43,9 +41,7 @@ where // assertion that table is not too large for the stack assert!(::Size::USIZE * N::USIZE < 2024); - let mut denominator = iter::repeat(F::ONE) - .take(N::USIZE) - .collect::>(); + let mut denominator = GenericArray::generate(|_| F::ONE); for (i, d) in denominator.iter_mut().enumerate() { for j in (0usize..N::USIZE).filter(|&j| i != j) { *d *= F::try_from(i as u128).unwrap() - F::try_from(j as u128).unwrap(); @@ -100,9 +96,7 @@ where /// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates" /// outputs the "y coordinates" such that `(x,y)` lies on `polynomial` pub fn eval(&self, polynomial: &Polynomial) -> GenericArray { - let mut result = iter::repeat(F::ONE) - .take(M::USIZE) - .collect::>(); + let mut result = GenericArray::generate(|_| F::ONE); self.mult_result_by_evaluation(polynomial, &mut result); result } @@ -159,9 +153,7 @@ where // assertion that table is not too large for the stack assert!(::Size::USIZE * N::USIZE * M::USIZE < 2024); - let mut table = iter::repeat(value.denominator.clone()) - .take(M::USIZE) - .collect::>(); + let mut table = GenericArray::generate(|_| value.denominator.clone()); table.iter_mut().enumerate().for_each(|(i, row)| { Self::compute_table_row(&F::try_from((i + N::USIZE) as u128).unwrap(), row); }); @@ -171,9 +163,7 @@ where #[cfg(all(test, unit_test))] mod test { - use std::iter; - - use generic_array::{ArrayLength, GenericArray}; + use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; use proptest::{prelude::*, proptest}; use typenum::{U1, U32, U7, U8}; @@ -203,9 +193,7 @@ mod test { M: ArrayLength, { // evaluate polynomial p at evaluation_points and random point using monomial base - let mut y_values = iter::repeat(F::ZERO) - .take(M::USIZE) - .collect::>(); + let mut y_values = GenericArray::generate(|_| F::ZERO); for (x, y) in x_output.iter().zip(y_values.iter_mut()) { // monomial base, i.e. `x^k` let mut base = F::ONE; @@ -225,9 +213,8 @@ mod test { N: ArrayLength, { fn from(value: MonomialFormPolynomial) -> Self { - let canonical_points: GenericArray = (0..N::USIZE) - .map(|i| F::try_from(i as u128).unwrap()) - .collect::>(); + let canonical_points: GenericArray = + GenericArray::generate(|i| F::try_from(i as u128).unwrap()); Polynomial { y_coordinates: value.eval(&canonical_points), } @@ -264,9 +251,8 @@ mod test { coefficients: GenericArray::::from_array(input_points), }; // the canonical x coordinates are 0..15, the outputs use coordinates 8..15: - let x_coordinates_output = (8..15) - .map(|i| TestField::try_from(i).unwrap()) - .collect::>(); + let x_coordinates_output = + GenericArray::<_, U7>::generate(|i| TestField::try_from(i as u128 + 8).unwrap()); let output_expected = polynomial_monomial_form.eval(&x_coordinates_output); let polynomial = Polynomial::from(polynomial_monomial_form.clone()); let denominator = CanonicalLagrangeDenominator::::new();