diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 0794d352..ef6baa19 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -2,6 +2,7 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector}; +use std::cmp::Ordering; /// Implements the /// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution) @@ -99,6 +100,16 @@ where if p.len() < 2 { return Err(MultinomialError::NotEnoughProbabilities); } + // sort decreasing, place NAN at front + p.as_mut_slice().sort_unstable_by(|a, b| { + if a.is_nan() { + Ordering::Less + } else if b.is_nan() { + Ordering::Greater + } else { + b.partial_cmp(a).unwrap() + } + }); let mut sum = 0.0; for &val in &p { @@ -168,7 +179,7 @@ where nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn sample(&self, rng: &mut R) -> OVector { - sample_generic(self, rng) + sample_generic(self.p.as_slice(), self.n, self.p.shape_generic().0, rng) } } @@ -180,26 +191,40 @@ where nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn sample(&self, rng: &mut R) -> OVector { - sample_generic(self, rng) + sample_generic(self.p.as_slice(), self.n, self.p.shape_generic().0, rng) } } #[cfg(feature = "rand")] -fn sample_generic(dist: &Multinomial, rng: &mut R) -> OVector +fn sample_generic(p: &[f64], n: u64, dim: D, rng: &mut R) -> OVector where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, R: ::rand::Rng + ?Sized, - T: ::num_traits::Num + ::nalgebra::Scalar + ::std::ops::AddAssign, + T: ::num_traits::Num + + ::nalgebra::Scalar + + num_traits::AsPrimitive + + num_traits::FromPrimitive, + super::Binomial: rand::distributions::Distribution, { - use nalgebra::Const; - - let p_cdf = super::categorical::prob_mass_to_cdf(dist.p().as_slice()); - let mut res = OVector::zeros_generic(dist.p.shape_generic().0, Const::<1>); - for _ in 0..dist.n { - let i = super::categorical::sample_unchecked(rng, &p_cdf); - res[i] += T::one(); + use rand::distributions::Distribution; + let mut res = OVector::zeros_generic(dim, nalgebra::U1); + let mut probs_not_taken = 1.0; + let mut samples_left = n; + for (p, s) in p[..p.len() - 1].iter().zip(res.iter_mut()) { + if !(0.0..=1.0).contains(&probs_not_taken) { + break; + } + let p_binom = p / probs_not_taken; + *s = super::Binomial::new(p_binom, samples_left) + .expect("probability already on [0,1]") + .sample(rng); + samples_left -= s.as_(); + probs_not_taken -= p; + } + if let Some(x) = res.as_mut_slice().last_mut() { + *x = T::from_u64(samples_left).unwrap(); } res } @@ -576,4 +601,44 @@ mod tests { // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); // n.ln_pmf(&[1, 3]); // } + + #[cfg(feature = "rand")] + #[test] + fn test_almost_zero_sample() { + use ::rand::{distributions::Distribution, prelude::thread_rng}; + let n = 10; + let weights = vec![0.0, 0.0, 0.0, 0.000000001]; + let multinomial = Multinomial::new(weights, n).unwrap(); + let sample: OVector = multinomial.sample(&mut thread_rng()); + assert_relative_eq!(sample[3], n as f64); + } + + #[test] + #[cfg(feature = "rand")] + fn test_uniform_samples() { + use crate::statistics::Statistics; + use ::rand::{distributions::Distribution, prelude::thread_rng}; + let n: f64 = 1000.0; + let k: usize = 20; + let weights = vec![1.0; k]; + let multinomial = Multinomial::new(weights, n as u64).unwrap(); + let samples: OVector = multinomial.sample(&mut thread_rng()); + eprintln!("{samples}"); + samples.iter().for_each(|&s| { + assert_abs_diff_eq!( + s, + n / k as f64, + epsilon = 3.0 * multinomial.variance().unwrap()[(0, 0)].sqrt(), + ) + }); + assert_abs_diff_eq!( + samples.iter().population_variance(), + multinomial.variance().unwrap()[(0, 0)], + epsilon = n.sqrt() + ); + assert_eq!( + samples.into_iter().map(|&x| x as u64).sum::(), + n as u64 + ); + } }