Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sampling integers from discrete distributions #155

Merged
merged 2 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/distribution/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,19 @@
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<bool> for Bernoulli {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> bool {
rng.gen_bool(self.p())
}

Check warning on line 92 in src/distribution/bernoulli.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/bernoulli.rs#L90-L92

Added lines #L90 - L92 were not covered by tests
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Bernoulli {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.gen_bool(self.p()) as u8 as f64
rng.sample::<bool, _>(self) as u8 as f64

Check warning on line 99 in src/distribution/bernoulli.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/bernoulli.rs#L99

Added line #L99 was not covered by tests
}
}

Expand Down
16 changes: 12 additions & 4 deletions src/distribution/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,27 @@

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Binomial {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
(0..self.n).fold(0.0, |acc, _| {
impl ::rand::distributions::Distribution<u64> for Binomial {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
(0..self.n).fold(0, |acc, _| {

Check warning on line 116 in src/distribution/binomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/binomial.rs#L115-L116

Added lines #L115 - L116 were not covered by tests
let n: f64 = rng.gen();
if n < self.p {
acc + 1.0
acc + 1

Check warning on line 119 in src/distribution/binomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/binomial.rs#L119

Added line #L119 was not covered by tests
} else {
acc
}
})
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Binomial {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.sample::<u64, _>(self) as f64
}

Check warning on line 132 in src/distribution/binomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/binomial.rs#L130-L132

Added lines #L130 - L132 were not covered by tests
}

impl DiscreteCDF<u64, f64> for Binomial {
/// Calculates the cumulative distribution function for the
/// binomial distribution at `x`
Expand Down
26 changes: 19 additions & 7 deletions src/distribution/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,27 @@
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<usize> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> usize {
sample_unchecked(rng, &self.cdf)
}

Check warning on line 131 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L129-L131

Added lines #L129 - L131 were not covered by tests
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<u64> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
sample_unchecked(rng, &self.cdf) as u64
}

Check warning on line 139 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L137-L139

Added lines #L137 - L139 were not covered by tests
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
sample_unchecked(rng, &self.cdf)
sample_unchecked(rng, &self.cdf) as f64

Check warning on line 146 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L146

Added line #L146 was not covered by tests
}
}

Expand Down Expand Up @@ -325,13 +341,9 @@
/// without doing any bounds checking
#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> f64 {
pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> usize {

Check warning on line 344 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L344

Added line #L344 was not covered by tests
let draw = rng.gen::<f64>() * cdf.last().unwrap();
cdf.iter()
.enumerate()
.find(|(_, val)| **val >= draw)
.map(|(i, _)| i)
.unwrap() as f64
cdf.iter().position(|val| *val >= draw).unwrap()

Check warning on line 346 in src/distribution/categorical.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/categorical.rs#L346

Added line #L346 was not covered by tests
}

/// Computes the cdf from the given probability masses. Performs
Expand Down
10 changes: 9 additions & 1 deletion src/distribution/discrete_uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,19 @@
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<i64> for DiscreteUniform {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
rng.gen_range(self.min..=self.max)
}

Check warning on line 82 in src/distribution/discrete_uniform.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/discrete_uniform.rs#L80-L82

Added lines #L80 - L82 were not covered by tests
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for DiscreteUniform {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.gen_range(self.min..=self.max) as f64
rng.sample::<i64, _>(self) as f64

Check warning on line 89 in src/distribution/discrete_uniform.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/discrete_uniform.rs#L89

Added line #L89 was not covered by tests
}
}

