Skip to content

Commit

Permalink
fix: correct sampling for multinomial, rely on Binomial
Browse files Browse the repository at this point in the history
  • Loading branch information
YeungOnion committed Sep 24, 2024
1 parent 7b8e8cb commit b653c8b
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::distribution::Discrete;
use crate::function::factorial;
use crate::statistics::*;
use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector};
use std::cmp::Ordering;

/// Implements the
/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
Expand Down Expand Up @@ -100,16 +99,6 @@ where
if p.len() < 2 {
return Err(MultinomialError::NotEnoughProbabilities);
}
// sort decreasing, place NAN at front
p.as_mut_slice().sort_unstable_by(|a, b| {
if a.is_nan() {
Ordering::Less
} else if b.is_nan() {
Ordering::Greater
} else {
b.partial_cmp(a).unwrap()
}
});

let mut sum = 0.0;
for &val in &p {
Expand Down Expand Up @@ -212,19 +201,29 @@ where
let mut res = OVector::zeros_generic(dim, nalgebra::U1);
let mut probs_not_taken = 1.0;
let mut samples_left = n;
for (p, s) in p[..p.len() - 1].iter().zip(res.iter_mut()) {
if !(0.0..=1.0).contains(&probs_not_taken) {

let mut p_sorted_inds: Vec<_> = (0..p.len()).collect();

// unwrap because NAN elements not allowed from this struct's `new`
p_sorted_inds.sort_unstable_by(|&i, &j| p[j].partial_cmp(&p[i]).unwrap());

for ind in p_sorted_inds.into_iter().take(p.len() - 1) {
let pi = p[ind];
if pi == 0.0 {
continue;
}
if !(0.0..=1.0).contains(&probs_not_taken) || samples_left == 0 {
break;
}
let p_binom = p / probs_not_taken;
*s = super::Binomial::new(p_binom, samples_left)
.expect("probability already on [0,1]")
let p_binom = pi / probs_not_taken;
res[ind] = super::Binomial::new(p_binom, samples_left)
.unwrap()
.sample(rng);
samples_left -= s.as_();
probs_not_taken -= p;
samples_left -= res[ind].as_();
probs_not_taken -= pi;
}
if let Some(x) = res.as_mut_slice().last_mut() {
*x = T::from_u64(samples_left).unwrap();
if samples_left > 0 {
*res.as_mut_slice().last_mut().unwrap() = T::from_u64(samples_left).unwrap();
}
res
}
Expand Down Expand Up @@ -268,11 +267,11 @@ where
/// where `n` is the number of trials, `p_i` is the `i`th probability,
/// and `k` is the total number of probabilities
fn variance(&self) -> Option<OMatrix<f64, D, D>> {
let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x)));
let mut offdiag = |x: usize, y: usize| {
let elt = -self.p[x] * self.p[y];
let mut cov = OMatrix::from_diagonal(&self.p.map(|p| p * (1.0 - p)));
let mut offdiag = |i: usize, j: usize| {
let elt = -self.p[i] * self.p[j];
// cov[(x, y)] = elt;
cov[(y, x)] = elt;
cov[(j, i)] = elt;
};

for i in 0..self.p.len() {
Expand Down Expand Up @@ -560,15 +559,15 @@ mod tests {
fn test_almost_zero_sample() {
use ::rand::{distributions::Distribution, prelude::thread_rng};
let n = 10;
let weights = vec![0.0, 0.0, 0.0, 0.000000001];
let weights = vec![0.0, 0.0, 0.0, 0.1];
let multinomial = Multinomial::new(weights, n).unwrap();
let sample: OVector<f64, Dyn> = multinomial.sample(&mut thread_rng());
assert_relative_eq!(sample[3], n as f64);
}

#[cfg(feature = "rand")]
#[test]
#[ignore = "stochastic tests will not always pass, need a solid rerun strategy"]
#[ignore = "this test is designed to assess approximately normal results within 2σ"]
fn test_stochastic_uniform_samples() {
use crate::statistics::Statistics;
use ::rand::{distributions::Distribution, prelude::thread_rng};
Expand Down

0 comments on commit b653c8b

Please sign in to comment.