Skip to content

Commit

Permalink
Rename WeightedError → WeightError; add IndexedRandom, IndexedMutRand…
Browse files Browse the repository at this point in the history
…om (#1382)

* Remove deprecated module rand::distributions::weighted
* WeightedTree: return InvalidWeight on not-a-number
* WeightedTree::try_sample return AllWeightsZero given no weights
* Rename WeightedError -> WeightError and revise variants
* Re-export WeightError from rand::seq
* Revise errors of rand::index::sample_weighted
* Split SliceRandom into IndexedRandom, IndexedMutRandom and SliceRandom
  • Loading branch information
dhardy authored Feb 15, 2024
1 parent ef245fd commit dba696e
Show file tree
Hide file tree
Showing 11 changed files with 329 additions and 396 deletions.
2 changes: 1 addition & 1 deletion examples/monty-hall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn simulate<R: Rng>(random_door: &Uniform<u32>, rng: &mut R) -> SimulationResult
// Returns the door the game host opens given our choice and knowledge of
// where the car is. The game host will never open the door with the car.
fn game_host_open<R: Rng>(car: u32, choice: u32, rng: &mut R) -> u32 {
use rand::seq::SliceRandom;
use rand::seq::IndexedRandom;
*free_doors(&[car, choice]).choose(rng).unwrap()
}

Expand Down
2 changes: 1 addition & 1 deletion rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ pub use self::weibull::{Error as WeibullError, Weibull};
pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError};
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use rand::distributions::{WeightedError, WeightedIndex};
pub use rand::distributions::{WeightError, WeightedIndex};
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use weighted_alias::WeightedAliasIndex;
Expand Down
49 changes: 23 additions & 26 deletions rand_distr/src/weighted_alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! This module contains an implementation of alias method for sampling random
//! indices with probabilities proportional to a collection of weights.
use super::WeightedError;
use super::WeightError;
use crate::{uniform::SampleUniform, Distribution, Uniform};
use core::fmt;
use core::iter::Sum;
Expand Down Expand Up @@ -79,18 +79,15 @@ pub struct WeightedAliasIndex<W: AliasableWeight> {
impl<W: AliasableWeight> WeightedAliasIndex<W> {
/// Creates a new [`WeightedAliasIndex`].
///
/// Returns an error if:
/// - The vector is empty.
/// - The vector is longer than `u32::MAX`.
/// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX /
/// weights.len()`.
/// - The sum of weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
/// Error cases:
/// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number,
/// negative or greater than `max = W::MAX / weights.len()`.
/// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, WeightError> {
let n = weights.len();
if n == 0 {
return Err(WeightedError::NoItem);
} else if n > ::core::u32::MAX as usize {
return Err(WeightedError::TooMany);
if n == 0 || n > ::core::u32::MAX as usize {
return Err(WeightError::InvalidInput);
}
let n = n as u32;

Expand All @@ -101,7 +98,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
.iter()
.all(|&w| W::ZERO <= w && w <= max_weight_size)
{
return Err(WeightedError::InvalidWeight);
return Err(WeightError::InvalidWeight);
}

// The sum of weights will represent 100% of no alias odds.
Expand All @@ -113,7 +110,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
weight_sum
};
if weight_sum == W::ZERO {
return Err(WeightedError::AllWeightsZero);
return Err(WeightError::InsufficientNonZero);
}

// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
Expand Down Expand Up @@ -382,23 +379,23 @@ mod test {
// Floating point special cases
assert_eq!(
WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
WeightedError::AllWeightsZero
WeightError::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand All @@ -416,11 +413,11 @@ mod test {
// Signed integer special cases
assert_eq!(
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand All @@ -438,11 +435,11 @@ mod test {
// Signed integer special cases
assert_eq!(
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand Down Expand Up @@ -486,15 +483,15 @@ mod test {

assert_eq!(
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
WeightedError::NoItem
WeightError::InvalidInput
);
assert_eq!(
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
WeightedError::AllWeightsZero
WeightError::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand Down
63 changes: 36 additions & 27 deletions rand_distr/src/weighted_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
use core::ops::SubAssign;

use super::WeightedError;
use super::WeightError;
use crate::Distribution;
use alloc::vec::Vec;
use rand::distributions::uniform::{SampleBorrow, SampleUniform};
Expand Down Expand Up @@ -98,15 +98,19 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
WeightedTreeIndex<W>
{
/// Creates a new [`WeightedTreeIndex`] from a slice of weights.
pub fn new<I>(weights: I) -> Result<Self, WeightedError>
///
/// Error cases:
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
pub fn new<I>(weights: I) -> Result<Self, WeightError>
where
I: IntoIterator,
I::Item: SampleBorrow<W>,
{
let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
for weight in subtotals.iter() {
if *weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
if !(*weight >= W::ZERO) {
return Err(WeightError::InvalidWeight);
}
}
let n = subtotals.len();
Expand All @@ -115,7 +119,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
let parent = (i - 1) / 2;
subtotals[parent]
.checked_add_assign(&w)
.map_err(|()| WeightedError::Overflow)?;
.map_err(|()| WeightError::Overflow)?;
}
Ok(Self { subtotals })
}
Expand Down Expand Up @@ -164,14 +168,18 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
}

/// Appends a new weight at the end.
pub fn push(&mut self, weight: W) -> Result<(), WeightedError> {
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
///
/// Error cases:
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
pub fn push(&mut self, weight: W) -> Result<(), WeightError> {
if !(weight >= W::ZERO) {
return Err(WeightError::InvalidWeight);
}
if let Some(total) = self.subtotals.first() {
let mut total = total.clone();
if total.checked_add_assign(&weight).is_err() {
return Err(WeightedError::Overflow);
return Err(WeightError::Overflow);
}
}
let mut index = self.len();
Expand All @@ -184,9 +192,13 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
}

/// Updates the weight at an index.
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> {
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
///
/// Error cases:
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
/// - [`WeightError::Overflow`] when the sum of all weights overflows.
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightError> {
if !(weight >= W::ZERO) {
return Err(WeightError::InvalidWeight);
}
let old_weight = self.get(index);
if weight > old_weight {
Expand All @@ -195,7 +207,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
if let Some(total) = self.subtotals.first() {
let mut total = total.clone();
if total.checked_add_assign(&difference).is_err() {
return Err(WeightedError::Overflow);
return Err(WeightError::Overflow);
}
}
self.subtotals[index]
Expand Down Expand Up @@ -235,13 +247,10 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
///
/// Returns an error if there are no elements or all weights are zero. This
/// is unlike [`Distribution::sample`], which panics in those cases.
fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
if self.subtotals.is_empty() {
return Err(WeightedError::NoItem);
}
let total_weight = self.subtotals[0].clone();
fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightError> {
let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO);
if total_weight == W::ZERO {
return Err(WeightedError::AllWeightsZero);
return Err(WeightError::InsufficientNonZero);
}
let mut target_weight = rng.gen_range(W::ZERO..total_weight);
let mut index = 0;
Expand Down Expand Up @@ -296,19 +305,19 @@ mod test {
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
assert_eq!(
tree.try_sample(&mut rng).unwrap_err(),
WeightedError::NoItem
WeightError::InsufficientNonZero
);
}

#[test]
fn test_overflow_error() {
assert_eq!(
WeightedTreeIndex::new(&[i32::MAX, 2]),
Err(WeightedError::Overflow)
Err(WeightError::Overflow)
);
let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap();
assert_eq!(tree.push(3), Err(WeightedError::Overflow));
assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow));
assert_eq!(tree.push(3), Err(WeightError::Overflow));
assert_eq!(tree.update(1, 4), Err(WeightError::Overflow));
tree.update(1, 2).unwrap();
}

Expand All @@ -318,22 +327,22 @@ mod test {
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
assert_eq!(
tree.try_sample(&mut rng).unwrap_err(),
WeightedError::AllWeightsZero
WeightError::InsufficientNonZero
);
}

#[test]
fn test_invalid_weight_error() {
assert_eq!(
WeightedTreeIndex::<i32>::new(&[1, -1]).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight);
assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight);
tree.push(1).unwrap();
assert_eq!(
tree.update(0, -1).unwrap_err(),
WeightedError::InvalidWeight
WeightError::InvalidWeight
);
}

Expand Down
9 changes: 1 addition & 8 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ pub mod hidden_export {
pub use super::float::IntoFloat; // used by rand_distr
}
pub mod uniform;
#[deprecated(
since = "0.8.0",
note = "use rand::distributions::{WeightedIndex, WeightedError} instead"
)]
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub mod weighted;

pub use self::bernoulli::{Bernoulli, BernoulliError};
pub use self::distribution::{Distribution, DistIter, DistMap};
Expand All @@ -126,7 +119,7 @@ pub use self::slice::Slice;
#[doc(inline)]
pub use self::uniform::Uniform;
#[cfg(feature = "alloc")]
pub use self::weighted_index::{Weight, WeightedError, WeightedIndex};
pub use self::weighted_index::{Weight, WeightError, WeightedIndex};

#[allow(unused)]
use crate::Rng;
Expand Down
14 changes: 7 additions & 7 deletions src/distributions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use alloc::string::String;
/// [`Slice::new`] constructs a distribution referencing a slice and uniformly
/// samples references from the items in the slice. It may do extra work up
/// front to make sampling of multiple values faster; if only one sample from
/// the slice is required, [`SliceRandom::choose`] can be more efficient.
/// the slice is required, [`IndexedRandom::choose`] can be more efficient.
///
/// Steps are taken to avoid bias which might be present in naive
/// implementations; for example `slice[rng.gen() % slice.len()]` samples from
Expand All @@ -25,7 +25,7 @@ use alloc::string::String;
/// This distribution samples with replacement; each sample is independent.
/// Sampling without replacement requires state to be retained, and therefore
/// cannot be handled by a distribution; you should instead consider methods
/// on [`SliceRandom`], such as [`SliceRandom::choose_multiple`].
/// on [`IndexedRandom`], such as [`IndexedRandom::choose_multiple`].
///
/// # Example
///
Expand All @@ -48,21 +48,21 @@ use alloc::string::String;
/// assert!(vowel_string.chars().all(|c| vowels.contains(&c)));
/// ```
///
/// For a single sample, [`SliceRandom::choose`][crate::seq::SliceRandom::choose]
/// For a single sample, [`IndexedRandom::choose`][crate::seq::IndexedRandom::choose]
/// may be preferred:
///
/// ```
/// use rand::seq::SliceRandom;
/// use rand::seq::IndexedRandom;
///
/// let vowels = ['a', 'e', 'i', 'o', 'u'];
/// let mut rng = rand::thread_rng();
///
/// println!("{}", vowels.choose(&mut rng).unwrap())
/// ```
///
/// [`SliceRandom`]: crate::seq::SliceRandom
/// [`SliceRandom::choose`]: crate::seq::SliceRandom::choose
/// [`SliceRandom::choose_multiple`]: crate::seq::SliceRandom::choose_multiple
/// [`IndexedRandom`]: crate::seq::IndexedRandom
/// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose
/// [`IndexedRandom::choose_multiple`]: crate::seq::IndexedRandom::choose_multiple
#[derive(Debug, Clone, Copy)]
pub struct Slice<'a, T> {
slice: &'a [T],
Expand Down
Loading

0 comments on commit dba696e

Please sign in to comment.