@@ -14,12 +14,15 @@ namespace cp_algo::math::fft {
14
14
using Int2 = base::Int2;
15
15
static bool _init;
16
16
static int split;
17
+ static u64x4 mod, imod;
17
18
18
19
void init () {
19
20
if (!_init) {
20
21
factor = 1 + random ::rng () % (base::mod () - 1 );
21
22
split = int (std::sqrt (base::mod ())) + 1 ;
22
23
ifactor = base (1 ) / factor;
24
+ mod = u64x4 () + base::mod ();
25
+ imod = u64x4 () + inv2 (-base::mod ());
23
26
_init = true ;
24
27
}
25
28
}
@@ -79,28 +82,38 @@ namespace cp_algo::math::fft {
79
82
}
80
83
81
84
void recover_mod (auto &&C, auto &res, size_t k) {
85
+ assert (size (res) % flen == 0 );
82
86
size_t n = A.size ();
83
87
auto splitsplit = base (split * split).getr ();
84
- base stepn = bpow (ifactor, n);
85
- base cur[] = {bpow (ifactor, 2 ), bpow (ifactor, 3 ), bpow (ifactor, 4 ), bpow (ifactor, 5 )};
86
- base step4 = cur[2 ];
88
+ base b2x32 = bpow (base (2 ), 32 );
89
+ base b2x64 = bpow (base (2 ), 64 );
90
+ u64x4 cur = {
91
+ (bpow (ifactor, 2 ) * b2x64).getr (),
92
+ (bpow (ifactor, 3 ) * b2x64).getr (),
93
+ (bpow (ifactor, 4 ) * b2x64).getr (),
94
+ (bpow (ifactor, 5 ) * b2x64).getr ()
95
+ };
96
+ u64x4 step4 = u64x4{} + (bpow (ifactor, 4 ) * b2x32).getr ();
97
+ u64x4 stepn = u64x4{} + (bpow (ifactor, n) * b2x32).getr ();
87
98
for (size_t i = 0 ; i < std::min (n, k); i += flen) {
88
99
auto [Ax, Ay] = A.at (i);
89
100
auto [Bx, By] = B.at (i);
90
101
auto [Cx, Cy] = C.at (i);
91
- auto A0 = lround (Ax), A1 = lround (Cx), A2 = lround (Bx);
92
- auto B0 = lround (Ay), B1 = lround (Cy), B2 = lround (By);
93
- for (size_t j = 0 ; j < flen; j++) {
94
- if (i + j < k) {
95
- res[i + j] = A0[j] + A1[j] * split + A2[j] * splitsplit;
96
- res[i + j] *= cur[j];
102
+ auto set_i = [&](size_t i, auto A, auto B, auto C, auto mul) {
103
+ auto A0 = lround (A), A1 = lround (C), A2 = lround (B);
104
+ auto Ai = A0 + A1 * split + A2 * splitsplit + base::modmod ();
105
+ auto Au = montgomery_reduce (u64x4 (Ai), mod, imod);
106
+ Au = montgomery_mul (Au, mul, mod, imod);
107
+ Au = Au >= base::mod () ? Au - base::mod () : Au;
108
+ for (size_t j = 0 ; j < flen; j++) {
109
+ res[i + j].setr (Au[j]);
97
110
}
98
- if (i + j + n < k) {
99
- res[i + j + n] = B0[j] + B1[j] * split + B2[j] * splitsplit;
100
- res[i + j + n] *= cur[j] * stepn;
101
- }
102
- cur[j] *= step4;
111
+ };
112
+ set_i (i, Ax, Bx, Cx, cur);
113
+ if (i + n < k) {
114
+ set_i (i + n, Ay, By, Cy, montgomery_mul (cur, stepn, mod, imod));
103
115
}
116
+ cur = montgomery_mul (cur, step4, mod, imod);
104
117
}
105
118
checkpoint (" recover mod" );
106
119
}
@@ -144,6 +157,8 @@ namespace cp_algo::math::fft {
144
157
template <modint_type base> base dft<base>::ifactor = 1 ;
145
158
template <modint_type base> bool dft<base>::_init = false ;
146
159
template <modint_type base> int dft<base>::split = 1 ;
160
+ template <modint_type base> u64x4 dft<base>::mod = {};
161
+ template <modint_type base> u64x4 dft<base>::imod = {};
147
162
148
163
void mul_slow (auto &a, auto const & b, size_t k) {
149
164
if (empty (a) || empty (b)) {
@@ -176,13 +191,13 @@ namespace cp_algo::math::fft {
176
191
std::min (k, size (a)) + std::min (k, size (b)) - 1
177
192
) / 2 );
178
193
auto A = dft<base>(a | std::views::take (k), n);
179
- a.assign (k, 0 );
180
- checkpoint (" reset a" );
194
+ a.assign ((k / flen + 1 ) * flen, 0 );
181
195
if (&a == &b) {
182
196
A.mul (A, a, k);
183
197
} else {
184
198
A.mul_inplace (dft<base>(b | std::views::take (k), n), a, k);
185
199
}
200
+ a.resize (k);
186
201
}
187
202
void mul (auto &a, auto const & b) {
188
203
size_t N = size (a) + size (b) - 1 ;
@@ -213,6 +228,7 @@ namespace cp_algo::math::fft {
213
228
a[i + n] += ai;
214
229
}
215
230
a.resize (N);
231
+ checkpoint (" karatsuba join" );
216
232
} else if (size (a)) {
217
233
mul_truncate (a, b, N);
218
234
}
0 commit comments