Skip to content

Commit

Permalink
feat: support sampling integers from discrete distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
SabrinaJewson authored and YeungOnion committed Sep 22, 2024
1 parent a514992 commit 9fa3643
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 34 deletions.
10 changes: 9 additions & 1 deletion src/distribution/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,19 @@ impl std::fmt::Display for Bernoulli {
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<bool> for Bernoulli {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> bool {
rng.gen_bool(self.p())
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Bernoulli {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.gen_bool(self.p()) as u8 as f64
f64::from(rng.sample::<bool, _>(self))
}
}

Expand Down
16 changes: 12 additions & 4 deletions src/distribution/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,27 @@ impl std::fmt::Display for Binomial {

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Binomial {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
(0..self.n).fold(0.0, |acc, _| {
impl ::rand::distributions::Distribution<u64> for Binomial {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
(0..self.n).fold(0, |acc, _| {
let n: f64 = rng.gen();
if n < self.p {
acc + 1.0
acc + 1
} else {
acc
}
})
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Binomial {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.sample::<u64, _>(self) as f64
}
}

impl DiscreteCDF<u64, f64> for Binomial {
/// Calculates the cumulative distribution function for the
/// binomial distribution at `x`
Expand Down
26 changes: 19 additions & 7 deletions src/distribution/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,27 @@ impl std::fmt::Display for Categorical {
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<usize> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> usize {
sample_unchecked(rng, &self.cdf)
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<u64> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
sample_unchecked(rng, &self.cdf) as u64
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
sample_unchecked(rng, &self.cdf)
sample_unchecked(rng, &self.cdf) as f64
}
}

Expand Down Expand Up @@ -325,13 +341,9 @@ impl Discrete<u64, f64> for Categorical {
/// without doing any bounds checking
#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> f64 {
pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> usize {
let draw = rng.gen::<f64>() * cdf.last().unwrap();
cdf.iter()
.enumerate()
.find(|(_, val)| **val >= draw)
.map(|(i, _)| i)
.unwrap() as f64
cdf.iter().position(|val| *val >= draw).unwrap()
}

/// Computes the cdf from the given probability masses. Performs
Expand Down
10 changes: 9 additions & 1 deletion src/distribution/discrete_uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,19 @@ impl std::fmt::Display for DiscreteUniform {
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<i64> for DiscreteUniform {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
rng.gen_range(self.min..=self.max)
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for DiscreteUniform {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.gen_range(self.min..=self.max) as f64
rng.sample::<i64, _>(self) as f64
}
}

Expand Down
23 changes: 16 additions & 7 deletions src/distribution/geometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,28 @@ impl std::fmt::Display for Geometric {

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Geometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> f64 {
use ::rand::distributions::OpenClosed01;

impl ::rand::distributions::Distribution<u64> for Geometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> u64 {
if ulps_eq!(self.p, 1.0) {
1.0
1
} else {
let x: f64 = r.sample(OpenClosed01);
x.log(1.0 - self.p).ceil()
let x: f64 = r.sample(::rand::distributions::OpenClosed01);
// This cast is safe, because the largest finite value this expression can take is when
// `x = 1.4e-45` and `1.0 - self.p = 0.9999999999999999`, in which case we get
// `930262250532780300`, which when casted to a `u64` is `930262250532780288`.
x.log(1.0 - self.p).ceil() as u64
}
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Geometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> f64 {
r.sample::<u64, _>(self) as f64
}
}

impl DiscreteCDF<u64, f64> for Geometric {
/// Calculates the cumulative distribution function for the geometric
/// distribution at `x`
Expand Down
16 changes: 12 additions & 4 deletions src/distribution/hypergeometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,17 @@ impl std::fmt::Display for Hypergeometric {

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Hypergeometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl ::rand::distributions::Distribution<u64> for Hypergeometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
let mut population = self.population as f64;
let mut successes = self.successes as f64;
let mut draws = self.draws;
let mut x = 0.0;
let mut x = 0;
loop {
let p = successes / population;
let next: f64 = rng.gen();
if next < p {
x += 1.0;
x += 1;
successes -= 1.0;
}
population -= 1.0;
Expand All @@ -170,6 +170,14 @@ impl ::rand::distributions::Distribution<f64> for Hypergeometric {
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Hypergeometric {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
rng.sample::<u64, _>(self) as f64
}
}

impl DiscreteCDF<u64, f64> for Hypergeometric {
/// Calculates the cumulative distribution function for the hypergeometric
/// distribution at `x`
Expand Down
44 changes: 34 additions & 10 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ where
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl<D> ::rand::distributions::Distribution<OVector<u64, D>> for Multinomial<D>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<u64, D>,
{
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<u64, D> {
sample_generic(self, rng)
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for Multinomial<D>
Expand All @@ -167,17 +180,28 @@ where
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
{
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
use nalgebra::Const;

let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice());
let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>);
for _ in 0..self.n {
let i = super::categorical::sample_unchecked(rng, &p_cdf);
let el = res.get_mut(i as usize).unwrap();
*el += 1.0;
}
res
sample_generic(self, rng)
}
}

#[cfg(feature = "rand")]
fn sample_generic<D, R, T>(dist: &Multinomial<D>, rng: &mut R) -> OVector<T, D>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<T, D>,
R: ::rand::Rng + ?Sized,
T: ::num_traits::Num + ::nalgebra::Scalar + ::std::ops::AddAssign<T>,
{
use nalgebra::Const;

let p_cdf = super::categorical::prob_mass_to_cdf(dist.p().as_slice());
let mut res = OVector::zeros_generic(dist.p.shape_generic().0, Const::<1>);
for _ in 0..dist.n {
let i = super::categorical::sample_unchecked(rng, &p_cdf);
res[i] += T::one();
}
res
}

impl<D> MeanN<DVector<f64>> for Multinomial<D>
Expand Down
14 changes: 14 additions & 0 deletions src/distribution/poisson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ impl std::fmt::Display for Poisson {
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<u64> for Poisson {
/// Generates one sample from the Poisson distribution either by
/// Knuth's method if lambda < 30.0 or Rejection method PA by
/// A. C. Atkinson from the Journal of the Royal Statistical Society
/// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35
/// otherwise
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
sample_unchecked(rng, self.lambda) as u64
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Poisson {
Expand Down Expand Up @@ -279,6 +292,7 @@ impl Discrete<u64, f64> for Poisson {
-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x)
}
}

/// Generates one sample from the Poisson distribution either by
/// Knuth's method if lambda < 30.0 or Rejection method PA by
/// A. C. Atkinson from the Journal of the Royal Statistical Society
Expand Down

0 comments on commit 9fa3643

Please sign in to comment.