Skip to content

Commit

Permalink
Support sampling integers from discrete distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
SabrinaJewson committed Aug 30, 2024
1 parent aa276c8 commit d09b3f4
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 23 deletions.
8 changes: 7 additions & 1 deletion src/distribution/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,15 @@ impl std::fmt::Display for Bernoulli {
}
}

impl ::rand::distributions::Distribution<bool> for Bernoulli {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
rng.gen_bool(self.p())
}

Check warning on line 92 in src/distribution/bernoulli.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/bernoulli.rs#L90-L92

Added lines #L90 - L92 were not covered by tests
}

impl ::rand::distributions::Distribution<f64> for Bernoulli {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.gen_bool(self.p()) as u8 as f64
rng.sample::<bool, _>(self) as u8 as f64

Check warning on line 97 in src/distribution/bernoulli.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/bernoulli.rs#L97

Added line #L97 was not covered by tests
}
}

Expand Down
14 changes: 10 additions & 4 deletions src/distribution/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,25 @@ impl std::fmt::Display for Binomial {
}
}

impl ::rand::distributions::Distribution<f64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
(0..self.n).fold(0.0, |acc, _| {
impl ::rand::distributions::Distribution<u64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
(0..self.n).fold(0, |acc, _| {

Check warning on line 97 in src/distribution/binomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/binomial.rs#L96-L97

Added lines #L96 - L97 were not covered by tests
let n: f64 = rng.gen();
if n < self.p {
acc + 1.0
acc + 1

Check warning on line 100 in src/distribution/binomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/binomial.rs#L100

Added line #L100 was not covered by tests
} else {
acc
}
})
}
}

impl ::rand::distributions::Distribution<f64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.sample::<u64, _>(self) as f64
}

Check warning on line 111 in src/distribution/binomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/binomial.rs#L109-L111

Added lines #L109 - L111 were not covered by tests
}

impl DiscreteCDF<u64, f64> for Binomial {
/// Calculates the cumulative distribution function for the
/// binomial distribution at `x`
Expand Down
22 changes: 15 additions & 7 deletions src/distribution/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,21 @@ impl std::fmt::Display for Categorical {
}
}

impl ::rand::distributions::Distribution<usize> for Categorical {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
sample_unchecked(rng, &self.cdf)
}

Check warning on line 89 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L87-L89

Added lines #L87 - L89 were not covered by tests
}

impl ::rand::distributions::Distribution<u64> for Categorical {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
sample_unchecked(rng, &self.cdf) as u64
}

Check warning on line 95 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L93-L95

Added lines #L93 - L95 were not covered by tests
}

impl ::rand::distributions::Distribution<f64> for Categorical {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
sample_unchecked(rng, &self.cdf)
sample_unchecked(rng, &self.cdf) as f64

Check warning on line 100 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L100

Added line #L100 was not covered by tests
}
}

