diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index 587f7f78..bad51b66 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -188,11 +188,10 @@ where fn sample_generic(p: &[f64], n: u64, dim: D, rng: &mut R) -> OVector where D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, R: ::rand::Rng + ?Sized, - T: ::num_traits::Num - + ::nalgebra::Scalar + T: nalgebra::Scalar + + num_traits::Zero + num_traits::AsPrimitive + num_traits::FromPrimitive, super::Binomial: rand::distributions::Distribution, @@ -203,8 +202,7 @@ where let mut samples_left = n; let mut p_sorted_inds: Vec<_> = (0..p.len()).collect(); - - // unwrap because NAN elements not allowed from this struct's `new` + // unwrap succeeds as NAN elements not allowed from this struct's `new` p_sorted_inds.sort_unstable_by(|&i, &j| p[j].partial_cmp(&p[i]).unwrap()); for ind in p_sorted_inds.into_iter().take(p.len() - 1) { @@ -215,6 +213,11 @@ where if !(0.0..=1.0).contains(&probs_not_taken) || samples_left == 0 { break; } + // since $p_j \le p_i \forall j < i$ and $\vec{p}$ is normalized, then + // $1 - sum(p_j, j, 0, i-1) = sum(p_j, j, i, n) = p_i + sum(p_j, j, i+1, n) > p_i$ + // this guarantees that logically p_binom on [0,1] + // TODO: demonstrate that this behavior also behaves well with floating point + // the logical reasoning provides that `unwrap` of Binomial::new will typically succeed let p_binom = pi / probs_not_taken; res[ind] = super::Binomial::new(p_binom, samples_left) .unwrap()