Skip to content

Commit 069d24d

Browse files
Merge pull request #174 from GiacomoPope/improve_powmod_and_compose
add compose_mod and powmod with large exp
2 parents 1178e20 + e3c9d03 commit 069d24d

File tree

4 files changed

+332
-21
lines changed

4 files changed

+332
-21
lines changed

Diff for: src/flint/flintlib/nmod_poly.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ cdef extern from "flint/nmod_poly.h":
5555
int nmod_poly_equal_trunc(const nmod_poly_t poly1, const nmod_poly_t poly2, slong n)
5656
int nmod_poly_is_zero(const nmod_poly_t poly)
5757
int nmod_poly_is_one(const nmod_poly_t poly)
58+
int nmod_poly_is_gen(const nmod_poly_t poly)
5859
void _nmod_poly_shift_left(mp_ptr res, mp_srcptr poly, slong len, slong k)
5960
void nmod_poly_shift_left(nmod_poly_t res, const nmod_poly_t poly, slong k)
6061
void _nmod_poly_shift_right(mp_ptr res, mp_srcptr poly, slong len, slong k)

Diff for: src/flint/test/test_all.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,15 @@ def test_nmod_poly():
14221422
assert raises(lambda: [] * s, TypeError)
14231423
assert raises(lambda: [] // s, TypeError)
14241424
assert raises(lambda: [] % s, TypeError)
1425-
assert raises(lambda: pow(P([1,2],3), 3, 4), NotImplementedError)
1425+
assert raises(lambda: [] % s, TypeError)
1426+
assert raises(lambda: s.reverse(-1), ValueError)
1427+
assert raises(lambda: s.compose("A"), TypeError)
1428+
assert raises(lambda: s.compose_mod(s, "A"), TypeError)
1429+
assert raises(lambda: s.compose_mod("A", P([3,6,9],17)), TypeError)
1430+
assert raises(lambda: s.compose_mod(s, P([0], 17)), ZeroDivisionError)
1431+
assert raises(lambda: pow(s, -1, P([3,6,9],17)), ValueError)
1432+
assert raises(lambda: pow(s, 1, "A"), TypeError)
1433+
assert raises(lambda: pow(s, "A", P([3,6,9],17)), TypeError)
14261434
assert str(P([1,2,3],17)) == "3*x^2 + 2*x + 1"
14271435
assert P([1,2,3],17).repr() == "nmod_poly([1, 2, 3], 17)"
14281436
p = P([3,4,5],17)
@@ -2087,6 +2095,18 @@ def test_fmpz_mod_poly():
20872095
assert f*f == f**2
20882096
assert f*f == f**fmpz(2)
20892097

2098+
# pow_mod
2099+
# assert ui and fmpz exp agree for polynomials and generators
2100+
R_gen = R_test.gen()
2101+
assert pow(f, 2**60, g) == pow(pow(f, 2**30, g), 2**30, g)
2102+
assert pow(R_gen, 2**60, g) == pow(pow(R_gen, 2**30, g), 2**30, g)
2103+
2104+
# Check other typechecks for pow_mod
2105+
assert raises(lambda: pow(f, -2, g), ValueError)
2106+
assert raises(lambda: pow(f, 1, "A"), TypeError)
2107+
assert raises(lambda: pow(f, "A", g), TypeError)
2108+
assert raises(lambda: f.pow_mod(2**32, g, mod_rev_inv="A"), TypeError)
2109+
20902110
# Shifts
20912111
assert raises(lambda: R_test([1,2,3]).left_shift(-1), ValueError)
20922112
assert raises(lambda: R_test([1,2,3]).right_shift(-1), ValueError)
@@ -2118,6 +2138,13 @@ def test_fmpz_mod_poly():
21182138
# compose
21192139
assert raises(lambda: h.compose("AAA"), TypeError)
21202140

2141+
# compose mod
2142+
mod = R_test([1,2,3,4])
2143+
assert f.compose(h) % mod == f.compose_mod(h, mod)
2144+
assert raises(lambda: h.compose_mod("AAA", mod), TypeError)
2145+
assert raises(lambda: h.compose_mod(f, "AAA"), TypeError)
2146+
assert raises(lambda: h.compose_mod(f, R_test(0)), ZeroDivisionError)
2147+
21212148
# Reverse
21222149
assert raises(lambda: h.reverse(degree=-100), ValueError)
21232150
assert R_test([-1,-2,-3]).reverse() == R_test([-3,-2,-1])
@@ -2135,9 +2162,9 @@ def test_fmpz_mod_poly():
21352162
assert raises(lambda: f.mulmod(f, "AAA"), TypeError)
21362163
assert raises(lambda: f.mulmod("AAA", g), TypeError)
21372164

2138-
# powmod
2139-
assert f.powmod(2, g) == (f*f) % g
2140-
assert raises(lambda: f.powmod(2, "AAA"), TypeError)
2165+
# pow_mod
2166+
assert f.pow_mod(2, g) == (f*f) % g
2167+
assert raises(lambda: f.pow_mod(2, "AAA"), TypeError)
21412168

21422169
# divmod
21432170
S, T = f.divmod(g)
@@ -2635,9 +2662,14 @@ def setbad(obj, i, val):
26352662
assert P([1, 1]) ** 2 == P([1, 2, 1])
26362663
assert raises(lambda: P([1, 1]) ** -1, ValueError)
26372664
assert raises(lambda: P([1, 1]) ** None, TypeError)
2638-
2639-
# # XXX: Not sure what this should do in general:
2640-
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)
2665+
2666+
# XXX: Not sure what this should do in general:
2667+
p = P([1, 1])
2668+
mod = P([1, 1])
2669+
if type(p) not in [flint.fmpz_mod_poly, flint.nmod_poly]:
2670+
assert raises(lambda: pow(p, 2, mod), NotImplementedError)
2671+
else:
2672+
assert p * p % mod == pow(p, 2, mod)
26412673