Expand Down
23 changes: 16 additions & 7 deletions src/distribution/geometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,28 @@

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Geometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> f64 {
use ::rand::distributions::OpenClosed01;

impl ::rand::distributions::Distribution<u64> for Geometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> u64 {

Check warning on line 96 in src/distribution/geometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/geometric.rs#L96

Added line #L96 was not covered by tests
if ulps_eq!(self.p, 1.0) {
1.0
1

Check warning on line 98 in src/distribution/geometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/geometric.rs#L98

Added line #L98 was not covered by tests
} else {
let x: f64 = r.sample(OpenClosed01);
x.log(1.0 - self.p).ceil()
let x: f64 = r.sample(::rand::distributions::OpenClosed01);
// This cast is safe, because the largest finite value this expression can take is when
// `x = 1.4e-45` and `1.0 - self.p = 0.9999999999999999`, in which case we get
// `930262250532780300`, which when casted to a `u64` is `930262250532780288`.
x.log(1.0 - self.p).ceil() as u64

Check warning on line 104 in src/distribution/geometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/geometric.rs#L100-L104

Added lines #L100 - L104 were not covered by tests
}
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Geometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> f64 {
r.sample::<u64, _>(self) as f64
}

Check warning on line 114 in src/distribution/geometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/geometric.rs#L112-L114

Added lines #L112 - L114 were not covered by tests
}

impl DiscreteCDF<u64, f64> for Geometric {
/// Calculates the cumulative distribution function for the geometric
/// distribution at `x`
Expand Down
16 changes: 12 additions & 4 deletions src/distribution/hypergeometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,17 @@

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Hypergeometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl ::rand::distributions::Distribution<u64> for Hypergeometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {

Check warning on line 151 in src/distribution/hypergeometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/hypergeometric.rs#L151

Added line #L151 was not covered by tests
let mut population = self.population as f64;
let mut successes = self.successes as f64;
let mut draws = self.draws;
let mut x = 0.0;
let mut x = 0;

Check warning on line 155 in src/distribution/hypergeometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/hypergeometric.rs#L155

Added line #L155 was not covered by tests
loop {
let p = successes / population;
let next: f64 = rng.gen();
if next < p {
x += 1.0;
x += 1;

Check warning on line 160 in src/distribution/hypergeometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/hypergeometric.rs#L160

Added line #L160 was not covered by tests
successes -= 1.0;
}
population -= 1.0;
Expand All @@ -170,6 +170,14 @@
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Hypergeometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.sample::<u64, _>(self) as f64
}

Check warning on line 178 in src/distribution/hypergeometric.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/hypergeometric.rs#L176-L178

Added lines #L176 - L178 were not covered by tests
}

impl DiscreteCDF<u64, f64> for Hypergeometric {
/// Calculates the cumulative distribution function for the hypergeometric
/// distribution at `x`
Expand Down
44 changes: 34 additions & 10 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl<D> ::rand::distributions::Distribution<OVector<u64, D>> for Multinomial<D>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<u64, D>,
{
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<u64, D> {
sample_generic(self, rng)
}

Check warning on line 172 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L170-L172

Added lines #L170 - L172 were not covered by tests
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for Multinomial<D>
Expand All @@ -167,17 +180,28 @@
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
{
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
use nalgebra::Const;

let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice());
let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>);
for _ in 0..self.n {
let i = super::categorical::sample_unchecked(rng, &p_cdf);
let el = res.get_mut(i as usize).unwrap();
*el += 1.0;
}
res
sample_generic(self, rng)
}

Check warning on line 184 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L183-L184

Added lines #L183 - L184 were not covered by tests
}

#[cfg(feature = "rand")]
fn sample_generic<D, R, T>(dist: &Multinomial<D>, rng: &mut R) -> OVector<T, D>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<T, D>,
R: ::rand::Rng + ?Sized,
T: ::num_traits::Num + ::nalgebra::Scalar + ::std::ops::AddAssign<T>,
{

Check warning on line 195 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L188-L195

Added lines #L188 - L195 were not covered by tests
use nalgebra::Const;

let p_cdf = super::categorical::prob_mass_to_cdf(dist.p().as_slice());
let mut res = OVector::zeros_generic(dist.p.shape_generic().0, Const::<1>);
for _ in 0..dist.n {
let i = super::categorical::sample_unchecked(rng, &p_cdf);
res[i] += T::one();

Check warning on line 202 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L198-L202

Added lines #L198 - L202 were not covered by tests
}
res

Check warning on line 204 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L204

Added line #L204 was not covered by tests
}

impl<D> MeanN<DVector<f64>> for Multinomial<D>
Expand Down
14 changes: 14 additions & 0 deletions src/distribution/poisson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<u64> for Poisson {
/// Generates one sample from the Poisson distribution either by
/// Knuth's method if lambda < 30.0 or Rejection method PA by
/// A. C. Atkinson from the Journal of the Royal Statistical Society
/// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35
/// otherwise
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
sample_unchecked(rng, self.lambda) as u64
}

Check warning on line 102 in src/distribution/poisson.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/poisson.rs#L100-L102

Added lines #L100 - L102 were not covered by tests
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Poisson {
Expand Down Expand Up @@ -279,6 +292,7 @@
-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x)
}
}

/// Generates one sample from the Poisson distribution either by
/// Knuth's method if lambda < 30.0 or Rejection method PA by
/// A. C. Atkinson from the Journal of the Royal Statistical Society
Expand Down