Skip to content

Commit ede08ee

Browse files
committed
Uee montgomery in mod recovery
1 parent c2eccc0 commit ede08ee

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

cp-algo/math/cvector.hpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,22 @@ namespace cp_algo::math::fft {
1919
vftype abs(vftype a) {
2020
return a < 0 ? -a : a;
2121
}
22-
using vint [[gnu::vector_size(flen * sizeof(int64_t))]] = int64_t;
22+
using i64x4 [[gnu::vector_size(bytes)]] = int64_t;
23+
using u64x4 [[gnu::vector_size(bytes)]] = uint64_t;
2324
auto lround(vftype a) {
24-
return __builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, vint);
25+
return __builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, i64x4);
2526
}
2627
auto round(vftype a) {
2728
return __builtin_convertvector(lround(a), vftype);
2829
}
30+
u64x4 montgomery_reduce(u64x4 x, u64x4 mod, u64x4 imod) {
31+
auto x_ninv = _mm256_mul_epu32(__m256i(x), __m256i(imod));
32+
auto x_res = _mm256_add_epi64(__m256i(x), _mm256_mul_epu32(x_ninv, __m256i(mod)));
33+
return u64x4(_mm256_bsrli_epi128(x_res, 4));
34+
}
35+
u64x4 montgomery_mul(u64x4 x, u64x4 y, u64x4 mod, u64x4 imod) {
36+
return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod);
37+
}
2938

3039
struct cvector {
3140
std::vector<vpoint, big_alloc<vpoint>> r;

cp-algo/math/fft.hpp

+32-16
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@ namespace cp_algo::math::fft {
1414
using Int2 = base::Int2;
1515
static bool _init;
1616
static int split;
17+
static u64x4 mod, imod;
1718

1819
void init() {
1920
if(!_init) {
2021
factor = 1 + random::rng() % (base::mod() - 1);
2122
split = int(std::sqrt(base::mod())) + 1;
2223
ifactor = base(1) / factor;
24+
mod = u64x4() + base::mod();
25+
imod = u64x4() + inv2(-base::mod());
2326
_init = true;
2427
}
2528
}
@@ -79,28 +82,38 @@ namespace cp_algo::math::fft {
7982
}
8083

8184
void recover_mod(auto &&C, auto &res, size_t k) {
85+
assert(size(res) % flen == 0);
8286
size_t n = A.size();
8387
auto splitsplit = base(split * split).getr();
84-
base stepn = bpow(ifactor, n);
85-
base cur[] = {bpow(ifactor, 2), bpow(ifactor, 3), bpow(ifactor, 4), bpow(ifactor, 5)};
86-
base step4 = cur[2];
88+
base b2x32 = bpow(base(2), 32);
89+
base b2x64 = bpow(base(2), 64);
90+
u64x4 cur = {
91+
(bpow(ifactor, 2) * b2x64).getr(),
92+
(bpow(ifactor, 3) * b2x64).getr(),
93+
(bpow(ifactor, 4) * b2x64).getr(),
94+
(bpow(ifactor, 5) * b2x64).getr()
95+
};
96+
u64x4 step4 = u64x4{} + (bpow(ifactor, 4) * b2x32).getr();
97+
u64x4 stepn = u64x4{} + (bpow(ifactor, n) * b2x32).getr();
8798
for(size_t i = 0; i < std::min(n, k); i += flen) {
8899
auto [Ax, Ay] = A.at(i);
89100
auto [Bx, By] = B.at(i);
90101
auto [Cx, Cy] = C.at(i);
91-
auto A0 = lround(Ax), A1 = lround(Cx), A2 = lround(Bx);
92-
auto B0 = lround(Ay), B1 = lround(Cy), B2 = lround(By);
93-
for(size_t j = 0; j < flen; j++) {
94-
if(i + j < k) {
95-
res[i + j] = A0[j] + A1[j] * split + A2[j] * splitsplit;
96-
res[i + j] *= cur[j];
102+
auto set_i = [&](size_t i, auto A, auto B, auto C, auto mul) {
103+
auto A0 = lround(A), A1 = lround(C), A2 = lround(B);
104+
auto Ai = A0 + A1 * split + A2 * splitsplit + base::modmod();
105+
auto Au = montgomery_reduce(u64x4(Ai), mod, imod);
106+
Au = montgomery_mul(Au, mul, mod, imod);
107+
Au = Au >= base::mod() ? Au - base::mod() : Au;
108+
for(size_t j = 0; j < flen; j++) {
109+
res[i + j].setr(Au[j]);
97110
}
98-
if(i + j + n < k) {
99-
res[i + j + n] = B0[j] + B1[j] * split + B2[j] * splitsplit;
100-
res[i + j + n] *= cur[j] * stepn;
101-
}
102-
cur[j] *= step4;
111+
};
112+
set_i(i, Ax, Bx, Cx, cur);
113+
if(i + n < k) {
114+
set_i(i + n, Ay, By, Cy, montgomery_mul(cur, stepn, mod, imod));
103115
}
116+
cur = montgomery_mul(cur, step4, mod, imod);
104117
}
105118
checkpoint("recover mod");
106119
}
@@ -144,6 +157,8 @@ namespace cp_algo::math::fft {
144157
template<modint_type base> base dft<base>::ifactor = 1;
145158
template<modint_type base> bool dft<base>::_init = false;
146159
template<modint_type base> int dft<base>::split = 1;
160+
template<modint_type base> u64x4 dft<base>::mod = {};
161+
template<modint_type base> u64x4 dft<base>::imod = {};
147162

148163
void mul_slow(auto &a, auto const& b, size_t k) {
149164
if(empty(a) || empty(b)) {
@@ -176,13 +191,13 @@ namespace cp_algo::math::fft {
176191
std::min(k, size(a)) + std::min(k, size(b)) - 1
177192
) / 2);
178193
auto A = dft<base>(a | std::views::take(k), n);
179-
a.assign(k, 0);
180-
checkpoint("reset a");
194+
a.assign((k / flen + 1) * flen, 0);
181195
if(&a == &b) {
182196
A.mul(A, a, k);
183197
} else {
184198
A.mul_inplace(dft<base>(b | std::views::take(k), n), a, k);
185199
}
200+
a.resize(k);
186201
}
187202
void mul(auto &a, auto const& b) {
188203
size_t N = size(a) + size(b) - 1;
@@ -213,6 +228,7 @@ namespace cp_algo::math::fft {
213228
a[i + n] += ai;
214229
}
215230
a.resize(N);
231+
checkpoint("karatsuba join");
216232
} else if(size(a)) {
217233
mul_truncate(a, b, N);
218234
}

cp-algo/number_theory/modint.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ namespace cp_algo::math {
6161
auto operator > (const modint &t) const {return to_modint().getr() > t.getr();}
6262
Int rem() const {
6363
UInt R = to_modint().getr();
64-
return 2 * R > (UInt)mod() ? R - mod() : R;
64+
return R - (R > (UInt)mod() / 2) * mod();
6565
}
6666
void setr(UInt rr) {
6767
r = rr;

0 commit comments

Comments
 (0)