Skip to content

Commit

Permalink
refactor: replace max value with index to avoid cloning
Browse files Browse the repository at this point in the history
Signed-off-by: bsbds <[email protected]>
  • Loading branch information
bsbds committed Mar 12, 2024
1 parent 4f3bba2 commit 0cfa232
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 61 deletions.
118 changes: 58 additions & 60 deletions crates/utils/src/interval_map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct IntervalMap<T, V, Ix = DefaultIx> {

impl<T, V, Ix> IntervalMap<T, V, Ix>
where
T: Clone + Ord,
T: Ord,
Ix: IndexType,
{
/// Creates a new `IntervalMap` with estimated capacity.
Expand All @@ -39,8 +39,8 @@ where
/// This method panics when the tree is at the maximum number of nodes for its index
#[inline]
pub fn insert(&mut self, interval: Interval<T>, value: V) -> Option<V> {
let node = Self::new_node(interval, value);
let node_idx = NodeIndex::new(self.nodes.len());
let node = Self::new_node(interval, value, node_idx);
// check for max capacity, except if we use usize
assert!(
<Ix as IndexType>::max().index() == !0 || NodeIndex::end() != node_idx,
Expand All @@ -56,9 +56,7 @@ where
pub fn remove(&mut self, interval: &Interval<T>) -> Option<V> {
if let Some(node_idx) = self.search_exact(interval) {
self.remove_inner(node_idx);
// To achieve an O(1) time complexity for node removal, we swap the node
// with the last node stored in the vector and update parent/left/right
// nodes of the last node.
// Swap the node with the last node stored in the vector and update indices
let mut node = self.nodes.swap_remove(node_idx.index());
let old = NodeIndex::<Ix>::new(self.nodes.len());
self.update_idx(old, node_idx);
Expand Down Expand Up @@ -150,7 +148,7 @@ where

impl<T, V> IntervalMap<T, V>
where
T: Clone + Ord,
T: Ord,
{
/// Creates an empty `IntervalMap`
#[must_use]
Expand All @@ -166,7 +164,7 @@ where

impl<T, V> Default for IntervalMap<T, V>
where
T: Clone + Ord,
T: Ord,
{
#[inline]
fn default() -> Self {
Expand All @@ -176,15 +174,15 @@ where

impl<T, V, Ix> IntervalMap<T, V, Ix>
where
T: Clone + Ord,
T: Ord,
Ix: IndexType,
{
/// Creates a new sentinel node
fn new_sentinel() -> Node<T, V, Ix> {
Node {
interval: None,
value: None,
max: None,
max_index: None,
left: None,
right: None,
parent: None,
Expand All @@ -193,9 +191,9 @@ where
}

/// Creates a new tree node
fn new_node(interval: Interval<T>, value: V) -> Node<T, V, Ix> {
fn new_node(interval: Interval<T>, value: V, index: NodeIndex<Ix>) -> Node<T, V, Ix> {
Node {
max: Some(interval.high.clone()),
max_index: Some(index),
interval: Some(interval),
value: Some(value),
left: Some(Self::sentinel()),
Expand All @@ -213,7 +211,7 @@ where

impl<T, V, Ix> IntervalMap<T, V, Ix>
where
T: Ord + Clone,
T: Ord,
Ix: IndexType,
{
/// Inserts a node into the tree.
Expand Down Expand Up @@ -300,16 +298,12 @@ where
if self.node_ref(x, Node::interval).overlap(interval) {
list.push(self.node_ref(x, |nx| (nx.interval(), nx.value())));
}
if self
.left_ref(x, Node::sentinel)
.map(Node::max)
.is_some_and(|lm| lm >= &interval.low)
{
if self.max(self.node_ref(x, Node::left)) >= Some(&interval.low) {
list.extend(self.find_all_overlap_inner(self.node_ref(x, Node::left), interval));
}
if self
.right_ref(x, Node::sentinel)
.map(|r| IntervalRef::new(&self.node_ref(x, Node::interval).low, r.max()))
.max(self.node_ref(x, Node::right))
.map(|rmax| IntervalRef::new(&self.node_ref(x, Node::interval).low, rmax))
.is_some_and(|i| i.overlap(interval))
{
list.extend(self.find_all_overlap_inner(self.node_ref(x, Node::right), interval));
Expand All @@ -325,11 +319,7 @@ where
.map(Node::interval)
.is_some_and(|xi| !xi.overlap(interval))
{
if self
.left_ref(x, Node::sentinel)
.map(Node::max)
.is_some_and(|lm| lm > &interval.low)
{
if self.max(self.node_ref(x, Node::left)) > Some(&interval.low) {
x = self.node_ref(x, Node::left);
} else {
x = self.node_ref(x, Node::right);
Expand All @@ -345,7 +335,7 @@ where
if self.node_ref(x, Node::interval) == interval {
return Some(x);
}
if self.node_ref(x, Node::max) < &interval.high {
if self.max(x) < Some(&interval.high) {
return None;
}
if self.node_ref(x, Node::interval) > interval {
Expand Down Expand Up @@ -512,36 +502,35 @@ where

/// Updates the max value after a rotation.
fn rotate_update_max(&mut self, x: NodeIndex<Ix>, y: NodeIndex<Ix>) {
self.node_mut(y, Node::set_max(self.node_ref(x, Node::max_owned)));
let mut max = &self.node_ref(x, Node::interval).high;
if let Some(lmax) = self.left_ref(x, Node::sentinel).map(Node::max) {
max = max.max(lmax);
}
if let Some(rmax) = self.right_ref(x, Node::sentinel).map(Node::max) {
max = max.max(rmax);
}
self.node_mut(x, Node::set_max(max.clone()));
self.node_mut(y, Node::set_max_index(self.node_ref(x, Node::max_index)));
self.recaculate_max(x);
}

/// Updates the max value towards the root
fn update_max_bottom_up(&mut self, x: NodeIndex<Ix>) {
let mut p = x;
while !self.node_ref(p, Node::is_sentinel) {
self.node_mut(
p,
Node::set_max(self.node_ref(p, Node::interval).high.clone()),
);
self.max_from(p, self.node_ref(p, Node::left));
self.max_from(p, self.node_ref(p, Node::right));
self.recaculate_max(p);
p = self.node_ref(p, Node::parent);
}
}

/// Updates a nodes value from a child node.
fn max_from(&mut self, x: NodeIndex<Ix>, c: NodeIndex<Ix>) {
if let Some(cmax) = self.node_ref(c, Node::sentinel).map(Node::max) {
let max = self.node_ref(x, Node::max).max(cmax).clone();
self.node_mut(x, Node::set_max(max));
/// Recaculate max value from left and right childrens
fn recaculate_max(&mut self, x: NodeIndex<Ix>) {
self.node_mut(x, Node::set_max_index(x));
let x_left = self.node_ref(x, Node::left);
let x_right = self.node_ref(x, Node::right);
if self.max(x_left) > self.max(x) {
self.node_mut(
x,
Node::set_max_index(self.node_ref(x_left, Node::max_index)),
);
}
if self.max(x_right) > self.max(x) {
self.node_mut(
x,
Node::set_max_index(self.node_ref(x_right, Node::max_index)),
);
}
}

Expand Down Expand Up @@ -578,7 +567,10 @@ where
self.parent_ref(node, Node::right) == node
}

/// Updates nodes index after remove
/// Updates nodes indices after remove
///
/// This method has a time complexity of `O(logn)`, as we need to
/// update the max index from bottom to top.
fn update_idx(&mut self, old: NodeIndex<Ix>, new: NodeIndex<Ix>) {
if self.root == old {
self.root = new;
Expand All @@ -593,6 +585,14 @@ where
}
self.left_mut(new, Node::set_parent(new));
self.right_mut(new, Node::set_parent(new));

let mut p = new;
while !self.node_ref(p, Node::is_sentinel) {
if self.node_ref(p, Node::max_index) == old {
self.node_mut(p, Node::set_max_index(new));
}
p = self.node_ref(p, Node::parent);
}
}
}
}
Expand Down Expand Up @@ -693,6 +693,11 @@ where
let grand_parent_idx = self.nodes[parent_idx].parent().index();
op(&mut self.nodes[grand_parent_idx])
}

fn max(&self, node: NodeIndex<Ix>) -> Option<&T> {
let max_index = self.nodes[node.index()].max_index?.index();
self.nodes[max_index].interval.as_ref().map(|i| &i.high)
}
}

/// An iterator over the entries of a `IntervalMap`.
Expand Down Expand Up @@ -781,7 +786,7 @@ pub struct VacantEntry<'a, T, V, Ix> {

impl<'a, T, V, Ix> Entry<'a, T, V, Ix>
where
T: Ord + Clone,
T: Ord,
Ix: IndexType,
{
/// Ensures a value is in the entry by inserting the default if empty, and returns
Expand Down Expand Up @@ -835,8 +840,8 @@ pub struct Node<T, V, Ix> {

/// Interval of the node
interval: Option<Interval<T>>,
/// Max value of the sub-tree of the node
max: Option<T>,
/// The index that point to the node with the max value
max_index: Option<NodeIndex<Ix>>,
/// Value of the node
value: Option<V>,
}
Expand All @@ -857,15 +862,8 @@ where
self.interval.as_ref().unwrap()
}

fn max(&self) -> &T {
self.max.as_ref().unwrap()
}

fn max_owned(&self) -> T
where
T: Clone,
{
self.max().clone()
fn max_index(&self) -> NodeIndex<Ix> {
self.max_index.unwrap()
}

fn left(&self) -> NodeIndex<Ix> {
Expand Down Expand Up @@ -918,9 +916,9 @@ where
}
}

fn set_max(max: T) -> impl FnOnce(&mut Node<T, V, Ix>) {
fn set_max_index(max_index: NodeIndex<Ix>) -> impl FnOnce(&mut Node<T, V, Ix>) {
move |node: &mut Node<T, V, Ix>| {
let _ignore = node.max.replace(max);
let _ignore = node.max_index.replace(max_index);
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/utils/src/interval_map/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl<V> IntervalMap<i32, V> {
let l_max = self.check_max_inner(self.node_ref(x, Node::left));
let r_max = self.check_max_inner(self.node_ref(x, Node::right));
let max = self.node_ref(x, |x| x.interval().high.max(l_max).max(r_max));
assert_eq!(self.node_ref(x, Node::max_owned), max);
assert_eq!(self.max(x), Some(&max));
max
}

Expand Down

0 comments on commit 0cfa232

Please sign in to comment.