Skip to content

Commit

Permalink
refactor(perf): Multinomial samples from Binomial
Browse files Browse the repository at this point in the history
lint: fix lints
  • Loading branch information
YeungOnion committed Sep 24, 2024
1 parent 993a4b5 commit d32bd1a
Showing 1 changed file with 76 additions and 11 deletions.
87 changes: 76 additions & 11 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::distribution::Discrete;
use crate::function::factorial;
use crate::statistics::*;
use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector};
use std::cmp::Ordering;

/// Implements the
/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
Expand Down Expand Up @@ -99,6 +100,16 @@ where
if p.len() < 2 {
return Err(MultinomialError::NotEnoughProbabilities);
}
// sort decreasing, place NAN at front
p.as_mut_slice().sort_unstable_by(|a, b| {
if a.is_nan() {
Ordering::Less
} else if b.is_nan() {
Ordering::Greater
} else {
b.partial_cmp(a).unwrap()
}
});

let mut sum = 0.0;
for &val in &p {
Expand Down Expand Up @@ -168,7 +179,7 @@ where
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<u64, D>,
{
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<u64, D> {
sample_generic(self, rng)
sample_generic(self.p.as_slice(), self.n, self.p.shape_generic().0, rng)
}
}

Expand All @@ -180,26 +191,40 @@ where
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
{
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
sample_generic(self, rng)
sample_generic(self.p.as_slice(), self.n, self.p.shape_generic().0, rng)
}
}

#[cfg(feature = "rand")]
fn sample_generic<D, R, T>(dist: &Multinomial<D>, rng: &mut R) -> OVector<T, D>
fn sample_generic<D, R, T>(p: &[f64], n: u64, dim: D, rng: &mut R) -> OVector<T, D>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<T, D>,
R: ::rand::Rng + ?Sized,
T: ::num_traits::Num + ::nalgebra::Scalar + ::std::ops::AddAssign<T>,
T: ::num_traits::Num
+ ::nalgebra::Scalar
+ num_traits::AsPrimitive<u64>
+ num_traits::FromPrimitive,
super::Binomial: rand::distributions::Distribution<T>,
{
use nalgebra::Const;

let p_cdf = super::categorical::prob_mass_to_cdf(dist.p().as_slice());
let mut res = OVector::zeros_generic(dist.p.shape_generic().0, Const::<1>);
for _ in 0..dist.n {
let i = super::categorical::sample_unchecked(rng, &p_cdf);
res[i] += T::one();
use rand::distributions::Distribution;
let mut res = OVector::zeros_generic(dim, nalgebra::U1);
let mut probs_not_taken = 1.0;
let mut samples_left = n;
for (p, s) in p[..p.len() - 1].iter().zip(res.iter_mut()) {
if !(0.0..=1.0).contains(&probs_not_taken) {
break;
}
let p_binom = p / probs_not_taken;
*s = super::Binomial::new(p_binom, samples_left)
.expect("probability already on [0,1]")
.sample(rng);
samples_left -= s.as_();
probs_not_taken -= p;
}
if let Some(x) = res.as_mut_slice().last_mut() {
*x = T::from_u64(samples_left).unwrap();
}
res
}
Expand Down Expand Up @@ -576,4 +601,44 @@ mod tests {
// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
// n.ln_pmf(&[1, 3]);
// }

#[cfg(feature = "rand")]
#[test]
fn test_almost_zero_sample() {
use ::rand::{distributions::Distribution, prelude::thread_rng};
let n = 10;
let weights = vec![0.0, 0.0, 0.0, 0.000000001];
let multinomial = Multinomial::new(weights, n).unwrap();
let sample: OVector<f64, Dyn> = multinomial.sample(&mut thread_rng());
assert_relative_eq!(sample[3], n as f64);
}

#[test]
#[cfg(feature = "rand")]
fn test_uniform_samples() {
use crate::statistics::Statistics;
use ::rand::{distributions::Distribution, prelude::thread_rng};
let n: f64 = 1000.0;
let k: usize = 20;
let weights = vec![1.0; k];
let multinomial = Multinomial::new(weights, n as u64).unwrap();
let samples: OVector<f64, Dyn> = multinomial.sample(&mut thread_rng());
eprintln!("{samples}");
samples.iter().for_each(|&s| {
assert_abs_diff_eq!(
s,
n / k as f64,
epsilon = 3.0 * multinomial.variance().unwrap()[(0, 0)].sqrt(),
)
});
assert_abs_diff_eq!(
samples.iter().population_variance(),
multinomial.variance().unwrap()[(0, 0)],
epsilon = n.sqrt()
);
assert_eq!(
samples.into_iter().map(|&x| x as u64).sum::<u64>(),
n as u64
);
}
}

0 comments on commit d32bd1a

Please sign in to comment.