From b653c8baafc2dbe61227203a976a912ac1fa73eb Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Wed, 11 Sep 2024 13:48:56 -0500 Subject: [PATCH] fix: correct sampling for multinomial, rely on Binomial --- src/distribution/multinomial.rs | 51 ++++++++++++++++----------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 1520727b..b8587ab6 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -2,7 +2,6 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector}; -use std::cmp::Ordering; /// Implements the /// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution) @@ -100,16 +99,6 @@ where if p.len() < 2 { return Err(MultinomialError::NotEnoughProbabilities); } - // sort decreasing, place NAN at front - p.as_mut_slice().sort_unstable_by(|a, b| { - if a.is_nan() { - Ordering::Less - } else if b.is_nan() { - Ordering::Greater - } else { - b.partial_cmp(a).unwrap() - } - }); let mut sum = 0.0; for &val in &p { @@ -212,19 +201,29 @@ where let mut res = OVector::zeros_generic(dim, nalgebra::U1); let mut probs_not_taken = 1.0; let mut samples_left = n; - for (p, s) in p[..p.len() - 1].iter().zip(res.iter_mut()) { - if !(0.0..=1.0).contains(&probs_not_taken) { + + let mut p_sorted_inds: Vec<_> = (0..p.len()).collect(); + + // unwrap because NAN elements not allowed from this struct's `new` + p_sorted_inds.sort_unstable_by(|&i, &j| p[j].partial_cmp(&p[i]).unwrap()); + + for ind in p_sorted_inds.into_iter().take(p.len() - 1) { + let pi = p[ind]; + if pi == 0.0 { + continue; + } + if !(0.0..=1.0).contains(&probs_not_taken) || samples_left == 0 { break; } - let p_binom = p / probs_not_taken; - *s = super::Binomial::new(p_binom, samples_left) - .expect("probability already on [0,1]") + let p_binom = pi / probs_not_taken; + res[ind] = super::Binomial::new(p_binom, samples_left) + .unwrap() .sample(rng); - samples_left -= s.as_(); - probs_not_taken -= p; + samples_left -= res[ind].as_(); + probs_not_taken -= pi; } - if let Some(x) = res.as_mut_slice().last_mut() { - *x = T::from_u64(samples_left).unwrap(); + if samples_left > 0 { + *res.as_mut_slice().last_mut().unwrap() = T::from_u64(samples_left).unwrap(); } res } @@ -268,11 +267,11 @@ where /// where `n` is the number of trials, `p_i` is the `i`th probability, /// and `k` is the total number of probabilities fn variance(&self) -> Option> { - let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x))); - let mut offdiag = |x: usize, y: usize| { - let elt = -self.p[x] * self.p[y]; + let mut cov = OMatrix::from_diagonal(&self.p.map(|p| p * (1.0 - p))); + let mut offdiag = |i: usize, j: usize| { + let elt = -self.p[i] * self.p[j]; // cov[(x, y)] = elt; - cov[(y, x)] = elt; + cov[(j, i)] = elt; }; for i in 0..self.p.len() { @@ -560,7 +559,7 @@ mod tests { fn test_almost_zero_sample() { use ::rand::{distributions::Distribution, prelude::thread_rng}; let n = 10; - let weights = vec![0.0, 0.0, 0.0, 0.000000001]; + let weights = vec![0.0, 0.0, 0.0, 0.1]; let multinomial = Multinomial::new(weights, n).unwrap(); let sample: OVector = multinomial.sample(&mut thread_rng()); assert_relative_eq!(sample[3], n as f64); @@ -568,7 +567,7 @@ mod tests { #[cfg(feature = "rand")] #[test] - #[ignore = "stochastic tests will not always pass, need a solid rerun strategy"] + #[ignore = "this test is designed to assess approximately normal results within 2σ"] fn test_stochastic_uniform_samples() { use crate::statistics::Statistics; use ::rand::{distributions::Distribution, prelude::thread_rng};