From fa3dfd20e2e13a10f2cc44e60bad88b4067c6610 Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Fri, 6 Sep 2024 17:59:10 +0200 Subject: [PATCH 1/4] Wavelet transform --- src/poly_utils/coeffs.rs | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/poly_utils/coeffs.rs b/src/poly_utils/coeffs.rs index 90f1a92..6e22054 100644 --- a/src/poly_utils/coeffs.rs +++ b/src/poly_utils/coeffs.rs @@ -1,3 +1,5 @@ +use std::ops::AddAssign; + use super::{evals::EvaluationsList, hypercube::BinaryHypercubePoint, MultilinearPoint}; use ark_ff::Field; use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial}; @@ -226,22 +228,24 @@ where { fn from(value: CoefficientList) -> Self { let mut evals = value.coeffs; - let num_coeffs = evals.len(); - let num_variables = value.num_variables; - - for var in 0..num_variables { - let step = 1 << var; - for i in (0..num_coeffs).step_by(step * 2) { - for j in 0..step { - if i + j + step < num_coeffs { - let sum = evals[i + j] + evals[i + j + step]; - evals[i + j + step] = sum; - } - } + wavelet_transform(&mut evals); + EvaluationsList::new(evals) + } +} + +fn wavelet_transform(values: &mut [F]) +where + F: for<'a> AddAssign<&'a F>, +{ + debug_assert!(values.len().is_power_of_two()); + eprintln!("wavelet_transform {}", values.len().trailing_zeros()); + for r in 0..values.len().trailing_zeros() { + for coeffs in values.chunks_mut(1 << (r + 1)) { + let (left, right) = coeffs.split_at_mut(1 << r); + for (left, right) in left.iter().zip(right.iter_mut()) { + *right += left; } } - - EvaluationsList::new(evals) } } From 5a211b84c12485f1c8a93781b89a2cc56b7b666a Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Fri, 6 Sep 2024 18:58:42 +0200 Subject: [PATCH 2/4] Novel four step discrete wavelet transform --- src/crypto/ntt.rs | 64 ++++++++++++++++++++++++++++++++++++++-- src/poly_utils/coeffs.rs | 19 +----------- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/src/crypto/ntt.rs b/src/crypto/ntt.rs index 9fe6a30..23508a2 100644 --- a/src/crypto/ntt.rs +++ b/src/crypto/ntt.rs @@ -4,9 +4,11 @@ //! A global cache is used for twiddle factors. use ark_ff::{FftField, Field}; -use std::any::{Any, TypeId}; -use std::collections::HashMap; -use std::sync::{Arc, LazyLock, Mutex, RwLock, RwLockReadGuard}; +use std::{ + any::{Any, TypeId}, + collections::HashMap, + sync::{Arc, LazyLock, Mutex, RwLock, RwLockReadGuard}, +}; #[cfg(feature = "parallel")] use { @@ -348,6 +350,62 @@ impl NttEngine { } } +/// Fast Wavelet Transform. +/// +/// The input slice must have a length that is a power of two. +/// Recursively applies the kernel +/// [1 0] +/// [1 1] +pub fn wavelet_transform(values: &mut [F]) { + debug_assert!(values.len().is_power_of_two()); + wavelet_transform_batch(values, values.len()) +} + +pub fn wavelet_transform_batch(values: &mut [F], size: usize) { + debug_assert_eq!(values.len() % size, 0); + debug_assert!(size.is_power_of_two()); + match size { + 0 | 1 => {} + 2 => { + for v in values.chunks_exact_mut(2) { + v[1] += v[0] + } + } + 4 => { + for v in values.chunks_exact_mut(4) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + } + } + 8 => { + for v in values.chunks_exact_mut(8) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + v[5] += v[4]; + v[7] += v[6]; + v[6] += v[4]; + v[7] += v[5]; + v[4] += v[0]; + v[5] += v[1]; + v[6] += v[2]; + v[7] += v[3]; + } + } + n => { + let n1 = 1 << (n.trailing_zeros() / 2); + let n2 = n / n1; + wavelet_transform_batch(values, n1); + transpose(values, n2, n1); + wavelet_transform_batch(values, n2); + transpose(values, n1, n2); + } + } +} + /// Transpose a matrix in-place. /// Will batch transpose multiple matrices if the length of the slice is a multiple of rows * cols. pub fn transpose(matrix: &mut [T], rows: usize, cols: usize) { diff --git a/src/poly_utils/coeffs.rs b/src/poly_utils/coeffs.rs index 6e22054..d0efcb8 100644 --- a/src/poly_utils/coeffs.rs +++ b/src/poly_utils/coeffs.rs @@ -1,6 +1,5 @@ -use std::ops::AddAssign; - use super::{evals::EvaluationsList, hypercube::BinaryHypercubePoint, MultilinearPoint}; +use crate::crypto::ntt::wavelet_transform; use ark_ff::Field; use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial}; #[cfg(feature = "parallel")] @@ -233,22 +232,6 @@ where } } -fn wavelet_transform(values: &mut [F]) -where - F: for<'a> AddAssign<&'a F>, -{ - debug_assert!(values.len().is_power_of_two()); - eprintln!("wavelet_transform {}", values.len().trailing_zeros()); - for r in 0..values.len().trailing_zeros() { - for coeffs in values.chunks_mut(1 << (r + 1)) { - let (left, right) = coeffs.split_at_mut(1 << r); - for (left, right) in left.iter().zip(right.iter_mut()) { - *right += left; - } - } - } -} - /* Previous recursive version impl From> for EvaluationsList where From e23a63201fc347859cf173e1936f0bc45f70f819 Mon Sep 17 00:00:00 2001 From: Giacomo Fenzi Date: Fri, 6 Sep 2024 22:23:31 +0200 Subject: [PATCH 3/4] Switch keccak for MT --- src/bin/benchmark.rs | 24 ++-- src/bin/main.rs | 24 ++-- src/cmdline_utils.rs | 6 +- src/crypto/merkle_tree/blake2.rs | 119 ------------------ src/crypto/merkle_tree/{sha3.rs => keccak.rs} | 38 +++--- src/crypto/merkle_tree/mod.rs | 3 +- 6 files changed, 47 insertions(+), 167 deletions(-) delete mode 100644 src/crypto/merkle_tree/blake2.rs rename src/crypto/merkle_tree/{sha3.rs => keccak.rs} (79%) diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index 4292a26..f857c13 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -110,9 +110,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Goldilocks1, AvailableMerkle::SHA3) => { + (AvailableFields::Goldilocks1, AvailableMerkle::Keccak256) => { use fields::Field64 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -126,9 +126,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Goldilocks2, AvailableMerkle::SHA3) => { + (AvailableFields::Goldilocks2, AvailableMerkle::Keccak256) => { use fields::Field64_2 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -142,9 +142,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Goldilocks3, AvailableMerkle::SHA3) => { + (AvailableFields::Goldilocks3, AvailableMerkle::Keccak256) => { use fields::Field64_3 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -158,9 +158,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Field128, AvailableMerkle::SHA3) => { + (AvailableFields::Field128, AvailableMerkle::Keccak256) => { use fields::Field128 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -174,9 +174,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Field192, AvailableMerkle::SHA3) => { + (AvailableFields::Field192, AvailableMerkle::Keccak256) => { use fields::Field192 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -190,9 +190,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Field256, AvailableMerkle::SHA3) => { + (AvailableFields::Field256, AvailableMerkle::Keccak256) => { use fields::Field256 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); diff --git a/src/bin/main.rs b/src/bin/main.rs index 9322340..7dd16f3 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -79,9 +79,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Goldilocks1, AvailableMerkle::SHA3) => { + (AvailableFields::Goldilocks1, AvailableMerkle::Keccak256) => { use fields::Field64 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -95,9 +95,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Goldilocks2, AvailableMerkle::SHA3) => { + (AvailableFields::Goldilocks2, AvailableMerkle::Keccak256) => { use fields::Field64_2 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -111,9 +111,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Goldilocks3, AvailableMerkle::SHA3) => { + (AvailableFields::Goldilocks3, AvailableMerkle::Keccak256) => { use fields::Field64_3 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -127,9 +127,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Field128, AvailableMerkle::SHA3) => { + (AvailableFields::Field128, AvailableMerkle::Keccak256) => { use fields::Field128 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -143,9 +143,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Field192, AvailableMerkle::SHA3) => { + (AvailableFields::Field192, AvailableMerkle::Keccak256) => { use fields::Field192 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); @@ -159,9 +159,9 @@ fn main() { run_whir::>(args, leaf_hash_params, two_to_one_params); } - (AvailableFields::Field256, AvailableMerkle::SHA3) => { + (AvailableFields::Field256, AvailableMerkle::Keccak256) => { use fields::Field256 as F; - use merkle_tree::sha3 as mt; + use merkle_tree::keccak as mt; let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); run_whir::>(args, leaf_hash_params, two_to_one_params); diff --git a/src/cmdline_utils.rs b/src/cmdline_utils.rs index a2cc1e9..9572fa3 100644 --- a/src/cmdline_utils.rs +++ b/src/cmdline_utils.rs @@ -56,7 +56,7 @@ impl FromStr for AvailableFields { #[derive(Debug, Clone, Copy, Serialize)] pub enum AvailableMerkle { - SHA3, + Keccak256, Blake3, } @@ -64,8 +64,8 @@ impl FromStr for AvailableMerkle { type Err = String; fn from_str(s: &str) -> Result { - if s == "SHA3" { - Ok(Self::SHA3) + if s == "Keccak" { + Ok(Self::Keccak256) } else if s == "Blake3" { Ok(Self::Blake3) } else { diff --git a/src/crypto/merkle_tree/blake2.rs b/src/crypto/merkle_tree/blake2.rs deleted file mode 100644 index f451ee1..0000000 --- a/src/crypto/merkle_tree/blake2.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::{borrow::Borrow, marker::PhantomData}; - -use ark_crypto_primitives::{ - crh::{CRHScheme, TwoToOneCRHScheme}, - merkle_tree::{Config, IdentityDigestConverter}, - sponge::Absorb, -}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use blake2::Digest; -use rand::RngCore; - -use super::HashCounter; - -#[derive( - Debug, Default, Clone, Copy, Eq, PartialEq, Hash, CanonicalSerialize, CanonicalDeserialize, -)] -pub struct Blake2Digest([u8; 32]); - -impl Absorb for Blake2Digest { - fn to_sponge_bytes(&self, dest: &mut Vec) { - dest.extend_from_slice(&self.0); - } - - fn to_sponge_field_elements(&self, dest: &mut Vec) { - let mut buf = [0; 32]; - buf.copy_from_slice(&self.0); - dest.push(F::from_be_bytes_mod_order(&buf)); - } -} - -pub struct Blake2LeafHash(PhantomData); -pub struct Blake2TwoToOneCRHScheme; - -impl CRHScheme for Blake2LeafHash { - type Input = Vec; - type Output = Blake2Digest; - type Parameters = (); - - fn setup(_: &mut R) -> Result { - Ok(()) - } - - fn evaluate>( - _: &Self::Parameters, - input: T, - ) -> Result { - let mut buf = vec![]; - CanonicalSerialize::serialize_compressed(input.borrow(), &mut buf)?; - - let mut h = blake2::Blake2s256::new(); - h.update(&buf); - - let mut output = [0; 32]; - output.copy_from_slice(&h.finalize()[..]); - HashCounter::add(); - Ok(Blake2Digest(output)) - } -} - -impl TwoToOneCRHScheme for Blake2TwoToOneCRHScheme { - type Input = Blake2Digest; - type Output = Blake2Digest; - type Parameters = (); - - fn setup(_: &mut R) -> Result { - Ok(()) - } - - fn evaluate>( - _: &Self::Parameters, - left_input: T, - right_input: T, - ) -> Result { - let mut h = blake2::Blake2s256::new(); - h.update(&left_input.borrow().0); - h.update(&right_input.borrow().0); - let mut output = [0; 32]; - output.copy_from_slice(&h.finalize()[..]); - HashCounter::add(); - Ok(Blake2Digest(output)) - } - - fn compress>( - parameters: &Self::Parameters, - left_input: T, - right_input: T, - ) -> Result { - ::evaluate(parameters, left_input, right_input) - } -} - -pub type LeafH = Blake2LeafHash; -pub type CompressH = Blake2TwoToOneCRHScheme; - -#[derive(Debug, Default, Clone)] -pub struct MerkleTreeParams(PhantomData); - -impl Config for MerkleTreeParams { - type Leaf = Vec; - - type LeafDigest = as CRHScheme>::Output; - type LeafInnerDigestConverter = IdentityDigestConverter; - type InnerDigest = ::Output; - - type LeafHash = LeafH; - type TwoToOneHash = CompressH; -} - -pub fn default_config( - rng: &mut impl RngCore, -) -> ( - as CRHScheme>::Parameters, - ::Parameters, -) { - let leaf_hash_params = as CRHScheme>::setup(rng).unwrap(); - let two_to_one_params = ::setup(rng).unwrap(); - - (leaf_hash_params, two_to_one_params) -} diff --git a/src/crypto/merkle_tree/sha3.rs b/src/crypto/merkle_tree/keccak.rs similarity index 79% rename from src/crypto/merkle_tree/sha3.rs rename to src/crypto/merkle_tree/keccak.rs index 289e414..2e67c18 100644 --- a/src/crypto/merkle_tree/sha3.rs +++ b/src/crypto/merkle_tree/keccak.rs @@ -13,9 +13,9 @@ use sha3::Digest; #[derive( Debug, Default, Clone, Copy, Eq, PartialEq, Hash, CanonicalSerialize, CanonicalDeserialize, )] -pub struct SHA3Digest([u8; 32]); +pub struct KeccakDigest([u8; 32]); -impl Absorb for SHA3Digest { +impl Absorb for KeccakDigest { fn to_sponge_bytes(&self, dest: &mut Vec) { dest.extend_from_slice(&self.0); } @@ -27,24 +27,24 @@ impl Absorb for SHA3Digest { } } -impl From<[u8; 32]> for SHA3Digest { +impl From<[u8; 32]> for KeccakDigest { fn from(value: [u8; 32]) -> Self { - SHA3Digest(value) + KeccakDigest(value) } } -impl AsRef<[u8]> for SHA3Digest { +impl AsRef<[u8]> for KeccakDigest { fn as_ref(&self) -> &[u8] { &self.0 } } -pub struct SHA3LeafHash(PhantomData); -pub struct SHA3TwoToOneCRHScheme; +pub struct KeccakLeafHash(PhantomData); +pub struct KeccakTwoToOneCRHScheme; -impl CRHScheme for SHA3LeafHash { +impl CRHScheme for KeccakLeafHash { type Input = [F]; - type Output = SHA3Digest; + type Output = KeccakDigest; type Parameters = (); fn setup(_: &mut R) -> Result { @@ -58,19 +58,19 @@ impl CRHScheme for SHA3LeafHash { let mut buf = vec![]; CanonicalSerialize::serialize_compressed(input.borrow(), &mut buf)?; - let mut h = sha3::Sha3_256::new(); + let mut h = sha3::Keccak256::new(); h.update(&buf); let mut output = [0; 32]; output.copy_from_slice(&h.finalize()[..]); HashCounter::add(); - Ok(SHA3Digest(output)) + Ok(KeccakDigest(output)) } } -impl TwoToOneCRHScheme for SHA3TwoToOneCRHScheme { - type Input = SHA3Digest; - type Output = SHA3Digest; +impl TwoToOneCRHScheme for KeccakTwoToOneCRHScheme { + type Input = KeccakDigest; + type Output = KeccakDigest; type Parameters = (); fn setup(_: &mut R) -> Result { @@ -82,13 +82,13 @@ impl TwoToOneCRHScheme for SHA3TwoToOneCRHScheme { left_input: T, right_input: T, ) -> Result { - let mut h = sha3::Sha3_256::new(); + let mut h = sha3::Keccak256::new(); h.update(&left_input.borrow().0); h.update(&right_input.borrow().0); let mut output = [0; 32]; output.copy_from_slice(&h.finalize()[..]); HashCounter::add(); - Ok(SHA3Digest(output)) + Ok(KeccakDigest(output)) } fn compress>( @@ -100,8 +100,8 @@ impl TwoToOneCRHScheme for SHA3TwoToOneCRHScheme { } } -pub type LeafH = SHA3LeafHash; -pub type CompressH = SHA3TwoToOneCRHScheme; +pub type LeafH = KeccakLeafHash; +pub type CompressH = KeccakTwoToOneCRHScheme; #[derive(Debug, Default, Clone)] pub struct MerkleTreeParams(PhantomData); @@ -110,7 +110,7 @@ impl Config for MerkleTreeParams { type Leaf = [F]; type LeafDigest = as CRHScheme>::Output; - type LeafInnerDigestConverter = IdentityDigestConverter; + type LeafInnerDigestConverter = IdentityDigestConverter; type InnerDigest = ::Output; type LeafHash = LeafH; diff --git a/src/crypto/merkle_tree/mod.rs b/src/crypto/merkle_tree/mod.rs index d057571..f316fea 100644 --- a/src/crypto/merkle_tree/mod.rs +++ b/src/crypto/merkle_tree/mod.rs @@ -1,7 +1,6 @@ -pub mod blake2; pub mod blake3; +pub mod keccak; pub mod mock; -pub mod sha3; use std::{borrow::Borrow, marker::PhantomData, sync::atomic::AtomicUsize}; From 87b2468a5a888f89b09777cea9a67e16c22d87e9 Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Mon, 9 Sep 2024 11:56:42 +0200 Subject: [PATCH 4/4] Parallel wavelet transform and cache-oblivious square transpose --- src/crypto/ntt.rs | 96 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 85 insertions(+), 11 deletions(-) diff --git a/src/crypto/ntt.rs b/src/crypto/ntt.rs index 23508a2..bdc6c8d 100644 --- a/src/crypto/ntt.rs +++ b/src/crypto/ntt.rs @@ -7,6 +7,7 @@ use ark_ff::{FftField, Field}; use std::{ any::{Any, TypeId}, collections::HashMap, + mem::swap, sync::{Arc, LazyLock, Mutex, RwLock, RwLockReadGuard}, }; @@ -364,6 +365,15 @@ pub fn wavelet_transform(values: &mut [F]) { pub fn wavelet_transform_batch(values: &mut [F], size: usize) { debug_assert_eq!(values.len() % size, 0); debug_assert!(size.is_power_of_two()); + #[cfg(feature = "parallel")] + if values.len() > NttEngine::::WORKLOAD_SIZE && values.len() != size { + // Multiple wavelet transforms, compute in parallel. + // Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`. + let workload_size = size * max(1, NttEngine::::WORKLOAD_SIZE / size); + return values.par_chunks_mut(workload_size).for_each(|values| { + wavelet_transform_batch(values, size); + }); + } match size { 0 | 1 => {} 2 => { @@ -395,6 +405,25 @@ pub fn wavelet_transform_batch(values: &mut [F], size: usize) { v[7] += v[3]; } } + 16 => { + for v in values.chunks_exact_mut(16) { + for v in v.chunks_exact_mut(4) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + } + let (a, v) = v.split_at_mut(4); + let (b, v) = v.split_at_mut(4); + let (c, d) = v.split_at_mut(4); + for i in 0..4 { + b[i] += a[i]; + d[i] += c[i]; + c[i] += a[i]; + d[i] += b[i]; + } + } + } n => { let n1 = 1 << (n.trailing_zeros() / 2); let n2 = n / n1; @@ -408,32 +437,77 @@ pub fn wavelet_transform_batch(values: &mut [F], size: usize) { /// Transpose a matrix in-place. /// Will batch transpose multiple matrices if the length of the slice is a multiple of rows * cols. -pub fn transpose(matrix: &mut [T], rows: usize, cols: usize) { +pub fn transpose(matrix: &mut [F], rows: usize, cols: usize) { debug_assert_eq!(matrix.len() % rows * cols, 0); if rows == cols { - // TODO: Cache-oblivious recursive parallel algorithm. for matrix in matrix.chunks_exact_mut(rows * cols) { - for i in 0..rows { - for j in (i + 1)..cols { - matrix.swap(i * cols + j, j * rows + i); - } - } + transpose_square(matrix, rows, cols); } } else { - // TODO: Re-use scratch space. - // TODO: Cache-oblivious recursive parallel algorithm. // TODO: Special case for rows = 2 * cols and cols = 2 * rows. + let mut scratch = vec![F::ZERO; rows * cols]; for matrix in matrix.chunks_exact_mut(rows * cols) { - let copy = matrix.to_vec(); + scratch.copy_from_slice(matrix); for i in 0..rows { for j in 0..cols { - matrix[j * rows + i] = copy[i * cols + j]; + matrix[j * rows + i] = scratch[i * cols + j]; } } } } } +// Transpose a square power-of-two matrix in-place. +fn transpose_square(matrix: &mut [F], size: usize, stride: usize) { + debug_assert!(matrix.len() >= (size - 1) * stride + size); + debug_assert!(size.is_power_of_two()); + if size * size > NttEngine::::WORKLOAD_SIZE { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (upper, lower) = matrix.split_at_mut(n * stride); + // Ideally we'd parallelize this, but its not possible to + // express the strided matrices without unsafe code. + transpose_square(upper, n, stride); + transpose_square_swap(&mut upper[n..], lower, n, stride); + transpose_square(&mut lower[n..], n, stride); + } else { + for i in 0..size { + for j in (i + 1)..size { + matrix.swap(i * stride + j, j * stride + i); + } + } + } +} + +/// Transpose and swap two square power-of-two size matrices. +fn transpose_square_swap(a: &mut [F], b: &mut [F], size: usize, stride: usize) { + debug_assert!(a.len() >= (size - 1) * stride + size); + debug_assert!(b.len() >= (size - 1) * stride + size); + debug_assert!(size.is_power_of_two()); + if size * size > NttEngine::::WORKLOAD_SIZE { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (a_upper, a_lower) = a.split_at_mut(n * stride); + let (b_upper, b_lower) = b.split_at_mut(n * stride); + // Ideally we'd parallelize this, but its not possible to + // express the strided matrices without unsafe code. + transpose_square_swap(a_upper, b_upper, n, stride); + transpose_square_swap(&mut a_upper[n..], b_lower, n, stride); + transpose_square_swap(a_lower, &mut b_upper[n..], n, stride); + transpose_square_swap(&mut a_lower[n..], &mut b_lower[n..], n, stride); + } else { + for i in 0..size { + for j in 0..size { + // The compiler does not eliminate the bounds checks here, + // but this doesn't matter as it is bottlenecked by memory bandwidth. + swap(&mut a[i * stride + j], &mut b[j * stride + i]); + } + } + } +} + /// Compute the largest factor of n that is <= sqrt(n). /// Assumes n is of the form 2^k * {1,3,9}. fn sqrt_factor(n: usize) -> usize {