Skip to content

Commit

Permalink
fn sample_efraimidis_spirakis: use BinaryHeap
Browse files Browse the repository at this point in the history
  • Loading branch information
dhardy committed Nov 26, 2024
1 parent a039a7f commit 831353f
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/seq/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ where
N: UInt,
IndexVec: From<Vec<N>>,
{
use std::cmp::Ordering;
use std::{cmp::Ordering, collections::BinaryHeap};

if amount == N::zero() {
return Ok(IndexVec::U32(Vec::new()));
Expand All @@ -373,9 +373,9 @@ where

impl<N> Ord for Element<N> {
fn cmp(&self, other: &Self) -> Ordering {
// partial_cmp will always produce a value,
// because we check that the weights are not nan
self.key.partial_cmp(&other.key).unwrap()
// unwrap() should not panic since weights should not be NaN
// We reverse so that BinaryHeap::peek shows the smallest item
self.key.partial_cmp(&other.key).unwrap().reverse()
}
}

Expand All @@ -387,7 +387,7 @@ where

impl<N> Eq for Element<N> {}

let mut candidates = Vec::with_capacity(amount.as_usize());
let mut candidates = BinaryHeap::with_capacity(amount.as_usize());
let mut index = N::zero();
while index < length && candidates.len() < amount.as_usize() {
let weight = weight(index.as_usize()).into();
Expand All @@ -402,26 +402,23 @@ where

index += N::one();
}
candidates.sort_unstable();

if candidates.len() < amount.as_usize() {
return Err(WeightError::InsufficientNonZero);
}

let mut x = rng.random::<f64>().ln() / candidates[0].key;
let mut x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
while index < length {
let weight = weight(index.as_usize()).into();
if weight > 0.0 {
x -= weight;
if x <= 0.0 {
let t = (candidates[0].key * weight).exp();
let min_candidate = candidates.pop().unwrap();
let t = (min_candidate.key * weight).exp();
let key = rng.random_range(t..1.0).ln() / weight;
candidates[0] = Element { index, key };
// TODO: consider using a binary tree instead of sorting at each
// step. This should be faster for some THRESHOLD < amount.
candidates.sort_unstable();
candidates.push(Element { index, key });

x = rng.random::<f64>().ln() / candidates[0].key;
x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
}
} else if !(weight >= 0.0) {
return Err(WeightError::InvalidWeight);
Expand Down

0 comments on commit 831353f

Please sign in to comment.