26422674
assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1])
26432675
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
@@ -2667,7 +2699,6 @@ def setbad(obj, i, val):
26672699
if is_field:
26682700
assert P([1, 2, 1]).integral() == P([0, 1, 1, S(1)/3])
26692701

2670-
26712702
def _all_mpolys():
26722703
return [
26732704
(flint.fmpz_mpoly, flint.fmpz_mpoly_ctx, flint.fmpz, False),

Diff for: src/flint/types/fmpz_mod_poly.pyx

+86-11
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ cdef class fmpz_mod_poly(flint_poly):
536536

537537
def __pow__(self, e, mod=None):
538538
if mod is not None:
539-
raise NotImplementedError
539+
return self.pow_mod(e, mod)
540540

541541
cdef fmpz_mod_poly res
542542
if e < 0:
@@ -778,11 +778,11 @@ cdef class fmpz_mod_poly(flint_poly):
778778

779779
return evaluations
780780

781-
def compose(self, input):
781+
def compose(self, other):
782782
"""
783783
Returns the composition of two polynomials
784784
785-
To be precise about the order of composition, given ``self``, and ``input``
785+
To be precise about the order of composition, given ``self``, and ``other``
786786
by `f(x)`, `g(x)`, returns `f(g(x))`.
787787
788788
>>> R = fmpz_mod_poly_ctx(163)
@@ -794,12 +794,45 @@ cdef class fmpz_mod_poly(flint_poly):
794794
9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
795795
"""
796796
cdef fmpz_mod_poly res
797-
val = self.ctx.any_as_fmpz_mod_poly(input)
797+
val = self.ctx.any_as_fmpz_mod_poly(other)
798798
if val is NotImplemented:
799-
raise TypeError(f"Cannot compose the polynomial with input: {input}")
799+
raise TypeError(f"Cannot compose the polynomial with input: {other}")
800800

801801
res = self.ctx.new_ctype_poly()
802802
fmpz_mod_poly_compose(res.val, self.val, (<fmpz_mod_poly>val).val, self.ctx.mod.val)
803+
return res
804+
805+
def compose_mod(self, other, modulus):
806+
"""
807+
Returns the composition of two polynomials modulo a third.
808+
809+
To be precise about the order of composition, given ``self``, and ``other``
810+
and ``modulus`` by `f(x)`, `g(x)` and `h(x)`, returns `f(g(x)) \mod h(x)`.
811+
We require that `h(x)` is non-zero.
812+
813+
>>> R = fmpz_mod_poly_ctx(163)
814+
>>> f = R([1,2,3,4,5])
815+
>>> g = R([3,2,1])
816+
>>> h = R([1,0,1,0,1])
817+
>>> f.compose_mod(g, h)
818+
63*x^3 + 100*x^2 + 17*x + 63
819+
>>> g.compose_mod(f, h)
820+
147*x^3 + 159*x^2 + 4*x + 7
821+
"""
822+
cdef fmpz_mod_poly res
823+
val = self.ctx.any_as_fmpz_mod_poly(other)
824+
if val is NotImplemented:
825+
raise TypeError(f"cannot compose the polynomial with input: {other}")
826+
827+
h = self.ctx.any_as_fmpz_mod_poly(modulus)
828+
if h is NotImplemented:
829+
raise TypeError(f"cannot reduce the polynomial with input: {modulus}")
830+
831+
if h.is_zero():
832+
raise ZeroDivisionError("cannot reduce modulo zero")
833+
834+
res = self.ctx.new_ctype_poly()
835+
fmpz_mod_poly_compose_mod(res.val, self.val, (<fmpz_mod_poly>val).val, (<fmpz_mod_poly>h).val, self.ctx.mod.val)
803836
return res
804837

805838
cpdef long length(self):
@@ -1104,30 +1137,72 @@ cdef class fmpz_mod_poly(flint_poly):
11041137
)
11051138
return res
11061139

