Skip to content

Commit

Permalink
a
Browse files Browse the repository at this point in the history
  • Loading branch information
xmakro committed Feb 8, 2024
1 parent c8e5e35 commit 689ac48
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions rand_distr/src/weighted_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,15 @@ use serde::{Deserialize, Serialize};
serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
)]
#[derive(Clone, Default, Debug, PartialEq)]
pub struct WeightedTreeIndex<W: SampleUniform> {
pub struct WeightedTreeIndex<
W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight,
> {
subtotals: Vec<W>,
}

impl<W: Clone + PartialEq + PartialOrd + SampleUniform + Weight> WeightedTreeIndex<W> {
impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
WeightedTreeIndex<W>
{
/// Creates a new [`WeightedTreeIndex`] from a slice of weights.
pub fn new<I>(weights: I) -> Result<Self, WeightedError>
where
Expand Down Expand Up @@ -138,28 +142,22 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + Weight> WeightedTreeInd
}

/// Gets the weight at an index.
pub fn get(&self, index: usize) -> W
where
W: for<'a> SubAssign<&'a W>,
{
pub fn get(&self, index: usize) -> W {
let left_index = 2 * index + 1;
let right_index = 2 * index + 2;
let mut w = self.subtotals[index].clone();
w -= &self.subtotal(left_index);
w -= &self.subtotal(right_index);
w -= self.subtotal(left_index);
w -= self.subtotal(right_index);
w
}

/// Removes the last weight and returns it, or [`None`] if it is empty.
pub fn pop(&mut self) -> Option<W>
where
W: for<'a> SubAssign<&'a W>,
{
pub fn pop(&mut self) -> Option<W> {
self.subtotals.pop().map(|weight| {
let mut index = self.len();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] -= &weight;
self.subtotals[index] -= weight.clone();
}
weight
})
Expand All @@ -186,15 +184,12 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + Weight> WeightedTreeInd
}

/// Updates the weight at an index.
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError>
where
W: for<'a> SubAssign<&'a W>,
{
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> {
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
}
let mut difference = weight;
difference -= &self.get(index);
difference -= self.get(index);
if difference == W::ZERO {
return Ok(());
}
Expand Down

0 comments on commit 689ac48

Please sign in to comment.