From 293ec246a958aa83dd6a79f4f5d7002d01f56a96 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 6 Aug 2024 23:11:59 +0100 Subject: [PATCH 01/30] fix(nmod): Add nmod_ctx to store is_prime --- src/flint/types/nmod.pxd | 12 +- src/flint/types/nmod.pyx | 285 ++++++++++++++++++++++++---------- src/flint/types/nmod_poly.pyx | 2 +- 3 files changed, 215 insertions(+), 84 deletions(-) diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 15b6fbc4..5f180812 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -2,8 +2,16 @@ from flint.flint_base.flint_base cimport flint_scalar from flint.flintlib.flint cimport mp_limb_t from flint.flintlib.nmod cimport nmod_t -cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1 +#cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1 +cdef nmod_ctx any_as_nmod_ctx(obj) + +cdef class nmod_ctx: + cdef nmod_t mod + cdef bint _is_prime + + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + cdef nmod _new(self, mp_limb_t * val) cdef class nmod(flint_scalar): cdef mp_limb_t val - cdef nmod_t mod + cdef nmod_ctx ctx diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 68796c4f..de29e8a5 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -12,33 +12,144 @@ from flint.flintlib.nmod_vec cimport * from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui from flint.flintlib.fmpq cimport fmpq_mod_fmpz -from flint.flintlib.ulong_extras cimport n_gcdinv, n_sqrtmod +from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime, n_sqrtmod from flint.utils.flint_exceptions import DomainError -cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1: - cdef int success - cdef fmpz_t t - if typecheck(obj, nmod): - if (obj).mod.n != mod.n: - raise ValueError("cannot coerce integers mod n with different n") - val[0] = (obj).val - return 1 - z = any_as_fmpz(obj) - if z is not NotImplemented: - val[0] = fmpz_fdiv_ui((z).val, mod.n) - return 1 - q = any_as_fmpq(obj) - if q is not NotImplemented: - fmpz_init(t) - fmpz_set_ui(t, mod.n) - success = fmpq_mod_fmpz(t, (q).val, t) - val[0] = fmpz_get_ui(t) - fmpz_clear(t) - if not success: - raise ZeroDivisionError("%s does not exist mod %i!" % (q, mod.n)) - return 1 - return 0 + +#cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1: +# return mod.ctx.any_as_nmod(val, obj) + + +_nmod_ctx_cache = {} + + +cdef nmod_ctx any_as_nmod_ctx(obj): + """Convert an int to an nmod_ctx.""" + if typecheck(obj, nmod_ctx): + return obj + if typecheck(obj, int): + ctx = _nmod_ctx_cache.get(obj) + if ctx is None: + ctx = nmod_ctx(obj) + _nmod_ctx_cache[obj] = ctx + return ctx + return NotImplemented + + +cdef class nmod_ctx: + """ + Context object for creating :class:`~.nmod` initalised + with modulus :math:`N`. + + >>> nmod_ctx(17) + nmod_ctx(17) + + """ + def __init__(self, mod): + cdef mp_limb_t m + m = mod + nmod_init(&self.mod, m) + self._is_prime = n_is_prime(m) + + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: + """Convert an object to an nmod element.""" + cdef int success + cdef fmpz_t t + if typecheck(obj, nmod): + if (obj).ctx != self: + raise ValueError("cannot coerce integers mod n with different n") + val[0] = (obj).val + return 1 + z = any_as_fmpz(obj) + if z is not NotImplemented: + val[0] = fmpz_fdiv_ui((z).val, self.mod.n) + return 1 + q = any_as_fmpq(obj) + if q is not NotImplemented: + fmpz_init(t) + fmpz_set_ui(t, self.mod.n) + success = fmpq_mod_fmpz(t, (q).val, t) + val[0] = fmpz_get_ui(t) + fmpz_clear(t) + if not success: + raise ZeroDivisionError("%s does not exist mod %i!" % (q, self.mod.n)) + return 1 + return 0 + + def modulus(self): + """Get the modulus of the context. + + >>> ctx = nmod_ctx(17) + >>> ctx.modulus() + 17 + + """ + return fmpz(self.mod) + + def is_prime(self): + """Check if the modulus is prime. + + >>> ctx = nmod_ctx(17) + >>> ctx.is_prime() + True + + """ + return self._is_prime + + def zero(self): + """Return the zero element of the context. + + >>> ctx = nmod_ctx(17) + >>> ctx.zero() + 0 + + """ + return self(0) + + def one(self): + """Return the one element of the context. + + >>> ctx = nmod_ctx(17) + >>> ctx.one() + 1 + + """ + return self(1) + + def __hash__(self): + return hash(self.mod) + + def __eq__(self, other): + if not typecheck(other, nmod_ctx): + return NotImplemented + else: + return self.mod == other.mod + + def __str__(self): + return f"Context for nmod with modulus: {self.modulus()}" + + def __repr__(self): + return f"nmod_ctx({self.modulus()})" + + cdef nmod _new(self, mp_limb_t * val): + cdef nmod r = nmod.__new__(nmod) + r.val = val[0] + r.ctx = self + return r + + def __call__(self, val): + """Create an nmod element from an integer. + + >>> ctx = nmod_ctx(17) + >>> ctx(10) + 10 + + """ + cdef mp_limb_t v + v = val + return self._new(&v) + cdef class nmod(flint_scalar): """ @@ -48,16 +159,15 @@ cdef class nmod(flint_scalar): 3 """ - def __init__(self, val, mod): - cdef mp_limb_t m - m = mod - nmod_init(&self.mod, m) - if not any_as_nmod(&self.val, val, self.mod): + ctx = any_as_nmod_ctx(mod) + if ctx is NotImplemented: + raise TypeError("Invalid context/modulus for nmod: %s" % mod) + if not ctx.any_as_nmod(&self.val, val): raise TypeError("cannot create nmod from object of type %s" % type(val)) def repr(self): - return "nmod(%s, %s)" % (self.val, self.mod.n) + return "nmod(%s, %s)" % (self.val, self.ctx.mod.n) def str(self): return str(int(self.val)) @@ -66,7 +176,7 @@ cdef class nmod(flint_scalar): return int(self.val) def modulus(self): - return self.mod.n + return self.ctx.mod.n def __richcmp__(s, t, int op): cdef bint res @@ -74,13 +184,13 @@ cdef class nmod(flint_scalar): raise TypeError("nmods cannot be ordered") if typecheck(s, nmod) and typecheck(t, nmod): res = ((s).val == (t).val) and \ - ((s).mod.n == (t).mod.n) + ((s).ctx.mod.n == (t).ctx.mod.n) if op == 2: return res else: return not res elif typecheck(s, nmod) and typecheck(t, int): - res = s.val == (t % s.mod.n) + res = s.val == (t % s.ctx.mod.n) if op == 2: return res else: @@ -98,100 +208,108 @@ cdef class nmod(flint_scalar): def __neg__(self): cdef nmod r = nmod.__new__(nmod) - r.mod = self.mod - r.val = nmod_neg(self.val, self.mod) + r.ctx = self.ctx + r.val = nmod_neg(self.val, self.ctx.mod) return r def __add__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): + s2 = s + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_add(val, (s).val, r.mod) + r.ctx = s2.ctx + r.val = nmod_add(val, s2.val, s2.ctx.mod) return r return NotImplemented def __radd__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): + s2 = s + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_add((s).val, val, r.mod) + r.ctx = s2.ctx + r.val = nmod_add(s2.val, val, s2.ctx.mod) return r return NotImplemented def __sub__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): + s2 = s + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_sub((s).val, val, r.mod) + r.ctx = s2.ctx + r.val = nmod_sub(s2.val, val, s2.ctx.mod) return r return NotImplemented def __rsub__(s, t): cdef nmod r cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): + s2 = s + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_sub(val, (s).val, r.mod) + r.ctx = s2.ctx + r.val = nmod_sub(val, s2.val, s2.ctx.mod) return r return NotImplemented def __mul__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): + s2 = s + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_mul(val, (s).val, r.mod) + r.ctx = s2.ctx + r.val = nmod_mul(val, s2.val, s2.ctx.mod) return r return NotImplemented def __rmul__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): + s2 = s + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_mul((s).val, val, r.mod) + r.ctx = s2.ctx + r.val = nmod_mul(s2.val, val, s2.ctx.mod) return r return NotImplemented @staticmethod def _div_(s, t): - cdef nmod r + cdef nmod r, s2, t2 cdef mp_limb_t sval, tval - cdef nmod_t mod + cdef nmod_ctx ctx cdef ulong tinvval if typecheck(s, nmod): - mod = (s).mod - sval = (s).val - if not any_as_nmod(&tval, t, mod): + s2 = s + ctx = s2.ctx + sval = s2.val + if not ctx.any_as_nmod(&tval, t): return NotImplemented else: - mod = (t).mod - tval = (t).val - if not any_as_nmod(&sval, s, mod): + t2 = t + ctx = t2.ctx + tval = t2.val + if not ctx.any_as_nmod(&sval, s): return NotImplemented if tval == 0: - raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n)) + raise ZeroDivisionError("%s is not invertible mod %s" % (tval, ctx.mod.n)) if not s: return s - g = n_gcdinv(&tinvval, tval, mod.n) + g = n_gcdinv(&tinvval, tval, ctx.mod.n) if g != 1: - raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n)) + raise ZeroDivisionError("%s is not invertible mod %s" % (tval, ctx.mod.n)) r = nmod.__new__(nmod) - r.mod = mod - r.val = nmod_mul(sval, tinvval, mod) + r.ctx = ctx + r.val = nmod_mul(sval, tinvval, ctx.mod) return r def __truediv__(s, t): @@ -201,19 +319,22 @@ cdef class nmod(flint_scalar): return nmod._div_(t, s) def __invert__(self): - cdef nmod r + cdef nmod r, s + cdef nmod_ctx ctx cdef ulong g, inv, sval - sval = (self).val - g = n_gcdinv(&inv, sval, self.mod.n) + s = self + ctx = s.ctx + sval = s.val + g = n_gcdinv(&inv, sval, ctx.mod.n) if g != 1: - raise ZeroDivisionError("%s is not invertible mod %s" % (sval, self.mod.n)) + raise ZeroDivisionError("%s is not invertible mod %s" % (sval, ctx.mod.n)) r = nmod.__new__(nmod) - r.mod = self.mod + r.ctx = ctx r.val = inv return r def __pow__(self, exp, modulus=None): - cdef nmod r + cdef nmod r, s cdef mp_limb_t rval, mod cdef ulong g, rinv @@ -224,8 +345,10 @@ cdef class nmod(flint_scalar): if e is NotImplemented: return NotImplemented - rval = (self).val - mod = (self).mod.n + s = self + ctx = s.ctx + rval = s.val + mod = ctx.mod.n # XXX: It is not clear that it is necessary to special case negative # exponents here. The nmod_pow_fmpz function seems to handle this fine @@ -238,8 +361,8 @@ cdef class nmod(flint_scalar): e = -e r = nmod.__new__(nmod) - r.mod = self.mod - r.val = nmod_pow_fmpz(rval, (e).val, self.mod) + r.ctx = ctx + r.val = nmod_pow_fmpz(rval, (e).val, ctx.mod) return r def sqrt(self): diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index b9259676..4b00ed64 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -4,7 +4,7 @@ from flint.utils.typecheck cimport typecheck from flint.types.fmpz cimport fmpz, any_as_fmpz from flint.types.fmpz_poly cimport any_as_fmpz_poly from flint.types.fmpz_poly cimport fmpz_poly -from flint.types.nmod cimport any_as_nmod +from flint.types.nmod cimport any_as_nmod_ctx from flint.types.nmod cimport nmod from flint.flintlib.nmod_vec cimport * From 27e5e4c7bbe83e64e7a0aad574ec4d8245d6bbc9 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Thu, 22 Aug 2024 01:31:45 +0100 Subject: [PATCH 02/30] Add nmod_poly_ctx --- src/flint/types/nmod_poly.pxd | 20 ++++- src/flint/types/nmod_poly.pyx | 150 ++++++++++++++++++++++------------ 2 files changed, 115 insertions(+), 55 deletions(-) diff --git a/src/flint/types/nmod_poly.pxd b/src/flint/types/nmod_poly.pxd index c0d1cd85..bd86887b 100644 --- a/src/flint/types/nmod_poly.pxd +++ b/src/flint/types/nmod_poly.pxd @@ -1,10 +1,26 @@ -from flint.flint_base.flint_base cimport flint_poly - +from flint.flintlib.nmod cimport nmod_t from flint.flintlib.nmod_poly cimport nmod_poly_t from flint.flintlib.flint cimport mp_limb_t +from flint.flint_base.flint_base cimport flint_poly + +from flint.types.nmod cimport nmod_ctx + + +cdef class nmod_poly_ctx: + cdef nmod_ctx ctx + cdef nmod_t mod + cdef bint _is_prime + + cdef nmod_poly_set_list(self, nmod_poly_t poly, list val) + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + cdef any_as_nmod_poly(self, obj) + + cdef class nmod_poly(flint_poly): cdef nmod_poly_t val + cdef nmod_poly_ctx ctx + cpdef long length(self) cpdef long degree(self) cpdef mp_limb_t modulus(self) diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 4b00ed64..ddf4773c 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -5,47 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz from flint.types.fmpz_poly cimport any_as_fmpz_poly from flint.types.fmpz_poly cimport fmpz_poly from flint.types.nmod cimport any_as_nmod_ctx -from flint.types.nmod cimport nmod +from flint.types.nmod cimport nmod, nmod_ctx from flint.flintlib.nmod_vec cimport * from flint.flintlib.nmod_poly cimport * from flint.flintlib.nmod_poly_factor cimport * from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly +from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime from flint.utils.flint_exceptions import DomainError -cdef any_as_nmod_poly(obj, nmod_t mod): - cdef nmod_poly r - cdef mp_limb_t v - # XXX: should check that modulus is the same here, and not all over the place - if typecheck(obj, nmod_poly): +_nmod_poly_ctx_cache = {} + + +cdef nmod_ctx any_as_nmod_poly_ctx(obj): + """Convert an int to an nmod_ctx.""" + if typecheck(obj, nmod_poly_ctx): return obj - if any_as_nmod(&v, obj, mod): - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init(r.val, mod.n) - nmod_poly_set_coeff_ui(r.val, 0, v) - return r - x = any_as_fmpz_poly(obj) - if x is not NotImplemented: - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this? - fmpz_poly_get_nmod_poly(r.val, (x).val) - return r + if typecheck(obj, int): + ctx = _nmod_poly_ctx_cache.get(obj) + if ctx is None: + ctx = nmod_poly_ctx(obj) + _nmod_poly_ctx_cache[obj] = ctx + return ctx return NotImplemented -cdef nmod_poly_set_list(nmod_poly_t poly, list val): - cdef long i, n - cdef nmod_t mod - cdef mp_limb_t v - nmod_init(&mod, nmod_poly_modulus(poly)) # XXX - n = PyList_GET_SIZE(val) - nmod_poly_fit_length(poly, n) - for i from 0 <= i < n: - if any_as_nmod(&v, val[i], mod): - nmod_poly_set_coeff_ui(poly, i, v) - else: - raise TypeError("unsupported coefficient in list") + +cdef class nmod_poly_ctx: + """ + Context object for creating :class:`~.nmod_poly` initalised + with modulus :math:`N`. + + >>> nmod_ctx(17) + nmod_ctx(17) + + """ + def __init__(self, mod): + cdef mp_limb_t m + m = mod + nmod_init(&self.mod, m) + self.ctx = nmod_ctx(mod) + self._is_prime = n_is_prime(m) + + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: + return self.ctx.any_as_nmod(val, obj) + + cdef any_as_nmod_poly(self, obj): + cdef nmod_poly r + cdef mp_limb_t v + # XXX: should check that modulus is the same here, and not all over the place + if typecheck(obj, nmod_poly): + return obj + if self.ctx.any_as_nmod(&v, obj): + r = nmod_poly.__new__(nmod_poly) + nmod_poly_init(r.val, self.mod.n) + nmod_poly_set_coeff_ui(r.val, 0, v) + return r + x = any_as_fmpz_poly(obj) + if x is not NotImplemented: + r = nmod_poly.__new__(nmod_poly) + nmod_poly_init(r.val, self.mod.n) # XXX: create flint _nmod_poly_set_modulus for this? + fmpz_poly_get_nmod_poly(r.val, (x).val) + return r + return NotImplemented + + cdef nmod_poly_set_list(self, nmod_poly_t poly, list val): + cdef long i, n + cdef mp_limb_t v + n = PyList_GET_SIZE(val) + nmod_poly_fit_length(poly, n) + for i from 0 <= i < n: + c = val[i] + if self.any_as_nmod(&v, val[i]): + nmod_poly_set_coeff_ui(poly, i, v) + else: + raise TypeError("unsupported coefficient in list") + cdef class nmod_poly(flint_poly): """ @@ -77,24 +113,32 @@ cdef class nmod_poly(flint_poly): def __dealloc__(self): nmod_poly_clear(self.val) - def __init__(self, val=None, ulong mod=0): + def __init__(self, val=None, mod=0): cdef ulong m2 cdef mp_limb_t v + cdef nmod_poly_ctx ctx + if typecheck(val, nmod_poly): m2 = nmod_poly_modulus((val).val) if m2 != mod: raise ValueError("different moduli!") nmod_poly_init(self.val, m2) nmod_poly_set(self.val, (val).val) + self.ctx = (val).ctx else: if mod == 0: raise ValueError("a nonzero modulus is required") - nmod_poly_init(self.val, mod) + ctx = any_as_nmod_poly_ctx(mod) + if ctx is NotImplemented: + raise TypeError("cannot create nmod_poly_ctx from input of type %s", type(mod)) + + self.ctx = ctx + nmod_poly_init(self.val, ctx.mod.n) if typecheck(val, fmpz_poly): fmpz_poly_get_nmod_poly(self.val, (val).val) elif typecheck(val, list): - nmod_poly_set_list(self.val, val) - elif any_as_nmod(&v, val, self.val.mod): + ctx.nmod_poly_set_list(self.val, val) + elif ctx.any_as_nmod(&v, val): nmod_poly_fit_length(self.val, 1) nmod_poly_set_coeff_ui(self.val, 0, v) else: @@ -175,7 +219,7 @@ cdef class nmod_poly(flint_poly): cdef mp_limb_t v if i < 0: raise ValueError("cannot assign to index < 0 of polynomial") - if any_as_nmod(&v, x, self.val.mod): + if self.ctx.any_as_nmod(&v, x): nmod_poly_set_coeff_ui(self.val, i, v) else: raise TypeError("cannot set element of type %s" % type(x)) @@ -288,7 +332,7 @@ cdef class nmod_poly(flint_poly): 9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1 """ cdef nmod_poly res - other = any_as_nmod_poly(other, (self).val.mod) + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") res = nmod_poly.__new__(nmod_poly) @@ -313,11 +357,11 @@ cdef class nmod_poly(flint_poly): 147*x^3 + 159*x^2 + 4*x + 7 """ cdef nmod_poly res - g = any_as_nmod_poly(other, self.val.mod) + g = self.ctx.any_as_nmod_poly(other) if g is NotImplemented: raise TypeError(f"cannot convert other = {other} to nmod_poly") - h = any_as_nmod_poly(modulus, self.val.mod) + h = self.any_as_nmod_poly(modulus, self.val.mod) if h is NotImplemented: raise TypeError(f"cannot convert modulus = {modulus} to nmod_poly") @@ -331,11 +375,11 @@ cdef class nmod_poly(flint_poly): def __call__(self, other): cdef mp_limb_t c - if any_as_nmod(&c, other, self.val.mod): + if self.ctx.any_as_nmod(&c, other): v = nmod(0, self.modulus()) (v).val = nmod_poly_evaluate_nmod(self.val, c) return v - t = any_as_nmod_poly(other, self.val.mod) + t = self.ctx.any_as_nmod_poly(other) if t is not NotImplemented: r = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv((r).val, self.val.mod.n, self.val.mod.ninv) @@ -366,7 +410,7 @@ cdef class nmod_poly(flint_poly): def _add_(s, t): cdef nmod_poly r - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t if (s).val.mod.n != (t).val.mod.n: @@ -392,20 +436,20 @@ cdef class nmod_poly(flint_poly): return r def __sub__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return s._sub_(t) def __rsub__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return t._sub_(s) def _mul_(s, t): cdef nmod_poly r - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t if (s).val.mod.n != (t).val.mod.n: @@ -422,7 +466,7 @@ cdef class nmod_poly(flint_poly): return s._mul_(t) def __truediv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t res, r = s._divmod_(t) @@ -431,7 +475,7 @@ cdef class nmod_poly(flint_poly): return res def __rtruediv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t res, r = t._divmod_(s) @@ -451,13 +495,13 @@ cdef class nmod_poly(flint_poly): return r def __floordiv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return s._floordiv_(t) def __rfloordiv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return t._floordiv_(s) @@ -476,13 +520,13 @@ cdef class nmod_poly(flint_poly): return P, Q def __divmod__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return s._divmod_(t) def __rdivmod__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return t._divmod_(s) @@ -531,7 +575,7 @@ cdef class nmod_poly(flint_poly): if e < 0: raise ValueError("Exponent must be non-negative") - modulus = any_as_nmod_poly(modulus, (self).val.mod) + modulus = self.ctx.any_as_nmod_poly(modulus) if modulus is NotImplemented: raise TypeError("cannot convert input to nmod_poly") @@ -553,7 +597,7 @@ cdef class nmod_poly(flint_poly): # To optimise powering, we precompute the inverse of the reverse of the modulus if mod_rev_inv is not None: - mod_rev_inv = any_as_nmod_poly(mod_rev_inv, (self).val.mod) + mod_rev_inv = self.any_as_nmod_poly(mod_rev_inv) if mod_rev_inv is NotImplemented: raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial") else: @@ -582,7 +626,7 @@ cdef class nmod_poly(flint_poly): """ cdef nmod_poly res - other = any_as_nmod_poly(other, (self).val.mod) + other = self.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") if self.val.mod.n != (other).val.mod.n: @@ -594,7 +638,7 @@ cdef class nmod_poly(flint_poly): def xgcd(self, other): cdef nmod_poly res1, res2, res3 - other = any_as_nmod_poly(other, (self).val.mod) + other = self.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to fmpq_poly") res1 = nmod_poly.__new__(nmod_poly) From a88f88177a35a3f590a7503710499812e3bb0d82 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Thu, 22 Aug 2024 13:44:08 +0100 Subject: [PATCH 03/30] Use nmod contexts for nmod_mat --- src/flint/types/nmod.pyx | 32 ++++++++++++---- src/flint/types/nmod_mat.pxd | 5 +++ src/flint/types/nmod_mat.pyx | 23 ++++++----- src/flint/types/nmod_poly.pyx | 72 ++++++++++++++++++++++++----------- 4 files changed, 93 insertions(+), 39 deletions(-) diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index de29e8a5..0e00d090 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -52,6 +52,23 @@ cdef class nmod_ctx: nmod_init(&self.mod, m) self._is_prime = n_is_prime(m) + def __eq__(self, other): + # XXX: If we could ensure uniqueness of nmod_ctx for given modulus then + # we would need to implement __eq__ and __hash__ at all... + # + # It isn't possible to ensure uniqueness in __new__ like it is in + # Python because we can't return an existing object from __new__. What + # we could do though is make it so that __init__ raises an error and + # use a static method .new() to create new objects. + if self is other: + return True + if not typecheck(other, nmod_ctx): + return NotImplemented + return self.mod.n == (other).mod.n + + def __repr__(self): + return f"nmod_ctx({self.modulus()})" + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: """Convert an object to an nmod element.""" cdef int success @@ -85,7 +102,7 @@ cdef class nmod_ctx: 17 """ - return fmpz(self.mod) + return fmpz(self.mod.n) def is_prime(self): """Check if the modulus is prime. @@ -121,10 +138,10 @@ cdef class nmod_ctx: return hash(self.mod) def __eq__(self, other): - if not typecheck(other, nmod_ctx): - return NotImplemented + if typecheck(other, nmod_ctx): + return self.mod.n == (other).mod.n else: - return self.mod == other.mod + return NotImplemented def __str__(self): return f"Context for nmod with modulus: {self.modulus()}" @@ -165,6 +182,7 @@ cdef class nmod(flint_scalar): raise TypeError("Invalid context/modulus for nmod: %s" % mod) if not ctx.any_as_nmod(&self.val, val): raise TypeError("cannot create nmod from object of type %s" % type(val)) + self.ctx = ctx def repr(self): return "nmod(%s, %s)" % (self.val, self.ctx.mod.n) @@ -385,14 +403,14 @@ cdef class nmod(flint_scalar): cdef nmod r cdef mp_limb_t val r = nmod.__new__(nmod) - r.mod = self.mod + r.ctx = self.ctx if self.val == 0: return r - val = n_sqrtmod(self.val, self.mod.n) + val = n_sqrtmod(self.val, self.ctx.mod.n) if val == 0: - raise DomainError("no square root exists for %s mod %s" % (self.val, self.mod.n)) + raise DomainError("no square root exists for %s mod %s" % (self.val, self.ctx.mod.n)) r.val = val return r diff --git a/src/flint/types/nmod_mat.pxd b/src/flint/types/nmod_mat.pxd index 0b8336d1..ada85a71 100644 --- a/src/flint/types/nmod_mat.pxd +++ b/src/flint/types/nmod_mat.pxd @@ -3,8 +3,13 @@ from flint.flint_base.flint_base cimport flint_mat from flint.flintlib.nmod_mat cimport nmod_mat_t from flint.flintlib.flint cimport mp_limb_t +from flint.types.nmod cimport nmod_ctx + + cdef class nmod_mat(flint_mat): cdef nmod_mat_t val + cdef nmod_ctx ctx + cpdef long nrows(self) cpdef long ncols(self) cpdef mp_limb_t modulus(self) diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 5a54cdce..eb78902d 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -43,8 +43,7 @@ from flint.flintlib.nmod_mat cimport ( from flint.utils.typecheck cimport typecheck from flint.types.fmpz_mat cimport any_as_fmpz_mat from flint.types.fmpz_mat cimport fmpz_mat -from flint.types.nmod cimport nmod -from flint.types.nmod cimport any_as_nmod +from flint.types.nmod cimport nmod, any_as_nmod_ctx from flint.types.nmod_poly cimport nmod_poly from flint.pyflint cimport global_random_state from flint.flint_base.flint_context cimport thectx @@ -87,20 +86,24 @@ cdef class nmod_mat(flint_mat): def __init__(self, *args): cdef long m, n, i, j cdef mp_limb_t mod + cdef nmod_ctx ctx if len(args) == 1: val = args[0] if typecheck(val, nmod_mat): nmod_mat_init_set(self.val, (val).val) + self.ctx = (val).ctx return mod = args[-1] args = args[:-1] if mod == 0: raise ValueError("modulus must be nonzero") + ctx = any_as_nmod_ctx(mod) + self.ctx = ctx if len(args) == 1: val = args[0] if typecheck(val, fmpz_mat): nmod_mat_init(self.val, fmpz_mat_nrows((val).val), - fmpz_mat_ncols((val).val), mod) + fmpz_mat_ncols((val).val), ctx.mod.n) fmpz_mat_get_nmod_mat(self.val, (val).val) elif isinstance(val, (list, tuple)): m = len(val) @@ -112,11 +115,11 @@ cdef class nmod_mat(flint_mat): for i from 1 <= i < m: if len(val[i]) != n: raise ValueError("input rows have different lengths") - nmod_mat_init(self.val, m, n, mod) + nmod_mat_init(self.val, m, n, ctx.mod.n) for i from 0 <= i < m: row = val[i] for j from 0 <= j < n: - x = nmod(row[j], mod) + x = nmod(row[j], ctx) # XXX: slow self.val.rows[i][j] = (x).val else: raise TypeError("cannot create nmod_mat from input of type %s" % type(val)) @@ -131,7 +134,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("list of entries has the wrong length") for i from 0 <= i < m: for j from 0 <= j < n: - x = nmod(entries[i*n + j], mod) # XXX: slow + x = nmod(entries[i*n + j], ctx) # XXX: slow self.val.rows[i][j] = (x).val else: raise TypeError("nmod_mat: expected 1-3 arguments plus modulus") @@ -207,7 +210,7 @@ cdef class nmod_mat(flint_mat): i, j = index if i < 0 or i >= self.nrows() or j < 0 or j >= self.ncols(): raise IndexError("index %i,%i exceeds matrix dimensions" % (i, j)) - if any_as_nmod(&v, value, self.val.mod): + if self.ctx.any_as_nmod(&v, value): nmod_mat_set_entry(self.val, i, j, v) else: raise TypeError("cannot set item of type %s" % type(value)) @@ -306,7 +309,7 @@ cdef class nmod_mat(flint_mat): sv = &(s).val[0] u = any_as_nmod_mat(t, sv.mod) if u is NotImplemented: - if any_as_nmod(&c, t, sv.mod): + if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) return NotImplemented tv = &(u).val[0] @@ -323,7 +326,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef mp_limb_t c sv = &(s).val[0] - if any_as_nmod(&c, t, sv.mod): + if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) u = any_as_nmod_mat(t, sv.mod) if u is NotImplemented: @@ -348,7 +351,7 @@ cdef class nmod_mat(flint_mat): @staticmethod def _div_(nmod_mat s, t): cdef mp_limb_t v - if not any_as_nmod(&v, t, s.val.mod): + if not s.ctx.any_as_nmod(&v, t): return NotImplemented t = nmod(v, s.val.mod.n) return s * (~t) diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index ddf4773c..89a64aac 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -19,7 +19,7 @@ from flint.utils.flint_exceptions import DomainError _nmod_poly_ctx_cache = {} -cdef nmod_ctx any_as_nmod_poly_ctx(obj): +cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj): """Convert an int to an nmod_ctx.""" if typecheck(obj, nmod_poly_ctx): return obj @@ -37,8 +37,8 @@ cdef class nmod_poly_ctx: Context object for creating :class:`~.nmod_poly` initalised with modulus :math:`N`. - >>> nmod_ctx(17) - nmod_ctx(17) + >>> nmod_poly_ctx(17) + nmod_poly_ctx(17) """ def __init__(self, mod): @@ -48,6 +48,9 @@ cdef class nmod_poly_ctx: self.ctx = nmod_ctx(mod) self._is_prime = n_is_prime(m) + def __repr__(self): + return f"nmod_poly_ctx({self.mod.n})" + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: return self.ctx.any_as_nmod(val, obj) @@ -61,12 +64,14 @@ cdef class nmod_poly_ctx: r = nmod_poly.__new__(nmod_poly) nmod_poly_init(r.val, self.mod.n) nmod_poly_set_coeff_ui(r.val, 0, v) + r.ctx = self return r x = any_as_fmpz_poly(obj) if x is not NotImplemented: r = nmod_poly.__new__(nmod_poly) nmod_poly_init(r.val, self.mod.n) # XXX: create flint _nmod_poly_set_modulus for this? fmpz_poly_get_nmod_poly(r.val, (x).val) + r.ctx = self return r return NotImplemented @@ -266,6 +271,7 @@ cdef class nmod_poly(flint_poly): res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) nmod_poly_reverse(res.val, self.val, length) + res.ctx = self.ctx return res def leading_coefficient(self): @@ -287,8 +293,8 @@ cdef class nmod_poly(flint_poly): cu = nmod_poly_get_coeff_ui(self.val, d) c = nmod.__new__(nmod) - c.mod = self.val.mod c.val = cu + c.ctx = self.ctx.ctx return c @@ -315,6 +321,7 @@ cdef class nmod_poly(flint_poly): res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) nmod_poly_inv_series(res.val, self.val, n) + res.ctx = self.ctx return res def compose(self, other): @@ -337,8 +344,9 @@ cdef class nmod_poly(flint_poly): raise TypeError("cannot convert input to nmod_poly") res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) - nmod_poly_compose(res.val, self.val, (other).val) - return res + nmod_poly_compose(res.val, self.val, (other).val) + res.ctx = self.ctx + return res def compose_mod(self, other, modulus): r""" @@ -361,7 +369,7 @@ cdef class nmod_poly(flint_poly): if g is NotImplemented: raise TypeError(f"cannot convert other = {other} to nmod_poly") - h = self.any_as_nmod_poly(modulus, self.val.mod) + h = self.ctx.any_as_nmod_poly(modulus) if h is NotImplemented: raise TypeError(f"cannot convert modulus = {modulus} to nmod_poly") @@ -370,8 +378,9 @@ cdef class nmod_poly(flint_poly): res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) - nmod_poly_compose_mod(res.val, self.val, (other).val, (modulus).val) - return res + nmod_poly_compose_mod(res.val, self.val, (other).val, (modulus).val) + res.ctx = self.ctx + return res def __call__(self, other): cdef mp_limb_t c @@ -384,6 +393,7 @@ cdef class nmod_poly(flint_poly): r = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv((r).val, self.val.mod.n, self.val.mod.ninv) nmod_poly_compose((r).val, self.val, (t).val) + (r).ctx = self.ctx return r raise TypeError("cannot call nmod_poly with input of type %s", type(other)) @@ -391,12 +401,14 @@ cdef class nmod_poly(flint_poly): cdef nmod_poly res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) nmod_poly_derivative(res.val, self.val) + res.ctx = self.ctx return res def integral(self): cdef nmod_poly res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) nmod_poly_integral(res.val, self.val) + res.ctx = self.ctx return res def __pos__(self): @@ -406,6 +418,7 @@ cdef class nmod_poly(flint_poly): cdef nmod_poly r = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(r.val, self.val.mod.n, self.val.mod.ninv) nmod_poly_neg(r.val, self.val) + r.ctx = self.ctx return r def _add_(s, t): @@ -418,6 +431,7 @@ cdef class nmod_poly(flint_poly): r = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) nmod_poly_add(r.val, (s).val, (t).val) + r.ctx = s.ctx return r def __add__(s, t): @@ -433,6 +447,7 @@ cdef class nmod_poly(flint_poly): r = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) nmod_poly_sub(r.val, (s).val, (t).val) + r.ctx = s.ctx return r def __sub__(s, t): @@ -442,14 +457,14 @@ cdef class nmod_poly(flint_poly): return s._sub_(t) def __rsub__(s, t): - t = s.any_as_nmod_poly(t) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return t._sub_(s) def _mul_(s, t): cdef nmod_poly r - t = s.any_as_nmod_poly(t) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t if (s).val.mod.n != (t).val.mod.n: @@ -457,6 +472,7 @@ cdef class nmod_poly(flint_poly): r = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) nmod_poly_mul(r.val, (s).val, (t).val) + r.ctx = s.ctx return r def __mul__(s, t): @@ -466,7 +482,7 @@ cdef class nmod_poly(flint_poly): return s._mul_(t) def __truediv__(s, t): - t = s.any_as_nmod_poly(t) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t res, r = s._divmod_(t) @@ -475,7 +491,7 @@ cdef class nmod_poly(flint_poly): return res def __rtruediv__(s, t): - t = s.any_as_nmod_poly(t) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t res, r = t._divmod_(s) @@ -492,16 +508,17 @@ cdef class nmod_poly(flint_poly): r = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) nmod_poly_div(r.val, (s).val, (t).val) + r.ctx = s.ctx return r def __floordiv__(s, t): - t = s.any_as_nmod_poly(t) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return s._floordiv_(t) def __rfloordiv__(s, t): - t = s.any_as_nmod_poly(t) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return t._floordiv_(s) @@ -517,16 +534,18 @@ cdef class nmod_poly(flint_poly): nmod_poly_init_preinv(P.val, (t).val.mod.n, (t).val.mod.ninv) nmod_poly_init_preinv(Q.val, (t).val.mod.n, (t).val.mod.ninv) nmod_poly_divrem(P.val, Q.val, (s).val, (t).val) + P.ctx = s.ctx + Q.ctx = s.ctx return P, Q def __divmod__(s, t): - t = s.any_as_nmod_poly(t) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return s._divmod_(t) def __rdivmod__(s, t): - t = s.any_as_nmod_poly(t) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return t._divmod_(s) @@ -546,6 +565,7 @@ cdef class nmod_poly(flint_poly): res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, (self).val.mod.n, (self).val.mod.ninv) nmod_poly_pow(res.val, self.val, exp) + res.ctx = self.ctx return res def pow_mod(self, e, modulus, mod_rev_inv=None): @@ -582,7 +602,8 @@ cdef class nmod_poly(flint_poly): # Output polynomial res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) - + res.ctx = self.ctx + # For small exponents, use a simple binary exponentiation method if e.bit_length() < 32: nmod_poly_powmod_ui_binexp( @@ -597,7 +618,7 @@ cdef class nmod_poly(flint_poly): # To optimise powering, we precompute the inverse of the reverse of the modulus if mod_rev_inv is not None: - mod_rev_inv = self.any_as_nmod_poly(mod_rev_inv) + mod_rev_inv = self.ctx.any_as_nmod_poly(mod_rev_inv) if mod_rev_inv is NotImplemented: raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial") else: @@ -626,7 +647,7 @@ cdef class nmod_poly(flint_poly): """ cdef nmod_poly res - other = self.any_as_nmod_poly(other) + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") if self.val.mod.n != (other).val.mod.n: @@ -634,11 +655,12 @@ cdef class nmod_poly(flint_poly): res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) nmod_poly_gcd(res.val, self.val, (other).val) + res.ctx = self.ctx return res def xgcd(self, other): cdef nmod_poly res1, res2, res3 - other = self.any_as_nmod_poly(other) + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to fmpq_poly") res1 = nmod_poly.__new__(nmod_poly) @@ -648,6 +670,9 @@ cdef class nmod_poly(flint_poly): nmod_poly_init(res2.val, (self).val.mod.n) nmod_poly_init(res3.val, (self).val.mod.n) nmod_poly_xgcd(res1.val, res2.val, res3.val, self.val, (other).val) + res1.ctx = self.ctx + res2.ctx = self.ctx + res3.ctx = self.ctx return (res1, res2, res3) def factor(self, algorithm=None): @@ -722,11 +747,12 @@ cdef class nmod_poly(flint_poly): nmod_poly_init_preinv((u).val, (self).val.mod.n, (self).val.mod.ninv) nmod_poly_set((u).val, &fac.p[i]) + (u).ctx = self.ctx exp = fac.exp[i] res[i] = (u, exp) c = nmod.__new__(nmod) - (c).mod = self.val.mod + (c).ctx = self.ctx.ctx (c).val = lead nmod_poly_factor_clear(fac) @@ -738,6 +764,7 @@ cdef class nmod_poly(flint_poly): cdef nmod_poly res res = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res.ctx = self.ctx if nmod_poly_sqrt(res.val, self.val): return res else: @@ -755,6 +782,7 @@ cdef class nmod_poly(flint_poly): v = nmod_poly.__new__(nmod_poly) nmod_poly_init(v.val, self.val.mod.n) nmod_poly_deflate(v.val, self.val, n) + v.ctx = self.ctx return v, int(n) def real_roots(self): From ba9c1fda159bc921f954713720555794b17a3c21 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Thu, 22 Aug 2024 14:18:27 +0100 Subject: [PATCH 04/30] Inline ctx.any_as_nmod for faster nmod.__mul__ --- src/flint/types/nmod.pxd | 2 +- src/flint/types/nmod.pyx | 29 +++++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 5f180812..713fe614 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -2,7 +2,7 @@ from flint.flint_base.flint_base cimport flint_scalar from flint.flintlib.flint cimport mp_limb_t from flint.flintlib.nmod cimport nmod_t -#cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1 +cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1 cdef nmod_ctx any_as_nmod_ctx(obj) cdef class nmod_ctx: diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 0e00d090..591fcdc3 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -20,7 +20,6 @@ from flint.utils.flint_exceptions import DomainError #cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1: # return mod.ctx.any_as_nmod(val, obj) - _nmod_ctx_cache = {} @@ -37,6 +36,32 @@ cdef nmod_ctx any_as_nmod_ctx(obj): return NotImplemented +cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1: + """Convert an object to an nmod element.""" + cdef int success + cdef fmpz_t t + if typecheck(obj, nmod): + if (obj).ctx.mod.n != mod.n: + raise ValueError("cannot coerce integers mod n with different n") + val[0] = (obj).val + return 1 + z = any_as_fmpz(obj) + if z is not NotImplemented: + val[0] = fmpz_fdiv_ui((z).val, mod.n) + return 1 + q = any_as_fmpq(obj) + if q is not NotImplemented: + fmpz_init(t) + fmpz_set_ui(t, mod.n) + success = fmpq_mod_fmpz(t, (q).val, t) + val[0] = fmpz_get_ui(t) + fmpz_clear(t) + if not success: + raise ZeroDivisionError("%s does not exist mod %i!" % (q, mod.n)) + return 1 + return 0 + + cdef class nmod_ctx: """ Context object for creating :class:`~.nmod` initalised @@ -278,7 +303,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if s2.ctx.any_as_nmod(&val, t): + if any_as_nmod(&val, t, s2.ctx.mod): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_mul(val, s2.val, s2.ctx.mod) From 8c2221b19c8f39197882f541b9de534a4f27f525 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Sat, 24 Aug 2024 01:16:13 +0100 Subject: [PATCH 05/30] Use nmod_ctx consistently in nmod_poly and nmod_mat --- src/flint/test/test_all.py | 174 +++++++++++++++++++++---------- src/flint/types/fmpz_mat.pyx | 5 +- src/flint/types/fmpz_mod_mat.pyx | 26 +++++ src/flint/types/nmod.pxd | 2 + src/flint/types/nmod_mat.pyx | 127 ++++++++++++++-------- src/flint/types/nmod_poly.pxd | 4 + src/flint/types/nmod_poly.pyx | 110 ++++++++----------- 7 files changed, 284 insertions(+), 164 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index 9494302f..6bdc9480 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -3595,17 +3595,34 @@ def factor_sqf(p): def _all_matrices(): """Return a list of matrix types and scalar types.""" + # Prime modulus R163 = flint.fmpz_mod_ctx(163) R127 = flint.fmpz_mod_ctx(2**127 - 1) R255 = flint.fmpz_mod_ctx(2**255 - 19) + + # Composite modulus + R164_C = flint.fmpz_mod_ctx(164) + R127_C = flint.fmpz_mod_ctx(2**127) + R255_C = flint.fmpz_mod_ctx(2**255) + return [ - # (matrix_type, scalar_type, is_field) - (flint.fmpz_mat, flint.fmpz, False), - (flint.fmpq_mat, flint.fmpq, True), - (lambda *a: flint.nmod_mat(*a, 17), lambda x: flint.nmod(x, 17), True), - (lambda *a: flint.fmpz_mod_mat(*a, R163), lambda x: flint.fmpz_mod(x, R163), True), - (lambda *a: flint.fmpz_mod_mat(*a, R127), lambda x: flint.fmpz_mod(x, R127), True), - (lambda *a: flint.fmpz_mod_mat(*a, R255), lambda x: flint.fmpz_mod(x, R255), True), + # (matrix_type, scalar_type, is_field, characteristic) + + # Z and Q + (flint.fmpz_mat, flint.fmpz, False, 0), + (flint.fmpq_mat, flint.fmpq, True, 0), + + # Z/pZ + (lambda *a: flint.nmod_mat(*a, 17), lambda x: flint.nmod(x, 17), True, 17), + (lambda *a: flint.fmpz_mod_mat(*a, R163), lambda x: flint.fmpz_mod(x, R163), True, 163), + (lambda *a: flint.fmpz_mod_mat(*a, R127), lambda x: flint.fmpz_mod(x, R127), True, 2**127 - 1), + (lambda *a: flint.fmpz_mod_mat(*a, R255), lambda x: flint.fmpz_mod(x, R255), True, 2**255 - 19), + + # Z/nZ (n composite) + (lambda *a: flint.nmod_mat(*a, 16), lambda x: flint.nmod(x, 16), False, 16), + (lambda *a: flint.fmpz_mod_mat(*a, R164_C), lambda x: flint.fmpz_mod(x, R164_C), False, 164), + (lambda *a: flint.fmpz_mod_mat(*a, R127_C), lambda x: flint.fmpz_mod(x, R127_C), False, 2**127), + (lambda *a: flint.fmpz_mod_mat(*a, R255_C), lambda x: flint.fmpz_mod(x, R255_C), False, 2**255), ] @@ -3726,7 +3743,7 @@ def _poly_type_from_matrix_type(mat_type): def test_matrices_eq(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): A1 = M([[1, 2], [3, 4]]) A2 = M([[1, 2], [3, 4]]) B = M([[5, 6], [7, 8]]) @@ -3751,7 +3768,7 @@ def test_matrices_eq(): def test_matrices_constructor(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): assert raises(lambda: M(), TypeError) # Empty matrices @@ -3823,7 +3840,7 @@ def _matrix_repr(M): def test_matrices_strrepr(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): A = M([[1, 2], [3, 4]]) A_str = "[1, 2]\n[3, 4]" A_repr = _matrix_repr(A) @@ -3846,7 +3863,7 @@ def test_matrices_strrepr(): def test_matrices_getitem(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) assert M1234[0, 0] == S(1) assert M1234[0, 1] == S(2) @@ -3862,7 +3879,7 @@ def test_matrices_getitem(): def test_matrices_setitem(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) assert M1234[0, 0] == S(1) @@ -3888,7 +3905,7 @@ def setbad(obj, key, val): def test_matrices_bool(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): assert bool(M([])) is False assert bool(M([[0]])) is False assert bool(M([[1]])) is True @@ -3899,14 +3916,14 @@ def test_matrices_bool(): def test_matrices_pos_neg(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) assert +M1234 == M1234 assert -M1234 == M([[-1, -2], [-3, -4]]) def test_matrices_add(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) M5678 = M([[5, 6], [7, 8]]) assert M1234 + M5678 == M([[6, 8], [10, 12]]) @@ -3926,7 +3943,7 @@ def test_matrices_add(): def test_matrices_sub(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) M5678 = M([[5, 6], [7, 8]]) assert M1234 - M5678 == M([[-4, -4], [-4, -4]]) @@ -3946,7 +3963,7 @@ def test_matrices_sub(): def test_matrices_mul(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) M5678 = M([[5, 6], [7, 8]]) assert M1234 * M5678 == M([[19, 22], [43, 50]]) @@ -3972,18 +3989,24 @@ def test_matrices_mul(): def test_matrices_pow(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) + assert M1234**0 == M([[1, 0], [0, 1]]) assert M1234**1 == M1234 assert M1234**2 == M([[7, 10], [15, 22]]) assert M1234**3 == M([[37, 54], [81, 118]]) + if is_field: assert M1234**-1 == M([[-4, 2], [3, -1]]) / 2 assert M1234**-2 == M([[22, -10], [-15, 7]]) / 4 assert M1234**-3 == M([[-118, 54], [81, -37]]) / 8 Ms = M([[1, 2], [3, 6]]) assert raises(lambda: Ms**-1, ZeroDivisionError) + else: + # XXX: Allow unimodular matrices? + assert raises(lambda: M1234**-1, DomainError) + Mr = M([[1, 2, 3], [4, 5, 6]]) assert raises(lambda: Mr**0, ValueError) assert raises(lambda: Mr**1, ValueError) @@ -3993,31 +4016,49 @@ def test_matrices_pow(): def test_matrices_div(): - for M, S, is_field in _all_matrices(): + + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) + if is_field: assert M1234 / 2 == M([[S(1)/2, S(1)], [S(3)/2, 2]]) assert M1234 / S(2) == M([[S(1)/2, S(1)], [S(3)/2, 2]]) assert raises(lambda: M1234 / 0, ZeroDivisionError) assert raises(lambda: M1234 / S(0), ZeroDivisionError) + else: + assert raises(lambda: M1234 / 2, DomainError) + if characteristic == 0: + assert (2*M1234) / 2 == M1234 + else: + assert raises(lambda: (2*M1234) / 2, DomainError) + raises(lambda: M1234 / None, TypeError) raises(lambda: None / M1234, TypeError) def test_matrices_inv(): - for M, S, is_field in _all_matrices(): - if is_field: - M1234 = M([[1, 2], [3, 4]]) + + for M, S, is_field, characteristic in _all_matrices(): + + M1234 = M([[1, 2], [3, 4]]) + M1236 = M([[1, 2], [3, 6]]) + Mr = M([[1, 2, 3], [4, 5, 6]]) + + if characteristic > 0 and not is_field: + assert raises(lambda: M([[1, 2], [3, 4]]).inv(), DomainError) + elif is_field: assert M1234.inv() == M([[-2, 1], [S(3)/2, -S(1)/2]]) - M1236 = M([[1, 2], [3, 6]]) assert raises(lambda: M1236.inv(), ZeroDivisionError) - Mr = M([[1, 2, 3], [4, 5, 6]]) assert raises(lambda: Mr.inv(), ValueError) - # XXX: Test non-field matrices. unimodular? + else: + # assert M1234.inv() == (M([[-4, 2], [3, -1]]), 2) + # assert M1236.inv() == (M([[-6, 2], [3, -1]]), 3) + # XXX: fmpz_mat.inv() return fmpq_mat... + assert M1234.inv() * M1234.det() == M([[4, -2], [-3, 1]]) def test_matrices_det(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) assert M1234.det() == S(-2) M9 = M([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) @@ -4027,7 +4068,7 @@ def test_matrices_det(): def test_matrices_charpoly(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): P = _poly_type_from_matrix_type(M) M1234 = M([[1, 2], [3, 4]]) assert M1234.charpoly() == P([-2, -5, 1]) @@ -4038,18 +4079,21 @@ def test_matrices_charpoly(): def test_matrices_minpoly(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): + if characteristic > 0 and not is_field: + assert raises(lambda: M([[1, 2], [3, 4]]).minpoly(), DomainError) + continue P = _poly_type_from_matrix_type(M) - M1234 = M([[1, 2], [3, 4]]) - assert M1234.minpoly() == P([-2, -5, 1]) - M9 = M([[2, 1, 0], [0, 2, 0], [0, 0, 2]]) - assert M9.minpoly() == P([4, -4, 1]) - Mr = M([[1, 2, 3], [4, 5, 6]]) - assert raises(lambda: Mr.minpoly(), ValueError) + assert M([[1, 2], [3, 4]]).minpoly() == P([-2, -5, 1]) + assert M([[2, 1, 0], [0, 2, 0], [0, 0, 2]]).minpoly() == P([4, -4, 1]) + assert raises(lambda: M([[1, 2, 3], [4, 5, 6]]).minpoly(), ValueError) def test_matrices_rank(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): + if characteristic > 0 and not is_field: + assert raises(lambda: M([[1, 2], [3, 4]]).rank(), DomainError) + continue M1234 = M([[1, 2], [3, 4]]) assert M1234.rank() == 2 Mr = M([[1, 2, 3], [4, 5, 6]]) @@ -4061,37 +4105,57 @@ def test_matrices_rank(): def test_matrices_rref(): - for M, S, is_field in _all_matrices(): - if is_field: - Mr = M([[1, 2, 3], [4, 5, 6]]) - Mr_rref = M([[1, 0, -1], [0, 1, 2]]) + for M, S, is_field, characteristic in _all_matrices(): + + Mr = M([[1, 2, 3], [4, 5, 6]]) + Mr_rref = M([[1, 0, -1], [0, 1, 2]]) + + if characteristic > 0 and not is_field: + # Z/nZ (n composite) raises + assert raises(lambda: Mr.rref(), DomainError) + elif is_field: + # Q, Z/pZ and GF(p^d) return usual RREF assert Mr.rref() == (Mr_rref, 2) assert Mr == M([[1, 2, 3], [4, 5, 6]]) assert Mr.rref(inplace=True) == (Mr_rref, 2) assert Mr == Mr_rref + else: + # Z returns RREF with divisor -3 + d = -3 + assert Mr.rref() == (d*Mr_rref, d, 2) + assert Mr == M([[1, 2, 3], [4, 5, 6]]) + assert Mr.rref(inplace=True) == (d*Mr_rref, d, 2) + assert Mr == d*Mr_rref def test_matrices_solve(): - for M, S, is_field in _all_matrices(): - if is_field: - A = M([[1, 2], [3, 4]]) - x = M([[1], [2]]) - b = M([[5], [11]]) - assert A*x == b + for M, S, is_field, characteristic in _all_matrices(): + + A = M([[1, 2], [3, 4]]) + x = M([[1], [2]]) + b = M([[5], [11]]) + assert A*x == b + + A2 = M([[1, 2], [2, 4]]) + + if characteristic > 0 and not is_field: + assert raises(lambda: A.solve(b), DomainError) + assert raises(lambda: A2.solve(b), DomainError) + else: assert A.solve(b) == x - A22 = M([[1, 2], [3, 4]]) - A23 = M([[1, 2, 3], [4, 5, 6]]) - b2 = M([[5], [11]]) - b3 = M([[5], [11], [17]]) - assert raises(lambda: A22.solve(b3), ValueError) - assert raises(lambda: A23.solve(b2), ValueError) - assert raises(lambda: A.solve(None), TypeError) - A = M([[1, 2], [2, 4]]) - assert raises(lambda: A.solve(b), ZeroDivisionError) + assert raises(lambda: A2.solve(b), ZeroDivisionError) + + A22 = M([[1, 2], [3, 4]]) + A23 = M([[1, 2, 3], [4, 5, 6]]) + b2 = M([[5], [11]]) + b3 = M([[5], [11], [17]]) + assert raises(lambda: A22.solve(b3), ValueError) + assert raises(lambda: A23.solve(b2), ValueError) + assert raises(lambda: A.solve(None), TypeError) def test_matrices_transpose(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2, 3], [4, 5, 6]]) assert M1234.transpose() == M([[1, 4], [2, 5], [3, 6]]) diff --git a/src/flint/types/fmpz_mat.pyx b/src/flint/types/fmpz_mat.pyx index e2d6444f..aae9ae9a 100644 --- a/src/flint/types/fmpz_mat.pyx +++ b/src/flint/types/fmpz_mat.pyx @@ -306,8 +306,11 @@ cdef class fmpz_mat(flint_mat): raise ValueError("matrix must be square") if m is not None: raise NotImplementedError("modular matrix exponentiation") + if e < 0: + raise DomainError("negative power of integer matrix: M**%i" % e) ee = e - t = fmpz_mat(self) # XXX + t = fmpz_mat.__new__(fmpz_mat) + fmpz_mat_init_set(t.val, (self).val) fmpz_mat_pow(t.val, t.val, ee) return t diff --git a/src/flint/types/fmpz_mod_mat.pyx b/src/flint/types/fmpz_mod_mat.pyx index f4aac95f..44f3c08c 100644 --- a/src/flint/types/fmpz_mod_mat.pyx +++ b/src/flint/types/fmpz_mod_mat.pyx @@ -53,6 +53,10 @@ from flint.types.nmod_mat cimport ( nmod_mat, ) +from flint.utils.flint_exceptions import ( + DomainError, +) + cdef any_as_fmpz_mod_mat(x): if typecheck(x, fmpz_mod_mat): @@ -401,6 +405,14 @@ cdef class fmpz_mod_mat(flint_mat): def _div(self, fmpz_mod other): """Divide an ``fmpz_mod_mat`` matrix by an ``fmpz_mod`` scalar.""" + try: + inv = other.inverse() + except ZeroDivisionError: + # XXX: Maybe fmpz_mod should raise DomainError? + if other == 0: + raise ZeroDivisionError("fmpz_mod_mat div: division by zero") + else: + raise DomainError("fmpz_mod_mat div: division by non-invertible element") return self._scalarmul(other.inverse()) def __add__(self, other): @@ -483,8 +495,12 @@ cdef class fmpz_mod_mat(flint_mat): Assumes that the modulus is prime. """ cdef fmpz_mod_mat res + if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat inv: matrix must be square") + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat inv: modulus must be prime") + res = self._newlike() r = compat_fmpz_mod_mat_inv(res.val, self.val, self.ctx.val) if r == 0: @@ -546,6 +562,8 @@ cdef class fmpz_mod_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat minpoly: matrix must be square") + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat minpoly: modulus must be prime") pctx = fmpz_mod_poly_ctx(self.ctx) res = fmpz_mod_poly(0, pctx) @@ -596,6 +614,8 @@ cdef class fmpz_mod_mat(flint_mat): raise ValueError("fmpz_mod_mat solve: matrix must be square") if self.nrows() != rhs.nrows(): raise ValueError("fmpz_mod_mat solve: shape mismatch") + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat solve: modulus must be prime") res = self._new(rhs.nrows(), rhs.ncols(), self.ctx) success = compat_fmpz_mod_mat_solve(res.val, self.val, ( rhs).val, self.ctx.val) @@ -616,6 +636,8 @@ cdef class fmpz_mod_mat(flint_mat): Assumes that the modulus is prime. """ + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat rank: modulus must be prime") return self.rref()[1] def rref(self, inplace=False): @@ -637,6 +659,10 @@ cdef class fmpz_mod_mat(flint_mat): """ cdef fmpz_mod_mat res cdef slong r + + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat rref: modulus must be prime") + if inplace: res = self else: diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 713fe614..927fa677 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -5,6 +5,7 @@ from flint.flintlib.nmod cimport nmod_t cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1 cdef nmod_ctx any_as_nmod_ctx(obj) + cdef class nmod_ctx: cdef nmod_t mod cdef bint _is_prime @@ -12,6 +13,7 @@ cdef class nmod_ctx: cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 cdef nmod _new(self, mp_limb_t * val) + cdef class nmod(flint_scalar): cdef mp_limb_t val cdef nmod_ctx ctx diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index eb78902d..58ccf86f 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -44,28 +44,48 @@ from flint.utils.typecheck cimport typecheck from flint.types.fmpz_mat cimport any_as_fmpz_mat from flint.types.fmpz_mat cimport fmpz_mat from flint.types.nmod cimport nmod, any_as_nmod_ctx -from flint.types.nmod_poly cimport nmod_poly +from flint.types.nmod_poly cimport nmod_poly, nmod_poly_new_init, any_as_nmod_poly_ctx from flint.pyflint cimport global_random_state from flint.flint_base.flint_context cimport thectx from flint.flint_base.flint_base cimport flint_mat +from flint.utils.flint_exceptions import DomainError + ctx = thectx -cdef any_as_nmod_mat(obj, nmod_t mod): +cdef nmod_mat new_nmod_mat_init(nmod_ctx ctx, ulong m, ulong n): + """New initialized nmod_mat of size m x n with context ctx.""" + cdef nmod_mat r = nmod_mat.__new__(nmod_mat) + nmod_mat_init(r.val, m, n, ctx.mod.n) + r.ctx = ctx + return r + + +cdef nmod_mat new_nmod_mat_copy(nmod_mat other): + """New copy of nmod_mat other.""" + cdef nmod_mat r = nmod_mat.__new__(nmod_mat) + nmod_mat_init_set(r.val, other.val) + r.ctx = other.ctx + return r + + +cdef any_as_nmod_mat(obj, nmod_ctx ctx): + """Convert obj to nmod_mat or return NotImplemented.""" cdef nmod_mat r if typecheck(obj, nmod_mat): return obj + x = any_as_fmpz_mat(obj) if x is not NotImplemented: - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, - fmpz_mat_nrows((x).val), - fmpz_mat_ncols((x).val), mod.n) + r = new_nmod_mat_init(ctx, + fmpz_mat_nrows((x).val), + fmpz_mat_ncols((x).val)) fmpz_mat_get_nmod_mat(r.val, (x).val) return r + return NotImplemented @@ -87,18 +107,22 @@ cdef class nmod_mat(flint_mat): cdef long m, n, i, j cdef mp_limb_t mod cdef nmod_ctx ctx + if len(args) == 1: val = args[0] if typecheck(val, nmod_mat): nmod_mat_init_set(self.val, (val).val) self.ctx = (val).ctx return + mod = args[-1] args = args[:-1] if mod == 0: raise ValueError("modulus must be nonzero") + ctx = any_as_nmod_ctx(mod) self.ctx = ctx + if len(args) == 1: val = args[0] if typecheck(val, fmpz_mat): @@ -228,7 +252,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, sv.mod) + t = any_as_nmod_mat(t, s.ctx) if t is NotImplemented: return t tv = &(t).val[0] @@ -236,8 +260,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot add nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix addition") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n) + r = new_nmod_mat_init(s.ctx, sv.r, sv.c) nmod_mat_add(r.val, sv, tv) return r @@ -246,7 +269,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, sv.mod) + t = any_as_nmod_mat(t, s.ctx) if t is NotImplemented: return t tv = &(t).val[0] @@ -254,8 +277,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot add nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix addition") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n) + r = new_nmod_mat_init(s.ctx, sv.r, sv.c) nmod_mat_add(r.val, sv, tv) return r @@ -264,7 +286,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, sv.mod) + t = any_as_nmod_mat(t, s.ctx) if t is NotImplemented: return t tv = &(t).val[0] @@ -272,8 +294,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot subtract nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix subtraction") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n) + r = new_nmod_mat_init(s.ctx, sv.r, sv.c) nmod_mat_sub(r.val, sv, tv) return r @@ -282,7 +303,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, sv.mod) + t = any_as_nmod_mat(t, s.ctx) if t is NotImplemented: return t tv = &(t).val[0] @@ -290,14 +311,13 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot subtract nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix subtraction") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n) + r = new_nmod_mat_init(s.ctx, sv.r, sv.c) nmod_mat_sub(r.val, tv, sv) return r cdef __mul_nmod(self, mp_limb_t c): - cdef nmod_mat r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, self.val.r, self.val.c, self.val.mod.n) + cdef nmod_mat r + r = new_nmod_mat_init(self.ctx, self.val.r, self.val.c) nmod_mat_scalar_mul(r.val, self.val, c) return r @@ -307,7 +327,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *tv cdef mp_limb_t c sv = &(s).val[0] - u = any_as_nmod_mat(t, sv.mod) + u = any_as_nmod_mat(t, s.ctx) if u is NotImplemented: if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) @@ -317,8 +337,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot multiply nmod_mats with different moduli") if sv.c != tv.r: raise ValueError("incompatible shapes for matrix multiplication") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, tv.c, sv.mod.n) + r = new_nmod_mat_init(s.ctx, sv.r, tv.c) nmod_mat_mul(r.val, sv, tv) return r @@ -328,7 +347,7 @@ cdef class nmod_mat(flint_mat): sv = &(s).val[0] if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) - u = any_as_nmod_mat(t, sv.mod) + u = any_as_nmod_mat(t, s.ctx) if u is NotImplemented: return u return u * s @@ -341,10 +360,12 @@ cdef class nmod_mat(flint_mat): if m is not None: raise NotImplementedError("modular matrix exponentiation") if e < 0: + if not self.ctx._is_prime: + raise DomainError("negative matrix power needs prime modulus") self = self.inv() e = -e ee = e - t = nmod_mat(self) # XXX + t = new_nmod_mat_copy(self) nmod_mat_pow(t.val, t.val, ee) return t @@ -354,7 +375,15 @@ cdef class nmod_mat(flint_mat): if not s.ctx.any_as_nmod(&v, t): return NotImplemented t = nmod(v, s.val.mod.n) - return s * (~t) + try: + tinv = ~t + except ZeroDivisionError: + # XXX: Maybe nmod.__invert__ should raise DomainError instead? + if t == 0: + raise ZeroDivisionError("division by zero") + else: + raise DomainError("nmod_mat division: modulus must be prime") + return s * tinv def __truediv__(s, t): return nmod_mat._div_(s, t) @@ -386,11 +415,14 @@ cdef class nmod_mat(flint_mat): """ cdef nmod_mat u + if not nmod_mat_is_square(self.val): raise ValueError("matrix must be square") - u = nmod_mat.__new__(nmod_mat) - nmod_mat_init(u.val, nmod_mat_nrows(self.val), - nmod_mat_ncols(self.val), self.val.mod.n) + if not self.ctx._is_prime: + raise DomainError("nmod_mat inv: modulus must be prime") + + u = new_nmod_mat_init(self.ctx, nmod_mat_nrows(self.val), + nmod_mat_ncols(self.val)) if not nmod_mat_inv(u.val, self.val): raise ZeroDivisionError("matrix is singular") return u @@ -405,9 +437,8 @@ cdef class nmod_mat(flint_mat): [2, 5] """ cdef nmod_mat u - u = nmod_mat.__new__(nmod_mat) - nmod_mat_init(u.val, nmod_mat_ncols(self.val), - nmod_mat_nrows(self.val), self.val.mod.n) + u = new_nmod_mat_init(self.ctx, nmod_mat_ncols(self.val), + nmod_mat_nrows(self.val)) nmod_mat_transpose(u.val, self.val) return u @@ -436,15 +467,17 @@ cdef class nmod_mat(flint_mat): """ cdef nmod_mat u cdef int result - t = any_as_nmod_mat(other, self.val.mod) + t = any_as_nmod_mat(other, self.ctx) if t is NotImplemented: raise TypeError("cannot convert input to nmod_mat") if (nmod_mat_nrows(self.val) != nmod_mat_ncols(self.val) or nmod_mat_nrows(self.val) != nmod_mat_nrows((t).val)): raise ValueError("need a square system and compatible right hand side") - u = nmod_mat.__new__(nmod_mat) - nmod_mat_init(u.val, nmod_mat_nrows((t).val), - nmod_mat_ncols((t).val), self.val.mod.n) + if not self.ctx._is_prime: + raise DomainError("nmod_mat solve: modulus must be prime") + + u = new_nmod_mat_init(self.ctx, nmod_mat_nrows((t).val), + nmod_mat_ncols((t).val)) result = nmod_mat_solve(u.val, self.val, (t).val) if not result: raise ZeroDivisionError("singular matrix in solve()") @@ -471,11 +504,13 @@ cdef class nmod_mat(flint_mat): [0, 0, 0] """ + if not self.ctx._is_prime: + raise DomainError("rref only works for prime moduli") + if inplace: res = self else: - res = nmod_mat.__new__(nmod_mat) - nmod_mat_init_set((res).val, self.val) + res = new_nmod_mat_copy(self) rank = nmod_mat_rref((res).val) return res, rank @@ -487,6 +522,8 @@ cdef class nmod_mat(flint_mat): >>> M.rank() 2 """ + if not self.ctx._is_prime: + raise DomainError("rank only works for prime moduli") return nmod_mat_rank(self.val) def nullspace(self): @@ -513,8 +550,8 @@ cdef class nmod_mat(flint_mat): """ cdef nmod_mat res - res = nmod_mat.__new__(nmod_mat) - nmod_mat_init(res.val, nmod_mat_ncols(self.val), nmod_mat_ncols(self.val), self.val.mod.n) + res = new_nmod_mat_init(self.ctx, nmod_mat_ncols(self.val), + nmod_mat_ncols(self.val)) nullity = nmod_mat_nullspace(res.val, self.val) return res, nullity @@ -532,8 +569,8 @@ cdef class nmod_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat charpoly: matrix must be square") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init(res.val, self.val.mod.n) + # XXX: don't create a new context for the polynomial + res = nmod_poly_new_init(any_as_nmod_poly_ctx(self.ctx.mod.n)) nmod_mat_charpoly(res.val, self.val) return res @@ -552,8 +589,10 @@ cdef class nmod_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat minpoly: matrix must be square") + if not self.ctx._is_prime: + raise DomainError("minpoly only works for prime moduli") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init(res.val, self.val.mod.n) + # XXX: don't create a new context for the polynomial + res = nmod_poly_new_init(any_as_nmod_poly_ctx(self.ctx.mod.n)) nmod_mat_minpoly(res.val, self.val) return res diff --git a/src/flint/types/nmod_poly.pxd b/src/flint/types/nmod_poly.pxd index bd86887b..1d98ca0b 100644 --- a/src/flint/types/nmod_poly.pxd +++ b/src/flint/types/nmod_poly.pxd @@ -7,6 +7,10 @@ from flint.flint_base.flint_base cimport flint_poly from flint.types.nmod cimport nmod_ctx +cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj) +cdef nmod_poly nmod_poly_new_init(nmod_poly_ctx ctx) + + cdef class nmod_poly_ctx: cdef nmod_ctx ctx cdef nmod_t mod diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 89a64aac..2bbc487a 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -32,6 +32,22 @@ cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj): return NotImplemented +cdef nmod_poly nmod_poly_new_init(nmod_poly_ctx ctx): + cdef nmod_poly p + p = nmod_poly.__new__(nmod_poly) + nmod_poly_init(p.val, ctx.mod.n) + p.ctx = ctx + return p + + +cdef nmod_poly nmod_poly_new_init_preinv(nmod_poly_ctx ctx): + cdef nmod_poly p + p = nmod_poly.__new__(nmod_poly) + nmod_poly_init_preinv(p.val, ctx.mod.n, ctx.mod.ninv) + p.ctx = ctx + return p + + cdef class nmod_poly_ctx: """ Context object for creating :class:`~.nmod_poly` initalised @@ -61,15 +77,13 @@ cdef class nmod_poly_ctx: if typecheck(obj, nmod_poly): return obj if self.ctx.any_as_nmod(&v, obj): - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init(r.val, self.mod.n) + r = nmod_poly_new_init(self) nmod_poly_set_coeff_ui(r.val, 0, v) r.ctx = self return r x = any_as_fmpz_poly(obj) if x is not NotImplemented: - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init(r.val, self.mod.n) # XXX: create flint _nmod_poly_set_modulus for this? + r = nmod_poly_new_init(self) # XXX: create flint _nmod_poly_set_modulus for this? fmpz_poly_get_nmod_poly(r.val, (x).val) r.ctx = self return r @@ -268,8 +282,7 @@ cdef class nmod_poly(flint_poly): else: length = nmod_poly_length(self.val) - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_reverse(res.val, self.val, length) res.ctx = self.ctx return res @@ -317,9 +330,7 @@ cdef class nmod_poly(flint_poly): if self.is_zero(): raise ValueError("cannot invert the zero element") - cdef nmod_poly res - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_inv_series(res.val, self.val, n) res.ctx = self.ctx return res @@ -342,8 +353,7 @@ cdef class nmod_poly(flint_poly): other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_compose(res.val, self.val, (other).val) res.ctx = self.ctx return res @@ -376,8 +386,7 @@ cdef class nmod_poly(flint_poly): if modulus.is_zero(): raise ZeroDivisionError("cannot reduce modulo zero") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_compose_mod(res.val, self.val, (other).val, (modulus).val) res.ctx = self.ctx return res @@ -390,23 +399,20 @@ cdef class nmod_poly(flint_poly): return v t = self.ctx.any_as_nmod_poly(other) if t is not NotImplemented: - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv((r).val, self.val.mod.n, self.val.mod.ninv) + r = nmod_poly_new_init_preinv(self.ctx) nmod_poly_compose((r).val, self.val, (t).val) (r).ctx = self.ctx return r raise TypeError("cannot call nmod_poly with input of type %s", type(other)) def derivative(self): - cdef nmod_poly res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_derivative(res.val, self.val) res.ctx = self.ctx return res def integral(self): - cdef nmod_poly res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_integral(res.val, self.val) res.ctx = self.ctx return res @@ -415,8 +421,7 @@ cdef class nmod_poly(flint_poly): return self def __neg__(self): - cdef nmod_poly r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, self.val.mod.n, self.val.mod.ninv) + cdef nmod_poly r = nmod_poly_new_init_preinv(self.ctx) nmod_poly_neg(r.val, self.val) r.ctx = self.ctx return r @@ -428,8 +433,7 @@ cdef class nmod_poly(flint_poly): return t if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot add nmod_polys with different moduli") - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) + r = nmod_poly_new_init_preinv(s.ctx) nmod_poly_add(r.val, (s).val, (t).val) r.ctx = s.ctx return r @@ -444,8 +448,7 @@ cdef class nmod_poly(flint_poly): cdef nmod_poly r if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot subtract nmod_polys with different moduli") - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) + r = nmod_poly_new_init_preinv(s.ctx) nmod_poly_sub(r.val, (s).val, (t).val) r.ctx = s.ctx return r @@ -469,8 +472,7 @@ cdef class nmod_poly(flint_poly): return t if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot multiply nmod_polys with different moduli") - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) + r = nmod_poly_new_init_preinv(s.ctx) nmod_poly_mul(r.val, (s).val, (t).val) r.ctx = s.ctx return r @@ -505,8 +507,7 @@ cdef class nmod_poly(flint_poly): raise ValueError("cannot divide nmod_polys with different moduli") if nmod_poly_is_zero((t).val): raise ZeroDivisionError("polynomial division by zero") - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) + r = nmod_poly_new_init_preinv(s.ctx) nmod_poly_div(r.val, (s).val, (t).val) r.ctx = s.ctx return r @@ -529,13 +530,9 @@ cdef class nmod_poly(flint_poly): raise ValueError("cannot divide nmod_polys with different moduli") if nmod_poly_is_zero((t).val): raise ZeroDivisionError("polynomial division by zero") - P = nmod_poly.__new__(nmod_poly) - Q = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(P.val, (t).val.mod.n, (t).val.mod.ninv) - nmod_poly_init_preinv(Q.val, (t).val.mod.n, (t).val.mod.ninv) + P = nmod_poly_new_init_preinv(s.ctx) + Q = nmod_poly_new_init_preinv(s.ctx) nmod_poly_divrem(P.val, Q.val, (s).val, (t).val) - P.ctx = s.ctx - Q.ctx = s.ctx return P, Q def __divmod__(s, t): @@ -562,8 +559,7 @@ cdef class nmod_poly(flint_poly): return self.pow_mod(exp, mod) if exp < 0: raise ValueError("negative exponent") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, (self).val.mod.n, (self).val.mod.ninv) + res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_pow(res.val, self.val, exp) res.ctx = self.ctx return res @@ -600,9 +596,7 @@ cdef class nmod_poly(flint_poly): raise TypeError("cannot convert input to nmod_poly") # Output polynomial - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) - res.ctx = self.ctx + res = nmod_poly_new_init_preinv(self.ctx) # For small exponents, use a simple binary exponentiation method if e.bit_length() < 32: @@ -652,8 +646,7 @@ cdef class nmod_poly(flint_poly): raise TypeError("cannot convert input to nmod_poly") if self.val.mod.n != (other).val.mod.n: raise ValueError("moduli must be the same") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_gcd(res.val, self.val, (other).val) res.ctx = self.ctx return res @@ -663,16 +656,10 @@ cdef class nmod_poly(flint_poly): other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to fmpq_poly") - res1 = nmod_poly.__new__(nmod_poly) - res2 = nmod_poly.__new__(nmod_poly) - res3 = nmod_poly.__new__(nmod_poly) - nmod_poly_init(res1.val, (self).val.mod.n) - nmod_poly_init(res2.val, (self).val.mod.n) - nmod_poly_init(res3.val, (self).val.mod.n) + res1 = nmod_poly_new_init(self.ctx) + res2 = nmod_poly_new_init(self.ctx) + res3 = nmod_poly_new_init(self.ctx) nmod_poly_xgcd(res1.val, res2.val, res3.val, self.val, (other).val) - res1.ctx = self.ctx - res2.ctx = self.ctx - res3.ctx = self.ctx return (res1, res2, res3) def factor(self, algorithm=None): @@ -724,6 +711,8 @@ cdef class nmod_poly(flint_poly): def _factor(self, factor_type): cdef nmod_poly_factor_t fac + cdef nmod_poly u + cdef nmod c cdef mp_limb_t lead cdef int i @@ -743,17 +732,14 @@ cdef class nmod_poly(flint_poly): res = [None] * fac.num for 0 <= i < fac.num: - u = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv((u).val, - (self).val.mod.n, (self).val.mod.ninv) - nmod_poly_set((u).val, &fac.p[i]) - (u).ctx = self.ctx + u = nmod_poly_new_init_preinv(self.ctx) + nmod_poly_set(u.val, &fac.p[i]) exp = fac.exp[i] res[i] = (u, exp) c = nmod.__new__(nmod) - (c).ctx = self.ctx.ctx - (c).val = lead + c.ctx = self.ctx.ctx + c.val = lead nmod_poly_factor_clear(fac) @@ -761,10 +747,7 @@ cdef class nmod_poly(flint_poly): def sqrt(nmod_poly self): """Return exact square root or ``None``. """ - cdef nmod_poly res - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) - res.ctx = self.ctx + cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx) if nmod_poly_sqrt(res.val, self.val): return res else: @@ -779,8 +762,7 @@ cdef class nmod_poly(flint_poly): if n == 1: return self, int(n) else: - v = nmod_poly.__new__(nmod_poly) - nmod_poly_init(v.val, self.val.mod.n) + v = nmod_poly_new_init(self.ctx) nmod_poly_deflate(v.val, self.val, n) v.ctx = self.ctx return v, int(n) From affe46257febcaee4b4b17515734ee8321fb580e Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Sat, 24 Aug 2024 15:04:46 +0100 Subject: [PATCH 06/30] Use cython.no_gc to speed up nmod --- src/flint/types/nmod.pyx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 591fcdc3..1cf5b23d 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -1,3 +1,5 @@ +cimport cython + from flint.flint_base.flint_base cimport flint_scalar from flint.utils.typecheck cimport typecheck from flint.types.fmpq cimport any_as_fmpq @@ -193,6 +195,7 @@ cdef class nmod_ctx: return self._new(&v) +@cython.no_gc cdef class nmod(flint_scalar): """ The nmod type represents elements of Z/nZ for word-size n. From 6f04cdd36849d80dc65ba3c739a77fdb8d539d90 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Mon, 26 Aug 2024 11:46:31 +0100 Subject: [PATCH 07/30] fix: check for non-prime modulus in nmod_poly --- src/flint/test/test_all.py | 117 ++++++++++++++++++---------------- src/flint/types/nmod_poly.pyx | 53 +++++++++++++-- 2 files changed, 111 insertions(+), 59 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index 6bdc9480..693d14a6 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -2104,7 +2104,7 @@ def test_fmpz_mod_poly(): assert pow(f, 2**60, g) == pow(pow(f, 2**30, g), 2**30, g) assert pow(R_gen, 2**60, g) == pow(pow(R_gen, 2**30, g), 2**30, g) - # Check other typechecks for pow_mod + # Check other typechecks for pow_mod assert raises(lambda: pow(f, -2, g), ValueError) assert raises(lambda: pow(f, 1, "A"), TypeError) assert raises(lambda: pow(f, "A", g), TypeError) @@ -2542,10 +2542,6 @@ def test_polys(): for P, S, is_field, characteristic in _all_polys(): composite_characteristic = characteristic != 0 and not characteristic.is_prime() - # nmod_poly crashes for many operations with non-prime modulus - # https://github.com/flintlib/python-flint/issues/124 - # so we can't even test it... - nmod_poly_will_crash = type(P(1)) is flint.nmod_poly and composite_characteristic assert P([S(1)]) == P([1]) == P(P([1])) == P(1) @@ -2690,30 +2686,58 @@ def setbad(obj, i, val): assert raises(lambda: P([1, 2, 3]) * None, TypeError) assert raises(lambda: None * P([1, 2, 3]), TypeError) - assert P([1, 2, 1]) // P([1, 1]) == P([1, 1]) - assert P([1, 2, 1]) % P([1, 1]) == P([0]) - assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0])) + if composite_characteristic and type(P(1)) is flint.nmod_poly: + # Z/nZ for n not prime + # + # fmpz_mod_poly and nmod_poly can sometimes compute division with + # composite characteristic, but it is not guaranteed to work. For + # fmpz_mod_poly, we can detect the failure and raise an exception. + # For nmod_poly, we cannot detect the failure and calling e.g. + # nmod_poly_divrem would crash the process so for nmod_poly we + # raise an exception in all cases if the modulus is not prime. + assert raises(lambda: P([1, 2, 1]) // P([1, 1]), DomainError) + assert raises(lambda: P([1, 2, 1]) % P([1, 1]), DomainError) + assert raises(lambda: divmod(P([1, 2, 1]), P([1, 1])), DomainError) + + assert raises(lambda: 1 // P([1, 1]), DomainError) + assert raises(lambda: 1 % P([1, 1]), DomainError) + assert raises(lambda: divmod(1, P([1, 1])), DomainError) + else: + assert P([1, 2, 1]) // P([1, 1]) == P([1, 1]) + assert P([1, 2, 1]) % P([1, 1]) == P([0]) + assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0])) + + assert 1 // P([1, 1]) == P([0]) + assert 1 % P([1, 1]) == P([1]) + assert divmod(1, P([1, 1])) == (P([0]), P([1])) + + assert P([1, 2, 1]) / P([1, 1]) == P([1, 1]) + + assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError) + + assert raises(lambda: 1 / P([1, 1]), DomainError) + assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError) if is_field: assert P([1, 1]) // 2 == P([S(1)/2, S(1)/2]) assert P([1, 1]) % 2 == P([0]) + assert P([2, 2]) / 2 == P([1, 1]) + assert P([1, 2]) / 2 == P([S(1)/2, 1]) elif characteristic == 0: assert P([1, 1]) // 2 == P([0, 0]) assert P([1, 1]) % 2 == P([1, 1]) - elif nmod_poly_will_crash: - pass + assert P([2, 2]) / 2 == P([1, 1]) + assert raises(lambda: P([1, 2]) / 2, DomainError) else: # Z/nZ for n not prime if characteristic % 2 == 0: assert raises(lambda: P([1, 1]) // 2, DomainError) assert raises(lambda: P([1, 1]) % 2, DomainError) + assert raises(lambda: P([2, 2]) / 2, DomainError) + assert raises(lambda: P([1, 2]) / 2, DomainError) else: 1/0 - assert 1 // P([1, 1]) == P([0]) - assert 1 % P([1, 1]) == P([1]) - assert divmod(1, P([1, 1])) == (P([0]), P([1])) - assert raises(lambda: P([1, 2, 1]) // None, TypeError) assert raises(lambda: P([1, 2, 1]) % None, TypeError) assert raises(lambda: divmod(P([1, 2, 1]), None), TypeError) @@ -2730,50 +2754,43 @@ def setbad(obj, i, val): assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError) assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError) - # Exact/field scalar division - if is_field: - assert P([2, 2]) / 2 == P([1, 1]) - assert P([1, 2]) / 2 == P([S(1)/2, 1]) - elif characteristic == 0: - assert P([2, 2]) / 2 == P([1, 1]) - assert raises(lambda: P([1, 2]) / 2, DomainError) - elif nmod_poly_will_crash: - pass - else: - # Z/nZ for n not prime - assert raises(lambda: P([2, 2]) / 2, DomainError) - assert raises(lambda: P([1, 2]) / 2, DomainError) - - assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError) - - if not nmod_poly_will_crash: - assert P([1, 2, 1]) / P([1, 1]) == P([1, 1]) - assert raises(lambda: 1 / P([1, 1]), DomainError) - assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError) - assert P([1, 1]) ** 0 == P([1]) assert P([1, 1]) ** 1 == P([1, 1]) assert P([1, 1]) ** 2 == P([1, 2, 1]) assert raises(lambda: P([1, 1]) ** -1, ValueError) assert raises(lambda: P([1, 1]) ** None, TypeError) - - # XXX: Not sure what this should do in general: + + # 3-arg pow: (x^2 + 1)**3 mod x-1 + + pow3_types = [ + # flint.fmpq_poly, XXX + flint.nmod_poly, + flint.fmpz_mod_poly, + flint.fq_default_poly + ] + p = P([1, 1]) mod = P([1, 1]) - if type(p) not in [flint.fmpz_mod_poly, flint.nmod_poly, flint.fq_default_poly]: + + if type(p) not in pow3_types: assert raises(lambda: pow(p, 2, mod), NotImplementedError) + assert p * p % mod == 0 + elif composite_characteristic and type(p) == flint.nmod_poly: + # nmod_poly does not support % with composite characteristic + assert pow(p, 2, mod) == 0 + assert raises(lambda: p * p % mod, DomainError) else: + # Should be for any is_field including fmpq_poly. Works also in + # some cases for fmpz_mod_poly with non-prime modulus. assert p * p % mod == pow(p, 2, mod) if not composite_characteristic: assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1]) - assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError) - elif nmod_poly_will_crash: - pass else: # Z/nZ for n not prime assert raises(lambda: P([1, 2, 1]).gcd(P([1, 1])), DomainError) - assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError) + + assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError) if is_field: p1 = P([1, 0, 1]) @@ -2784,19 +2801,16 @@ def setbad(obj, i, val): if not composite_characteristic: assert P([1, 2, 1]).factor() == (S(1), [(P([1, 1]), 2)]) - elif nmod_poly_will_crash: - pass else: assert raises(lambda: P([1, 2, 1]).factor(), DomainError) if not composite_characteristic: assert P([1, 2, 1]).sqrt() == P([1, 1]) - assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError) - elif nmod_poly_will_crash: - pass else: assert raises(lambda: P([1, 2, 1]).sqrt(), DomainError) + assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError) + if P == flint.fmpq_poly: assert raises(lambda: P([1, 2, 1], 3).sqrt(), ValueError) assert P([1, 2, 1], 4).sqrt() == P([1, 1], 2) @@ -3424,13 +3438,6 @@ def factor_sqf(p): for P, S, [x, y], is_field, characteristic in _all_polys_mpolys(): if characteristic != 0 and not characteristic.is_prime(): - # nmod_poly crashes for many operations with non-prime modulus - # https://github.com/flintlib/python-flint/issues/124 - # so we can't even test it... - nmod_poly_will_crash = type(x) is flint.nmod_poly - if nmod_poly_will_crash: - continue - try: S(4).sqrt() ** 2 == S(4) except DomainError: @@ -4190,7 +4197,7 @@ def test_fq_default(): # p must be prime assert raises(lambda: flint.fq_default_ctx(10), ValueError) - + # degree must be positive assert raises(lambda: flint.fq_default_ctx(11, -1), ValueError) diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 2bbc487a..b2622ff9 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -503,10 +503,14 @@ cdef class nmod_poly(flint_poly): def _floordiv_(s, t): cdef nmod_poly r + if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot divide nmod_polys with different moduli") if nmod_poly_is_zero((t).val): raise ZeroDivisionError("polynomial division by zero") + if not s.ctx._is_prime: + raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime") + r = nmod_poly_new_init_preinv(s.ctx) nmod_poly_div(r.val, (s).val, (t).val) r.ctx = s.ctx @@ -526,10 +530,14 @@ cdef class nmod_poly(flint_poly): def _divmod_(s, t): cdef nmod_poly P, Q + if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot divide nmod_polys with different moduli") if nmod_poly_is_zero((t).val): raise ZeroDivisionError("polynomial division by zero") + if not s.ctx._is_prime: + raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime") + P = nmod_poly_new_init_preinv(s.ctx) Q = nmod_poly_new_init_preinv(s.ctx) nmod_poly_divrem(P.val, Q.val, (s).val, (t).val) @@ -639,27 +647,53 @@ cdef class nmod_poly(flint_poly): >>> (A * B).gcd(B) * 5 5*x^2 + x + 4 + The modulus must be prime. """ cdef nmod_poly res + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") if self.val.mod.n != (other).val.mod.n: raise ValueError("moduli must be the same") + if not self.ctx._is_prime: + raise DomainError("nmod_poly gcd: modulus {self.ctx.mod.n} is not prime") + res = nmod_poly_new_init_preinv(self.ctx) nmod_poly_gcd(res.val, self.val, (other).val) res.ctx = self.ctx return res def xgcd(self, other): + r""" + Computes the extended gcd of self and other: (`G`, `S`, `T`) + where `G` is the ``gcd(self, other)`` and `S`, `T` are such that: + + :math:`G = \textrm{self}*S + \textrm{other}*T` + + >>> f = nmod_poly([143, 19, 37, 138, 102, 127, 95], 163) + >>> g = nmod_poly([139, 9, 35, 154, 87, 120, 24], 163) + >>> f.xgcd(g) + (x^3 + 128*x^2 + 123*x + 91, 17*x^2 + 49*x + 104, 21*x^2 + 5*x + 25) + + The modulus must be prime. + """ cdef nmod_poly res1, res2, res3 + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to fmpq_poly") + if self.val.mod.n != (other).val.mod.n: + raise ValueError("moduli must be the same") + if not self.ctx._is_prime: + raise DomainError("nmod_poly xgcd: modulus {self.ctx.mod.n} is not prime") + res1 = nmod_poly_new_init(self.ctx) res2 = nmod_poly_new_init(self.ctx) res3 = nmod_poly_new_init(self.ctx) + nmod_poly_xgcd(res1.val, res2.val, res3.val, self.val, (other).val) + return (res1, res2, res3) def factor(self, algorithm=None): @@ -684,11 +718,14 @@ cdef class nmod_poly(flint_poly): >>> nmod_poly([3,2,1,2,3], 7).factor(algorithm='cantor-zassenhaus') (3, [(x + 4, 1), (x + 2, 1), (x^2 + 4*x + 1, 1)]) + The modulus must be prime. """ if algorithm is None: algorithm = 'irreducible' elif algorithm not in ('berlekamp', 'cantor-zassenhaus'): raise ValueError(f"unknown factorization algorithm: {algorithm}") + if not self.ctx._is_prime: + raise DomainError(f"nmod_poly factor: modulus {self.ctx.mod.n} is not prime") return self._factor(algorithm) def factor_squarefree(self): @@ -707,6 +744,8 @@ cdef class nmod_poly(flint_poly): (2, [(x, 2), (x + 5, 2), (x + 1, 3)]) """ + if not self.ctx._is_prime: + raise DomainError(f"nmod_poly factor_squarefree: modulus {self.ctx.mod.n} is not prime") return self._factor('squarefree') def _factor(self, factor_type): @@ -747,12 +786,18 @@ cdef class nmod_poly(flint_poly): def sqrt(nmod_poly self): """Return exact square root or ``None``. """ - cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx) - if nmod_poly_sqrt(res.val, self.val): - return res - else: + cdef nmod_poly + + if not self.ctx._is_prime: + raise DomainError(f"nmod_poly sqrt: modulus {self.ctx.mod.n} is not prime") + + res = nmod_poly_new_init_preinv(self.ctx) + + if not nmod_poly_sqrt(res.val, self.val): raise DomainError(f"Cannot compute square root of {self}") + return res + def deflation(self): cdef nmod_poly v cdef ulong n From 23389a5fe8f0997a21ca6f099cc902b0a366882b Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Mon, 26 Aug 2024 11:56:59 +0100 Subject: [PATCH 08/30] test: Add nmod_poly with modulus=9 test case --- src/flint/test/test_all.py | 101 +++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index 693d14a6..dadf55ba 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -2485,58 +2485,64 @@ def test_division_matrix(): def _all_polys(): - return [ - # (poly_type, scalar_type, is_field, characteristic) + # (poly_type, scalar_type, is_field, characteristic) + FMPZ = (flint.fmpz_poly, flint.fmpz, False, flint.fmpz(0)) + FMPQ = (flint.fmpq_poly, flint.fmpq, True, flint.fmpz(0)) + + def NMOD(n): + return ( + lambda *a: flint.nmod_poly(*a, n), + lambda x: flint.nmod(x, n), + flint.fmpz(n).is_prime(), + flint.fmpz(n) + ) + + def FMPZ_MOD(n): + return ( + lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(n)), + lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(n)), + flint.fmpz(n).is_prime(), + flint.fmpz(n) + ) + + def FQ_DEFAULT(n, k): + return ( + lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(n, k)), + lambda x: flint.fq_default(x, flint.fq_default_ctx(n, k)), + True, + flint.fmpz(n) + ) + + ALL_POLYS = [ # Z and Q - (flint.fmpz_poly, flint.fmpz, False, flint.fmpz(0)), - (flint.fmpq_poly, flint.fmpq, True, flint.fmpz(0)), + FMPZ, + FMPQ, # Z/pZ (p prime) - (lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True, flint.fmpz(17)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)), - True, flint.fmpz(163)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)), - True, flint.fmpz(2**127 - 1)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)), - True, flint.fmpz(2**255 - 19)), + NMOD(17), + FMPZ_MOD(163), + FMPZ_MOD(2**127 - 1), + FMPZ_MOD(2**255 - 19), # GF(p^k) (p prime) - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(2**127 - 1)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(2**127 - 1)), - True, flint.fmpz(2**127 - 1)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(2**127 - 1, 2)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(2**127 - 1, 2)), - True, flint.fmpz(2**127 - 1)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(65537)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(65537)), - True, flint.fmpz(65537)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(65537, 5)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(65537, 5)), - True, flint.fmpz(65537)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(11)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(11)), - True, flint.fmpz(11)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(11, 5)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(11, 5)), - True, flint.fmpz(11)), + FQ_DEFAULT(2**127 - 1, 1), + FQ_DEFAULT(2**127 - 1, 2), + FQ_DEFAULT(65537, 1), + FQ_DEFAULT(65537, 5), + FQ_DEFAULT(11, 1), + FQ_DEFAULT(11, 5), # Z/nZ (n composite) - (lambda *a: flint.nmod_poly(*a, 16), lambda x: flint.nmod(x, 16), False, flint.fmpz(16)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(164)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(164)), - False, flint.fmpz(164)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127)), - False, flint.fmpz(2**127)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255)), - False, flint.fmpz(2**255)), + NMOD(9), + NMOD(16), + FMPZ_MOD(164), + FMPZ_MOD(2**127), + FMPZ_MOD(2**255), ] + return ALL_POLYS + def test_polys(): for P, S, is_field, characteristic in _all_polys(): @@ -2730,13 +2736,10 @@ def setbad(obj, i, val): assert raises(lambda: P([1, 2]) / 2, DomainError) else: # Z/nZ for n not prime - if characteristic % 2 == 0: - assert raises(lambda: P([1, 1]) // 2, DomainError) - assert raises(lambda: P([1, 1]) % 2, DomainError) - assert raises(lambda: P([2, 2]) / 2, DomainError) - assert raises(lambda: P([1, 2]) / 2, DomainError) - else: - 1/0 + assert raises(lambda: P([1, 1]) // 2, DomainError) + assert raises(lambda: P([1, 1]) % 2, DomainError) + assert raises(lambda: P([2, 2]) / 2, DomainError) + assert raises(lambda: P([1, 2]) / 2, DomainError) assert raises(lambda: P([1, 2, 1]) // None, TypeError) assert raises(lambda: P([1, 2, 1]) % None, TypeError) From f47188cd090cd43a6b264f50974481ad9d00f6e5 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Mon, 26 Aug 2024 22:44:13 +0100 Subject: [PATCH 09/30] Use any_as_nmod rather than ctx.any_as_nmod --- src/flint/test/test_all.py | 4 + src/flint/types/arb.pyx | 2 +- src/flint/types/nmod.pxd | 13 ++- src/flint/types/nmod.pyx | 177 +++++++++++++++------------------- src/flint/types/nmod_mat.pyx | 4 +- src/flint/types/nmod_poly.pyx | 3 +- 6 files changed, 95 insertions(+), 108 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index dadf55ba..ba58ad1a 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -1357,6 +1357,10 @@ def test_nmod(): assert str(G(3,5)) == "3" assert G(3,5).repr() == "nmod(3, 5)" + G = flint.nmod_ctx.get_ctx(7) + assert G(0) == G(7) == G(-7) + + def test_nmod_poly(): N = flint.nmod P = flint.nmod_poly diff --git a/src/flint/types/arb.pyx b/src/flint/types/arb.pyx index 4ae8fa45..d9bd95c2 100644 --- a/src/flint/types/arb.pyx +++ b/src/flint/types/arb.pyx @@ -2259,7 +2259,7 @@ cdef class arb(flint_scalar): >>> from flint import showgood >>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5, abc=True), dps=25) 1.447530478120770807945697 - >>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5), dps=25) + >>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5), dps=25) # doctest: +SKIP Traceback (most recent call last): ... ValueError: no convergence (maxprec=960, try higher maxprec) diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 927fa677..d4f386d5 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -1,15 +1,22 @@ from flint.flint_base.flint_base cimport flint_scalar -from flint.flintlib.flint cimport mp_limb_t +from flint.flintlib.flint cimport mp_limb_t, ulong from flint.flintlib.nmod cimport nmod_t -cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1 -cdef nmod_ctx any_as_nmod_ctx(obj) +cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx mod) except -1 +#cdef nmod_ctx any_as_nmod_ctx(obj) cdef class nmod_ctx: cdef nmod_t mod cdef bint _is_prime + @staticmethod + cdef nmod_ctx any_as_nmod_ctx(obj) + @staticmethod + cdef _get_ctx(int mod) + @staticmethod + cdef _new_ctx(ulong mod) + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 cdef nmod _new(self, mp_limb_t * val) diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 1cf5b23d..96a9bc6a 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -19,49 +19,7 @@ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime, n_sqrtmod from flint.utils.flint_exceptions import DomainError -#cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1: -# return mod.ctx.any_as_nmod(val, obj) - -_nmod_ctx_cache = {} - - -cdef nmod_ctx any_as_nmod_ctx(obj): - """Convert an int to an nmod_ctx.""" - if typecheck(obj, nmod_ctx): - return obj - if typecheck(obj, int): - ctx = _nmod_ctx_cache.get(obj) - if ctx is None: - ctx = nmod_ctx(obj) - _nmod_ctx_cache[obj] = ctx - return ctx - return NotImplemented - - -cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1: - """Convert an object to an nmod element.""" - cdef int success - cdef fmpz_t t - if typecheck(obj, nmod): - if (obj).ctx.mod.n != mod.n: - raise ValueError("cannot coerce integers mod n with different n") - val[0] = (obj).val - return 1 - z = any_as_fmpz(obj) - if z is not NotImplemented: - val[0] = fmpz_fdiv_ui((z).val, mod.n) - return 1 - q = any_as_fmpq(obj) - if q is not NotImplemented: - fmpz_init(t) - fmpz_set_ui(t, mod.n) - success = fmpq_mod_fmpz(t, (q).val, t) - val[0] = fmpz_get_ui(t) - fmpz_clear(t) - if not success: - raise ZeroDivisionError("%s does not exist mod %i!" % (q, mod.n)) - return 1 - return 0 +cdef dict _nmod_ctx_cache = {} cdef class nmod_ctx: @@ -69,62 +27,55 @@ cdef class nmod_ctx: Context object for creating :class:`~.nmod` initalised with modulus :math:`N`. - >>> nmod_ctx(17) + >>> nmod_ctx.get_ctx(17) nmod_ctx(17) """ - def __init__(self, mod): - cdef mp_limb_t m - m = mod - nmod_init(&self.mod, m) - self._is_prime = n_is_prime(m) - def __eq__(self, other): - # XXX: If we could ensure uniqueness of nmod_ctx for given modulus then - # we would need to implement __eq__ and __hash__ at all... - # - # It isn't possible to ensure uniqueness in __new__ like it is in - # Python because we can't return an existing object from __new__. What - # we could do though is make it so that __init__ raises an error and - # use a static method .new() to create new objects. - if self is other: - return True - if not typecheck(other, nmod_ctx): - return NotImplemented - return self.mod.n == (other).mod.n + def __init__(self, *args, **kwargs): + raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.get_ctx()") + + @staticmethod + cdef nmod_ctx any_as_nmod_ctx(obj): + """Convert an int to an nmod_ctx.""" + if typecheck(obj, nmod_ctx): + return obj + if typecheck(obj, int): + return nmod_ctx._get_ctx(obj) + return NotImplemented + + @staticmethod + def get_ctx(mod): + """Create a new nmod context.""" + return nmod_ctx._get_ctx(mod) + + @staticmethod + cdef _get_ctx(int mod): + """Create a new nmod context.""" + ctx = _nmod_ctx_cache.get(mod) + if ctx is None: + _nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod) + return ctx + + @staticmethod + cdef _new_ctx(ulong mod): + """Create a new nmod context.""" + cdef nmod_ctx ctx = nmod_ctx.__new__(nmod_ctx) + nmod_init(&ctx.mod, mod) + ctx._is_prime = n_is_prime(mod) + return ctx def __repr__(self): return f"nmod_ctx({self.modulus()})" cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: """Convert an object to an nmod element.""" - cdef int success - cdef fmpz_t t - if typecheck(obj, nmod): - if (obj).ctx != self: - raise ValueError("cannot coerce integers mod n with different n") - val[0] = (obj).val - return 1 - z = any_as_fmpz(obj) - if z is not NotImplemented: - val[0] = fmpz_fdiv_ui((z).val, self.mod.n) - return 1 - q = any_as_fmpq(obj) - if q is not NotImplemented: - fmpz_init(t) - fmpz_set_ui(t, self.mod.n) - success = fmpq_mod_fmpz(t, (q).val, t) - val[0] = fmpz_get_ui(t) - fmpz_clear(t) - if not success: - raise ZeroDivisionError("%s does not exist mod %i!" % (q, self.mod.n)) - return 1 - return 0 + return any_as_nmod(val, obj, self) def modulus(self): """Get the modulus of the context. - >>> ctx = nmod_ctx(17) + >>> ctx = nmod_ctx.get_ctx(17) >>> ctx.modulus() 17 @@ -134,7 +85,7 @@ cdef class nmod_ctx: def is_prime(self): """Check if the modulus is prime. - >>> ctx = nmod_ctx(17) + >>> ctx = nmod_ctx.get_ctx(17) >>> ctx.is_prime() True @@ -144,7 +95,7 @@ cdef class nmod_ctx: def zero(self): """Return the zero element of the context. - >>> ctx = nmod_ctx(17) + >>> ctx = nmod_ctx.get_ctx(17) >>> ctx.zero() 0 @@ -154,7 +105,7 @@ cdef class nmod_ctx: def one(self): """Return the one element of the context. - >>> ctx = nmod_ctx(17) + >>> ctx = nmod_ctx.get_ctx(17) >>> ctx.one() 1 @@ -185,16 +136,42 @@ cdef class nmod_ctx: def __call__(self, val): """Create an nmod element from an integer. - >>> ctx = nmod_ctx(17) + >>> ctx = nmod_ctx.get_ctx(17) >>> ctx(10) 10 """ cdef mp_limb_t v - v = val + v = val % self.mod.n return self._new(&v) +cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx ctx) except -1: + """Convert an object to an nmod element.""" + cdef int success + cdef fmpz_t t + if typecheck(obj, nmod): + if (obj).ctx.mod.n != ctx.mod.n: + raise ValueError("cannot coerce integers mod n with different n") + val[0] = (obj).val + return 1 + z = any_as_fmpz(obj) + if z is not NotImplemented: + val[0] = fmpz_fdiv_ui((z).val, ctx.mod.n) + return 1 + q = any_as_fmpq(obj) + if q is not NotImplemented: + fmpz_init(t) + fmpz_set_ui(t, ctx.mod.n) + success = fmpq_mod_fmpz(t, (q).val, t) + val[0] = fmpz_get_ui(t) + fmpz_clear(t) + if not success: + raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n)) + return 1 + return 0 + + @cython.no_gc cdef class nmod(flint_scalar): """ @@ -205,10 +182,10 @@ cdef class nmod(flint_scalar): """ def __init__(self, val, mod): - ctx = any_as_nmod_ctx(mod) + ctx = nmod_ctx.any_as_nmod_ctx(mod) if ctx is NotImplemented: raise TypeError("Invalid context/modulus for nmod: %s" % mod) - if not ctx.any_as_nmod(&self.val, val): + if not any_as_nmod(&self.val, val, ctx): raise TypeError("cannot create nmod from object of type %s" % type(val)) self.ctx = ctx @@ -262,7 +239,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if s2.ctx.any_as_nmod(&val, t): + if any_as_nmod(&val, t, s2.ctx): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_add(val, s2.val, s2.ctx.mod) @@ -273,7 +250,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if s2.ctx.any_as_nmod(&val, t): + if any_as_nmod(&val, t, s2.ctx): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_add(s2.val, val, s2.ctx.mod) @@ -284,7 +261,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if s2.ctx.any_as_nmod(&val, t): + if any_as_nmod(&val, t, s2.ctx): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_sub(s2.val, val, s2.ctx.mod) @@ -295,7 +272,7 @@ cdef class nmod(flint_scalar): cdef nmod r cdef mp_limb_t val s2 = s - if s2.ctx.any_as_nmod(&val, t): + if any_as_nmod(&val, t, s2.ctx): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_sub(val, s2.val, s2.ctx.mod) @@ -306,7 +283,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if any_as_nmod(&val, t, s2.ctx.mod): + if any_as_nmod(&val, t, s2.ctx): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_mul(val, s2.val, s2.ctx.mod) @@ -317,7 +294,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if s2.ctx.any_as_nmod(&val, t): + if any_as_nmod(&val, t, s2.ctx): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_mul(s2.val, val, s2.ctx.mod) @@ -335,13 +312,13 @@ cdef class nmod(flint_scalar): s2 = s ctx = s2.ctx sval = s2.val - if not ctx.any_as_nmod(&tval, t): + if not any_as_nmod(&tval, t, ctx): return NotImplemented else: t2 = t ctx = t2.ctx tval = t2.val - if not ctx.any_as_nmod(&sval, s): + if not any_as_nmod(&sval, s, ctx): return NotImplemented if tval == 0: diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 58ccf86f..7257c00f 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -43,7 +43,7 @@ from flint.flintlib.nmod_mat cimport ( from flint.utils.typecheck cimport typecheck from flint.types.fmpz_mat cimport any_as_fmpz_mat from flint.types.fmpz_mat cimport fmpz_mat -from flint.types.nmod cimport nmod, any_as_nmod_ctx +from flint.types.nmod cimport nmod, nmod_ctx from flint.types.nmod_poly cimport nmod_poly, nmod_poly_new_init, any_as_nmod_poly_ctx from flint.pyflint cimport global_random_state from flint.flint_base.flint_context cimport thectx @@ -120,7 +120,7 @@ cdef class nmod_mat(flint_mat): if mod == 0: raise ValueError("modulus must be nonzero") - ctx = any_as_nmod_ctx(mod) + ctx = nmod_ctx.any_as_nmod_ctx(mod) self.ctx = ctx if len(args) == 1: diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index b2622ff9..e00eb44a 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -4,7 +4,6 @@ from flint.utils.typecheck cimport typecheck from flint.types.fmpz cimport fmpz, any_as_fmpz from flint.types.fmpz_poly cimport any_as_fmpz_poly from flint.types.fmpz_poly cimport fmpz_poly -from flint.types.nmod cimport any_as_nmod_ctx from flint.types.nmod cimport nmod, nmod_ctx from flint.flintlib.nmod_vec cimport * @@ -61,7 +60,7 @@ cdef class nmod_poly_ctx: cdef mp_limb_t m m = mod nmod_init(&self.mod, m) - self.ctx = nmod_ctx(mod) + self.ctx = nmod_ctx.get_ctx(mod) self._is_prime = n_is_prime(m) def __repr__(self): From 154b035246b40d9556a74a152dfa51dcdeb0118d Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Mon, 26 Aug 2024 23:15:39 +0100 Subject: [PATCH 10/30] perf: use @cython.final for nmod_ctx.any_as_nmod --- src/flint/types/nmod.pxd | 3 -- src/flint/types/nmod.pyx | 83 +++++++++++++++++----------------------- 2 files changed, 35 insertions(+), 51 deletions(-) diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index d4f386d5..44675775 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -2,9 +2,6 @@ from flint.flint_base.flint_base cimport flint_scalar from flint.flintlib.flint cimport mp_limb_t, ulong from flint.flintlib.nmod cimport nmod_t -cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx mod) except -1 -#cdef nmod_ctx any_as_nmod_ctx(obj) - cdef class nmod_ctx: cdef nmod_t mod diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 96a9bc6a..a4940d6b 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -65,13 +65,35 @@ cdef class nmod_ctx: ctx._is_prime = n_is_prime(mod) return ctx + @cython.final + cdef int any_as_nmod(nmod_ctx ctx, mp_limb_t * val, obj) except -1: + """Convert an object to an nmod element.""" + cdef int success + cdef fmpz_t t + if typecheck(obj, nmod): + if (obj).ctx.mod.n != ctx.mod.n: + raise ValueError("cannot coerce integers mod n with different n") + val[0] = (obj).val + return 1 + z = any_as_fmpz(obj) + if z is not NotImplemented: + val[0] = fmpz_fdiv_ui((z).val, ctx.mod.n) + return 1 + q = any_as_fmpq(obj) + if q is not NotImplemented: + fmpz_init(t) + fmpz_set_ui(t, ctx.mod.n) + success = fmpq_mod_fmpz(t, (q).val, t) + val[0] = fmpz_get_ui(t) + fmpz_clear(t) + if not success: + raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n)) + return 1 + return 0 + def __repr__(self): return f"nmod_ctx({self.modulus()})" - cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: - """Convert an object to an nmod element.""" - return any_as_nmod(val, obj, self) - def modulus(self): """Get the modulus of the context. @@ -112,15 +134,6 @@ cdef class nmod_ctx: """ return self(1) - def __hash__(self): - return hash(self.mod) - - def __eq__(self, other): - if typecheck(other, nmod_ctx): - return self.mod.n == (other).mod.n - else: - return NotImplemented - def __str__(self): return f"Context for nmod with modulus: {self.modulus()}" @@ -146,32 +159,6 @@ cdef class nmod_ctx: return self._new(&v) -cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx ctx) except -1: - """Convert an object to an nmod element.""" - cdef int success - cdef fmpz_t t - if typecheck(obj, nmod): - if (obj).ctx.mod.n != ctx.mod.n: - raise ValueError("cannot coerce integers mod n with different n") - val[0] = (obj).val - return 1 - z = any_as_fmpz(obj) - if z is not NotImplemented: - val[0] = fmpz_fdiv_ui((z).val, ctx.mod.n) - return 1 - q = any_as_fmpq(obj) - if q is not NotImplemented: - fmpz_init(t) - fmpz_set_ui(t, ctx.mod.n) - success = fmpq_mod_fmpz(t, (q).val, t) - val[0] = fmpz_get_ui(t) - fmpz_clear(t) - if not success: - raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n)) - return 1 - return 0 - - @cython.no_gc cdef class nmod(flint_scalar): """ @@ -185,7 +172,7 @@ cdef class nmod(flint_scalar): ctx = nmod_ctx.any_as_nmod_ctx(mod) if ctx is NotImplemented: raise TypeError("Invalid context/modulus for nmod: %s" % mod) - if not any_as_nmod(&self.val, val, ctx): + if not ctx.any_as_nmod(&self.val, val): raise TypeError("cannot create nmod from object of type %s" % type(val)) self.ctx = ctx @@ -239,7 +226,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if any_as_nmod(&val, t, s2.ctx): + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_add(val, s2.val, s2.ctx.mod) @@ -250,7 +237,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if any_as_nmod(&val, t, s2.ctx): + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_add(s2.val, val, s2.ctx.mod) @@ -261,7 +248,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if any_as_nmod(&val, t, s2.ctx): + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_sub(s2.val, val, s2.ctx.mod) @@ -272,7 +259,7 @@ cdef class nmod(flint_scalar): cdef nmod r cdef mp_limb_t val s2 = s - if any_as_nmod(&val, t, s2.ctx): + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_sub(val, s2.val, s2.ctx.mod) @@ -283,7 +270,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if any_as_nmod(&val, t, s2.ctx): + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_mul(val, s2.val, s2.ctx.mod) @@ -294,7 +281,7 @@ cdef class nmod(flint_scalar): cdef nmod r, s2 cdef mp_limb_t val s2 = s - if any_as_nmod(&val, t, s2.ctx): + if s2.ctx.any_as_nmod(&val, t): r = nmod.__new__(nmod) r.ctx = s2.ctx r.val = nmod_mul(s2.val, val, s2.ctx.mod) @@ -312,13 +299,13 @@ cdef class nmod(flint_scalar): s2 = s ctx = s2.ctx sval = s2.val - if not any_as_nmod(&tval, t, ctx): + if not ctx.any_as_nmod(&tval, t): return NotImplemented else: t2 = t ctx = t2.ctx tval = t2.val - if not any_as_nmod(&sval, s, ctx): + if not ctx.any_as_nmod(&sval, s): return NotImplemented if tval == 0: From 82af09eb9dac16d1025413284d4606b50986278a Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 27 Aug 2024 18:45:24 +0100 Subject: [PATCH 11/30] Add nmod_ctx.new_nmod function --- src/flint/test/test_all.py | 2 +- src/flint/types/nmod.pxd | 2 +- src/flint/types/nmod.pyx | 89 ++++++++++++++++------------------- src/flint/types/nmod_poly.pyx | 2 +- 4 files changed, 44 insertions(+), 51 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index ba58ad1a..84a7ec26 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -1357,7 +1357,7 @@ def test_nmod(): assert str(G(3,5)) == "3" assert G(3,5).repr() == "nmod(3, 5)" - G = flint.nmod_ctx.get_ctx(7) + G = flint.nmod_ctx.new(7) assert G(0) == G(7) == G(-7) diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 44675775..24bf5e96 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -15,7 +15,7 @@ cdef class nmod_ctx: cdef _new_ctx(ulong mod) cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 - cdef nmod _new(self, mp_limb_t * val) + cdef nmod new_nmod(self) cdef class nmod(flint_scalar): diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index a4940d6b..75aa6c37 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -27,31 +27,38 @@ cdef class nmod_ctx: Context object for creating :class:`~.nmod` initalised with modulus :math:`N`. - >>> nmod_ctx.get_ctx(17) + >>> ctx = nmod_ctx.new(17) + >>> ctx nmod_ctx(17) + >>> ctx.modulus() + 17 + >>> e = ctx(10) + >>> e + 10 + >>> e + 10 + 3 """ - def __init__(self, *args, **kwargs): - raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.get_ctx()") + raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.new()") + + @staticmethod + def new(mod): + """Get an nmod context with modulus ``mod``.""" + return nmod_ctx._get_ctx(mod) @staticmethod cdef nmod_ctx any_as_nmod_ctx(obj): - """Convert an int to an nmod_ctx.""" + """Convert an ``nmod_ctx`` or ``int`` to an ``nmod_ctx``.""" if typecheck(obj, nmod_ctx): return obj if typecheck(obj, int): return nmod_ctx._get_ctx(obj) return NotImplemented - @staticmethod - def get_ctx(mod): - """Create a new nmod context.""" - return nmod_ctx._get_ctx(mod) - @staticmethod cdef _get_ctx(int mod): - """Create a new nmod context.""" + """Retrieve an nmod context from the cache or create a new one.""" ctx = _nmod_ctx_cache.get(mod) if ctx is None: _nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod) @@ -91,13 +98,16 @@ cdef class nmod_ctx: return 1 return 0 - def __repr__(self): - return f"nmod_ctx({self.modulus()})" + @cython.final + cdef nmod new_nmod(self): + cdef nmod r = nmod.__new__(nmod) + r.ctx = self + return r def modulus(self): """Get the modulus of the context. - >>> ctx = nmod_ctx.get_ctx(17) + >>> ctx = nmod_ctx.new(17) >>> ctx.modulus() 17 @@ -107,7 +117,7 @@ cdef class nmod_ctx: def is_prime(self): """Check if the modulus is prime. - >>> ctx = nmod_ctx.get_ctx(17) + >>> ctx = nmod_ctx.new(17) >>> ctx.is_prime() True @@ -117,7 +127,7 @@ cdef class nmod_ctx: def zero(self): """Return the zero element of the context. - >>> ctx = nmod_ctx.get_ctx(17) + >>> ctx = nmod_ctx.new(17) >>> ctx.zero() 0 @@ -127,7 +137,7 @@ cdef class nmod_ctx: def one(self): """Return the one element of the context. - >>> ctx = nmod_ctx.get_ctx(17) + >>> ctx = nmod_ctx.new(17) >>> ctx.one() 1 @@ -140,23 +150,17 @@ cdef class nmod_ctx: def __repr__(self): return f"nmod_ctx({self.modulus()})" - cdef nmod _new(self, mp_limb_t * val): - cdef nmod r = nmod.__new__(nmod) - r.val = val[0] - r.ctx = self - return r - def __call__(self, val): """Create an nmod element from an integer. - >>> ctx = nmod_ctx.get_ctx(17) + >>> ctx = nmod_ctx.new(17) >>> ctx(10) 10 """ - cdef mp_limb_t v - v = val % self.mod.n - return self._new(&v) + r = self.new_nmod() + self.any_as_nmod(&r.val, val) + return r @cython.no_gc @@ -217,8 +221,7 @@ cdef class nmod(flint_scalar): return self def __neg__(self): - cdef nmod r = nmod.__new__(nmod) - r.ctx = self.ctx + r = self.ctx.new_nmod() r.val = nmod_neg(self.val, self.ctx.mod) return r @@ -227,8 +230,7 @@ cdef class nmod(flint_scalar): cdef mp_limb_t val s2 = s if s2.ctx.any_as_nmod(&val, t): - r = nmod.__new__(nmod) - r.ctx = s2.ctx + r = s2.ctx.new_nmod() r.val = nmod_add(val, s2.val, s2.ctx.mod) return r return NotImplemented @@ -238,8 +240,7 @@ cdef class nmod(flint_scalar): cdef mp_limb_t val s2 = s if s2.ctx.any_as_nmod(&val, t): - r = nmod.__new__(nmod) - r.ctx = s2.ctx + r = s2.ctx.new_nmod() r.val = nmod_add(s2.val, val, s2.ctx.mod) return r return NotImplemented @@ -249,8 +250,7 @@ cdef class nmod(flint_scalar): cdef mp_limb_t val s2 = s if s2.ctx.any_as_nmod(&val, t): - r = nmod.__new__(nmod) - r.ctx = s2.ctx + r = s2.ctx.new_nmod() r.val = nmod_sub(s2.val, val, s2.ctx.mod) return r return NotImplemented @@ -260,8 +260,7 @@ cdef class nmod(flint_scalar): cdef mp_limb_t val s2 = s if s2.ctx.any_as_nmod(&val, t): - r = nmod.__new__(nmod) - r.ctx = s2.ctx + r = s2.ctx.new_nmod() r.val = nmod_sub(val, s2.val, s2.ctx.mod) return r return NotImplemented @@ -271,8 +270,7 @@ cdef class nmod(flint_scalar): cdef mp_limb_t val s2 = s if s2.ctx.any_as_nmod(&val, t): - r = nmod.__new__(nmod) - r.ctx = s2.ctx + r = s2.ctx.new_nmod() r.val = nmod_mul(val, s2.val, s2.ctx.mod) return r return NotImplemented @@ -282,8 +280,7 @@ cdef class nmod(flint_scalar): cdef mp_limb_t val s2 = s if s2.ctx.any_as_nmod(&val, t): - r = nmod.__new__(nmod) - r.ctx = s2.ctx + r = s2.ctx.new_nmod() r.val = nmod_mul(s2.val, val, s2.ctx.mod) return r return NotImplemented @@ -317,8 +314,7 @@ cdef class nmod(flint_scalar): if g != 1: raise ZeroDivisionError("%s is not invertible mod %s" % (tval, ctx.mod.n)) - r = nmod.__new__(nmod) - r.ctx = ctx + r = ctx.new_nmod() r.val = nmod_mul(sval, tinvval, ctx.mod) return r @@ -338,8 +334,7 @@ cdef class nmod(flint_scalar): g = n_gcdinv(&inv, sval, ctx.mod.n) if g != 1: raise ZeroDivisionError("%s is not invertible mod %s" % (sval, ctx.mod.n)) - r = nmod.__new__(nmod) - r.ctx = ctx + r = ctx.new_nmod() r.val = inv return r @@ -370,8 +365,7 @@ cdef class nmod(flint_scalar): rval = rinv e = -e - r = nmod.__new__(nmod) - r.ctx = ctx + r = ctx.new_nmod() r.val = nmod_pow_fmpz(rval, (e).val, ctx.mod) return r @@ -394,8 +388,7 @@ cdef class nmod(flint_scalar): """ cdef nmod r cdef mp_limb_t val - r = nmod.__new__(nmod) - r.ctx = self.ctx + r = self.ctx.new_nmod() if self.val == 0: return r diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index e00eb44a..98178226 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -60,7 +60,7 @@ cdef class nmod_poly_ctx: cdef mp_limb_t m m = mod nmod_init(&self.mod, m) - self.ctx = nmod_ctx.get_ctx(mod) + self.ctx = nmod_ctx.new(mod) self._is_prime = n_is_prime(m) def __repr__(self): From 8d69a4e54ee6518aa0fe4d72b940411415dd5043 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 27 Aug 2024 21:06:17 +0100 Subject: [PATCH 12/30] Add nmod_poly_ctx and nmod_mat_ctx --- src/flint/types/nmod.pxd | 7 +- src/flint/types/nmod.pyx | 9 +- src/flint/types/nmod_mat.pxd | 39 ++++- src/flint/types/nmod_mat.pyx | 296 ++++++++++++++++++++++++++-------- src/flint/types/nmod_poly.pxd | 28 +++- src/flint/types/nmod_poly.pyx | 232 +++++++++++++++----------- 6 files changed, 436 insertions(+), 175 deletions(-) diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 24bf5e96..0ed57a7a 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -1,20 +1,25 @@ +cimport cython + from flint.flint_base.flint_base cimport flint_scalar from flint.flintlib.flint cimport mp_limb_t, ulong from flint.flintlib.nmod cimport nmod_t +@cython.no_gc cdef class nmod_ctx: cdef nmod_t mod cdef bint _is_prime @staticmethod - cdef nmod_ctx any_as_nmod_ctx(obj) + cdef any_as_nmod_ctx(obj) @staticmethod cdef _get_ctx(int mod) @staticmethod cdef _new_ctx(ulong mod) + @cython.final cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + @cython.final cdef nmod new_nmod(self) diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 75aa6c37..378f79bc 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -48,7 +48,7 @@ cdef class nmod_ctx: return nmod_ctx._get_ctx(mod) @staticmethod - cdef nmod_ctx any_as_nmod_ctx(obj): + cdef any_as_nmod_ctx(obj): """Convert an ``nmod_ctx`` or ``int`` to an ``nmod_ctx``.""" if typecheck(obj, nmod_ctx): return obj @@ -163,7 +163,6 @@ cdef class nmod_ctx: return r -@cython.no_gc cdef class nmod(flint_scalar): """ The nmod type represents elements of Z/nZ for word-size n. @@ -173,9 +172,11 @@ cdef class nmod(flint_scalar): """ def __init__(self, val, mod): - ctx = nmod_ctx.any_as_nmod_ctx(mod) - if ctx is NotImplemented: + cdef nmod_ctx ctx + c = nmod_ctx.any_as_nmod_ctx(mod) + if c is NotImplemented: raise TypeError("Invalid context/modulus for nmod: %s" % mod) + ctx = c if not ctx.any_as_nmod(&self.val, val): raise TypeError("cannot create nmod from object of type %s" % type(val)) self.ctx = ctx diff --git a/src/flint/types/nmod_mat.pxd b/src/flint/types/nmod_mat.pxd index ada85a71..12a16774 100644 --- a/src/flint/types/nmod_mat.pxd +++ b/src/flint/types/nmod_mat.pxd @@ -1,14 +1,47 @@ +cimport cython + from flint.flint_base.flint_base cimport flint_mat +from flint.flintlib.nmod cimport nmod_t from flint.flintlib.nmod_mat cimport nmod_mat_t -from flint.flintlib.flint cimport mp_limb_t +from flint.flintlib.flint cimport mp_limb_t, ulong + +from flint.types.nmod cimport nmod_ctx, nmod +from flint.types.nmod_poly cimport nmod_poly_ctx, nmod_poly + + +@cython.no_gc +cdef class nmod_mat_ctx: + cdef nmod_t mod + cdef bint _is_prime + cdef nmod_ctx scalar_ctx + cdef nmod_poly_ctx poly_ctx + + @staticmethod + cdef any_as_nmod_mat_ctx(obj) + @staticmethod + cdef nmod_mat_ctx _get_ctx(int mod) + @staticmethod + cdef nmod_mat_ctx _new_ctx(ulong mod) -from flint.types.nmod cimport nmod_ctx + @cython.final + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + @cython.final + cdef any_as_nmod_mat(self, obj) + @cython.final + cdef nmod new_nmod(self) + @cython.final + cdef nmod_poly new_nmod_poly(self) + @cython.final + cdef nmod_mat new_nmod_mat(self, ulong m, ulong n) + @cython.final + cdef nmod_mat new_nmod_mat_copy(self, nmod_mat other) +@cython.no_gc cdef class nmod_mat(flint_mat): cdef nmod_mat_t val - cdef nmod_ctx ctx + cdef nmod_mat_ctx ctx cpdef long nrows(self) cpdef long ncols(self) diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 7257c00f..e015cef4 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -3,10 +3,6 @@ cimport cython from flint.flintlib.flint cimport ulong, mp_limb_t from flint.flintlib.nmod cimport nmod_t -from flint.flintlib.nmod_poly cimport ( - nmod_poly_init, -) - from flint.flintlib.fmpz_mat cimport fmpz_mat_nrows, fmpz_mat_ncols from flint.flintlib.fmpz_mat cimport fmpz_mat_get_nmod_mat @@ -41,10 +37,12 @@ from flint.flintlib.nmod_mat cimport ( ) from flint.utils.typecheck cimport typecheck +from flint.types.fmpz cimport fmpz from flint.types.fmpz_mat cimport any_as_fmpz_mat from flint.types.fmpz_mat cimport fmpz_mat from flint.types.nmod cimport nmod, nmod_ctx -from flint.types.nmod_poly cimport nmod_poly, nmod_poly_new_init, any_as_nmod_poly_ctx +from flint.types.nmod_poly cimport nmod_poly + from flint.pyflint cimport global_random_state from flint.flint_base.flint_context cimport thectx @@ -56,37 +54,187 @@ from flint.utils.flint_exceptions import DomainError ctx = thectx -cdef nmod_mat new_nmod_mat_init(nmod_ctx ctx, ulong m, ulong n): - """New initialized nmod_mat of size m x n with context ctx.""" - cdef nmod_mat r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, m, n, ctx.mod.n) - r.ctx = ctx - return r +cdef dict _nmod_mat_ctx_cache = {} + + +@cython.no_gc +cdef class nmod_mat_ctx: + """ + Context object for creating :class:`~.nmod_mat` initalised + with modulus :math:`N`. + + >>> ctx = nmod_mat_ctx.new(17) + >>> M = ctx([[1,2],[3,4]]) + >>> M + [1, 2] + [3, 4] + + """ + def __init__(self, *args, **kwargs): + raise TypeError("cannot create nmod_poly_ctx directly: use nmod_poly_ctx.new()") + + @staticmethod + def new(mod): + """Get an ``nmod_poly`` context with modulus ``mod``.""" + return nmod_mat_ctx._get_ctx(mod) + + @staticmethod + cdef any_as_nmod_mat_ctx(obj): + """Convert an ``nmod_mat_ctx`` or ``int`` to an ``nmod_mat_ctx``.""" + if typecheck(obj, nmod_mat_ctx): + return obj + if typecheck(obj, int): + return nmod_mat_ctx._get_ctx(obj) + elif typecheck(obj, fmpz): + return nmod_mat_ctx._get_ctx(int(obj)) + return NotImplemented + + @staticmethod + cdef nmod_mat_ctx _get_ctx(int mod): + """Retrieve an nmod_mat context from the cache or create a new one.""" + ctx = _nmod_mat_ctx_cache.get(mod) + if ctx is None: + _nmod_mat_ctx_cache[mod] = ctx = nmod_mat_ctx._new_ctx(mod) + return ctx + + @staticmethod + cdef nmod_mat_ctx _new_ctx(ulong mod): + """Create a new nmod_mat context.""" + cdef nmod_ctx scalar_ctx + cdef nmod_poly_ctx poly_ctx + cdef nmod_mat_ctx ctx + + poly_ctx = nmod_poly_ctx.new(mod) + scalar_ctx = poly_ctx.scalar_ctx + + ctx = nmod_mat_ctx.__new__(nmod_mat_ctx) + ctx.mod = scalar_ctx.mod + ctx._is_prime = scalar_ctx._is_prime + ctx.scalar_ctx = scalar_ctx + ctx.poly_ctx = poly_ctx + + return ctx + + @cython.final + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: + return self.scalar_ctx.any_as_nmod(val, obj) + + @cython.final + cdef any_as_nmod_mat(self, obj): + """Convert obj to nmod_mat or return NotImplemented.""" + cdef nmod_mat r + cdef mp_limb_t v + + if typecheck(obj, nmod_mat): + return obj + + x = any_as_fmpz_mat(obj) + if x is not NotImplemented: + r = self.new_nmod_mat(fmpz_mat_nrows((x).val), + fmpz_mat_ncols((x).val)) + fmpz_mat_get_nmod_mat(r.val, (x).val) + return r + return NotImplemented -cdef nmod_mat new_nmod_mat_copy(nmod_mat other): - """New copy of nmod_mat other.""" - cdef nmod_mat r = nmod_mat.__new__(nmod_mat) - nmod_mat_init_set(r.val, other.val) - r.ctx = other.ctx - return r + @cython.final + cdef nmod new_nmod(self): + return self.scalar_ctx.new_nmod() + @cython.final + cdef nmod_poly new_nmod_poly(self): + return self.poly_ctx.new_nmod_poly() -cdef any_as_nmod_mat(obj, nmod_ctx ctx): - """Convert obj to nmod_mat or return NotImplemented.""" - cdef nmod_mat r - if typecheck(obj, nmod_mat): - return obj + @cython.final + cdef nmod_mat new_nmod_mat(self, ulong m, ulong n): + """New initialized nmod_mat of size m x n with context ctx.""" + cdef nmod_mat r = nmod_mat.__new__(nmod_mat) + nmod_mat_init(r.val, m, n, self.mod.n) + r.ctx = self + return r - x = any_as_fmpz_mat(obj) - if x is not NotImplemented: - r = new_nmod_mat_init(ctx, - fmpz_mat_nrows((x).val), - fmpz_mat_ncols((x).val)) - fmpz_mat_get_nmod_mat(r.val, (x).val) + @cython.final + cdef nmod_mat new_nmod_mat_copy(self, nmod_mat other): + """New copy of nmod_mat other.""" + cdef nmod_mat r = nmod_mat.__new__(nmod_mat) + nmod_mat_init_set(r.val, other.val) + r.ctx = other.ctx return r - return NotImplemented + def modulus(self): + """Get the modulus of the context. + + >>> ctx = nmod_mat_ctx.new(17) + >>> ctx.modulus() + 17 + + """ + return fmpz(self.mod.n) + + def is_prime(self): + """Check if the modulus is prime. + + >>> ctx = nmod_mat_ctx.new(17) + >>> ctx.is_prime() + True + + """ + return self._is_prime + + #def zero(self, slong m, slong n): + # """Return the zero ``nmod_mat``. + # + # >>> ctx = nmod_mat_ctx.new(17) + # >>> ctx.zero(2, 3) + # [0, 0, 0] + # [0, 0, 0] + + # """ + # cdef nmod_mat r = self.new_nmod_mat() + # nmod_mat_zero(r.val) + # return r + + #def one(self, slong m, slong n=-1): + # """Return the one ``nmod_mat``. + + # >>> ctx = nmod_mat_ctx.new(17) + # >>> ctx.one(2) + # [1, 0] + # [0, 1] + # >>> ctx.one(2, 3) + # [1, 0, 0] + # [0, 1, 0] + + # """ + # cdef nmod_mat r = self.new_nmod_mat() + # if n == -1: + # n = m + # n = min(m, n) + # for i from 0 <= i < n: + # nmod_mat_set_entry(r.val, i, i, 1) + # return r + + def __str__(self): + return f"Context for nmod_mat with modulus: {self.mod.n}" + + def __repr__(self): + return f"nmod_mat_ctx({self.mod.n})" + + def __call__(self, *args): + """Create an ``nmod_mat``. + + >>> mat5 = nmod_mat_ctx.new(5) + >>> M = mat5([[1,2],[3,4]]) + >>> M + [1, 2] + [3, 4] + >>> M2 = mat5(2, 3, [1,2,3,4,5,6]) + >>> M2 + [1, 2, 3] + [4, 0, 1] + + """ + return nmod_mat(*args, self) cdef class nmod_mat(flint_mat): @@ -96,8 +244,8 @@ cdef class nmod_mat(flint_mat): Some operations may assume that n is a prime. """ - # cdef nmod_mat_t val +# cdef nmod_mat_ctx ctx def __dealloc__(self): nmod_mat_clear(self.val) @@ -105,8 +253,7 @@ cdef class nmod_mat(flint_mat): @cython.embedsignature(False) def __init__(self, *args): cdef long m, n, i, j - cdef mp_limb_t mod - cdef nmod_ctx ctx + cdef nmod_mat_ctx ctx if len(args) == 1: val = args[0] @@ -117,12 +264,17 @@ cdef class nmod_mat(flint_mat): mod = args[-1] args = args[:-1] - if mod == 0: - raise ValueError("modulus must be nonzero") - ctx = nmod_ctx.any_as_nmod_ctx(mod) + c = nmod_mat_ctx.any_as_nmod_mat_ctx(mod) + if c is NotImplemented: + raise TypeError("nmod_mat: expected last argument to be an nmod_mat_ctx or an integer") + + ctx = c self.ctx = ctx + if mod == 0: + raise ValueError("modulus must be nonzero") + if len(args) == 1: val = args[0] if typecheck(val, fmpz_mat): @@ -143,23 +295,23 @@ cdef class nmod_mat(flint_mat): for i from 0 <= i < m: row = val[i] for j from 0 <= j < n: - x = nmod(row[j], ctx) # XXX: slow - self.val.rows[i][j] = (x).val + if not ctx.any_as_nmod(&self.val.rows[i][j], row[j]): + raise TypeError("cannot create nmod from input of type %s" % type(row[j])) else: raise TypeError("cannot create nmod_mat from input of type %s" % type(val)) elif len(args) == 2: m, n = args - nmod_mat_init(self.val, m, n, mod) + nmod_mat_init(self.val, m, n, ctx.mod.n) elif len(args) == 3: m, n, entries = args - nmod_mat_init(self.val, m, n, mod) + nmod_mat_init(self.val, m, n, ctx.mod.n) entries = list(entries) if len(entries) != m*n: raise ValueError("list of entries has the wrong length") for i from 0 <= i < m: for j from 0 <= j < n: - x = nmod(entries[i*n + j], ctx) # XXX: slow - self.val.rows[i][j] = (x).val + if not ctx.any_as_nmod(&self.val.rows[i][j], entries[i*n + j]): + raise TypeError("cannot create nmod from input of type %s" % type(entries[i*n + j])) else: raise TypeError("nmod_mat: expected 1-3 arguments plus modulus") @@ -225,7 +377,8 @@ cdef class nmod_mat(flint_mat): i, j = index if i < 0 or i >= self.nrows() or j < 0 or j >= self.ncols(): raise IndexError("index %i,%i exceeds matrix dimensions" % (i, j)) - x = nmod(nmod_mat_entry(self.val, i, j), self.modulus()) # XXX: slow + x = self.ctx.new_nmod() + x.val = nmod_mat_entry(self.val, i, j) return x def __setitem__(self, index, value): @@ -252,7 +405,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, s.ctx) + t = s.ctx.any_as_nmod_mat(t) if t is NotImplemented: return t tv = &(t).val[0] @@ -260,7 +413,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot add nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix addition") - r = new_nmod_mat_init(s.ctx, sv.r, sv.c) + r = s.ctx.new_nmod_mat(sv.r, sv.c) nmod_mat_add(r.val, sv, tv) return r @@ -269,7 +422,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, s.ctx) + t = s.ctx.any_as_nmod_mat(t) if t is NotImplemented: return t tv = &(t).val[0] @@ -277,7 +430,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot add nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix addition") - r = new_nmod_mat_init(s.ctx, sv.r, sv.c) + r = s.ctx.new_nmod_mat(sv.r, sv.c) nmod_mat_add(r.val, sv, tv) return r @@ -286,7 +439,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, s.ctx) + t = s.ctx.any_as_nmod_mat(t) if t is NotImplemented: return t tv = &(t).val[0] @@ -294,7 +447,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot subtract nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix subtraction") - r = new_nmod_mat_init(s.ctx, sv.r, sv.c) + r = s.ctx.new_nmod_mat(sv.r, sv.c) nmod_mat_sub(r.val, sv, tv) return r @@ -303,7 +456,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, s.ctx) + t = s.ctx.any_as_nmod_mat(t) if t is NotImplemented: return t tv = &(t).val[0] @@ -311,13 +464,13 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot subtract nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix subtraction") - r = new_nmod_mat_init(s.ctx, sv.r, sv.c) + r = s.ctx.new_nmod_mat(sv.r, sv.c) nmod_mat_sub(r.val, tv, sv) return r cdef __mul_nmod(self, mp_limb_t c): cdef nmod_mat r - r = new_nmod_mat_init(self.ctx, self.val.r, self.val.c) + r = self.ctx.new_nmod_mat(self.val.r, self.val.c) nmod_mat_scalar_mul(r.val, self.val, c) return r @@ -327,7 +480,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *tv cdef mp_limb_t c sv = &(s).val[0] - u = any_as_nmod_mat(t, s.ctx) + u = s.ctx.any_as_nmod_mat(t) if u is NotImplemented: if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) @@ -337,7 +490,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot multiply nmod_mats with different moduli") if sv.c != tv.r: raise ValueError("incompatible shapes for matrix multiplication") - r = new_nmod_mat_init(s.ctx, sv.r, tv.c) + r = s.ctx.new_nmod_mat(sv.r, tv.c) nmod_mat_mul(r.val, sv, tv) return r @@ -347,7 +500,7 @@ cdef class nmod_mat(flint_mat): sv = &(s).val[0] if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) - u = any_as_nmod_mat(t, s.ctx) + u = s.ctx.any_as_nmod_mat(t) if u is NotImplemented: return u return u * s @@ -365,7 +518,7 @@ cdef class nmod_mat(flint_mat): self = self.inv() e = -e ee = e - t = new_nmod_mat_copy(self) + t = self.ctx.new_nmod_mat_copy(self) nmod_mat_pow(t.val, t.val, ee) return t @@ -399,9 +552,13 @@ cdef class nmod_mat(flint_mat): 15 """ + cdef nmod r if not nmod_mat_is_square(self.val): raise ValueError("matrix must be square") - return nmod(nmod_mat_det(self.val), self.modulus()) + + r = self.ctx.new_nmod() + r.val = nmod_mat_det(self.val) + return r def inv(self): """ @@ -421,8 +578,8 @@ cdef class nmod_mat(flint_mat): if not self.ctx._is_prime: raise DomainError("nmod_mat inv: modulus must be prime") - u = new_nmod_mat_init(self.ctx, nmod_mat_nrows(self.val), - nmod_mat_ncols(self.val)) + u = self.ctx.new_nmod_mat(nmod_mat_nrows(self.val), + nmod_mat_ncols(self.val)) if not nmod_mat_inv(u.val, self.val): raise ZeroDivisionError("matrix is singular") return u @@ -437,8 +594,8 @@ cdef class nmod_mat(flint_mat): [2, 5] """ cdef nmod_mat u - u = new_nmod_mat_init(self.ctx, nmod_mat_ncols(self.val), - nmod_mat_nrows(self.val)) + u = self.ctx.new_nmod_mat(nmod_mat_ncols(self.val), + nmod_mat_nrows(self.val)) nmod_mat_transpose(u.val, self.val) return u @@ -467,17 +624,18 @@ cdef class nmod_mat(flint_mat): """ cdef nmod_mat u cdef int result - t = any_as_nmod_mat(other, self.ctx) + t = self.ctx.any_as_nmod_mat(other) if t is NotImplemented: raise TypeError("cannot convert input to nmod_mat") if (nmod_mat_nrows(self.val) != nmod_mat_ncols(self.val) or nmod_mat_nrows(self.val) != nmod_mat_nrows((t).val)): raise ValueError("need a square system and compatible right hand side") + # XXX: Should check for same modulus. if not self.ctx._is_prime: raise DomainError("nmod_mat solve: modulus must be prime") - u = new_nmod_mat_init(self.ctx, nmod_mat_nrows((t).val), - nmod_mat_ncols((t).val)) + u = self.ctx.new_nmod_mat(nmod_mat_nrows((t).val), + nmod_mat_ncols((t).val)) result = nmod_mat_solve(u.val, self.val, (t).val) if not result: raise ZeroDivisionError("singular matrix in solve()") @@ -510,7 +668,7 @@ cdef class nmod_mat(flint_mat): if inplace: res = self else: - res = new_nmod_mat_copy(self) + res = self.ctx.new_nmod_mat_copy(self) rank = nmod_mat_rref((res).val) return res, rank @@ -550,8 +708,8 @@ cdef class nmod_mat(flint_mat): """ cdef nmod_mat res - res = new_nmod_mat_init(self.ctx, nmod_mat_ncols(self.val), - nmod_mat_ncols(self.val)) + res = self.ctx.new_nmod_mat(nmod_mat_ncols(self.val), + nmod_mat_ncols(self.val)) nullity = nmod_mat_nullspace(res.val, self.val) return res, nullity @@ -569,8 +727,7 @@ cdef class nmod_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat charpoly: matrix must be square") - # XXX: don't create a new context for the polynomial - res = nmod_poly_new_init(any_as_nmod_poly_ctx(self.ctx.mod.n)) + res = self.ctx.new_nmod_poly() nmod_mat_charpoly(res.val, self.val) return res @@ -592,7 +749,6 @@ cdef class nmod_mat(flint_mat): if not self.ctx._is_prime: raise DomainError("minpoly only works for prime moduli") - # XXX: don't create a new context for the polynomial - res = nmod_poly_new_init(any_as_nmod_poly_ctx(self.ctx.mod.n)) + res = self.ctx.new_nmod_poly() nmod_mat_minpoly(res.val, self.val) return res diff --git a/src/flint/types/nmod_poly.pxd b/src/flint/types/nmod_poly.pxd index 1d98ca0b..1e70b526 100644 --- a/src/flint/types/nmod_poly.pxd +++ b/src/flint/types/nmod_poly.pxd @@ -1,26 +1,40 @@ +cimport cython + from flint.flintlib.nmod cimport nmod_t from flint.flintlib.nmod_poly cimport nmod_poly_t -from flint.flintlib.flint cimport mp_limb_t +from flint.flintlib.flint cimport mp_limb_t, ulong from flint.flint_base.flint_base cimport flint_poly -from flint.types.nmod cimport nmod_ctx - - -cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj) -cdef nmod_poly nmod_poly_new_init(nmod_poly_ctx ctx) +from flint.types.nmod cimport nmod_ctx, nmod +@cython.no_gc cdef class nmod_poly_ctx: - cdef nmod_ctx ctx cdef nmod_t mod cdef bint _is_prime + cdef nmod_ctx scalar_ctx + + @staticmethod + cdef any_as_nmod_poly_ctx(obj) + @staticmethod + cdef nmod_poly_ctx _get_ctx(int mod) + @staticmethod + cdef nmod_poly_ctx _new_ctx(ulong mod) + @cython.final cdef nmod_poly_set_list(self, nmod_poly_t poly, list val) + @cython.final cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + @cython.final cdef any_as_nmod_poly(self, obj) + @cython.final + cdef nmod new_nmod(self) + @cython.final + cdef nmod_poly new_nmod_poly(self) +@cython.no_gc cdef class nmod_poly(flint_poly): cdef nmod_poly_t val cdef nmod_poly_ctx ctx diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 98178226..1375d4a6 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -15,36 +15,7 @@ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime from flint.utils.flint_exceptions import DomainError -_nmod_poly_ctx_cache = {} - - -cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj): - """Convert an int to an nmod_ctx.""" - if typecheck(obj, nmod_poly_ctx): - return obj - if typecheck(obj, int): - ctx = _nmod_poly_ctx_cache.get(obj) - if ctx is None: - ctx = nmod_poly_ctx(obj) - _nmod_poly_ctx_cache[obj] = ctx - return ctx - return NotImplemented - - -cdef nmod_poly nmod_poly_new_init(nmod_poly_ctx ctx): - cdef nmod_poly p - p = nmod_poly.__new__(nmod_poly) - nmod_poly_init(p.val, ctx.mod.n) - p.ctx = ctx - return p - - -cdef nmod_poly nmod_poly_new_init_preinv(nmod_poly_ctx ctx): - cdef nmod_poly p - p = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(p.val, ctx.mod.n, ctx.mod.ninv) - p.ctx = ctx - return p +cdef dict _nmod_poly_ctx_cache = {} cdef class nmod_poly_ctx: @@ -52,22 +23,51 @@ cdef class nmod_poly_ctx: Context object for creating :class:`~.nmod_poly` initalised with modulus :math:`N`. - >>> nmod_poly_ctx(17) + >>> nmod_poly_ctx.new(17) nmod_poly_ctx(17) """ - def __init__(self, mod): - cdef mp_limb_t m - m = mod - nmod_init(&self.mod, m) - self.ctx = nmod_ctx.new(mod) - self._is_prime = n_is_prime(m) + def __init__(self, *args, **kwargs): + raise TypeError("cannot create nmod_poly_ctx directly: use nmod_poly_ctx.new()") + + @staticmethod + def new(mod): + """Get an ``nmod_poly`` context with modulus ``mod``.""" + return nmod_poly_ctx._get_ctx(mod) + + @staticmethod + cdef any_as_nmod_poly_ctx(obj): + """Convert an ``nmod_poly_ctx`` or ``int`` to an ``nmod_poly_ctx``.""" + if typecheck(obj, nmod_poly_ctx): + return obj + if typecheck(obj, int): + return nmod_poly_ctx._get_ctx(obj) + return NotImplemented - def __repr__(self): - return f"nmod_poly_ctx({self.mod.n})" + @staticmethod + cdef nmod_poly_ctx _get_ctx(int mod): + """Retrieve an nmod_poly context from the cache or create a new one.""" + ctx = _nmod_poly_ctx_cache.get(mod) + if ctx is None: + _nmod_poly_ctx_cache[mod] = ctx = nmod_poly_ctx._new_ctx(mod) + return ctx + + @staticmethod + cdef nmod_poly_ctx _new_ctx(ulong mod): + """Create a new nmod_poly context.""" + cdef nmod_ctx scalar_ctx + cdef nmod_poly_ctx ctx + scalar_ctx = nmod_ctx.new(mod) + + ctx = nmod_poly_ctx.__new__(nmod_poly_ctx) + ctx.mod = scalar_ctx.mod + ctx._is_prime = scalar_ctx._is_prime + ctx.scalar_ctx = scalar_ctx + + return ctx cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: - return self.ctx.any_as_nmod(val, obj) + return self.scalar_ctx.any_as_nmod(val, obj) cdef any_as_nmod_poly(self, obj): cdef nmod_poly r @@ -75,19 +75,26 @@ cdef class nmod_poly_ctx: # XXX: should check that modulus is the same here, and not all over the place if typecheck(obj, nmod_poly): return obj - if self.ctx.any_as_nmod(&v, obj): - r = nmod_poly_new_init(self) + if self.any_as_nmod(&v, obj): + r = self.new_nmod_poly() nmod_poly_set_coeff_ui(r.val, 0, v) - r.ctx = self return r x = any_as_fmpz_poly(obj) if x is not NotImplemented: - r = nmod_poly_new_init(self) # XXX: create flint _nmod_poly_set_modulus for this? + r = self.new_nmod_poly() fmpz_poly_get_nmod_poly(r.val, (x).val) - r.ctx = self return r return NotImplemented + cdef nmod new_nmod(self): + return self.scalar_ctx.new_nmod() + + cdef nmod_poly new_nmod_poly(self): + cdef nmod_poly p = nmod_poly.__new__(nmod_poly) + nmod_poly_init_preinv(p.val, self.mod.n, self.mod.ninv) + p.ctx = self + return p + cdef nmod_poly_set_list(self, nmod_poly_t poly, list val): cdef long i, n cdef mp_limb_t v @@ -100,6 +107,68 @@ cdef class nmod_poly_ctx: else: raise TypeError("unsupported coefficient in list") + def modulus(self): + """Get the modulus of the context. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx.modulus() + 17 + + """ + return fmpz(self.mod.n) + + def is_prime(self): + """Check if the modulus is prime. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx.is_prime() + True + + """ + return self._is_prime + + def zero(self): + """Return the zero ``nmod_poly``. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx.zero() + 0 + + """ + cdef nmod_poly r = self.new_nmod_poly() + nmod_poly_zero(r.val) + return r + + def one(self): + """Return the one ``nmod_poly``. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx.one() + 1 + + """ + cdef nmod_poly r = self.new_nmod_poly() + nmod_poly_set_coeff_ui(r.val, 0, 1) + return r + + def __str__(self): + return f"Context for nmod_poly with modulus: {self.mod.n}" + + def __repr__(self): + return f"nmod_poly_ctx({self.mod.n})" + + def __call__(self, arg): + """Create an ``nmod_poly``. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx(10) + 10 + >>> ctx([1,2,3]) + 3*x^2 + 2*x + 1 + + """ + return nmod_poly(arg, self) + cdef class nmod_poly(flint_poly): """ @@ -146,7 +215,7 @@ cdef class nmod_poly(flint_poly): else: if mod == 0: raise ValueError("a nonzero modulus is required") - ctx = any_as_nmod_poly_ctx(mod) + ctx = nmod_poly_ctx.any_as_nmod_poly_ctx(mod) if ctx is NotImplemented: raise TypeError("cannot create nmod_poly_ctx from input of type %s", type(mod)) @@ -281,9 +350,8 @@ cdef class nmod_poly(flint_poly): else: length = nmod_poly_length(self.val) - res = nmod_poly_new_init_preinv(self.ctx) + res = self.ctx.new_nmod_poly() nmod_poly_reverse(res.val, self.val, length) - res.ctx = self.ctx return res def leading_coefficient(self): @@ -304,10 +372,8 @@ cdef class nmod_poly(flint_poly): else: cu = nmod_poly_get_coeff_ui(self.val, d) - c = nmod.__new__(nmod) + c = self.ctx.new_nmod() c.val = cu - c.ctx = self.ctx.ctx - return c def inverse_series_trunc(self, slong n): @@ -329,9 +395,8 @@ cdef class nmod_poly(flint_poly): if self.is_zero(): raise ValueError("cannot invert the zero element") - cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx) + cdef nmod_poly res = self.ctx.new_nmod_poly() nmod_poly_inv_series(res.val, self.val, n) - res.ctx = self.ctx return res def compose(self, other): @@ -352,9 +417,8 @@ cdef class nmod_poly(flint_poly): other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") - res = nmod_poly_new_init_preinv(self.ctx) + res = self.ctx.new_nmod_poly() nmod_poly_compose(res.val, self.val, (other).val) - res.ctx = self.ctx return res def compose_mod(self, other, modulus): @@ -385,12 +449,12 @@ cdef class nmod_poly(flint_poly): if modulus.is_zero(): raise ZeroDivisionError("cannot reduce modulo zero") - res = nmod_poly_new_init_preinv(self.ctx) + res = self.ctx.new_nmod_poly() nmod_poly_compose_mod(res.val, self.val, (other).val, (modulus).val) - res.ctx = self.ctx return res def __call__(self, other): + cdef nmod_poly r cdef mp_limb_t c if self.ctx.any_as_nmod(&c, other): v = nmod(0, self.modulus()) @@ -398,31 +462,27 @@ cdef class nmod_poly(flint_poly): return v t = self.ctx.any_as_nmod_poly(other) if t is not NotImplemented: - r = nmod_poly_new_init_preinv(self.ctx) - nmod_poly_compose((r).val, self.val, (t).val) - (r).ctx = self.ctx + r = self.ctx.new_nmod_poly() + nmod_poly_compose(r.val, self.val, (t).val) return r raise TypeError("cannot call nmod_poly with input of type %s", type(other)) def derivative(self): - cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx) + cdef nmod_poly res = self.ctx.new_nmod_poly() nmod_poly_derivative(res.val, self.val) - res.ctx = self.ctx return res def integral(self): - cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx) + cdef nmod_poly res = self.ctx.new_nmod_poly() nmod_poly_integral(res.val, self.val) - res.ctx = self.ctx return res def __pos__(self): return self def __neg__(self): - cdef nmod_poly r = nmod_poly_new_init_preinv(self.ctx) + cdef nmod_poly r = self.ctx.new_nmod_poly() nmod_poly_neg(r.val, self.val) - r.ctx = self.ctx return r def _add_(s, t): @@ -432,9 +492,8 @@ cdef class nmod_poly(flint_poly): return t if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot add nmod_polys with different moduli") - r = nmod_poly_new_init_preinv(s.ctx) + r = s.ctx.new_nmod_poly() nmod_poly_add(r.val, (s).val, (t).val) - r.ctx = s.ctx return r def __add__(s, t): @@ -447,9 +506,8 @@ cdef class nmod_poly(flint_poly): cdef nmod_poly r if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot subtract nmod_polys with different moduli") - r = nmod_poly_new_init_preinv(s.ctx) + r = s.ctx.new_nmod_poly() nmod_poly_sub(r.val, (s).val, (t).val) - r.ctx = s.ctx return r def __sub__(s, t): @@ -471,9 +529,8 @@ cdef class nmod_poly(flint_poly): return t if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot multiply nmod_polys with different moduli") - r = nmod_poly_new_init_preinv(s.ctx) + r = s.ctx.new_nmod_poly() nmod_poly_mul(r.val, (s).val, (t).val) - r.ctx = s.ctx return r def __mul__(s, t): @@ -510,9 +567,8 @@ cdef class nmod_poly(flint_poly): if not s.ctx._is_prime: raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime") - r = nmod_poly_new_init_preinv(s.ctx) + r = s.ctx.new_nmod_poly() nmod_poly_div(r.val, (s).val, (t).val) - r.ctx = s.ctx return r def __floordiv__(s, t): @@ -537,8 +593,8 @@ cdef class nmod_poly(flint_poly): if not s.ctx._is_prime: raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime") - P = nmod_poly_new_init_preinv(s.ctx) - Q = nmod_poly_new_init_preinv(s.ctx) + P = s.ctx.new_nmod_poly() + Q = s.ctx.new_nmod_poly() nmod_poly_divrem(P.val, Q.val, (s).val, (t).val) return P, Q @@ -566,9 +622,8 @@ cdef class nmod_poly(flint_poly): return self.pow_mod(exp, mod) if exp < 0: raise ValueError("negative exponent") - res = nmod_poly_new_init_preinv(self.ctx) + res = self.ctx.new_nmod_poly() nmod_poly_pow(res.val, self.val, exp) - res.ctx = self.ctx return res def pow_mod(self, e, modulus, mod_rev_inv=None): @@ -603,7 +658,7 @@ cdef class nmod_poly(flint_poly): raise TypeError("cannot convert input to nmod_poly") # Output polynomial - res = nmod_poly_new_init_preinv(self.ctx) + res = self.ctx.new_nmod_poly() # For small exponents, use a simple binary exponentiation method if e.bit_length() < 32: @@ -658,9 +713,8 @@ cdef class nmod_poly(flint_poly): if not self.ctx._is_prime: raise DomainError("nmod_poly gcd: modulus {self.ctx.mod.n} is not prime") - res = nmod_poly_new_init_preinv(self.ctx) + res = self.ctx.new_nmod_poly() nmod_poly_gcd(res.val, self.val, (other).val) - res.ctx = self.ctx return res def xgcd(self, other): @@ -687,9 +741,9 @@ cdef class nmod_poly(flint_poly): if not self.ctx._is_prime: raise DomainError("nmod_poly xgcd: modulus {self.ctx.mod.n} is not prime") - res1 = nmod_poly_new_init(self.ctx) - res2 = nmod_poly_new_init(self.ctx) - res3 = nmod_poly_new_init(self.ctx) + res1 = self.ctx.new_nmod_poly() + res2 = self.ctx.new_nmod_poly() + res3 = self.ctx.new_nmod_poly() nmod_poly_xgcd(res1.val, res2.val, res3.val, self.val, (other).val) @@ -770,13 +824,12 @@ cdef class nmod_poly(flint_poly): res = [None] * fac.num for 0 <= i < fac.num: - u = nmod_poly_new_init_preinv(self.ctx) + u = self.ctx.new_nmod_poly() nmod_poly_set(u.val, &fac.p[i]) exp = fac.exp[i] res[i] = (u, exp) - c = nmod.__new__(nmod) - c.ctx = self.ctx.ctx + c = self.ctx.new_nmod() c.val = lead nmod_poly_factor_clear(fac) @@ -790,7 +843,7 @@ cdef class nmod_poly(flint_poly): if not self.ctx._is_prime: raise DomainError(f"nmod_poly sqrt: modulus {self.ctx.mod.n} is not prime") - res = nmod_poly_new_init_preinv(self.ctx) + res = self.ctx.new_nmod_poly() if not nmod_poly_sqrt(res.val, self.val): raise DomainError(f"Cannot compute square root of {self}") @@ -806,9 +859,8 @@ cdef class nmod_poly(flint_poly): if n == 1: return self, int(n) else: - v = nmod_poly_new_init(self.ctx) + v = self.ctx.new_nmod_poly() nmod_poly_deflate(v.val, self.val, n) - v.ctx = self.ctx return v, int(n) def real_roots(self): From c454681950bece8992967b4387715dd3021b90f5 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 27 Aug 2024 21:19:13 +0100 Subject: [PATCH 13/30] Use no_gc for nmod --- src/flint/types/nmod.pxd | 1 + src/flint/types/nmod.pyx | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 0ed57a7a..666c4a09 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -23,6 +23,7 @@ cdef class nmod_ctx: cdef nmod new_nmod(self) +@cython.no_gc cdef class nmod(flint_scalar): cdef mp_limb_t val cdef nmod_ctx ctx diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 378f79bc..fd7b1c5a 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -22,6 +22,7 @@ from flint.utils.flint_exceptions import DomainError cdef dict _nmod_ctx_cache = {} +@cython.no_gc cdef class nmod_ctx: """ Context object for creating :class:`~.nmod` initalised @@ -163,6 +164,7 @@ cdef class nmod_ctx: return r +@cython.no_gc cdef class nmod(flint_scalar): """ The nmod type represents elements of Z/nZ for word-size n. From 9062d323eafaa690f491a587a9dabaea431b634e Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 27 Aug 2024 21:33:21 +0100 Subject: [PATCH 14/30] Use setdefault for ctx caches --- src/flint/types/nmod.pyx | 2 +- src/flint/types/nmod_mat.pyx | 2 +- src/flint/types/nmod_poly.pyx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index fd7b1c5a..246c4064 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -62,7 +62,7 @@ cdef class nmod_ctx: """Retrieve an nmod context from the cache or create a new one.""" ctx = _nmod_ctx_cache.get(mod) if ctx is None: - _nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod) + ctx = _nmod_ctx_cache.setdefault(mod, nmod_ctx._new_ctx(mod)) return ctx @staticmethod diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index e015cef4..9cebaa5f 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -94,7 +94,7 @@ cdef class nmod_mat_ctx: """Retrieve an nmod_mat context from the cache or create a new one.""" ctx = _nmod_mat_ctx_cache.get(mod) if ctx is None: - _nmod_mat_ctx_cache[mod] = ctx = nmod_mat_ctx._new_ctx(mod) + ctx = _nmod_mat_ctx_cache.setdefault(mod, nmod_mat_ctx._new_ctx(mod)) return ctx @staticmethod diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 1375d4a6..9ff98d4d 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -49,7 +49,7 @@ cdef class nmod_poly_ctx: """Retrieve an nmod_poly context from the cache or create a new one.""" ctx = _nmod_poly_ctx_cache.get(mod) if ctx is None: - _nmod_poly_ctx_cache[mod] = ctx = nmod_poly_ctx._new_ctx(mod) + ctx = _nmod_poly_ctx_cache.setdefault(mod, nmod_poly_ctx._new_ctx(mod)) return ctx @staticmethod From 0fa5d420e8911088de95785aa4c391834966f9e9 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 27 Aug 2024 22:13:27 +0100 Subject: [PATCH 15/30] pin cython commit for the coverage job --- .github/workflows/buildwheel.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/buildwheel.yml b/.github/workflows/buildwheel.yml index cde98caf..e7dc46a3 100644 --- a/.github/workflows/buildwheel.yml +++ b/.github/workflows/buildwheel.yml @@ -193,7 +193,7 @@ jobs: # Test that we can make a coverage build and report coverage test_coverage_build: - name: Test coverage setuptools build + name: Test coverage build runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 @@ -202,9 +202,11 @@ jobs: python-version: '3.12' - run: sudo apt-get update - run: sudo apt-get install libflint-dev - # Need Cython's master branch until 3.1 is released because of: + # Need Cython's master branch until 3.1 is released because we need: # https://github.com/cython/cython/pull/6341 - - run: pip install git+https://github.com/cython/cython.git@master + # Except now we can't use the master branch any more because of: + # https://github.com/cython/cython/issues/6366 + - run: pip install git+https://github.com/cython/cython.git@accde54653217dc28c52befd10d183941ba28336 - run: pip install -r requirements-dev.txt - run: bin/coverage.sh From 3abe2836bdbd0e66eead527325e8b2ea1e909734 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 27 Aug 2024 22:26:04 +0100 Subject: [PATCH 16/30] Use an earlier Cython commit... --- .github/workflows/buildwheel.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/buildwheel.yml b/.github/workflows/buildwheel.yml index e7dc46a3..e8d3b399 100644 --- a/.github/workflows/buildwheel.yml +++ b/.github/workflows/buildwheel.yml @@ -206,7 +206,7 @@ jobs: # https://github.com/cython/cython/pull/6341 # Except now we can't use the master branch any more because of: # https://github.com/cython/cython/issues/6366 - - run: pip install git+https://github.com/cython/cython.git@accde54653217dc28c52befd10d183941ba28336 + - run: pip install git+https://github.com/cython/cython.git@20bceea6b19ffc2f65b9fba2e4f737f09e5a2b20 - run: pip install -r requirements-dev.txt - run: bin/coverage.sh From f7d27a55ba196d02aeae53e95beb5e0f3c3325f3 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 27 Aug 2024 22:40:43 +0100 Subject: [PATCH 17/30] Use Oscar's cython branch for coverage again --- .github/workflows/buildwheel.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/buildwheel.yml b/.github/workflows/buildwheel.yml index e8d3b399..eba6f8c0 100644 --- a/.github/workflows/buildwheel.yml +++ b/.github/workflows/buildwheel.yml @@ -206,7 +206,8 @@ jobs: # https://github.com/cython/cython/pull/6341 # Except now we can't use the master branch any more because of: # https://github.com/cython/cython/issues/6366 - - run: pip install git+https://github.com/cython/cython.git@20bceea6b19ffc2f65b9fba2e4f737f09e5a2b20 + # So we have to keep using Oscar's PR branch: + - run: pip install git+https://github.com/oscarbenjamin/cython.git@pr_relative_paths - run: pip install -r requirements-dev.txt - run: bin/coverage.sh From 5b16c20efa31888cf3ef6c65e6d49f2eaf4713a5 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Wed, 28 Aug 2024 12:42:50 +0100 Subject: [PATCH 18/30] test: add tests for fmpz and fmpz_mat --- .coveragerc | 4 ++++ src/flint/test/test_all.py | 13 +++++++++++-- src/flint/types/fmpz.pyx | 13 ++++--------- src/flint/types/fmpz_mat.pyx | 20 +++++++++++--------- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/.coveragerc b/.coveragerc index a72fd709..74a1213f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,6 @@ [run] plugins = coverage_plugin + +[report] +exclude_lines = + assert False diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index 84a7ec26..96e5a454 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -21,6 +21,11 @@ def raises(f, exception): return False +def test_raises(): + assert raises(lambda: 1/0, ZeroDivisionError) is True + assert raises(lambda: 1/1, ZeroDivisionError) is False + + _default_ctx_string = """\ pretty = True # pretty-print repr() output unicode = False # use unicode characters in output @@ -153,8 +158,9 @@ def test_fmpz(): # https://github.com/flintlib/python-flint/issues/74 if not PYPY: assert pow(a, flint.fmpz(b), c) == ab_mod_c - assert pow(a, b, flint.fmpz(c)) == ab_mod_c assert pow(a, flint.fmpz(b), flint.fmpz(c)) == ab_mod_c + assert pow(a, b, flint.fmpz(c)) == ab_mod_c + assert raises(lambda: pow([], flint.fmpz(2), 2), TypeError) assert raises(lambda: pow(flint.fmpz(2), 2, 0), ValueError) # XXX: Handle negative modulus like int? @@ -602,7 +608,8 @@ def set_bad(i,j): assert raises(lambda: M([[1,1],[1,1]]).solve(b), ZeroDivisionError) assert raises(lambda: M([[1,2],[3,4],[5,6]]).solve(b), ValueError) assert M([[1,0],[1,2]]).solve(b) == flint.fmpq_mat([[3],[2]]) - assert raises(lambda: M([[1,0],[1,2]]).solve(b, integer=True), ValueError) + assert raises(lambda: M([[1,0],[1,0]]).solve(b, integer=True), ZeroDivisionError) + assert raises(lambda: M([[1,0],[1,2]]).solve(b, integer=True), DomainError) assert raises(lambda: M([[1,2,3],[4,5,6]]).inv(), ValueError) assert raises(lambda: M([[1,1],[1,1]]).inv(), ZeroDivisionError) assert raises(lambda: M([[1,0],[1,2]]).inv(integer=True), ValueError) @@ -632,6 +639,7 @@ def set_bad(i,j): for gram in "approx", "exact": assert M4.lll(rep=rep, gram=gram) == L4 assert M4.lll(rep=rep, gram=gram, transform=True) == (L4, T4) + assert raises(lambda: M4.lll(rep="gram"), AssertionError) assert raises(lambda: M4.lll(rep="bad"), ValueError) assert raises(lambda: M4.lll(gram="bad"), ValueError) M5 = M([[1,2,3],[4,5,6]]) @@ -4561,6 +4569,7 @@ def test_all_tests(): all_tests = [ + test_raises, test_pyflint, test_showgood, diff --git a/src/flint/types/fmpz.pyx b/src/flint/types/fmpz.pyx index 71270b75..cf51f6b3 100644 --- a/src/flint/types/fmpz.pyx +++ b/src/flint/types/fmpz.pyx @@ -395,18 +395,15 @@ cdef class fmpz(flint_scalar): return u def __pow__(s, t, m): - cdef fmpz_struct sval[1] - cdef fmpz_struct tval[1] - cdef fmpz_struct mval[1] - cdef int stype = FMPZ_UNKNOWN + cdef fmpz_t tval + cdef fmpz_t mval cdef int ttype = FMPZ_UNKNOWN cdef int mtype = FMPZ_UNKNOWN cdef int success u = NotImplemented try: - stype = fmpz_set_any_ref(sval, s) - if stype == FMPZ_UNKNOWN: + if not typecheck(s, fmpz): return NotImplemented ttype = fmpz_set_any_ref(tval, t) if ttype == FMPZ_UNKNOWN: @@ -440,12 +437,10 @@ cdef class fmpz(flint_scalar): raise ValueError("pow(): negative modulus not supported") u = fmpz.__new__(fmpz) - fmpz_powm((u).val, sval, tval, mval) + fmpz_powm((u).val, s.val, tval, mval) return u finally: - if stype == FMPZ_TMP: - fmpz_clear(sval) if ttype == FMPZ_TMP: fmpz_clear(tval) if mtype == FMPZ_TMP: diff --git a/src/flint/types/fmpz_mat.pyx b/src/flint/types/fmpz_mat.pyx index aae9ae9a..530cafd1 100644 --- a/src/flint/types/fmpz_mat.pyx +++ b/src/flint/types/fmpz_mat.pyx @@ -300,17 +300,16 @@ cdef class fmpz_mat(flint_mat): def __pow__(self, e, m): cdef fmpz_mat t cdef ulong ee - if not typecheck(self, fmpz_mat): - return NotImplemented - if not fmpz_mat_is_square((self).val): + if not fmpz_mat_is_square(self.val): raise ValueError("matrix must be square") if m is not None: raise NotImplementedError("modular matrix exponentiation") if e < 0: + # Allow unimodular? raise DomainError("negative power of integer matrix: M**%i" % e) ee = e t = fmpz_mat.__new__(fmpz_mat) - fmpz_mat_init_set(t.val, (self).val) + fmpz_mat_init_set(t.val, self.val) fmpz_mat_pow(t.val, t.val, ee) return t @@ -518,7 +517,7 @@ cdef class fmpz_mat(flint_mat): >>> A.solve(B, integer=True) Traceback (most recent call last): ... - ValueError: matrix is not invertible over the integers + flint.utils.flint_exceptions.DomainError: matrix is not invertible over the integers >>> fmpz_mat([[1,2], [3,5]]).solve(B, integer=True) [ 6, 3, 0] [-3, -1, 1] @@ -554,11 +553,11 @@ cdef class fmpz_mat(flint_mat): fmpz_mat_ncols((t).val)) d = fmpz.__new__(fmpz) result = fmpz_mat_solve(u.val, d.val, self.val, (t).val) - if not fmpz_is_pm1(d.val): - raise ValueError("matrix is not invertible over the integers") - u *= d if not result: raise ZeroDivisionError("singular matrix in solve()") + if not fmpz_is_pm1(d.val): + raise DomainError("matrix is not invertible over the integers") + u *= d return u def rref(self, inplace=False): @@ -640,7 +639,10 @@ cdef class fmpz_mat(flint_mat): if rep == "zbasis": rt = 1 elif rep == "gram": - rt = 0 + # rt = 0 + # XXX: This consumes all memory and crashes. Maybe the parameters + # need to be different or something? Best to disable this for now. + assert False, "rep = 'gram' does not work currently." else: raise ValueError("rep must be 'zbasis' or 'gram'") if gram == "approx": From 9b3d7c4c687c2808a62a793cdeca3e7356eb7e25 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Wed, 28 Aug 2024 13:20:55 +0100 Subject: [PATCH 19/30] test: add polys tests for div/sqrt --- src/flint/test/test_all.py | 17 ++++++++++++++++- src/flint/types/fmpq_poly.pyx | 14 ++------------ src/flint/types/fmpz_mod_poly.pyx | 4 ---- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index 96e5a454..5fe4b415 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -933,6 +933,7 @@ def test_fmpq_poly(): assert raises(lambda: Q([1,[]]), TypeError) assert raises(lambda: Q({}), TypeError) assert raises(lambda: Q([1], []), TypeError) + assert raises(lambda: Q(1, 1, 1), TypeError) assert raises(lambda: Q([1], 0), ZeroDivisionError) assert bool(Q()) == False assert bool(Q([1])) == True @@ -2735,6 +2736,8 @@ def setbad(obj, i, val): assert raises(lambda: 1 / P([1, 1]), DomainError) assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError) + assert raises(lambda: [] / P([1, 1]), TypeError) + assert raises(lambda: P([1, 1]) / [], TypeError) if is_field: assert P([1, 1]) // 2 == P([S(1)/2, S(1)/2]) @@ -2827,7 +2830,7 @@ def setbad(obj, i, val): assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError) if P == flint.fmpq_poly: - assert raises(lambda: P([1, 2, 1], 3).sqrt(), ValueError) + assert raises(lambda: P([1, 2, 1], 3).sqrt(), DomainError) assert P([1, 2, 1], 4).sqrt() == P([1, 1], 2) assert P([]).deflation() == (P([]), 1) @@ -3471,6 +3474,18 @@ def factor_sqf(p): assert S(1).sqrt() == S(1) assert S(4).sqrt()**2 == S(4) + if is_field: + for n in range(1, 10): + try: + sqrtn = S(n).sqrt() + except DomainError: + sqrtn = None + if sqrtn is None: + assert raises(lambda: ((x + 1)**2/n).sqrt(), DomainError) + else: + assert ((x + 1)**2/n).sqrt() ** 2 == (x + 1)**2/n + assert raises(lambda: ((x**2 + 1)/n).sqrt(), DomainError) + for i in range(-100, 100): try: assert S(i).sqrt() ** 2 == S(i) diff --git a/src/flint/types/fmpq_poly.pyx b/src/flint/types/fmpq_poly.pyx index f9e6df48..9c266c3d 100644 --- a/src/flint/types/fmpq_poly.pyx +++ b/src/flint/types/fmpq_poly.pyx @@ -120,9 +120,6 @@ cdef class fmpq_poly(flint_poly): cdef bint r if op != 2 and op != 3: raise TypeError("polynomials cannot be ordered") - self = any_as_fmpq_poly(self) - if self is NotImplemented: - return self other = any_as_fmpq_poly(other) if other is NotImplemented: return other @@ -476,15 +473,8 @@ cdef class fmpq_poly(flint_poly): 1/2*x + 1/2 """ - d = self.denom() - n = self.numer() - d, r = d.sqrtrem() - if r != 0: - raise ValueError(f"Cannot compute square root of {self}") - n = n.sqrt() - if n is None: - raise ValueError(f"Cannot compute square root of {self}") - return fmpq_poly(n, d) + d = self.denom().sqrt() + return fmpq_poly(self.numer().sqrt(), d) def deflation(self): num, n = self.numer().deflation() diff --git a/src/flint/types/fmpz_mod_poly.pyx b/src/flint/types/fmpz_mod_poly.pyx index 69e17e71..434275d4 100644 --- a/src/flint/types/fmpz_mod_poly.pyx +++ b/src/flint/types/fmpz_mod_poly.pyx @@ -433,10 +433,6 @@ cdef class fmpz_mod_poly(flint_poly): def _div_(self, other): cdef fmpz_mod_poly res - other = self.ctx.mod.any_as_fmpz_mod(other) - if other is NotImplemented: - return NotImplemented - if other == 0: raise ZeroDivisionError("Cannot divide by zero") elif not other.is_unit(): From e7c1418362f97891cebc874f626e31c0616ac028 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Wed, 28 Aug 2024 14:26:05 +0100 Subject: [PATCH 20/30] test: handle excluded lines in coverage plugin --- .coveragerc | 2 +- coverage_plugin.py | 22 +++++++++++++++------- src/flint/types/fmpz_poly.pyx | 9 +++------ 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/.coveragerc b/.coveragerc index 74a1213f..3c42fcfb 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,5 +2,5 @@ plugins = coverage_plugin [report] -exclude_lines = +exclude_also = assert False diff --git a/coverage_plugin.py b/coverage_plugin.py index 8382dc26..69c9aeeb 100644 --- a/coverage_plugin.py +++ b/coverage_plugin.py @@ -62,7 +62,7 @@ def get_cython_build_rules(): @cache -def parse_all_cfile_lines(): +def parse_all_cfile_lines(exclude_patterns=None): """Parse all generated C files from the build directory.""" # # Each .c file can include code generated from multiple Cython files (e.g. @@ -80,7 +80,7 @@ def parse_all_cfile_lines(): for c_file, _ in get_cython_build_rules(): - cfile_lines = parse_cfile_lines(c_file) + cfile_lines = parse_cfile_lines(c_file, exclude_patterns=exclude_patterns) for cython_file, line_map in cfile_lines.items(): if cython_file == '(tree fragment)': @@ -94,15 +94,22 @@ def parse_all_cfile_lines(): return all_code_lines -def parse_cfile_lines(c_file): +def parse_cfile_lines(c_file, exclude_patterns=None): """Use Cython's coverage plugin to parse the C code.""" from Cython.Coverage import Plugin - return Plugin()._parse_cfile_lines(c_file) + p = Plugin() + p._excluded_line_patterns = list(exclude_patterns) + return p._parse_cfile_lines(c_file) class Plugin(CoveragePlugin): """A coverage plugin for a spin/meson project with Cython code.""" + def configure(self, config): + # Entry point for coverage "configurer". + # Read the regular expressions from the coverage config that match lines to be excluded from coverage. + self.exclude_patterns = tuple(config.get_option("report:exclude_lines")) + def file_tracer(self, filename): """Find a tracer for filename to handle trace events.""" path = Path(filename) @@ -121,7 +128,7 @@ def file_tracer(self, filename): def file_reporter(self, filename): """Return a file reporter for filename.""" srcfile = Path(filename).relative_to(src_dir) - return CyFileReporter(srcfile) + return CyFileReporter(srcfile, exclude_patterns=self.exclude_patterns) class CyFileTracer(FileTracer): @@ -157,7 +164,7 @@ def get_source_filename(filename): class CyFileReporter(FileReporter): """File reporter for Cython or Python files (.pyx,.pxd,.py).""" - def __init__(self, srcpath): + def __init__(self, srcpath, exclude_patterns): abspath = (src_dir / srcpath) assert abspath.exists() @@ -165,6 +172,7 @@ def __init__(self, srcpath): super().__init__(str(abspath)) self.srcpath = srcpath + self.exclude_patterns = exclude_patterns def relative_filename(self): """Path displayed in the coverage reports.""" @@ -173,7 +181,7 @@ def relative_filename(self): def lines(self): """Set of line numbers for possibly traceable lines.""" srcpath = str(self.srcpath) - all_line_maps = parse_all_cfile_lines() + all_line_maps = parse_all_cfile_lines(exclude_patterns=self.exclude_patterns) line_map = all_line_maps[srcpath] return set(line_map) diff --git a/src/flint/types/fmpz_poly.pyx b/src/flint/types/fmpz_poly.pyx index ec9b761f..e1468a9a 100644 --- a/src/flint/types/fmpz_poly.pyx +++ b/src/flint/types/fmpz_poly.pyx @@ -106,9 +106,6 @@ cdef class fmpz_poly(flint_poly): cdef bint r if op != 2 and op != 3: raise TypeError("polynomials cannot be ordered") - self = any_as_fmpz_poly(self) - if self is NotImplemented: - return self other = any_as_fmpz_poly(other) if other is NotImplemented: return other @@ -455,7 +452,7 @@ cdef class fmpz_poly(flint_poly): return [] flags = 0 if verbose: - flags = 1 + flags = 1 # pragma: no cover roots = [] fmpz_poly_factor_init(fac) fmpz_poly_factor_squarefree(fac, self.val) @@ -550,8 +547,8 @@ cdef class fmpz_poly(flint_poly): arb_poly_init(t) arb_poly_swinnerton_dyer_ui(t, n, 0) if not arb_poly_get_unique_fmpz_poly((u).val, t): - arb_poly_clear(t) - raise ValueError("insufficient precision") + arb_poly_clear(t) # pragma: no cover + raise ValueError("insufficient precision") # pragma: no cover arb_poly_clear(t) else: fmpz_poly_swinnerton_dyer((u).val, n) From 0fa1dbfd113d250885d41e37bef60bf13117aa34 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Wed, 28 Aug 2024 14:39:22 +0100 Subject: [PATCH 21/30] fmpq: use fmpq_cmp for <,<=,>,>= --- src/flint/test/test_all.py | 1 + src/flint/types/fmpq.pyx | 19 ++++++++----------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index 5fe4b415..d5f1034f 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -772,6 +772,7 @@ def test_fmpq(): assert raises(lambda: Q([]), TypeError) assert raises(lambda: Q(1, []), TypeError) assert raises(lambda: Q([], 1), TypeError) + assert raises(lambda: Q(1, 1, 1), TypeError) assert bool(Q(0)) == False assert bool(Q(1)) == True assert Q(1,3) + Q(2,3) == 1 diff --git a/src/flint/types/fmpq.pyx b/src/flint/types/fmpq.pyx index c4569426..1016f4f3 100644 --- a/src/flint/types/fmpq.pyx +++ b/src/flint/types/fmpq.pyx @@ -107,9 +107,6 @@ cdef class fmpq(flint_scalar): def __richcmp__(s, t, int op): cdef bint res - s = any_as_fmpq(s) - if s is NotImplemented: - return s t = any_as_fmpq(t) if t is NotImplemented: return t @@ -119,17 +116,17 @@ cdef class fmpq(flint_scalar): res = not res return res else: - # todo: use fmpq_cmp when available + res = fmpq_cmp(s.val, (t).val) if op == 0: - res = (s-t).p < 0 + res = res < 0 elif op == 1: - res = (s-t).p <= 0 + res = res <= 0 elif op == 4: - res = (s-t).p > 0 + res = res > 0 elif op == 5: - res = (s-t).p >= 0 + res = res >= 0 else: - raise ValueError + assert False return res def numer(self): @@ -442,9 +439,9 @@ cdef class fmpq(flint_scalar): import sys from fractions import Fraction if sys.version_info < (3, 12): - return hash(Fraction(int(self.p), int(self.q), _normalize=False)) + return hash(Fraction(int(self.p), int(self.q), _normalize=False)) # pragma: no cover else: - return hash(Fraction._from_coprime_ints(int(self.p), int(self.q))) + return hash(Fraction._from_coprime_ints(int(self.p), int(self.q))) # pragma: no cover def height_bits(self, bint signed=False): """ From b605f9710b8ec8b2a12f71a2979e5f35f9bfd094 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Wed, 28 Aug 2024 18:03:19 +0100 Subject: [PATCH 22/30] Full coverage of nmod_poly and various fixes. --- src/flint/test/test_all.py | 99 ++++++++++++++++++++++++++++- src/flint/types/fmpq_mat.pyx | 8 +-- src/flint/types/fmpz_mod_mat.pyx | 4 +- src/flint/types/fmpz_mod_poly.pyx | 14 +++- src/flint/types/fq_default_poly.pyx | 3 + src/flint/types/nmod.pxd | 6 +- src/flint/types/nmod.pyx | 19 +++--- src/flint/types/nmod_poly.pxd | 2 +- src/flint/types/nmod_poly.pyx | 67 ++++++++++--------- 9 files changed, 160 insertions(+), 62 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index d5f1034f..129bd7cf 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -1113,17 +1113,27 @@ def set_bad(i): raises(lambda: Q(1,2,[3,4]) * Q(1,3,[5,6,7]), ValueError) raises(lambda: Q(1,2,[3,4]) * Z(1,3,[5,6,7]), ValueError) raises(lambda: Z(1,2,[3,4]) * Q(1,3,[5,6,7]), ValueError) + A = Q([[3,4],[5,7]]) / 11 X = Q([[1,2],[3,4]]) B = A*X assert A.solve(B) == X for algorithm in None, "fflu", "dixon": assert A.solve(B, algorithm=algorithm) == X + for _ in range(2): + A = Q(flint.fmpz_mat.randtest(30, 30, 10)) + if A.det() == 0: + continue + B = Q(flint.fmpz_mat.randtest(30, 1, 10)) + X = A.solve(B) + assert A*X == B + assert raises(lambda: A.solve(B, algorithm="invalid"), ValueError) assert raises(lambda: A.solve(None), TypeError) assert raises(lambda: A.solve([1,2]), TypeError) assert raises(lambda: A.solve(Q([[1,2]])), ValueError) assert raises(lambda: Q([[1,2],[2,4]]).solve(Q([[1],[2]])), ZeroDivisionError) + M = Q([[1,2,3],[flint.fmpq(1,2),5,6]]) Mcopy = Q(M) Mrref = Q([[1,0,flint.fmpq(3,4)],[0,1,flint.fmpq(9,8)]]) @@ -1466,9 +1476,15 @@ def set_bad2(): assert raises(set_bad2, TypeError) assert bool(P([], 5)) is False assert bool(P([1], 5)) is True + assert P([1,2,1],3).gcd(P([1,1],3)) == P([1,1],3) - raises(lambda: P([1,2],3).gcd([]), TypeError) - raises(lambda: P([1,2],3).gcd(P([1,2],5)), ValueError) + assert raises(lambda: P([1,2],3).gcd([]), TypeError) + assert raises(lambda: P([1,2],3).gcd(P([1,2],5)), ValueError) + assert P([1,2,1],3).xgcd(P([1,1],3)) == (P([1, 1], 3), P([0], 3), P([1], 3)) + assert raises(lambda: P([1,2],3).xgcd([]), TypeError) + assert raises(lambda: P([1,2],3).xgcd(P([1,2],5)), ValueError) + assert raises(lambda: P([1,2],6).xgcd(P([1,2],6)), DomainError) + p3 = P([1,2,3,4,5,6],7) f3 = (N(6,7), [(P([6, 1],7), 5)]) assert p3.factor() == f3 @@ -1476,6 +1492,8 @@ def set_bad2(): for alg in [None, 'berlekamp', 'cantor-zassenhaus']: assert p3.factor(alg) == f3 assert p3.factor(algorithm=alg) == f3 + assert raises(lambda: p3.factor(algorithm="invalid"), ValueError) + assert P([1], 11).roots() == [] assert P([1, 2, 3], 11).roots() == [(8, 1), (6, 1)] assert P([1, 6, 1, 8], 11).roots() == [(5, 3)] @@ -1608,6 +1626,33 @@ def test_nmod_series(): # XXX: currently no code in nmod_series.pyx pass + +def test_nmod_contexts(): + # XXX: Generalise this test to cover fmpz_mod, fq_default, etc. + C = flint.nmod_ctx + CP = flint.nmod_poly_ctx + G = flint.nmod + P = flint.nmod_poly + + for c, name in [(C, 'nmod'), (CP, 'nmod_poly')]: + ctx = c.new(17) + assert ctx.modulus() == 17 + assert str(ctx) == f"Context for {name} with modulus: 17" + assert repr(ctx) == f"{name}_ctx(17)" + assert raises(lambda: c(3), TypeError) + assert raises(lambda: ctx.new(3.0), TypeError) + + ctx = C.new(17) + assert ctx(3) == G(3,17) == G(3, ctx) + assert raises(lambda: ctx(3.0), TypeError) + assert raises(lambda: G(3, []), TypeError) + + ctx_poly = CP.new(17) + assert ctx_poly([1,2,3]) == P([1,2,3],17) == P([1,2,3], ctx_poly) + assert raises(lambda: ctx_poly([1,2.0,3]), TypeError) + assert raises(lambda: P([1,2,3], []), TypeError) + + def test_arb(): A = flint.arb assert A(3) > A(2.5) @@ -2211,7 +2256,7 @@ def test_fmpz_mod_poly(): f_inv = f.inverse_series_trunc(2) assert (f * f_inv) % R_test([0,0,1]) == 1 - assert raises(lambda: R_cmp([0,0,1]).inverse_series_trunc(2), ValueError) + assert raises(lambda: R_cmp([0,0,1]).inverse_series_trunc(2), ZeroDivisionError) # Resultant f1 = R_test([-3, 1]) @@ -2846,6 +2891,50 @@ def setbad(obj, i, val): if type(p) == flint.fq_default_poly: assert raises(lambda: p.integral(), NotImplementedError) + if characteristic == 0: + assert not hasattr(P(0), "inverse_series_trunc") + elif composite_characteristic: + x = P([0, 1]) + if type(x) is flint.fmpz_mod_poly: + assert (1 + x).inverse_series_trunc(4) == 1 - x + x**2 - x**3 + if characteristic.gcd(3) != 1: + assert (3 + x).inverse_series_trunc(4) == 1 - x + x**2 - x**3 + else: + assert (3 + x).inverse_series_trunc(4)\ + == S(1)/3 - S(1)/9*x + S(1)/27*x**2 - S(1)/81*x**3 + elif type(x) is flint.nmod_poly: + assert raises(lambda: (1 + x).inverse_series_trunc(4), DomainError) + else: + assert False + else: + x = P([0, 1]) + assert (1 + x).inverse_series_trunc(4) == 1 - x + x**2 - x**3 + assert (3 + x).inverse_series_trunc(4)\ + == S(1)/3 - S(1)/9*x + S(1)/27*x**2 - S(1)/81*x**3 + assert raises(lambda: (1 + x).inverse_series_trunc(-1), ValueError) + assert raises(lambda: x.inverse_series_trunc(4), ZeroDivisionError) + + if characteristic == 0: + assert not hasattr(P(0), "pow_mod") + elif composite_characteristic: + pass + else: + x = P([0, 1]) + assert (1 + x).pow_mod(4, x**2 + 1) == -4 + assert (3 + x).pow_mod(4, x**2 + 1) == 96*x + 28 + assert x.pow_mod(4, x**2 + 1) == 1 + + assert x.pow_mod(2**127, x - 1) == 1 + assert (1 + x).pow_mod(2**127, x - 1) == pow(2, 2**127, characteristic) + + if type(x) is not flint.fq_default_poly: + assert (1 + x).pow_mod(2**127, x - 1, S(1)/2) == pow(2, 2**127, characteristic) + assert raises(lambda: (1 + x).pow_mod(2**127, x - 1, []), TypeError) + + assert raises(lambda: (1 + x).pow_mod(4, []), TypeError) + assert raises(lambda: (1 + x).pow_mod([], x), TypeError) + assert raises(lambda: (1 + x).pow_mod(-1, x), ValueError) + def _all_mpolys(): return [ @@ -4045,6 +4134,8 @@ def test_matrices_pow(): # XXX: Allow unimodular matrices? assert raises(lambda: M1234**-1, DomainError) + assert raises(lambda: pow(M1234, 2, 3), NotImplementedError) + Mr = M([[1, 2, 3], [4, 5, 6]]) assert raises(lambda: Mr**0, ValueError) assert raises(lambda: Mr**1, ValueError) @@ -4608,6 +4699,8 @@ def test_all_tests(): test_nmod_mat, test_nmod_series, + test_nmod_contexts, + test_fmpz_mod, test_fmpz_mod_dlog, test_fmpz_mod_poly, diff --git a/src/flint/types/fmpq_mat.pyx b/src/flint/types/fmpq_mat.pyx index 0daf80dd..bb9e6a83 100644 --- a/src/flint/types/fmpq_mat.pyx +++ b/src/flint/types/fmpq_mat.pyx @@ -99,9 +99,6 @@ cdef class fmpq_mat(flint_mat): cdef bint r if op != 2 and op != 3: raise TypeError("matrices cannot be ordered") - s = any_as_fmpq_mat(s) - if t is NotImplemented: - return s t = any_as_fmpq_mat(t) if t is NotImplemented: return t @@ -268,9 +265,6 @@ cdef class fmpq_mat(flint_mat): def __truediv__(s, t): return fmpq_mat._div_(s, t) - def __div__(s, t): - return fmpq_mat._div_(s, t) - def inv(self): """ Returns the inverse matrix of *self*. @@ -488,7 +482,7 @@ cdef class fmpq_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("matrix must be square") if z is not None: - raise TypeError("fmpq_mat does not support modular exponentiation") + raise NotImplementedError("fmpq_mat does not support modular exponentiation") n = int(n) if n == 0: diff --git a/src/flint/types/fmpz_mod_mat.pyx b/src/flint/types/fmpz_mod_mat.pyx index 44f3c08c..c00659fe 100644 --- a/src/flint/types/fmpz_mod_mat.pyx +++ b/src/flint/types/fmpz_mod_mat.pyx @@ -467,8 +467,10 @@ cdef class fmpz_mod_mat(flint_mat): return self._scalarmul(e) return NotImplemented - def __pow__(self, other): + def __pow__(self, other, m=None): """``M ** n``: Raise a matrix to an integer power.""" + if m is not None: + raise NotImplementedError("fmpz_mod_mat pow: modulo not supported") if not isinstance(other, int): return NotImplemented return self._pow(other) diff --git a/src/flint/types/fmpz_mod_poly.pyx b/src/flint/types/fmpz_mod_poly.pyx index 434275d4..980fb208 100644 --- a/src/flint/types/fmpz_mod_poly.pyx +++ b/src/flint/types/fmpz_mod_poly.pyx @@ -1434,17 +1434,27 @@ cdef class fmpz_mod_poly(flint_poly): """ cdef fmpz_t f cdef fmpz_mod_poly res + cdef bint is_one + + if n < 1: + raise ValueError(f"{n = } must be positive") + + if self.constant_coefficient() == 0: + raise ZeroDivisionError("fmpz_mod_poly inverse_series_trunc: zero constant term") res = self.ctx.new_ctype_poly() fmpz_init(f) fmpz_mod_poly_inv_series_f( f, res.val, self.val, n, res.ctx.mod.val ) - if not fmpz_is_one(f): - fmpz_clear(f) + is_one = fmpz_is_one(f) + fmpz_clear(f) + + if not is_one: raise ValueError( f"Cannot compute inverse series of {self} modulo x^{n}" ) + return res def resultant(self, other): diff --git a/src/flint/types/fq_default_poly.pyx b/src/flint/types/fq_default_poly.pyx index ec2dfe83..be6a2555 100644 --- a/src/flint/types/fq_default_poly.pyx +++ b/src/flint/types/fq_default_poly.pyx @@ -1240,6 +1240,9 @@ cdef class fq_default_poly(flint_poly): """ cdef fq_default_poly res + if n < 1: + raise ValueError(f"{n = } must be positive") + if self.constant_coefficient().is_zero(): raise ZeroDivisionError("constant coefficient must be invertible") diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 666c4a09..8349001e 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -11,11 +11,11 @@ cdef class nmod_ctx: cdef bint _is_prime @staticmethod - cdef any_as_nmod_ctx(obj) + cdef nmod_ctx any_as_nmod_ctx(obj) @staticmethod - cdef _get_ctx(int mod) + cdef nmod_ctx _get_ctx(int mod) @staticmethod - cdef _new_ctx(ulong mod) + cdef nmod_ctx _new_ctx(ulong mod) @cython.final cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 246c4064..52fc8bff 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -46,19 +46,19 @@ cdef class nmod_ctx: @staticmethod def new(mod): """Get an nmod context with modulus ``mod``.""" - return nmod_ctx._get_ctx(mod) + return nmod_ctx.any_as_nmod_ctx(mod) @staticmethod - cdef any_as_nmod_ctx(obj): + cdef nmod_ctx any_as_nmod_ctx(obj): """Convert an ``nmod_ctx`` or ``int`` to an ``nmod_ctx``.""" if typecheck(obj, nmod_ctx): return obj if typecheck(obj, int): return nmod_ctx._get_ctx(obj) - return NotImplemented + raise TypeError("Invalid context/modulus for nmod: %s" % obj) @staticmethod - cdef _get_ctx(int mod): + cdef nmod_ctx _get_ctx(int mod): """Retrieve an nmod context from the cache or create a new one.""" ctx = _nmod_ctx_cache.get(mod) if ctx is None: @@ -66,7 +66,7 @@ cdef class nmod_ctx: return ctx @staticmethod - cdef _new_ctx(ulong mod): + cdef nmod_ctx _new_ctx(ulong mod): """Create a new nmod context.""" cdef nmod_ctx ctx = nmod_ctx.__new__(nmod_ctx) nmod_init(&ctx.mod, mod) @@ -160,7 +160,8 @@ cdef class nmod_ctx: """ r = self.new_nmod() - self.any_as_nmod(&r.val, val) + if not self.any_as_nmod(&r.val, val): + raise TypeError("cannot create nmod from object of type %s" % type(val)) return r @@ -174,11 +175,7 @@ cdef class nmod(flint_scalar): """ def __init__(self, val, mod): - cdef nmod_ctx ctx - c = nmod_ctx.any_as_nmod_ctx(mod) - if c is NotImplemented: - raise TypeError("Invalid context/modulus for nmod: %s" % mod) - ctx = c + cdef nmod_ctx ctx = nmod_ctx.any_as_nmod_ctx(mod) if not ctx.any_as_nmod(&self.val, val): raise TypeError("cannot create nmod from object of type %s" % type(val)) self.ctx = ctx diff --git a/src/flint/types/nmod_poly.pxd b/src/flint/types/nmod_poly.pxd index 1e70b526..a10ef8c6 100644 --- a/src/flint/types/nmod_poly.pxd +++ b/src/flint/types/nmod_poly.pxd @@ -16,7 +16,7 @@ cdef class nmod_poly_ctx: cdef nmod_ctx scalar_ctx @staticmethod - cdef any_as_nmod_poly_ctx(obj) + cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj) @staticmethod cdef nmod_poly_ctx _get_ctx(int mod) @staticmethod diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 9ff98d4d..5ad24187 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -33,16 +33,16 @@ cdef class nmod_poly_ctx: @staticmethod def new(mod): """Get an ``nmod_poly`` context with modulus ``mod``.""" - return nmod_poly_ctx._get_ctx(mod) + return nmod_poly_ctx.any_as_nmod_poly_ctx(mod) @staticmethod - cdef any_as_nmod_poly_ctx(obj): + cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj): """Convert an ``nmod_poly_ctx`` or ``int`` to an ``nmod_poly_ctx``.""" if typecheck(obj, nmod_poly_ctx): return obj if typecheck(obj, int): return nmod_poly_ctx._get_ctx(obj) - return NotImplemented + raise TypeError("Invalid context/modulus for nmod_poly: %s" % obj) @staticmethod cdef nmod_poly_ctx _get_ctx(int mod): @@ -216,9 +216,6 @@ cdef class nmod_poly(flint_poly): if mod == 0: raise ValueError("a nonzero modulus is required") ctx = nmod_poly_ctx.any_as_nmod_poly_ctx(mod) - if ctx is NotImplemented: - raise TypeError("cannot create nmod_poly_ctx from input of type %s", type(mod)) - self.ctx = ctx nmod_poly_init(self.val, ctx.mod.n) if typecheck(val, fmpz_poly): @@ -244,10 +241,12 @@ cdef class nmod_poly(flint_poly): return nmod_poly_modulus(self.val) def __richcmp__(s, t, int op): + cdef mp_limb_t v + cdef slong length cdef bint res if op != 2 and op != 3: raise TypeError("nmod_polys cannot be ordered") - if typecheck(s, nmod_poly) and typecheck(t, nmod_poly): + if typecheck(t, nmod_poly): if (s).val.mod.n != (t).val.mod.n: res = False else: @@ -256,22 +255,19 @@ cdef class nmod_poly(flint_poly): return res if op == 3: return not res - else: - if not typecheck(s, nmod_poly): - s, t = t, s - try: - t = nmod_poly([t], (s).val.mod.n) - except TypeError: - pass - if typecheck(s, nmod_poly) and typecheck(t, nmod_poly): - if (s).val.mod.n != (t).val.mod.n: - res = False - else: - res = nmod_poly_equal((s).val, (t).val) - if op == 2: - return res - if op == 3: - return not res + + # zero or constant poly can be equal to a scalar + length = nmod_poly_length(s.val) + if length <= 1 and s.ctx.any_as_nmod(&v, t): + if length == 0: + res = (v == 0) + else: + res = (v == nmod_poly_get_coeff_ui(s.val, 0)) + if op == 2: + return res + if op == 3: + return not res + return NotImplemented def __iter__(self): @@ -391,9 +387,12 @@ cdef class nmod_poly(flint_poly): """ if n <= 0: raise ValueError(f"n = {n} must be positive") + + if nmod_poly_get_coeff_ui(self.val, 0) == 0: + raise ZeroDivisionError(f"nmod_poly inverse_series_trunc: leading coefficient is zero") - if self.is_zero(): - raise ValueError("cannot invert the zero element") + if not self.ctx._is_prime: + raise DomainError(f"nmod_poly inverse_series_trunc: modulus {self.ctx.mod.n} is not prime") cdef nmod_poly res = self.ctx.new_nmod_poly() nmod_poly_inv_series(res.val, self.val, n) @@ -650,28 +649,28 @@ cdef class nmod_poly(flint_poly): """ cdef nmod_poly res - if e < 0: - raise ValueError("Exponent must be non-negative") - modulus = self.ctx.any_as_nmod_poly(modulus) if modulus is NotImplemented: raise TypeError("cannot convert input to nmod_poly") + # For larger exponents we need an fmpz + e_fmpz = any_as_fmpz(e) + if e_fmpz is NotImplemented: + raise TypeError(f"exponent cannot be cast to an fmpz type: {e}") + + if e < 0: + raise ValueError("Exponent must be non-negative") + # Output polynomial res = self.ctx.new_nmod_poly() # For small exponents, use a simple binary exponentiation method if e.bit_length() < 32: nmod_poly_powmod_ui_binexp( - res.val, self.val, e, (modulus).val + res.val, self.val, int(e), (modulus).val ) return res - # For larger exponents we need to cast e to an fmpz first - e_fmpz = any_as_fmpz(e) - if e_fmpz is NotImplemented: - raise TypeError(f"exponent cannot be cast to an fmpz type: {e}") - # To optimise powering, we precompute the inverse of the reverse of the modulus if mod_rev_inv is not None: mod_rev_inv = self.ctx.any_as_nmod_poly(mod_rev_inv) From ff215568a758751c15213c7e4a989cc5ea9000b6 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Wed, 28 Aug 2024 18:48:41 +0100 Subject: [PATCH 23/30] Test nmod_mat_ctx --- src/flint/test/test_all.py | 28 +++++++++++++++++++--------- src/flint/types/fmpz_mod_poly.pyx | 2 +- src/flint/types/nmod_mat.pxd | 2 +- src/flint/types/nmod_mat.pyx | 20 ++++++-------------- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index 129bd7cf..ac4ae406 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -1123,7 +1123,7 @@ def set_bad(i): for _ in range(2): A = Q(flint.fmpz_mat.randtest(30, 30, 10)) if A.det() == 0: - continue + continue # pragma: no cover B = Q(flint.fmpz_mat.randtest(30, 1, 10)) X = A.solve(B) assert A*X == B @@ -1548,6 +1548,8 @@ def test_nmod_mat(): assert raises(lambda: M([1], 5), TypeError) assert raises(lambda: M([[1],[2,3]], 5), ValueError) assert raises(lambda: M([[1],[2]], 0), ValueError) + assert raises(lambda: M([[1, 2], [3, 4.0]], 5), TypeError) + assert raises(lambda: M(2, 2, [1, 2, 3, 4.0], 5), TypeError) assert raises(lambda: M(None), TypeError) assert raises(lambda: M(None,17), TypeError) assert M(2,3,17) == M(2,3,[0,0,0,0,0,0],17) @@ -1629,12 +1631,14 @@ def test_nmod_series(): def test_nmod_contexts(): # XXX: Generalise this test to cover fmpz_mod, fq_default, etc. - C = flint.nmod_ctx + CS = flint.nmod_ctx CP = flint.nmod_poly_ctx - G = flint.nmod + CM = flint.nmod_mat_ctx + S = flint.nmod P = flint.nmod_poly + M = flint.nmod_mat - for c, name in [(C, 'nmod'), (CP, 'nmod_poly')]: + for c, name in [(CS, 'nmod'), (CP, 'nmod_poly'), (CM, 'nmod_mat')]: ctx = c.new(17) assert ctx.modulus() == 17 assert str(ctx) == f"Context for {name} with modulus: 17" @@ -1642,16 +1646,21 @@ def test_nmod_contexts(): assert raises(lambda: c(3), TypeError) assert raises(lambda: ctx.new(3.0), TypeError) - ctx = C.new(17) - assert ctx(3) == G(3,17) == G(3, ctx) + ctx = CS.new(17) + assert ctx(3) == S(3,17) == S(3, ctx) assert raises(lambda: ctx(3.0), TypeError) - assert raises(lambda: G(3, []), TypeError) + assert raises(lambda: S(3, []), TypeError) ctx_poly = CP.new(17) assert ctx_poly([1,2,3]) == P([1,2,3],17) == P([1,2,3], ctx_poly) assert raises(lambda: ctx_poly([1,2.0,3]), TypeError) assert raises(lambda: P([1,2,3], []), TypeError) + ctx_mat = CM.new(17) + assert ctx_mat([[1,2],[3,4]]) == M([[1,2],[3,4]],17) == M([[1,2],[3,4]], ctx_mat) + assert raises(lambda: ctx_mat([[1,2.0],[3,4]]), TypeError) + assert raises(lambda: M([[1,2],[3,4]], []), TypeError) + def test_arb(): A = flint.arb @@ -2596,6 +2605,7 @@ def FQ_DEFAULT(n, k): NMOD(9), NMOD(16), FMPZ_MOD(164), + FMPZ_MOD(9), FMPZ_MOD(2**127), FMPZ_MOD(2**255), ] @@ -2795,7 +2805,7 @@ def setbad(obj, i, val): assert P([1, 1]) % 2 == P([1, 1]) assert P([2, 2]) / 2 == P([1, 1]) assert raises(lambda: P([1, 2]) / 2, DomainError) - else: + elif characteristic.gcd(2) != 1 or type(P(1)) is flint.nmod_poly: # Z/nZ for n not prime assert raises(lambda: P([1, 1]) // 2, DomainError) assert raises(lambda: P([1, 1]) % 2, DomainError) @@ -2898,7 +2908,7 @@ def setbad(obj, i, val): if type(x) is flint.fmpz_mod_poly: assert (1 + x).inverse_series_trunc(4) == 1 - x + x**2 - x**3 if characteristic.gcd(3) != 1: - assert (3 + x).inverse_series_trunc(4) == 1 - x + x**2 - x**3 + assert raises(lambda: (3 + x).inverse_series_trunc(4), DomainError) else: assert (3 + x).inverse_series_trunc(4)\ == S(1)/3 - S(1)/9*x + S(1)/27*x**2 - S(1)/81*x**3 diff --git a/src/flint/types/fmpz_mod_poly.pyx b/src/flint/types/fmpz_mod_poly.pyx index 980fb208..30a32b7d 100644 --- a/src/flint/types/fmpz_mod_poly.pyx +++ b/src/flint/types/fmpz_mod_poly.pyx @@ -1451,7 +1451,7 @@ cdef class fmpz_mod_poly(flint_poly): fmpz_clear(f) if not is_one: - raise ValueError( + raise DomainError( f"Cannot compute inverse series of {self} modulo x^{n}" ) diff --git a/src/flint/types/nmod_mat.pxd b/src/flint/types/nmod_mat.pxd index 12a16774..ac0d5bca 100644 --- a/src/flint/types/nmod_mat.pxd +++ b/src/flint/types/nmod_mat.pxd @@ -18,7 +18,7 @@ cdef class nmod_mat_ctx: cdef nmod_poly_ctx poly_ctx @staticmethod - cdef any_as_nmod_mat_ctx(obj) + cdef nmod_mat_ctx any_as_nmod_mat_ctx(obj) @staticmethod cdef nmod_mat_ctx _get_ctx(int mod) @staticmethod diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 9cebaa5f..2e800291 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -76,10 +76,10 @@ cdef class nmod_mat_ctx: @staticmethod def new(mod): """Get an ``nmod_poly`` context with modulus ``mod``.""" - return nmod_mat_ctx._get_ctx(mod) + return nmod_mat_ctx.any_as_nmod_mat_ctx(mod) @staticmethod - cdef any_as_nmod_mat_ctx(obj): + cdef nmod_mat_ctx any_as_nmod_mat_ctx(obj): """Convert an ``nmod_mat_ctx`` or ``int`` to an ``nmod_mat_ctx``.""" if typecheck(obj, nmod_mat_ctx): return obj @@ -87,7 +87,7 @@ cdef class nmod_mat_ctx: return nmod_mat_ctx._get_ctx(obj) elif typecheck(obj, fmpz): return nmod_mat_ctx._get_ctx(int(obj)) - return NotImplemented + raise TypeError("nmod_mat: expected last argument to be an nmod_mat_ctx or an integer") @staticmethod cdef nmod_mat_ctx _get_ctx(int mod): @@ -265,12 +265,7 @@ cdef class nmod_mat(flint_mat): mod = args[-1] args = args[:-1] - c = nmod_mat_ctx.any_as_nmod_mat_ctx(mod) - if c is NotImplemented: - raise TypeError("nmod_mat: expected last argument to be an nmod_mat_ctx or an integer") - - ctx = c - self.ctx = ctx + self.ctx = ctx = nmod_mat_ctx.any_as_nmod_mat_ctx(mod) if mod == 0: raise ValueError("modulus must be nonzero") @@ -427,7 +422,7 @@ cdef class nmod_mat(flint_mat): return t tv = &(t).val[0] if sv.mod.n != tv.mod.n: - raise ValueError("cannot add nmod_mats with different moduli") + raise ValueError("cannot add nmod_mats with different moduli") # pragma: no cover if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix addition") r = s.ctx.new_nmod_mat(sv.r, sv.c) @@ -461,7 +456,7 @@ cdef class nmod_mat(flint_mat): return t tv = &(t).val[0] if sv.mod.n != tv.mod.n: - raise ValueError("cannot subtract nmod_mats with different moduli") + raise ValueError("cannot subtract nmod_mats with different moduli") # pragma: no cover if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix subtraction") r = s.ctx.new_nmod_mat(sv.r, sv.c) @@ -541,9 +536,6 @@ cdef class nmod_mat(flint_mat): def __truediv__(s, t): return nmod_mat._div_(s, t) - def __div__(s, t): - return nmod_mat._div_(s, t) - def det(self): """ Returns the determinant of self as an nmod. From c06bc866abfc49bddc00f22597e08b43dae83e8f Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Wed, 28 Aug 2024 19:27:33 +0100 Subject: [PATCH 24/30] Use int(characteristic) for PyPy PyPy cannot handle 3-arg pow if the first two types don't know about the third like pow(int, int, fmpz). --- src/flint/test/test_all.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index ac4ae406..b02c4a17 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -2935,7 +2935,7 @@ def setbad(obj, i, val): assert x.pow_mod(4, x**2 + 1) == 1 assert x.pow_mod(2**127, x - 1) == 1 - assert (1 + x).pow_mod(2**127, x - 1) == pow(2, 2**127, characteristic) + assert (1 + x).pow_mod(2**127, x - 1) == pow(2, 2**127, int(characteristic)) if type(x) is not flint.fq_default_poly: assert (1 + x).pow_mod(2**127, x - 1, S(1)/2) == pow(2, 2**127, characteristic) From 3d39267d7de88111e4a4b927a162545694af45a8 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Wed, 28 Aug 2024 20:09:12 +0100 Subject: [PATCH 25/30] Use int for PyPy --- src/flint/test/test_all.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index b02c4a17..a6cd6834 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -2938,7 +2938,7 @@ def setbad(obj, i, val): assert (1 + x).pow_mod(2**127, x - 1) == pow(2, 2**127, int(characteristic)) if type(x) is not flint.fq_default_poly: - assert (1 + x).pow_mod(2**127, x - 1, S(1)/2) == pow(2, 2**127, characteristic) + assert (1 + x).pow_mod(2**127, x - 1, S(1)/2) == pow(2, 2**127, int(characteristic)) assert raises(lambda: (1 + x).pow_mod(2**127, x - 1, []), TypeError) assert raises(lambda: (1 + x).pow_mod(4, []), TypeError) From 58f388705d7ad4fb3bd833ee59e6055e13b5f73d Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Thu, 29 Aug 2024 10:52:29 +0100 Subject: [PATCH 26/30] Use Cython master branch again after: https://github.com/cython/cython/pull/6369 --- .github/workflows/buildwheel.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/buildwheel.yml b/.github/workflows/buildwheel.yml index eba6f8c0..f94df605 100644 --- a/.github/workflows/buildwheel.yml +++ b/.github/workflows/buildwheel.yml @@ -204,10 +204,7 @@ jobs: - run: sudo apt-get install libflint-dev # Need Cython's master branch until 3.1 is released because we need: # https://github.com/cython/cython/pull/6341 - # Except now we can't use the master branch any more because of: - # https://github.com/cython/cython/issues/6366 - # So we have to keep using Oscar's PR branch: - - run: pip install git+https://github.com/oscarbenjamin/cython.git@pr_relative_paths + - run: pip install git+https://github.com/cython/cython.git@master - run: pip install -r requirements-dev.txt - run: bin/coverage.sh From e837d994497fe77f3a40ce6f8b6f517c629c685c Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Thu, 29 Aug 2024 15:40:33 +0100 Subject: [PATCH 27/30] Add @cython.no_gc to nmod_poly and nmod_mat --- src/flint/types/nmod_mat.pyx | 3 ++- src/flint/types/nmod_poly.pyx | 17 ++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 2e800291..1d015aa1 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -237,6 +237,7 @@ cdef class nmod_mat_ctx: return nmod_mat(*args, self) +@cython.no_gc cdef class nmod_mat(flint_mat): """ The nmod_mat type represents dense matrices over Z/nZ for word-size n (see @@ -474,7 +475,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv cdef mp_limb_t c - sv = &(s).val[0] + sv = &s.val[0] u = s.ctx.any_as_nmod_mat(t) if u is NotImplemented: if s.ctx.any_as_nmod(&c, t): diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 5ad24187..a66b9237 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -1,3 +1,5 @@ +cimport cython + from cpython.list cimport PyList_GET_SIZE from flint.flint_base.flint_base cimport flint_poly from flint.utils.typecheck cimport typecheck @@ -18,9 +20,10 @@ from flint.utils.flint_exceptions import DomainError cdef dict _nmod_poly_ctx_cache = {} +@cython.no_gc cdef class nmod_poly_ctx: """ - Context object for creating :class:`~.nmod_poly` initalised + Context object for creating :class:`~.nmod_poly` initalised with modulus :math:`N`. >>> nmod_poly_ctx.new(17) @@ -170,6 +173,7 @@ cdef class nmod_poly_ctx: return nmod_poly(arg, self) +@cython.no_gc cdef class nmod_poly(flint_poly): """ The nmod_poly type represents dense univariate polynomials @@ -417,8 +421,8 @@ cdef class nmod_poly(flint_poly): if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") res = self.ctx.new_nmod_poly() - nmod_poly_compose(res.val, self.val, (other).val) - return res + nmod_poly_compose(res.val, self.val, (other).val) + return res def compose_mod(self, other, modulus): r""" @@ -449,8 +453,8 @@ cdef class nmod_poly(flint_poly): raise ZeroDivisionError("cannot reduce modulo zero") res = self.ctx.new_nmod_poly() - nmod_poly_compose_mod(res.val, self.val, (other).val, (modulus).val) - return res + nmod_poly_compose_mod(res.val, self.val, (other).val, (modulus).val) + return res def __call__(self, other): cdef nmod_poly r @@ -638,7 +642,6 @@ cdef class nmod_poly(flint_poly): >>> f = 30*x**6 + 104*x**5 + 76*x**4 + 33*x**3 + 70*x**2 + 44*x + 65 >>> g = 43*x**6 + 91*x**5 + 77*x**4 + 113*x**3 + 71*x**2 + 132*x + 60 >>> mod = x**4 + 93*x**3 + 78*x**2 + 72*x + 149 - >>> >>> f.pow_mod(123, mod) 3*x^3 + 25*x^2 + 115*x + 161 >>> f.pow_mod(2**64, mod) @@ -663,7 +666,7 @@ cdef class nmod_poly(flint_poly): # Output polynomial res = self.ctx.new_nmod_poly() - + # For small exponents, use a simple binary exponentiation method if e.bit_length() < 32: nmod_poly_powmod_ui_binexp( From 6d14958d9f684730d43a6479629d936aa068be59 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Sun, 1 Sep 2024 03:26:01 +0100 Subject: [PATCH 28/30] Various lint fixes --- src/flint/types/fmpz_mod_mat.pyx | 2 +- src/flint/types/nmod.pxd | 3 ++ src/flint/types/nmod.pyx | 2 +- src/flint/types/nmod_mat.pxd | 7 ++++ src/flint/types/nmod_mat.pyx | 68 +++++++++++++++----------------- src/flint/types/nmod_poly.pxd | 6 +++ src/flint/types/nmod_poly.pyx | 6 +-- 7 files changed, 52 insertions(+), 42 deletions(-) diff --git a/src/flint/types/fmpz_mod_mat.pyx b/src/flint/types/fmpz_mod_mat.pyx index c00659fe..c036dca1 100644 --- a/src/flint/types/fmpz_mod_mat.pyx +++ b/src/flint/types/fmpz_mod_mat.pyx @@ -413,7 +413,7 @@ cdef class fmpz_mod_mat(flint_mat): raise ZeroDivisionError("fmpz_mod_mat div: division by zero") else: raise DomainError("fmpz_mod_mat div: division by non-invertible element") - return self._scalarmul(other.inverse()) + return self._scalarmul(inv) def __add__(self, other): """``M + N``: Add two matrices.""" diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 8349001e..9a6a04a3 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -12,13 +12,16 @@ cdef class nmod_ctx: @staticmethod cdef nmod_ctx any_as_nmod_ctx(obj) + @staticmethod cdef nmod_ctx _get_ctx(int mod) + @staticmethod cdef nmod_ctx _new_ctx(ulong mod) @cython.final cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + @cython.final cdef nmod new_nmod(self) diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 52fc8bff..82c0fe30 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -25,7 +25,7 @@ cdef dict _nmod_ctx_cache = {} @cython.no_gc cdef class nmod_ctx: """ - Context object for creating :class:`~.nmod` initalised + Context object for creating :class:`~.nmod` initalised with modulus :math:`N`. >>> ctx = nmod_ctx.new(17) diff --git a/src/flint/types/nmod_mat.pxd b/src/flint/types/nmod_mat.pxd index ac0d5bca..e5e69cc9 100644 --- a/src/flint/types/nmod_mat.pxd +++ b/src/flint/types/nmod_mat.pxd @@ -19,21 +19,28 @@ cdef class nmod_mat_ctx: @staticmethod cdef nmod_mat_ctx any_as_nmod_mat_ctx(obj) + @staticmethod cdef nmod_mat_ctx _get_ctx(int mod) + @staticmethod cdef nmod_mat_ctx _new_ctx(ulong mod) @cython.final cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + @cython.final cdef any_as_nmod_mat(self, obj) + @cython.final cdef nmod new_nmod(self) + @cython.final cdef nmod_poly new_nmod_poly(self) + @cython.final cdef nmod_mat new_nmod_mat(self, ulong m, ulong n) + @cython.final cdef nmod_mat new_nmod_mat_copy(self, nmod_mat other) diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 1d015aa1..8a7d097c 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -1,7 +1,6 @@ cimport cython from flint.flintlib.flint cimport ulong, mp_limb_t -from flint.flintlib.nmod cimport nmod_t from flint.flintlib.fmpz_mat cimport fmpz_mat_nrows, fmpz_mat_ncols from flint.flintlib.fmpz_mat cimport fmpz_mat_get_nmod_mat @@ -123,7 +122,6 @@ cdef class nmod_mat_ctx: cdef any_as_nmod_mat(self, obj): """Convert obj to nmod_mat or return NotImplemented.""" cdef nmod_mat r - cdef mp_limb_t v if typecheck(obj, nmod_mat): return obj @@ -181,38 +179,38 @@ cdef class nmod_mat_ctx: """ return self._is_prime - #def zero(self, slong m, slong n): - # """Return the zero ``nmod_mat``. - # - # >>> ctx = nmod_mat_ctx.new(17) - # >>> ctx.zero(2, 3) - # [0, 0, 0] - # [0, 0, 0] - - # """ - # cdef nmod_mat r = self.new_nmod_mat() - # nmod_mat_zero(r.val) - # return r - - #def one(self, slong m, slong n=-1): - # """Return the one ``nmod_mat``. - - # >>> ctx = nmod_mat_ctx.new(17) - # >>> ctx.one(2) - # [1, 0] - # [0, 1] - # >>> ctx.one(2, 3) - # [1, 0, 0] - # [0, 1, 0] - - # """ - # cdef nmod_mat r = self.new_nmod_mat() - # if n == -1: - # n = m - # n = min(m, n) - # for i from 0 <= i < n: - # nmod_mat_set_entry(r.val, i, i, 1) - # return r + #def zero(self, slong m, slong n): + # """Return the zero ``nmod_mat``. + # + # >>> ctx = nmod_mat_ctx.new(17) + # >>> ctx.zero(2, 3) + # [0, 0, 0] + # [0, 0, 0] + + # """ + # cdef nmod_mat r = self.new_nmod_mat() + # nmod_mat_zero(r.val) + # return r + + #def one(self, slong m, slong n=-1): + # """Return the one ``nmod_mat``. + + # >>> ctx = nmod_mat_ctx.new(17) + # >>> ctx.one(2) + # [1, 0] + # [0, 1] + # >>> ctx.one(2, 3) + # [1, 0, 0] + # [0, 1, 0] + + # """ + # cdef nmod_mat r = self.new_nmod_mat() + # if n == -1: + # n = m + # n = min(m, n) + # for i from 0 <= i < n: + # nmod_mat_set_entry(r.val, i, i, 1) + # return r def __str__(self): return f"Context for nmod_mat with modulus: {self.mod.n}" @@ -491,9 +489,7 @@ cdef class nmod_mat(flint_mat): return r def __rmul__(s, t): - cdef nmod_mat_struct *sv cdef mp_limb_t c - sv = &(s).val[0] if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) u = s.ctx.any_as_nmod_mat(t) diff --git a/src/flint/types/nmod_poly.pxd b/src/flint/types/nmod_poly.pxd index a10ef8c6..dde80357 100644 --- a/src/flint/types/nmod_poly.pxd +++ b/src/flint/types/nmod_poly.pxd @@ -17,19 +17,25 @@ cdef class nmod_poly_ctx: @staticmethod cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj) + @staticmethod cdef nmod_poly_ctx _get_ctx(int mod) + @staticmethod cdef nmod_poly_ctx _new_ctx(ulong mod) @cython.final cdef nmod_poly_set_list(self, nmod_poly_t poly, list val) + @cython.final cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + @cython.final cdef any_as_nmod_poly(self, obj) + @cython.final cdef nmod new_nmod(self) + @cython.final cdef nmod_poly new_nmod_poly(self) diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index a66b9237..4ec18fa9 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -12,7 +12,6 @@ from flint.flintlib.nmod_vec cimport * from flint.flintlib.nmod_poly cimport * from flint.flintlib.nmod_poly_factor cimport * from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly -from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime from flint.utils.flint_exceptions import DomainError @@ -104,7 +103,6 @@ cdef class nmod_poly_ctx: n = PyList_GET_SIZE(val) nmod_poly_fit_length(poly, n) for i from 0 <= i < n: - c = val[i] if self.any_as_nmod(&v, val[i]): nmod_poly_set_coeff_ui(poly, i, v) else: @@ -393,7 +391,7 @@ cdef class nmod_poly(flint_poly): raise ValueError(f"n = {n} must be positive") if nmod_poly_get_coeff_ui(self.val, 0) == 0: - raise ZeroDivisionError(f"nmod_poly inverse_series_trunc: leading coefficient is zero") + raise ZeroDivisionError("nmod_poly inverse_series_trunc: leading coefficient is zero") if not self.ctx._is_prime: raise DomainError(f"nmod_poly inverse_series_trunc: modulus {self.ctx.mod.n} is not prime") @@ -840,7 +838,7 @@ cdef class nmod_poly(flint_poly): def sqrt(nmod_poly self): """Return exact square root or ``None``. """ - cdef nmod_poly + cdef nmod_poly res if not self.ctx._is_prime: raise DomainError(f"nmod_poly sqrt: modulus {self.ctx.mod.n} is not prime") From a9215d08d6256a91ae7c987999349f52952d9a4e Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Mon, 2 Sep 2024 19:08:31 +0100 Subject: [PATCH 29/30] Lint fixes --- src/flint/types/fmpq.pyx | 2 +- src/flint/types/fmpz_mod_poly.pyx | 2 +- src/flint/types/fq_default_poly.pyx | 2 +- src/flint/types/nmod_mat.pyx | 37 ++--------------------------- src/flint/types/nmod_poly.pyx | 2 +- 5 files changed, 6 insertions(+), 39 deletions(-) diff --git a/src/flint/types/fmpq.pyx b/src/flint/types/fmpq.pyx index 1016f4f3..d84cc55b 100644 --- a/src/flint/types/fmpq.pyx +++ b/src/flint/types/fmpq.pyx @@ -441,7 +441,7 @@ cdef class fmpq(flint_scalar): if sys.version_info < (3, 12): return hash(Fraction(int(self.p), int(self.q), _normalize=False)) # pragma: no cover else: - return hash(Fraction._from_coprime_ints(int(self.p), int(self.q))) # pragma: no cover + return hash(Fraction._from_coprime_ints(int(self.p), int(self.q))) # pragma: no cover def height_bits(self, bint signed=False): """ diff --git a/src/flint/types/fmpz_mod_poly.pyx b/src/flint/types/fmpz_mod_poly.pyx index 30a32b7d..e1482c83 100644 --- a/src/flint/types/fmpz_mod_poly.pyx +++ b/src/flint/types/fmpz_mod_poly.pyx @@ -1437,7 +1437,7 @@ cdef class fmpz_mod_poly(flint_poly): cdef bint is_one if n < 1: - raise ValueError(f"{n = } must be positive") + raise ValueError(f"n = {n} must be positive") if self.constant_coefficient() == 0: raise ZeroDivisionError("fmpz_mod_poly inverse_series_trunc: zero constant term") diff --git a/src/flint/types/fq_default_poly.pyx b/src/flint/types/fq_default_poly.pyx index be6a2555..4582f0e5 100644 --- a/src/flint/types/fq_default_poly.pyx +++ b/src/flint/types/fq_default_poly.pyx @@ -1241,7 +1241,7 @@ cdef class fq_default_poly(flint_poly): cdef fq_default_poly res if n < 1: - raise ValueError(f"{n = } must be positive") + raise ValueError(f"n = {n} must be positive") if self.constant_coefficient().is_zero(): raise ZeroDivisionError("constant coefficient must be invertible") diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 8a7d097c..0f7955e5 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -179,39 +179,6 @@ cdef class nmod_mat_ctx: """ return self._is_prime - #def zero(self, slong m, slong n): - # """Return the zero ``nmod_mat``. - # - # >>> ctx = nmod_mat_ctx.new(17) - # >>> ctx.zero(2, 3) - # [0, 0, 0] - # [0, 0, 0] - - # """ - # cdef nmod_mat r = self.new_nmod_mat() - # nmod_mat_zero(r.val) - # return r - - #def one(self, slong m, slong n=-1): - # """Return the one ``nmod_mat``. - - # >>> ctx = nmod_mat_ctx.new(17) - # >>> ctx.one(2) - # [1, 0] - # [0, 1] - # >>> ctx.one(2, 3) - # [1, 0, 0] - # [0, 1, 0] - - # """ - # cdef nmod_mat r = self.new_nmod_mat() - # if n == -1: - # n = m - # n = min(m, n) - # for i from 0 <= i < n: - # nmod_mat_set_entry(r.val, i, i, 1) - # return r - def __str__(self): return f"Context for nmod_mat with modulus: {self.mod.n}" @@ -421,7 +388,7 @@ cdef class nmod_mat(flint_mat): return t tv = &(t).val[0] if sv.mod.n != tv.mod.n: - raise ValueError("cannot add nmod_mats with different moduli") # pragma: no cover + raise ValueError("cannot add nmod_mats with different moduli") # pragma: no cover if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix addition") r = s.ctx.new_nmod_mat(sv.r, sv.c) @@ -455,7 +422,7 @@ cdef class nmod_mat(flint_mat): return t tv = &(t).val[0] if sv.mod.n != tv.mod.n: - raise ValueError("cannot subtract nmod_mats with different moduli") # pragma: no cover + raise ValueError("cannot subtract nmod_mats with different moduli") # pragma: no cover if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix subtraction") r = s.ctx.new_nmod_mat(sv.r, sv.c) diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 4ec18fa9..77809b8e 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -389,7 +389,7 @@ cdef class nmod_poly(flint_poly): """ if n <= 0: raise ValueError(f"n = {n} must be positive") - + if nmod_poly_get_coeff_ui(self.val, 0) == 0: raise ZeroDivisionError("nmod_poly inverse_series_trunc: leading coefficient is zero") From b1d523994b0cb035fa1923b45d3a1e164657ba4b Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Tue, 3 Sep 2024 15:20:35 +0100 Subject: [PATCH 30/30] refactor: move nmod_ctx cdef methods to .pxd files --- src/flint/types/nmod.pxd | 75 +++++++++++++++++-- src/flint/types/nmod.pyx | 67 +---------------- src/flint/types/nmod_mat.pxd | 94 ++++++++++++++++++++--- src/flint/types/nmod_mat.pyx | 105 +++----------------------- src/flint/types/nmod_poly.pxd | 88 +++++++++++++++++++--- src/flint/types/nmod_poly.pyx | 135 ++++++++++++++-------------------- 6 files changed, 297 insertions(+), 267 deletions(-) diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 9a6a04a3..dddfc701 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -1,8 +1,29 @@ cimport cython from flint.flint_base.flint_base cimport flint_scalar +from flint.utils.typecheck cimport typecheck + from flint.flintlib.flint cimport mp_limb_t, ulong -from flint.flintlib.nmod cimport nmod_t +from flint.flintlib.nmod cimport nmod_t, nmod_init +from flint.flintlib.ulong_extras cimport n_is_prime + +from flint.flintlib.fmpz cimport fmpz_t +from flint.flintlib.fmpq cimport fmpq_mod_fmpz + +from flint.flintlib.fmpz cimport ( + fmpz_t, + fmpz_fdiv_ui, + fmpz_init, + fmpz_clear, + fmpz_set_ui, + fmpz_get_ui, +) + +from flint.types.fmpz cimport fmpz, any_as_fmpz +from flint.types.fmpq cimport fmpq, any_as_fmpq + + +cdef dict _nmod_ctx_cache @cython.no_gc @@ -11,19 +32,61 @@ cdef class nmod_ctx: cdef bint _is_prime @staticmethod - cdef nmod_ctx any_as_nmod_ctx(obj) + cdef inline nmod_ctx any_as_nmod_ctx(obj): + """Convert an ``nmod_ctx`` or ``int`` to an ``nmod_ctx``.""" + if typecheck(obj, nmod_ctx): + return obj + if typecheck(obj, int): + return nmod_ctx._get_ctx(obj) + raise TypeError("Invalid context/modulus for nmod: %s" % obj) @staticmethod - cdef nmod_ctx _get_ctx(int mod) + cdef inline nmod_ctx _get_ctx(int mod): + """Retrieve an nmod context from the cache or create a new one.""" + ctx = _nmod_ctx_cache.get(mod) + if ctx is None: + ctx = _nmod_ctx_cache.setdefault(mod, nmod_ctx._new_ctx(mod)) + return ctx @staticmethod - cdef nmod_ctx _new_ctx(ulong mod) + cdef inline nmod_ctx _new_ctx(ulong mod): + """Create a new nmod context.""" + cdef nmod_ctx ctx = nmod_ctx.__new__(nmod_ctx) + nmod_init(&ctx.mod, mod) + ctx._is_prime = n_is_prime(mod) + return ctx @cython.final - cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + cdef inline int any_as_nmod(nmod_ctx ctx, mp_limb_t * val, obj) except -1: + """Convert an object to an nmod element.""" + cdef int success + cdef fmpz_t t + if typecheck(obj, nmod): + if (obj).ctx.mod.n != ctx.mod.n: + raise ValueError("cannot coerce integers mod n with different n") + val[0] = (obj).val + return 1 + z = any_as_fmpz(obj) + if z is not NotImplemented: + val[0] = fmpz_fdiv_ui((z).val, ctx.mod.n) + return 1 + q = any_as_fmpq(obj) + if q is not NotImplemented: + fmpz_init(t) + fmpz_set_ui(t, ctx.mod.n) + success = fmpq_mod_fmpz(t, (q).val, t) + val[0] = fmpz_get_ui(t) + fmpz_clear(t) + if not success: + raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n)) + return 1 + return 0 @cython.final - cdef nmod new_nmod(self) + cdef inline nmod new_nmod(self): + cdef nmod r = nmod.__new__(nmod) + r.ctx = self + return r @cython.no_gc diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 82c0fe30..322810de 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -2,24 +2,18 @@ cimport cython from flint.flint_base.flint_base cimport flint_scalar from flint.utils.typecheck cimport typecheck -from flint.types.fmpq cimport any_as_fmpq from flint.types.fmpz cimport any_as_fmpz from flint.types.fmpz cimport fmpz -from flint.types.fmpq cimport fmpq from flint.flintlib.flint cimport ulong -from flint.flintlib.fmpz cimport fmpz_t from flint.flintlib.nmod cimport nmod_pow_fmpz from flint.flintlib.nmod_vec cimport * -from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear -from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui -from flint.flintlib.fmpq cimport fmpq_mod_fmpz -from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime, n_sqrtmod +from flint.flintlib.ulong_extras cimport n_gcdinv, n_sqrtmod from flint.utils.flint_exceptions import DomainError -cdef dict _nmod_ctx_cache = {} +_nmod_ctx_cache = {} @cython.no_gc @@ -48,63 +42,6 @@ cdef class nmod_ctx: """Get an nmod context with modulus ``mod``.""" return nmod_ctx.any_as_nmod_ctx(mod) - @staticmethod - cdef nmod_ctx any_as_nmod_ctx(obj): - """Convert an ``nmod_ctx`` or ``int`` to an ``nmod_ctx``.""" - if typecheck(obj, nmod_ctx): - return obj - if typecheck(obj, int): - return nmod_ctx._get_ctx(obj) - raise TypeError("Invalid context/modulus for nmod: %s" % obj) - - @staticmethod - cdef nmod_ctx _get_ctx(int mod): - """Retrieve an nmod context from the cache or create a new one.""" - ctx = _nmod_ctx_cache.get(mod) - if ctx is None: - ctx = _nmod_ctx_cache.setdefault(mod, nmod_ctx._new_ctx(mod)) - return ctx - - @staticmethod - cdef nmod_ctx _new_ctx(ulong mod): - """Create a new nmod context.""" - cdef nmod_ctx ctx = nmod_ctx.__new__(nmod_ctx) - nmod_init(&ctx.mod, mod) - ctx._is_prime = n_is_prime(mod) - return ctx - - @cython.final - cdef int any_as_nmod(nmod_ctx ctx, mp_limb_t * val, obj) except -1: - """Convert an object to an nmod element.""" - cdef int success - cdef fmpz_t t - if typecheck(obj, nmod): - if (obj).ctx.mod.n != ctx.mod.n: - raise ValueError("cannot coerce integers mod n with different n") - val[0] = (obj).val - return 1 - z = any_as_fmpz(obj) - if z is not NotImplemented: - val[0] = fmpz_fdiv_ui((z).val, ctx.mod.n) - return 1 - q = any_as_fmpq(obj) - if q is not NotImplemented: - fmpz_init(t) - fmpz_set_ui(t, ctx.mod.n) - success = fmpq_mod_fmpz(t, (q).val, t) - val[0] = fmpz_get_ui(t) - fmpz_clear(t) - if not success: - raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n)) - return 1 - return 0 - - @cython.final - cdef nmod new_nmod(self): - cdef nmod r = nmod.__new__(nmod) - r.ctx = self - return r - def modulus(self): """Get the modulus of the context. diff --git a/src/flint/types/nmod_mat.pxd b/src/flint/types/nmod_mat.pxd index e5e69cc9..3b5ac2e2 100644 --- a/src/flint/types/nmod_mat.pxd +++ b/src/flint/types/nmod_mat.pxd @@ -1,15 +1,30 @@ cimport cython +from flint.utils.typecheck cimport typecheck from flint.flint_base.flint_base cimport flint_mat -from flint.flintlib.nmod cimport nmod_t -from flint.flintlib.nmod_mat cimport nmod_mat_t from flint.flintlib.flint cimport mp_limb_t, ulong - +from flint.flintlib.fmpz_mat cimport ( + fmpz_mat_nrows, + fmpz_mat_ncols, + fmpz_mat_get_nmod_mat, +) +from flint.flintlib.nmod cimport nmod_t +from flint.flintlib.nmod_mat cimport ( + nmod_mat_t, + nmod_mat_init, + nmod_mat_init_set, +) + +from flint.types.fmpz cimport fmpz +from flint.types.fmpz_mat cimport fmpz_mat, any_as_fmpz_mat from flint.types.nmod cimport nmod_ctx, nmod from flint.types.nmod_poly cimport nmod_poly_ctx, nmod_poly +cdef dict _nmod_mat_ctx_cache + + @cython.no_gc cdef class nmod_mat_ctx: cdef nmod_t mod @@ -18,31 +33,86 @@ cdef class nmod_mat_ctx: cdef nmod_poly_ctx poly_ctx @staticmethod - cdef nmod_mat_ctx any_as_nmod_mat_ctx(obj) + cdef inline nmod_mat_ctx any_as_nmod_mat_ctx(obj): + """Convert an ``nmod_mat_ctx`` or ``int`` to an ``nmod_mat_ctx``.""" + if typecheck(obj, nmod_mat_ctx): + return obj + if typecheck(obj, int): + return nmod_mat_ctx._get_ctx(obj) + elif typecheck(obj, fmpz): + return nmod_mat_ctx._get_ctx(int(obj)) + raise TypeError("nmod_mat: expected last argument to be an nmod_mat_ctx or an integer") @staticmethod - cdef nmod_mat_ctx _get_ctx(int mod) + cdef inline nmod_mat_ctx _get_ctx(int mod): + """Retrieve an nmod_mat context from the cache or create a new one.""" + ctx = _nmod_mat_ctx_cache.get(mod) + if ctx is None: + ctx = _nmod_mat_ctx_cache.setdefault(mod, nmod_mat_ctx._new_ctx(mod)) + return ctx @staticmethod - cdef nmod_mat_ctx _new_ctx(ulong mod) + cdef inline nmod_mat_ctx _new_ctx(ulong mod): + """Create a new nmod_mat context.""" + cdef nmod_ctx scalar_ctx + cdef nmod_poly_ctx poly_ctx + cdef nmod_mat_ctx ctx + + poly_ctx = nmod_poly_ctx.new(mod) + scalar_ctx = poly_ctx.scalar_ctx + + ctx = nmod_mat_ctx.__new__(nmod_mat_ctx) + ctx.mod = scalar_ctx.mod + ctx._is_prime = scalar_ctx._is_prime + ctx.scalar_ctx = scalar_ctx + ctx.poly_ctx = poly_ctx + + return ctx @cython.final - cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + cdef inline int any_as_nmod(self, mp_limb_t * val, obj) except -1: + return self.scalar_ctx.any_as_nmod(val, obj) @cython.final - cdef any_as_nmod_mat(self, obj) + cdef inline any_as_nmod_mat(self, obj): + """Convert obj to nmod_mat or return NotImplemented.""" + cdef nmod_mat r + + if typecheck(obj, nmod_mat): + return obj + + x = any_as_fmpz_mat(obj) + if x is not NotImplemented: + r = self.new_nmod_mat(fmpz_mat_nrows((x).val), + fmpz_mat_ncols((x).val)) + fmpz_mat_get_nmod_mat(r.val, (x).val) + return r + + return NotImplemented @cython.final - cdef nmod new_nmod(self) + cdef inline nmod new_nmod(self): + return self.scalar_ctx.new_nmod() @cython.final - cdef nmod_poly new_nmod_poly(self) + cdef inline nmod_poly new_nmod_poly(self): + return self.poly_ctx.new_nmod_poly() @cython.final - cdef nmod_mat new_nmod_mat(self, ulong m, ulong n) + cdef inline nmod_mat new_nmod_mat(self, ulong m, ulong n): + """New initialized nmod_mat of size m x n with context ctx.""" + cdef nmod_mat r = nmod_mat.__new__(nmod_mat) + nmod_mat_init(r.val, m, n, self.mod.n) + r.ctx = self + return r @cython.final - cdef nmod_mat new_nmod_mat_copy(self, nmod_mat other) + cdef inline nmod_mat new_nmod_mat_copy(self, nmod_mat other): + """New copy of nmod_mat other.""" + cdef nmod_mat r = nmod_mat.__new__(nmod_mat) + nmod_mat_init_set(r.val, other.val) + r.ctx = other.ctx + return r @cython.no_gc diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 0f7955e5..75450b06 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -1,9 +1,17 @@ cimport cython +from flint.utils.typecheck cimport typecheck +from flint.pyflint cimport global_random_state + from flint.flintlib.flint cimport ulong, mp_limb_t +from flint.flint_base.flint_context cimport thectx +from flint.flint_base.flint_base cimport flint_mat -from flint.flintlib.fmpz_mat cimport fmpz_mat_nrows, fmpz_mat_ncols -from flint.flintlib.fmpz_mat cimport fmpz_mat_get_nmod_mat +from flint.flintlib.fmpz_mat cimport ( + fmpz_mat_nrows, + fmpz_mat_ncols, + fmpz_mat_get_nmod_mat, +) from flint.flintlib.nmod_mat cimport ( nmod_mat_struct, @@ -35,25 +43,18 @@ from flint.flintlib.nmod_mat cimport ( nmod_mat_randtest, ) -from flint.utils.typecheck cimport typecheck from flint.types.fmpz cimport fmpz -from flint.types.fmpz_mat cimport any_as_fmpz_mat from flint.types.fmpz_mat cimport fmpz_mat -from flint.types.nmod cimport nmod, nmod_ctx +from flint.types.nmod cimport nmod from flint.types.nmod_poly cimport nmod_poly -from flint.pyflint cimport global_random_state -from flint.flint_base.flint_context cimport thectx - -from flint.flint_base.flint_base cimport flint_mat - from flint.utils.flint_exceptions import DomainError ctx = thectx -cdef dict _nmod_mat_ctx_cache = {} +_nmod_mat_ctx_cache = {} @cython.no_gc @@ -77,88 +78,6 @@ cdef class nmod_mat_ctx: """Get an ``nmod_poly`` context with modulus ``mod``.""" return nmod_mat_ctx.any_as_nmod_mat_ctx(mod) - @staticmethod - cdef nmod_mat_ctx any_as_nmod_mat_ctx(obj): - """Convert an ``nmod_mat_ctx`` or ``int`` to an ``nmod_mat_ctx``.""" - if typecheck(obj, nmod_mat_ctx): - return obj - if typecheck(obj, int): - return nmod_mat_ctx._get_ctx(obj) - elif typecheck(obj, fmpz): - return nmod_mat_ctx._get_ctx(int(obj)) - raise TypeError("nmod_mat: expected last argument to be an nmod_mat_ctx or an integer") - - @staticmethod - cdef nmod_mat_ctx _get_ctx(int mod): - """Retrieve an nmod_mat context from the cache or create a new one.""" - ctx = _nmod_mat_ctx_cache.get(mod) - if ctx is None: - ctx = _nmod_mat_ctx_cache.setdefault(mod, nmod_mat_ctx._new_ctx(mod)) - return ctx - - @staticmethod - cdef nmod_mat_ctx _new_ctx(ulong mod): - """Create a new nmod_mat context.""" - cdef nmod_ctx scalar_ctx - cdef nmod_poly_ctx poly_ctx - cdef nmod_mat_ctx ctx - - poly_ctx = nmod_poly_ctx.new(mod) - scalar_ctx = poly_ctx.scalar_ctx - - ctx = nmod_mat_ctx.__new__(nmod_mat_ctx) - ctx.mod = scalar_ctx.mod - ctx._is_prime = scalar_ctx._is_prime - ctx.scalar_ctx = scalar_ctx - ctx.poly_ctx = poly_ctx - - return ctx - - @cython.final - cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: - return self.scalar_ctx.any_as_nmod(val, obj) - - @cython.final - cdef any_as_nmod_mat(self, obj): - """Convert obj to nmod_mat or return NotImplemented.""" - cdef nmod_mat r - - if typecheck(obj, nmod_mat): - return obj - - x = any_as_fmpz_mat(obj) - if x is not NotImplemented: - r = self.new_nmod_mat(fmpz_mat_nrows((x).val), - fmpz_mat_ncols((x).val)) - fmpz_mat_get_nmod_mat(r.val, (x).val) - return r - - return NotImplemented - - @cython.final - cdef nmod new_nmod(self): - return self.scalar_ctx.new_nmod() - - @cython.final - cdef nmod_poly new_nmod_poly(self): - return self.poly_ctx.new_nmod_poly() - - @cython.final - cdef nmod_mat new_nmod_mat(self, ulong m, ulong n): - """New initialized nmod_mat of size m x n with context ctx.""" - cdef nmod_mat r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, m, n, self.mod.n) - r.ctx = self - return r - - @cython.final - cdef nmod_mat new_nmod_mat_copy(self, nmod_mat other): - """New copy of nmod_mat other.""" - cdef nmod_mat r = nmod_mat.__new__(nmod_mat) - nmod_mat_init_set(r.val, other.val) - r.ctx = other.ctx - return r - def modulus(self): """Get the modulus of the context. diff --git a/src/flint/types/nmod_poly.pxd b/src/flint/types/nmod_poly.pxd index dde80357..34ad454d 100644 --- a/src/flint/types/nmod_poly.pxd +++ b/src/flint/types/nmod_poly.pxd @@ -1,14 +1,28 @@ cimport cython -from flint.flintlib.nmod cimport nmod_t -from flint.flintlib.nmod_poly cimport nmod_poly_t -from flint.flintlib.flint cimport mp_limb_t, ulong +from cpython.list cimport PyList_GET_SIZE +from flint.utils.typecheck cimport typecheck from flint.flint_base.flint_base cimport flint_poly +from flint.flintlib.flint cimport mp_limb_t, ulong +from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly + +from flint.flintlib.nmod cimport nmod_t +from flint.flintlib.nmod_poly cimport ( + nmod_poly_t, + nmod_poly_init_preinv, + nmod_poly_fit_length, + nmod_poly_set_coeff_ui, +) + +from flint.types.fmpz_poly cimport fmpz_poly, any_as_fmpz_poly from flint.types.nmod cimport nmod_ctx, nmod +cdef dict _nmod_poly_ctx_cache = {} + + @cython.no_gc cdef class nmod_poly_ctx: cdef nmod_t mod @@ -16,28 +30,80 @@ cdef class nmod_poly_ctx: cdef nmod_ctx scalar_ctx @staticmethod - cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj) + cdef inline nmod_poly_ctx any_as_nmod_poly_ctx(obj): + """Convert an ``nmod_poly_ctx`` or ``int`` to an ``nmod_poly_ctx``.""" + if typecheck(obj, nmod_poly_ctx): + return obj + if typecheck(obj, int): + return nmod_poly_ctx._get_ctx(obj) + raise TypeError("Invalid context/modulus for nmod_poly: %s" % obj) @staticmethod - cdef nmod_poly_ctx _get_ctx(int mod) + cdef inline nmod_poly_ctx _get_ctx(int mod): + """Retrieve an nmod_poly context from the cache or create a new one.""" + ctx = _nmod_poly_ctx_cache.get(mod) + if ctx is None: + ctx = _nmod_poly_ctx_cache.setdefault(mod, nmod_poly_ctx._new_ctx(mod)) + return ctx @staticmethod - cdef nmod_poly_ctx _new_ctx(ulong mod) + cdef inline nmod_poly_ctx _new_ctx(ulong mod): + """Create a new nmod_poly context.""" + cdef nmod_ctx scalar_ctx + cdef nmod_poly_ctx ctx + scalar_ctx = nmod_ctx.new(mod) + + ctx = nmod_poly_ctx.__new__(nmod_poly_ctx) + ctx.mod = scalar_ctx.mod + ctx._is_prime = scalar_ctx._is_prime + ctx.scalar_ctx = scalar_ctx + + return ctx @cython.final - cdef nmod_poly_set_list(self, nmod_poly_t poly, list val) + cdef inline int any_as_nmod(self, mp_limb_t * val, obj) except -1: + return self.scalar_ctx.any_as_nmod(val, obj) @cython.final - cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + cdef inline any_as_nmod_poly(self, obj): + cdef nmod_poly r + cdef mp_limb_t v + # XXX: should check that modulus is the same here, and not all over the place + if typecheck(obj, nmod_poly): + return obj + if self.any_as_nmod(&v, obj): + r = self.new_nmod_poly() + nmod_poly_set_coeff_ui(r.val, 0, v) + return r + x = any_as_fmpz_poly(obj) + if x is not NotImplemented: + r = self.new_nmod_poly() + fmpz_poly_get_nmod_poly(r.val, (x).val) + return r + return NotImplemented @cython.final - cdef any_as_nmod_poly(self, obj) + cdef inline nmod new_nmod(self): + return self.scalar_ctx.new_nmod() @cython.final - cdef nmod new_nmod(self) + cdef inline nmod_poly new_nmod_poly(self): + cdef nmod_poly p = nmod_poly.__new__(nmod_poly) + nmod_poly_init_preinv(p.val, self.mod.n, self.mod.ninv) + p.ctx = self + return p @cython.final - cdef nmod_poly new_nmod_poly(self) + cdef inline nmod_poly_set_list(self, nmod_poly_t poly, list val): + cdef long i, n + cdef mp_limb_t v + n = PyList_GET_SIZE(val) + nmod_poly_fit_length(poly, n) + for i from 0 <= i < n: + if self.any_as_nmod(&v, val[i]): + nmod_poly_set_coeff_ui(poly, i, v) + else: + raise TypeError("unsupported coefficient in list") @cython.no_gc diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 77809b8e..f2935e36 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -1,22 +1,68 @@ cimport cython -from cpython.list cimport PyList_GET_SIZE from flint.flint_base.flint_base cimport flint_poly from flint.utils.typecheck cimport typecheck -from flint.types.fmpz cimport fmpz, any_as_fmpz -from flint.types.fmpz_poly cimport any_as_fmpz_poly -from flint.types.fmpz_poly cimport fmpz_poly -from flint.types.nmod cimport nmod, nmod_ctx -from flint.flintlib.nmod_vec cimport * -from flint.flintlib.nmod_poly cimport * -from flint.flintlib.nmod_poly_factor cimport * +from flint.flintlib.flint cimport mp_limb_t, ulong, slong from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly +from flint.types.fmpz cimport fmpz, any_as_fmpz +from flint.types.fmpz_poly cimport fmpz_poly +from flint.types.nmod cimport nmod + +from flint.flintlib.nmod_poly cimport ( + nmod_poly_init, + nmod_poly_clear, + nmod_poly_set, + nmod_poly_fit_length, + nmod_poly_zero, + nmod_poly_set_coeff_ui, + nmod_poly_get_coeff_ui, + nmod_poly_length, + nmod_poly_degree, + nmod_poly_modulus, + nmod_poly_equal, + nmod_poly_is_zero, + nmod_poly_is_one, + nmod_poly_is_gen, + nmod_poly_reverse, + nmod_poly_evaluate_nmod, + nmod_poly_derivative, + nmod_poly_integral, + nmod_poly_neg, + nmod_poly_add, + nmod_poly_sub, + nmod_poly_mul, + nmod_poly_div, + nmod_poly_divrem, + nmod_poly_pow, + nmod_poly_compose, + nmod_poly_compose_mod, + nmod_poly_powmod_ui_binexp, + nmod_poly_powmod_fmpz_binexp_preinv, + nmod_poly_powmod_x_fmpz_preinv, + nmod_poly_inv_series, + nmod_poly_sqrt, + nmod_poly_xgcd, + nmod_poly_gcd, + nmod_poly_deflation, + nmod_poly_deflate, +) + +from flint.flintlib.nmod_poly_factor cimport ( + nmod_poly_factor_t, + nmod_poly_factor_init, + nmod_poly_factor_clear, + nmod_poly_factor, + nmod_poly_factor_with_berlekamp, + nmod_poly_factor_with_cantor_zassenhaus, + nmod_poly_factor_squarefree, +) + from flint.utils.flint_exceptions import DomainError -cdef dict _nmod_poly_ctx_cache = {} +_nmod_poly_ctx_cache = {} @cython.no_gc @@ -37,77 +83,6 @@ cdef class nmod_poly_ctx: """Get an ``nmod_poly`` context with modulus ``mod``.""" return nmod_poly_ctx.any_as_nmod_poly_ctx(mod) - @staticmethod - cdef nmod_poly_ctx any_as_nmod_poly_ctx(obj): - """Convert an ``nmod_poly_ctx`` or ``int`` to an ``nmod_poly_ctx``.""" - if typecheck(obj, nmod_poly_ctx): - return obj - if typecheck(obj, int): - return nmod_poly_ctx._get_ctx(obj) - raise TypeError("Invalid context/modulus for nmod_poly: %s" % obj) - - @staticmethod - cdef nmod_poly_ctx _get_ctx(int mod): - """Retrieve an nmod_poly context from the cache or create a new one.""" - ctx = _nmod_poly_ctx_cache.get(mod) - if ctx is None: - ctx = _nmod_poly_ctx_cache.setdefault(mod, nmod_poly_ctx._new_ctx(mod)) - return ctx - - @staticmethod - cdef nmod_poly_ctx _new_ctx(ulong mod): - """Create a new nmod_poly context.""" - cdef nmod_ctx scalar_ctx - cdef nmod_poly_ctx ctx - scalar_ctx = nmod_ctx.new(mod) - - ctx = nmod_poly_ctx.__new__(nmod_poly_ctx) - ctx.mod = scalar_ctx.mod - ctx._is_prime = scalar_ctx._is_prime - ctx.scalar_ctx = scalar_ctx - - return ctx - - cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: - return self.scalar_ctx.any_as_nmod(val, obj) - - cdef any_as_nmod_poly(self, obj): - cdef nmod_poly r - cdef mp_limb_t v - # XXX: should check that modulus is the same here, and not all over the place - if typecheck(obj, nmod_poly): - return obj - if self.any_as_nmod(&v, obj): - r = self.new_nmod_poly() - nmod_poly_set_coeff_ui(r.val, 0, v) - return r - x = any_as_fmpz_poly(obj) - if x is not NotImplemented: - r = self.new_nmod_poly() - fmpz_poly_get_nmod_poly(r.val, (x).val) - return r - return NotImplemented - - cdef nmod new_nmod(self): - return self.scalar_ctx.new_nmod() - - cdef nmod_poly new_nmod_poly(self): - cdef nmod_poly p = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(p.val, self.mod.n, self.mod.ninv) - p.ctx = self - return p - - cdef nmod_poly_set_list(self, nmod_poly_t poly, list val): - cdef long i, n - cdef mp_limb_t v - n = PyList_GET_SIZE(val) - nmod_poly_fit_length(poly, n) - for i from 0 <= i < n: - if self.any_as_nmod(&v, val[i]): - nmod_poly_set_coeff_ui(poly, i, v) - else: - raise TypeError("unsupported coefficient in list") - def modulus(self): """Get the modulus of the context.