diff --git a/src/modint.rs b/src/modint.rs index 10e261f..82f9634 100644 --- a/src/modint.rs +++ b/src/modint.rs @@ -673,7 +673,7 @@ pub trait ModIntBase: #[inline] fn pow(self, mut n: u64) -> Self { let mut x = self; - let mut r = Self::raw(1); + let mut r = Self::raw(u32::from(Self::modulus() > 1)); while n > 0 { if n & 1 == 1 { r *= x; @@ -1046,9 +1046,9 @@ macro_rules! impl_folding { impl_folding! { impl Sum<_> for StaticModInt { fn sum(_) -> _ { _(Self::raw(0), Add::add) } } - impl Product<_> for StaticModInt { fn product(_) -> _ { _(Self::raw(1), Mul::mul) } } + impl Product<_> for StaticModInt { fn product(_) -> _ { _(Self::raw(u32::from(Self::modulus() > 1)), Mul::mul) } } impl Sum<_> for DynamicModInt { fn sum(_) -> _ { _(Self::raw(0), Add::add) } } - impl Product<_> for DynamicModInt { fn product(_) -> _ { _(Self::raw(1), Mul::mul) } } + impl Product<_> for DynamicModInt { fn product(_) -> _ { _(Self::raw(u32::from(Self::modulus() > 1)), Mul::mul) } } } #[cfg(test)] @@ -1163,6 +1163,19 @@ mod tests { assert_eq!(expected, c); } + // Corner cases of "modint" when mod = 1 + // https://github.com/rust-lang-ja/ac-library-rs/issues/110 + #[test] + fn mod1_corner_case() { + ModInt::set_modulus(1); // !! + + let x: ModInt = std::iter::empty::().product(); + assert_eq!(x.val(), 0); + + let y = ModInt::new(123).pow(0); + assert_eq!(y.val(), 0); + } + // test `2^31 < modulus < 2^32` case // https://github.com/rust-lang-ja/ac-library-rs/issues/111 #[test]