Expand Down Expand Up @@ -281,13 +293,9 @@ impl Discrete<u64, f64> for Categorical {

/// Draws a sample from the categorical distribution described by `cdf`
/// without doing any bounds checking
pub fn sample_unchecked<R: Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> f64 {
pub fn sample_unchecked<R: Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> usize {

Check warning on line 296 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L296

Added line #L296 was not covered by tests
let draw = rng.gen::<f64>() * cdf.last().unwrap();
cdf.iter()
.enumerate()
.find(|(_, val)| **val >= draw)
.map(|(i, _)| i)
.unwrap() as f64
cdf.iter().position(|val| *val >= draw).unwrap()

Check warning on line 298 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L298

Added line #L298 was not covered by tests
}

/// Computes the cdf from the given probability masses. Performs
Expand Down
6 changes: 6 additions & 0 deletions src/distribution/discrete_uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ impl std::fmt::Display for DiscreteUniform {
}
}

impl ::rand::distributions::Distribution<i64> for DiscreteUniform {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> i64 {
rng.gen_range(self.min..=self.max)
}

Check warning on line 63 in src/distribution/discrete_uniform.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/discrete_uniform.rs#L61-L63

Added lines #L61 - L63 were not covered by tests
}

impl ::rand::distributions::Distribution<f64> for DiscreteUniform {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.gen_range(self.min..=self.max) as f64
Expand Down
17 changes: 13 additions & 4 deletions src/distribution/geometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,26 @@ impl std::fmt::Display for Geometric {
}
}

impl ::rand::distributions::Distribution<f64> for Geometric {
fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> f64 {
impl ::rand::distributions::Distribution<u64> for Geometric {
fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> u64 {

Check warning on line 78 in src/distribution/geometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/geometric.rs#L78

Added line #L78 was not covered by tests
if ulps_eq!(self.p, 1.0) {
1.0
1

Check warning on line 80 in src/distribution/geometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/geometric.rs#L80

Added line #L80 was not covered by tests
} else {
let x: f64 = r.sample(OpenClosed01);
x.log(1.0 - self.p).ceil()
// This cast is safe, because the largest finite value this expression can take is when
// `x = 1.4e-45` and `1.0 - self.p = 0.9999999999999999`, in which case we get
// `930262250532780300`, which when casted to a `u64` is `930262250532780288`.
x.log(1.0 - self.p).ceil() as u64

Check warning on line 86 in src/distribution/geometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/geometric.rs#L83-L86

Added lines #L83 - L86 were not covered by tests
}
}
}

impl ::rand::distributions::Distribution<f64> for Geometric {
fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> f64 {
r.sample::<u64, _>(self) as f64
}

Check warning on line 94 in src/distribution/geometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/geometric.rs#L92-L94

Added lines #L92 - L94 were not covered by tests
}

impl DiscreteCDF<u64, f64> for Geometric {
/// Calculates the cumulative distribution function for the geometric
/// distribution at `x`
Expand Down
14 changes: 10 additions & 4 deletions src/distribution/hypergeometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,17 @@ impl std::fmt::Display for Hypergeometric {
}
}

impl ::rand::distributions::Distribution<f64> for Hypergeometric {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl ::rand::distributions::Distribution<u64> for Hypergeometric {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {

Check warning on line 124 in src/distribution/hypergeometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/hypergeometric.rs#L124

Added line #L124 was not covered by tests
let mut population = self.population as f64;
let mut successes = self.successes as f64;
let mut draws = self.draws;
let mut x = 0.0;
let mut x = 0;

Check warning on line 128 in src/distribution/hypergeometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/hypergeometric.rs#L128

Added line #L128 was not covered by tests
loop {
let p = successes / population;
let next: f64 = rng.gen();
if next < p {
x += 1.0;
x += 1;

Check warning on line 133 in src/distribution/hypergeometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/hypergeometric.rs#L133

Added line #L133 was not covered by tests
successes -= 1.0;
}
population -= 1.0;
Expand All @@ -143,6 +143,12 @@ impl ::rand::distributions::Distribution<f64> for Hypergeometric {
}
}

impl ::rand::distributions::Distribution<f64> for Hypergeometric {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.sample::<u64, _>(self) as f64
}

Check warning on line 149 in src/distribution/hypergeometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/hypergeometric.rs#L147-L149

Added lines #L147 - L149 were not covered by tests
}

impl DiscreteCDF<u64, f64> for Hypergeometric {
/// Calculates the cumulative distribution function for the hypergeometric
/// distribution at `x`
Expand Down
20 changes: 18 additions & 2 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ where
}
}

impl<D> ::rand::distributions::Distribution<OVector<u64, D>> for Multinomial<D>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<u64, D>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> OVector<u64, D> {
let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice());
let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>);
for _ in 0..self.n {
let i = super::categorical::sample_unchecked(rng, &p_cdf);
res[i] += 1;
}
res
}

Check warning on line 136 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L128-L136

Added lines #L128 - L136 were not covered by tests
}

impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for Multinomial<D>
where
D: Dim,
Expand All @@ -129,8 +146,7 @@ where
let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>);
for _ in 0..self.n {
let i = super::categorical::sample_unchecked(rng, &p_cdf);
let el = res.get_mut(i as usize).unwrap();
*el += 1.0;
res[i] += 1.0;

Check warning on line 149 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L149

Added line #L149 was not covered by tests
}
res
}
Expand Down
2 changes: 1 addition & 1 deletion src/distribution/negative_binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl std::fmt::Display for NegativeBinomial {
impl ::rand::distributions::Distribution<u64> for NegativeBinomial {
fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> u64 {
let lambda = distribution::gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p);
poisson::sample_unchecked(r, lambda).floor() as u64
poisson::sample_unchecked(r, lambda) as u64

Check warning on line 116 in src/distribution/negative_binomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/negative_binomial.rs#L116

Added line #L116 was not covered by tests
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/distribution/poisson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ impl std::fmt::Display for Poisson {
}
}

impl ::rand::distributions::Distribution<u64> for Poisson {
/// Generates one sample from the Poisson distribution either by
/// Knuth's method if lambda < 30.0 or Rejection method PA by
/// A. C. Atkinson from the Journal of the Royal Statistical Society
/// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35
/// otherwise
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
sample_unchecked(rng, self.lambda) as u64
}

Check warning on line 83 in src/distribution/poisson.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/poisson.rs#L81-L83

Added lines #L81 - L83 were not covered by tests
}

impl ::rand::distributions::Distribution<f64> for Poisson {
/// Generates one sample from the Poisson distribution either by
/// Knuth's method if lambda < 30.0 or Rejection method PA by
Expand Down Expand Up @@ -260,6 +271,7 @@ impl Discrete<u64, f64> for Poisson {
-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x)
}
}

/// Generates one sample from the Poisson distribution either by
/// Knuth's method if lambda < 30.0 or Rejection method PA by
/// A. C. Atkinson from the Journal of the Royal Statistical Society
Expand Down

0 comments on commit d09b3f4

Please sign in to comment.