Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xmakro committed Feb 7, 2024
1 parent 7a0e234 commit a23e842
Showing 1 changed file with 35 additions and 18 deletions.
53 changes: 35 additions & 18 deletions rand_distr/src/weighted_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,17 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! This module contains an implementation of a tree sttructure for sampling random
//! This module contains an implementation of a tree structure for sampling random
//! indices with probabilities proportional to a collection of weights.
use core::ops::SubAssign;

use super::WeightedError;
use crate::Distribution;
use alloc::vec::Vec;
use rand::{
distributions::{
uniform::{SampleBorrow, SampleUniform},
Weight,
},
Rng,
};
use rand::distributions::uniform::{SampleBorrow, SampleUniform};
use rand::distributions::Weight;
use rand::Rng;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

Expand All @@ -29,7 +25,7 @@ use serde::{Deserialize, Serialize};
/// Sampling a [`WeightedTreeIndex<W>`] distribution returns the index of a randomly
/// selected element from the vector used to create the [`WeightedTreeIndex<W>`].
/// The chance of a given element being picked is proportional to the value of
/// the element. The weights can have any type `W` for which a implementation of
/// the element. The weights can have any type `W` for which an implementation of
/// [`Weight`] exists.
///
/// # Key differences
Expand Down Expand Up @@ -71,15 +67,16 @@ use serde::{Deserialize, Serialize};
/// dist.push(1).unwrap();
/// dist.update(1, 1).unwrap();
/// let mut rng = thread_rng();
/// let mut samples = [0; 3];
/// for _ in 0..100 {
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
/// let i = dist.sample(&mut rng).unwrap();
/// println!("{}", choices[i]);
/// let i = dist.sample(&mut rng);
/// samples[i] += 1;
/// }
/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::<Vec<_>>());
/// ```
///
/// [`WeightedTreeIndex<W>`]: WeightedTreeIndex
/// [`Uniform<W>::sample`]: Distribution::sample
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(
Expand Down Expand Up @@ -132,7 +129,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + Weight> WeightedTreeInd
/// Returns `true` if we can sample.
///
/// This is the case if the total weight of the tree is greater than zero.
pub fn can_sample(&self) -> bool {
pub fn is_valid(&self) -> bool {
if let Some(weight) = self.subtotals.first() {
*weight > W::ZERO
} else {
Expand Down Expand Up @@ -229,9 +226,13 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + Weight> WeightedTreeInd
}

impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
Distribution<Result<usize, WeightedError>> for WeightedTreeIndex<W>
WeightedTreeIndex<W>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
/// Samples a randomly selected index from the weighted distribution.
///
/// Returns an error if there are no elements or all weights are zero. This
/// is unlike [`Distribution::sample`], which panics in those cases.
fn safe_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
if self.subtotals.is_empty() {
return Err(WeightedError::NoItem);
}
Expand Down Expand Up @@ -269,6 +270,19 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
}
}

impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight> Distribution<usize>
for WeightedTreeIndex<W>
{
/// Samples a randomly selected index from the weighted distribution.
///
/// Caution: This method panics if there are no elements or all weights are zero. However,
/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`]
/// returns `true`.
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
self.safe_sample(rng).unwrap()
}
}

#[cfg(test)]
mod test {
use super::*;
Expand All @@ -277,7 +291,10 @@ mod test {
fn test_no_item_error() {
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem);
assert_eq!(
tree.safe_sample(&mut rng).unwrap_err(),
WeightedError::NoItem
);
}

#[test]
Expand All @@ -297,7 +314,7 @@ mod test {
let tree = WeightedTreeIndex::<f64>::new(&[0.0, 0.0]).unwrap();
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
assert_eq!(
tree.sample(&mut rng).unwrap_err(),
tree.safe_sample(&mut rng).unwrap_err(),
WeightedError::AllWeightsZero
);
}
Expand Down Expand Up @@ -350,7 +367,7 @@ mod test {
}
let mut counts = alloc::vec![0_usize; end];
for _ in 0..samples {
let i = tree.sample(&mut rng).unwrap();
let i = tree.sample(&mut rng);
counts[i] += 1;
}
for i in 0..start {
Expand Down

0 comments on commit a23e842

Please sign in to comment.