Skip to content

Commit

Permalink
fix: correct sampling for multinomial, rely on Binomial
Browse files Browse the repository at this point in the history
  • Loading branch information
YeungOnion committed Sep 11, 2024
1 parent 5dda539 commit a0fd736
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::distribution::Discrete;
use crate::function::factorial;
use crate::statistics::*;
use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector};
use std::cmp::Ordering;

/// Implements the
/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
Expand Down Expand Up @@ -100,16 +99,6 @@ where
if p.len() < 2 {
return Err(MultinomialError::NotEnoughProbabilities);
}
// sort decreasing, place NAN at front
p.as_mut_slice().sort_unstable_by(|a, b| {
if a.is_nan() {
Ordering::Less
} else if b.is_nan() {
Ordering::Greater
} else {
b.partial_cmp(a).unwrap()
}
});

let mut sum = 0.0;
for &val in &p {
Expand Down Expand Up @@ -177,26 +166,37 @@ where
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
{
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
let dim = self.p.shape_generic().0;
let mut res = OVector::zeros_generic(dim, nalgebra::U1);
let mut probs_not_taken = 1.0;
let mut samples_left = self.n;
for (p, s) in self.p.as_slice()[..self.p.len() - 1]
.iter()
.zip(res.iter_mut())
{
if !(0.0..=1.0).contains(&probs_not_taken) {

let mut p_sorted_inds: Vec<_> = (0..self.p().len()).collect();

// unwrap because NAN elements not allowed from this struct's `new`
p_sorted_inds.sort_unstable_by(|&i, &j| self.p[j].partial_cmp(&self.p[i]).unwrap());

// allocate result and write into it
let mut res = OVector::zeros_generic(self.p.shape_generic().0, nalgebra::U1);

for ind in 0..p_sorted_inds.len() - 1 {
let p = self.p[ind];
if p == 0.0 {
continue;
}
if !(0.0..=1.0).contains(&probs_not_taken) || samples_left == 0 {
break;
}
let p_binom = p / probs_not_taken;
*s = super::Binomial::new(p_binom, samples_left)
.expect("probability already on [0,1]")
res[ind] = super::Binomial::new(p_binom, samples_left)
.unwrap_or_else(|_| panic!("expected clamped probability to lead to valid Binom dist, got prob = {p_binom}"))
.sample(rng);
samples_left -= *s as u64;
samples_left -= res[ind] as u64;
probs_not_taken -= p;
}

Check warning on line 194 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L184-L194

Added lines #L184 - L194 were not covered by tests
if let Some(x) = res.as_mut_slice().last_mut() {
*x = samples_left as f64;

if samples_left > 0 {
if let Some(x) = res.as_mut_slice().last_mut() {
*x = samples_left as f64;
}
}
res
}
Expand Down

0 comments on commit a0fd736

Please sign in to comment.