diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index e106f4b5b1..9b48b622fa 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -37,10 +37,10 @@ jobs: uses: actions/configure-pages@v4 - name: Upload artifact - uses: actions/upload-pages-artifact@v1 + uses: actions/upload-pages-artifact@v3 with: path: './target/doc' - name: Deploy to GitHub Pages id: deployment - uses: actions/deploy-pages@v1 + uses: actions/deploy-pages@v4 diff --git a/examples/monty-hall.rs b/examples/monty-hall.rs index 23e1117896..7499193bce 100644 --- a/examples/monty-hall.rs +++ b/examples/monty-hall.rs @@ -61,7 +61,7 @@ fn simulate(random_door: &Uniform, rng: &mut R) -> SimulationResult // Returns the door the game host opens given our choice and knowledge of // where the car is. The game host will never open the door with the car. fn game_host_open(car: u32, choice: u32, rng: &mut R) -> u32 { - use rand::seq::SliceRandom; + use rand::seq::IndexedRandom; *free_doors(&[car, choice]).choose(rng).unwrap() } diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index b2bc28e1a2..36533d4646 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -30,7 +30,7 @@ serde1 = ["serde", "rand/serde1"] rand = { path = "..", version = "=0.9.0-alpha.0", default-features = false } num-traits = { version = "0.2", default-features = false, features = ["libm"] } serde = { version = "1.0.103", features = ["derive"], optional = true } -serde_with = { version = "1.14.0", optional = true } +serde_with = { version = "3.6.1", optional = true } [dev-dependencies] rand_pcg = { version = "=0.9.0-alpha.0", path = "../rand_pcg" } diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 1e28aaaa79..dc155bb5d5 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -130,7 +130,7 @@ pub use self::weibull::{Error as WeibullError, Weibull}; pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError}; #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub use rand::distributions::{WeightedError, WeightedIndex}; +pub use rand::distributions::{WeightError, WeightedIndex}; #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub use weighted_alias::WeightedAliasIndex; diff --git a/rand_distr/src/weighted_alias.rs b/rand_distr/src/weighted_alias.rs index 170de80c4a..236e2ad734 100644 --- a/rand_distr/src/weighted_alias.rs +++ b/rand_distr/src/weighted_alias.rs @@ -9,7 +9,7 @@ //! This module contains an implementation of alias method for sampling random //! indices with probabilities proportional to a collection of weights. -use super::WeightedError; +use super::WeightError; use crate::{uniform::SampleUniform, Distribution, Uniform}; use core::fmt; use core::iter::Sum; @@ -79,18 +79,15 @@ pub struct WeightedAliasIndex { impl WeightedAliasIndex { /// Creates a new [`WeightedAliasIndex`]. /// - /// Returns an error if: - /// - The vector is empty. - /// - The vector is longer than `u32::MAX`. - /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / - /// weights.len()`. - /// - The sum of weights is zero. - pub fn new(weights: Vec) -> Result { + /// Error cases: + /// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`. + /// - [`WeightError::InvalidWeight`] when a weight is not-a-number, + /// negative or greater than `max = W::MAX / weights.len()`. + /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. + pub fn new(weights: Vec) -> Result { let n = weights.len(); - if n == 0 { - return Err(WeightedError::NoItem); - } else if n > ::core::u32::MAX as usize { - return Err(WeightedError::TooMany); + if n == 0 || n > ::core::u32::MAX as usize { + return Err(WeightError::InvalidInput); } let n = n as u32; @@ -101,7 +98,7 @@ impl WeightedAliasIndex { .iter() .all(|&w| W::ZERO <= w && w <= max_weight_size) { - return Err(WeightedError::InvalidWeight); + return Err(WeightError::InvalidWeight); } // The sum of weights will represent 100% of no alias odds. @@ -113,7 +110,7 @@ impl WeightedAliasIndex { weight_sum }; if weight_sum == W::ZERO { - return Err(WeightedError::AllWeightsZero); + return Err(WeightError::InsufficientNonZero); } // `weight_sum` would have been zero if `try_from_lossy` causes an error here. @@ -382,23 +379,23 @@ mod test { // Floating point special cases assert_eq!( WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(), - WeightedError::AllWeightsZero + WeightError::InsufficientNonZero ); assert_eq!( WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); } @@ -416,11 +413,11 @@ mod test { // Signed integer special cases assert_eq!( WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); } @@ -438,11 +435,11 @@ mod test { // Signed integer special cases assert_eq!( WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); } @@ -486,15 +483,15 @@ mod test { assert_eq!( WeightedAliasIndex::::new(vec![]).unwrap_err(), - WeightedError::NoItem + WeightError::InvalidInput ); assert_eq!( WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(), - WeightedError::AllWeightsZero + WeightError::InsufficientNonZero ); assert_eq!( WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); } diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index b308cdb2c0..d5b4ef467d 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -11,7 +11,7 @@ use core::ops::SubAssign; -use super::WeightedError; +use super::WeightError; use crate::Distribution; use alloc::vec::Vec; use rand::distributions::uniform::{SampleBorrow, SampleUniform}; @@ -98,15 +98,19 @@ impl + Weight> WeightedTreeIndex { /// Creates a new [`WeightedTreeIndex`] from a slice of weights. - pub fn new(weights: I) -> Result + /// + /// Error cases: + /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`WeightError::Overflow`] when the sum of all weights overflows. + pub fn new(weights: I) -> Result where I: IntoIterator, I::Item: SampleBorrow, { let mut subtotals: Vec = weights.into_iter().map(|x| x.borrow().clone()).collect(); for weight in subtotals.iter() { - if *weight < W::ZERO { - return Err(WeightedError::InvalidWeight); + if !(*weight >= W::ZERO) { + return Err(WeightError::InvalidWeight); } } let n = subtotals.len(); @@ -115,7 +119,7 @@ impl + Weight> let parent = (i - 1) / 2; subtotals[parent] .checked_add_assign(&w) - .map_err(|()| WeightedError::Overflow)?; + .map_err(|()| WeightError::Overflow)?; } Ok(Self { subtotals }) } @@ -164,14 +168,18 @@ impl + Weight> } /// Appends a new weight at the end. - pub fn push(&mut self, weight: W) -> Result<(), WeightedError> { - if weight < W::ZERO { - return Err(WeightedError::InvalidWeight); + /// + /// Error cases: + /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`WeightError::Overflow`] when the sum of all weights overflows. + pub fn push(&mut self, weight: W) -> Result<(), WeightError> { + if !(weight >= W::ZERO) { + return Err(WeightError::InvalidWeight); } if let Some(total) = self.subtotals.first() { let mut total = total.clone(); if total.checked_add_assign(&weight).is_err() { - return Err(WeightedError::Overflow); + return Err(WeightError::Overflow); } } let mut index = self.len(); @@ -184,9 +192,13 @@ impl + Weight> } /// Updates the weight at an index. - pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> { - if weight < W::ZERO { - return Err(WeightedError::InvalidWeight); + /// + /// Error cases: + /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`WeightError::Overflow`] when the sum of all weights overflows. + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightError> { + if !(weight >= W::ZERO) { + return Err(WeightError::InvalidWeight); } let old_weight = self.get(index); if weight > old_weight { @@ -195,7 +207,7 @@ impl + Weight> if let Some(total) = self.subtotals.first() { let mut total = total.clone(); if total.checked_add_assign(&difference).is_err() { - return Err(WeightedError::Overflow); + return Err(WeightError::Overflow); } } self.subtotals[index] @@ -235,13 +247,10 @@ impl + Weight> /// /// Returns an error if there are no elements or all weights are zero. This /// is unlike [`Distribution::sample`], which panics in those cases. - fn try_sample(&self, rng: &mut R) -> Result { - if self.subtotals.is_empty() { - return Err(WeightedError::NoItem); - } - let total_weight = self.subtotals[0].clone(); + fn try_sample(&self, rng: &mut R) -> Result { + let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO); if total_weight == W::ZERO { - return Err(WeightedError::AllWeightsZero); + return Err(WeightError::InsufficientNonZero); } let mut target_weight = rng.gen_range(W::ZERO..total_weight); let mut index = 0; @@ -296,7 +305,7 @@ mod test { let tree = WeightedTreeIndex::::new(&[]).unwrap(); assert_eq!( tree.try_sample(&mut rng).unwrap_err(), - WeightedError::NoItem + WeightError::InsufficientNonZero ); } @@ -304,11 +313,11 @@ mod test { fn test_overflow_error() { assert_eq!( WeightedTreeIndex::new(&[i32::MAX, 2]), - Err(WeightedError::Overflow) + Err(WeightError::Overflow) ); let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap(); - assert_eq!(tree.push(3), Err(WeightedError::Overflow)); - assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow)); + assert_eq!(tree.push(3), Err(WeightError::Overflow)); + assert_eq!(tree.update(1, 4), Err(WeightError::Overflow)); tree.update(1, 2).unwrap(); } @@ -318,7 +327,7 @@ mod test { let mut rng = crate::test::rng(0x9c9fa0b0580a7031); assert_eq!( tree.try_sample(&mut rng).unwrap_err(), - WeightedError::AllWeightsZero + WeightError::InsufficientNonZero ); } @@ -326,14 +335,14 @@ mod test { fn test_invalid_weight_error() { assert_eq!( WeightedTreeIndex::::new(&[1, -1]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight); + assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight); tree.push(1).unwrap(); assert_eq!( tree.update(0, -1).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); } diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 5adb82f811..39d967d4f6 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -108,13 +108,6 @@ pub mod hidden_export { pub use super::float::IntoFloat; // used by rand_distr } pub mod uniform; -#[deprecated( - since = "0.8.0", - note = "use rand::distributions::{WeightedIndex, WeightedError} instead" -)] -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod weighted; pub use self::bernoulli::{Bernoulli, BernoulliError}; pub use self::distribution::{Distribution, DistIter, DistMap}; @@ -126,7 +119,7 @@ pub use self::slice::Slice; #[doc(inline)] pub use self::uniform::Uniform; #[cfg(feature = "alloc")] -pub use self::weighted_index::{Weight, WeightedError, WeightedIndex}; +pub use self::weighted_index::{Weight, WeightError, WeightedIndex}; #[allow(unused)] use crate::Rng; diff --git a/src/distributions/slice.rs b/src/distributions/slice.rs index 224bf1712c..5fc08751f6 100644 --- a/src/distributions/slice.rs +++ b/src/distributions/slice.rs @@ -15,7 +15,7 @@ use alloc::string::String; /// [`Slice::new`] constructs a distribution referencing a slice and uniformly /// samples references from the items in the slice. It may do extra work up /// front to make sampling of multiple values faster; if only one sample from -/// the slice is required, [`SliceRandom::choose`] can be more efficient. +/// the slice is required, [`IndexedRandom::choose`] can be more efficient. /// /// Steps are taken to avoid bias which might be present in naive /// implementations; for example `slice[rng.gen() % slice.len()]` samples from @@ -25,7 +25,7 @@ use alloc::string::String; /// This distribution samples with replacement; each sample is independent. /// Sampling without replacement requires state to be retained, and therefore /// cannot be handled by a distribution; you should instead consider methods -/// on [`SliceRandom`], such as [`SliceRandom::choose_multiple`]. +/// on [`IndexedRandom`], such as [`IndexedRandom::choose_multiple`]. /// /// # Example /// @@ -48,11 +48,11 @@ use alloc::string::String; /// assert!(vowel_string.chars().all(|c| vowels.contains(&c))); /// ``` /// -/// For a single sample, [`SliceRandom::choose`][crate::seq::SliceRandom::choose] +/// For a single sample, [`IndexedRandom::choose`][crate::seq::IndexedRandom::choose] /// may be preferred: /// /// ``` -/// use rand::seq::SliceRandom; +/// use rand::seq::IndexedRandom; /// /// let vowels = ['a', 'e', 'i', 'o', 'u']; /// let mut rng = rand::thread_rng(); @@ -60,9 +60,9 @@ use alloc::string::String; /// println!("{}", vowels.choose(&mut rng).unwrap()) /// ``` /// -/// [`SliceRandom`]: crate::seq::SliceRandom -/// [`SliceRandom::choose`]: crate::seq::SliceRandom::choose -/// [`SliceRandom::choose_multiple`]: crate::seq::SliceRandom::choose_multiple +/// [`IndexedRandom`]: crate::seq::IndexedRandom +/// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose +/// [`IndexedRandom::choose_multiple`]: crate::seq::IndexedRandom::choose_multiple #[derive(Debug, Clone, Copy)] pub struct Slice<'a, T> { slice: &'a [T], diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs deleted file mode 100644 index 846b9df9c2..0000000000 --- a/src/distributions/weighted.rs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Weighted index sampling -//! -//! This module is deprecated. Use [`crate::distributions::WeightedIndex`] and -//! [`crate::distributions::WeightedError`] instead. - -pub use super::{WeightedIndex, WeightedError}; - -#[allow(missing_docs)] -#[deprecated(since = "0.8.0", note = "moved to rand_distr crate")] -pub mod alias_method { - // This module exists to provide a deprecation warning which minimises - // compile errors, but still fails to compile if ever used. - use core::marker::PhantomData; - use alloc::vec::Vec; - use super::WeightedError; - - #[derive(Debug)] - pub struct WeightedIndex { - _phantom: PhantomData, - } - impl WeightedIndex { - pub fn new(_weights: Vec) -> Result { - Err(WeightedError::NoItem) - } - } - - pub trait Weight {} - macro_rules! impl_weight { - () => {}; - ($T:ident, $($more:ident,)*) => { - impl Weight for $T {} - impl_weight!($($more,)*); - }; - } - impl_weight!(f64, f32,); - impl_weight!(u8, u16, u32, u64, usize,); - impl_weight!(i8, i16, i32, i64, isize,); - impl_weight!(u128, i128,); -} diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 0b1b4da947..49cb02d6ad 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -94,22 +94,25 @@ impl WeightedIndex { /// in `weights`. The weights can use any type `X` for which an /// implementation of [`Uniform`] exists. /// - /// Returns an error if the iterator is empty, if any weight is `< 0`, or - /// if its total value is 0. + /// Error cases: + /// - [`WeightError::InvalidInput`] when the iterator `weights` is empty. + /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. + /// - [`WeightError::Overflow`] when the sum of all weights overflows. /// /// [`Uniform`]: crate::distributions::uniform::Uniform - pub fn new(weights: I) -> Result, WeightedError> + pub fn new(weights: I) -> Result, WeightError> where I: IntoIterator, I::Item: SampleBorrow, X: Weight, { let mut iter = weights.into_iter(); - let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); + let mut total_weight: X = iter.next().ok_or(WeightError::InvalidInput)?.borrow().clone(); let zero = X::ZERO; if !(total_weight >= zero) { - return Err(WeightedError::InvalidWeight); + return Err(WeightError::InvalidWeight); } let mut weights = Vec::::with_capacity(iter.size_hint().0); @@ -117,17 +120,17 @@ impl WeightedIndex { // Note that `!(w >= x)` is not equivalent to `w < x` for partially // ordered types due to NaNs which are equal to nothing. if !(w.borrow() >= &zero) { - return Err(WeightedError::InvalidWeight); + return Err(WeightError::InvalidWeight); } weights.push(total_weight.clone()); if let Err(()) = total_weight.checked_add_assign(w.borrow()) { - return Err(WeightedError::Overflow); + return Err(WeightError::Overflow); } } if total_weight == zero { - return Err(WeightedError::AllWeightsZero); + return Err(WeightError::InsufficientNonZero); } let distr = X::Sampler::new(zero, total_weight.clone()).unwrap(); @@ -146,16 +149,19 @@ impl WeightedIndex { /// weights is modified. No allocations are performed, unless the weight type `X` uses /// allocation internally. /// - /// In case of error, `self` is not modified. + /// In case of error, `self` is not modified. Error cases: + /// - [`WeightError::InvalidInput`] when `new_weights` are not ordered by + /// index or an index is too large. + /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. + /// Note that due to floating-point loss of precision, this case is not + /// always correctly detected; usage of a fixed-point weight type may be + /// preferred. /// /// Updates take `O(N)` time. If you need to frequently update weights, consider /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) /// as an alternative where an update is `O(log N)`. - /// - /// Note: Updating floating-point weights may cause slight inaccuracies in the total weight. - /// This method may not return `WeightedError::AllWeightsZero` when all weights - /// are zero if using floating-point weights. - pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightError> where X: for<'a> ::core::ops::AddAssign<&'a X> + for<'a> ::core::ops::SubAssign<&'a X> @@ -176,14 +182,14 @@ impl WeightedIndex { for &(i, w) in new_weights { if let Some(old_i) = prev_i { if old_i >= i { - return Err(WeightedError::InvalidWeight); + return Err(WeightError::InvalidInput); } } if !(*w >= zero) { - return Err(WeightedError::InvalidWeight); + return Err(WeightError::InvalidWeight); } if i > self.cumulative_weights.len() { - return Err(WeightedError::TooMany); + return Err(WeightError::InvalidInput); } let mut old_w = if i < self.cumulative_weights.len() { @@ -200,7 +206,7 @@ impl WeightedIndex { prev_i = Some(i); } if total_weight <= zero { - return Err(WeightedError::AllWeightsZero); + return Err(WeightError::InsufficientNonZero); } // Update the weights. Because we checked all the preconditions in the @@ -328,15 +334,15 @@ mod test { fn test_accepting_nan() { assert_eq!( WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), - WeightedError::InvalidWeight, + WeightError::InvalidWeight, ); assert_eq!( WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight, + WeightError::InvalidWeight, ); assert_eq!( WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight, + WeightError::InvalidWeight, ); assert_eq!( @@ -344,7 +350,7 @@ mod test { .unwrap() .update_weights(&[(0, &core::f32::NAN)]) .unwrap_err(), - WeightedError::InvalidWeight, + WeightError::InvalidWeight, ) } @@ -404,23 +410,23 @@ mod test { assert_eq!( WeightedIndex::new(&[10][0..0]).unwrap_err(), - WeightedError::NoItem + WeightError::InvalidInput ); assert_eq!( WeightedIndex::new(&[0]).unwrap_err(), - WeightedError::AllWeightsZero + WeightError::InsufficientNonZero ); assert_eq!( WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); assert_eq!( WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); assert_eq!( WeightedIndex::new(&[-10]).unwrap_err(), - WeightedError::InvalidWeight + WeightError::InvalidWeight ); } @@ -497,43 +503,42 @@ mod test { fn overflow() { assert_eq!( WeightedIndex::new([2, usize::MAX]), - Err(WeightedError::Overflow) + Err(WeightError::Overflow) ); } } -/// Error type returned from `WeightedIndex::new`. +/// Errors returned by weighted distributions #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WeightedError { - /// The provided weight collection contains no items. - NoItem, +pub enum WeightError { + /// The input weight sequence is empty, too long, or wrongly ordered + InvalidInput, - /// A weight is either less than zero, greater than the supported maximum, - /// NaN, or otherwise invalid. + /// A weight is negative, too large for the distribution, or not a valid number InvalidWeight, - /// All items in the provided weight collection are zero. - AllWeightsZero, - - /// Too many weights are provided (length greater than `u32::MAX`) - TooMany, + /// Not enough non-zero weights are available to sample values + /// + /// When attempting to sample a single value this implies that all weights + /// are zero. When attempting to sample `amount` values this implies that + /// less than `amount` weights are greater than zero. + InsufficientNonZero, - /// The sum of weights overflows + /// Overflow when calculating the sum of weights Overflow, } #[cfg(feature = "std")] -impl std::error::Error for WeightedError {} +impl std::error::Error for WeightError {} -impl fmt::Display for WeightedError { +impl fmt::Display for WeightError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str(match *self { - WeightedError::NoItem => "No weights provided in distribution", - WeightedError::InvalidWeight => "A weight is invalid in distribution", - WeightedError::AllWeightsZero => "All weights are zero in distribution", - WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", - WeightedError::Overflow => "The sum of weights overflowed", + WeightError::InvalidInput => "Weights sequence is empty/too long/unordered", + WeightError::InvalidWeight => "A weight is negative, too large or not a valid number", + WeightError::InsufficientNonZero => "Not enough weights > zero", + WeightError::Overflow => "Overflow when summing weights", }) } } diff --git a/src/prelude.rs b/src/prelude.rs index 1ce747b625..35fee3d73f 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -27,7 +27,8 @@ pub use crate::rngs::SmallRng; #[doc(no_inline)] #[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] pub use crate::rngs::ThreadRng; -#[doc(no_inline)] pub use crate::seq::{IteratorRandom, SliceRandom}; +#[doc(no_inline)] +pub use crate::seq::{IndexedMutRandom, IndexedRandom, IteratorRandom, SliceRandom}; #[doc(no_inline)] #[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] pub use crate::{random, thread_rng}; diff --git a/src/seq/index.rs b/src/seq/index.rs index 956ea60ed8..e98b7ec106 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -17,7 +17,7 @@ use alloc::collections::BTreeSet; #[cfg(feature = "std")] use std::collections::HashSet; #[cfg(feature = "std")] -use crate::distributions::WeightedError; +use super::WeightError; #[cfg(feature = "alloc")] use crate::{Rng, distributions::{uniform::SampleUniform, Distribution, Uniform}}; @@ -267,14 +267,16 @@ where R: Rng + ?Sized { /// sometimes be useful to have the indices themselves so this is provided as /// an alternative. /// -/// This implementation uses `O(length + amount)` space and `O(length)` time. +/// Error cases: +/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. +/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. /// -/// Panics if `amount > length`. +/// This implementation uses `O(length + amount)` space and `O(length)` time. #[cfg(feature = "std")] #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] pub fn sample_weighted( rng: &mut R, length: usize, weight: F, amount: usize, -) -> Result +) -> Result where R: Rng + ?Sized, F: Fn(usize) -> X, @@ -300,11 +302,13 @@ where /// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 /// It uses `O(length + amount)` space and `O(length)` time. /// -/// Panics if `amount > length`. +/// Error cases: +/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. +/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. #[cfg(feature = "std")] fn sample_efraimidis_spirakis( rng: &mut R, length: N, weight: F, amount: N, -) -> Result +) -> Result where R: Rng + ?Sized, F: Fn(usize) -> X, @@ -316,10 +320,6 @@ where return Ok(IndexVec::U32(Vec::new())); } - if amount > length { - panic!("`amount` of samples must be less than or equal to `length`"); - } - struct Element { index: N, key: f64, @@ -347,22 +347,27 @@ where let mut index = N::zero(); while index < length { let weight = weight(index.as_usize()).into(); - if !(weight >= 0.) { - return Err(WeightedError::InvalidWeight); + if weight > 0.0 { + let key = rng.gen::().powf(1.0 / weight); + candidates.push(Element { index, key }); + } else if !(weight >= 0.0) { + return Err(WeightError::InvalidWeight); } - let key = rng.gen::().powf(1.0 / weight); - candidates.push(Element { index, key }); - index += N::one(); } + let avail = candidates.len(); + if avail < amount.as_usize() { + return Err(WeightError::InsufficientNonZero); + } + // Partially sort the array to find the `amount` elements with the greatest // keys. Do this by using `select_nth_unstable` to put the elements with // the *smallest* keys at the beginning of the list in `O(n)` time, which // provides equivalent information about the elements with the *greatest* keys. let (_, mid, greater) - = candidates.select_nth_unstable(length.as_usize() - amount.as_usize()); + = candidates.select_nth_unstable(avail - amount.as_usize()); let mut result: Vec = Vec::with_capacity(amount.as_usize()); result.push(mid.index); @@ -576,7 +581,7 @@ mod test { #[test] fn test_sample_weighted() { let seed_rng = crate::test::rng; - for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] { + for &(amount, len) in &[(0, 10), (5, 10), (9, 10)] { let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap(); match v { IndexVec::U32(mut indices) => { @@ -591,6 +596,9 @@ mod test { IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), } } + + let r = sample_weighted(&mut seed_rng(423), 10, |i| i as f64, 10); + assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); } #[test] diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 9012b21b90..f5cbc6008e 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -10,8 +10,10 @@ //! //! This module provides: //! -//! * [`SliceRandom`] slice sampling and mutation -//! * [`IteratorRandom`] iterator sampling +//! * [`IndexedRandom`] for sampling slices and other indexable lists +//! * [`IndexedMutRandom`] for sampling slices and other mutably indexable lists +//! * [`SliceRandom`] for mutating slices +//! * [`IteratorRandom`] for sampling iterators //! * [`index::sample`] low-level API to choose multiple indices from //! `0..length` //! @@ -32,41 +34,36 @@ pub mod index; mod increasing_uniform; #[cfg(feature = "alloc")] -use core::ops::Index; +#[doc(no_inline)] +pub use crate::distributions::WeightError; + +use core::ops::{Index, IndexMut}; #[cfg(feature = "alloc")] use alloc::vec::Vec; #[cfg(feature = "alloc")] use crate::distributions::uniform::{SampleBorrow, SampleUniform}; -#[cfg(feature = "alloc")] -use crate::distributions::{Weight, WeightedError}; +#[cfg(feature = "alloc")] use crate::distributions::Weight; use crate::Rng; use self::coin_flipper::CoinFlipper; use self::increasing_uniform::IncreasingUniform; -/// Extension trait on slices, providing random mutation and sampling methods. -/// -/// This trait is implemented on all `[T]` slice types, providing several -/// methods for choosing and shuffling elements. You must `use` this trait: -/// -/// ``` -/// use rand::seq::SliceRandom; +/// Extension trait on indexable lists, providing random sampling methods. /// -/// let mut rng = rand::thread_rng(); -/// let mut bytes = "Hello, random!".to_string().into_bytes(); -/// bytes.shuffle(&mut rng); -/// let str = String::from_utf8(bytes).unwrap(); -/// println!("{}", str); -/// ``` -/// Example output (non-deterministic): -/// ```none -/// l,nmroHado !le -/// ``` -pub trait SliceRandom { - /// The element type. - type Item; +/// This trait is implemented on `[T]` slice types. Other types supporting +/// [`std::ops::Index`] may implement this (only [`Self::len`] must be +/// specified). +pub trait IndexedRandom: Index { + /// The length + fn len(&self) -> usize; + + /// True when the length is zero + #[inline] + fn is_empty(&self) -> bool { + self.len() == 0 + } /// Uniformly sample one element /// @@ -79,26 +76,23 @@ pub trait SliceRandom { /// /// ``` /// use rand::thread_rng; - /// use rand::seq::SliceRandom; + /// use rand::seq::IndexedRandom; /// /// let choices = [1, 2, 4, 8, 16, 32]; /// let mut rng = thread_rng(); /// println!("{:?}", choices.choose(&mut rng)); /// assert_eq!(choices[..0].choose(&mut rng), None); /// ``` - fn choose(&self, rng: &mut R) -> Option<&Self::Item> - where - R: Rng + ?Sized; - - /// Uniformly sample one element (mut) - /// - /// Returns a mutable reference to one uniformly-sampled random element of - /// the slice, or `None` if the slice is empty. - /// - /// For slices, complexity is `O(1)`. - fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> + fn choose(&self, rng: &mut R) -> Option<&Self::Output> where - R: Rng + ?Sized; + R: Rng + ?Sized, + { + if self.is_empty() { + None + } else { + Some(&self[gen_index(rng, self.len())]) + } + } /// Uniformly sample `amount` distinct elements /// @@ -112,7 +106,7 @@ pub trait SliceRandom { /// /// # Example /// ``` - /// use rand::seq::SliceRandom; + /// use rand::seq::IndexedRandom; /// /// let mut rng = &mut rand::thread_rng(); /// let sample = "Hello, audience!".as_bytes(); @@ -128,9 +122,18 @@ pub trait SliceRandom { /// ``` #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter + fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter where - R: Rng + ?Sized; + Self::Output: Sized, + R: Rng + ?Sized, + { + let amount = ::core::cmp::min(amount, self.len()); + SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample(rng, self.len(), amount).into_iter(), + } + } /// Biased sampling for one element /// @@ -158,48 +161,24 @@ pub trait SliceRandom { /// // and 'd' will never be printed /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); /// ``` - /// [`choose`]: SliceRandom::choose - /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut + /// [`choose`]: IndexedRandom::choose + /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut /// [`distributions::WeightedIndex`]: crate::distributions::WeightedIndex #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] fn choose_weighted( &self, rng: &mut R, weight: F, - ) -> Result<&Self::Item, WeightedError> + ) -> Result<&Self::Output, WeightError> where R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, + F: Fn(&Self::Output) -> B, B: SampleBorrow, - X: SampleUniform + Weight + ::core::cmp::PartialOrd; - - /// Biased sampling for one element (mut) - /// - /// Returns a mutable reference to one element of the slice, sampled according - /// to the provided weights. Returns `None` only if the slice is empty. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// For more information about the underlying algorithm, - /// see [`distributions::WeightedIndex`]. - /// - /// See also [`choose_weighted`]. - /// - /// [`choose_mut`]: SliceRandom::choose_mut - /// [`choose_weighted`]: SliceRandom::choose_weighted - /// [`distributions::WeightedIndex`]: crate::distributions::WeightedIndex - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_weighted_mut( - &mut self, rng: &mut R, weight: F, - ) -> Result<&mut Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform + Weight + ::core::cmp::PartialOrd; + X: SampleUniform + Weight + ::core::cmp::PartialOrd, + { + use crate::distributions::{Distribution, WeightedIndex}; + let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; + Ok(&self[distr.sample(rng)]) + } /// Biased sampling of `amount` distinct elements /// @@ -232,7 +211,7 @@ pub trait SliceRandom { /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); /// ``` - /// [`choose_multiple`]: SliceRandom::choose_multiple + /// [`choose_multiple`]: IndexedRandom::choose_multiple // // Note: this is feature-gated on std due to usage of f64::powf. // If necessary, we may use alloc+libm as an alternative (see PR #1089). @@ -240,12 +219,106 @@ pub trait SliceRandom { #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] fn choose_multiple_weighted( &self, rng: &mut R, amount: usize, weight: F, - ) -> Result, WeightedError> + ) -> Result, WeightError> + where + Self::Output: Sized, + R: Rng + ?Sized, + F: Fn(&Self::Output) -> X, + X: Into, + { + let amount = ::core::cmp::min(amount, self.len()); + Ok(SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample_weighted( + rng, + self.len(), + |idx| weight(&self[idx]).into(), + amount, + )? + .into_iter(), + }) + } +} + +/// Extension trait on indexable lists, providing random sampling methods. +/// +/// This trait is implemented automatically for every type implementing +/// [`IndexedRandom`] and [`std::ops::IndexMut`]. +pub trait IndexedMutRandom: IndexedRandom + IndexMut { + /// Uniformly sample one element (mut) + /// + /// Returns a mutable reference to one uniformly-sampled random element of + /// the slice, or `None` if the slice is empty. + /// + /// For slices, complexity is `O(1)`. + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Output> + where + R: Rng + ?Sized, + { + if self.is_empty() { + None + } else { + let len = self.len(); + Some(&mut self[gen_index(rng, len)]) + } + } + + /// Biased sampling for one element (mut) + /// + /// Returns a mutable reference to one element of the slice, sampled according + /// to the provided weights. Returns `None` only if the slice is empty. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// For more information about the underlying algorithm, + /// see [`distributions::WeightedIndex`]. + /// + /// See also [`choose_weighted`]. + /// + /// [`choose_mut`]: IndexedMutRandom::choose_mut + /// [`choose_weighted`]: IndexedRandom::choose_weighted + /// [`distributions::WeightedIndex`]: crate::distributions::WeightedIndex + #[cfg(feature = "alloc")] + #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] + fn choose_weighted_mut( + &mut self, rng: &mut R, weight: F, + ) -> Result<&mut Self::Output, WeightError> where R: Rng + ?Sized, - F: Fn(&Self::Item) -> X, - X: Into; + F: Fn(&Self::Output) -> B, + B: SampleBorrow, + X: SampleUniform + Weight + ::core::cmp::PartialOrd, + { + use crate::distributions::{Distribution, WeightedIndex}; + let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; + let index = distr.sample(rng); + Ok(&mut self[index]) + } +} +/// Extension trait on slices, providing shuffling methods. +/// +/// This trait is implemented on all `[T]` slice types, providing several +/// methods for choosing and shuffling elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::SliceRandom; +/// +/// let mut rng = rand::thread_rng(); +/// let mut bytes = "Hello, random!".to_string().into_bytes(); +/// bytes.shuffle(&mut rng); +/// let str = String::from_utf8(bytes).unwrap(); +/// println!("{}", str); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// l,nmroHado !le +/// ``` +pub trait SliceRandom: IndexedMutRandom { /// Shuffle a mutable slice in place. /// /// For slices of length `n`, complexity is `O(n)`. @@ -286,8 +359,9 @@ pub trait SliceRandom { /// For slices, complexity is `O(m)` where `m = amount`. fn partial_shuffle( &mut self, rng: &mut R, amount: usize, - ) -> (&mut [Self::Item], &mut [Self::Item]) + ) -> (&mut [Self::Output], &mut [Self::Output]) where + Self::Output: Sized, R: Rng + ?Sized; } @@ -460,7 +534,7 @@ pub trait IteratorRandom: Iterator + Sized { /// case this equals the number of elements available. /// /// Complexity is `O(n)` where `n` is the length of the iterator. - /// For slices, prefer [`SliceRandom::choose_multiple`]. + /// For slices, prefer [`IndexedRandom::choose_multiple`]. fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize where R: Rng + ?Sized, @@ -500,7 +574,7 @@ pub trait IteratorRandom: Iterator + Sized { /// elements available. /// /// Complexity is `O(n)` where `n` is the length of the iterator. - /// For slices, prefer [`SliceRandom::choose_multiple`]. + /// For slices, prefer [`IndexedRandom::choose_multiple`]. #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec @@ -530,98 +604,15 @@ pub trait IteratorRandom: Iterator + Sized { } } -impl SliceRandom for [T] { - type Item = T; - - fn choose(&self, rng: &mut R) -> Option<&Self::Item> - where - R: Rng + ?Sized, - { - if self.is_empty() { - None - } else { - Some(&self[gen_index(rng, self.len())]) - } - } - - fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> - where - R: Rng + ?Sized, - { - if self.is_empty() { - None - } else { - let len = self.len(); - Some(&mut self[gen_index(rng, len)]) - } - } - - #[cfg(feature = "alloc")] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter - where - R: Rng + ?Sized, - { - let amount = ::core::cmp::min(amount, self.len()); - SliceChooseIter { - slice: self, - _phantom: Default::default(), - indices: index::sample(rng, self.len(), amount).into_iter(), - } - } - - #[cfg(feature = "alloc")] - fn choose_weighted( - &self, rng: &mut R, weight: F, - ) -> Result<&Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform + Weight + ::core::cmp::PartialOrd, - { - use crate::distributions::{Distribution, WeightedIndex}; - let distr = WeightedIndex::new(self.iter().map(weight))?; - Ok(&self[distr.sample(rng)]) +impl IndexedRandom for [T] { + fn len(&self) -> usize { + self.len() } +} - #[cfg(feature = "alloc")] - fn choose_weighted_mut( - &mut self, rng: &mut R, weight: F, - ) -> Result<&mut Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform + Weight + ::core::cmp::PartialOrd, - { - use crate::distributions::{Distribution, WeightedIndex}; - let distr = WeightedIndex::new(self.iter().map(weight))?; - Ok(&mut self[distr.sample(rng)]) - } - - #[cfg(feature = "std")] - fn choose_multiple_weighted( - &self, rng: &mut R, amount: usize, weight: F, - ) -> Result, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> X, - X: Into, - { - let amount = ::core::cmp::min(amount, self.len()); - Ok(SliceChooseIter { - slice: self, - _phantom: Default::default(), - indices: index::sample_weighted( - rng, - self.len(), - |idx| weight(&self[idx]).into(), - amount, - )? - .into_iter(), - }) - } +impl + ?Sized> IndexedMutRandom for IR {} +impl SliceRandom for [T] { fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized, @@ -635,7 +626,7 @@ impl SliceRandom for [T] { fn partial_shuffle( &mut self, rng: &mut R, amount: usize, - ) -> (&mut [Self::Item], &mut [Self::Item]) + ) -> (&mut [T], &mut [T]) where R: Rng + ?Sized, { @@ -672,7 +663,7 @@ impl IteratorRandom for I where I: Iterator + Sized {} /// An iterator over multiple slice elements. /// /// This struct is created by -/// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple). +/// [`IndexedRandom::choose_multiple`](trait.IndexedRandom.html#tymethod.choose_multiple). #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[derive(Debug)] @@ -1187,23 +1178,23 @@ mod test { let empty_slice = &mut [10][0..0]; assert_eq!( empty_slice.choose_weighted(&mut r, |_| 1), - Err(WeightedError::NoItem) + Err(WeightError::InvalidInput) ); assert_eq!( empty_slice.choose_weighted_mut(&mut r, |_| 1), - Err(WeightedError::NoItem) + Err(WeightError::InvalidInput) ); assert_eq!( ['x'].choose_weighted_mut(&mut r, |_| 0), - Err(WeightedError::AllWeightsZero) + Err(WeightError::InsufficientNonZero) ); assert_eq!( [0, -1].choose_weighted_mut(&mut r, |x| *x), - Err(WeightedError::InvalidWeight) + Err(WeightError::InvalidWeight) ); assert_eq!( [-1, 0].choose_weighted_mut(&mut r, |x| *x), - Err(WeightedError::InvalidWeight) + Err(WeightError::InvalidWeight) ); } @@ -1340,42 +1331,23 @@ mod test { // Case 2: All of the weights are 0 let choices = [('a', 0), ('b', 0), ('c', 0)]; - - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .count(), - 2 - ); + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); // Case 3: Negative weights let choices = [('a', -1), ('b', 1), ('c', 1)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); // Case 4: Empty list let choices = []; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) - .unwrap() - .count(), - 0 - ); + let r = choices.choose_multiple_weighted(&mut rng, 0, |_: &()| 0); + assert_eq!(r.unwrap().count(), 0); // Case 5: NaN weights let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); // Case 6: +infinity weights let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; @@ -1390,18 +1362,13 @@ mod test { // Case 7: -infinity weights let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); // Case 8: -0 weights let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; - assert!(choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .is_ok()); + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert!(r.is_ok()); } #[test]