diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 7172e8d432..1e28aaaa79 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -76,8 +76,9 @@ //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution //! - [`UnitDisc`] distribution -//! - Alternative implementation for weighted index sampling +//! - Alternative implementations for weighted index sampling //! - [`WeightedAliasIndex`] distribution +//! - [`WeightedTreeIndex`] distribution //! - Misc. distributions //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 045c9e1f2b..2f1349126e 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -9,12 +9,12 @@ //! This module contains an implementation of a tree sttructure for sampling random //! indices with probabilities proportional to a collection of weights. -use core::ops::{Add, AddAssign, Sub, SubAssign}; +use core::ops::{Sub, SubAssign}; use super::WeightedError; use crate::Distribution; use alloc::{vec, vec::Vec}; -use num_traits::Zero; +use num_traits::{Zero, CheckedAdd}; use rand::{distributions::uniform::SampleUniform, Rng}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -113,7 +113,8 @@ impl WeightedTreeIndex { } else { W::zero() }; - subtotals[i] = weights[i] + left_subtotal + right_subtotal; + let children_subtotal = left_subtotal.checked_add(&right_subtotal).ok_or(WeightedError::Overflow)?; + subtotals[i] = weights[i].checked_add(&children_subtotal).ok_or(WeightedError::Overflow)?; } Ok(Self { subtotals }) } @@ -163,11 +164,16 @@ impl WeightedTreeIndex { if weight < W::zero() { return Err(WeightedError::InvalidWeight); } + if let Some(total) = self.subtotals.first() { + if total.checked_add(&weight).is_none() { + return Err(WeightedError::Overflow); + } + } let mut index = self.len(); self.subtotals.push(weight); while index != 0 { index = (index - 1) / 2; - self.subtotals[index] += weight; + self.subtotals[index] = self.subtotals[index].checked_add(&weight).unwrap(); } Ok(()) } @@ -181,10 +187,15 @@ impl WeightedTreeIndex { if difference == W::zero() { return Ok(()); } - self.subtotals[index] += difference; + if let Some(total) = self.subtotals.first() { + if total.checked_add(&difference).is_none() { + return Err(WeightedError::Overflow); + } + } + self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap(); while index != 0 { index = (index - 1) / 2; - self.subtotals[index] += difference; + self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap(); } Ok(()) } @@ -246,27 +257,49 @@ pub trait Weight: + Copy + SampleUniform + PartialOrd - + Add - + AddAssign + Sub + SubAssign + Zero { + /// Adds two numbers, checking for overflow. If overflow happens, None is returned. + fn checked_add(&self, b: &Self) -> Option; } -impl Weight for T where - T: Sized - + Copy - + SampleUniform - + PartialOrd - + Add - + AddAssign - + Sub - + SubAssign - + Zero -{ +macro_rules! impl_weight_for_float { + ($T: ident) => { + impl Weight for $T { + fn checked_add(&self, b: &Self) -> Option { + Some(self + b) + } + } + }; } +macro_rules! impl_weight_for_int { + ($T: ident) => { + impl Weight for $T { + fn checked_add(&self, b: &Self) -> Option { + CheckedAdd::checked_add(self, b) + } + } + }; +} + +impl_weight_for_float!(f64); +impl_weight_for_float!(f32); +impl_weight_for_int!(usize); +impl_weight_for_int!(u128); +impl_weight_for_int!(u64); +impl_weight_for_int!(u32); +impl_weight_for_int!(u16); +impl_weight_for_int!(u8); +impl_weight_for_int!(isize); +impl_weight_for_int!(i128); +impl_weight_for_int!(i64); +impl_weight_for_int!(i32); +impl_weight_for_int!(i16); +impl_weight_for_int!(i8); + #[cfg(test)] mod test { use super::*; @@ -278,6 +311,15 @@ mod test { assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem); } + #[test] + fn test_overflow_error() { + assert_eq!(WeightedTreeIndex::new(&[i32::MAX, 2]), Err(WeightedError::Overflow)); + let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap(); + assert_eq!(tree.push(3), Err(WeightedError::Overflow)); + assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow)); + tree.update(1, 2).unwrap(); + } + #[test] fn test_all_weights_zero_error() { let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index de3628b5ea..5223af594d 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -35,7 +35,9 @@ use serde::{Serialize, Deserialize}; /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where /// `N` is the number of weights. As an alternative, /// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) -/// supports `O(1)` sampling, but with much higher initialisation cost. +/// supports `O(1)` sampling, but with much higher initialisation cost, +/// and [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) +/// supports `O(log n)` updates with O /// /// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its /// size is the sum of the size of those objects, possibly plus some alignment.