From 89f2e8d4f47d64ac6fe59f2d3b6713e1c03d7ef3 Mon Sep 17 00:00:00 2001 From: sarah el kazdadi Date: Mon, 6 Feb 2023 12:47:05 +0100 Subject: [PATCH] feat: implement 128bit fft --- Cargo.toml | 13 +- README.md | 57 +- benches/bench.rs | 308 ++++++ benches/fft.rs | 190 ---- benches/lib.rs | 3 - src/fft128/f128_impl.rs | 1075 +++++++++++++++++++ src/fft128/mod.rs | 2204 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 7 + src/ordered.rs | 2 +- src/unordered.rs | 8 +- 10 files changed, 3640 insertions(+), 227 deletions(-) create mode 100644 benches/bench.rs delete mode 100644 benches/fft.rs delete mode 100644 benches/lib.rs create mode 100644 src/fft128/f128_impl.rs create mode 100644 src/fft128/mod.rs diff --git a/Cargo.toml b/Cargo.toml index c6675ff..c52fb55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,13 +13,16 @@ keywords = ["fft"] [dependencies] num-complex = "0.4" dyn-stack = { version = "0.8", default-features = false } +pulp = "0.10" +bytemuck = "1.13" aligned-vec = { version = "0.5", default-features = false } serde = { version = "1.0", optional = true, default-features = false } [features] -default = ["std"] -nightly = [] +default = ["std", "fft128"] +nightly = ["pulp/nightly", "bytemuck/nightly_stdsimd"] std = [] +fft128 = [] serde = ["dep:serde", "num-complex/serde"] [dev-dependencies] @@ -28,9 +31,13 @@ rustfft = "6.0" fftw-sys = { version = "0.6", default-features = false, features = ["system"] } rand = "0.8" bincode = "1.3" +more-asserts = "0.3.1" + +[target.'cfg(target_os = "linux")'.dev-dependencies] +rug = "1.19.0" [[bench]] -name = "fft" +name = "bench" harness = false [package.metadata.docs.rs] diff --git a/README.md b/README.md index f3002c6..2cec244 100644 --- a/README.md +++ b/README.md @@ -3,33 +3,38 @@ that processes vectors of sizes that are powers of two. It was made to be used as a backend in Zama's `concrete` library. This library provides two FFT modules: - - The ordered module FFT applies a forward/inverse FFT that takes its input in standard - order, and outputs the result in standard order. For more detail on what the FFT - computes, check the ordered module-level documentation. - - The unordered module FFT applies a forward FFT that takes its input in standard order, - and outputs the result in a certain permuted order that may depend on the FFT plan. On the - other hand, the inverse FFT takes its input in that same permuted order and outputs its result - in standard order. This is useful for cases where the order of the coefficients in the - Fourier domain is not important. An example is using the Fourier transform for vector - convolution. The only operations that are performed in the Fourier domain are elementwise, and - so the order of the coefficients does not affect the results. + +- The ordered module FFT applies a forward/inverse FFT that takes its input in standard + order, and outputs the result in standard order. For more detail on what the FFT + computes, check the ordered module-level documentation. +- The unordered module FFT applies a forward FFT that takes its input in standard order, + and outputs the result in a certain permuted order that may depend on the FFT plan. On the + other hand, the inverse FFT takes its input in that same permuted order and outputs its result + in standard order. This is useful for cases where the order of the coefficients in the + Fourier domain is not important. An example is using the Fourier transform for vector + convolution. The only operations that are performed in the Fourier domain are elementwise, and + so the order of the coefficients does not affect the results. + +Additionally, an optional 128-bit negacyclic FFT module is provided. ## Features - - `std` (default): This enables runtime arch detection for accelerated SIMD - instructions, and an FFT plan that measures the various implementations to - choose the fastest one at runtime. - - `nightly`: This enables unstable Rust features to further speed up the FFT, - by enabling AVX512F instructions on CPUs that support them. This feature - requires a nightly Rust - toolchain. - - `serde`: This enables serialization and deserialization functions for the - unordered plan. These allow for data in the Fourier domain to be serialized - from the permuted order to the standard order, and deserialized from the - standard order to the permuted order. This is needed since the inverse - transform must be used with the same plan that computed/deserialized the - forward transform (or more specifically, a plan with the same internal base - FFT size). +- `std` (default): This enables runtime arch detection for accelerated SIMD + instructions, and an FFT plan that measures the various implementations to + choose the fastest one at runtime. +- `fft128` (default): This flag provides access to the 128-bit FFT, which is accessible in the + `concrete_fft::fft128` module. +- `nightly`: This enables unstable Rust features to further speed up the FFT, + by enabling AVX512F instructions on CPUs that support them. This feature + requires a nightly Rust + toolchain. +- `serde`: This enables serialization and deserialization functions for the + unordered plan. These allow for data in the Fourier domain to be serialized + from the permuted order to the standard order, and deserialized from the + standard order to the permuted order. This is needed since the inverse + transform must be used with the same plan that computed/deserialized the + forward transform (or more specifically, a plan with the same internal base + FFT size). ## Example @@ -65,8 +70,8 @@ for (actual, expected) in transformed_inv.iter().map(|z| z / N as f64).zip(data) ## Links - - [Zama](https://www.zama.ai/) - - [Concrete](https://github.com/zama-ai/concrete) +- [Zama](https://www.zama.ai/) +- [Concrete](https://github.com/zama-ai/concrete) ## License diff --git a/benches/bench.rs b/benches/bench.rs new file mode 100644 index 0000000..53aac52 --- /dev/null +++ b/benches/bench.rs @@ -0,0 +1,308 @@ +use concrete_fft::c64; +use core::ptr::NonNull; +use criterion::{criterion_group, criterion_main, Criterion}; +use dyn_stack::{DynStack, ReborrowMut, StackReq}; + +struct FftwAlloc { + bytes: NonNull, +} + +impl Drop for FftwAlloc { + fn drop(&mut self) { + unsafe { + fftw_sys::fftw_free(self.bytes.as_ptr()); + } + } +} + +impl FftwAlloc { + pub fn new(size_bytes: usize) -> FftwAlloc { + unsafe { + let bytes = fftw_sys::fftw_malloc(size_bytes); + if bytes.is_null() { + use std::alloc::{handle_alloc_error, Layout}; + handle_alloc_error(Layout::from_size_align_unchecked(size_bytes, 1)); + } + FftwAlloc { + bytes: NonNull::new_unchecked(bytes), + } + } + } +} + +pub struct PlanInterleavedC64 { + plan: fftw_sys::fftw_plan, + n: usize, +} + +impl Drop for PlanInterleavedC64 { + fn drop(&mut self) { + unsafe { + fftw_sys::fftw_destroy_plan(self.plan); + } + } +} + +pub enum Sign { + Forward, + Backward, +} + +impl PlanInterleavedC64 { + pub fn new(n: usize, sign: Sign) -> Self { + let size_bytes = n.checked_mul(core::mem::size_of::()).unwrap(); + let src = FftwAlloc::new(size_bytes); + let dst = FftwAlloc::new(size_bytes); + unsafe { + let p = fftw_sys::fftw_plan_dft_1d( + n.try_into().unwrap(), + src.bytes.as_ptr() as _, + dst.bytes.as_ptr() as _, + match sign { + Sign::Forward => fftw_sys::FFTW_FORWARD as _, + Sign::Backward => fftw_sys::FFTW_BACKWARD as _, + }, + fftw_sys::FFTW_MEASURE, + ); + PlanInterleavedC64 { plan: p, n } + } + } + + pub fn print(&self) { + unsafe { + fftw_sys::fftw_print_plan(self.plan); + } + } + + pub fn execute(&self, src: &mut [c64], dst: &mut [c64]) { + assert_eq!(src.len(), self.n); + assert_eq!(dst.len(), self.n); + let src = src.as_mut_ptr(); + let dst = dst.as_mut_ptr(); + unsafe { + use fftw_sys::{fftw_alignment_of, fftw_execute_dft}; + assert_eq!(fftw_alignment_of(src as _), 0); + assert_eq!(fftw_alignment_of(dst as _), 0); + fftw_execute_dft(self.plan, src as _, dst as _); + } + } +} + +pub fn criterion_benchmark(c: &mut Criterion) { + for n in [ + 1 << 8, + 1 << 9, + 1 << 10, + 1 << 11, + 1 << 12, + 1 << 13, + 1 << 14, + 1 << 15, + 1 << 16, + ] { + let mut mem = dyn_stack::GlobalMemBuffer::new( + StackReq::new_aligned::(n, 64) // scratch + .and( + StackReq::new_aligned::(2 * n, 64).or(StackReq::new_aligned::(n, 64)), // src | twiddles + ) + .and(StackReq::new_aligned::(n, 64)), // dst + ); + let mut stack = DynStack::new(&mut mem); + let z = c64::new(0.0, 0.0); + + { + let mut scratch = []; + let bench_duration = std::time::Duration::from_millis(10); + + { + let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + let (mut src, _) = stack.make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("rustfft-fwd-{n}"), |b| { + use rustfft::FftPlannerAvx; + let mut planner = FftPlannerAvx::::new().unwrap(); + let fwd_rustfft = planner.plan_fft_forward(n); + + b.iter(|| { + fwd_rustfft.process_outofplace_with_scratch( + &mut src, + &mut dst, + &mut scratch, + ) + }) + }); + + c.bench_function(&format!("fftw-fwd-{n}"), |b| { + let fwd_fftw = PlanInterleavedC64::new(n, Sign::Forward); + + b.iter(|| { + fwd_fftw.execute(&mut src, &mut dst); + }) + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("concrete-fwd-{n}"), |b| { + let ordered = concrete_fft::ordered::Plan::new( + n, + concrete_fft::ordered::Method::Measure(bench_duration), + ); + + b.iter(|| ordered.fwd(&mut dst, stack.rb_mut())) + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("unordered-fwd-{n}"), |b| { + let unordered = concrete_fft::unordered::Plan::new( + n, + concrete_fft::unordered::Method::Measure(bench_duration), + ); + + b.iter(|| unordered.fwd(&mut dst, stack.rb_mut())); + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("unordered-inv-{n}"), |b| { + let unordered = concrete_fft::unordered::Plan::new( + n, + concrete_fft::unordered::Method::Measure(bench_duration), + ); + + b.iter(|| unordered.inv(&mut dst, stack.rb_mut())); + }); + } + } + + // memcpy + { + let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + let (src, _) = stack.make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("memcpy-{n}"), |b| { + b.iter(|| unsafe { + std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), n); + }) + }); + } + } + + use concrete_fft::fft128::*; + for n in [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] { + let twid_re0 = vec![0.0; n]; + let twid_re1 = vec![0.0; n]; + let twid_im0 = vec![0.0; n]; + let twid_im1 = vec![0.0; n]; + + let mut data_re0 = vec![0.0; n]; + let mut data_re1 = vec![0.0; n]; + let mut data_im0 = vec![0.0; n]; + let mut data_im1 = vec![0.0; n]; + + c.bench_function(&format!("concrete-fft128-fwd-{n}"), |bench| { + bench.iter(|| { + negacyclic_fwd_fft_scalar( + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + + c.bench_function(&format!("concrete-fft128-inv-{n}"), |bench| { + bench.iter(|| { + negacyclic_inv_fft_scalar( + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if let Some(simd) = Avx::try_new() { + c.bench_function(&format!("concrete-fft128-avx-fwd-{n}"), |bench| { + bench.iter(|| { + negacyclic_fwd_fft_avxfma( + simd, + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + c.bench_function(&format!("concrete-fft128-avx-inv-{n}"), |bench| { + bench.iter(|| { + negacyclic_inv_fft_avxfma( + simd, + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[cfg(feature = "nightly")] + if let Some(simd) = Avx512::try_new() { + c.bench_function(&format!("concrete-fft128-avx512-fwd-{n}"), |bench| { + bench.iter(|| { + negacyclic_fwd_fft_avx512( + simd, + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + c.bench_function(&format!("concrete-fft128-avx512-inv-{n}"), |bench| { + bench.iter(|| { + negacyclic_inv_fft_avx512( + simd, + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/benches/fft.rs b/benches/fft.rs deleted file mode 100644 index 7d93b6b..0000000 --- a/benches/fft.rs +++ /dev/null @@ -1,190 +0,0 @@ -use concrete_fft::c64; -use core::ptr::NonNull; -use criterion::{criterion_group, criterion_main, Criterion}; -use dyn_stack::{DynStack, ReborrowMut, StackReq}; - -struct FftwAlloc { - bytes: NonNull, -} - -impl Drop for FftwAlloc { - fn drop(&mut self) { - unsafe { - fftw_sys::fftw_free(self.bytes.as_ptr()); - } - } -} - -impl FftwAlloc { - pub fn new(size_bytes: usize) -> FftwAlloc { - unsafe { - let bytes = fftw_sys::fftw_malloc(size_bytes); - if bytes.is_null() { - use std::alloc::{handle_alloc_error, Layout}; - handle_alloc_error(Layout::from_size_align_unchecked(size_bytes, 1)); - } - FftwAlloc { - bytes: NonNull::new_unchecked(bytes), - } - } - } -} - -pub struct PlanInterleavedC64 { - plan: fftw_sys::fftw_plan, - n: usize, -} - -impl Drop for PlanInterleavedC64 { - fn drop(&mut self) { - unsafe { - fftw_sys::fftw_destroy_plan(self.plan); - } - } -} - -pub enum Sign { - Forward, - Backward, -} - -impl PlanInterleavedC64 { - pub fn new(n: usize, sign: Sign) -> Self { - let size_bytes = n.checked_mul(core::mem::size_of::()).unwrap(); - let src = FftwAlloc::new(size_bytes); - let dst = FftwAlloc::new(size_bytes); - unsafe { - let p = fftw_sys::fftw_plan_dft_1d( - n.try_into().unwrap(), - src.bytes.as_ptr() as _, - dst.bytes.as_ptr() as _, - match sign { - Sign::Forward => fftw_sys::FFTW_FORWARD as _, - Sign::Backward => fftw_sys::FFTW_BACKWARD as _, - }, - fftw_sys::FFTW_MEASURE, - ); - PlanInterleavedC64 { plan: p, n } - } - } - - pub fn print(&self) { - unsafe { - fftw_sys::fftw_print_plan(self.plan); - } - } - - pub fn execute(&self, src: &mut [c64], dst: &mut [c64]) { - assert_eq!(src.len(), self.n); - assert_eq!(dst.len(), self.n); - let src = src.as_mut_ptr(); - let dst = dst.as_mut_ptr(); - unsafe { - use fftw_sys::{fftw_alignment_of, fftw_execute_dft}; - assert_eq!(fftw_alignment_of(src as _), 0); - assert_eq!(fftw_alignment_of(dst as _), 0); - fftw_execute_dft(self.plan, src as _, dst as _); - } - } -} - -pub fn criterion_benchmark(c: &mut Criterion) { - for n in [ - 1 << 8, - 1 << 9, - 1 << 10, - 1 << 11, - 1 << 12, - 1 << 13, - 1 << 14, - 1 << 15, - 1 << 16, - ] { - let mut mem = dyn_stack::GlobalMemBuffer::new( - StackReq::new_aligned::(n, 64) // scratch - .and( - StackReq::new_aligned::(2 * n, 64).or(StackReq::new_aligned::(n, 64)), // src | twiddles - ) - .and(StackReq::new_aligned::(n, 64)), // dst - ); - let mut stack = DynStack::new(&mut mem); - let z = c64::new(0.0, 0.0); - - { - use rustfft::FftPlannerAvx; - let mut planner = FftPlannerAvx::::new().unwrap(); - - let fwd_rustfft = planner.plan_fft_forward(n); - let mut scratch = []; - - let fwd_fftw = PlanInterleavedC64::new(n, Sign::Forward); - - let bench_duration = std::time::Duration::from_millis(10); - let ordered = concrete_fft::ordered::Plan::new( - n, - concrete_fft::ordered::Method::Measure(bench_duration), - ); - let unordered = concrete_fft::unordered::Plan::new( - n, - concrete_fft::unordered::Method::Measure(bench_duration), - ); - - { - let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - let (mut src, _) = stack.make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("rustfft-fwd-{}", n), |b| { - b.iter(|| { - fwd_rustfft.process_outofplace_with_scratch( - &mut src, - &mut dst, - &mut scratch, - ) - }) - }); - - c.bench_function(&format!("fftw-fwd-{}", n), |b| { - b.iter(|| { - fwd_fftw.execute(&mut src, &mut dst); - }) - }); - } - { - let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("concrete-fwd-{}", n), |b| { - b.iter(|| ordered.fwd(&mut *dst, stack.rb_mut())) - }); - } - { - let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("unordered-fwd-{}", n), |b| { - b.iter(|| unordered.fwd(&mut dst, stack.rb_mut())); - }); - } - { - let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("unordered-inv-{}", n), |b| { - b.iter(|| unordered.inv(&mut dst, stack.rb_mut())); - }); - } - } - - // memcpy - { - let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - let (src, _) = stack.make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("memcpy-{}", n), |b| { - b.iter(|| unsafe { - std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), n); - }) - }); - } - } -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/benches/lib.rs b/benches/lib.rs deleted file mode 100644 index 2ee2f80..0000000 --- a/benches/lib.rs +++ /dev/null @@ -1,3 +0,0 @@ -#![allow(dead_code)] - -mod fft; diff --git a/src/fft128/f128_impl.rs b/src/fft128/f128_impl.rs new file mode 100644 index 0000000..6b426f0 --- /dev/null +++ b/src/fft128/f128_impl.rs @@ -0,0 +1,1075 @@ +use super::f128; + +/// Computes $\operatorname{fl}(a+b)$ and $\operatorname{err}(a+b)$. +/// Assumes $|a| \geq |b|$. +#[inline(always)] +fn quick_two_sum(a: f64, b: f64) -> (f64, f64) { + let s = a + b; + (s, b - (s - a)) +} + +/// Computes $\operatorname{fl}(a-b)$ and $\operatorname{err}(a-b)$. +/// Assumes $|a| \geq |b|$. +#[allow(dead_code)] +#[inline(always)] +fn quick_two_diff(a: f64, b: f64) -> (f64, f64) { + let s = a - b; + (s, (a - s) - b) +} + +/// Computes $\operatorname{fl}(a+b)$ and $\operatorname{err}(a+b)$. +#[inline(always)] +fn two_sum(a: f64, b: f64) -> (f64, f64) { + let s = a + b; + let bb = s - a; + (s, (a - (s - bb)) + (b - bb)) +} + +/// Computes $\operatorname{fl}(a-b)$ and $\operatorname{err}(a-b)$. +#[inline(always)] +fn two_diff(a: f64, b: f64) -> (f64, f64) { + let s = a - b; + let bb = s - a; + (s, (a - (s - bb)) - (b + bb)) +} + +#[inline(always)] +fn two_prod(a: f64, b: f64) -> (f64, f64) { + let p = a * b; + (p, f64::mul_add(a, b, -p)) +} + +use core::{ + cmp::Ordering, + convert::From, + ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +impl From for f128 { + #[inline(always)] + fn from(value: f64) -> Self { + Self(value, 0.0) + } +} + +impl Add for f128 { + type Output = f128; + + #[inline(always)] + fn add(self, rhs: f128) -> Self::Output { + f128::add_f128_f128(self, rhs) + } +} + +impl Add for f128 { + type Output = f128; + + #[inline(always)] + fn add(self, rhs: f64) -> Self::Output { + f128::add_f128_f64(self, rhs) + } +} + +impl Add for f64 { + type Output = f128; + + #[inline(always)] + fn add(self, rhs: f128) -> Self::Output { + f128::add_f64_f128(self, rhs) + } +} + +impl AddAssign for f128 { + #[inline(always)] + fn add_assign(&mut self, rhs: f64) { + *self = *self + rhs + } +} + +impl AddAssign for f128 { + #[inline(always)] + fn add_assign(&mut self, rhs: f128) { + *self = *self + rhs + } +} + +impl Sub for f128 { + type Output = f128; + + #[inline(always)] + fn sub(self, rhs: f128) -> Self::Output { + f128::sub_f128_f128(self, rhs) + } +} + +impl Sub for f128 { + type Output = f128; + + #[inline(always)] + fn sub(self, rhs: f64) -> Self::Output { + f128::sub_f128_f64(self, rhs) + } +} + +impl Sub for f64 { + type Output = f128; + + #[inline(always)] + fn sub(self, rhs: f128) -> Self::Output { + f128::sub_f64_f128(self, rhs) + } +} + +impl SubAssign for f128 { + #[inline(always)] + fn sub_assign(&mut self, rhs: f64) { + *self = *self - rhs + } +} + +impl SubAssign for f128 { + #[inline(always)] + fn sub_assign(&mut self, rhs: f128) { + *self = *self - rhs + } +} + +impl Mul for f128 { + type Output = f128; + + #[inline(always)] + fn mul(self, rhs: f128) -> Self::Output { + f128::mul_f128_f128(self, rhs) + } +} + +impl Mul for f128 { + type Output = f128; + + #[inline(always)] + fn mul(self, rhs: f64) -> Self::Output { + f128::mul_f128_f64(self, rhs) + } +} + +impl Mul for f64 { + type Output = f128; + + #[inline(always)] + fn mul(self, rhs: f128) -> Self::Output { + f128::mul_f64_f128(self, rhs) + } +} + +impl MulAssign for f128 { + #[inline(always)] + fn mul_assign(&mut self, rhs: f64) { + *self = *self * rhs + } +} + +impl MulAssign for f128 { + #[inline(always)] + fn mul_assign(&mut self, rhs: f128) { + *self = *self * rhs + } +} + +impl Div for f128 { + type Output = f128; + + #[inline(always)] + fn div(self, rhs: f128) -> Self::Output { + f128::div_f128_f128(self, rhs) + } +} + +impl Div for f128 { + type Output = f128; + + #[inline(always)] + fn div(self, rhs: f64) -> Self::Output { + f128::div_f128_f64(self, rhs) + } +} + +impl Div for f64 { + type Output = f128; + + #[inline(always)] + fn div(self, rhs: f128) -> Self::Output { + f128::div_f64_f128(self, rhs) + } +} + +impl DivAssign for f128 { + #[inline(always)] + fn div_assign(&mut self, rhs: f64) { + *self = *self / rhs + } +} + +impl DivAssign for f128 { + #[inline(always)] + fn div_assign(&mut self, rhs: f128) { + *self = *self / rhs + } +} + +impl Neg for f128 { + type Output = f128; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self(-self.0, -self.1) + } +} + +impl PartialEq for f128 { + #[inline(always)] + fn eq(&self, other: &f128) -> bool { + matches!((self.0 == other.0, self.1 == other.1), (true, true)) + } +} + +impl PartialEq for f128 { + #[inline(always)] + fn eq(&self, other: &f64) -> bool { + (*self).eq(&f128(*other, 0.0)) + } +} + +impl PartialEq for f64 { + #[inline(always)] + fn eq(&self, other: &f128) -> bool { + (*other).eq(self) + } +} + +impl PartialOrd for f128 { + #[inline(always)] + fn partial_cmp(&self, other: &f128) -> Option { + let first_cmp = self.0.partial_cmp(&other.0); + let second_cmp = self.1.partial_cmp(&other.1); + + match first_cmp { + Some(Ordering::Equal) => second_cmp, + _ => first_cmp, + } + } +} + +impl PartialOrd for f128 { + #[inline(always)] + fn partial_cmp(&self, other: &f64) -> Option { + (*self).partial_cmp(&f128(*other, 0.0)) + } +} + +impl PartialOrd for f64 { + #[inline(always)] + fn partial_cmp(&self, other: &f128) -> Option { + f128(*self, 0.0).partial_cmp(other) + } +} + +impl f128 { + /// Adds `a` and `b` and returns the result. + #[inline(always)] + pub fn add_f64_f64(a: f64, b: f64) -> Self { + let (s, e) = two_sum(a, b); + Self(s, e) + } + + /// Adds `a` and `b` and returns the result. + #[inline(always)] + pub fn add_f128_f64(a: f128, b: f64) -> Self { + let (s1, s2) = two_sum(a.0, b); + let s2 = s2 + a.1; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Adds `a` and `b` and returns the result. + #[inline(always)] + pub fn add_f64_f128(a: f64, b: f128) -> Self { + Self::add_f128_f64(b, a) + } + + /// Adds `a` and `b` and returns the result. + /// This function has a slightly higher error bound than [`Self::add_f128_f128`] + #[inline(always)] + pub fn add_estimate_f128_f128(a: f128, b: f128) -> Self { + let (s, e) = two_sum(a.0, b.0); + let e = e + (a.1 + b.1); + let (s, e) = quick_two_sum(s, e); + Self(s, e) + } + + /// Adds `a` and `b` and returns the result. + #[inline(always)] + pub fn add_f128_f128(a: f128, b: f128) -> Self { + let (s1, s2) = two_sum(a.0, b.0); + let (t1, t2) = two_sum(a.1, b.1); + + let s2 = s2 + t1; + let (s1, s2) = quick_two_sum(s1, s2); + let s2 = s2 + t2; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Subtracts `b` from `a` and returns the result. + #[inline(always)] + pub fn sub_f64_f64(a: f64, b: f64) -> Self { + let (s, e) = two_diff(a, b); + Self(s, e) + } + + /// Subtracts `b` from `a` and returns the result. + #[inline(always)] + pub fn sub_f128_f64(a: f128, b: f64) -> Self { + let (s1, s2) = two_diff(a.0, b); + let s2 = s2 + a.1; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Subtracts `b` from `a` and returns the result. + #[inline(always)] + pub fn sub_f64_f128(a: f64, b: f128) -> Self { + let (s1, s2) = two_diff(a, b.0); + let s2 = s2 - b.1; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Subtracts `b` from `a` and returns the result. + /// This function has a slightly higher error bound than [`Self::sub_f128_f128`] + #[inline(always)] + pub fn sub_estimate_f128_f128(a: f128, b: f128) -> Self { + let (s, e) = two_diff(a.0, b.0); + let e = e + a.1; + let e = e - b.1; + let (s, e) = quick_two_sum(s, e); + Self(s, e) + } + + /// Subtracts `b` from `a` and returns the result. + #[inline(always)] + pub fn sub_f128_f128(a: f128, b: f128) -> Self { + let (s1, s2) = two_diff(a.0, b.0); + let (t1, t2) = two_diff(a.1, b.1); + + let s2 = s2 + t1; + let (s1, s2) = quick_two_sum(s1, s2); + let s2 = s2 + t2; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Multiplies `a` and `b` and returns the result. + #[inline(always)] + pub fn mul_f64_f64(a: f64, b: f64) -> Self { + let (p, e) = two_prod(a, b); + Self(p, e) + } + + /// Multiplies `a` and `b` and returns the result. + #[inline(always)] + pub fn mul_f128_f64(a: f128, b: f64) -> Self { + let (p1, p2) = two_prod(a.0, b); + let p2 = p2 + (a.1 * b); + let (p1, p2) = quick_two_sum(p1, p2); + Self(p1, p2) + } + + /// Multiplies `a` and `b` and returns the result. + #[inline(always)] + pub fn mul_f64_f128(a: f64, b: f128) -> Self { + Self::mul_f128_f64(b, a) + } + + /// Multiplies `a` and `b` and returns the result. + #[inline(always)] + pub fn mul_f128_f128(a: f128, b: f128) -> Self { + let (p1, p2) = two_prod(a.0, b.0); + let p2 = p2 + (a.0 * b.1 + a.1 * b.0); + let (p1, p2) = quick_two_sum(p1, p2); + Self(p1, p2) + } + + /// Squares `self` and returns the result. + #[inline(always)] + pub fn sqr(self) -> Self { + let (p1, p2) = two_prod(self.0, self.0); + let p2 = p2 + 2.0 * (self.0 * self.1); + let (p1, p2) = quick_two_sum(p1, p2); + Self(p1, p2) + } + + /// Divides `a` by `b` and returns the result. + #[inline(always)] + pub fn div_f64_f64(a: f64, b: f64) -> Self { + let q1 = a / b; + + // Compute a - q1 * b + let (p1, p2) = two_prod(q1, b); + let (s, e) = two_diff(a, p1); + let e = e - p2; + + // get next approximation + let q2 = (s + e) / b; + + let (s, e) = quick_two_sum(q1, q2); + f128(s, e) + } + + /// Divides `a` by `b` and returns the result. + #[inline(always)] + pub fn div_f128_f64(a: f128, b: f64) -> Self { + // approximate quotient + let q1 = a.0 / b; + + // Compute a - q1 * b + let (p1, p2) = two_prod(q1, b); + let (s, e) = two_diff(a.0, p1); + let e = e + a.1; + let e = e - p2; + + // get next approximation + let q2 = (s + e) / b; + + // renormalize + let (r0, r1) = quick_two_sum(q1, q2); + Self(r0, r1) + } + + /// Divides `a` by `b` and returns the result. + #[inline(always)] + pub fn div_f64_f128(a: f64, b: f128) -> Self { + Self::div_f128_f128(a.into(), b) + } + + /// Divides `a` by `b` and returns the result. + /// This function has a slightly higher error bound than [`Self::div_f128_f128`] + #[inline(always)] + pub fn div_estimate_f128_f128(a: f128, b: f128) -> Self { + // approximate quotient + let q1 = a.0 / b.0; + + // compute a - q1 * b + let r = b * q1; + let (s1, s2) = two_diff(a.0, r.0); + let s2 = s2 - r.1; + let s2 = s2 + a.1; + + // get next approximation + let q2 = (s1 + s2) / b.0; + + // renormalize + let (r0, r1) = quick_two_sum(q1, q2); + Self(r0, r1) + } + + /// Divides `a` by `b` and returns the result. + #[inline(always)] + pub fn div_f128_f128(a: f128, b: f128) -> Self { + // approximate quotient + let q1 = a.0 / b.0; + + let r = a - b * q1; + + let q2 = r.0 / b.0; + let r = r - q2 * b; + + let q3 = r.0 / b.0; + + let (q1, q2) = quick_two_sum(q1, q2); + Self(q1, q2) + q3 + } + + /// Casts `self` to an `f64`. + #[inline(always)] + pub fn to_f64(self) -> f64 { + self.0 + } + + /// Checks if `self` is `NaN`. + #[inline(always)] + pub fn is_nan(self) -> bool { + !matches!((self.0.is_nan(), self.1.is_nan()), (false, false)) + } + + /// Returns the absolute value of `self`. + #[inline(always)] + pub fn abs(self) -> Self { + if self.0 < 0.0 { + -self + } else { + self + } + } + + fn sincospi_taylor(self) -> (Self, Self) { + let mut sinc = Self::PI; + let mut cos = f128(1.0, 0.0); + + let sqr = self.sqr(); + let mut pow = f128(1.0, 0.0); + for (s, c) in Self::SINPI_TAYLOR + .iter() + .copied() + .zip(Self::COSPI_TAYLOR.iter().copied()) + { + pow *= sqr; + sinc += s * pow; + cos += c * pow; + } + + (sinc * self, cos) + } + + /// Takes and input in `(-1.0, 1.0)`, and returns the sine and cosine of `self`. + pub fn sincospi(self) -> (Self, Self) { + #[allow(clippy::manual_range_contains)] + if self > 1.0 || self < -1.0 { + panic!("only inputs in [-1, 1] are currently supported, received: {self:?}"); + } + // approximately reduce modulo 1/2 + let p = (self.0 * 2.0).round(); + let r = self - p * 0.5; + + // approximately reduce modulo 1/16 + let q = (r.0 * 16.0).round(); + let r = r - q * (1.0 / 16.0); + + let p = p as isize; + let q = q as isize; + + let q_abs = q.unsigned_abs(); + + let (sin_r, cos_r) = r.sincospi_taylor(); + + let (s, c) = if q == 0 { + (sin_r, cos_r) + } else { + let u = Self::COS_K_PI_OVER_16_TABLE[q_abs - 1]; + let v = Self::SIN_K_PI_OVER_16_TABLE[q_abs - 1]; + if q > 0 { + (u * sin_r + v * cos_r, u * cos_r - v * sin_r) + } else { + (u * sin_r - v * cos_r, u * cos_r + v * sin_r) + } + }; + + if p == 0 { + (s, c) + } else if p == 1 { + (c, -s) + } else if p == -1 { + (-c, s) + } else { + (-s, -c) + } + } +} + +#[allow(clippy::approx_constant)] +impl f128 { + pub const PI: Self = f128(3.141592653589793, 1.2246467991473532e-16); + + const SINPI_TAYLOR: &'static [Self; 9] = &[ + f128(-5.16771278004997, 2.2665622825789447e-16), + f128(2.5501640398773455, -7.931006345326556e-17), + f128(-0.5992645293207921, 2.845026112698218e-17), + f128(0.08214588661112823, -3.847292805297656e-18), + f128(-0.0073704309457143504, -3.328281165603432e-19), + f128(0.00046630280576761255, 1.0704561733683463e-20), + f128(-2.1915353447830217e-5, 1.4648526682685598e-21), + f128(7.952054001475513e-7, 1.736540361519021e-23), + f128(-2.2948428997269873e-8, -7.376346207041088e-26), + ]; + + const COSPI_TAYLOR: &'static [Self; 9] = &[ + f128(-4.934802200544679, -3.1326477543698557e-16), + f128(4.0587121264167685, -2.6602000824298645e-16), + f128(-1.3352627688545895, 3.1815237892149862e-18), + f128(0.2353306303588932, -1.2583065576724427e-18), + f128(-0.02580689139001406, 1.170191067939226e-18), + f128(0.0019295743094039231, -9.669517939986956e-20), + f128(-0.0001046381049248457, -2.421206183964864e-21), + f128(4.303069587032947e-6, -2.864010082936791e-22), + f128(-1.3878952462213771e-7, -7.479362090417238e-24), + ]; + + const SIN_K_PI_OVER_16_TABLE: &'static [Self; 4] = &[ + f128(0.19509032201612828, -7.991079068461731e-18), + f128(0.3826834323650898, -1.0050772696461588e-17), + f128(0.5555702330196022, 4.709410940561677e-17), + f128(0.7071067811865476, -4.833646656726457e-17), + ]; + + const COS_K_PI_OVER_16_TABLE: &'static [Self; 4] = &[ + f128(0.9807852804032304, 1.8546939997825006e-17), + f128(0.9238795325112867, 1.7645047084336677e-17), + f128(0.8314696123025452, 1.4073856984728024e-18), + f128(0.7071067811865476, -4.833646656726457e-17), + ]; +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64", doc))] +#[cfg_attr(docsrs, doc(cfg(any(target_arch = "x86", target_arch = "x86_64"))))] +pub mod x86 { + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; + + pulp::simd_type! { + /// Avx SIMD type. + pub struct Avx { + pub sse: "sse", + pub sse2: "sse2", + pub avx: "avx", + pub fma: "fma", + } + } + + #[cfg(feature = "nightly")] + pulp::simd_type! { + /// Avx512 SIMD type. + pub struct Avx512 { + pub sse: "sse", + pub sse2: "sse2", + pub avx: "avx", + pub avx2: "avx2", + pub fma: "fma", + pub avx512f: "avx512f", + } + } + + #[inline(always)] + pub(crate) fn _mm256_quick_two_sum(simd: Avx, a: __m256d, b: __m256d) -> (__m256d, __m256d) { + let s = simd.avx._mm256_add_pd(a, b); + (s, simd.avx._mm256_sub_pd(b, simd.avx._mm256_sub_pd(s, a))) + } + + #[inline(always)] + pub(crate) fn _mm256_two_sum(simd: Avx, a: __m256d, b: __m256d) -> (__m256d, __m256d) { + let s = simd.avx._mm256_add_pd(a, b); + let bb = simd.avx._mm256_sub_pd(s, a); + ( + s, + simd.avx._mm256_add_pd( + simd.avx._mm256_sub_pd(a, simd.avx._mm256_sub_pd(s, bb)), + simd.avx._mm256_sub_pd(b, bb), + ), + ) + } + + #[inline(always)] + pub(crate) fn _mm256_two_diff(simd: Avx, a: __m256d, b: __m256d) -> (__m256d, __m256d) { + let s = simd.avx._mm256_sub_pd(a, b); + let bb = simd.avx._mm256_sub_pd(s, a); + ( + s, + simd.avx._mm256_sub_pd( + simd.avx._mm256_sub_pd(a, simd.avx._mm256_sub_pd(s, bb)), + simd.avx._mm256_add_pd(b, bb), + ), + ) + } + + #[inline(always)] + pub(crate) fn _mm256_two_prod(simd: Avx, a: __m256d, b: __m256d) -> (__m256d, __m256d) { + let p = simd.avx._mm256_mul_pd(a, b); + (p, simd.fma._mm256_fmsub_pd(a, b, p)) + } + + #[cfg(feature = "nightly")] + #[inline(always)] + pub(crate) fn _mm512_quick_two_sum(simd: Avx512, a: __m512d, b: __m512d) -> (__m512d, __m512d) { + let s = simd.avx512f._mm512_add_pd(a, b); + ( + s, + simd.avx512f + ._mm512_sub_pd(b, simd.avx512f._mm512_sub_pd(s, a)), + ) + } + + #[cfg(feature = "nightly")] + #[inline(always)] + pub(crate) fn _mm512_two_sum(simd: Avx512, a: __m512d, b: __m512d) -> (__m512d, __m512d) { + let s = simd.avx512f._mm512_add_pd(a, b); + let bb = simd.avx512f._mm512_sub_pd(s, a); + ( + s, + simd.avx512f._mm512_add_pd( + simd.avx512f + ._mm512_sub_pd(a, simd.avx512f._mm512_sub_pd(s, bb)), + simd.avx512f._mm512_sub_pd(b, bb), + ), + ) + } + + #[cfg(feature = "nightly")] + #[inline(always)] + pub(crate) fn _mm512_two_diff(simd: Avx512, a: __m512d, b: __m512d) -> (__m512d, __m512d) { + let s = simd.avx512f._mm512_sub_pd(a, b); + let bb = simd.avx512f._mm512_sub_pd(s, a); + ( + s, + simd.avx512f._mm512_sub_pd( + simd.avx512f + ._mm512_sub_pd(a, simd.avx512f._mm512_sub_pd(s, bb)), + simd.avx512f._mm512_add_pd(b, bb), + ), + ) + } + + #[cfg(feature = "nightly")] + #[inline(always)] + pub(crate) fn _mm512_two_prod(simd: Avx512, a: __m512d, b: __m512d) -> (__m512d, __m512d) { + let p = simd.avx512f._mm512_mul_pd(a, b); + (p, simd.avx512f._mm512_fmsub_pd(a, b, p)) + } + + impl Avx { + #[inline(always)] + pub fn _mm256_add_estimate_f128_f128( + self, + a0: __m256d, + a1: __m256d, + b0: __m256d, + b1: __m256d, + ) -> (__m256d, __m256d) { + let (s, e) = _mm256_two_sum(self, a0, b0); + let e = self.avx._mm256_add_pd(e, self.avx._mm256_add_pd(a1, b1)); + _mm256_quick_two_sum(self, s, e) + } + + #[inline(always)] + pub fn _mm256_sub_estimate_f128_f128( + self, + a0: __m256d, + a1: __m256d, + b0: __m256d, + b1: __m256d, + ) -> (__m256d, __m256d) { + let (s, e) = _mm256_two_diff(self, a0, b0); + let e = self.avx._mm256_add_pd(e, a1); + let e = self.avx._mm256_sub_pd(e, b1); + _mm256_quick_two_sum(self, s, e) + } + + #[inline(always)] + pub fn _mm256_mul_f128_f128( + self, + a0: __m256d, + a1: __m256d, + b0: __m256d, + b1: __m256d, + ) -> (__m256d, __m256d) { + let (p1, p2) = _mm256_two_prod(self, a0, b0); + let p2 = self.avx._mm256_add_pd( + p2, + self.avx._mm256_add_pd( + self.avx._mm256_mul_pd(a0, b1), + self.avx._mm256_mul_pd(a1, b0), + ), + ); + _mm256_quick_two_sum(self, p1, p2) + } + } + + #[cfg(feature = "nightly")] + impl Avx512 { + #[inline(always)] + pub fn _mm512_add_estimate_f128_f128( + self, + a0: __m512d, + a1: __m512d, + b0: __m512d, + b1: __m512d, + ) -> (__m512d, __m512d) { + let (s, e) = _mm512_two_sum(self, a0, b0); + let e = self + .avx512f + ._mm512_add_pd(e, self.avx512f._mm512_add_pd(a1, b1)); + _mm512_quick_two_sum(self, s, e) + } + + #[inline(always)] + pub fn _mm512_sub_estimate_f128_f128( + self, + a0: __m512d, + a1: __m512d, + b0: __m512d, + b1: __m512d, + ) -> (__m512d, __m512d) { + let (s, e) = _mm512_two_diff(self, a0, b0); + let e = self.avx512f._mm512_add_pd(e, a1); + let e = self.avx512f._mm512_sub_pd(e, b1); + _mm512_quick_two_sum(self, s, e) + } + + #[inline(always)] + pub fn _mm512_mul_f128_f128( + self, + a0: __m512d, + a1: __m512d, + b0: __m512d, + b1: __m512d, + ) -> (__m512d, __m512d) { + let (p1, p2) = _mm512_two_prod(self, a0, b0); + let p2 = self.avx512f._mm512_add_pd( + p2, + self.avx512f._mm512_add_pd( + self.avx512f._mm512_mul_pd(a0, b1), + self.avx512f._mm512_mul_pd(a1, b0), + ), + ); + _mm512_quick_two_sum(self, p1, p2) + } + } +} + +#[cfg(all(test, target_os = "linux"))] +mod tests { + use super::*; + use more_asserts::assert_le; + use rug::{ops::Pow, Float, Integer}; + + const PREC: u32 = 1024; + + fn float_to_f128(value: &Float) -> f128 { + let x0: f64 = value.to_f64(); + let diff = value.clone() - x0; + let x1 = diff.to_f64(); + f128(x0, x1) + } + + fn f128_to_float(value: f128) -> Float { + Float::with_val(PREC, value.0) + Float::with_val(PREC, value.1) + } + + #[test] + fn test_add() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(0u64)); + + for _ in 0..100 { + let a = Float::with_val(PREC, Float::random_normal(&mut rng)); + let b = Float::with_val(PREC, Float::random_normal(&mut rng)); + + let a_f128 = float_to_f128(&a); + let b_f128 = float_to_f128(&b); + let a = f128_to_float(a_f128); + let b = f128_to_float(b_f128); + + let sum = Float::with_val(PREC, &a + &b); + let sum_rug_f128 = float_to_f128(&sum); + let sum_f128 = a_f128 + b_f128; + + assert_le!( + (sum_f128 - sum_rug_f128).abs(), + 2.0f64.powi(-104) * sum_f128.abs() + ); + } + } + + #[test] + fn test_sub() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(1u64)); + + for _ in 0..100 { + let a = Float::with_val(PREC, Float::random_normal(&mut rng)); + let b = Float::with_val(PREC, Float::random_normal(&mut rng)); + + let a_f128 = float_to_f128(&a); + let b_f128 = float_to_f128(&b); + let a = f128_to_float(a_f128); + let b = f128_to_float(b_f128); + + let diff = Float::with_val(PREC, &a - &b); + let diff_rug_f128 = float_to_f128(&diff); + let diff_f128 = a_f128 - b_f128; + + assert_le!( + (diff_f128 - diff_rug_f128).abs(), + 2.0f64.powi(-104) * diff_f128.abs() + ); + } + } + + #[test] + fn test_mul() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(2u64)); + + for _ in 0..100 { + let a = Float::with_val(PREC, Float::random_normal(&mut rng)); + let b = Float::with_val(PREC, Float::random_normal(&mut rng)); + + let a_f128 = float_to_f128(&a); + let b_f128 = float_to_f128(&b); + let a = f128_to_float(a_f128); + let b = f128_to_float(b_f128); + + let prod = Float::with_val(PREC, &a * &b); + let prod_rug_f128 = float_to_f128(&prod); + let prod_f128 = a_f128 * b_f128; + + assert_le!( + (prod_f128 - prod_rug_f128).abs(), + 2.0f64.powi(-104) * prod_f128.abs() + ); + } + } + + #[test] + fn test_div() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(3u64)); + + for _ in 0..100 { + let a = Float::with_val(PREC, Float::random_normal(&mut rng)); + let b = Float::with_val(PREC, Float::random_normal(&mut rng)); + + let a_f128 = float_to_f128(&a); + let b_f128 = float_to_f128(&b); + let a = f128_to_float(a_f128); + let b = f128_to_float(b_f128); + + let quot = Float::with_val(PREC, &a / &b); + let quot_rug_f128 = float_to_f128("); + let quot_f128 = a_f128 / b_f128; + + assert_le!( + (quot_f128 - quot_rug_f128).abs(), + 2.0f64.powi(-104) * quot_f128.abs() + ); + } + } + + #[test] + fn test_sincos_taylor() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(4u64)); + + for _ in 0..10000 { + let a = (Float::with_val(PREC, Float::random_bits(&mut rng)) * 2.0 - 1.0) / 32; + let a_f128 = float_to_f128(&a); + let a = f128_to_float(a_f128); + + let sin = Float::with_val(PREC, a.clone().sin_pi()); + let cos = Float::with_val(PREC, a.clone().cos_pi()); + let sin_rug_f128 = float_to_f128(&sin); + let cos_rug_f128 = float_to_f128(&cos); + let (sin_f128, cos_f128) = a_f128.sincospi_taylor(); + assert_le!( + (cos_f128 - cos_rug_f128).abs(), + 2.0f64.powi(-103) * cos_f128.abs() + ); + assert_le!( + (sin_f128 - sin_rug_f128).abs(), + 2.0f64.powi(-103) * sin_f128.abs() + ); + } + } + + #[test] + fn test_sincos() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(5u64)); + + #[track_caller] + fn test_sincos(a: Float) { + let a_f128 = float_to_f128(&a); + let a = f128_to_float(a_f128); + + let sin = Float::with_val(PREC, a.clone().sin_pi()); + let cos = Float::with_val(PREC, a.cos_pi()); + let sin_rug_f128 = float_to_f128(&sin); + let cos_rug_f128 = float_to_f128(&cos); + let (sin_f128, cos_f128) = a_f128.sincospi(); + assert_le!( + (cos_f128 - cos_rug_f128).abs(), + 2.0f64.powi(-103) * cos_f128.abs() + ); + assert_le!( + (sin_f128 - sin_rug_f128).abs(), + 2.0f64.powi(-103) * sin_f128.abs() + ); + } + + test_sincos(Float::with_val(PREC, 0.00)); + test_sincos(Float::with_val(PREC, 0.25)); + test_sincos(Float::with_val(PREC, 0.50)); + test_sincos(Float::with_val(PREC, 0.75)); + test_sincos(Float::with_val(PREC, 1.00)); + + for _ in 0..10000 { + test_sincos(Float::with_val(PREC, Float::random_bits(&mut rng)) * 2.0 - 1.0); + } + } + + #[cfg(feature = "std")] + #[test] + fn generate_constants() { + let pi = Float::with_val(PREC, rug::float::Constant::Pi); + + println!(); + println!("###############################################################################"); + println!("impl f128 {{"); + println!(" pub const PI: Self = {:?};", float_to_f128(&pi)); + + println!(); + println!(" const SINPI_TAYLOR: &'static [Self; 9] = &["); + let mut factorial = 1_u64; + for i in 1..10 { + let k = 2 * i + 1; + factorial *= (k - 1) * k; + println!( + " {:?},", + (-1.0f64).powi(i as i32) * float_to_f128(&(pi.clone().pow(k) / factorial)), + ); + } + println!(" ];"); + + println!(); + println!(" const COSPI_TAYLOR: &'static [Self; 9] = &["); + let mut factorial = 1_u64; + for i in 1..10 { + let k = 2 * i; + factorial *= (k - 1) * k; + println!( + " {:?},", + (-1.0f64).powi(i as i32) * float_to_f128(&(pi.clone().pow(k) / factorial)), + ); + } + println!(" ];"); + + println!(); + println!(" const SIN_K_PI_OVER_16_TABLE: &'static [Self; 4] = &["); + for k in 1..5 { + let x: Float = Float::with_val(PREC, k as f64 / 16.0); + println!(" {:?},", float_to_f128(&x.clone().sin_pi()),); + } + println!(" ];"); + + println!(); + println!(" const COS_K_PI_OVER_16_TABLE: &'static [Self; 4] = &["); + for k in 1..5 { + let x: Float = Float::with_val(PREC, k as f64 / 16.0); + println!(" {:?},", float_to_f128(&x.clone().cos_pi()),); + } + println!(" ];"); + + println!("}}"); + println!("###############################################################################"); + assert_eq!(float_to_f128(&pi), f128::PI); + } +} diff --git a/src/fft128/mod.rs b/src/fft128/mod.rs new file mode 100644 index 0000000..eda2bbd --- /dev/null +++ b/src/fft128/mod.rs @@ -0,0 +1,2204 @@ +mod f128_impl; + +/// 128-bit floating point number. +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub struct f128(pub f64, pub f64); + +use aligned_vec::{avec, ABox}; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub use f128_impl::x86::Avx; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "nightly")] +#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))] +pub use f128_impl::x86::Avx512; + +use pulp::{as_arrays, as_arrays_mut, cast}; + +#[allow(unused_macros)] +macro_rules! izip { + (@ __closure @ ($a:expr)) => { |a| (a,) }; + (@ __closure @ ($a:expr, $b:expr)) => { |(a, b)| (a, b) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr)) => { |((a, b), c)| (a, b, c) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr)) => { |(((a, b), c), d)| (a, b, c, d) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr)) => { |((((a, b), c), d), e)| (a, b, c, d, e) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr)) => { |(((((a, b), c), d), e), f)| (a, b, c, d, e, f) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr)) => { |((((((a, b), c), d), e), f), g)| (a, b, c, d, e, f, g) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr)) => { |(((((((a, b), c), d), e), f), g), h)| (a, b, c, d, e, f, g, h) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr)) => { |((((((((a, b), c), d), e), f), g), h), i)| (a, b, c, d, e, f, g, h, i) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr)) => { |(((((((((a, b), c), d), e), f), g), h), i), j)| (a, b, c, d, e, f, g, h, i, j) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr)) => { |((((((((((a, b), c), d), e), f), g), h), i), j), k)| (a, b, c, d, e, f, g, h, i, j, k) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr)) => { |(((((((((((a, b), c), d), e), f), g), h), i), j), k), l)| (a, b, c, d, e, f, g, h, i, j, k, l) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr)) => { |((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m)| (a, b, c, d, e, f, g, h, i, j, k, l, m) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr)) => { |(((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr, $o:expr)) => { |((((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n), o)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) }; + + ( $first:expr $(,)?) => { + { + ::core::iter::IntoIterator::into_iter($first) + } + }; + ( $first:expr, $($rest:expr),+ $(,)?) => { + { + ::core::iter::IntoIterator::into_iter($first) + $(.zip($rest))* + .map(izip!(@ __closure @ ($first, $($rest),*))) + } + }; +} + +trait FftSimdF128: Copy { + type Reg: Copy + core::fmt::Debug; + + fn splat(self, value: f64) -> Self::Reg; + fn add(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg); + fn sub(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg); + fn mul(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg); +} + +#[derive(Copy, Clone)] +struct Scalar; + +impl FftSimdF128 for Scalar { + type Reg = f64; + + #[inline(always)] + fn splat(self, value: f64) -> Self::Reg { + value + } + + #[inline(always)] + fn add(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let f128(o0, o1) = f128(a.0, a.1) + f128(b.0, b.1); + (o0, o1) + } + + #[inline(always)] + fn sub(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let f128(o0, o1) = f128(a.0, a.1) - f128(b.0, b.1); + (o0, o1) + } + + #[inline(always)] + fn mul(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let f128(o0, o1) = f128(a.0, a.1) * f128(b.0, b.1); + (o0, o1) + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +impl FftSimdF128 for Avx { + type Reg = [f64; 4]; + + #[inline(always)] + fn splat(self, value: f64) -> Self::Reg { + cast(self.avx._mm256_set1_pd(value)) + } + + #[inline(always)] + fn add(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm256_add_estimate_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } + + #[inline(always)] + fn sub(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm256_sub_estimate_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } + + #[inline(always)] + fn mul(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm256_mul_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "nightly")] +impl FftSimdF128 for Avx512 { + type Reg = [f64; 8]; + + #[inline(always)] + fn splat(self, value: f64) -> Self::Reg { + cast(self.avx512f._mm512_set1_pd(value)) + } + + #[inline(always)] + fn add(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm512_add_estimate_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } + + #[inline(always)] + fn sub(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm512_sub_estimate_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } + + #[inline(always)] + fn mul(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm512_mul_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } +} + +trait FftSimdF128Ext: FftSimdF128 { + #[inline(always)] + fn cplx_add( + self, + a_re: (Self::Reg, Self::Reg), + a_im: (Self::Reg, Self::Reg), + b_re: (Self::Reg, Self::Reg), + b_im: (Self::Reg, Self::Reg), + ) -> ((Self::Reg, Self::Reg), (Self::Reg, Self::Reg)) { + (self.add(a_re, b_re), self.add(a_im, b_im)) + } + + #[inline(always)] + fn cplx_sub( + self, + a_re: (Self::Reg, Self::Reg), + a_im: (Self::Reg, Self::Reg), + b_re: (Self::Reg, Self::Reg), + b_im: (Self::Reg, Self::Reg), + ) -> ((Self::Reg, Self::Reg), (Self::Reg, Self::Reg)) { + (self.sub(a_re, b_re), self.sub(a_im, b_im)) + } + + /// `a * b` + #[inline(always)] + fn cplx_mul( + self, + a_re: (Self::Reg, Self::Reg), + a_im: (Self::Reg, Self::Reg), + b_re: (Self::Reg, Self::Reg), + b_im: (Self::Reg, Self::Reg), + ) -> ((Self::Reg, Self::Reg), (Self::Reg, Self::Reg)) { + let a_re_x_b_re = self.mul(a_re, b_re); + let a_re_x_b_im = self.mul(a_re, b_im); + let a_im_x_b_re = self.mul(a_im, b_re); + let a_im_x_b_im = self.mul(a_im, b_im); + + ( + self.sub(a_re_x_b_re, a_im_x_b_im), + self.add(a_im_x_b_re, a_re_x_b_im), + ) + } + + /// `a * conj(b)` + #[inline(always)] + fn cplx_mul_conj( + self, + a_re: (Self::Reg, Self::Reg), + a_im: (Self::Reg, Self::Reg), + b_re: (Self::Reg, Self::Reg), + b_im: (Self::Reg, Self::Reg), + ) -> ((Self::Reg, Self::Reg), (Self::Reg, Self::Reg)) { + let a_re_x_b_re = self.mul(a_re, b_re); + let a_re_x_b_im = self.mul(a_re, b_im); + let a_im_x_b_re = self.mul(a_im, b_re); + let a_im_x_b_im = self.mul(a_im, b_im); + + ( + self.add(a_re_x_b_re, a_im_x_b_im), + self.sub(a_im_x_b_re, a_re_x_b_im), + ) + } +} + +impl FftSimdF128Ext for T {} + +#[doc(hidden)] +pub fn negacyclic_fwd_fft_scalar( + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + let mut t = n; + let mut m = 1; + let simd = Scalar; + + while m < n { + t /= 2; + + for i in 0..m { + let w1_re = (twid_re0[m + i], twid_re1[m + i]); + let w1_im = (twid_im0[m + i], twid_im1[m + i]); + + let start = 2 * i * t; + + let data_re0 = &mut data_re0[start..][..2 * t]; + let data_re1 = &mut data_re1[start..][..2 * t]; + let data_im0 = &mut data_im0[start..][..2 * t]; + let data_im1 = &mut data_im1[start..][..2 * t]; + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) + { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + } + } + + m *= 2; + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[doc(hidden)] +pub fn negacyclic_fwd_fft_avxfma( + simd: Avx, + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + assert!(n >= 32); + + simd.vectorize({ + #[inline(always)] + || { + { + let mut t = n; + let mut m = 1; + + while m < n / 4 { + t /= 2; + + let twid_re0 = &twid_re0[m..]; + let twid_re1 = &twid_re1[m..]; + let twid_im0 = &twid_im0[m..]; + let twid_im1 = &twid_im1[m..]; + + let iter = izip!( + data_re0.chunks_mut(2 * t), + data_re1.chunks_mut(2 * t), + data_im0.chunks_mut(2 * t), + data_im1.chunks_mut(2 * t), + twid_re0, + twid_re1, + twid_im0, + twid_im1, + ); + for (data_re0, data_re1, data_im0, data_im1, w1_re0, w1_re1, w1_im0, w1_im1) in + iter + { + let w1_re = (*w1_re0, *w1_re1); + let w1_im = (*w1_im0, *w1_im1); + + let w1_re = (simd.splat(w1_re.0), simd.splat(w1_re.1)); + let w1_im = (simd.splat(w1_im.0), simd.splat(w1_im.1)); + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + let z0_re0 = as_arrays_mut::<4, _>(z0_re0).0; + let z0_re1 = as_arrays_mut::<4, _>(z0_re1).0; + let z0_im0 = as_arrays_mut::<4, _>(z0_im0).0; + let z0_im1 = as_arrays_mut::<4, _>(z0_im1).0; + let z1_re0 = as_arrays_mut::<4, _>(z1_re0).0; + let z1_re1 = as_arrays_mut::<4, _>(z1_re1).0; + let z1_im0 = as_arrays_mut::<4, _>(z1_im0).0; + let z1_im1 = as_arrays_mut::<4, _>(z1_im1).0; + + let iter = + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1); + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in iter + { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + } + } + + m *= 2; + } + } + + // m = n / 4 + // t = 2 + { + let m = n / 4; + + let twid_re0 = as_arrays::<2, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<2, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<2, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<2, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<4, _>(data_re0).0; + let data_re1 = as_arrays_mut::<4, _>(data_re1).0; + let data_im0 = as_arrays_mut::<4, _>(data_im0).0; + let data_im1 = as_arrays_mut::<4, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for ( + z0z0z1z1_re0, + z0z0z1z1_re1, + z0z0z1z1_im0, + z0z0z1z1_im1, + w1_re0, + w1_re1, + w1_im0, + w1_im1, + ) in iter + { + // 0 1 2 3 | 4 5 6 7 -> 0 1 4 5 | 2 3 6 7 + // + // is its own inverse since: + // 0 1 4 5 | 2 3 6 7 -> 0 1 2 3 | 4 5 6 7 + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 4]; 2]| -> [[f64; 4]; 2] { + [ + cast(simd.avx._mm256_permute2f128_pd::<0b00100000>( + cast(z0z0z1z1[0]), + cast(z0z0z1z1[1]), + )), + cast(simd.avx._mm256_permute2f128_pd::<0b00110001>( + cast(z0z0z1z1[0]), + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let splat2 = { + #[inline(always)] + |w: [f64; 2]| -> [f64; 4] { + let w00 = simd.sse2._mm_set1_pd(w[0]); + let w11 = simd.sse2._mm_set1_pd(w[1]); + + let w0011 = simd.avx._mm256_insertf128_pd::<0b1>( + simd.avx._mm256_castpd128_pd256(w00), + w11, + ); + + cast(w0011) + } + }; + + let w1_re = (splat2(*w1_re0), splat2(*w1_re1)); + let w1_im = (splat2(*w1_im0), splat2(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z0z1z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z0z1z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z0z1z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z0z1z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z0z1z1_re0 = interleave([z0_re0, z1_re0]); + *z0z0z1z1_re1 = interleave([z0_re1, z1_re1]); + *z0z0z1z1_im0 = interleave([z0_im0, z1_im0]); + *z0z0z1z1_im1 = interleave([z0_im1, z1_im1]); + } + } + + // m = n / 2 + // t = 1 + { + let m = n / 2; + + let twid_re0 = as_arrays::<4, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<4, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<4, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<4, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<4, _>(data_re0).0; + let data_re1 = as_arrays_mut::<4, _>(data_re1).0; + let data_im0 = as_arrays_mut::<4, _>(data_im0).0; + let data_im1 = as_arrays_mut::<4, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 -> 0 2 1 3 + let permute = { + #[inline(always)] + |w: [f64; 4]| -> [f64; 4] { + let avx = simd.avx; + let w0123 = cast(w); + let w0101 = avx._mm256_permute2f128_pd::<0b00000000>(w0123, w0123); + let w2323 = avx._mm256_permute2f128_pd::<0b00110011>(w0123, w0123); + let w0213 = avx._mm256_shuffle_pd::<0b1100>(w0101, w2323); + cast(w0213) + } + }; + + // 0 1 2 3 | 4 5 6 7 -> 0 4 2 6 | 1 5 3 7 + // + // is its own inverse since: + // 0 4 2 6 | 1 5 3 7 -> 0 1 2 3 | 4 5 6 7 + let interleave = { + #[inline(always)] + |z0z1: [[f64; 4]; 2]| -> [[f64; 4]; 2] { + [ + cast(simd.avx._mm256_unpacklo_pd(cast(z0z1[0]), cast(z0z1[1]))), + cast(simd.avx._mm256_unpackhi_pd(cast(z0z1[0]), cast(z0z1[1]))), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + } + } + }); +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "nightly")] +#[doc(hidden)] +pub fn negacyclic_fwd_fft_avx512( + simd: Avx512, + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + assert!(n >= 32); + + simd.vectorize({ + #[inline(always)] + || { + { + let mut t = n; + let mut m = 1; + + while m < n / 8 { + t /= 2; + + let twid_re0 = &twid_re0[m..]; + let twid_re1 = &twid_re1[m..]; + let twid_im0 = &twid_im0[m..]; + let twid_im1 = &twid_im1[m..]; + + let iter = izip!( + data_re0.chunks_mut(2 * t), + data_re1.chunks_mut(2 * t), + data_im0.chunks_mut(2 * t), + data_im1.chunks_mut(2 * t), + twid_re0, + twid_re1, + twid_im0, + twid_im1, + ); + for (data_re0, data_re1, data_im0, data_im1, w1_re0, w1_re1, w1_im0, w1_im1) in + iter + { + let w1_re = (*w1_re0, *w1_re1); + let w1_im = (*w1_im0, *w1_im1); + + let w1_re = (simd.splat(w1_re.0), simd.splat(w1_re.1)); + let w1_im = (simd.splat(w1_im.0), simd.splat(w1_im.1)); + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + let z0_re0 = as_arrays_mut::<8, _>(z0_re0).0; + let z0_re1 = as_arrays_mut::<8, _>(z0_re1).0; + let z0_im0 = as_arrays_mut::<8, _>(z0_im0).0; + let z0_im1 = as_arrays_mut::<8, _>(z0_im1).0; + let z1_re0 = as_arrays_mut::<8, _>(z1_re0).0; + let z1_re1 = as_arrays_mut::<8, _>(z1_re1).0; + let z1_im0 = as_arrays_mut::<8, _>(z1_im0).0; + let z1_im1 = as_arrays_mut::<8, _>(z1_im1).0; + + let iter = + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1); + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in iter + { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + } + } + + m *= 2; + } + } + + // m = n / 8 + // t = 4 + { + let m = n / 8; + + let twid_re0 = as_arrays::<2, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<2, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<2, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<2, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 -> 0 0 0 0 1 1 1 1 + let permute = { + #[inline(always)] + |w: [f64; 2]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let w01xxxxxx = avx512f._mm512_castpd128_pd512(w); + let idx = avx512f._mm512_setr_epi64(0, 0, 0, 0, 1, 1, 1, 1); + + cast(avx512f._mm512_permutexvar_pd(idx, w01xxxxxx)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 1 2 3 8 9 a b | 4 5 6 7 c d e f + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + let idx_0 = + avx512f._mm512_setr_epi64(0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xa, 0xb); + let idx_1 = + avx512f._mm512_setr_epi64(0x4, 0x5, 0x6, 0x7, 0xc, 0xd, 0xe, 0xf); + [ + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_0, + cast(z0z0z1z1[1]), + )), + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_1, + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + } + + // m = n / 4 + // t = 2 + { + let m = n / 4; + + let twid_re0 = as_arrays::<4, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<4, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<4, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<4, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 -> 0 0 2 2 1 1 3 3 + let permute = { + #[inline(always)] + |w: [f64; 4]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let w0123xxxx = avx512f._mm512_castpd256_pd512(w); + let idx = avx512f._mm512_setr_epi64(0, 0, 2, 2, 1, 1, 3, 3); + + cast(avx512f._mm512_permutexvar_pd(idx, w0123xxxx)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 1 8 9 4 5 c d | 2 3 a b 6 7 e f + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + let idx_0 = + avx512f._mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xc, 0xd); + let idx_1 = + avx512f._mm512_setr_epi64(0x2, 0x3, 0xa, 0xb, 0x6, 0x7, 0xe, 0xf); + [ + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_0, + cast(z0z0z1z1[1]), + )), + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_1, + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + } + + // m = n / 2 + // t = 1 + { + let m = n / 2; + + let twid_re0 = as_arrays::<8, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<8, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<8, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<8, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 4 5 6 7 -> 0 4 1 5 2 6 3 7 + let permute = { + #[inline(always)] + |w: [f64; 8]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let idx = avx512f._mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7); + cast(avx512f._mm512_permutexvar_pd(idx, w)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 8 2 a 4 c 6 e | 1 9 3 b 5 d 7 f + // + // is its own inverse since: + // 0 8 2 a 4 c 6 e | 1 9 3 b 5 d 7 f -> 0 1 2 3 4 5 6 7 | 8 9 a b c d e f + let interleave = { + #[inline(always)] + |z0z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + [ + cast(avx512f._mm512_unpacklo_pd(cast(z0z1[0]), cast(z0z1[1]))), + cast(avx512f._mm512_unpackhi_pd(cast(z0z1[0]), cast(z0z1[1]))), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + } + } + }); +} + +#[doc(hidden)] +pub fn negacyclic_fwd_fft( + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + #[cfg(feature = "nightly")] + if let Some(simd) = Avx512::try_new() { + return negacyclic_fwd_fft_avx512( + simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, + twid_im1, + ); + } + if let Some(simd) = Avx::try_new() { + return negacyclic_fwd_fft_avxfma( + simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, + twid_im1, + ); + } + } + negacyclic_fwd_fft_scalar( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1, + ) +} + +#[doc(hidden)] +pub fn negacyclic_inv_fft( + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + #[cfg(feature = "nightly")] + if let Some(simd) = Avx512::try_new() { + return negacyclic_inv_fft_avx512( + simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, + twid_im1, + ); + } + if let Some(simd) = Avx::try_new() { + return negacyclic_inv_fft_avxfma( + simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, + twid_im1, + ); + } + } + negacyclic_inv_fft_scalar( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1, + ) +} + +#[doc(hidden)] +pub fn negacyclic_inv_fft_scalar( + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + let mut t = 1; + let mut m = n; + let simd = Scalar; + + while m > 1 { + m /= 2; + + for i in 0..m { + let w1_re = (twid_re0[m + i], twid_re1[m + i]); + let w1_im = (twid_im0[m + i], twid_im1[m + i]); + + let start = 2 * i * t; + + let data_re0 = &mut data_re0[start..][..2 * t]; + let data_re1 = &mut data_re1[start..][..2 * t]; + let data_im0 = &mut data_im0[start..][..2 * t]; + let data_im1 = &mut data_im1[start..][..2 * t]; + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) + { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + } + } + + t *= 2; + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[doc(hidden)] +pub fn negacyclic_inv_fft_avxfma( + simd: Avx, + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + assert!(n >= 32); + + simd.vectorize({ + #[inline(always)] + || { + let mut t = 1; + let mut m = n; + + // m = n / 2 + // t = 1 + { + m /= 2; + + let twid_re0 = as_arrays::<4, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<4, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<4, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<4, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<4, _>(data_re0).0; + let data_re1 = as_arrays_mut::<4, _>(data_re1).0; + let data_im0 = as_arrays_mut::<4, _>(data_im0).0; + let data_im1 = as_arrays_mut::<4, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 -> 0 2 1 3 + let permute = { + #[inline(always)] + |w: [f64; 4]| -> [f64; 4] { + let avx = simd.avx; + let w0123 = cast(w); + let w0101 = avx._mm256_permute2f128_pd::<0b00000000>(w0123, w0123); + let w2323 = avx._mm256_permute2f128_pd::<0b00110011>(w0123, w0123); + let w0213 = avx._mm256_shuffle_pd::<0b1100>(w0101, w2323); + cast(w0213) + } + }; + + // 0 1 2 3 | 4 5 6 7 -> 0 4 2 6 | 1 5 3 7 + // + // is its own inverse since: + // 0 4 2 6 | 1 5 3 7 -> 0 1 2 3 | 4 5 6 7 + let interleave = { + #[inline(always)] + |z0z1: [[f64; 4]; 2]| -> [[f64; 4]; 2] { + [ + cast(simd.avx._mm256_unpacklo_pd(cast(z0z1[0]), cast(z0z1[1]))), + cast(simd.avx._mm256_unpackhi_pd(cast(z0z1[0]), cast(z0z1[1]))), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + // m = n / 4 + // t = 2 + { + m /= 2; + + let twid_re0 = as_arrays::<2, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<2, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<2, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<2, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<4, _>(data_re0).0; + let data_re1 = as_arrays_mut::<4, _>(data_re1).0; + let data_im0 = as_arrays_mut::<4, _>(data_im0).0; + let data_im1 = as_arrays_mut::<4, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for ( + z0z0z1z1_re0, + z0z0z1z1_re1, + z0z0z1z1_im0, + z0z0z1z1_im1, + w1_re0, + w1_re1, + w1_im0, + w1_im1, + ) in iter + { + // 0 1 2 3 | 4 5 6 7 -> 0 1 4 5 | 2 3 6 7 + // + // is its own inverse since: + // 0 1 4 5 | 2 3 6 7 -> 0 1 2 3 | 4 5 6 7 + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 4]; 2]| -> [[f64; 4]; 2] { + [ + cast(simd.avx._mm256_permute2f128_pd::<0b00100000>( + cast(z0z0z1z1[0]), + cast(z0z0z1z1[1]), + )), + cast(simd.avx._mm256_permute2f128_pd::<0b00110001>( + cast(z0z0z1z1[0]), + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let splat2 = { + #[inline(always)] + |w: [f64; 2]| -> [f64; 4] { + let w00 = simd.sse2._mm_set1_pd(w[0]); + let w11 = simd.sse2._mm_set1_pd(w[1]); + + let w0011 = simd.avx._mm256_insertf128_pd::<0b1>( + simd.avx._mm256_castpd128_pd256(w00), + w11, + ); + + cast(w0011) + } + }; + + let w1_re = (splat2(*w1_re0), splat2(*w1_re1)); + let w1_im = (splat2(*w1_im0), splat2(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z0z1z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z0z1z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z0z1z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z0z1z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z0z1z1_re0 = interleave([z0_re0, z1_re0]); + *z0z0z1z1_re1 = interleave([z0_re1, z1_re1]); + *z0z0z1z1_im0 = interleave([z0_im0, z1_im0]); + *z0z0z1z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + while m > 1 { + m /= 2; + + let twid_re0 = &twid_re0[m..]; + let twid_re1 = &twid_re1[m..]; + let twid_im0 = &twid_im0[m..]; + let twid_im1 = &twid_im1[m..]; + + let iter = izip!( + data_re0.chunks_mut(2 * t), + data_re1.chunks_mut(2 * t), + data_im0.chunks_mut(2 * t), + data_im1.chunks_mut(2 * t), + twid_re0, + twid_re1, + twid_im0, + twid_im1, + ); + for (data_re0, data_re1, data_im0, data_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + let w1_re = (*w1_re0, *w1_re1); + let w1_im = (*w1_im0, *w1_im1); + + let w1_re = (simd.splat(w1_re.0), simd.splat(w1_re.1)); + let w1_im = (simd.splat(w1_im.0), simd.splat(w1_im.1)); + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + let z0_re0 = as_arrays_mut::<4, _>(z0_re0).0; + let z0_re1 = as_arrays_mut::<4, _>(z0_re1).0; + let z0_im0 = as_arrays_mut::<4, _>(z0_im0).0; + let z0_im1 = as_arrays_mut::<4, _>(z0_im1).0; + let z1_re0 = as_arrays_mut::<4, _>(z1_re0).0; + let z1_re1 = as_arrays_mut::<4, _>(z1_re1).0; + let z1_im0 = as_arrays_mut::<4, _>(z1_im0).0; + let z1_im1 = as_arrays_mut::<4, _>(z1_im1).0; + + let iter = + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1); + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in iter { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + } + } + + t *= 2; + } + } + }); +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "nightly")] +#[doc(hidden)] +pub fn negacyclic_inv_fft_avx512( + simd: Avx512, + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + assert!(n >= 32); + + simd.vectorize({ + #[inline(always)] + || { + let mut t = 1; + let mut m = n; + + // m = n / 2 + // t = 1 + { + m /= 2; + + let twid_re0 = as_arrays::<8, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<8, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<8, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<8, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 4 5 6 7 -> 0 4 1 5 2 6 3 7 + let permute = { + #[inline(always)] + |w: [f64; 8]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let idx = avx512f._mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7); + cast(avx512f._mm512_permutexvar_pd(idx, w)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 8 2 a 4 c 6 e | 1 9 3 b 5 d 7 f + // + // is its own inverse since: + // 0 8 2 a 4 c 6 e | 1 9 3 b 5 d 7 f -> 0 1 2 3 4 5 6 7 | 8 9 a b c d e f + let interleave = { + #[inline(always)] + |z0z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + [ + cast(avx512f._mm512_unpacklo_pd(cast(z0z1[0]), cast(z0z1[1]))), + cast(avx512f._mm512_unpackhi_pd(cast(z0z1[0]), cast(z0z1[1]))), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + // m = n / 4 + // t = 2 + { + m /= 2; + + let twid_re0 = as_arrays::<4, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<4, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<4, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<4, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for ( + z0z0z1z1_re0, + z0z0z1z1_re1, + z0z0z1z1_im0, + z0z0z1z1_im1, + w1_re0, + w1_re1, + w1_im0, + w1_im1, + ) in iter + { + // 0 1 2 3 -> 0 0 2 2 1 1 3 3 + let permute = { + #[inline(always)] + |w: [f64; 4]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let w0123xxxx = avx512f._mm512_castpd256_pd512(w); + let idx = avx512f._mm512_setr_epi64(0, 0, 2, 2, 1, 1, 3, 3); + + cast(avx512f._mm512_permutexvar_pd(idx, w0123xxxx)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 1 8 9 4 5 c d | 2 3 a b 6 7 e f + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + let idx_0 = + avx512f._mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xc, 0xd); + let idx_1 = + avx512f._mm512_setr_epi64(0x2, 0x3, 0xa, 0xb, 0x6, 0x7, 0xe, 0xf); + [ + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_0, + cast(z0z0z1z1[1]), + )), + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_1, + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z0z1z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z0z1z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z0z1z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z0z1z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z0z1z1_re0 = interleave([z0_re0, z1_re0]); + *z0z0z1z1_re1 = interleave([z0_re1, z1_re1]); + *z0z0z1z1_im0 = interleave([z0_im0, z1_im0]); + *z0z0z1z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + // m = n / 8 + // t = 4 + { + m /= 2; + + let twid_re0 = as_arrays::<2, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<2, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<2, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<2, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for ( + z0z0z1z1_re0, + z0z0z1z1_re1, + z0z0z1z1_im0, + z0z0z1z1_im1, + w1_re0, + w1_re1, + w1_im0, + w1_im1, + ) in iter + { + // 0 1 -> 0 0 0 0 1 1 1 1 + let permute = { + #[inline(always)] + |w: [f64; 2]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let w01xxxxxx = avx512f._mm512_castpd128_pd512(w); + let idx = avx512f._mm512_setr_epi64(0, 0, 0, 0, 1, 1, 1, 1); + + cast(avx512f._mm512_permutexvar_pd(idx, w01xxxxxx)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 1 2 3 8 9 a b | 4 5 6 7 c d e f + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + let idx_0 = + avx512f._mm512_setr_epi64(0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xa, 0xb); + let idx_1 = + avx512f._mm512_setr_epi64(0x4, 0x5, 0x6, 0x7, 0xc, 0xd, 0xe, 0xf); + [ + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_0, + cast(z0z0z1z1[1]), + )), + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_1, + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z0z1z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z0z1z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z0z1z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z0z1z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z0z1z1_re0 = interleave([z0_re0, z1_re0]); + *z0z0z1z1_re1 = interleave([z0_re1, z1_re1]); + *z0z0z1z1_im0 = interleave([z0_im0, z1_im0]); + *z0z0z1z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + while m > 1 { + m /= 2; + + let twid_re0 = &twid_re0[m..]; + let twid_re1 = &twid_re1[m..]; + let twid_im0 = &twid_im0[m..]; + let twid_im1 = &twid_im1[m..]; + + let iter = izip!( + data_re0.chunks_mut(2 * t), + data_re1.chunks_mut(2 * t), + data_im0.chunks_mut(2 * t), + data_im1.chunks_mut(2 * t), + twid_re0, + twid_re1, + twid_im0, + twid_im1, + ); + for (data_re0, data_re1, data_im0, data_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + let w1_re = (*w1_re0, *w1_re1); + let w1_im = (*w1_im0, *w1_im1); + + let w1_re = (simd.splat(w1_re.0), simd.splat(w1_re.1)); + let w1_im = (simd.splat(w1_im.0), simd.splat(w1_im.1)); + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + let z0_re0 = as_arrays_mut::<8, _>(z0_re0).0; + let z0_re1 = as_arrays_mut::<8, _>(z0_re1).0; + let z0_im0 = as_arrays_mut::<8, _>(z0_im0).0; + let z0_im1 = as_arrays_mut::<8, _>(z0_im1).0; + let z1_re0 = as_arrays_mut::<8, _>(z1_re0).0; + let z1_re1 = as_arrays_mut::<8, _>(z1_re1).0; + let z1_im0 = as_arrays_mut::<8, _>(z1_im0).0; + let z1_im1 = as_arrays_mut::<8, _>(z1_im1).0; + + let iter = + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1); + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in iter { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + } + } + + t *= 2; + } + } + }); +} + +fn bitreverse(i: usize, n: usize) -> usize { + let logn = n.trailing_zeros(); + let mut result = 0; + for k in 0..logn { + let kth_bit = (i >> k) & 1_usize; + result |= kth_bit << (logn - k - 1); + } + result +} + +#[doc(hidden)] +pub fn init_negacyclic_twiddles( + twid_re0: &mut [f64], + twid_re1: &mut [f64], + twid_im0: &mut [f64], + twid_im1: &mut [f64], +) { + let n = twid_re0.len(); + let mut m = 1_usize; + + while m < n { + for i in 0..m { + let k = 2 * m + i; + let pos = m + i; + + let theta_over_pi = f128(bitreverse(k, 2 * n) as f64 / (2 * n) as f64, 0.0); + let (s, c) = theta_over_pi.sincospi(); + twid_re0[pos] = c.0; + twid_re1[pos] = c.1; + twid_im0[pos] = s.0; + twid_im1[pos] = s.1; + } + m *= 2; + } +} + +/// 128-bit negacyclic FFT plan. +#[derive(Clone)] +pub struct Plan { + twid_re0: ABox<[f64]>, + twid_re1: ABox<[f64]>, + twid_im0: ABox<[f64]>, + twid_im1: ABox<[f64]>, +} + +impl core::fmt::Debug for Plan { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Plan") + .field("fft_size", &self.fft_size()) + .finish() + } +} + +impl Plan { + /// Returns a new negacyclic FFT plan for the given vector size, following the algorithm in + /// [Fast and Error-Free Negacyclic Integer Convolution using Extended Fourier Transform][paper] + /// + /// # Panics + /// + /// - Panics if `n` is not a power of two, or if it is less than `32`. + /// + /// # Example + /// + /// ``` + /// use concrete_fft::fft128::Plan; + /// let plan = Plan::new(32); + /// ``` + /// + /// [paper]: https://eprint.iacr.org/2021/480 + #[track_caller] + pub fn new(n: usize) -> Self { + assert!(n.is_power_of_two()); + assert!(n >= 32); + + let mut twid_re0 = avec![0.0f64; n].into_boxed_slice(); + let mut twid_re1 = avec![0.0f64; n].into_boxed_slice(); + let mut twid_im0 = avec![0.0f64; n].into_boxed_slice(); + let mut twid_im1 = avec![0.0f64; n].into_boxed_slice(); + + init_negacyclic_twiddles(&mut twid_re0, &mut twid_re1, &mut twid_im0, &mut twid_im1); + + Self { + twid_re0, + twid_re1, + twid_im0, + twid_im1, + } + } + + /// Returns the vector size of the negacyclic FFT. + /// + /// # Example + /// + /// ``` + /// use concrete_fft::fft128::Plan; + /// let plan = Plan::new(32); + /// assert_eq!(plan.fft_size(), 32); + /// ``` + pub fn fft_size(&self) -> usize { + self.twid_re0.len() + } + + /// Performs a forward negacyclic FFT in place. + /// + /// # Note + /// + /// The values in `buf_re0`, `buf_re1`, `buf_im0`, `buf_im1` must be in standard order prior to + /// calling this function. When this function returns, the values in `buf_re0`, `buf_re1`, `buf_im0`, `buf_im1` will contain the + /// terms of the forward transform in bit-reversed order. + #[track_caller] + pub fn fwd( + &self, + buf_re0: &mut [f64], + buf_re1: &mut [f64], + buf_im0: &mut [f64], + buf_im1: &mut [f64], + ) { + assert_eq!(buf_re0.len(), self.fft_size()); + assert_eq!(buf_re1.len(), self.fft_size()); + assert_eq!(buf_im0.len(), self.fft_size()); + assert_eq!(buf_im1.len(), self.fft_size()); + + negacyclic_fwd_fft( + buf_re0, + buf_re1, + buf_im0, + buf_im1, + &self.twid_re0, + &self.twid_re1, + &self.twid_im0, + &self.twid_im1, + ); + } + + /// Performs an inverse negacyclic FFT in place. + /// + /// # Note + /// + /// The values in `buf_re0`, `buf_re1`, `buf_im0`, `buf_im1` must be in bit-reversed order + /// prior to calling this function. When this function returns, the values in `buf_re0`, + /// `buf_re1`, `buf_im0`, `buf_im1` will contain the terms of the inverse transform in standard + /// order. + #[track_caller] + pub fn inv( + &self, + buf_re0: &mut [f64], + buf_re1: &mut [f64], + buf_im0: &mut [f64], + buf_im1: &mut [f64], + ) { + assert_eq!(buf_re0.len(), self.fft_size()); + assert_eq!(buf_re1.len(), self.fft_size()); + assert_eq!(buf_im0.len(), self.fft_size()); + assert_eq!(buf_im1.len(), self.fft_size()); + + negacyclic_inv_fft( + buf_re0, + buf_re1, + buf_im0, + buf_im1, + &self.twid_re0, + &self.twid_re1, + &self.twid_im0, + &self.twid_im1, + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use alloc::vec; + use rand::random; + + extern crate alloc; + + #[test] + fn test_wrapper() { + let n = 1024; + + let mut lhs = vec![f128(0.0, 0.0); n]; + let mut rhs = vec![f128(0.0, 0.0); n]; + let mut result = vec![f128(0.0, 0.0); n]; + + for x in &mut lhs { + x.0 = random(); + } + for x in &mut rhs { + x.0 = random(); + } + + let mut full_convolution = vec![f128(0.0, 0.0); 2 * n]; + let mut negacyclic_convolution = vec![f128(0.0, 0.0); n]; + for i in 0..n { + for j in 0..n { + full_convolution[i + j] += lhs[i] * rhs[j]; + } + } + for i in 0..n { + negacyclic_convolution[i] = full_convolution[i] - full_convolution[i + n]; + } + + let mut lhs_fourier_re0 = vec![0.0; n / 2]; + let mut lhs_fourier_re1 = vec![0.0; n / 2]; + let mut lhs_fourier_im0 = vec![0.0; n / 2]; + let mut lhs_fourier_im1 = vec![0.0; n / 2]; + + let mut rhs_fourier_re0 = vec![0.0; n / 2]; + let mut rhs_fourier_re1 = vec![0.0; n / 2]; + let mut rhs_fourier_im0 = vec![0.0; n / 2]; + let mut rhs_fourier_im1 = vec![0.0; n / 2]; + + for i in 0..n / 2 { + lhs_fourier_re0[i] = lhs[i].0; + lhs_fourier_re1[i] = lhs[i].1; + lhs_fourier_im0[i] = lhs[i + n / 2].0; + lhs_fourier_im1[i] = lhs[i + n / 2].1; + + rhs_fourier_re0[i] = rhs[i].0; + rhs_fourier_re1[i] = rhs[i].1; + rhs_fourier_im0[i] = rhs[i + n / 2].0; + rhs_fourier_im1[i] = rhs[i + n / 2].1; + } + + let plan = Plan::new(n / 2); + + plan.fwd( + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + ); + plan.fwd( + &mut rhs_fourier_re0, + &mut rhs_fourier_re1, + &mut rhs_fourier_im0, + &mut rhs_fourier_im1, + ); + + let factor = 2.0 / n as f64; + let simd = Scalar; + for i in 0..n / 2 { + let (prod_re, prod_im) = simd.cplx_mul( + (lhs_fourier_re0[i], lhs_fourier_re1[i]), + (lhs_fourier_im0[i], lhs_fourier_im1[i]), + (rhs_fourier_re0[i], rhs_fourier_re1[i]), + (rhs_fourier_im0[i], rhs_fourier_im1[i]), + ); + + lhs_fourier_re0[i] = prod_re.0 * factor; + lhs_fourier_re1[i] = prod_re.1 * factor; + lhs_fourier_im0[i] = prod_im.0 * factor; + lhs_fourier_im1[i] = prod_im.1 * factor; + } + + plan.inv( + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + ); + + for i in 0..n / 2 { + result[i] = f128(lhs_fourier_re0[i], lhs_fourier_re1[i]); + result[i + n / 2] = f128(lhs_fourier_im0[i], lhs_fourier_im1[i]); + } + + for i in 0..n { + assert!((result[i] - negacyclic_convolution[i]).abs() < 1e-28); + } + } + + #[test] + fn test_product() { + let n = 1024; + + let mut lhs = vec![f128(0.0, 0.0); n]; + let mut rhs = vec![f128(0.0, 0.0); n]; + let mut result = vec![f128(0.0, 0.0); n]; + + for x in &mut lhs { + x.0 = random(); + } + for x in &mut rhs { + x.0 = random(); + } + + let mut full_convolution = vec![f128(0.0, 0.0); 2 * n]; + let mut negacyclic_convolution = vec![f128(0.0, 0.0); n]; + for i in 0..n { + for j in 0..n { + full_convolution[i + j] += lhs[i] * rhs[j]; + } + } + for i in 0..n { + negacyclic_convolution[i] = full_convolution[i] - full_convolution[i + n]; + } + + let mut twid_re0 = vec![0.0; n / 2]; + let mut twid_re1 = vec![0.0; n / 2]; + let mut twid_im0 = vec![0.0; n / 2]; + let mut twid_im1 = vec![0.0; n / 2]; + + let mut lhs_fourier_re0 = vec![0.0; n / 2]; + let mut lhs_fourier_re1 = vec![0.0; n / 2]; + let mut lhs_fourier_im0 = vec![0.0; n / 2]; + let mut lhs_fourier_im1 = vec![0.0; n / 2]; + + let mut rhs_fourier_re0 = vec![0.0; n / 2]; + let mut rhs_fourier_re1 = vec![0.0; n / 2]; + let mut rhs_fourier_im0 = vec![0.0; n / 2]; + let mut rhs_fourier_im1 = vec![0.0; n / 2]; + + init_negacyclic_twiddles(&mut twid_re0, &mut twid_re1, &mut twid_im0, &mut twid_im1); + + for i in 0..n / 2 { + lhs_fourier_re0[i] = lhs[i].0; + lhs_fourier_re1[i] = lhs[i].1; + lhs_fourier_im0[i] = lhs[i + n / 2].0; + lhs_fourier_im1[i] = lhs[i + n / 2].1; + + rhs_fourier_re0[i] = rhs[i].0; + rhs_fourier_re1[i] = rhs[i].1; + rhs_fourier_im0[i] = rhs[i + n / 2].0; + rhs_fourier_im1[i] = rhs[i + n / 2].1; + } + + negacyclic_fwd_fft_scalar( + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + negacyclic_fwd_fft_scalar( + &mut rhs_fourier_re0, + &mut rhs_fourier_re1, + &mut rhs_fourier_im0, + &mut rhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + let factor = 2.0 / n as f64; + let simd = Scalar; + for i in 0..n / 2 { + let (prod_re, prod_im) = simd.cplx_mul( + (lhs_fourier_re0[i], lhs_fourier_re1[i]), + (lhs_fourier_im0[i], lhs_fourier_im1[i]), + (rhs_fourier_re0[i], rhs_fourier_re1[i]), + (rhs_fourier_im0[i], rhs_fourier_im1[i]), + ); + + lhs_fourier_re0[i] = prod_re.0 * factor; + lhs_fourier_re1[i] = prod_re.1 * factor; + lhs_fourier_im0[i] = prod_im.0 * factor; + lhs_fourier_im1[i] = prod_im.1 * factor; + } + + negacyclic_inv_fft_scalar( + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + for i in 0..n / 2 { + result[i] = f128(lhs_fourier_re0[i], lhs_fourier_re1[i]); + result[i + n / 2] = f128(lhs_fourier_im0[i], lhs_fourier_im1[i]); + } + + for i in 0..n { + assert!((result[i] - negacyclic_convolution[i]).abs() < 1e-28); + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[test] + fn test_product_avxfma() { + if let Some(simd) = Avx::try_new() { + let n = 1024; + + let mut lhs = vec![f128(0.0, 0.0); n]; + let mut rhs = vec![f128(0.0, 0.0); n]; + let mut result = vec![f128(0.0, 0.0); n]; + + for x in &mut lhs { + x.0 = random(); + } + for x in &mut rhs { + x.0 = random(); + } + + let mut full_convolution = vec![f128(0.0, 0.0); 2 * n]; + let mut negacyclic_convolution = vec![f128(0.0, 0.0); n]; + for i in 0..n { + for j in 0..n { + full_convolution[i + j] += lhs[i] * rhs[j]; + } + } + for i in 0..n { + negacyclic_convolution[i] = full_convolution[i] - full_convolution[i + n]; + } + + let mut twid_re0 = vec![0.0; n / 2]; + let mut twid_re1 = vec![0.0; n / 2]; + let mut twid_im0 = vec![0.0; n / 2]; + let mut twid_im1 = vec![0.0; n / 2]; + + let mut lhs_fourier_re0 = vec![0.0; n / 2]; + let mut lhs_fourier_re1 = vec![0.0; n / 2]; + let mut lhs_fourier_im0 = vec![0.0; n / 2]; + let mut lhs_fourier_im1 = vec![0.0; n / 2]; + + let mut rhs_fourier_re0 = vec![0.0; n / 2]; + let mut rhs_fourier_re1 = vec![0.0; n / 2]; + let mut rhs_fourier_im0 = vec![0.0; n / 2]; + let mut rhs_fourier_im1 = vec![0.0; n / 2]; + + init_negacyclic_twiddles(&mut twid_re0, &mut twid_re1, &mut twid_im0, &mut twid_im1); + + for i in 0..n / 2 { + lhs_fourier_re0[i] = lhs[i].0; + lhs_fourier_re1[i] = lhs[i].1; + lhs_fourier_im0[i] = lhs[i + n / 2].0; + lhs_fourier_im1[i] = lhs[i + n / 2].1; + + rhs_fourier_re0[i] = rhs[i].0; + rhs_fourier_re1[i] = rhs[i].1; + rhs_fourier_im0[i] = rhs[i + n / 2].0; + rhs_fourier_im1[i] = rhs[i + n / 2].1; + } + + negacyclic_fwd_fft_avxfma( + simd, + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + negacyclic_fwd_fft_avxfma( + simd, + &mut rhs_fourier_re0, + &mut rhs_fourier_re1, + &mut rhs_fourier_im0, + &mut rhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + let factor = 2.0 / n as f64; + let scalar = Scalar; + for i in 0..n / 2 { + let (prod_re, prod_im) = scalar.cplx_mul( + (lhs_fourier_re0[i], lhs_fourier_re1[i]), + (lhs_fourier_im0[i], lhs_fourier_im1[i]), + (rhs_fourier_re0[i], rhs_fourier_re1[i]), + (rhs_fourier_im0[i], rhs_fourier_im1[i]), + ); + + lhs_fourier_re0[i] = prod_re.0 * factor; + lhs_fourier_re1[i] = prod_re.1 * factor; + lhs_fourier_im0[i] = prod_im.0 * factor; + lhs_fourier_im1[i] = prod_im.1 * factor; + } + + negacyclic_inv_fft_avxfma( + simd, + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + for i in 0..n / 2 { + result[i] = f128(lhs_fourier_re0[i], lhs_fourier_re1[i]); + result[i + n / 2] = f128(lhs_fourier_im0[i], lhs_fourier_im1[i]); + } + + for i in 0..n { + assert!((result[i] - negacyclic_convolution[i]).abs() < 1e-28); + } + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[cfg(feature = "nightly")] + #[test] + fn test_product_avx512() { + if let Some(simd) = Avx512::try_new() { + let n = 1024; + + let mut lhs = vec![f128(0.0, 0.0); n]; + let mut rhs = vec![f128(0.0, 0.0); n]; + let mut result = vec![f128(0.0, 0.0); n]; + + for x in &mut lhs { + x.0 = random(); + } + for x in &mut rhs { + x.0 = random(); + } + + let mut full_convolution = vec![f128(0.0, 0.0); 2 * n]; + let mut negacyclic_convolution = vec![f128(0.0, 0.0); n]; + for i in 0..n { + for j in 0..n { + full_convolution[i + j] += lhs[i] * rhs[j]; + } + } + for i in 0..n { + negacyclic_convolution[i] = full_convolution[i] - full_convolution[i + n]; + } + + let mut twid_re0 = vec![0.0; n / 2]; + let mut twid_re1 = vec![0.0; n / 2]; + let mut twid_im0 = vec![0.0; n / 2]; + let mut twid_im1 = vec![0.0; n / 2]; + + let mut lhs_fourier_re0 = vec![0.0; n / 2]; + let mut lhs_fourier_re1 = vec![0.0; n / 2]; + let mut lhs_fourier_im0 = vec![0.0; n / 2]; + let mut lhs_fourier_im1 = vec![0.0; n / 2]; + + let mut rhs_fourier_re0 = vec![0.0; n / 2]; + let mut rhs_fourier_re1 = vec![0.0; n / 2]; + let mut rhs_fourier_im0 = vec![0.0; n / 2]; + let mut rhs_fourier_im1 = vec![0.0; n / 2]; + + init_negacyclic_twiddles(&mut twid_re0, &mut twid_re1, &mut twid_im0, &mut twid_im1); + + for i in 0..n / 2 { + lhs_fourier_re0[i] = lhs[i].0; + lhs_fourier_re1[i] = lhs[i].1; + lhs_fourier_im0[i] = lhs[i + n / 2].0; + lhs_fourier_im1[i] = lhs[i + n / 2].1; + + rhs_fourier_re0[i] = rhs[i].0; + rhs_fourier_re1[i] = rhs[i].1; + rhs_fourier_im0[i] = rhs[i + n / 2].0; + rhs_fourier_im1[i] = rhs[i + n / 2].1; + } + + negacyclic_fwd_fft_avx512( + simd, + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + negacyclic_fwd_fft_avx512( + simd, + &mut rhs_fourier_re0, + &mut rhs_fourier_re1, + &mut rhs_fourier_im0, + &mut rhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + let factor = 2.0 / n as f64; + let scalar = Scalar; + for i in 0..n / 2 { + let (prod_re, prod_im) = scalar.cplx_mul( + (lhs_fourier_re0[i], lhs_fourier_re1[i]), + (lhs_fourier_im0[i], lhs_fourier_im1[i]), + (rhs_fourier_re0[i], rhs_fourier_re1[i]), + (rhs_fourier_im0[i], rhs_fourier_im1[i]), + ); + + lhs_fourier_re0[i] = prod_re.0 * factor; + lhs_fourier_re1[i] = prod_re.1 * factor; + lhs_fourier_im0[i] = prod_im.0 * factor; + lhs_fourier_im1[i] = prod_im.1 * factor; + } + + negacyclic_inv_fft_avx512( + simd, + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + for i in 0..n / 2 { + result[i] = f128(lhs_fourier_re0[i], lhs_fourier_re1[i]); + result[i + n / 2] = f128(lhs_fourier_im0[i], lhs_fourier_im1[i]); + } + + for i in 0..n { + assert!((result[i] - negacyclic_convolution[i]).abs() < 1e-28); + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index cf76154..d95cae6 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,10 +13,14 @@ //! convolution. The only operations that are performed in the Fourier domain are elementwise, and //! so the order of the coefficients does not affect the results. //! +//! Additionally, an optional 128-bit negacyclic FFT module is provided. +//! //! # Features //! //! - `std` (default): This enables runtime arch detection for accelerated SIMD instructions, and //! an FFT plan that measures the various implementations to choose the fastest one at runtime. +//! - `fft128` (default): This flag provides access to the 128-bit FFT, which is accessible in the +//! [`fft128`] module. //! - `nightly`: This enables unstable Rust features to further speed up the FFT, by enabling //! AVX512F instructions on CPUs that support them. This feature requires a nightly Rust //! toolchain. @@ -89,6 +93,9 @@ pub(crate) mod dit2; pub(crate) mod dit4; pub(crate) mod dit8; +#[cfg(feature = "fft128")] +#[cfg_attr(docsrs, doc(cfg(feature = "fft128")))] +pub mod fft128; pub mod ordered; pub mod unordered; diff --git a/src/ordered.rs b/src/ordered.rs index 4cd24da..737c645 100644 --- a/src/ordered.rs +++ b/src/ordered.rs @@ -87,7 +87,7 @@ fn measure_n_runs( #[cfg(feature = "std")] fn duration_div_f64(duration: Duration, n: f64) -> Duration { - Duration::from_secs_f64(duration.as_secs_f64() / n as f64) + Duration::from_secs_f64(duration.as_secs_f64() / n) } #[cfg(feature = "std")] diff --git a/src/unordered.rs b/src/unordered.rs index 962a76b..e565fa4 100644 --- a/src/unordered.rs +++ b/src/unordered.rs @@ -958,7 +958,7 @@ impl Plan { /// # Note /// /// The values in `buf` must be in permuted order prior to calling this function. - /// When this function returns, the values in `buf` will contain the terms of the forward + /// When this function returns, the values in `buf` will contain the terms of the inverse /// transform in standard order. /// /// # Example @@ -1144,7 +1144,7 @@ mod tests { ); let base_n = plan.algo().1; let mut mem = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); - let stack = DynStack::new(&mut *mem); + let stack = DynStack::new(&mut mem); plan.fwd(&mut z, stack); for i in 0..n { @@ -1177,7 +1177,7 @@ mod tests { }, ); let mut mem = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); - let mut stack = DynStack::new(&mut *mem); + let mut stack = DynStack::new(&mut mem); plan.fwd(&mut z, stack.rb_mut()); plan.inv(&mut z, stack); @@ -1232,7 +1232,7 @@ mod tests_serde { .unwrap() .or(plan2.fft_scratch().unwrap()), ); - let mut stack = DynStack::new(&mut *mem); + let mut stack = DynStack::new(&mut mem); plan1.fwd(&mut z, stack.rb_mut());