From d09b3f48a915e0cd9cc2954e69295db70eefa05b Mon Sep 17 00:00:00 2001 From: SabrinaJewson Date: Fri, 30 Aug 2024 16:58:41 +0100 Subject: [PATCH] Support sampling integers from discrete distributions --- src/distribution/bernoulli.rs | 8 +++++++- src/distribution/binomial.rs | 14 ++++++++++---- src/distribution/categorical.rs | 22 +++++++++++++++------- src/distribution/discrete_uniform.rs | 6 ++++++ src/distribution/geometric.rs | 17 +++++++++++++---- src/distribution/hypergeometric.rs | 14 ++++++++++---- src/distribution/multinomial.rs | 20 ++++++++++++++++++-- src/distribution/negative_binomial.rs | 2 +- src/distribution/poisson.rs | 12 ++++++++++++ 9 files changed, 92 insertions(+), 23 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index 61499ebd..79dc22f8 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -86,9 +86,15 @@ impl std::fmt::Display for Bernoulli { } } +impl ::rand::distributions::Distribution for Bernoulli { + fn sample(&self, rng: &mut R) -> bool { + rng.gen_bool(self.p()) + } +} + impl ::rand::distributions::Distribution for Bernoulli { fn sample(&self, rng: &mut R) -> f64 { - rng.gen_bool(self.p()) as u8 as f64 + rng.sample::(self) as u8 as f64 } } diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 07060bfc..d4ce1757 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -92,12 +92,12 @@ impl std::fmt::Display for Binomial { } } -impl ::rand::distributions::Distribution for Binomial { - fn sample(&self, rng: &mut R) -> f64 { - (0..self.n).fold(0.0, |acc, _| { +impl ::rand::distributions::Distribution for Binomial { + fn sample(&self, rng: &mut R) -> u64 { + (0..self.n).fold(0, |acc, _| { let n: f64 = rng.gen(); if n < self.p { - acc + 1.0 + acc + 1 } else { acc } @@ -105,6 +105,12 @@ impl ::rand::distributions::Distribution for Binomial { } } +impl ::rand::distributions::Distribution for Binomial { + fn sample(&self, rng: &mut R) -> f64 { + rng.sample::(self) as f64 + } +} + impl DiscreteCDF for Binomial { /// Calculates the cumulative distribution function for the /// binomial distribution at `x` diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 31bccf8b..631c2b29 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -83,9 +83,21 @@ impl std::fmt::Display for Categorical { } } +impl ::rand::distributions::Distribution for Categorical { + fn sample(&self, rng: &mut R) -> usize { + sample_unchecked(rng, &self.cdf) + } +} + +impl ::rand::distributions::Distribution for Categorical { + fn sample(&self, rng: &mut R) -> u64 { + sample_unchecked(rng, &self.cdf) as u64 + } +} + impl ::rand::distributions::Distribution for Categorical { fn sample(&self, rng: &mut R) -> f64 { - sample_unchecked(rng, &self.cdf) + sample_unchecked(rng, &self.cdf) as f64 } } @@ -281,13 +293,9 @@ impl Discrete for Categorical { /// Draws a sample from the categorical distribution described by `cdf` /// without doing any bounds checking -pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> f64 { +pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> usize { let draw = rng.gen::() * cdf.last().unwrap(); - cdf.iter() - .enumerate() - .find(|(_, val)| **val >= draw) - .map(|(i, _)| i) - .unwrap() as f64 + cdf.iter().position(|val| *val >= draw).unwrap() } /// Computes the cdf from the given probability masses. Performs diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 361cadd8..7d95ac33 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -57,6 +57,12 @@ impl std::fmt::Display for DiscreteUniform { } } +impl ::rand::distributions::Distribution for DiscreteUniform { + fn sample(&self, rng: &mut R) -> i64 { + rng.gen_range(self.min..=self.max) + } +} + impl ::rand::distributions::Distribution for DiscreteUniform { fn sample(&self, rng: &mut R) -> f64 { rng.gen_range(self.min..=self.max) as f64 diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index 4df623ed..983468d7 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -74,17 +74,26 @@ impl std::fmt::Display for Geometric { } } -impl ::rand::distributions::Distribution for Geometric { - fn sample(&self, r: &mut R) -> f64 { +impl ::rand::distributions::Distribution for Geometric { + fn sample(&self, r: &mut R) -> u64 { if ulps_eq!(self.p, 1.0) { - 1.0 + 1 } else { let x: f64 = r.sample(OpenClosed01); - x.log(1.0 - self.p).ceil() + // This cast is safe, because the largest finite value this expression can take is when + // `x = 1.4e-45` and `1.0 - self.p = 0.9999999999999999`, in which case we get + // `930262250532780300`, which when casted to a `u64` is `930262250532780288`. + x.log(1.0 - self.p).ceil() as u64 } } } +impl ::rand::distributions::Distribution for Geometric { + fn sample(&self, r: &mut R) -> f64 { + r.sample::(self) as f64 + } +} + impl DiscreteCDF for Geometric { /// Calculates the cumulative distribution function for the geometric /// distribution at `x` diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 8b6d8500..8960f2c3 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -120,17 +120,17 @@ impl std::fmt::Display for Hypergeometric { } } -impl ::rand::distributions::Distribution for Hypergeometric { - fn sample(&self, rng: &mut R) -> f64 { +impl ::rand::distributions::Distribution for Hypergeometric { + fn sample(&self, rng: &mut R) -> u64 { let mut population = self.population as f64; let mut successes = self.successes as f64; let mut draws = self.draws; - let mut x = 0.0; + let mut x = 0; loop { let p = successes / population; let next: f64 = rng.gen(); if next < p { - x += 1.0; + x += 1; successes -= 1.0; } population -= 1.0; @@ -143,6 +143,12 @@ impl ::rand::distributions::Distribution for Hypergeometric { } } +impl ::rand::distributions::Distribution for Hypergeometric { + fn sample(&self, rng: &mut R) -> f64 { + rng.sample::(self) as f64 + } +} + impl DiscreteCDF for Hypergeometric { /// Calculates the cumulative distribution function for the hypergeometric /// distribution at `x` diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index dc402050..1f955031 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -119,6 +119,23 @@ where } } +impl ::rand::distributions::Distribution> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn sample(&self, rng: &mut R) -> OVector { + let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice()); + let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>); + for _ in 0..self.n { + let i = super::categorical::sample_unchecked(rng, &p_cdf); + res[i] += 1; + } + res + } +} + impl ::rand::distributions::Distribution> for Multinomial where D: Dim, @@ -129,8 +146,7 @@ where let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>); for _ in 0..self.n { let i = super::categorical::sample_unchecked(rng, &p_cdf); - let el = res.get_mut(i as usize).unwrap(); - *el += 1.0; + res[i] += 1.0; } res } diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index 065c2239..098d8d32 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -113,7 +113,7 @@ impl std::fmt::Display for NegativeBinomial { impl ::rand::distributions::Distribution for NegativeBinomial { fn sample(&self, r: &mut R) -> u64 { let lambda = distribution::gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p); - poisson::sample_unchecked(r, lambda).floor() as u64 + poisson::sample_unchecked(r, lambda) as u64 } } diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 7653ed20..6de2f999 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -72,6 +72,17 @@ impl std::fmt::Display for Poisson { } } +impl ::rand::distributions::Distribution for Poisson { + /// Generates one sample from the Poisson distribution either by + /// Knuth's method if lambda < 30.0 or Rejection method PA by + /// A. C. Atkinson from the Journal of the Royal Statistical Society + /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 + /// otherwise + fn sample(&self, rng: &mut R) -> u64 { + sample_unchecked(rng, self.lambda) as u64 + } +} + impl ::rand::distributions::Distribution for Poisson { /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by @@ -260,6 +271,7 @@ impl Discrete for Poisson { -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x) } } + /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society