Skip to content

Commit 08744bc

Browse files
committed
improve dft init, recursive fft for radix-4
1 parent 3ea50d0 commit 08744bc

File tree

2 files changed

+47
-42
lines changed

2 files changed

+47
-42
lines changed

cp-algo/math/cvector.hpp

+38-34
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,23 @@ namespace cp_algo::math::fft {
112112

113113
void ifft() {
114114
size_t n = size();
115-
for(size_t i = flen; i <= n / 2; i *= 2) {
116-
if (4 * i <= n) { // radix-4
117-
exec_on_evals<4>(n / (4 * i), [&](size_t k, point rt) {
118-
k *= 4 * i;
115+
bool parity = std::countr_zero(n) % 2;
116+
if(parity) {
117+
exec_on_evals<2>(n / (2 * flen), [&](size_t k, point rt) {
118+
k *= 2 * flen;
119+
vpoint cvrt = {vz + real(rt), vz - imag(rt)};
120+
auto B = at(k) - at(k + flen);
121+
at(k) += at(k + flen);
122+
at(k + flen) = B * cvrt;
123+
});
124+
}
125+
126+
for(size_t leaf = 3 * flen; leaf < n; leaf += 4 * flen) {
127+
size_t level = std::countr_one(leaf + 3);
128+
for(size_t lvl = 4 + parity; lvl <= level; lvl += 2) {
129+
size_t i = (1 << lvl) / 4;
130+
exec_on_eval<4>(n >> lvl, leaf >> lvl, [&](size_t k, point rt) {
131+
k <<= lvl;
119132
vpoint v1 = {vz + real(rt), vz - imag(rt)};
120133
vpoint v2 = v1 * v1;
121134
vpoint v3 = v1 * v2;
@@ -124,21 +137,10 @@ namespace cp_algo::math::fft {
124137
auto B = at(j + i);
125138
auto C = at(j + 2 * i);
126139
auto D = at(j + 3 * i);
127-
at(j) = (A + B + C + D);
128-
at(j + 2 * i) = (A + B - C - D) * v2;
129-
at(j + i) = (A - B - vi(C - D)) * v1;
130-
at(j + 3 * i) = (A - B + vi(C - D)) * v3;
131-
}
132-
});
133-
i *= 2;
134-
} else { // radix-2 fallback
135-
exec_on_evals<2>(n / (2 * i), [&](size_t k, point rt) {
136-
k *= 2 * i;
137-
vpoint cvrt = {vz + real(rt), vz - imag(rt)};
138-
for(size_t j = k; j < k + i; j += flen) {
139-
auto B = at(j) - at(j + i);
140-
at(j) += at(j + i);
141-
at(j + i) = B * cvrt;
140+
at(j) = ((A + B) + (C + D));
141+
at(j + 2 * i) = ((A + B) - (C + D)) * v2;
142+
at(j + i) = ((A - B) - vi(C - D)) * v1;
143+
at(j + 3 * i) = ((A - B) + vi(C - D)) * v3;
142144
}
143145
});
144146
}
@@ -150,11 +152,14 @@ namespace cp_algo::math::fft {
150152
}
151153
void fft() {
152154
size_t n = size();
153-
for(size_t i = n / 2; i >= flen; i /= 2) {
154-
if (i / 2 >= flen) { // radix-4
155-
i /= 2;
156-
exec_on_evals<4>(n / (4 * i), [&](size_t k, point rt) {
157-
k *= 4 * i;
155+
bool parity = std::countr_zero(n) % 2;
156+
for(size_t leaf = 0; leaf < n; leaf += 4 * flen) {
157+
size_t level = std::countr_zero(n + leaf);
158+
level -= level % 2 != parity;
159+
for(size_t lvl = level; lvl >= 4; lvl -= 2) {
160+
size_t i = (1 << lvl) / 4;
161+
exec_on_eval<4>(n >> lvl, leaf >> lvl, [&](size_t k, point rt) {
162+
k <<= lvl;
158163
vpoint v1 = {vz + real(rt), vz + imag(rt)};
159164
vpoint v2 = v1 * v1;
160165
vpoint v3 = v1 * v2;
@@ -169,18 +174,17 @@ namespace cp_algo::math::fft {
169174
at(j + 3 * i) = (A - C) - vi(B - D);
170175
}
171176
});
172-
} else { // radix-2 fallback
173-
exec_on_evals<2>(n / (2 * i), [&](size_t k, point rt) {
174-
k *= 2 * i;
175-
vpoint vrt = {vz + real(rt), vz + imag(rt)};
176-
for(size_t j = k; j < k + i; j += flen) {
177-
auto t = at(j + i) * vrt;
178-
at(j + i) = at(j) - t;
179-
at(j) += t;
180-
}
181-
});
182177
}
183178
}
179+
if(parity) {
180+
exec_on_evals<2>(n / (2 * flen), [&](size_t k, point rt) {
181+
k *= 2 * flen;
182+
vpoint vrt = {vz + real(rt), vz + imag(rt)};
183+
auto t = at(k + flen) * vrt;
184+
at(k + flen) = at(k) - t;
185+
at(k) += t;
186+
});
187+
}
184188
checkpoint("fft");
185189
}
186190
static constexpr size_t pre_evals = 1 << 16;

cp-algo/math/fft.hpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ namespace cp_algo::math::fft {
1313
static base factor, ifactor;
1414
using Int2 = base::Int2;
1515
static bool _init;
16-
static int split;
16+
static int split() {
17+
static const int splt = int(std::sqrt(base::mod())) + 1;
18+
return splt;
19+
}
1720
static u64x4 mod, imod;
1821

1922
void init() {
2023
if(!_init) {
2124
factor = 1 + random::rng() % (base::mod() - 1);
22-
split = int(std::sqrt(base::mod())) + 1;
2325
ifactor = base(1) / factor;
2426
mod = u64x4() + base::mod();
2527
imod = u64x4() + inv2(-base::mod());
@@ -51,9 +53,9 @@ namespace cp_algo::math::fft {
5153
};
5254
au = montgomery_mul(au, mul, mod, imod);
5355
au = au >= base::mod() ? au - base::mod() : au;
54-
auto ai = i64x4(au);
55-
ai = ai >= base::mod() / 2 ? ai - base::mod() : ai;
56-
return std::pair{to_double(ai % split), to_double(ai / split)};
56+
auto ai = to_double(i64x4(au >= base::mod() / 2 ? au - base::mod() : au));
57+
auto quo = round(ai / split());
58+
return std::pair{ai - quo * split(), quo};
5759
};
5860
auto [rai, qai] = splt(i, cur);
5961
auto [rani, qani] = splt(n + i, montgomery_mul(cur, stepn, mod, imod));
@@ -101,7 +103,7 @@ namespace cp_algo::math::fft {
101103
void recover_mod(auto &&C, auto &res, size_t k) {
102104
res.assign((k / flen + 1) * flen, base(0));
103105
size_t n = A.size();
104-
auto splitsplit = base(split * split).getr();
106+
auto const splitsplit = base(split() * split()).getr();
105107
base b2x32 = bpow(base(2), 32);
106108
base b2x64 = bpow(base(2), 64);
107109
u64x4 cur = {
@@ -118,7 +120,7 @@ namespace cp_algo::math::fft {
118120
auto [Cx, Cy] = C.at(i);
119121
auto set_i = [&](size_t i, auto A, auto B, auto C, auto mul) {
120122
auto A0 = lround(A), A1 = lround(C), A2 = lround(B);
121-
auto Ai = A0 + A1 * split + A2 * splitsplit + uint64_t(base::modmod());
123+
auto Ai = A0 + A1 * split() + A2 * splitsplit + uint64_t(base::modmod());
122124
auto Au = montgomery_reduce(u64x4(Ai), mod, imod);
123125
Au = montgomery_mul(Au, mul, mod, imod);
124126
Au = Au >= base::mod() ? Au - base::mod() : Au;
@@ -174,7 +176,6 @@ namespace cp_algo::math::fft {
174176
template<modint_type base> base dft<base>::factor = 1;
175177
template<modint_type base> base dft<base>::ifactor = 1;
176178
template<modint_type base> bool dft<base>::_init = false;
177-
template<modint_type base> int dft<base>::split = 1;
178179
template<modint_type base> u64x4 dft<base>::mod = {};
179180
template<modint_type base> u64x4 dft<base>::imod = {};
180181

0 commit comments

Comments
 (0)