Skip to content

Commit

Permalink
Merge pull request #256 from dhardy/distribution
Browse files Browse the repository at this point in the history
Replace distribution::Sample with Distribution + polymorphism over Rng
  • Loading branch information
dhardy authored Feb 20, 2018
2 parents 18e8e91 + 11b2b45 commit 8ce7435
Show file tree
Hide file tree
Showing 14 changed files with 788 additions and 687 deletions.
4 changes: 2 additions & 2 deletions benches/distributions/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use std::mem::size_of;
use test::Bencher;
use rand;
use rand::distributions::exponential::Exp;
use rand::distributions::Sample;
use rand::distributions::Distribution;

#[bench]
fn rand_exp(b: &mut Bencher) {
let mut rng = rand::weak_rng();
let mut exp = Exp::new(2.71828 * 3.14159);
let exp = Exp::new(2.71828 * 3.14159);

b.iter(|| {
for _ in 0..::RAND_BENCH_N {
Expand Down
6 changes: 3 additions & 3 deletions benches/distributions/gamma.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::mem::size_of;
use test::Bencher;
use rand;
use rand::distributions::IndependentSample;
use rand::distributions::Distribution;
use rand::distributions::gamma::Gamma;

#[bench]
Expand All @@ -11,7 +11,7 @@ fn bench_gamma_large_shape(b: &mut Bencher) {

b.iter(|| {
for _ in 0..::RAND_BENCH_N {
gamma.ind_sample(&mut rng);
gamma.sample(&mut rng);
}
});
b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;
Expand All @@ -24,7 +24,7 @@ fn bench_gamma_small_shape(b: &mut Bencher) {

b.iter(|| {
for _ in 0..::RAND_BENCH_N {
gamma.ind_sample(&mut rng);
gamma.sample(&mut rng);
}
});
b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;
Expand Down
4 changes: 2 additions & 2 deletions benches/distributions/normal.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::mem::size_of;
use test::Bencher;
use rand;
use rand::distributions::Sample;
use rand::distributions::Distribution;
use rand::distributions::normal::Normal;

#[bench]
fn rand_normal(b: &mut Bencher) {
let mut rng = rand::weak_rng();
let mut normal = Normal::new(-2.71828, 3.14159);
let normal = Normal::new(-2.71828, 3.14159);

b.iter(|| {
for _ in 0..::RAND_BENCH_N {
Expand Down
8 changes: 4 additions & 4 deletions benches/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const BYTES_LEN: usize = 1024;
use std::mem::size_of;
use test::{black_box, Bencher};

use rand::{RngCore, Rng, NewRng, StdRng, OsRng, JitterRng, EntropyRng};
use rand::{RngCore, Rng, SeedableRng, NewRng, StdRng, OsRng, JitterRng, EntropyRng};
use rand::{XorShiftRng, Hc128Rng, IsaacRng, Isaac64Rng, ChaChaRng};
use rand::reseeding::ReseedingRng;

Expand Down Expand Up @@ -41,7 +41,7 @@ macro_rules! gen_uint {
($fnn:ident, $ty:ty, $gen:ident) => {
#[bench]
fn $fnn(b: &mut Bencher) {
let mut rng: $gen = OsRng::new().unwrap().gen();
let mut rng = $gen::new().unwrap();
b.iter(|| {
for _ in 0..RAND_BENCH_N {
black_box(rng.gen::<$ty>());
Expand Down Expand Up @@ -96,9 +96,9 @@ macro_rules! init_gen {
($fnn:ident, $gen:ident) => {
#[bench]
fn $fnn(b: &mut Bencher) {
let mut rng: XorShiftRng = OsRng::new().unwrap().gen();
let mut rng = XorShiftRng::new().unwrap();
b.iter(|| {
let r2: $gen = rng.gen();
let r2 = $gen::from_rng(&mut rng).unwrap();
black_box(r2);
});
}
Expand Down
51 changes: 25 additions & 26 deletions src/distributions/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@

//! The exponential distribution.
use {Rng, Rand};
use distributions::{ziggurat, ziggurat_tables, Sample, IndependentSample};
use {Rng};
use distributions::{ziggurat, ziggurat_tables, Distribution};

/// A wrapper around an `f64` to generate Exp(1) random numbers.
/// Samples floating-point numbers according to the exponential distribution,
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
/// sampling with `-rng.gen::<f64>().ln()`, but faster.
///
/// See `Exp` for the general exponential distribution.
///
Expand All @@ -27,33 +29,33 @@ use distributions::{ziggurat, ziggurat_tables, Sample, IndependentSample};
/// College, Oxford
///
/// # Example
///
/// ```rust
/// use rand::distributions::exponential::Exp1;
/// use rand::{weak_rng, Rng};
/// use rand::distributions::Exp1;
///
/// let Exp1(x) = rand::random();
/// println!("{}", x);
/// let val: f64 = weak_rng().sample(Exp1);
/// println!("{}", val);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Exp1(pub f64);
pub struct Exp1;

// This could be done via `-rng.gen::<f64>().ln()` but that is slower.
impl Rand for Exp1 {
impl Distribution<f64> for Exp1 {
#[inline]
fn rand<R:Rng>(rng: &mut R) -> Exp1 {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
#[inline]
fn pdf(x: f64) -> f64 {
(-x).exp()
}
#[inline]
fn zero_case<R:Rng>(rng: &mut R, _u: f64) -> f64 {
fn zero_case<R: Rng + ?Sized>(rng: &mut R, _u: f64) -> f64 {
ziggurat_tables::ZIG_EXP_R - rng.gen::<f64>().ln()
}

Exp1(ziggurat(rng, false,
&ziggurat_tables::ZIG_EXP_X,
&ziggurat_tables::ZIG_EXP_F,
pdf, zero_case))
ziggurat(rng, false,
&ziggurat_tables::ZIG_EXP_X,
&ziggurat_tables::ZIG_EXP_F,
pdf, zero_case)
}
}

Expand All @@ -65,10 +67,10 @@ impl Rand for Exp1 {
/// # Example
///
/// ```rust
/// use rand::distributions::{Exp, IndependentSample};
/// use rand::distributions::{Exp, Distribution};
///
/// let exp = Exp::new(2.0);
/// let v = exp.ind_sample(&mut rand::thread_rng());
/// let v = exp.sample(&mut rand::thread_rng());
/// println!("{} is from a Exp(2) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
Expand All @@ -87,28 +89,25 @@ impl Exp {
}
}

impl Sample<f64> for Exp {
fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 { self.ind_sample(rng) }
}
impl IndependentSample<f64> for Exp {
fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
let Exp1(n) = rng.gen::<Exp1>();
impl Distribution<f64> for Exp {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let n: f64 = rng.sample(Exp1);
n * self.lambda_inverse
}
}

#[cfg(test)]
mod test {
use distributions::{Sample, IndependentSample};
use distributions::Distribution;
use super::Exp;

#[test]
fn test_exp() {
let mut exp = Exp::new(10.0);
let exp = Exp::new(10.0);
let mut rng = ::test::rng(221);
for _ in 0..1000 {
assert!(exp.sample(&mut rng) >= 0.0);
assert!(exp.ind_sample(&mut rng) >= 0.0);
assert!(exp.sample(&mut rng) >= 0.0);
}
}
#[test]
Expand Down
199 changes: 199 additions & 0 deletions src/distributions/float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Basic floating-point number distributions

/// A distribution to sample floating point numbers uniformly in the open
/// interval `(0, 1)` (not including either endpoint).
///
/// See also: [`Closed01`] for the closed `[0, 1]`; [`Uniform`] for the
/// half-open `[0, 1)`.
///
/// # Example
/// ```rust
/// use rand::{weak_rng, Rng};
/// use rand::distributions::Open01;
///
/// let val: f32 = weak_rng().sample(Open01);
/// println!("f32 from (0,1): {}", val);
/// ```
///
/// [`Uniform`]: struct.Uniform.html
/// [`Closed01`]: struct.Closed01.html
#[derive(Clone, Copy, Debug)]
pub struct Open01;

/// A distribution to sample floating point numbers uniformly in the closed
/// interval `[0, 1]` (including both endpoints).
///
/// See also: [`Open01`] for the open `(0, 1)`; [`Uniform`] for the half-open
/// `[0, 1)`.
///
/// # Example
/// ```rust
/// use rand::{weak_rng, Rng};
/// use rand::distributions::Closed01;
///
/// let val: f32 = weak_rng().sample(Closed01);
/// println!("f32 from [0,1]: {}", val);
/// ```
///
/// [`Uniform`]: struct.Uniform.html
/// [`Open01`]: struct.Open01.html
#[derive(Clone, Copy, Debug)]
pub struct Closed01;


macro_rules! float_impls {
($mod_name:ident, $ty:ty, $mantissa_bits:expr, $method_name:ident) => {
mod $mod_name {
use Rng;
use distributions::{Distribution, Uniform};
use super::{Open01, Closed01};

const SCALE: $ty = (1u64 << $mantissa_bits) as $ty;

impl Distribution<$ty> for Uniform {
/// Generate a floating point number in the half-open
/// interval `[0,1)`.
///
/// See `Closed01` for the closed interval `[0,1]`,
/// and `Open01` for the open interval `(0,1)`.
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
rng.$method_name()
}
}
impl Distribution<$ty> for Open01 {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
// add 0.5 * epsilon, so that smallest number is
// greater than 0, and largest number is still
// less than 1, specifically 1 - 0.5 * epsilon.
rng.$method_name() + 0.5 / SCALE
}
}
impl Distribution<$ty> for Closed01 {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
// rescale so that 1.0 - epsilon becomes 1.0
// precisely.
rng.$method_name() * SCALE / (SCALE - 1.0)
}
}
}
}
}
float_impls! { f64_rand_impls, f64, 52, next_f64 }
float_impls! { f32_rand_impls, f32, 23, next_f32 }


#[cfg(test)]
mod tests {
use {Rng, RngCore, impls};
use distributions::{Open01, Closed01};

const EPSILON32: f32 = ::core::f32::EPSILON;
const EPSILON64: f64 = ::core::f64::EPSILON;

struct ConstantRng(u64);
impl RngCore for ConstantRng {
fn next_u32(&mut self) -> u32 {
let ConstantRng(v) = *self;
v as u32
}
fn next_u64(&mut self) -> u64 {
let ConstantRng(v) = *self;
v
}

fn fill_bytes(&mut self, dest: &mut [u8]) {
impls::fill_bytes_via_u64(self, dest)
}
}

#[test]
fn floating_point_edge_cases() {
let mut zeros = ConstantRng(0);
assert_eq!(zeros.gen::<f32>(), 0.0);
assert_eq!(zeros.gen::<f64>(), 0.0);

let mut one = ConstantRng(1);
assert_eq!(one.gen::<f32>(), EPSILON32);
assert_eq!(one.gen::<f64>(), EPSILON64);

let mut max = ConstantRng(!0);
assert_eq!(max.gen::<f32>(), 1.0 - EPSILON32);
assert_eq!(max.gen::<f64>(), 1.0 - EPSILON64);
}

#[test]
fn fp_closed_edge_cases() {
let mut zeros = ConstantRng(0);
assert_eq!(zeros.sample::<f32, _>(Closed01), 0.0);
assert_eq!(zeros.sample::<f64, _>(Closed01), 0.0);

let mut one = ConstantRng(1);
let one32 = one.sample::<f32, _>(Closed01);
let one64 = one.sample::<f64, _>(Closed01);
assert!(EPSILON32 < one32 && one32 < EPSILON32 * 1.01);
assert!(EPSILON64 < one64 && one64 < EPSILON64 * 1.01);

let mut max = ConstantRng(!0);
assert_eq!(max.sample::<f32, _>(Closed01), 1.0);
assert_eq!(max.sample::<f64, _>(Closed01), 1.0);
}

#[test]
fn fp_open_edge_cases() {
let mut zeros = ConstantRng(0);
assert_eq!(zeros.sample::<f32, _>(Open01), 0.0 + EPSILON32 / 2.0);
assert_eq!(zeros.sample::<f64, _>(Open01), 0.0 + EPSILON64 / 2.0);

let mut one = ConstantRng(1);
let one32 = one.sample::<f32, _>(Open01);
let one64 = one.sample::<f64, _>(Open01);
assert!(EPSILON32 < one32 && one32 < EPSILON32 * 2.0);
assert!(EPSILON64 < one64 && one64 < EPSILON64 * 2.0);

let mut max = ConstantRng(!0);
assert_eq!(max.sample::<f32, _>(Open01), 1.0 - EPSILON32 / 2.0);
assert_eq!(max.sample::<f64, _>(Open01), 1.0 - EPSILON64 / 2.0);
}

#[test]
fn rand_open() {
// this is unlikely to catch an incorrect implementation that
// generates exactly 0 or 1, but it keeps it sane.
let mut rng = ::test::rng(510);
for _ in 0..1_000 {
// strict inequalities
let f: f64 = rng.sample(Open01);
assert!(0.0 < f && f < 1.0);

let f: f32 = rng.sample(Open01);
assert!(0.0 < f && f < 1.0);
}
}

#[test]
fn rand_closed() {
let mut rng = ::test::rng(511);
for _ in 0..1_000 {
// strict inequalities
let f: f64 = rng.sample(Closed01);
assert!(0.0 <= f && f <= 1.0);

let f: f32 = rng.sample(Closed01);
assert!(0.0 <= f && f <= 1.0);
}
}
}
Loading

0 comments on commit 8ce7435

Please sign in to comment.