Skip to content

Commit

Permalink
Merge pull request #14 from GottfriedHerold/understanding_whir
Browse files Browse the repository at this point in the history
Documentation and minor robustness changes
  • Loading branch information
WizardOfMenlo authored Nov 11, 2024
2 parents 901f92e + 02b3c9c commit 3c764b4
Show file tree
Hide file tree
Showing 14 changed files with 744 additions and 102 deletions.
4 changes: 4 additions & 0 deletions src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ where
})
}

// returns the size of the domain after folding folding_factor many times.
//
// This asserts that the domain size is divisible by 1 << folding_factor
pub fn folded_size(&self, folding_factor: usize) -> usize {
assert!(self.backing_domain.size() % (1 << folding_factor) == 0);
self.backing_domain.size() / (1 << folding_factor)
}

Expand Down
45 changes: 31 additions & 14 deletions src/ntt/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ unsafe impl<'a, T: Send> Send for MatrixMut<'_, T> {}
unsafe impl<'a, T: Sync> Sync for MatrixMut<'_, T> {}

impl<'a, T> MatrixMut<'a, T> {
/// creates a MatrixMut from `slice`, where slice is the concatenations of `rows` rows, each consisting of `cols` many entries.
pub fn from_mut_slice(slice: &'a mut [T], rows: usize, cols: usize) -> Self {
assert_eq!(slice.len(), rows * cols);
// Safety: The input slice is valid for the lifetime `'a` and has
Expand All @@ -40,28 +41,33 @@ impl<'a, T> MatrixMut<'a, T> {
}
}

/// returns the number of rows
pub fn rows(&self) -> usize {
self.rows
}

/// returns the number of columns
pub fn cols(&self) -> usize {
self.cols
}

/// checks whether the matrix is a square matrix
pub fn is_square(&self) -> bool {
self.rows == self.cols
}

/// returns a mutable reference to the `row`'th row of the MatrixMut
pub fn row(&mut self, row: usize) -> &mut [T] {
assert!(row < self.rows);
// Safety: The structure invariant guarantees that at offset `row * self.row_stride`
// there is valid data of length `self.cols`.
unsafe { slice::from_raw_parts_mut(self.data.add(row * self.row_stride), self.cols) }
}

/// Split the matrix into two vertically.
/// Split the matrix into two vertically at the `row`'th row (meaning that in the returned pair (A,B), the matrix A has `row` rows).
///
/// [A] = self
/// [A]
/// [ ] = self
/// [B]
pub fn split_vertical(self, row: usize) -> (Self, Self) {
assert!(row <= self.rows);
Expand All @@ -83,7 +89,7 @@ impl<'a, T> MatrixMut<'a, T> {
)
}

/// Split the matrix into two horizontally.
/// Split the matrix into two horizontally at the `col`th column (meaning that in the returned pair (A,B), the matrix A has `col` columns).
///
/// [A B] = self
pub fn split_horizontal(self, col: usize) -> (Self, Self) {
Expand All @@ -108,19 +114,21 @@ impl<'a, T> MatrixMut<'a, T> {
)
}

/// Split the matrix into four quadrants.
/// Split the matrix into four quadrants at the indicated `row` and `col` (meaning that in the returned 4-tuple (A,B,C,D), the matrix A is a `row`x`col` matrix)
///
/// [A B] = self
/// [C D]
/// self = [A B]
/// [C D]
pub fn split_quadrants(self, row: usize, col: usize) -> (Self, Self, Self, Self) {
let (u, d) = self.split_vertical(row);
let (u, l) = self.split_vertical(row); // split into upper and lower parts
let (a, b) = u.split_horizontal(col);
let (c, d) = d.split_horizontal(col);
let (c, d) = l.split_horizontal(col);
(a, b, c, d)
}

/// Swap two elements in the matrix.
pub fn swap(&mut self, a: (usize, usize), b: (usize, usize)) {
/// Swap two elements `a` and `b` in the matrix.
/// Each of `a`, `b` is given as (row,column)-pair.
/// If the given coordinates are out-of-bounds, the behaviour is undefined.
pub unsafe fn swap(&mut self, a: (usize, usize), b: (usize, usize)) {
if a != b {
unsafe {
let a = self.ptr_at_mut(a.0, a.1);
Expand All @@ -130,23 +138,32 @@ impl<'a, T> MatrixMut<'a, T> {
}
}

/// returns an immutable pointer to the element at (`row`, `col`). This performs no bounds checking and provining indices out-of-bounds is UB.
unsafe fn ptr_at(&self, row: usize, col: usize) -> *const T {
assert!(row < self.rows);
assert!(col < self.cols);
// Safe to call under the following assertion (checked by caller)
// assert!(row < self.rows);
// assert!(col < self.cols);

// Safety: The structure invariant guarantees that at offset `row * self.row_stride + col`
// there is valid data.
self.data.add(row * self.row_stride + col)
}

/// returns a mutable pointer to the element at (`row`, `col`). This performs no bounds checking and provining indices out-of-bounds is UB.
unsafe fn ptr_at_mut(&mut self, row: usize, col: usize) -> *mut T {
assert!(row < self.rows);
assert!(col < self.cols);
// Safe to call under the following assertion (checked by caller)
//
// assert!(row < self.rows);
// assert!(col < self.cols);

// Safety: The structure invariant guarantees that at offset `row * self.row_stride + col`
// there is valid data.
self.data.add(row * self.row_stride + col)
}
}

// Use MatrixMut::ptr_at and MatrixMut::ptr_at_mut to implement Index and IndexMut. The latter are not unsafe, since they contain bounds-checks.

impl<T> Index<(usize, usize)> for MatrixMut<'_, T> {
type Output = T;

Expand Down
23 changes: 14 additions & 9 deletions src/ntt/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ static ENGINE_CACHE: LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>
/// Enginge for computing NTTs over arbitrary fields.
/// Assumes the field has large two-adicity.
pub struct NttEngine<F: Field> {
order: usize,
omega_order: F,
order: usize, // order of omega_orger
omega_order: F, // primitive order'th root.

// Small roots (zero if unavailable)
// Roots of small order (zero if unavailable). The naming convention is that omega_foo has order foo.
half_omega_3_1_plus_2: F, // ½(ω₃ + ω₃²)
half_omega_3_1_min_2: F, // ½(ω₃ - ω₃²)
omega_4_1: F,
Expand All @@ -45,20 +45,23 @@ pub struct NttEngine<F: Field> {

/// Compute the NTT of a slice of field elements using a cached engine.
pub fn ntt<F: FftField>(values: &mut [F]) {
NttEngine::new_from_cache().ntt(values);
NttEngine::<F>::new_from_cache().ntt(values);
}

/// Compute the many NTTs of size `size` using a cached engine.
pub fn ntt_batch<F: FftField>(values: &mut [F], size: usize) {
NttEngine::new_from_cache().ntt_batch(values, size);
NttEngine::<F>::new_from_cache().ntt_batch(values, size);
}

/// Compute the inverse NTT of a slice of field element without the 1/n scaling factor, using a cached engine.
pub fn intt<F: FftField>(values: &mut [F]) {
NttEngine::new_from_cache().intt(values);
NttEngine::<F>::new_from_cache().intt(values);
}


/// Compute the inverse NTT of multiple slice of field elements, each of size `size`, without the 1/n scaling factor and using a cached engine.
pub fn intt_batch<F: FftField>(values: &mut [F], size: usize) {
NttEngine::new_from_cache().intt_batch(values, size);
NttEngine::<F>::new_from_cache().intt_batch(values, size);
}

impl<F: FftField> NttEngine<F> {
Expand Down Expand Up @@ -90,10 +93,11 @@ impl<F: FftField> NttEngine<F> {
}
}

/// Creates a new NttEngine. `omega_order` must be a primitive root of unity of even order `omega`.
impl<F: Field> NttEngine<F> {
pub fn new(order: usize, omega_order: F) -> Self {
assert!(order.trailing_zeros() > 0, "Order must be a power of 2.");
// TODO: Assert that omega_order factors into 2s and 3s.
assert!(order.trailing_zeros() > 0, "Order must be a multiple of 2.");
// TODO: Assert that omega factors into 2s and 3s.
assert_eq!(omega_order.pow([order as u64]), F::ONE);
assert_ne!(omega_order.pow([order as u64 / 2]), F::ONE);
let mut res = NttEngine {
Expand All @@ -112,6 +116,7 @@ impl<F: Field> NttEngine<F> {
if order % 3 == 0 {
let omega_3_1 = res.root(3);
let omega_3_2 = omega_3_1 * omega_3_1;
// Note: char F cannot be 2 and so division by 2 works, because primitive roots of unity with even order exist.
res.half_omega_3_1_min_2 = (omega_3_1 - omega_3_2) / F::from(2u64);
res.half_omega_3_1_plus_2 = (omega_3_1 + omega_3_2) / F::from(2u64);
}
Expand Down
Loading

0 comments on commit 3c764b4

Please sign in to comment.