@@ -112,10 +112,23 @@ namespace cp_algo::math::fft {
112
112
113
113
void ifft () {
114
114
size_t n = size ();
115
- for (size_t i = flen; i <= n / 2 ; i *= 2 ) {
116
- if (4 * i <= n) { // radix-4
117
- exec_on_evals<4 >(n / (4 * i), [&](size_t k, point rt) {
118
- k *= 4 * i;
115
+ bool parity = std::countr_zero (n) % 2 ;
116
+ if (parity) {
117
+ exec_on_evals<2 >(n / (2 * flen), [&](size_t k, point rt) {
118
+ k *= 2 * flen;
119
+ vpoint cvrt = {vz + real (rt), vz - imag (rt)};
120
+ auto B = at (k) - at (k + flen);
121
+ at (k) += at (k + flen);
122
+ at (k + flen) = B * cvrt;
123
+ });
124
+ }
125
+
126
+ for (size_t leaf = 3 * flen; leaf < n; leaf += 4 * flen) {
127
+ size_t level = std::countr_one (leaf + 3 );
128
+ for (size_t lvl = 4 + parity; lvl <= level; lvl += 2 ) {
129
+ size_t i = (1 << lvl) / 4 ;
130
+ exec_on_eval<4 >(n >> lvl, leaf >> lvl, [&](size_t k, point rt) {
131
+ k <<= lvl;
119
132
vpoint v1 = {vz + real (rt), vz - imag (rt)};
120
133
vpoint v2 = v1 * v1;
121
134
vpoint v3 = v1 * v2;
@@ -124,21 +137,10 @@ namespace cp_algo::math::fft {
124
137
auto B = at (j + i);
125
138
auto C = at (j + 2 * i);
126
139
auto D = at (j + 3 * i);
127
- at (j) = (A + B + C + D);
128
- at (j + 2 * i) = (A + B - C - D) * v2;
129
- at (j + i) = (A - B - vi (C - D)) * v1;
130
- at (j + 3 * i) = (A - B + vi (C - D)) * v3;
131
- }
132
- });
133
- i *= 2 ;
134
- } else { // radix-2 fallback
135
- exec_on_evals<2 >(n / (2 * i), [&](size_t k, point rt) {
136
- k *= 2 * i;
137
- vpoint cvrt = {vz + real (rt), vz - imag (rt)};
138
- for (size_t j = k; j < k + i; j += flen) {
139
- auto B = at (j) - at (j + i);
140
- at (j) += at (j + i);
141
- at (j + i) = B * cvrt;
140
+ at (j) = ((A + B) + (C + D));
141
+ at (j + 2 * i) = ((A + B) - (C + D)) * v2;
142
+ at (j + i) = ((A - B) - vi (C - D)) * v1;
143
+ at (j + 3 * i) = ((A - B) + vi (C - D)) * v3;
142
144
}
143
145
});
144
146
}
@@ -150,11 +152,14 @@ namespace cp_algo::math::fft {
150
152
}
151
153
void fft () {
152
154
size_t n = size ();
153
- for (size_t i = n / 2 ; i >= flen; i /= 2 ) {
154
- if (i / 2 >= flen) { // radix-4
155
- i /= 2 ;
156
- exec_on_evals<4 >(n / (4 * i), [&](size_t k, point rt) {
157
- k *= 4 * i;
155
+ bool parity = std::countr_zero (n) % 2 ;
156
+ for (size_t leaf = 0 ; leaf < n; leaf += 4 * flen) {
157
+ size_t level = std::countr_zero (n + leaf);
158
+ level -= level % 2 != parity;
159
+ for (size_t lvl = level; lvl >= 4 ; lvl -= 2 ) {
160
+ size_t i = (1 << lvl) / 4 ;
161
+ exec_on_eval<4 >(n >> lvl, leaf >> lvl, [&](size_t k, point rt) {
162
+ k <<= lvl;
158
163
vpoint v1 = {vz + real (rt), vz + imag (rt)};
159
164
vpoint v2 = v1 * v1;
160
165
vpoint v3 = v1 * v2;
@@ -169,18 +174,17 @@ namespace cp_algo::math::fft {
169
174
at (j + 3 * i) = (A - C) - vi (B - D);
170
175
}
171
176
});
172
- } else { // radix-2 fallback
173
- exec_on_evals<2 >(n / (2 * i), [&](size_t k, point rt) {
174
- k *= 2 * i;
175
- vpoint vrt = {vz + real (rt), vz + imag (rt)};
176
- for (size_t j = k; j < k + i; j += flen) {
177
- auto t = at (j + i) * vrt;
178
- at (j + i) = at (j) - t;
179
- at (j) += t;
180
- }
181
- });
182
177
}
183
178
}
179
+ if (parity) {
180
+ exec_on_evals<2 >(n / (2 * flen), [&](size_t k, point rt) {
181
+ k *= 2 * flen;
182
+ vpoint vrt = {vz + real (rt), vz + imag (rt)};
183
+ auto t = at (k + flen) * vrt;
184
+ at (k + flen) = at (k) - t;
185
+ at (k) += t;
186
+ });
187
+ }
184
188
checkpoint (" fft" );
185
189
}
186
190
static constexpr size_t pre_evals = 1 << 16 ;
0 commit comments