Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel square transpose #10

Merged
merged 13 commits into from
Sep 10, 2024
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ rayon = { version = "1.10.0", optional = true }
debug = true

[features]
default = ["parallel", "dep:rayon"]
default = ["parallel"]
#default = []
parallel = [
"dep:rayon",
"ark-poly/parallel",
"ark-ff/parallel",
"ark-crypto-primitives/parallel",
Expand Down
1 change: 0 additions & 1 deletion src/crypto/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
pub mod fields;
pub mod merkle_tree;
pub mod ntt;
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod cmdline_utils;
pub mod crypto; // Crypto utils
pub mod domain; // Domain that we are evaluating over
pub mod fs_utils;
pub mod ntt;
pub mod parameters;
pub mod poly_utils; // Utils for polynomials
pub mod sumcheck; // Sumcheck specialised
Expand Down
170 changes: 170 additions & 0 deletions src/ntt/matrix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
//! Minimal matrix class that supports strided access.
//! This abstracts over the unsafe pointer arithmetic required for transpose-like algorithms.

#![allow(unsafe_code)]

use std::{
marker::PhantomData,
ops::{Index, IndexMut},
ptr, slice,
};

/// Mutable reference to a matrix.
///
/// The invariant this data structure maintains is that `data` has lifetime
/// `'a` and points to a collection of `rows` rowws, at intervals `row_stride`,
/// each of length `cols`.
pub struct MatrixMut<'a, T> {
data: *mut T,
rows: usize,
cols: usize,
row_stride: usize,
_lifetime: PhantomData<&'a mut T>,
}

unsafe impl<'a, T: Send> Send for MatrixMut<'_, T> {}

unsafe impl<'a, T: Sync> Sync for MatrixMut<'_, T> {}

impl<'a, T> MatrixMut<'a, T> {
pub fn from_mut_slice(slice: &'a mut [T], rows: usize, cols: usize) -> Self {
assert_eq!(slice.len(), rows * cols);
// Safety: The input slice is valid for the lifetime `'a` and has
// `rows` contiguous rows of length `cols`.
Self {
data: slice.as_mut_ptr(),
rows,
cols,
row_stride: cols,
_lifetime: PhantomData,
}
}

pub fn rows(&self) -> usize {
self.rows
}

pub fn cols(&self) -> usize {
self.cols
}

pub fn is_square(&self) -> bool {
self.rows == self.cols
}

pub fn row(&mut self, row: usize) -> &mut [T] {
assert!(row < self.rows);
// Safety: The structure invariant guarantees that at offset `row * self.row_stride`
// there is valid data of length `self.cols`.
unsafe { slice::from_raw_parts_mut(self.data.add(row * self.row_stride), self.cols) }
}

/// Split the matrix into two vertically.
///
/// [A] = self
/// [B]
pub fn split_vertical(self, row: usize) -> (Self, Self) {
assert!(row <= self.rows);
(
Self {
data: self.data,
rows: row,
cols: self.cols,
row_stride: self.row_stride,
_lifetime: PhantomData,
},
Self {
data: unsafe { self.data.add(row * self.row_stride) },
rows: self.rows - row,
cols: self.cols,
row_stride: self.row_stride,
_lifetime: PhantomData,
},
)
}

/// Split the matrix into two horizontally.
///
/// [A B] = self
pub fn split_horizontal(self, col: usize) -> (Self, Self) {
assert!(col <= self.cols);
(
// Safety: This reduces the number of cols, keeping all else the same.
Self {
data: self.data,
rows: self.rows,
cols: col,
row_stride: self.row_stride,
_lifetime: PhantomData,
},
// Safety: This reduces the number of cols and offsets and, keeping all else the same.
Self {
data: unsafe { self.data.add(col) },
rows: self.rows,
cols: self.cols - col,
row_stride: self.row_stride,
_lifetime: PhantomData,
},
)
}

/// Split the matrix into four quadrants.
///
/// [A B] = self
/// [C D]
pub fn split_quadrants(self, row: usize, col: usize) -> (Self, Self, Self, Self) {
let (u, d) = self.split_vertical(row);
let (a, b) = u.split_horizontal(col);
let (c, d) = d.split_horizontal(col);
(a, b, c, d)
}

/// Swap two elements in the matrix.
pub fn swap(&mut self, a: (usize, usize), b: (usize, usize)) {
if a != b {
unsafe {
let a = self.ptr_at_mut(a.0, a.1);
let b = self.ptr_at_mut(b.0, b.1);
ptr::swap_nonoverlapping(a, b, 1)
}
}
}

unsafe fn ptr_at(&self, row: usize, col: usize) -> *const T {
assert!(row < self.rows);
assert!(col < self.cols);
// Safety: The structure invariant guarantees that at offset `row * self.row_stride + col`
// there is valid data.
self.data.add(row * self.row_stride + col)
}

unsafe fn ptr_at_mut(&mut self, row: usize, col: usize) -> *mut T {
assert!(row < self.rows);
assert!(col < self.cols);
// Safety: The structure invariant guarantees that at offset `row * self.row_stride + col`
// there is valid data.
self.data.add(row * self.row_stride + col)
}
}

impl<T> Index<(usize, usize)> for MatrixMut<'_, T> {
type Output = T;

fn index(&self, (row, col): (usize, usize)) -> &T {
assert!(row < self.rows);
assert!(col < self.cols);
// Safety: The structure invariant guarantees that at offset `row * self.row_stride + col`
// there is valid data.
unsafe { &*self.ptr_at(row, col) }
}
}

impl<T> IndexMut<(usize, usize)> for MatrixMut<'_, T> {
fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut T {
assert!(row < self.rows);
assert!(col < self.cols);
// Safety: The structure invariant guarantees that at offset `row * self.row_stride + col`
// there is valid data.
unsafe { &mut *self.ptr_at_mut(row, col) }
}
}
61 changes: 61 additions & 0 deletions src/ntt/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//! NTT and related algorithms.

mod matrix;
mod ntt;
mod transpose;
mod utils;
mod wavelet;

use self::matrix::MatrixMut;
use ark_ff::FftField;

#[cfg(feature = "parallel")]
use rayon::prelude::*;

pub use self::{
ntt::{intt, intt_batch, ntt, ntt_batch},
transpose::transpose,
wavelet::wavelet_transform,
};

/// RS encode at a rate 1/`expansion`.
pub fn expand_from_coeff<F: FftField>(coeffs: &[F], expansion: usize) -> Vec<F> {
let engine = ntt::NttEngine::<F>::new_from_cache();
let expanded_size = coeffs.len() * expansion;
let mut result = Vec::with_capacity(expanded_size);
// Note: We can also zero-extend the coefficients and do a larger NTT.
// But this is more efficient.

// Do coset NTT.
let root = engine.root(expanded_size);
result.extend_from_slice(coeffs);
#[cfg(not(feature = "parallel"))]
for i in 1..expansion {
let root = root.pow([i as u64]);
let mut offset = F::ONE;
result.extend(coeffs.iter().map(|x| {
let val = *x * offset;
offset *= root;
val
}));
}
#[cfg(feature = "parallel")]
result.par_extend((1..expansion).into_par_iter().flat_map(|i| {
let root_i = root.pow([i as u64]);
coeffs
.par_iter()
.enumerate()
.map_with(F::ZERO, move |root_j, (j, coeff)| {
if root_j.is_zero() {
*root_j = root_i.pow([j as u64]);
} else {
*root_j *= root_i;
}
*coeff * *root_j
})
}));

ntt_batch(&mut result, coeffs.len());
transpose(&mut result, expansion, coeffs.len());
result
}
Loading