1107-
def powmod(self, e, modulus):
1140+
def pow_mod(self, e, modulus, mod_rev_inv=None):
11081141
"""
11091142
Returns ``self`` raised to the power ``e`` modulo ``modulus``:
1110-
:math:`f^e \mod g`
1143+
:math:`f^e \mod g`/
1144+
1145+
``mod_rev_inv`` is the inverse of the reverse of the modulus,
1146+
precomputing it and passing it to ``pow_mod()`` can optimise
1147+
powering of polynomials with large exponents.
11111148
11121149
>>> R = fmpz_mod_poly_ctx(163)
11131150
>>> x = R.gen()
11141151
>>> f = 30*x**6 + 104*x**5 + 76*x**4 + 33*x**3 + 70*x**2 + 44*x + 65
11151152
>>> g = 43*x**6 + 91*x**5 + 77*x**4 + 113*x**3 + 71*x**2 + 132*x + 60
11161153
>>> mod = x**4 + 93*x**3 + 78*x**2 + 72*x + 149
11171154
>>>
1118-
>>> f.powmod(123, mod)
1155+
>>> f.pow_mod(123, mod)
11191156
3*x^3 + 25*x^2 + 115*x + 161
1157+
>>> f.pow_mod(2**64, mod)
1158+
52*x^3 + 96*x^2 + 136*x + 9
1159+
>>> mod_rev_inv = mod.reverse().inverse_series_trunc(4)
1160+
>>> f.pow_mod(2**64, mod, mod_rev_inv)
1161+
52*x^3 + 96*x^2 + 136*x + 9
11201162
"""
11211163
cdef fmpz_mod_poly res
11221164

1165+
if e < 0:
1166+
raise ValueError("Exponent must be non-negative")
1167+
11231168
modulus = self.ctx.any_as_fmpz_mod_poly(modulus)
11241169
if modulus is NotImplemented:
11251170
raise TypeError(f"Cannot interpret {modulus} as a polynomial")
11261171

1172+
# Output polynomial
11271173
res = self.ctx.new_ctype_poly()
1128-
fmpz_mod_poly_powmod_ui_binexp(
1129-
res.val, self.val, <ulong>e, (<fmpz_mod_poly>modulus).val, res.ctx.mod.val
1130-
)
1174+
1175+
# For small exponents, use a simple binary exponentiation method
1176+
if e.bit_length() < 32:
1177+
fmpz_mod_poly_powmod_ui_binexp(
1178+
res.val, self.val, <ulong>e, (<fmpz_mod_poly>modulus).val, res.ctx.mod.val
1179+
)
1180+
return res
1181+
1182+
# For larger exponents we need to cast e to an fmpz first
1183+
e_fmpz = any_as_fmpz(e)
1184+
if e_fmpz is NotImplemented:
1185+
raise TypeError(f"exponent cannot be cast to an fmpz type: {e = }")
1186+
1187+
# To optimise powering, we precompute the inverse of the reverse of the modulus
1188+
if mod_rev_inv is not None:
1189+
mod_rev_inv = self.ctx.any_as_fmpz_mod_poly(mod_rev_inv)
1190+
if mod_rev_inv is NotImplemented:
1191+
raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial")
1192+
else:
1193+
mod_rev_inv = modulus.reverse().inverse_series_trunc(modulus.length())
1194+
1195+
# Use windowed exponentiation optimisation when self = x
1196+
if self.is_gen():
1197+
fmpz_mod_poly_powmod_x_fmpz_preinv(
1198+
res.val, (<fmpz>e_fmpz).val, (<fmpz_mod_poly>modulus).val, (<fmpz_mod_poly>mod_rev_inv).val, res.ctx.mod.val
1199+
)
1200+
return res
1201+
1202+
# Otherwise using binary exponentiation for all other inputs
1203+
fmpz_mod_poly_powmod_fmpz_binexp_preinv(
1204+
res.val, self.val, (<fmpz>e_fmpz).val, (<fmpz_mod_poly>modulus).val, (<fmpz_mod_poly>mod_rev_inv).val, res.ctx.mod.val
1205+
)
11311206
return res
11321207

11331208
def divmod(self, other):

0 commit comments

Comments
 (0)