diff --git a/src/lib.rs b/src/lib.rs index 0281954..c8c551d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,10 +21,11 @@ extern crate std; extern crate num_traits as traits; +use core::cmp::Ordering; use core::mem; -use core::ops::Add; +use core::ops::{Add, Neg, Shr}; -use traits::{Num, Signed, Zero}; +use traits::{Num, NumRef, RefNum, Signed, Zero}; mod roots; pub use roots::Roots; @@ -1013,6 +1014,84 @@ impl_integer_for_usize!(usize, test_integer_usize); #[cfg(has_i128)] impl_integer_for_usize!(u128, test_integer_u128); +/// Calculate greatest common divisor and the corresponding coefficients. +pub fn extended_gcd(a: T, b: T) -> ExtendedGcd +where + for<'a> &'a T: RefNum, +{ + // Euclid's extended algorithm + let (mut s, mut old_s) = (T::zero(), T::one()); + let (mut t, mut old_t) = (T::one(), T::zero()); + let (mut r, mut old_r) = (b, a); + + while r != T::zero() { + let quotient = &old_r / &r; + old_r = old_r - "ient * &r; + mem::swap(&mut old_r, &mut r); + old_s = old_s - "ient * &s; + mem::swap(&mut old_s, &mut s); + old_t = old_t - quotient * &t; + mem::swap(&mut old_t, &mut t); + } + + let _quotients = (t, s); // == (a, b) / gcd + + ExtendedGcd { + gcd: old_r, + x: old_s, + y: old_t, + _hidden: (), + } +} + +/// Find the standard representation of a (mod n). +pub fn normalize(a: T, n: &T) -> T { + let a = a % n; + match a.cmp(&T::zero()) { + Ordering::Less => a + n, + _ => a, + } +} + +/// Calculate the inverse of a (mod n). +pub fn inverse(a: T, n: &T) -> Option +where + for<'a> &'a T: RefNum, +{ + let ExtendedGcd { gcd, x: c, .. } = extended_gcd(a, n.clone()); + if gcd == T::one() { + Some(normalize(c, n)) + } else { + None + } +} + +/// Calculate base^exp (mod modulus). +pub fn powm(base: &T, exp: &T, modulus: &T) -> T +where + T: Integer + NumRef + Clone + Neg + Shr, + for<'a> &'a T: RefNum, +{ + let zero = T::zero(); + let one = T::one(); + let two = &one + &one; + let mut exp = exp.clone(); + let mut result = one.clone(); + let mut base = base % modulus; + if exp < zero { + exp = -exp; + base = inverse(base, modulus).unwrap(); + } + while exp > zero { + if &exp % &two == one { + result = (result * &base) % modulus; + } + exp = exp >> 1; + base = (&base * &base) % modulus; + } + result +} + /// An iterator over binomial coefficients. pub struct IterBinomial { a: T, @@ -1169,6 +1248,38 @@ fn test_lcm_overflow() { check!(u64, 0x8000_0000_0000_0000, 0x02, 0x8000_0000_0000_0000); } +#[test] +fn test_extended_gcd() { + assert_eq!( + extended_gcd(240, 46), + ExtendedGcd { + gcd: 2, + x: -9, + y: 47, + _hidden: () + } + ); +} + +#[test] +fn test_normalize() { + assert_eq!(normalize(10, &7), 3); + assert_eq!(normalize(7, &7), 0); + assert_eq!(normalize(5, &7), 5); + assert_eq!(normalize(-3, &7), 4); +} + +#[test] +fn test_inverse() { + assert_eq!(inverse(5, &7).unwrap(), 3); +} + +#[test] +fn test_powm() { + // `i64::pow` would overflow. + assert_eq!(powm(&11, &19, &7), 4); +} + #[test] fn test_iter_binomial() { macro_rules! check_simple {