Skip to content

Commit

Permalink
fix: remove unneeded trait bounds
Browse files Browse the repository at this point in the history
also adds a comment demonstrating success of unwrap
  • Loading branch information
YeungOnion committed Sep 24, 2024
1 parent 7b4180c commit 66b8d7e
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,10 @@ where
fn sample_generic<D, R, T>(p: &[f64], n: u64, dim: D, rng: &mut R) -> OVector<T, D>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<T, D>,
R: ::rand::Rng + ?Sized,
T: ::num_traits::Num
+ ::nalgebra::Scalar
T: nalgebra::Scalar
+ num_traits::Zero
+ num_traits::AsPrimitive<u64>
+ num_traits::FromPrimitive,
super::Binomial: rand::distributions::Distribution<T>,
Expand All @@ -203,8 +202,7 @@ where
let mut samples_left = n;

let mut p_sorted_inds: Vec<_> = (0..p.len()).collect();

// unwrap because NAN elements not allowed from this struct's `new`
// unwrap succeeds as NAN elements not allowed from this struct's `new`
p_sorted_inds.sort_unstable_by(|&i, &j| p[j].partial_cmp(&p[i]).unwrap());

for ind in p_sorted_inds.into_iter().take(p.len() - 1) {
Expand All @@ -215,6 +213,11 @@ where
if !(0.0..=1.0).contains(&probs_not_taken) || samples_left == 0 {
break;

Check warning on line 214 in src/distribution/multinomial.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multinomial.rs#L214

Added line #L214 was not covered by tests
}
// since $p_j \le p_i \forall j < i$ and $\vec{p}$ is normalized, then
// $1 - sum(p_j, j, 0, i-1) = sum(p_j, j, i, n) = p_i + sum(p_j, j, i+1, n) > p_i$
// this guarantees that logically p_binom on [0,1]
// TODO: demonstrate that this behavior also behaves well with floating point
// the logical reasoning provides that `unwrap` of Binomial::new will typically succeed
let p_binom = pi / probs_not_taken;
res[ind] = super::Binomial::new(p_binom, samples_left)
.unwrap()
Expand Down

0 comments on commit 66b8d7e

Please sign in to comment.