diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index ce5aad4dc2..370c38914d 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! This module contains an implementation of a tree sttructure for sampling random +//! This module contains an implementation of a tree structure for sampling random //! indices with probabilities proportional to a collection of weights. use core::ops::SubAssign; @@ -14,13 +14,9 @@ use core::ops::SubAssign; use super::WeightedError; use crate::Distribution; use alloc::vec::Vec; -use rand::{ - distributions::{ - uniform::{SampleBorrow, SampleUniform}, - Weight, - }, - Rng, -}; +use rand::distributions::uniform::{SampleBorrow, SampleUniform}; +use rand::distributions::Weight; +use rand::Rng; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -29,7 +25,7 @@ use serde::{Deserialize, Serialize}; /// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly /// selected element from the vector used to create the [`WeightedTreeIndex`]. /// The chance of a given element being picked is proportional to the value of -/// the element. The weights can have any type `W` for which a implementation of +/// the element. The weights can have any type `W` for which an implementation of /// [`Weight`] exists. /// /// # Key differences @@ -71,15 +67,16 @@ use serde::{Deserialize, Serialize}; /// dist.push(1).unwrap(); /// dist.update(1, 1).unwrap(); /// let mut rng = thread_rng(); +/// let mut samples = [0; 3]; /// for _ in 0..100 { /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// let i = dist.sample(&mut rng).unwrap(); -/// println!("{}", choices[i]); +/// let i = dist.sample(&mut rng); +/// samples[i] += 1; /// } +/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::>()); /// ``` /// /// [`WeightedTreeIndex`]: WeightedTreeIndex -/// [`Uniform::sample`]: Distribution::sample #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr( @@ -132,7 +129,7 @@ impl WeightedTreeInd /// Returns `true` if we can sample. /// /// This is the case if the total weight of the tree is greater than zero. - pub fn can_sample(&self) -> bool { + pub fn is_valid(&self) -> bool { if let Some(weight) = self.subtotals.first() { *weight > W::ZERO } else { @@ -229,9 +226,13 @@ impl WeightedTreeInd } impl + Weight> - Distribution> for WeightedTreeIndex + WeightedTreeIndex { - fn sample(&self, rng: &mut R) -> Result { + /// Samples a randomly selected index from the weighted distribution. + /// + /// Returns an error if there are no elements or all weights are zero. This + /// is unlike [`Distribution::sample`], which panics in those cases. + fn safe_sample(&self, rng: &mut R) -> Result { if self.subtotals.is_empty() { return Err(WeightedError::NoItem); } @@ -269,6 +270,19 @@ impl + Weight> } } +impl + Weight> Distribution + for WeightedTreeIndex +{ + /// Samples a randomly selected index from the weighted distribution. + /// + /// Caution: This method panics if there are no elements or all weights are zero. However, + /// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] + /// returns `true`. + fn sample(&self, rng: &mut R) -> usize { + self.safe_sample(rng).unwrap() + } +} + #[cfg(test)] mod test { use super::*; @@ -277,7 +291,10 @@ mod test { fn test_no_item_error() { let mut rng = crate::test::rng(0x9c9fa0b0580a7031); let tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem); + assert_eq!( + tree.safe_sample(&mut rng).unwrap_err(), + WeightedError::NoItem + ); } #[test] @@ -297,7 +314,7 @@ mod test { let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); let mut rng = crate::test::rng(0x9c9fa0b0580a7031); assert_eq!( - tree.sample(&mut rng).unwrap_err(), + tree.safe_sample(&mut rng).unwrap_err(), WeightedError::AllWeightsZero ); } @@ -350,7 +367,7 @@ mod test { } let mut counts = alloc::vec![0_usize; end]; for _ in 0..samples { - let i = tree.sample(&mut rng).unwrap(); + let i = tree.sample(&mut rng); counts[i] += 1; } for i in 0..start {