Skip to content

Commit bb3dea3

Browse files
committed
Vectorize recover_mod (a bit)
1 parent b61487e commit bb3dea3

File tree

3 files changed

+51
-41
lines changed

3 files changed

+51
-41
lines changed

cp-algo/math/cvector.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@ namespace cp_algo::math::fft {
1616
using vpoint = complex<vftype>;
1717
static constexpr vftype vz = {};
1818
static constexpr vpoint vi = {vz, vz + 1};
19+
vftype abs(vftype a) {
20+
return a < 0 ? -a : a;
21+
}
22+
using vint [[gnu::vector_size(flen * sizeof(int64_t))]] = int64_t;
23+
auto lround(vftype a) {
24+
return __builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, vint);
25+
}
26+
auto round(vftype a) {
27+
return __builtin_convertvector(lround(a), vftype);
28+
}
1929

2030
struct cvector {
2131
std::vector<vpoint, big_alloc<vpoint>> r;

cp-algo/math/fft.hpp

+39-29
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,27 @@ namespace cp_algo::math::fft {
1212
cvector A, B;
1313
static base factor, ifactor;
1414
using Int2 = base::Int2;
15-
static bool init;
15+
static bool _init;
1616
static int split;
1717

18-
dft(auto const& a, size_t n): A(n), B(n) {
19-
if(!init) {
18+
void init() {
19+
if(!_init) {
2020
factor = 1 + random::rng() % (base::mod() - 1);
2121
split = int(std::sqrt(base::mod())) + 1;
2222
ifactor = base(1) / factor;
23-
init = true;
23+
_init = true;
2424
}
25+
}
26+
27+
dft(auto const& a, size_t n): A(n), B(n) {
28+
init();
2529
base cur = factor;
2630
base step = bpow(factor, n);
2731
for(size_t i = 0; i < std::min(n, size(a)); i++) {
2832
auto splt = [&](size_t i, auto mul) {
29-
auto ai = i < size(a) ? (a[i] * mul).rem() : 0;
30-
auto rem = ai % split;
33+
auto ai = i < size(a) ? (a[i] * mul).getr() : 0;
3134
auto quo = ai / split;
35+
auto rem = ai % split;
3236
return std::pair{(ftype)rem, (ftype)quo};
3337
};
3438
auto [rai, qai] = splt(i, cur);
@@ -74,6 +78,33 @@ namespace cp_algo::math::fft {
7478
checkpoint("dot");
7579
}
7680

81+
void recover_mod(auto &&C, auto &res, size_t k) {
82+
size_t n = A.size();
83+
auto splitsplit = base(split * split).getr();
84+
base stepn = bpow(ifactor, n);
85+
base cur[] = {bpow(ifactor, 2), bpow(ifactor, 3), bpow(ifactor, 4), bpow(ifactor, 5)};
86+
base step4 = cur[2];
87+
for(size_t i = 0; i < std::min(n, k); i += flen) {
88+
auto [Ax, Ay] = A.at(i);
89+
auto [Bx, By] = B.at(i);
90+
auto [Cx, Cy] = C.at(i);
91+
auto A0 = lround(Ax), A1 = lround(Cx), A2 = lround(Bx);
92+
auto B0 = lround(Ay), B1 = lround(Cy), B2 = lround(By);
93+
for(size_t j = 0; j < flen; j++) {
94+
if(i + j < k) {
95+
res[i + j] = A0[j] + A1[j] * split + A2[j] * splitsplit;
96+
res[i + j] *= cur[j];
97+
}
98+
if(i + j + n < k) {
99+
res[i + j + n] = B0[j] + B1[j] * split + B2[j] * splitsplit;
100+
res[i + j + n] *= cur[j] * stepn;
101+
}
102+
cur[j] *= step4;
103+
}
104+
}
105+
checkpoint("recover mod");
106+
}
107+
77108
void mul(auto &&C, auto const& D, auto &res, size_t k) {
78109
assert(A.size() == C.size());
79110
size_t n = A.size();
@@ -85,28 +116,7 @@ namespace cp_algo::math::fft {
85116
A.ifft();
86117
B.ifft();
87118
C.ifft();
88-
auto splitsplit = (base(split) * split).rem();
89-
base cur = ifactor * ifactor;
90-
base step = bpow(ifactor, n);
91-
for(size_t i = 0; i < std::min(n, k); i++) {
92-
auto [Ax, Ay] = A.get(i);
93-
auto [Bx, By] = B.get(i);
94-
auto [Cx, Cy] = C.get(i);
95-
Int2 A0 = llround(Ax);
96-
Int2 A1 = llround(Cx);
97-
Int2 A2 = llround(Bx);
98-
res[i] = A0 + A1 * split + A2 * splitsplit;
99-
res[i] *= cur;
100-
if(n + i < k) {
101-
Int2 B0 = llround(Ay);
102-
Int2 B1 = llround(Cy);
103-
Int2 B2 = llround(By);
104-
res[n + i] = B0 + B1 * split + B2 * splitsplit;
105-
res[n + i] *= cur * step;
106-
}
107-
cur *= ifactor;
108-
}
109-
checkpoint("recover mod");
119+
recover_mod(C, res, k);
110120
}
111121
void mul_inplace(auto &&B, auto& res, size_t k) {
112122
mul(B.A, B.B, res, k);
@@ -132,7 +142,7 @@ namespace cp_algo::math::fft {
132142
};
133143
template<modint_type base> base dft<base>::factor = 1;
134144
template<modint_type base> base dft<base>::ifactor = 1;
135-
template<modint_type base> bool dft<base>::init = false;
145+
template<modint_type base> bool dft<base>::_init = false;
136146
template<modint_type base> int dft<base>::split = 1;
137147

138148
void mul_slow(auto &a, auto const& b, size_t k) {

verify/poly/wildcard.test.cpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,10 @@ void semicorr(auto &a, auto &b) {
2121
a.ifft();
2222
}
2323

24-
vftype abs(vftype a) {
25-
return a < 0 ? -a : a;
26-
}
27-
28-
using v4di [[gnu::vector_size(32)]] = long;
29-
30-
auto round(vftype a) {
31-
return __builtin_convertvector(__builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, v4di), vftype);
32-
}
33-
3424
auto is_integer(auto a) {
3525
static const double eps = 1e-8;
36-
return abs(imag(a)) < eps
37-
&& abs(real(a) - round(real(a))) < eps;
26+
return fft::abs(imag(a)) < eps
27+
&& fft::abs(real(a) - fft::round(real(a))) < eps;
3828
}
3929

4030
string matches(string const& A, string const& B, char wild = '*') {

0 commit comments

Comments
 (0)