diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index a0a85e94..1e02b34e 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -82,7 +82,7 @@ impl Laplace { impl ::rand::distributions::Distribution for Laplace { fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen_range(-0.5..0.5); - self.location - self.scale * x.signum() * (1. - 2. * x).ln() + self.location - self.scale * x.signum() * (1. - 2. * x.abs()).ln() } } @@ -458,4 +458,38 @@ mod tests { let l = try_create(0.1, 0.5); l.sample(&mut thread_rng()); } + + #[test] + fn test_sample_distribution() { + use ::rand::rngs::StdRng; + use ::rand::SeedableRng; + use rand::distributions::Distribution; + + // sanity check sampling + let location = 0.0; + let scale = 1.0; + let n = try_create(location, scale); + let trials = 10_000; + let tolerance = 250; + + for seed in 0..10 { + let mut r: StdRng = SeedableRng::seed_from_u64(seed); + + let result = (0..trials).map(|_| n.sample(&mut r)).fold(0, |sum, val| { + if val > 0.0 { + sum + 1 + } else if val < 0.0 { + sum - 1 + } else { + 0 + } + }); + assert!( + result > -tolerance && result < tolerance, + "Balance is {} for seed {}", + result, + seed + ); + } + } }