From 2caee42837c80301642bbe8ba0239d913af34dc3 Mon Sep 17 00:00:00 2001 From: feathercyc Date: Fri, 26 Jul 2024 14:31:04 +0800 Subject: [PATCH] chore: change some fn's and variable's name Signed-off-by: feathercyc --- Cargo.lock | 2 + Cargo.toml | 8 +- benches/bench.rs | 2 +- src/entry.rs | 6 +- src/index.rs | 84 ++++--- src/interval.rs | 83 ++++++- src/intervalmap.rs | 550 +++++++++------------------------------------ src/lib.rs | 2 +- src/node.rs | 254 +++++++++++++++++---- 9 files changed, 472 insertions(+), 519 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 78c1808..493b595 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,6 +205,8 @@ version = "0.1.0" dependencies = [ "criterion", "rand", + "serde", + "serde_json", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8a89c0e..78b398b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,14 +8,20 @@ license = "Apache-2.0" keywords = ["Interval Tree", "Augmented Tree", "Red-Black Tree"] [dependencies] +serde = { version = "1.0", default-features = false, features = [ + "derive", + "std", +], optional = true } [dev-dependencies] criterion = "0.5.1" rand = "0.8.5" +serde_json = "1.0" [features] default = [] -interval_tree_find_overlap_ordered = [] +graphviz = [] +serde = ["dep:serde"] [[bench]] name = "bench" diff --git a/benches/bench.rs b/benches/bench.rs index f7f7435..904e254 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -65,7 +65,7 @@ fn interval_map_insert_remove(count: usize, bench: &mut Bencher) { black_box(map.insert(i, ())); } for i in &intervals { - black_box(map.remove(&i)); + black_box(map.remove(i)); } }); } diff --git a/src/entry.rs b/src/entry.rs index 731a55a..09cf4de 100644 --- a/src/entry.rs +++ b/src/entry.rs @@ -19,7 +19,7 @@ pub struct OccupiedEntry<'a, T, V, Ix> { /// Reference to the map pub map_ref: &'a mut IntervalMap, /// The entry node - pub node: NodeIndex, + pub node_idx: NodeIndex, } /// A view into a vacant entry in a `IntervalMap`. @@ -53,7 +53,7 @@ where #[inline] pub fn or_insert(self, default: V) -> &'a mut V { match self { - Entry::Occupied(entry) => entry.map_ref.node_mut(entry.node, Node::value_mut), + Entry::Occupied(entry) => entry.map_ref.node_mut(entry.node_idx, Node::value_mut), Entry::Vacant(entry) => { let entry_idx = NodeIndex::new(entry.map_ref.nodes.len()); let _ignore = entry.map_ref.insert(entry.interval, default); @@ -88,7 +88,7 @@ where { match self { Entry::Occupied(entry) => { - f(entry.map_ref.node_mut(entry.node, Node::value_mut)); + f(entry.map_ref.node_mut(entry.node_idx, Node::value_mut)); Self::Occupied(entry) } Entry::Vacant(entry) => Self::Vacant(entry), diff --git a/src/index.rs b/src/index.rs index 657f955..15f6b79 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,30 +1,60 @@ use std::fmt; use std::hash::Hash; -pub type DefaultIx = u32; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// Trait for the unsigned integer type used for node indices. +pub trait IndexType: Copy + Default + Hash + Ord + fmt::Debug + 'static { + const SENTINEL: Self; -pub unsafe trait IndexType: Copy + Default + Hash + Ord + fmt::Debug + 'static { + /// Convert x from usize to the corresponding type + /// # Notice + /// Using u8 and u16 types may cause numerical overflow. Please check the numerical range before using fn new(x: usize) -> Self; + + /// Convert self to usize fn index(&self) -> usize; + + /// Return Self::MAX fn max() -> Self; -} -unsafe impl IndexType for u32 { - #[inline(always)] - fn new(x: usize) -> Self { - x as u32 - } - #[inline(always)] - fn index(&self) -> usize { - *self as usize - } - #[inline(always)] - fn max() -> Self { - ::std::u32::MAX + /// Check if self is Self::SENTINEL + fn is_sentinel(&self) -> bool { + *self == Self::SENTINEL } } +macro_rules! impl_index { + ($type:ident) => { + impl IndexType for $type { + const SENTINEL: Self = 0; + + #[inline(always)] + fn new(x: usize) -> Self { + x as $type + } + #[inline(always)] + fn index(&self) -> usize { + *self as usize + } + #[inline(always)] + fn max() -> Self { + Self::MAX + } + } + }; +} + +impl_index!(u8); +impl_index!(u16); +impl_index!(u32); +impl_index!(u64); + +pub type DefaultIx = u32; + /// Node identifier. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct NodeIndex(Ix); @@ -34,26 +64,30 @@ impl NodeIndex { NodeIndex(IndexType::new(x)) } - #[inline] - pub fn index(self) -> usize { - self.0.index() - } - - #[inline] - pub fn end() -> Self { - NodeIndex(IndexType::max()) + pub fn inc(&self) -> Self { + if self.index() == ::max().index() { + panic!("Index will overflow!") + } + NodeIndex::new(self.index() + 1) } } -unsafe impl IndexType for NodeIndex { +impl IndexType for NodeIndex { + const SENTINEL: Self = NodeIndex(Ix::SENTINEL); + + #[inline] fn index(&self) -> usize { self.0.index() } + + #[inline] fn new(x: usize) -> Self { NodeIndex::new(x) } + + #[inline] fn max() -> Self { - NodeIndex(::max()) + NodeIndex(IndexType::max()) } } diff --git a/src/interval.rs b/src/interval.rs index e7d78bc..0fe12a7 100644 --- a/src/interval.rs +++ b/src/interval.rs @@ -8,9 +8,14 @@ //! //! Currently, `interval_map` only supports half-open intervals, i.e., [...,...). +use std::fmt::{Display, Formatter}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + /// The interval stored in `IntervalMap` represents [low, high) -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[non_exhaustive] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Interval { /// Low value pub low: T, @@ -32,9 +37,27 @@ impl Interval { /// Checks if self overlaps with other interval #[inline] - pub fn overlap(&self, other: &Self) -> bool { + pub fn overlaps(&self, other: &Self) -> bool { self.high > other.low && other.high > self.low } + + /// Checks if self contains other interval + /// e.g. [1,10) contains [1,8) + #[inline] + pub fn contains(&self, other: &Self) -> bool { + self.low <= other.low && self.high > other.high + } + + /// Checks if self contains a point + pub fn contains_point(&self, p: T) -> bool { + self.low <= p && self.high > p + } +} + +impl Display for Interval { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "[{},{})", self.low, self.high) + } } /// Reference type of `Interval` @@ -64,6 +87,24 @@ impl<'a, T: Ord> IntervalRef<'a, T> { } } +#[cfg(feature = "serde")] +impl Serialize for Interval { + fn serialize(&self, serializer: S) -> Result { + (&self.low, &self.high).serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'de, T> Deserialize<'de> for Interval +where + T: Deserialize<'de>, +{ + fn deserialize>(deserializer: D) -> Result { + let (low, high) = <(T, T)>::deserialize(deserializer)?; + Ok(Interval { low, high }) + } +} + #[cfg(test)] mod test { use super::*; @@ -73,4 +114,42 @@ mod test { fn invalid_range_should_panic() { let _interval = Interval::new(3, 1); } + + #[test] + fn test_interval_clone() { + let interval1 = Interval::new(1, 10); + let interval2 = interval1.clone(); + assert_eq!(interval1, interval2); + } + + #[test] + fn test_interval_compare() { + let interval1 = Interval::new(1, 10); + let interval2 = Interval::new(5, 15); + assert!(interval1 < interval2); + assert!(interval2 > interval1); + assert_eq!(interval1, Interval::new(1, 10)); + assert_ne!(interval1, interval2); + } + + #[test] + fn test_interval_hash() { + let interval1 = Interval::new(1, 10); + let interval2 = Interval::new(1, 10); + let interval3 = Interval::new(5, 15); + let mut hashset = std::collections::HashSet::new(); + hashset.insert(interval1); + hashset.insert(interval2); + hashset.insert(interval3); + assert_eq!(hashset.len(), 2); + } + + #[cfg(feature = "serde")] + #[test] + fn test_interval_serialize_deserialize() { + let interval = Interval::new(1, 10); + let serialized = serde_json::to_string(&interval).unwrap(); + let deserialized: Interval = serde_json::from_str(&serialized).unwrap(); + assert_eq!(interval, deserialized); + } } diff --git a/src/intervalmap.rs b/src/intervalmap.rs index 51dcb24..1d7fd20 100644 --- a/src/intervalmap.rs +++ b/src/intervalmap.rs @@ -2,9 +2,22 @@ use crate::entry::{Entry, OccupiedEntry, VacantEntry}; use crate::index::{DefaultIx, IndexType, NodeIndex}; use crate::interval::{Interval, IntervalRef}; use crate::node::{Color, Node}; + use std::collections::VecDeque; +use std::fmt::Debug; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "graphviz")] +use std::fmt::Display; +#[cfg(feature = "graphviz")] +use std::fs::OpenOptions; +#[cfg(feature = "graphviz")] +use std::io::Write; /// An interval-value map, which support operations on dynamic sets of intervals. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] pub struct IntervalMap { /// Vector that stores nodes @@ -24,11 +37,11 @@ where #[inline] #[must_use] pub fn with_capacity(capacity: usize) -> Self { - let mut nodes = vec![Self::new_sentinel()]; + let mut nodes = vec![Node::new_sentinel()]; nodes.reserve(capacity); IntervalMap { nodes, - root: Self::sentinel(), + root: NodeIndex::SENTINEL, len: 0, } } @@ -52,10 +65,11 @@ where #[inline] pub fn insert(&mut self, interval: Interval, value: V) -> Option { let node_idx = NodeIndex::new(self.nodes.len()); - let node = Self::new_node(interval, value, node_idx); + let node = Node::new(interval, value, node_idx); // check for max capacity, except if we use usize assert!( - ::max().index() == !0 || NodeIndex::end() != node_idx, + ::max().index() == !0 + || as IndexType>::max() != node_idx, "Reached maximum number of nodes" ); self.nodes.push(node); @@ -101,19 +115,22 @@ where /// map.insert(Interval::new(1, 3), ()); /// map.insert(Interval::new(6, 7), ()); /// map.insert(Interval::new(9, 11), ()); - /// assert!(map.overlap(&Interval::new(2, 5))); - /// assert!(map.overlap(&Interval::new(1, 17))); - /// assert!(!map.overlap(&Interval::new(3, 6))); - /// assert!(!map.overlap(&Interval::new(11, 23))); + /// assert!(map.overlaps(&Interval::new(2, 5))); + /// assert!(map.overlaps(&Interval::new(1, 17))); + /// assert!(!map.overlaps(&Interval::new(3, 6))); + /// assert!(!map.overlaps(&Interval::new(11, 23))); /// ``` #[inline] - pub fn overlap(&self, interval: &Interval) -> bool { + pub fn overlaps(&self, interval: &Interval) -> bool { let node_idx = self.search(interval); !self.node_ref(node_idx, Node::is_sentinel) } /// Find all intervals in the map that overlaps with the given interval. /// + /// # Note + /// This method's returned data is unordered. To get ordered data, please use `find_all_overlap_ordered`. + /// /// # Example /// ```rust /// use interval_map::{Interval, IntervalMap}; @@ -132,7 +149,7 @@ where if self.node_ref(self.root, Node::is_sentinel) { Vec::new() } else { - self.find_all_overlap_inner_unordered(self.root, interval) + self.find_all_overlap_inner(self.root, interval) } } @@ -172,16 +189,6 @@ where .map(|idx| self.node_mut(idx, Node::value_mut)) } - /// Get an iterator over the entries of the map, sorted by key. - #[inline] - #[must_use] - pub fn iter(&self) -> Iter<'_, T, V, Ix> { - Iter { - map_ref: self, - stack: None, - } - } - /// Get the given key's corresponding entry in the map for in-place manipulation. /// /// # Example @@ -199,9 +206,9 @@ where #[inline] pub fn entry(&mut self, interval: Interval) -> Entry<'_, T, V, Ix> { match self.search_exact(&interval) { - Some(node) => Entry::Occupied(OccupiedEntry { + Some(node_idx) => Entry::Occupied(OccupiedEntry { map_ref: self, - node, + node_idx, }), None => Entry::Vacant(VacantEntry { map_ref: self, @@ -214,8 +221,8 @@ where #[inline] pub fn clear(&mut self) { self.nodes.clear(); - self.nodes.push(Self::new_sentinel()); - self.root = Self::sentinel(); + self.nodes.push(Node::new_sentinel()); + self.root = NodeIndex::SENTINEL; self.len = 0; } @@ -242,11 +249,7 @@ where #[inline] #[must_use] pub fn new() -> Self { - Self { - nodes: vec![Self::new_sentinel()], - root: Self::sentinel(), - len: 0, - } + Self::with_capacity(0) } } @@ -265,46 +268,9 @@ where T: Ord, Ix: IndexType, { - /// Create a new sentinel node - fn new_sentinel() -> Node { - Node { - interval: None, - value: None, - max_index: None, - left: None, - right: None, - parent: None, - color: Color::Black, - } - } - - /// Create a new tree node - fn new_node(interval: Interval, value: V, index: NodeIndex) -> Node { - Node { - max_index: Some(index), - interval: Some(interval), - value: Some(value), - left: Some(Self::sentinel()), - right: Some(Self::sentinel()), - parent: Some(Self::sentinel()), - color: Color::Red, - } - } - - /// Get the sentinel node index - fn sentinel() -> NodeIndex { - NodeIndex::new(0) - } -} - -impl IntervalMap -where - T: Ord, - Ix: IndexType, -{ - /// Insert a node into the tree. + /// insert a node into the tree. fn insert_inner(&mut self, z: NodeIndex) -> Option { - let mut y = Self::sentinel(); + let mut y = NodeIndex::SENTINEL; let mut x = self.root; while !self.node_ref(x, Node::is_sentinel) { @@ -380,34 +346,10 @@ where self.len = self.len.wrapping_sub(1); } - /// Find all intervals in the map that overlaps with the given interval. - #[cfg(interval_tree_find_overlap_ordered)] - fn find_all_overlap_inner( - &self, - x: NodeIndex, - interval: &Interval, - ) -> Vec<(&Interval, &V)> { - let mut list = vec![]; - if self.node_ref(x, Node::interval).overlap(interval) { - list.push(self.node_ref(x, |nx| (nx.interval(), nx.value()))); - } - if self.max(self.node_ref(x, Node::left)) >= Some(&interval.low) { - list.extend(self.find_all_overlap_inner(self.node_ref(x, Node::left), interval)); - } - if self - .max(self.node_ref(x, Node::right)) - .map(|rmax| IntervalRef::new(&self.node_ref(x, Node::interval).low, rmax)) - .is_some_and(|i| i.overlap(interval)) - { - list.extend(self.find_all_overlap_inner(self.node_ref(x, Node::right), interval)); - } - list - } - /// Find all intervals in the map that overlaps with the given interval. /// /// The result is unordered because of breadth-first search to save stack size - fn find_all_overlap_inner_unordered( + fn find_all_overlap_inner( &self, x: NodeIndex, interval: &Interval, @@ -416,7 +358,7 @@ where let mut queue = VecDeque::new(); queue.push_back(x); while let Some(p) = queue.pop_front() { - if self.node_ref(p, Node::interval).overlap(interval) { + if self.node_ref(p, Node::interval).overlaps(interval) { list.push(self.node_ref(p, |np| (np.interval(), np.value()))); } let p_left = self.node_ref(p, Node::left); @@ -442,7 +384,7 @@ where while self .node_ref(x, Node::sentinel) .map(Node::interval) - .is_some_and(|xi| !xi.overlap(interval)) + .is_some_and(|xi| !xi.overlaps(interval)) { if self.max(self.node_ref(x, Node::left)) > Some(&interval.low) { x = self.node_ref(x, Node::left); @@ -680,13 +622,13 @@ where } /// Check if a node is a left child of its parent. - fn is_left_child(&self, node: NodeIndex) -> bool { - self.parent_ref(node, Node::left) == node + fn is_left_child(&self, node_idx: NodeIndex) -> bool { + self.parent_ref(node_idx, Node::left) == node_idx } /// Check if a node is a right child of its parent. - fn is_right_child(&self, node: NodeIndex) -> bool { - self.parent_ref(node, Node::right) == node + fn is_right_child(&self, node_idx: NodeIndex) -> bool { + self.parent_ref(node_idx, Node::right) == node_idx } /// Update nodes indices after remove @@ -719,423 +661,149 @@ where } } +#[cfg(feature = "graphviz")] +impl IntervalMap +where + T: Ord + Copy + Display, + V: Display, + Ix: IndexType, +{ + /// writes dot file to `filename`. `T` and `V` should implement `Display`. + pub fn draw(&self, filename: &str) -> std::io::Result<()> { + let mut file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(filename)?; + writeln!(file, "digraph {{")?; + // begin at 1, because 0 is sentinel node + for i in 1..self.nodes.len() { + self.nodes[i].draw(i, &mut file)?; + } + writeln!(file, "}}") + } +} + +#[cfg(feature = "graphviz")] +impl IntervalMap +where + T: Ord + Copy + Display, + Ix: IndexType, +{ + /// Writes dot file to `filename` without values. `T` should implement `Display`. + pub fn draw_without_value(&self, filename: &str) -> std::io::Result<()> { + let mut file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(filename)?; + writeln!(file, "digraph {{")?; + // begin at 1, because 0 is sentinel node + for i in 1..self.nodes.len() { + self.nodes[i].draw_without_value(i, &mut file)?; + } + writeln!(file, "}}") + } +} + // Convenient methods for reference or mutate current/parent/left/right node impl<'a, T, V, Ix> IntervalMap where + T: Ord, Ix: IndexType, { - fn node_ref(&'a self, node: NodeIndex, op: F) -> R + pub(crate) fn node_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - op(&self.nodes[node.index()]) + op(&self.nodes[node_idx.index()]) } - pub(crate) fn node_mut(&'a mut self, node: NodeIndex, op: F) -> R + pub(crate) fn node_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - op(&mut self.nodes[node.index()]) + op(&mut self.nodes[node_idx.index()]) } - fn left_ref(&'a self, node: NodeIndex, op: F) -> R + pub(crate) fn left_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - let idx = self.nodes[node.index()].left().index(); + let idx = self.nodes[node_idx.index()].left().index(); op(&self.nodes[idx]) } - fn right_ref(&'a self, node: NodeIndex, op: F) -> R + pub(crate) fn right_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - let idx = self.nodes[node.index()].right().index(); + let idx = self.nodes[node_idx.index()].right().index(); op(&self.nodes[idx]) } - fn parent_ref(&'a self, node: NodeIndex, op: F) -> R + fn parent_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - let idx = self.nodes[node.index()].parent().index(); + let idx = self.nodes[node_idx.index()].parent().index(); op(&self.nodes[idx]) } - fn grand_parent_ref(&'a self, node: NodeIndex, op: F) -> R + fn grand_parent_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - let parent_idx = self.nodes[node.index()].parent().index(); + let parent_idx = self.nodes[node_idx.index()].parent().index(); let grand_parent_idx = self.nodes[parent_idx].parent().index(); op(&self.nodes[grand_parent_idx]) } - fn left_mut(&'a mut self, node: NodeIndex, op: F) -> R + fn left_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - let idx = self.nodes[node.index()].left().index(); + let idx = self.nodes[node_idx.index()].left().index(); op(&mut self.nodes[idx]) } - fn right_mut(&'a mut self, node: NodeIndex, op: F) -> R + fn right_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - let idx = self.nodes[node.index()].right().index(); + let idx = self.nodes[node_idx.index()].right().index(); op(&mut self.nodes[idx]) } - fn parent_mut(&'a mut self, node: NodeIndex, op: F) -> R + fn parent_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - let idx = self.nodes[node.index()].parent().index(); + let idx = self.nodes[node_idx.index()].parent().index(); op(&mut self.nodes[idx]) } - fn grand_parent_mut(&'a mut self, node: NodeIndex, op: F) -> R + fn grand_parent_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - let parent_idx = self.nodes[node.index()].parent().index(); + let parent_idx = self.nodes[node_idx.index()].parent().index(); let grand_parent_idx = self.nodes[parent_idx].parent().index(); op(&mut self.nodes[grand_parent_idx]) } - fn max(&self, node: NodeIndex) -> Option<&T> { - let max_index = self.nodes[node.index()].max_index?.index(); + pub(crate) fn max(&self, node_idx: NodeIndex) -> Option<&T> { + let max_index = self.nodes[node_idx.index()].max_index?.index(); self.nodes[max_index].interval.as_ref().map(|i| &i.high) } } - -/// An iterator over the entries of a `IntervalMap`. -#[derive(Debug)] -pub struct Iter<'a, T, V, Ix> { - /// Reference to the map - map_ref: &'a IntervalMap, - /// Stack for iteration - stack: Option>>, -} - -impl Iter<'_, T, V, Ix> -where - Ix: IndexType, -{ - /// Initializes the stack - fn init_stack(&mut self) { - self.stack = Some(Self::left_link(self.map_ref, self.map_ref.root)); - } - - /// Pushes a link of nodes on the left to stack. - fn left_link(map_ref: &IntervalMap, mut x: NodeIndex) -> Vec> { - let mut nodes = vec![]; - while !map_ref.node_ref(x, Node::is_sentinel) { - nodes.push(x); - x = map_ref.node_ref(x, Node::left); - } - nodes - } -} - -impl<'a, T, V, Ix> Iterator for Iter<'a, T, V, Ix> -where - Ix: IndexType, -{ - type Item = (&'a Interval, &'a V); - - #[inline] - fn next(&mut self) -> Option { - if self.stack.is_none() { - self.init_stack(); - } - let stack = self.stack.as_mut().unwrap(); - if stack.is_empty() { - return None; - } - let x = stack.pop().unwrap(); - stack.extend(Self::left_link( - self.map_ref, - self.map_ref.node_ref(x, Node::right), - )); - Some(self.map_ref.node_ref(x, |xn| (xn.interval(), xn.value()))) - } -} - -#[cfg(test)] -mod test { - use std::collections::HashSet; - - use rand::{rngs::StdRng, Rng, SeedableRng}; - - use super::*; - - struct IntervalGenerator { - rng: StdRng, - unique: HashSet>, - limit: i32, - } - - impl IntervalGenerator { - fn new(seed: [u8; 32]) -> Self { - const LIMIT: i32 = 1000; - Self { - rng: SeedableRng::from_seed(seed), - unique: HashSet::new(), - limit: LIMIT, - } - } - - fn next(&mut self) -> Interval { - let low = self.rng.gen_range(0..self.limit - 1); - let high = self.rng.gen_range((low + 1)..self.limit); - Interval::new(low, high) - } - - fn next_unique(&mut self) -> Interval { - let mut interval = self.next(); - while self.unique.contains(&interval) { - interval = self.next(); - } - self.unique.insert(interval.clone()); - interval - } - - fn next_with_range(&mut self, range: i32) -> Interval { - let low = self.rng.gen_range(0..self.limit - 1); - let high = self - .rng - .gen_range((low + 1)..self.limit.min(low + 1 + range)); - Interval::new(low, high) - } - } - - impl IntervalMap { - fn check_max(&self) { - let _ignore = self.check_max_inner(self.root); - } - - fn check_max_inner(&self, x: NodeIndex) -> i32 { - if self.node_ref(x, Node::is_sentinel) { - return 0; - } - let l_max = self.check_max_inner(self.node_ref(x, Node::left)); - let r_max = self.check_max_inner(self.node_ref(x, Node::right)); - let max = self.node_ref(x, |x| x.interval().high.max(l_max).max(r_max)); - assert_eq!(self.max(x), Some(&max)); - max - } - - /// 1. Every node is either red or black. - /// 2. The root is black. - /// 3. Every leaf (NIL) is black. - /// 4. If a node is red, then both its children are black. - /// 5. For each node, all simple paths from the node to descendant leaves contain the - /// same number of black nodes. - fn check_rb_properties(&self) { - assert!(matches!( - self.node_ref(self.root, Node::color), - Color::Black - )); - self.check_children_color(self.root); - self.check_black_height(self.root); - } - - fn check_children_color(&self, x: NodeIndex) { - if self.node_ref(x, Node::is_sentinel) { - return; - } - self.check_children_color(self.node_ref(x, Node::left)); - self.check_children_color(self.node_ref(x, Node::right)); - if self.node_ref(x, Node::is_red) { - assert!(matches!(self.left_ref(x, Node::color), Color::Black)); - assert!(matches!(self.right_ref(x, Node::color), Color::Black)); - } - } - - fn check_black_height(&self, x: NodeIndex) -> usize { - if self.node_ref(x, Node::is_sentinel) { - return 0; - } - let lefth = self.check_black_height(self.node_ref(x, Node::left)); - let righth = self.check_black_height(self.node_ref(x, Node::right)); - assert_eq!(lefth, righth); - if self.node_ref(x, Node::is_black) { - return lefth + 1; - } - lefth - } - } - - fn with_map_and_generator(test_fn: impl Fn(IntervalMap, IntervalGenerator)) { - let seeds = vec![[0; 32], [1; 32], [2; 32]]; - for seed in seeds { - let gen = IntervalGenerator::new(seed); - let map = IntervalMap::new(); - test_fn(map, gen); - } - } - - #[test] - fn red_black_tree_properties_is_satisfied() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - } - map.check_rb_properties(); - }); - } - - #[test] - fn map_len_will_update() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(100) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - } - assert_eq!(map.len(), 100); - for i in intervals { - let _ignore = map.remove(&i); - } - assert_eq!(map.len(), 0); - }); - } - - #[test] - fn check_overlap_is_ok() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_with_range(10)) - .take(100) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - } - let to_check: Vec<_> = std::iter::repeat_with(|| gen.next_with_range(10)) - .take(1000) - .collect(); - let expects: Vec<_> = to_check - .iter() - .map(|ci| intervals.iter().any(|i| ci.overlap(i))) - .collect(); - - for (ci, expect) in to_check.into_iter().zip(expects.into_iter()) { - assert_eq!(map.overlap(&ci), expect); - } - }); - } - - #[test] - fn check_max_is_ok() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - map.check_max(); - } - assert_eq!(map.len(), 1000); - for i in intervals { - let _ignore = map.remove(&i); - map.check_max(); - } - }); - } - - #[test] - fn remove_non_exist_interval_will_do_nothing() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in intervals { - let _ignore = map.insert(i, ()); - } - assert_eq!(map.len(), 1000); - let to_remove: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in to_remove { - let _ignore = map.remove(&i); - } - assert_eq!(map.len(), 1000); - }); - } - - #[test] - fn find_all_overlap_is_ok() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - } - let to_find: Vec<_> = std::iter::repeat_with(|| gen.next()).take(1000).collect(); - - let expects: Vec> = to_find - .iter() - .map(|ti| intervals.iter().filter(|i| ti.overlap(i)).collect()) - .collect(); - - for (ti, mut expect) in to_find.into_iter().zip(expects.into_iter()) { - let mut result = map.find_all_overlap(&ti); - expect.sort_unstable(); - result.sort_unstable(); - assert_eq!(expect.len(), result.len()); - for (e, r) in expect.into_iter().zip(result.into_iter()) { - assert_eq!(e, r.0); - } - } - }); - } - - #[test] - fn iterate_through_map_is_sorted() { - with_map_and_generator(|mut map, mut gen| { - let mut intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .enumerate() - .take(1000) - .collect(); - for (v, i) in intervals.clone() { - let _ignore = map.insert(i, v); - } - intervals.sort_unstable_by(|a, b| a.1.cmp(&b.1)); - - for ((ei, ev), (v, i)) in map.iter().zip(intervals.iter()) { - assert_eq!(ei, i); - assert_eq!(ev, v); - } - }); - } - - #[test] - fn interval_map_clear_is_ok() { - let mut map = IntervalMap::new(); - map.insert(Interval::new(1, 3), 1); - map.insert(Interval::new(2, 4), 2); - map.insert(Interval::new(6, 7), 3); - assert_eq!(map.len(), 3); - map.clear(); - assert_eq!(map.len(), 0); - assert!(map.is_empty()); - assert_eq!(map.nodes.len(), 1); - assert!(map.nodes[0].is_sentinel()); - } -} diff --git a/src/lib.rs b/src/lib.rs index 13e5141..446178d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,4 +29,4 @@ mod node; pub use entry::{Entry, OccupiedEntry, VacantEntry}; pub use interval::Interval; -pub use intervalmap::{IntervalMap, Iter}; +pub use intervalmap::IntervalMap; diff --git a/src/node.rs b/src/node.rs index 63cbb03..263df69 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,7 +1,13 @@ +use crate::index::{IndexType, NodeIndex}; use crate::interval::Interval; -use crate::index::{IndexType, NodeIndex}; +#[cfg(feature = "graphviz")] +use std::fmt::Display; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] /// Node of the interval tree #[derive(Debug)] pub struct Node { @@ -22,103 +28,261 @@ pub struct Node { pub value: Option, } -// Convenient getter/setter methods impl Node where - Ix: IndexType, + T: Ord, { - pub fn color(&self) -> Color { - self.color - } - pub fn interval(&self) -> &Interval { self.interval.as_ref().unwrap() } - pub fn max_index(&self) -> NodeIndex { + pub fn value(&self) -> &V { + self.value.as_ref().unwrap() + } + pub fn value_mut(&mut self) -> &mut V { + self.value.as_mut().unwrap() + } + pub fn take_value(&mut self) -> V { + self.value.take().unwrap() + } + + pub fn set_value(value: V) -> impl FnOnce(&mut Node) -> V { + move |node: &mut Node| node.value.replace(value).unwrap() + } +} + +// Convenient getter/setter methods +impl Node +where + T: Ord, + Ix: IndexType, +{ + pub(crate) fn max_index(&self) -> NodeIndex { self.max_index.unwrap() } - pub fn left(&self) -> NodeIndex { - self.left.unwrap() + pub(crate) fn set_max_index(max_index: NodeIndex) -> impl FnOnce(&mut Node) { + move |node: &mut Node| { + let _ignore = node.max_index.replace(max_index); + } } - pub fn right(&self) -> NodeIndex { - self.right.unwrap() + pub(crate) fn left(&self) -> NodeIndex { + self.left.unwrap() } - pub fn parent(&self) -> NodeIndex { - self.parent.unwrap() + pub(crate) fn set_left(left: NodeIndex) -> impl FnOnce(&mut Node) { + move |node: &mut Node| { + let _ignore = node.left.replace(left); + } } - pub fn is_sentinel(&self) -> bool { - self.interval.is_none() + pub(crate) fn right(&self) -> NodeIndex { + self.right.unwrap() } - pub fn sentinel(&self) -> Option<&Self> { - self.interval.is_some().then_some(self) + pub(crate) fn set_right(right: NodeIndex) -> impl FnOnce(&mut Node) { + move |node: &mut Node| { + let _ignore = node.right.replace(right); + } } - pub fn is_black(&self) -> bool { - matches!(self.color, Color::Black) + pub(crate) fn parent(&self) -> NodeIndex { + self.parent.unwrap() } - pub fn is_red(&self) -> bool { - matches!(self.color, Color::Red) + pub(crate) fn set_parent(parent: NodeIndex) -> impl FnOnce(&mut Node) { + move |node: &mut Node| { + let _ignore = node.parent.replace(parent); + } + } + pub(crate) fn is_sentinel(&self) -> bool { + self.interval.is_none() } - pub fn value(&self) -> &V { - self.value.as_ref().unwrap() + pub(crate) fn sentinel(&self) -> Option<&Self> { + self.interval.is_some().then_some(self) } - pub fn value_mut(&mut self) -> &mut V { - self.value.as_mut().unwrap() + pub(crate) fn color(&self) -> Color { + self.color } - pub fn take_value(&mut self) -> V { - self.value.take().unwrap() + pub(crate) fn is_black(&self) -> bool { + matches!(self.color, Color::Black) } - pub fn set_value(value: V) -> impl FnOnce(&mut Node) -> V { - move |node: &mut Node| node.value.replace(value).unwrap() + pub(crate) fn is_red(&self) -> bool { + matches!(self.color, Color::Red) } - pub fn set_color(color: Color) -> impl FnOnce(&mut Node) { + pub(crate) fn set_color(color: Color) -> impl FnOnce(&mut Node) { move |node: &mut Node| { node.color = color; } } +} - pub fn set_max_index(max_index: NodeIndex) -> impl FnOnce(&mut Node) { - move |node: &mut Node| { - let _ignore = node.max_index.replace(max_index); +#[cfg(feature = "graphviz")] +impl Node +where + T: Ord + Display, + V: Display, + Ix: IndexType, +{ + pub(crate) fn draw( + &self, + index: usize, + mut writer: W, + ) -> std::io::Result<()> { + writeln!( + writer, + " {} [label=\"i={}\\n{}: {}\\n\", fillcolor={}, style=filled]", + index, + index, + self.interval.as_ref().unwrap(), + self.value.as_ref().unwrap(), + if self.is_red() { "salmon" } else { "grey65" } + )?; + if !self.left.unwrap().is_sentinel() { + writeln!( + writer, + " {} -> {} [label=\"L\"]", + index, + self.left.unwrap().index() + )?; + } + if !self.right.unwrap().is_sentinel() { + writeln!( + writer, + " {} -> {} [label=\"R\"]", + index, + self.right.unwrap().index() + )?; } + Ok(()) } +} - pub fn set_left(left: NodeIndex) -> impl FnOnce(&mut Node) { - move |node: &mut Node| { - let _ignore = node.left.replace(left); +#[cfg(feature = "graphviz")] +impl Node +where + T: Display + Ord, + Ix: IndexType, +{ + pub(crate) fn draw_without_value( + &self, + index: usize, + mut writer: W, + ) -> std::io::Result<()> { + writeln!( + writer, + " {} [label=\"i={}: {}\", fillcolor={}, style=filled]", + index, + index, + self.interval.as_ref().unwrap(), + if self.is_red() { "salmon" } else { "grey65" } + )?; + if !self.left.unwrap().is_sentinel() { + writeln!( + writer, + " {} -> {} [label=\"L\"]", + index, + self.left.unwrap().index() + )?; + } + if !self.right.unwrap().is_sentinel() { + writeln!( + writer, + " {} -> {} [label=\"R\"]", + index, + self.right.unwrap().index() + )?; } + Ok(()) } +} - pub fn set_right(right: NodeIndex) -> impl FnOnce(&mut Node) { - move |node: &mut Node| { - let _ignore = node.right.replace(right); +impl Node +where + T: Ord, + Ix: IndexType, +{ + pub fn new(interval: Interval, value: V, index: NodeIndex) -> Self { + Node { + interval: Some(interval), + value: Some(value), + max_index: Some(index), + left: Some(NodeIndex::SENTINEL), + right: Some(NodeIndex::SENTINEL), + parent: Some(NodeIndex::SENTINEL), + color: Color::Red, } } - pub fn set_parent(parent: NodeIndex) -> impl FnOnce(&mut Node) { - move |node: &mut Node| { - let _ignore = node.parent.replace(parent); + pub fn new_sentinel() -> Self { + Node { + interval: None, + value: None, + max_index: None, + left: None, + right: None, + parent: None, + color: Color::Black, } } } /// The color of the node -#[derive(Debug, Clone, Copy)] -pub enum Color { +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum Color { /// Red node Red, /// Black node Black, } + +#[cfg(feature = "serde")] +#[cfg(test)] +mod tests { + use super::*; + use serde_json::{json, Value}; + + #[test] + fn test_node_serialize_deserialize() { + let node = Node:: { + left: Some(NodeIndex::new(0)), + right: Some(NodeIndex::new(1)), + parent: Some(NodeIndex::new(2)), + color: Color::Red, + interval: Some(Interval::new(10, 20)), + max_index: Some(NodeIndex::new(3)), + value: Some(42), + }; + + // Serialize the node to JSON + let serialized = serde_json::to_string(&node).unwrap(); + let expected = json!({ + "left": 0, + "right": 1, + "parent": 2, + "color": "Red", + "interval": [10,20], + "max_index": 3, + "value": 42 + }); + let actual: Value = serde_json::from_str(&serialized).unwrap(); + assert_eq!(expected, actual); + + // Deserialize the node from JSON + let deserialized: Node = serde_json::from_str(&serialized).unwrap(); + assert_eq!(node.left, deserialized.left); + assert_eq!(node.right, deserialized.right); + assert_eq!(node.parent, deserialized.parent); + assert_eq!(node.color, deserialized.color); + assert_eq!(node.interval, deserialized.interval); + assert_eq!(node.max_index, deserialized.max_index); + assert_eq!(node.value, deserialized.value); + } +}