Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
danielmasny committed Feb 20, 2024
1 parent ee6d2f5 commit 9d51abd
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ where
}
*d = d.invert();
}
CanonicalLagrangeDenominator { denominator }
Self { denominator }
}
}

Expand All @@ -63,9 +63,9 @@ where
{
/// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point
/// The "x coordinate" of the output point is `x_output`.
pub fn new(denominator: &CanonicalLagrangeDenominator<F, N>, x_output: &F) -> Self {
let mut table = denominator.denominator.clone();
compute_table_row(x_output, &mut table);
pub fn new(denominator: CanonicalLagrangeDenominator<F, N>, x_output: &F) -> Self {
let mut table = denominator.denominator;
Self::compute_table_row(x_output, &mut table);
LagrangeTable::<F, N, U1> {
table: GenericArray::from_array([table; 1]),
}
Expand Down Expand Up @@ -102,6 +102,22 @@ where
.fold(F::ZERO, |acc, (&base, &y)| acc + base * y);
}
}

/// helper function to compute a single row of `CanonicalLagrangeTable`
///
/// ## Panics
/// When the field size is too small for `N` evaluation points
fn compute_table_row(x_output: &F, table_row: &mut GenericArray<F, N>)
where
F: Field,
N: ArrayLength,
{
for (i, entry) in table_row.iter_mut().enumerate() {
for j in (0usize..N::USIZE).filter(|&j| j != i) {
*entry *= *x_output - F::try_from(j as u128).unwrap();
}
}
}
}

impl<F, N, M> From<CanonicalLagrangeDenominator<F, N>> for LagrangeTable<F, N, M>
Expand All @@ -113,34 +129,18 @@ where
fn from(value: CanonicalLagrangeDenominator<F, N>) -> Self {
// assertion that field is large enough
// also checks that `try_from` for conversions from sufficiently small `u128` to `F` do not panic
debug_assert!(F::BITS > usize::BITS - N::USIZE.leading_zeros() - M::USIZE.leading_zeros());
debug_assert!(F::BITS > usize::BITS - (N::USIZE + M::USIZE).leading_zeros());

let mut table = iter::repeat(value.denominator.clone())
.take(M::USIZE)
.collect::<GenericArray<_, _>>();
table.iter_mut().enumerate().for_each(|(i, row)| {
compute_table_row(&F::try_from((i + N::USIZE) as u128).unwrap(), row);
Self::compute_table_row(&F::try_from((i + N::USIZE) as u128).unwrap(), row);
});
LagrangeTable { table }
}
}

/// helper function to compute a single row of `CanonicalLagrangeTable`
///
/// ## Panics
/// When the field size is too small for `N` evaluation points
fn compute_table_row<F, N>(x_output: &F, table_row: &mut GenericArray<F, N>)
where
F: Field,
N: ArrayLength,
{
for (i, entry) in table_row.iter_mut().enumerate() {
for j in (0usize..N::USIZE).filter(|&j| j != i) {
*entry *= *x_output - F::try_from(j as u128).unwrap();
}
}
}

#[cfg(all(test, unit_test))]
mod test {
use std::iter;
Expand Down Expand Up @@ -216,7 +216,7 @@ mod test {
let polynomial = Polynomial::from(polynomial_monomial_form.clone());
let denominator = CanonicalLagrangeDenominator::<TestField,U32>::new();
// generate table using new
let lagrange_table = LagrangeTable::<TestField,U32,U1>::new(&denominator,&output_point);
let lagrange_table = LagrangeTable::<TestField,U32,U1>::new(denominator,&output_point);
let output = lagrange_table.eval(&polynomial);
assert_eq!(output,output_expected);
}
Expand Down

0 comments on commit 9d51abd

Please sign in to comment.