From 64163e138df1d4233de5940280583dc19c6771c9 Mon Sep 17 00:00:00 2001 From: Mizar Date: Sat, 21 Jan 2023 00:53:15 +0900 Subject: [PATCH 1/3] internal_math: mul_mod fix for 2^31 u32 { // [1] m = 1 // a = b = im = 0, so okay // [2] m >= 2 - // im = ceil(2^64 / m) + // im = ceil(2^64 / m) = floor((2^64 - 1) / m) + 1 // -> im * m = 2^64 + r (0 <= r < m) // let z = a*b = c*m + d (0 <= c, d < m) // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2 // ((ab * im) >> 64) == c or c + 1 - let mut z = a as u64; - z *= b as u64; + let z = (a as u64) * (b as u64); let x = (((z as u128) * (im as u128)) >> 64) as u64; - let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32; - if m <= v { - v = v.wrapping_add(m); + match z.overflowing_sub(x.wrapping_mul(m as u64)) { + (v, true) => (v as u32).wrapping_add(m), + (v, false) => v as u32, } - v } /// # Parameters @@ -280,6 +278,14 @@ mod tests { let b = Barrett::new(2147483647); assert_eq!(b.umod(), 2147483647); assert_eq!(b.mul(1073741824, 2147483645), 2147483646); + + // test `2^31 < self._m < 2^32` case. + let b = Barrett::new(3221225471); + assert_eq!(b.umod(), 3221225471); + assert_eq!(b.mul(3188445886, 2844002853), 1840468257); + assert_eq!(b.mul(2834869488, 2779159607), 2084027561); + assert_eq!(b.mul(3032263594, 3039996727), 2130247251); + assert_eq!(b.mul(3029175553, 3140869278), 1892378237); } #[test] From ab2c3ed512c4470374c09cd5ce0089af8c8894f4 Mon Sep 17 00:00:00 2001 From: Mizar Date: Sat, 21 Jan 2023 00:52:08 +0900 Subject: [PATCH 2/3] modint: add/sub impl fix for 2^31 Self { let modulus = Self::modulus(); - let mut val = lhs.val() + rhs.val(); - if val >= modulus { - val -= modulus; - } + let v = u64::from(lhs.val()) + u64::from(rhs.val()); + let val = match v.overflowing_sub(u64::from(modulus)) { + (_, true) => v as u32, + (w, false) => w as u32, + }; Self::raw(val) } #[inline] fn sub_impl(lhs: Self, rhs: Self) -> Self { let modulus = Self::modulus(); - let mut val = lhs.val().wrapping_sub(rhs.val()); - if val >= modulus { - val = val.wrapping_add(modulus) - } + let val = match lhs.val().overflowing_sub(rhs.val()) { + (v, true) => v.wrapping_add(modulus), + (v, false) => v, + }; Self::raw(val) } @@ -1050,6 +1051,8 @@ impl_folding! { #[cfg(test)] mod tests { + #![allow(clippy::unreadable_literal)] + use crate::modint::ModInt; use crate::modint::ModInt1000000007; #[test] @@ -1157,4 +1160,29 @@ mod tests { c /= b; assert_eq!(expected, c); } + + // test `2^31 < modulus < 2^32` case + // https://github.com/rust-lang-ja/ac-library-rs/issues/111 + #[test] + fn dynamic_modint_m32() { + let m = 3221225471; + ModInt::set_modulus(m); + let f = ModInt::new::; + assert_eq!(f(1398188832) + f(3184083880), f(1361047241)); + assert_eq!(f(3013899062) + f(2238406135), f(2031079726)); + assert_eq!(f(2699997885) + f(2745140255), f(2223912669)); + assert_eq!(f(2824399978) + f(2531872141), f(2135046648)); + assert_eq!(f(36496612) - f(2039504668), f(1218217415)); + assert_eq!(f(266176802) - f(1609833977), f(1877568296)); + assert_eq!(f(713535382) - f(2153383999), f(1781376854)); + assert_eq!(f(1249965147) - f(3144251805), f(1326938813)); + assert_eq!(f(2692223381) * f(2935379475), f(2084179397)); + assert_eq!(f(2800462205) * f(2822998916), f(2089431198)); + assert_eq!(f(3061947734) * f(3210920667), f(1962208034)); + assert_eq!(f(3138997926) * f(2994465129), f(1772479317)); + assert_eq!(f(2947552629) / f(576466398), f(2041593039)); + assert_eq!(f(2914694891) / f(399734126), f(1983162347)); + assert_eq!(f(2202862138) / f(1154428799), f(2139936238)); + assert_eq!(f(3037207894) / f(2865447143), f(1894581230)); + } } From 122e2dd4b9c395852b5a1c96d8439de074fb417f Mon Sep 17 00:00:00 2001 From: Mizar Date: Sat, 21 Jan 2023 16:38:25 +0900 Subject: [PATCH 3/3] Corner cases of "modint" when mod = 1 https://github.com/rust-lang-ja/ac-library-rs/issues/110 --- src/modint.rs | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/modint.rs b/src/modint.rs index 38ffde6..dbb093a 100644 --- a/src/modint.rs +++ b/src/modint.rs @@ -673,7 +673,7 @@ pub trait ModIntBase: #[inline] fn pow(self, mut n: u64) -> Self { let mut x = self; - let mut r = Self::raw(1); + let mut r = Self::raw(u32::from(Self::modulus() > 1)); while n > 0 { if n & 1 == 1 { r *= x; @@ -1044,9 +1044,9 @@ macro_rules! impl_folding { impl_folding! { impl Sum<_> for StaticModInt { fn sum(_) -> _ { _(Self::raw(0), Add::add) } } - impl Product<_> for StaticModInt { fn product(_) -> _ { _(Self::raw(1), Mul::mul) } } + impl Product<_> for StaticModInt { fn product(_) -> _ { _(Self::raw(u32::from(Self::modulus() > 1)), Mul::mul) } } impl Sum<_> for DynamicModInt { fn sum(_) -> _ { _(Self::raw(0), Add::add) } } - impl Product<_> for DynamicModInt { fn product(_) -> _ { _(Self::raw(1), Mul::mul) } } + impl Product<_> for DynamicModInt { fn product(_) -> _ { _(Self::raw(u32::from(Self::modulus() > 1)), Mul::mul) } } } #[cfg(test)] @@ -1161,6 +1161,19 @@ mod tests { assert_eq!(expected, c); } + // Corner cases of "modint" when mod = 1 + // https://github.com/rust-lang-ja/ac-library-rs/issues/110 + #[test] + fn mod1_corner_case() { + ModInt::set_modulus(1); // !! + + let x: ModInt = std::iter::empty::().product(); + assert_eq!(x.val(), 0); + + let y = ModInt::new(123).pow(0); + assert_eq!(y.val(), 0); + } + // test `2^31 < modulus < 2^32` case // https://github.com/rust-lang-ja/ac-library-rs/issues/111 #[test]