Skip to content

Commit

Permalink
Checked adds and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
xmakro committed Jan 12, 2024
1 parent 30866d6 commit 3b6229e
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 134 deletions.
190 changes: 81 additions & 109 deletions rand_distr/src/weighted_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@
//! This module contains an implementation of a tree sttructure for sampling random
//! indices with probabilities proportional to a collection of weights.

use core::ops::{Sub, SubAssign};
use core::ops::SubAssign;

use super::WeightedError;
use crate::Distribution;
use alloc::{vec, vec::Vec};
use num_traits::{Zero, CheckedAdd};
use rand::{distributions::uniform::SampleUniform, Rng};
use alloc::vec::Vec;
use rand::{
distributions::{
uniform::{SampleBorrow, SampleUniform},
Weight,
},
Rng,
};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

Expand All @@ -32,17 +37,17 @@ use serde::{Deserialize, Serialize};
/// The main distinction between [`WeightedTreeIndex<W>`] and [`rand::distributions::WeightedIndex<W>`]
/// lies in the internal representation of weights. In [`WeightedTreeIndex<W>`],
/// weights are structured as a tree, which is optimized for frequent updates of the weights.
///
///
/// # Caution: Floating point types
///
///
/// When utilizing [`WeightedTreeIndex<W>`] with floating point types (such as f32 or f64),
/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types
/// are susceptible to numerical rounding errors. Since operations on floating point weights are
/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable
/// deviations from the expected behavior.
///
///
/// Ideally, use fixed point or integer types whenever possible.
///
///
/// # Performance
///
/// A [`WeightedTreeIndex<W>`] with `n` elements requires `O(n)` memory.
Expand Down Expand Up @@ -86,35 +91,30 @@ use serde::{Deserialize, Serialize};
serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
)]
#[derive(Clone, Default, Debug, PartialEq)]
pub struct WeightedTreeIndex<W: Weight> {
pub struct WeightedTreeIndex<W> {
subtotals: Vec<W>,
}

impl<W: Weight> WeightedTreeIndex<W> {
impl<W: Clone + PartialEq + PartialOrd + Weight> WeightedTreeIndex<W> {
/// Creates a new [`WeightedTreeIndex`] from a slice of weights.
pub fn new(weights: &[W]) -> Result<Self, WeightedError> {
for &weight in weights {
if weight < W::zero() {
pub fn new<I>(weights: I) -> Result<Self, WeightedError>
where
I: IntoIterator,
I::Item: SampleBorrow<W>,
{
let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
for weight in subtotals.iter() {
if *weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
}
}
let n = weights.len();
let mut subtotals = vec![W::zero(); n];
for i in (0..n).rev() {
let left_index = 2 * i + 1;
let left_subtotal = if left_index < n {
subtotals[left_index]
} else {
W::zero()
};
let right_index = 2 * i + 2;
let right_subtotal = if right_index < n {
subtotals[right_index]
} else {
W::zero()
};
let children_subtotal = left_subtotal.checked_add(&right_subtotal).ok_or(WeightedError::Overflow)?;
subtotals[i] = weights[i].checked_add(&children_subtotal).ok_or(WeightedError::Overflow)?;
let n = subtotals.len();
for i in (1..n).rev() {
let w = subtotals[i].clone();
let parent = (i - 1) / 2;
subtotals[parent]
.checked_add_assign(&w)
.map_err(|()| WeightedError::Overflow)?;
}
Ok(Self { subtotals })
}
Expand All @@ -133,92 +133,113 @@ impl<W: Weight> WeightedTreeIndex<W> {
///
/// This is the case if the total weight of the tree is greater than zero.
pub fn can_sample(&self) -> bool {
if let Some(&w) = self.subtotals.first() {
w > W::zero()
if let Some(weight) = self.subtotals.first() {
*weight > W::ZERO
} else {
false
}
}

/// Gets the weight at an index.
pub fn get(&self, index: usize) -> W {
pub fn get(&self, index: usize) -> W
where
W: for<'a> SubAssign<&'a W>,
{
let left_index = 2 * index + 1;
let right_index = 2 * index + 2;
self.subtotals[index] - self.subtotal(left_index) - self.subtotal(right_index)
let mut w = self.subtotals[index].clone();
w -= &self.subtotal(left_index);
w -= &self.subtotal(right_index);
w
}

/// Removes the last weight and returns it, or [`None`] if it is empty.
pub fn pop(&mut self) -> Option<W> {
pub fn pop(&mut self) -> Option<W>
where
W: for<'a> SubAssign<&'a W>,
{
self.subtotals.pop().map(|weight| {
let mut index = self.len();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] -= weight;
self.subtotals[index] -= &weight;
}
weight
})
}

/// Appends a new weight at the end.
pub fn push(&mut self, weight: W) -> Result<(), WeightedError> {
if weight < W::zero() {
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
}
if let Some(total) = self.subtotals.first() {
if total.checked_add(&weight).is_none() {
let mut total = total.clone();
if total.checked_add_assign(&weight).is_err() {
return Err(WeightedError::Overflow);
}
}
let mut index = self.len();
self.subtotals.push(weight);
self.subtotals.push(weight.clone());
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] = self.subtotals[index].checked_add(&weight).unwrap();
self.subtotals[index].checked_add_assign(&weight).unwrap();
}
Ok(())
}

/// Updates the weight at an index.
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> {
if weight < W::zero() {
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError>
where
W: for<'a> SubAssign<&'a W>,
{
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
}
let difference = weight - self.get(index);
if difference == W::zero() {
let mut difference = weight;
difference -= &self.get(index);
if difference == W::ZERO {
return Ok(());
}
if let Some(total) = self.subtotals.first() {
if total.checked_add(&difference).is_none() {
let mut total = total.clone();
if total.checked_add_assign(&difference).is_err() {
return Err(WeightedError::Overflow);
}
}
self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap();
self.subtotals[index]
.checked_add_assign(&difference)
.unwrap();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap();
self.subtotals[index]
.checked_add_assign(&difference)
.unwrap();
}
Ok(())
}

fn subtotal(&self, index: usize) -> W {
if index < self.subtotals.len() {
self.subtotals[index]
self.subtotals[index].clone()
} else {
W::zero()
W::ZERO
}
}
}

impl<W: Weight> Distribution<Result<usize, WeightedError>> for WeightedTreeIndex<W> {
impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
Distribution<Result<usize, WeightedError>> for WeightedTreeIndex<W>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
if self.subtotals.is_empty() {
return Err(WeightedError::NoItem);
}
let total_weight = self.subtotals[0];
if total_weight == W::zero() {
let total_weight = self.subtotals[0].clone();
if total_weight == W::ZERO {
return Err(WeightedError::AllWeightsZero);
}
let mut target_weight = rng.gen_range(W::zero()..total_weight);
let mut target_weight = rng.gen_range(W::ZERO..total_weight);
let mut index = 0;
loop {
// Maybe descend into the left sub tree.
Expand All @@ -242,64 +263,12 @@ impl<W: Weight> Distribution<Result<usize, WeightedError>> for WeightedTreeIndex
// Otherwise we found the index with the target weight.
break;
}
assert!(target_weight >= W::zero());
assert!(target_weight >= W::ZERO);
assert!(target_weight < self.subtotal(index));
Ok(index)
}
}

/// Trait that must be implemented for weights, that are used with
/// [`WeightedTreeIndex`]. Currently no guarantees on the correctness of
/// [`WeightedTreeIndex`] are given for custom implementations of this trait.
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub trait Weight:
Sized
+ Copy
+ SampleUniform
+ PartialOrd
+ Sub<Output = Self>
+ SubAssign
+ Zero
{
/// Adds two numbers, checking for overflow. If overflow happens, None is returned.
fn checked_add(&self, b: &Self) -> Option<Self>;
}

macro_rules! impl_weight_for_float {
($T: ident) => {
impl Weight for $T {
fn checked_add(&self, b: &Self) -> Option<Self> {
Some(self + b)
}
}
};
}

macro_rules! impl_weight_for_int {
($T: ident) => {
impl Weight for $T {
fn checked_add(&self, b: &Self) -> Option<Self> {
CheckedAdd::checked_add(self, b)
}
}
};
}

impl_weight_for_float!(f64);
impl_weight_for_float!(f32);
impl_weight_for_int!(usize);
impl_weight_for_int!(u128);
impl_weight_for_int!(u64);
impl_weight_for_int!(u32);
impl_weight_for_int!(u16);
impl_weight_for_int!(u8);
impl_weight_for_int!(isize);
impl_weight_for_int!(i128);
impl_weight_for_int!(i64);
impl_weight_for_int!(i32);
impl_weight_for_int!(i16);
impl_weight_for_int!(i8);

#[cfg(test)]
mod test {
use super::*;
Expand All @@ -313,7 +282,10 @@ mod test {

#[test]
fn test_overflow_error() {
assert_eq!(WeightedTreeIndex::new(&[i32::MAX, 2]), Err(WeightedError::Overflow));
assert_eq!(
WeightedTreeIndex::new(&[i32::MAX, 2]),
Err(WeightedError::Overflow)
);
let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap();
assert_eq!(tree.push(3), Err(WeightedError::Overflow));
assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow));
Expand Down Expand Up @@ -365,7 +337,7 @@ mod test {
let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect();
let mut tree = WeightedTreeIndex::new(&weights).unwrap();
let mut total_weight = 0.0;
let mut weights = vec![0.0; end];
let mut weights = alloc::vec![0.0; end];
for i in 0..end {
tree.update(i, i as f64).unwrap();
weights[i] = i as f64;
Expand All @@ -376,7 +348,7 @@ mod test {
weights[i] = 0.0;
total_weight -= i as f64;
}
let mut counts = vec![0_usize; end];
let mut counts = alloc::vec![0_usize; end];
for _ in 0..samples {
let i = tree.sample(&mut rng).unwrap();
counts[i] += 1;
Expand Down
Loading

0 comments on commit 3b6229e

Please sign in to comment.