diff --git a/Cargo.toml b/Cargo.toml index df3a2233..3fe47f2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ name = "shootout-pidigits" [dependencies] [dependencies.num-integer] -version = "0.1.38" +version = "0.1.39" default-features = false [dependencies.num-traits] diff --git a/benches/bigint.rs b/benches/bigint.rs index 7a763c2d..6893f137 100644 --- a/benches/bigint.rs +++ b/benches/bigint.rs @@ -4,6 +4,7 @@ extern crate test; extern crate num_bigint; extern crate num_traits; +extern crate num_integer; extern crate rand; use std::mem::replace; @@ -342,3 +343,27 @@ fn modpow_even(b: &mut Bencher) { b.iter(|| base.modpow(&e, &m)); } + +#[bench] +fn roots_sqrt(b: &mut Bencher) { + let mut rng = get_rng(); + let x = rng.gen_biguint(2048); + + b.iter(|| x.sqrt()); +} + +#[bench] +fn roots_cbrt(b: &mut Bencher) { + let mut rng = get_rng(); + let x = rng.gen_biguint(2048); + + b.iter(|| x.cbrt()); +} + +#[bench] +fn roots_nth_100(b: &mut Bencher) { + let mut rng = get_rng(); + let x = rng.gen_biguint(2048); + + b.iter(|| x.nth_root(100)); +} diff --git a/src/bigint.rs b/src/bigint.rs index 3c8d2962..93bb6b26 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -16,7 +16,7 @@ use std::iter::{Product, Sum}; #[cfg(feature = "serde")] use serde; -use integer::Integer; +use integer::{Integer, Roots}; use traits::{ToPrimitive, FromPrimitive, Num, CheckedAdd, CheckedSub, CheckedMul, CheckedDiv, Signed, Zero, One}; @@ -1802,6 +1802,25 @@ impl Integer for BigInt { } } +impl Roots for BigInt { + fn nth_root(&self, n: u32) -> Self { + assert!(!(self.is_negative() && n.is_even()), + "root of degree {} is imaginary", n); + + BigInt::from_biguint(self.sign, self.data.nth_root(n)) + } + + fn sqrt(&self) -> Self { + assert!(!self.is_negative(), "square root is imaginary"); + + BigInt::from_biguint(self.sign, self.data.sqrt()) + } + + fn cbrt(&self) -> Self { + BigInt::from_biguint(self.sign, self.data.cbrt()) + } +} + impl ToPrimitive for BigInt { #[inline] fn to_i64(&self) -> Option { @@ -2538,6 +2557,24 @@ impl BigInt { }; BigInt::from_biguint(sign, mag) } + + /// Returns the truncated principal square root of `self` -- + /// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt). + pub fn sqrt(&self) -> Self { + Roots::sqrt(self) + } + + /// Returns the truncated principal cube root of `self` -- + /// see [Roots::cbrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.cbrt). + pub fn cbrt(&self) -> Self { + Roots::cbrt(self) + } + + /// Returns the truncated principal `n`th root of `self` -- + /// See [Roots::nth_root](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#tymethod.nth_root). + pub fn nth_root(&self, n: u32) -> Self { + Roots::nth_root(self, n) + } } impl_sum_iter_type!(BigInt); diff --git a/src/biguint.rs b/src/biguint.rs index 5d4aaf89..e7a3ce19 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -17,9 +17,9 @@ use std::ascii::AsciiExt; #[cfg(feature = "serde")] use serde; -use integer::Integer; +use integer::{Integer, Roots}; use traits::{ToPrimitive, FromPrimitive, Float, Num, Unsigned, CheckedAdd, CheckedSub, CheckedMul, - CheckedDiv, Zero, One}; + CheckedDiv, Zero, One, pow}; use big_digit::{self, BigDigit, DoubleBigDigit}; @@ -1026,6 +1026,94 @@ impl Integer for BigUint { } } +impl Roots for BigUint { + // nth_root, sqrt and cbrt use Newton's method to compute + // principal root of a given degree for a given integer. + + // Reference: + // Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.14 + fn nth_root(&self, n: u32) -> Self { + assert!(n > 0, "root degree n must be at least 1"); + + if self.is_zero() || self.is_one() { + return self.clone() + } + + match n { // Optimize for small n + 1 => return self.clone(), + 2 => return self.sqrt(), + 3 => return self.cbrt(), + _ => (), + } + + let n = n as usize; + let n_min_1 = n - 1; + + let guess = BigUint::one() << (self.bits()/n + 1); + + let mut u = guess; + let mut s: BigUint; + + loop { + s = u; + let q = self / pow(s.clone(), n_min_1); + let t: BigUint = n_min_1 * &s + q; + + u = t / n; + + if u >= s { break; } + } + + s + } + + // Reference: + // Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 + fn sqrt(&self) -> Self { + if self.is_zero() || self.is_one() { + return self.clone() + } + + let guess = BigUint::one() << (self.bits()/2 + 1); + + let mut u = guess; + let mut s: BigUint; + + loop { + s = u; + let q = self / &s; + let t: BigUint = &s + q; + u = t >> 1; + + if u >= s { break; } + } + + s + } + + fn cbrt(&self) -> Self { + if self.is_zero() || self.is_one() { + return self.clone() + } + + let guess = BigUint::one() << (self.bits()/3 + 1); + + let mut u = guess; + let mut s: BigUint; + + loop { + s = u; + let q = self / (&s * &s); + let t: BigUint = (&s << 1) + q; + u = t / 3u32; + + if u >= s { break; } + } + + s + } +} + fn high_bits_to_u64(v: &BigUint) -> u64 { match v.data.len() { 0 => 0, @@ -1749,6 +1837,24 @@ impl BigUint { } acc } + + /// Returns the truncated principal square root of `self` -- + /// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt) + pub fn sqrt(&self) -> Self { + Roots::sqrt(self) + } + + /// Returns the truncated principal cube root of `self` -- + /// see [Roots::cbrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.cbrt). + pub fn cbrt(&self) -> Self { + Roots::cbrt(self) + } + + /// Returns the truncated principal `n`th root of `self` -- + /// see [Roots::nth_root](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#tymethod.nth_root). + pub fn nth_root(&self, n: u32) -> Self { + Roots::nth_root(self, n) + } } /// Returns the number of least-significant bits that are zero, diff --git a/tests/roots.rs b/tests/roots.rs new file mode 100644 index 00000000..58838b76 --- /dev/null +++ b/tests/roots.rs @@ -0,0 +1,104 @@ +extern crate num_bigint; +extern crate num_integer; +extern crate num_traits; + +mod biguint { + use num_bigint::BigUint; + use num_traits::pow; + use std::str::FromStr; + + fn check(x: u64, n: u32) { + let big_x = BigUint::from(x); + let res = big_x.nth_root(n); + + if n == 2 { + assert_eq!(&res, &big_x.sqrt()) + } else if n == 3 { + assert_eq!(&res, &big_x.cbrt()) + } + + assert!(pow(res.clone(), n as usize) <= big_x); + assert!(pow(res.clone() + 1u32, n as usize) > big_x); + } + + #[test] + fn test_sqrt() { + check(99, 2); + check(100, 2); + check(120, 2); + } + + #[test] + fn test_cbrt() { + check(8, 3); + check(26, 3); + } + + #[test] + fn test_nth_root() { + check(0, 1); + check(10, 1); + check(100, 4); + } + + #[test] + #[should_panic] + fn test_nth_root_n_is_zero() { + check(4, 0); + } + + #[test] + fn test_nth_root_big() { + let x = BigUint::from_str("123_456_789").unwrap(); + let expected = BigUint::from(6u32); + + assert_eq!(x.nth_root(10), expected); + } +} + +mod bigint { + use num_bigint::BigInt; + use num_traits::{Signed, pow}; + + fn check(x: i64, n: u32) { + let big_x = BigInt::from(x); + let res = big_x.nth_root(n); + + if n == 2 { + assert_eq!(&res, &big_x.sqrt()) + } else if n == 3 { + assert_eq!(&res, &big_x.cbrt()) + } + + if big_x.is_negative() { + assert!(pow(res.clone() - 1u32, n as usize) < big_x); + assert!(pow(res.clone(), n as usize) >= big_x); + } else { + assert!(pow(res.clone(), n as usize) <= big_x); + assert!(pow(res.clone() + 1u32, n as usize) > big_x); + } + } + + #[test] + fn test_nth_root() { + check(-100, 3); + } + + #[test] + #[should_panic] + fn test_nth_root_x_neg_n_even() { + check(-100, 4); + } + + #[test] + #[should_panic] + fn test_sqrt_x_neg() { + check(-4, 2); + } + + #[test] + fn test_cbrt() { + check(8, 3); + check(-8, 3); + } +}