Skip to content

Commit

Permalink
Merge pull request #15 from mmaker/main
Browse files Browse the repository at this point in the history
Update to arkworks 0.5 + cargo fmt
  • Loading branch information
WizardOfMenlo authored Nov 14, 2024
2 parents 3c764b4 + 141497d commit 0c0808d
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 159 deletions.
225 changes: 143 additions & 82 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ edition = "2021"
default-run = "main"

[dependencies]
ark-std = "0.4"
ark-ff = { version = "0.4", features = ["asm"] }
ark-serialize = "0.4"
ark-crypto-primitives = { version = "0.4", features = ["merkle_tree"] }
ark-poly = "0.4"
ark-test-curves = { version = "0.4", features = ["bls12_381_curve"] }
ark-std = {version = "0.5", features = ["std"]}
ark-ff = { version = "0.5", features = ["asm", "std"] }
ark-serialize = "0.5"
ark-crypto-primitives = { version = "0.5", features = ["merkle_tree"] }
ark-poly = "0.5"
ark-test-curves = { version = "0.5", features = ["bls12_381_curve"] }
derivative = { version = "2", features = ["use_core"] }
blake3 = "1.5.0"
blake2 = "0.10"
Expand Down
2 changes: 1 addition & 1 deletion src/ntt/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl<'a, T> MatrixMut<'a, T> {

/// Split the matrix into four quadrants at the indicated `row` and `col` (meaning that in the returned 4-tuple (A,B,C,D), the matrix A is a `row`x`col` matrix)
///
/// self = [A B]
/// self = [A B]
/// [C D]
pub fn split_quadrants(self, row: usize, col: usize) -> (Self, Self, Self, Self) {
let (u, l) = self.split_vertical(row); // split into upper and lower parts
Expand Down
5 changes: 2 additions & 3 deletions src/ntt/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ static ENGINE_CACHE: LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>
/// Enginge for computing NTTs over arbitrary fields.
/// Assumes the field has large two-adicity.
pub struct NttEngine<F: Field> {
order: usize, // order of omega_orger
omega_order: F, // primitive order'th root.
order: usize, // order of omega_orger
omega_order: F, // primitive order'th root.

// Roots of small order (zero if unavailable). The naming convention is that omega_foo has order foo.
half_omega_3_1_plus_2: F, // ½(ω₃ + ω₃²)
Expand Down Expand Up @@ -58,7 +58,6 @@ pub fn intt<F: FftField>(values: &mut [F]) {
NttEngine::<F>::new_from_cache().intt(values);
}


/// Compute the inverse NTT of multiple slice of field elements, each of size `size`, without the 1/n scaling factor and using a cached engine.
pub fn intt_batch<F: FftField>(values: &mut [F], size: usize) {
NttEngine::<F>::new_from_cache().intt_batch(values, size);
Expand Down
71 changes: 38 additions & 33 deletions src/ntt/transpose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub fn transpose<F: Sized + Copy + Send>(matrix: &mut [F], rows: usize, cols: us
}

// The following function have both a parallel and a non-parallel implementation.
// We fuly split those in a parallel and a non-parallel functions (rather than using #[cfg] within a single function)
// We fuly split those in a parallel and a non-parallel functions (rather than using #[cfg] within a single function)
// and have main entry point fun that just calls the appropriate version (either fun_parallel or fun_not_parallel).
// The sole reason is that this simplifies unit tests: We otherwise would need to build twice to cover both cases.
// For effiency, we assume the compiler inlines away the extra "indirection" that we add to the entry point function.
Expand Down Expand Up @@ -83,7 +83,6 @@ fn transpose_copy_parallel<'a, 'b, F: Sized + Copy + Send>(
}
}


/// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible.
/// This is the non-parallel version
fn transpose_copy_not_parallel<'a, 'b, F: Sized + Copy>(
Expand Down Expand Up @@ -308,7 +307,7 @@ mod tests {
funs.push(&transpose_copy_parallel::<Pair>);

for f in funs {
let rows: usize = workload_size::<Pair>() + 1; // intentionally not a power of two: The function is not described as only working for powers of two.
let rows: usize = workload_size::<Pair>() + 1; // intentionally not a power of two: The function is not described as only working for powers of two.
let columns: usize = 4;
let mut srcarray = make_example_matrix(rows, columns);
let mut dstarray: Vec<(usize, usize)> = vec![(0, 0); rows * columns];
Expand Down Expand Up @@ -338,7 +337,6 @@ mod tests {
funs.push(&transpose_square_swap_parallel::<Triple>);

for f in funs {

// Set rows manually. We want to be sure to trigger the actual recursion.
// (Computing this from workload_size was too much hassle.)
let rows = 1024; // workload_size::<Triple>();
Expand Down Expand Up @@ -370,43 +368,44 @@ mod tests {
}

#[test]
fn test_transpose_square(){
let mut funs: Vec<&dyn for <'a> Fn(MatrixMut<'a,_>)> = vec![
&transpose_square::<Pair>, &transpose_square_parallel::<Pair>
fn test_transpose_square() {
let mut funs: Vec<&dyn for<'a> Fn(MatrixMut<'a, _>)> = vec![
&transpose_square::<Pair>,
&transpose_square_parallel::<Pair>,
];
#[cfg(feature="parallel")]
#[cfg(feature = "parallel")]
funs.push(&transpose_square::<Pair>);
for f in funs{
for f in funs {
// Set rows manually. We want to be sure to trigger the actual recursion.
// (Computing this from workload_size was too much hassle.)
let size = 1024;
let size = 1024;
assert!(size * size > 2 * workload_size::<Pair>());

let mut example = make_example_matrix(size, size);
let view = MatrixMut::from_mut_slice(&mut example, size, size);
f(view);
let view = MatrixMut::from_mut_slice(&mut example, size, size);
for i in 0..size{
for j in 0..size{
assert_eq!(view[(i,j)], (j,i));
for i in 0..size {
for j in 0..size {
assert_eq!(view[(i, j)], (j, i));
}
}
}
}

#[test]
fn test_transpose(){
fn test_transpose() {
let size = 1024;

// rectangular matrix:
let rows = size;
let cols = 16;
let mut example = make_example_matrix(rows, cols);
transpose(&mut example, rows, cols);
let view = MatrixMut::from_mut_slice(&mut example, cols, rows);
for i in 0..cols{
for j in 0..rows{
assert_eq!(view[(i,j)], (j,i));
for i in 0..cols {
for j in 0..rows {
assert_eq!(view[(i, j)], (j, i));
}
}

Expand All @@ -416,9 +415,9 @@ mod tests {
let mut example = make_example_matrix(rows, cols);
transpose(&mut example, rows, cols);
let view = MatrixMut::from_mut_slice(&mut example, cols, rows);
for i in 0..cols{
for j in 0..rows{
assert_eq!(view[(i,j)], (j,i));
for i in 0..cols {
for j in 0..rows {
assert_eq!(view[(i, j)], (j, i));
}
}

Expand All @@ -428,11 +427,15 @@ mod tests {
let cols = 16;
let mut example = make_example_matrices(rows, cols, number_of_matrices);
transpose(&mut example, rows, cols);
for index in 0..number_of_matrices{
let view = MatrixMut::from_mut_slice(&mut example[index*rows*cols..(index+1)*rows*cols], cols, rows);
for i in 0..cols{
for j in 0..rows{
assert_eq!(view[(i,j)], (index,j,i));
for index in 0..number_of_matrices {
let view = MatrixMut::from_mut_slice(
&mut example[index * rows * cols..(index + 1) * rows * cols],
cols,
rows,
);
for i in 0..cols {
for j in 0..rows {
assert_eq!(view[(i, j)], (index, j, i));
}
}
}
Expand All @@ -443,15 +446,17 @@ mod tests {
let cols = size;
let mut example = make_example_matrices(rows, cols, number_of_matrices);
transpose(&mut example, rows, cols);
for index in 0..number_of_matrices{
let view = MatrixMut::from_mut_slice(&mut example[index*rows*cols..(index+1)*rows*cols], cols, rows);
for i in 0..cols{
for j in 0..rows{
assert_eq!(view[(i,j)], (index,j,i));
for index in 0..number_of_matrices {
let view = MatrixMut::from_mut_slice(
&mut example[index * rows * cols..(index + 1) * rows * cols],
cols,
rows,
);
for i in 0..cols {
for j in 0..rows {
assert_eq!(view[(i, j)], (index, j, i));
}
}
}


}
}
27 changes: 17 additions & 10 deletions src/poly_utils/coeffs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ use {
std::mem::size_of,
};


/// A CoefficientList models a (multilinear) polynomial in `num_variable` variables in coefficient form.
///
///
/// The order of coefficients follows the following convention: coeffs[j] corresponds to the monomial
/// determined by the binary decomposition of j with an X_i-variable present if the
/// i-th highest-significant bit among the `num_variables` least significant bits is set.
///
///
/// e.g. is `num_variables` is 3 with variables X_0, X_1, X_2, then
/// - coeffs[0] is the coefficient of 1
/// - coeffs[1] is the coefficient of X_2
Expand All @@ -23,7 +22,7 @@ use {
#[derive(Debug, Clone)]
pub struct CoefficientList<F> {
coeffs: Vec<F>, // list of coefficients. For multilinear polynomials, we have coeffs.len() == 1 << num_variables.
num_variables: usize, // number of variables
num_variables: usize, // number of variables
}

impl<F> CoefficientList<F>
Expand Down Expand Up @@ -59,12 +58,16 @@ where
// NOTE (Gotti): This algorithm uses 2^{n+1}-1 multiplications for a polynomial in n variables.
// You could do with 2^{n}-1 by just doing a + x * b (and not forwarding scalar through the recursion at all).
// The difference comes from multiplications by E::ONE at the leaves of the recursion tree.

// recursive helper function for polynomial evaluation:
// Note that eval(coeffs, [X_0, X1,...]) = eval(coeffs_left, [X_1,...]) + X_0 * eval(coeffs_right, [X_1,...])

/// Recursively compute scalar * poly_eval(coeffs;eval) where poly_eval interprets coeffs as a polynomial and eval are the evaluation points.
fn eval_extension_nonparallel<E: Field<BasePrimeField = F>>(coeff: &[F], eval: &[E], scalar: E) -> E {
fn eval_extension_nonparallel<E: Field<BasePrimeField = F>>(
coeff: &[F],
eval: &[E],
scalar: E,
) -> E {
debug_assert_eq!(coeff.len(), 1 << eval.len());
if let Some((&x, tail)) = eval.split_first() {
let (low, high) = coeff.split_at(coeff.len() / 2);
Expand All @@ -77,7 +80,11 @@ where
}

#[cfg(feature = "parallel")]
fn eval_extension_parallel<E: Field<BasePrimeField = F>>(coeff: &[F], eval: &[E], scalar: E) -> E {
fn eval_extension_parallel<E: Field<BasePrimeField = F>>(
coeff: &[F],
eval: &[E],
scalar: E,
) -> E {
const PARALLEL_THRESHOLD: usize = 10;
debug_assert_eq!(coeff.len(), 1 << eval.len());
if let Some((&x, tail)) = eval.split_first() {
Expand All @@ -98,7 +105,7 @@ where
}

/// Evaluate self at `point`, where `point` is from a field extension extending the field over which the polynomial `self` is defined.
///
///
/// Note that we only support the case where F is a prime field.
pub fn evaluate_at_extension<E: Field<BasePrimeField = F>>(
&self,
Expand Down Expand Up @@ -226,7 +233,7 @@ where
{
/// fold folds the polynomial at the provided folding_randomness.
///
/// Namely, when self is interpreted as a multi-linear polynomial f in X_0, ..., X_{n-1},
/// Namely, when self is interpreted as a multi-linear polynomial f in X_0, ..., X_{n-1},
/// it partially evaluates f at the provided `folding_randomness`.
/// Our ordering convention is to evaluate at the higher indices, i.e. we return f(X_0,X_1,..., folding_randomness[0], folding_randomness[1],...)
pub fn fold(&self, folding_randomness: &MultilinearPoint<F>) -> Self {
Expand Down
4 changes: 2 additions & 2 deletions src/poly_utils/evals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::{sequential_lag_poly::LagrangePolynomialIterator, MultilinearPoint};

/// An EvaluationsList models a multi-linear polynomial f in `num_variables`
/// unknowns, stored via their evaluations at {0,1}^{num_variables}
///
///
/// `evals` stores the evaluation in lexicographic order.
#[derive(Debug)]
pub struct EvaluationsList<F> {
Expand All @@ -19,7 +19,7 @@ where
F: Field,
{
/// Constructs a EvaluationList from the given vector `eval` of evaluations.
///
///
/// The provided `evals` is supposed to be the list of evaluations, where the ordering of evaluation points in {0,1}^n
/// is lexicographic.
pub fn new(evals: Vec<F>) -> Self {
Expand Down
7 changes: 3 additions & 4 deletions src/poly_utils/hypercube.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]

// TODO (Gotti): Should pos rather be a u64? usize is platform-dependent, giving a platform-dependent limit on the number of variables.
// num_variables may be smaller as well.

// NOTE: Conversion BinaryHypercube <-> MultilinearPoint is Big Endian, using only the num_variables least significant bits of the number stored inside BinaryHypercube.

/// point on the binary hypercube {0,1}^n for some n.
///
///
/// The point is encoded via the n least significant bits of a usize in big endian order and we do not store n.
pub struct BinaryHypercubePoint(pub usize);

/// BinaryHypercube is an Iterator that is used to range over the points of the hypercube {0,1}^n, where n == `num_variables`
pub struct BinaryHypercube {
pos: usize, // current position, encoded via the bits of pos
pos: usize, // current position, encoded via the bits of pos
num_variables: usize, // dimension of the hypercube
}

impl BinaryHypercube {
pub fn new(num_variables: usize) -> Self {
debug_assert!(num_variables < usize::BITS as usize); // Note that we need strictly smaller, since some code would overflow otherwise.
debug_assert!(num_variables < usize::BITS as usize); // Note that we need strictly smaller, since some code would overflow otherwise.
BinaryHypercube {
pos: 0,
num_variables,
Expand Down
4 changes: 2 additions & 2 deletions src/poly_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ where
self.0.len()
}

// NOTE: Conversion BinaryHypercube <-> MultilinearPoint converts a
// NOTE: Conversion BinaryHypercube <-> MultilinearPoint converts a
// multilinear point (x1,x2,...,x_n) into the number with bit-pattern 0...0 x_1 x_2 ... x_n, provided all x_i are in {0,1}.
// That means we pad zero bits in BinaryHypercube from the msb end and use big-endian for the actual conversion.

Expand Down Expand Up @@ -112,7 +112,7 @@ where
{
let mut point = point.0;
let n_variables = coords.n_variables();
assert!(point < (1 << n_variables)); // check that the lengths of coords and point match.
assert!(point < (1 << n_variables)); // check that the lengths of coords and point match.

let mut acc = F::ONE;

Expand Down
6 changes: 3 additions & 3 deletions src/poly_utils/sequential_lag_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ use super::{hypercube::BinaryHypercubePoint, MultilinearPoint};
///
/// This means that y == eq_poly(point, x)
pub struct LagrangePolynomialIterator<F: Field> {
last_position: Option<usize>, // the previously output BinaryHypercubePoint (encoded as usize). None before the first output.
point: Vec<F>, // stores a copy of the `point` given when creating the iterator. For easier(?) bit-fiddling, we store in in reverse order.
last_position: Option<usize>, // the previously output BinaryHypercubePoint (encoded as usize). None before the first output.
point: Vec<F>, // stores a copy of the `point` given when creating the iterator. For easier(?) bit-fiddling, we store in in reverse order.
point_negated: Vec<F>, // stores the precomputed values 1-point[i] in the same ordering as point.
/// stack Stores the n+1 values (in order) 1, y_1, y_1*y_2, y_1*y_2*y_3, ..., y_1*...*y_n for the previously output y.
/// Before the first iteration (if last_position == None), it stores the values for the next (i.e. first) output instead.
stack: Vec<F>,
num_variables: usize, // dimension
num_variables: usize, // dimension
}

impl<F: Field> LagrangePolynomialIterator<F> {
Expand Down
12 changes: 6 additions & 6 deletions src/sumcheck/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
// Stored in evaluation form
#[derive(Debug, Clone)]
pub struct SumcheckPolynomial<F> {
n_variables: usize, // number of variables;
n_variables: usize, // number of variables;
// evaluations has length 3^{n_variables}
// The order in which it is stored is such that evaluations[i]
// corresponds to the evaluation at utils::base_decomposition(i, 3, n_variables),
Expand All @@ -29,7 +29,7 @@ where
}
}

/// Returns the vector of evaluations at {0,1,2}^n_variables of the polynomial f
/// Returns the vector of evaluations at {0,1,2}^n_variables of the polynomial f
/// in the following order: [f(0,0,..,0), f(0,0,..,1), f(0,0,...,2), f(0,0,...,1,0), ...]
/// (i.e. lexicographic wrt. to the evaluation points.
pub fn evaluations(&self) -> &[F] {
Expand All @@ -40,7 +40,7 @@ where
// TODO(Gotti): Make more efficient; the base_decomposition and filtering is unneccessary.

/// Returns the sum of evaluations of f, when summed only over {0,1}^n_variables
///
///
/// (and not over {0,1,2}^n_variable)
pub fn sum_over_hypercube(&self) -> F {
let num_evaluation_points = 3_usize.pow(self.n_variables as u32);
Expand All @@ -59,10 +59,10 @@ where
}

/// evaluates the polynomial at an arbitrary point, not neccessarily in {0,1,2}^n_variables.
///
/// We assert that point.n_variables() == self.n_variables
///
/// We assert that point.n_variables() == self.n_variables
pub fn evaluate_at_point(&self, point: &MultilinearPoint<F>) -> F {
assert!(point.n_variables() == self.n_variables);
assert!(point.n_variables() == self.n_variables);
let num_evaluation_points = 3_usize.pow(self.n_variables as u32);

let mut evaluation = F::ZERO;
Expand Down
Loading

0 comments on commit 0c0808d

Please sign in to comment.