Skip to content

Commit

Permalink
fix: correct sampling for multinomial, rely on Binomial
Browse files Browse the repository at this point in the history
  • Loading branch information
YeungOnion committed Sep 23, 2024
1 parent 1dc7207 commit 41e32c5
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::distribution::Discrete;
use crate::function::factorial;
use crate::statistics::*;
use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector};
use std::cmp::Ordering;

/// Implements the
/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
Expand Down Expand Up @@ -100,16 +99,6 @@ where
if p.len() < 2 {
return Err(MultinomialError::NotEnoughProbabilities);
}
// sort decreasing, place NAN at front
p.as_mut_slice().sort_unstable_by(|a, b| {
if a.is_nan() {
Ordering::Less
} else if b.is_nan() {
Ordering::Greater
} else {
b.partial_cmp(a).unwrap()
}
});

let mut sum = 0.0;
for &val in &p {
Expand Down Expand Up @@ -178,26 +167,37 @@ where
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
{
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
let dim = self.p.shape_generic().0;
let mut res = OVector::zeros_generic(dim, nalgebra::U1);
let mut probs_not_taken = 1.0;
let mut samples_left = self.n;
for (p, s) in self.p.as_slice()[..self.p.len() - 1]
.iter()
.zip(res.iter_mut())
{
if !(0.0..=1.0).contains(&probs_not_taken) {

let mut p_sorted_inds: Vec<_> = (0..self.p().len()).collect();

// unwrap because NAN elements not allowed from this struct's `new`
p_sorted_inds.sort_unstable_by(|&i, &j| self.p[j].partial_cmp(&self.p[i]).unwrap());

// allocate result and write into it
let mut res = OVector::zeros_generic(self.p.shape_generic().0, nalgebra::U1);

for ind in p_sorted_inds.into_iter().take(self.p.len() - 1) {
let p = self.p[ind];
if p == 0.0 {
continue;
}
if !(0.0..=1.0).contains(&probs_not_taken) || samples_left == 0 {
break;
}
let p_binom = p / probs_not_taken;
*s = super::Binomial::new(p_binom, samples_left)
.expect("probability already on [0,1]")
res[ind] = super::Binomial::new(p_binom, samples_left)
.unwrap_or_else(|_| panic!("expected clamped probability to lead to valid Binom dist, got prob = {p_binom}"))
.sample(rng);
samples_left -= *s as u64;
samples_left -= res[ind] as u64;
probs_not_taken -= p;
}
if let Some(x) = res.as_mut_slice().last_mut() {
*x = samples_left as f64;

if samples_left > 0 {
if let Some(x) = res.as_mut_slice().last_mut() {
*x = samples_left as f64;
}
}
res
}
Expand Down Expand Up @@ -242,11 +242,11 @@ where
/// where `n` is the number of trials, `p_i` is the `i`th probability,
/// and `k` is the total number of probabilities
fn variance(&self) -> Option<OMatrix<f64, D, D>> {
let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x)));
let mut offdiag = |x: usize, y: usize| {
let elt = -self.p[x] * self.p[y];
let mut cov = OMatrix::from_diagonal(&self.p.map(|p| p * (1.0 - p)));
let mut offdiag = |i: usize, j: usize| {
let elt = -self.p[i] * self.p[j];
// cov[(x, y)] = elt;
cov[(y, x)] = elt;
cov[(j, i)] = elt;
};

for i in 0..self.p.len() {
Expand Down Expand Up @@ -543,7 +543,7 @@ mod tests {

#[cfg(feature = "rand")]
#[test]
#[ignore = "stochastic tests will not always pass, need a solid rerun strategy"]
#[ignore = "this test is designed to assess approximately normal results within 2σ"]
fn test_stochastic_uniform_samples() {
use crate::statistics::Statistics;
use ::rand::{distributions::Distribution, prelude::thread_rng};
Expand Down

0 comments on commit 41e32c5

Please sign in to comment.