diff --git a/Cargo.toml b/Cargo.toml index d5809c1..9f4612a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,9 +30,10 @@ rayon = { version = "1.10.0", optional = true } debug = true [features] -default = ["parallel", "dep:rayon"] +default = ["parallel"] #default = [] parallel = [ + "dep:rayon", "ark-poly/parallel", "ark-ff/parallel", "ark-crypto-primitives/parallel", diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 0248157..039b38f 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,3 +1,2 @@ pub mod fields; pub mod merkle_tree; -pub mod ntt; diff --git a/src/lib.rs b/src/lib.rs index 8cbc934..153baf3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod cmdline_utils; pub mod crypto; // Crypto utils pub mod domain; // Domain that we are evaluating over pub mod fs_utils; +pub mod ntt; pub mod parameters; pub mod poly_utils; // Utils for polynomials pub mod sumcheck; // Sumcheck specialised diff --git a/src/ntt/matrix.rs b/src/ntt/matrix.rs new file mode 100644 index 0000000..4bbd1eb --- /dev/null +++ b/src/ntt/matrix.rs @@ -0,0 +1,170 @@ +//! Minimal matrix class that supports strided access. +//! This abstracts over the unsafe pointer arithmetic required for transpose-like algorithms. + +#![allow(unsafe_code)] + +use std::{ + marker::PhantomData, + ops::{Index, IndexMut}, + ptr, slice, +}; + +/// Mutable reference to a matrix. +/// +/// The invariant this data structure maintains is that `data` has lifetime +/// `'a` and points to a collection of `rows` rowws, at intervals `row_stride`, +/// each of length `cols`. +pub struct MatrixMut<'a, T> { + data: *mut T, + rows: usize, + cols: usize, + row_stride: usize, + _lifetime: PhantomData<&'a mut T>, +} + +unsafe impl<'a, T: Send> Send for MatrixMut<'_, T> {} + +unsafe impl<'a, T: Sync> Sync for MatrixMut<'_, T> {} + +impl<'a, T> MatrixMut<'a, T> { + pub fn from_mut_slice(slice: &'a mut [T], rows: usize, cols: usize) -> Self { + assert_eq!(slice.len(), rows * cols); + // Safety: The input slice is valid for the lifetime `'a` and has + // `rows` contiguous rows of length `cols`. + Self { + data: slice.as_mut_ptr(), + rows, + cols, + row_stride: cols, + _lifetime: PhantomData, + } + } + + pub fn rows(&self) -> usize { + self.rows + } + + pub fn cols(&self) -> usize { + self.cols + } + + pub fn is_square(&self) -> bool { + self.rows == self.cols + } + + pub fn row(&mut self, row: usize) -> &mut [T] { + assert!(row < self.rows); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride` + // there is valid data of length `self.cols`. + unsafe { slice::from_raw_parts_mut(self.data.add(row * self.row_stride), self.cols) } + } + + /// Split the matrix into two vertically. + /// + /// [A] = self + /// [B] + pub fn split_vertical(self, row: usize) -> (Self, Self) { + assert!(row <= self.rows); + ( + Self { + data: self.data, + rows: row, + cols: self.cols, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + Self { + data: unsafe { self.data.add(row * self.row_stride) }, + rows: self.rows - row, + cols: self.cols, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + ) + } + + /// Split the matrix into two horizontally. + /// + /// [A B] = self + pub fn split_horizontal(self, col: usize) -> (Self, Self) { + assert!(col <= self.cols); + ( + // Safety: This reduces the number of cols, keeping all else the same. + Self { + data: self.data, + rows: self.rows, + cols: col, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + // Safety: This reduces the number of cols and offsets and, keeping all else the same. + Self { + data: unsafe { self.data.add(col) }, + rows: self.rows, + cols: self.cols - col, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + ) + } + + /// Split the matrix into four quadrants. + /// + /// [A B] = self + /// [C D] + pub fn split_quadrants(self, row: usize, col: usize) -> (Self, Self, Self, Self) { + let (u, d) = self.split_vertical(row); + let (a, b) = u.split_horizontal(col); + let (c, d) = d.split_horizontal(col); + (a, b, c, d) + } + + /// Swap two elements in the matrix. + pub fn swap(&mut self, a: (usize, usize), b: (usize, usize)) { + if a != b { + unsafe { + let a = self.ptr_at_mut(a.0, a.1); + let b = self.ptr_at_mut(b.0, b.1); + ptr::swap_nonoverlapping(a, b, 1) + } + } + } + + unsafe fn ptr_at(&self, row: usize, col: usize) -> *const T { + assert!(row < self.rows); + assert!(col < self.cols); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + self.data.add(row * self.row_stride + col) + } + + unsafe fn ptr_at_mut(&mut self, row: usize, col: usize) -> *mut T { + assert!(row < self.rows); + assert!(col < self.cols); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + self.data.add(row * self.row_stride + col) + } +} + +impl Index<(usize, usize)> for MatrixMut<'_, T> { + type Output = T; + + fn index(&self, (row, col): (usize, usize)) -> &T { + assert!(row < self.rows); + assert!(col < self.cols); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + unsafe { &*self.ptr_at(row, col) } + } +} + +impl IndexMut<(usize, usize)> for MatrixMut<'_, T> { + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut T { + assert!(row < self.rows); + assert!(col < self.cols); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + unsafe { &mut *self.ptr_at_mut(row, col) } + } +} diff --git a/src/ntt/mod.rs b/src/ntt/mod.rs new file mode 100644 index 0000000..3fd176a --- /dev/null +++ b/src/ntt/mod.rs @@ -0,0 +1,61 @@ +//! NTT and related algorithms. + +mod matrix; +mod ntt; +mod transpose; +mod utils; +mod wavelet; + +use self::matrix::MatrixMut; +use ark_ff::FftField; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +pub use self::{ + ntt::{intt, intt_batch, ntt, ntt_batch}, + transpose::transpose, + wavelet::wavelet_transform, +}; + +/// RS encode at a rate 1/`expansion`. +pub fn expand_from_coeff(coeffs: &[F], expansion: usize) -> Vec { + let engine = ntt::NttEngine::::new_from_cache(); + let expanded_size = coeffs.len() * expansion; + let mut result = Vec::with_capacity(expanded_size); + // Note: We can also zero-extend the coefficients and do a larger NTT. + // But this is more efficient. + + // Do coset NTT. + let root = engine.root(expanded_size); + result.extend_from_slice(coeffs); + #[cfg(not(feature = "parallel"))] + for i in 1..expansion { + let root = root.pow([i as u64]); + let mut offset = F::ONE; + result.extend(coeffs.iter().map(|x| { + let val = *x * offset; + offset *= root; + val + })); + } + #[cfg(feature = "parallel")] + result.par_extend((1..expansion).into_par_iter().flat_map(|i| { + let root_i = root.pow([i as u64]); + coeffs + .par_iter() + .enumerate() + .map_with(F::ZERO, move |root_j, (j, coeff)| { + if root_j.is_zero() { + *root_j = root_i.pow([j as u64]); + } else { + *root_j *= root_i; + } + *coeff * *root_j + }) + })); + + ntt_batch(&mut result, coeffs.len()); + transpose(&mut result, expansion, coeffs.len()); + result +} diff --git a/src/crypto/ntt.rs b/src/ntt/ntt.rs similarity index 63% rename from src/crypto/ntt.rs rename to src/ntt/ntt.rs index bdc6c8d..3a86d1a 100644 --- a/src/crypto/ntt.rs +++ b/src/ntt/ntt.rs @@ -3,25 +3,19 @@ //! Implements the √N Cooley-Tukey six-step algorithm to achieve parallelism with good locality. //! A global cache is used for twiddle factors. +use super::{ + transpose, + utils::{lcm, sqrt_factor, workload_size}, +}; use ark_ff::{FftField, Field}; use std::{ any::{Any, TypeId}, collections::HashMap, - mem::swap, sync::{Arc, LazyLock, Mutex, RwLock, RwLockReadGuard}, }; #[cfg(feature = "parallel")] -use { - rayon::prelude::*, - std::{cmp::max, mem::size_of}, -}; - -/// Target thread workload size for parallel NTTs in bytes. -/// Should ideally be a multiple of a cache line (64 bytes) -/// and close to the L1 cache size (32 KB). -#[cfg(feature = "parallel")] -const WORKLOAD_SIZE: usize = 1 << 15; // 32 KB +use {rayon::prelude::*, std::cmp::max}; /// Global cache for NTT engines, indexed by field. // TODO: Skip `LazyLock` when `HashMap::with_hasher` becomes const. @@ -97,9 +91,6 @@ impl NttEngine { } impl NttEngine { - #[cfg(feature = "parallel")] - const WORKLOAD_SIZE: usize = WORKLOAD_SIZE / (2 * size_of::()); - pub fn new(order: usize, omega_order: F) -> Self { assert!(order.trailing_zeros() > 0, "Order must be a power of 2."); // TODO: Assert that omega_order factors into 2s and 3s. @@ -191,6 +182,7 @@ impl NttEngine { // Race condition: check if another thread updated the cache. if roots.is_empty() || roots.len() % order != 0 { // Compute minimal size to support all sizes seen so far. + // TODO: Do we really need all of these? Can we leverage omege_2 = -1? let size = if roots.is_empty() { order } else { @@ -201,11 +193,23 @@ impl NttEngine { // Compute powers of roots of unity. let root = self.root(size); - let mut root_i = F::ONE; - while roots.len() < size { - roots.push(root_i); - root_i *= root; + #[cfg(not(feature = "parallel"))] + { + let mut root_i = F::ONE; + for _ in 0..size { + roots.push(root_i); + root_i *= root; + } } + #[cfg(feature = "parallel")] + roots.par_extend((0..size).into_par_iter().map_with(F::ZERO, |root_i, i| { + if root_i.is_zero() { + *root_i = root.pow([i as u64]); + } else { + *root_i *= root; + } + *root_i + })); } // Back to read lock. drop(roots); @@ -218,39 +222,87 @@ impl NttEngine { /// Compute NTTs in place by splititng into two factors. /// Recurses using the sqrt(N) Cooley-Tukey Six step NTT algorithm. fn ntt_recurse(&self, values: &mut [F], roots: &[F], size: usize) { + debug_assert_eq!(values.len() % size, 0); let n1 = sqrt_factor(size); let n2 = size / n1; - let step = roots.len() / size; transpose(values, n1, n2); self.ntt_dispatch(values, roots, n1); transpose(values, n2, n1); // TODO: When (n1, n2) are coprime we can use the // Good-Thomas NTT algorithm and avoid the twiddle loop. - // TODO: Parallelize the twiddle loop when values.len() is large. - for values in values.chunks_exact_mut(size) { - for i in 1..n1 { + self.apply_twiddles(values, roots, n1, n2); + self.ntt_dispatch(values, roots, n2); + transpose(values, n1, n2); + } + + #[cfg(not(feature = "parallel"))] + fn apply_twiddles(&self, values: &mut [F], roots: &[F], rows: usize, cols: usize) { + debug_assert_eq!(values.len() % (rows * cols), 0); + let step = roots.len() / (rows * cols); + for values in values.chunks_exact_mut(rows * cols) { + for (i, row) in values.chunks_exact_mut(cols).enumerate().skip(1) { let step = (i * step) % roots.len(); let mut index = step; - for j in 1..n2 { + for value in row.iter_mut().skip(1) { index %= roots.len(); - values[i * n2 + j] *= roots[index]; + *value *= roots[index]; index += step; } } } - self.ntt_dispatch(values, roots, n2); - transpose(values, n1, n2); + } + + #[cfg(feature = "parallel")] + fn apply_twiddles(&self, values: &mut [F], roots: &[F], rows: usize, cols: usize) { + debug_assert_eq!(values.len() % (rows * cols), 0); + if values.len() > workload_size::() { + let size = rows * cols; + if values.len() != size { + let workload_size = size * max(1, workload_size::() / size); + values.par_chunks_mut(workload_size).for_each(|values| { + self.apply_twiddles(values, roots, rows, cols); + }); + } else { + let step = roots.len() / (rows * cols); + values + .par_chunks_exact_mut(cols) + .enumerate() + .skip(1) + .for_each(|(i, row)| { + let step = (i * step) % roots.len(); + let mut index = step; + for value in row.iter_mut().skip(1) { + index %= roots.len(); + *value *= roots[index]; + index += step; + } + }); + } + } else { + let step = roots.len() / (rows * cols); + for values in values.chunks_exact_mut(rows * cols) { + for (i, row) in values.chunks_exact_mut(cols).enumerate().skip(1) { + let step = (i * step) % roots.len(); + let mut index = step; + for value in row.iter_mut().skip(1) { + index %= roots.len(); + *value *= roots[index]; + index += step; + } + } + } + } } fn ntt_dispatch(&self, values: &mut [F], roots: &[F], size: usize) { debug_assert_eq!(values.len() % size, 0); debug_assert_eq!(roots.len() % size, 0); #[cfg(feature = "parallel")] - if values.len() > Self::WORKLOAD_SIZE && values.len() != size { + if values.len() > workload_size::() && values.len() != size { // Multiple NTTs, compute in parallel. // Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`. - let workload_size = size * max(1, Self::WORKLOAD_SIZE / size); + let workload_size = size * max(1, workload_size::() / size); return values.par_chunks_mut(workload_size).for_each(|values| { self.ntt_dispatch(values, roots, size); }); @@ -350,184 +402,3 @@ impl NttEngine { } } } - -/// Fast Wavelet Transform. -/// -/// The input slice must have a length that is a power of two. -/// Recursively applies the kernel -/// [1 0] -/// [1 1] -pub fn wavelet_transform(values: &mut [F]) { - debug_assert!(values.len().is_power_of_two()); - wavelet_transform_batch(values, values.len()) -} - -pub fn wavelet_transform_batch(values: &mut [F], size: usize) { - debug_assert_eq!(values.len() % size, 0); - debug_assert!(size.is_power_of_two()); - #[cfg(feature = "parallel")] - if values.len() > NttEngine::::WORKLOAD_SIZE && values.len() != size { - // Multiple wavelet transforms, compute in parallel. - // Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`. - let workload_size = size * max(1, NttEngine::::WORKLOAD_SIZE / size); - return values.par_chunks_mut(workload_size).for_each(|values| { - wavelet_transform_batch(values, size); - }); - } - match size { - 0 | 1 => {} - 2 => { - for v in values.chunks_exact_mut(2) { - v[1] += v[0] - } - } - 4 => { - for v in values.chunks_exact_mut(4) { - v[1] += v[0]; - v[3] += v[2]; - v[2] += v[0]; - v[3] += v[1]; - } - } - 8 => { - for v in values.chunks_exact_mut(8) { - v[1] += v[0]; - v[3] += v[2]; - v[2] += v[0]; - v[3] += v[1]; - v[5] += v[4]; - v[7] += v[6]; - v[6] += v[4]; - v[7] += v[5]; - v[4] += v[0]; - v[5] += v[1]; - v[6] += v[2]; - v[7] += v[3]; - } - } - 16 => { - for v in values.chunks_exact_mut(16) { - for v in v.chunks_exact_mut(4) { - v[1] += v[0]; - v[3] += v[2]; - v[2] += v[0]; - v[3] += v[1]; - } - let (a, v) = v.split_at_mut(4); - let (b, v) = v.split_at_mut(4); - let (c, d) = v.split_at_mut(4); - for i in 0..4 { - b[i] += a[i]; - d[i] += c[i]; - c[i] += a[i]; - d[i] += b[i]; - } - } - } - n => { - let n1 = 1 << (n.trailing_zeros() / 2); - let n2 = n / n1; - wavelet_transform_batch(values, n1); - transpose(values, n2, n1); - wavelet_transform_batch(values, n2); - transpose(values, n1, n2); - } - } -} - -/// Transpose a matrix in-place. -/// Will batch transpose multiple matrices if the length of the slice is a multiple of rows * cols. -pub fn transpose(matrix: &mut [F], rows: usize, cols: usize) { - debug_assert_eq!(matrix.len() % rows * cols, 0); - if rows == cols { - for matrix in matrix.chunks_exact_mut(rows * cols) { - transpose_square(matrix, rows, cols); - } - } else { - // TODO: Special case for rows = 2 * cols and cols = 2 * rows. - let mut scratch = vec![F::ZERO; rows * cols]; - for matrix in matrix.chunks_exact_mut(rows * cols) { - scratch.copy_from_slice(matrix); - for i in 0..rows { - for j in 0..cols { - matrix[j * rows + i] = scratch[i * cols + j]; - } - } - } - } -} - -// Transpose a square power-of-two matrix in-place. -fn transpose_square(matrix: &mut [F], size: usize, stride: usize) { - debug_assert!(matrix.len() >= (size - 1) * stride + size); - debug_assert!(size.is_power_of_two()); - if size * size > NttEngine::::WORKLOAD_SIZE { - // Recurse into quadrants. - // This results in a cache-oblivious algorithm. - let n = size / 2; - let (upper, lower) = matrix.split_at_mut(n * stride); - // Ideally we'd parallelize this, but its not possible to - // express the strided matrices without unsafe code. - transpose_square(upper, n, stride); - transpose_square_swap(&mut upper[n..], lower, n, stride); - transpose_square(&mut lower[n..], n, stride); - } else { - for i in 0..size { - for j in (i + 1)..size { - matrix.swap(i * stride + j, j * stride + i); - } - } - } -} - -/// Transpose and swap two square power-of-two size matrices. -fn transpose_square_swap(a: &mut [F], b: &mut [F], size: usize, stride: usize) { - debug_assert!(a.len() >= (size - 1) * stride + size); - debug_assert!(b.len() >= (size - 1) * stride + size); - debug_assert!(size.is_power_of_two()); - if size * size > NttEngine::::WORKLOAD_SIZE { - // Recurse into quadrants. - // This results in a cache-oblivious algorithm. - let n = size / 2; - let (a_upper, a_lower) = a.split_at_mut(n * stride); - let (b_upper, b_lower) = b.split_at_mut(n * stride); - // Ideally we'd parallelize this, but its not possible to - // express the strided matrices without unsafe code. - transpose_square_swap(a_upper, b_upper, n, stride); - transpose_square_swap(&mut a_upper[n..], b_lower, n, stride); - transpose_square_swap(a_lower, &mut b_upper[n..], n, stride); - transpose_square_swap(&mut a_lower[n..], &mut b_lower[n..], n, stride); - } else { - for i in 0..size { - for j in 0..size { - // The compiler does not eliminate the bounds checks here, - // but this doesn't matter as it is bottlenecked by memory bandwidth. - swap(&mut a[i * stride + j], &mut b[j * stride + i]); - } - } - } -} - -/// Compute the largest factor of n that is <= sqrt(n). -/// Assumes n is of the form 2^k * {1,3,9}. -fn sqrt_factor(n: usize) -> usize { - let twos = n.trailing_zeros(); - match n >> twos { - 1 => 1 << (twos / 2), - 3 | 9 => 3 << (twos / 2), - _ => panic!(), - } -} - -/// Least common multiple. -fn lcm(a: usize, b: usize) -> usize { - a * b / gcd(a, b) -} - -// Greatest common divisor. -fn gcd(mut a: usize, mut b: usize) -> usize { - while b != 0 { - (a, b) = (b, a % b); - } - a -} diff --git a/src/ntt/transpose.rs b/src/ntt/transpose.rs new file mode 100644 index 0000000..3af84a4 --- /dev/null +++ b/src/ntt/transpose.rs @@ -0,0 +1,135 @@ +use super::{utils::workload_size, MatrixMut}; +use std::mem::swap; + +#[cfg(feature = "parallel")] +use rayon::join; + +/// Transpose a matrix in-place. +/// Will batch transpose multiple matrices if the length of the slice is a multiple of rows * cols. +pub fn transpose(matrix: &mut [F], rows: usize, cols: usize) { + debug_assert_eq!(matrix.len() % rows * cols, 0); + // eprintln!( + // "Transpose {} x {rows} x {cols} matrix.", + // matrix.len() / (rows * cols) + // ); + if rows == cols { + for matrix in matrix.chunks_exact_mut(rows * cols) { + let matrix = MatrixMut::from_mut_slice(matrix, rows, cols); + transpose_square(matrix); + } + } else { + // TODO: Special case for rows = 2 * cols and cols = 2 * rows. + // TODO: Special case for very wide matrices (e.g. n x 16). + let mut scratch = vec![matrix[0]; rows * cols]; + for matrix in matrix.chunks_exact_mut(rows * cols) { + scratch.copy_from_slice(matrix); + let src = MatrixMut::from_mut_slice(scratch.as_mut_slice(), rows, cols); + let dst = MatrixMut::from_mut_slice(matrix, cols, rows); + transpose_copy(src, dst); + } + } +} + +fn transpose_copy(src: MatrixMut, mut dst: MatrixMut) { + assert_eq!(src.rows(), dst.cols()); + assert_eq!(src.cols(), dst.rows()); + if src.rows() * src.cols() > workload_size::() { + // Split along longest axis and recurse. + // This results in a cache-oblivious algorithm. + let ((a, b), (x, y)) = if src.rows() > src.cols() { + let n = src.rows() / 2; + (src.split_vertical(n), dst.split_horizontal(n)) + } else { + let n = src.cols() / 2; + (src.split_horizontal(n), dst.split_vertical(n)) + }; + #[cfg(not(feature = "parallel"))] + { + transpose_copy(a, x); + transpose_copy(b, y); + } + #[cfg(feature = "parallel")] + join(|| transpose_copy(a, x), || transpose_copy(b, y)); + } else { + for i in 0..src.rows() { + for j in 0..src.cols() { + dst[(j, i)] = src[(i, j)]; + } + } + } +} + +/// Transpose a square matrix in-place. +fn transpose_square(mut m: MatrixMut) { + debug_assert!(m.is_square()); + debug_assert!(m.rows().is_power_of_two()); + let size = m.rows(); + if size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (a, b, c, d) = m.split_quadrants(n, n); + + #[cfg(not(feature = "parallel"))] + { + transpose_square(a); + transpose_square(d); + transpose_square_swap(b, c); + } + #[cfg(feature = "parallel")] + join( + || transpose_square(a), + || join(|| transpose_square(d), || transpose_square_swap(b, c)), + ); + } else { + for i in 0..size { + for j in (i + 1)..size { + m.swap((i, j), (j, i)); + } + } + } +} + +/// Transpose and swap two square size matrices. +fn transpose_square_swap(mut a: MatrixMut, mut b: MatrixMut) { + debug_assert!(a.is_square()); + debug_assert_eq!(a.rows(), b.cols()); + debug_assert_eq!(a.cols(), b.rows()); + let size = a.rows(); + if 2 * size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (aa, ab, ac, ad) = a.split_quadrants(n, n); + let (ba, bb, bc, bd) = b.split_quadrants(n, n); + + #[cfg(not(feature = "parallel"))] + { + transpose_square_swap(aa, ba); + transpose_square_swap(ab, bc); + transpose_square_swap(ac, bb); + transpose_square_swap(ad, bd); + } + #[cfg(feature = "parallel")] + join( + || { + join( + || transpose_square_swap(aa, ba), + || transpose_square_swap(ab, bc), + ) + }, + || { + join( + || transpose_square_swap(ac, bb), + || transpose_square_swap(ad, bd), + ) + }, + ); + } else { + for i in 0..size { + for j in 0..size { + swap(&mut a[(i, j)], &mut b[(j, i)]) + } + } + } +} diff --git a/src/ntt/utils.rs b/src/ntt/utils.rs new file mode 100644 index 0000000..d71223b --- /dev/null +++ b/src/ntt/utils.rs @@ -0,0 +1,48 @@ +/// Target single-thread workload size for `T`. +/// Should ideally be a multiple of a cache line (64 bytes) +/// and close to the L1 cache size (32 KB). +pub const fn workload_size() -> usize { + const CACHE_SIZE: usize = 1 << 15; + CACHE_SIZE / size_of::() +} + +/// Cast a slice into chunks of size N. +/// +/// TODO: Replace with `slice::as_chunks` when stable. +pub fn as_chunks_exact_mut(slice: &mut [T]) -> &mut [[T; N]] { + assert!(N != 0, "chunk size must be non-zero"); + assert_eq!( + slice.len() % N, + 0, + "slice length must be a multiple of chunk size" + ); + // SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length + let new_len = slice.len() / N; + // SAFETY: We cast a slice of `new_len * N` elements into + // a slice of `new_len` many `N` elements chunks. + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) } +} + +/// Compute the largest factor of n that is <= sqrt(n). +/// Assumes n is of the form 2^k * {1,3,9}. +pub fn sqrt_factor(n: usize) -> usize { + let twos = n.trailing_zeros(); + match n >> twos { + 1 => 1 << (twos / 2), + 3 | 9 => 3 << (twos / 2), + _ => panic!(), + } +} + +/// Least common multiple. +pub fn lcm(a: usize, b: usize) -> usize { + a * b / gcd(a, b) +} + +// Greatest common divisor. +pub fn gcd(mut a: usize, mut b: usize) -> usize { + while b != 0 { + (a, b) = (b, a % b); + } + a +} diff --git a/src/ntt/wavelet.rs b/src/ntt/wavelet.rs new file mode 100644 index 0000000..88ed2ae --- /dev/null +++ b/src/ntt/wavelet.rs @@ -0,0 +1,90 @@ +use super::{transpose, utils::workload_size}; +use ark_ff::Field; +use std::cmp::max; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Fast Wavelet Transform. +/// +/// The input slice must have a length that is a power of two. +/// Recursively applies the kernel +/// [1 0] +/// [1 1] +pub fn wavelet_transform(values: &mut [F]) { + debug_assert!(values.len().is_power_of_two()); + wavelet_transform_batch(values, values.len()) +} + +pub fn wavelet_transform_batch(values: &mut [F], size: usize) { + debug_assert_eq!(values.len() % size, 0); + debug_assert!(size.is_power_of_two()); + #[cfg(feature = "parallel")] + if values.len() > workload_size::() && values.len() != size { + // Multiple wavelet transforms, compute in parallel. + // Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`. + let workload_size = size * max(1, workload_size::() / size); + return values.par_chunks_mut(workload_size).for_each(|values| { + wavelet_transform_batch(values, size); + }); + } + match size { + 0 | 1 => {} + 2 => { + for v in values.chunks_exact_mut(2) { + v[1] += v[0] + } + } + 4 => { + for v in values.chunks_exact_mut(4) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + } + } + 8 => { + for v in values.chunks_exact_mut(8) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + v[5] += v[4]; + v[7] += v[6]; + v[6] += v[4]; + v[7] += v[5]; + v[4] += v[0]; + v[5] += v[1]; + v[6] += v[2]; + v[7] += v[3]; + } + } + 16 => { + for v in values.chunks_exact_mut(16) { + for v in v.chunks_exact_mut(4) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + } + let (a, v) = v.split_at_mut(4); + let (b, v) = v.split_at_mut(4); + let (c, d) = v.split_at_mut(4); + for i in 0..4 { + b[i] += a[i]; + d[i] += c[i]; + c[i] += a[i]; + d[i] += b[i]; + } + } + } + n => { + let n1 = 1 << (n.trailing_zeros() / 2); + let n2 = n / n1; + wavelet_transform_batch(values, n1); + transpose(values, n2, n1); + wavelet_transform_batch(values, n2); + transpose(values, n1, n2); + } + } +} diff --git a/src/poly_utils/coeffs.rs b/src/poly_utils/coeffs.rs index d0efcb8..83629a7 100644 --- a/src/poly_utils/coeffs.rs +++ b/src/poly_utils/coeffs.rs @@ -1,5 +1,5 @@ use super::{evals::EvaluationsList, hypercube::BinaryHypercubePoint, MultilinearPoint}; -use crate::crypto::ntt::wavelet_transform; +use crate::ntt::wavelet_transform; use ark_ff::Field; use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial}; #[cfg(feature = "parallel")] diff --git a/src/poly_utils/fold.rs b/src/poly_utils/fold.rs index 91c25ba..00016cb 100644 --- a/src/poly_utils/fold.rs +++ b/src/poly_utils/fold.rs @@ -1,4 +1,4 @@ -use crate::crypto::ntt::intt_batch; +use crate::ntt::intt_batch; use crate::parameters::FoldType; use ark_ff::{FftField, Field}; diff --git a/src/utils.rs b/src/utils.rs index 4f8f708..cc30724 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,6 @@ -use std::collections::BTreeSet; - +use crate::ntt::transpose; use ark_ff::Field; +use std::collections::BTreeSet; pub fn is_power_of_two(n: usize) -> bool { n & (n - 1) == 0 @@ -50,19 +50,12 @@ pub fn dedup(v: impl IntoIterator) -> Vec { // Takes the vector of evaluations (assume that evals[i] = f(omega^i)) // and folds them into a vector of such that folded_evals[i] = [f(omega^(i + k * j)) for j in 0..folding_factor] -pub fn stack_evaluations(evals: Vec, folding_factor: usize) -> Vec { +pub fn stack_evaluations(mut evals: Vec, folding_factor: usize) -> Vec { let folding_factor_exp = 1 << folding_factor; assert!(evals.len() % folding_factor_exp == 0); let size_of_new_domain = evals.len() / folding_factor_exp; - - let mut stacked_evaluations = Vec::with_capacity(evals.len()); - for i in 0..size_of_new_domain { - for j in 0..folding_factor_exp { - stacked_evaluations.push(evals[i + j * size_of_new_domain]); - } - } - - stacked_evaluations + transpose(&mut evals, folding_factor_exp, size_of_new_domain); + evals } #[cfg(test)] @@ -71,11 +64,13 @@ mod tests { #[test] fn test_evaluations_stack() { + use crate::crypto::fields::Field64 as F; + let num = 256; let folding_factor = 3; let fold_size = 1 << folding_factor; assert_eq!(num % fold_size, 0); - let evals: Vec<_> = (0..num).collect(); + let evals: Vec<_> = (0..num as u64).map(F::from).collect(); let stacked = stack_evaluations(evals, folding_factor); assert_eq!(stacked.len(), num); @@ -83,7 +78,7 @@ mod tests { for (i, fold) in stacked.chunks_exact(fold_size).enumerate() { assert_eq!(fold.len(), fold_size); for j in 0..fold_size { - assert_eq!(fold[j], i + j * num / fold_size); + assert_eq!(fold[j], F::from((i + j * num / fold_size) as u64)); } } } diff --git a/src/whir/committer.rs b/src/whir/committer.rs index 0fc5945..762750a 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -1,11 +1,12 @@ use super::parameters::WhirConfig; use crate::{ + ntt::expand_from_coeff, poly_utils::{coeffs::CoefficientList, fold::restructure_evaluations, MultilinearPoint}, utils, }; use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; use ark_ff::FftField; -use ark_poly::{univariate::DensePolynomial, EvaluationDomain}; +use ark_poly::EvaluationDomain; use nimue::{ plugins::ark::{FieldChallenges, FieldWriter}, ByteWriter, Merlin, ProofResult, @@ -49,11 +50,10 @@ where Merlin: FieldChallenges + ByteWriter, { let base_domain = self.0.starting_domain.base_domain.unwrap(); - let univariate: DensePolynomial<_> = polynomial.clone().into(); - let evals = univariate - .evaluate_over_domain_by_ref(self.0.starting_domain.base_domain.unwrap()) - .evals; - + let expansion = base_domain.size() / polynomial.num_coeffs(); + let evals = expand_from_coeff(polynomial.coeffs(), expansion); + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. let folded_evals = utils::stack_evaluations(evals, self.0.folding_factor); let folded_evals = restructure_evaluations( folded_evals, diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 420147e..a157fe9 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -1,6 +1,7 @@ use super::{committer::Witness, parameters::WhirConfig, Statement, WhirProof}; use crate::{ domain::Domain, + ntt::expand_from_coeff, parameters::FoldType, poly_utils::{ coeffs::CoefficientList, @@ -12,7 +13,7 @@ use crate::{ }; use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; use ark_ff::FftField; -use ark_poly::{univariate::DensePolynomial, EvaluationDomain}; +use ark_poly::EvaluationDomain; use nimue::{ plugins::{ ark::{FieldChallenges, FieldWriter}, @@ -170,11 +171,10 @@ where // Fold the coefficients, and compute fft of polynomial (and commit) let new_domain = round_state.domain.scale(2); - let univariate: DensePolynomial<_> = folded_coefficients.clone().into(); - let evals = univariate - .evaluate_over_domain_by_ref(new_domain.backing_domain) - .evals; - + let expansion = new_domain.size() / folded_coefficients.num_coeffs(); + let evals = expand_from_coeff(folded_coefficients.coeffs(), expansion); + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. let folded_evals = utils::stack_evaluations(evals, self.0.folding_factor); let folded_evals = restructure_evaluations( folded_evals, @@ -183,6 +183,7 @@ where new_domain.backing_domain.group_gen_inv(), self.0.folding_factor, ); + #[cfg(not(feature = "parallel"))] let leafs_iter = folded_evals.chunks_exact(1 << self.0.folding_factor); #[cfg(feature = "parallel")] diff --git a/src/whir_ldt/committer.rs b/src/whir_ldt/committer.rs index a3749aa..692d2a8 100644 --- a/src/whir_ldt/committer.rs +++ b/src/whir_ldt/committer.rs @@ -1,11 +1,12 @@ use super::parameters::WhirConfig; use crate::{ + ntt::expand_from_coeff, poly_utils::{coeffs::CoefficientList, fold::restructure_evaluations}, utils, }; use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; use ark_ff::FftField; -use ark_poly::{univariate::DensePolynomial, EvaluationDomain}; +use ark_poly::EvaluationDomain; use nimue::{plugins::ark::FieldChallenges, ByteWriter, Merlin, ProofResult}; #[cfg(feature = "parallel")] @@ -44,9 +45,10 @@ where Merlin: FieldChallenges + ByteWriter, { let base_domain = self.0.starting_domain.base_domain.unwrap(); - let univariate: DensePolynomial<_> = polynomial.clone().into(); - let evals = univariate.evaluate_over_domain_by_ref(base_domain).evals; - + let expansion = base_domain.size() / polynomial.num_coeffs(); + let evals = expand_from_coeff(polynomial.coeffs(), expansion); + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. let folded_evals = utils::stack_evaluations(evals, self.0.folding_factor); let folded_evals = restructure_evaluations( folded_evals, diff --git a/src/whir_ldt/prover.rs b/src/whir_ldt/prover.rs index 187f6ac..e042991 100644 --- a/src/whir_ldt/prover.rs +++ b/src/whir_ldt/prover.rs @@ -1,6 +1,7 @@ use super::{committer::Witness, parameters::WhirConfig, WhirProof}; use crate::{ domain::Domain, + ntt::expand_from_coeff, parameters::FoldType, poly_utils::{ coeffs::CoefficientList, @@ -12,7 +13,7 @@ use crate::{ }; use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; use ark_ff::FftField; -use ark_poly::{univariate::DensePolynomial, EvaluationDomain}; +use ark_poly::EvaluationDomain; use nimue::{ plugins::{ ark::{FieldChallenges, FieldWriter}, @@ -144,11 +145,10 @@ where // Fold the coefficients, and compute fft of polynomial (and commit) let new_domain = round_state.domain.scale(2); - let univariate: DensePolynomial<_> = folded_coefficients.clone().into(); - let evals = univariate - .evaluate_over_domain_by_ref(new_domain.backing_domain) - .evals; - + let expansion = new_domain.size() / folded_coefficients.num_coeffs(); + let evals = expand_from_coeff(folded_coefficients.coeffs(), expansion); + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. let folded_evals = utils::stack_evaluations(evals, self.0.folding_factor); let folded_evals = restructure_evaluations( folded_evals,