@@ -12,23 +12,27 @@ namespace cp_algo::math::fft {
12
12
cvector A, B;
13
13
static base factor, ifactor;
14
14
using Int2 = base::Int2;
15
- static bool init ;
15
+ static bool _init ;
16
16
static int split;
17
17
18
- dft ( auto const & a, size_t n): A(n), B(n ) {
19
- if (!init ) {
18
+ void init ( ) {
19
+ if (!_init ) {
20
20
factor = 1 + random ::rng () % (base::mod () - 1 );
21
21
split = int (std::sqrt (base::mod ())) + 1 ;
22
22
ifactor = base (1 ) / factor;
23
- init = true ;
23
+ _init = true ;
24
24
}
25
+ }
26
+
27
+ dft (auto const & a, size_t n): A(n), B(n) {
28
+ init ();
25
29
base cur = factor;
26
30
base step = bpow (factor, n);
27
31
for (size_t i = 0 ; i < std::min (n, size (a)); i++) {
28
32
auto splt = [&](size_t i, auto mul) {
29
- auto ai = i < size (a) ? (a[i] * mul).rem () : 0 ;
30
- auto rem = ai % split;
33
+ auto ai = i < size (a) ? (a[i] * mul).getr () : 0 ;
31
34
auto quo = ai / split;
35
+ auto rem = ai % split;
32
36
return std::pair{(ftype)rem, (ftype)quo};
33
37
};
34
38
auto [rai, qai] = splt (i, cur);
@@ -74,6 +78,33 @@ namespace cp_algo::math::fft {
74
78
checkpoint (" dot" );
75
79
}
76
80
81
+ void recover_mod (auto &&C, auto &res, size_t k) {
82
+ size_t n = A.size ();
83
+ auto splitsplit = base (split * split).getr ();
84
+ base stepn = bpow (ifactor, n);
85
+ base cur[] = {bpow (ifactor, 2 ), bpow (ifactor, 3 ), bpow (ifactor, 4 ), bpow (ifactor, 5 )};
86
+ base step4 = cur[2 ];
87
+ for (size_t i = 0 ; i < std::min (n, k); i += flen) {
88
+ auto [Ax, Ay] = A.at (i);
89
+ auto [Bx, By] = B.at (i);
90
+ auto [Cx, Cy] = C.at (i);
91
+ auto A0 = lround (Ax), A1 = lround (Cx), A2 = lround (Bx);
92
+ auto B0 = lround (Ay), B1 = lround (Cy), B2 = lround (By);
93
+ for (size_t j = 0 ; j < flen; j++) {
94
+ if (i + j < k) {
95
+ res[i + j] = A0[j] + A1[j] * split + A2[j] * splitsplit;
96
+ res[i + j] *= cur[j];
97
+ }
98
+ if (i + j + n < k) {
99
+ res[i + j + n] = B0[j] + B1[j] * split + B2[j] * splitsplit;
100
+ res[i + j + n] *= cur[j] * stepn;
101
+ }
102
+ cur[j] *= step4;
103
+ }
104
+ }
105
+ checkpoint (" recover mod" );
106
+ }
107
+
77
108
void mul (auto &&C, auto const & D, auto &res, size_t k) {
78
109
assert (A.size () == C.size ());
79
110
size_t n = A.size ();
@@ -85,28 +116,7 @@ namespace cp_algo::math::fft {
85
116
A.ifft ();
86
117
B.ifft ();
87
118
C.ifft ();
88
- auto splitsplit = (base (split) * split).rem ();
89
- base cur = ifactor * ifactor;
90
- base step = bpow (ifactor, n);
91
- for (size_t i = 0 ; i < std::min (n, k); i++) {
92
- auto [Ax, Ay] = A.get (i);
93
- auto [Bx, By] = B.get (i);
94
- auto [Cx, Cy] = C.get (i);
95
- Int2 A0 = llround (Ax);
96
- Int2 A1 = llround (Cx);
97
- Int2 A2 = llround (Bx);
98
- res[i] = A0 + A1 * split + A2 * splitsplit;
99
- res[i] *= cur;
100
- if (n + i < k) {
101
- Int2 B0 = llround (Ay);
102
- Int2 B1 = llround (Cy);
103
- Int2 B2 = llround (By);
104
- res[n + i] = B0 + B1 * split + B2 * splitsplit;
105
- res[n + i] *= cur * step;
106
- }
107
- cur *= ifactor;
108
- }
109
- checkpoint (" recover mod" );
119
+ recover_mod (C, res, k);
110
120
}
111
121
void mul_inplace (auto &&B, auto & res, size_t k) {
112
122
mul (B.A , B.B , res, k);
@@ -132,7 +142,7 @@ namespace cp_algo::math::fft {
132
142
};
133
143
template <modint_type base> base dft<base>::factor = 1 ;
134
144
template <modint_type base> base dft<base>::ifactor = 1 ;
135
- template <modint_type base> bool dft<base>::init = false ;
145
+ template <modint_type base> bool dft<base>::_init = false ;
136
146
template <modint_type base> int dft<base>::split = 1 ;
137
147
138
148
void mul_slow (auto &a, auto const & b, size_t k) {
0 commit comments