diff --git a/.github/workflows/pr_check.yml b/.github/workflows/pr_check.yml new file mode 100644 index 0000000..6eee752 --- /dev/null +++ b/.github/workflows/pr_check.yml @@ -0,0 +1,70 @@ +name: PR Checks + +on: + push: + branches: [main] + pull_request: + types: [opened, synchronize, reopened] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + componets: clippy + override: true + + - name: Make script executable + run: chmod +x ci/scripts/check-trailing-spaces.sh + + - name: Trailing spaces check + run: ci/scripts/check-trailing-spaces.sh + + - name: Audit + run: cargo audit + + - name: Format + run: cargo fmt --all -- --check + + - name: Install cargo-hakari + run: cargo install cargo-hakari + + - name: Workspace hack check + run: cargo hakari generate --diff && cargo hakari manage-deps --dry-run && cargo hakari verify + + - name: Clippy + run: cargo clippy --all-targets --all-features -- -D warnings + + - name: Test + run: cargo test --verbose + + commit: + name: Commit Message Validation + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - run: git show-ref + - uses: actions-rs/install@v0.1 + with: + crate: git-cz + version: latest + - name: Validate commit messages + run: git-cz check ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} + + spell-check: + name: Spell Check + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Check Spelling + uses: crate-ci/typos@v1.23.1 diff --git a/Cargo.lock b/Cargo.lock index 78c1808..493b595 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,6 +205,8 @@ version = "0.1.0" dependencies = [ "criterion", "rand", + "serde", + "serde_json", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8a89c0e..78b398b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,14 +8,20 @@ license = "Apache-2.0" keywords = ["Interval Tree", "Augmented Tree", "Red-Black Tree"] [dependencies] +serde = { version = "1.0", default-features = false, features = [ + "derive", + "std", +], optional = true } [dev-dependencies] criterion = "0.5.1" rand = "0.8.5" +serde_json = "1.0" [features] default = [] -interval_tree_find_overlap_ordered = [] +graphviz = [] +serde = ["dep:serde"] [[bench]] name = "bench" diff --git a/README.md b/README.md index 0fdb9b5..a7f1cc6 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ The implementation of the interval tree in interval_map references "Introduction To safely and efficiently handle insertion and deletion operations in Rust, `interval_map` innovatively **uses arrays to simulate pointers** for managing the parent-child references in the red-black tree. This approach also ensures that interval_map has the `Send` and `Unpin` traits, allowing it to be safely transferred between threads and to maintain a fixed memory location during asynchronous operations. `interval_map` implements an `IntervalMap` struct: -- It accepts `Interval` as the key, where `T` can be any type that implements `Ord+Clone` trait. Therefore, intervals such as $[1, 2)$ and $["aaa", "bbb")$ are allowed +- It accepts `Interval` as the key, where `T` can be any type that implements `Ord` trait. Therefore, intervals such as $[1, 2)$ and $["aaa", "bbb")$ are allowed - The value can be of any type `interval_map` supports `insert`, `delete`, and `iter` fns. Traversal is performed in the order of `Interval` . For instance, with intervals of type `Interval`: @@ -22,15 +22,16 @@ Currently, `interval_map` only supports half-open intervals, i.e., $[...,...)$. The benchmark was conducted on a platform with `AMD R7 7840H + DDR5 5600MHz`. The result are as follows: 1. Only insert - | insert | 100 | 1000 | 10, 000 | 100, 000 | - | --------------- | --------- | --------- | --------- | --------- | - | Time per insert | 5.4168 µs | 80.518 µs | 2.2823 ms | 36.528 ms | + | insert | 100 | 1000 | 10, 000 | 100, 000 | + | ---------- | --------- | --------- | --------- | --------- | + | Total time | 5.4168 µs | 80.518 µs | 2.2823 ms | 36.528 ms | 2. Insert N and remove N - | insert_and_remove | 100 | 1000 | 10, 000 | 100, 000 | - | ------------------ | --------- | --------- | --------- | --------- | - | Time per operation | 10.333 µs | 223.43 µs | 4.9358 ms | 81.634 ms | + | insert_and_remove | 100 | 1000 | 10, 000 | 100, 000 | + | ----------------- | --------- | --------- | --------- | --------- | + | Total time | 10.333 µs | 223.43 µs | 4.9358 ms | 81.634 ms | ## TODO -- [] Support for $(...,...)$, $[...,...]$ and $(...,...]$ interval types. -- [] Add more tests like [etcd](https://github.com/etcd-io/etcd/blob/main/pkg/adt/interval_tree_test.go) -- [] Add Point type for Interval +- [ ] ~~Support for $(...,...)$, $[...,...]$ and $(...,...]$ interval types.~~ There's no way to support these interval type without performance loss now. +- [ ] ~~Add Point type for Interval~~ To support Point type, it should also support $[...,...]$, so it couldn't be supported now, either. But you could write code like [examples/new_point](examples/new_point.rs). +- [x] Add more tests like [etcd](https://github.com/etcd-io/etcd/blob/main/pkg/adt/interval_tree_test.go). +- [x] Refine iter mod. \ No newline at end of file diff --git a/benches/bench.rs b/benches/bench.rs index f7f7435..c9f9de1 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,44 +1,25 @@ use criterion::{criterion_group, criterion_main, Bencher, Criterion}; use interval_map::{Interval, IntervalMap}; +use rand::{rngs::StdRng, Rng, SeedableRng}; use std::hint::black_box; -struct Rng { - state: u32, -} -impl Rng { - fn new() -> Self { - Self { state: 0x87654321 } - } - - fn gen_u32(&mut self) -> u32 { - self.state ^= self.state << 13; - self.state ^= self.state >> 17; - self.state ^= self.state << 5; - self.state - } - - fn gen_range_i32(&mut self, low: i32, high: i32) -> i32 { - let d = (high - low) as u32; - low + (self.gen_u32() % d) as i32 - } -} - struct IntervalGenerator { - rng: Rng, - limit: i32, + rng: StdRng, } impl IntervalGenerator { fn new() -> Self { - const LIMIT: i32 = 100000; + let seed: [u8; 32] = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + ]; Self { - rng: Rng::new(), - limit: LIMIT, + rng: StdRng::from_seed(seed), } } - fn next(&mut self) -> Interval { - let low = self.rng.gen_range_i32(0, self.limit - 1); - let high = self.rng.gen_range_i32(low + 1, self.limit); + fn next(&mut self) -> Interval { + let low = self.rng.gen(); + let high = self.rng.gen_range(low + 1..=u32::MAX); Interval::new(low, high) } } @@ -65,7 +46,7 @@ fn interval_map_insert_remove(count: usize, bench: &mut Bencher) { black_box(map.insert(i, ())); } for i in &intervals { - black_box(map.remove(&i)); + black_box(map.remove(i)); } }); } @@ -100,14 +81,68 @@ fn bench_interval_map_insert_remove(c: &mut Criterion) { }); } +// FilterIter helper fn +fn interval_map_filter_iter(count: usize, bench: &mut Bencher) { + let mut gen = IntervalGenerator::new(); + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next()).take(count).collect(); + let mut map = IntervalMap::new(); + for i in intervals.clone() { + map.insert(i, ()); + } + bench.iter(|| { + for i in intervals.clone() { + black_box(map.filter_iter(&i).collect::>()); + } + }); +} + +// iter().filter() helper fn +fn interval_map_iter_filter(count: usize, bench: &mut Bencher) { + let mut gen = IntervalGenerator::new(); + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next()).take(count).collect(); + let mut map = IntervalMap::new(); + for i in intervals.clone() { + map.insert(i, ()); + } + bench.iter(|| { + for i in intervals.clone() { + black_box(map.iter().filter(|v| v.0.overlap(&i)).collect::>()); + } + }); +} + +fn bench_interval_map_filter_iter(c: &mut Criterion) { + c.bench_function("bench_interval_map_filter_iter_100", |b| { + interval_map_filter_iter(100, b) + }); + c.bench_function("bench_interval_map_filter_iter_1000", |b| { + interval_map_filter_iter(1000, b) + }); +} + +fn bench_interval_map_iter_filter(c: &mut Criterion) { + c.bench_function("bench_interval_map_iter_filter_100", |b| { + interval_map_iter_filter(100, b) + }); + c.bench_function("bench_interval_map_iter_filter_1000", |b| { + interval_map_iter_filter(1000, b) + }); +} + fn criterion_config() -> Criterion { Criterion::default().configure_from_args().without_plots() } criterion_group! { - name = benches; + name = benches_basic_op; + config = criterion_config(); + targets = bench_interval_map_insert, bench_interval_map_insert_remove, +} + +criterion_group! { + name = benches_iter; config = criterion_config(); - targets = bench_interval_map_insert, bench_interval_map_insert_remove + targets = bench_interval_map_filter_iter, bench_interval_map_iter_filter } -criterion_main!(benches); +criterion_main!(benches_basic_op, benches_iter); diff --git a/ci/scripts/check-trailing-spaces.sh b/ci/scripts/check-trailing-spaces.sh new file mode 100644 index 0000000..8a788db --- /dev/null +++ b/ci/scripts/check-trailing-spaces.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash + +# Script Author: https://github.com/stdrc +# Repo: https://github.com/risingwavelabs/risingwave/blame/main/scripts/check/check-trailing-spaces.sh + +# Exits as soon as any line fails. +set -euo pipefail + +self=$0 + +# Shell colors +RED='\033[0;31m' +BLUE='\033[0;34m' +GREEN='\033[0;32m' +ORANGE='\033[0;33m' +BOLD='\033[1m' +NONE='\033[0m' + +print_help() { + echo "Usage: $self [-f|--fix]" + echo + echo "Options:" + echo " -f, --fix Fix trailing spaces." + echo " -h, --help Show this help message and exit." +} + +fix=false +while [ $# -gt 0 ]; do + case $1 in + -f | --fix) + fix=true + ;; + -h | --help) + print_help + exit 0 + ;; + *) + echo -e "${RED}${BOLD}$self: invalid option \`$1\`\n${NONE}" + print_help + exit 1 + ;; + esac + shift +done + +temp_file=$(mktemp) + +echo -ne "${BLUE}" +git config --global --add safe.directory '*' +git grep -nIP --untracked '[[:space:]]+$' | tee $temp_file || true +echo -ne "${NONE}" + +bad_files=$(cat $temp_file | cut -f1 -d ':' | sort -u) +rm $temp_file + +if [ ! -z "$bad_files" ]; then + if [[ $fix == true ]]; then + for file in $bad_files; do + sed -i '' -e's/[[:space:]]*$//' "$file" + done + + echo + echo -e "${GREEN}${BOLD}All trailing spaces listed above have been cleaned.${NONE}" + exit 0 + else + echo + echo -e "${RED}${BOLD}Please clean all the trailing spaces listed above.${NONE}" + echo -e "${BOLD}You can run '$self --fix' for convenience.${NONE}" + exit 1 + fi +else + echo -e "${GREEN}${BOLD}No trailing spaces found.${NONE}" + exit 0 +fi \ No newline at end of file diff --git a/examples/new_point.rs b/examples/new_point.rs new file mode 100644 index 0000000..e872dd6 --- /dev/null +++ b/examples/new_point.rs @@ -0,0 +1,27 @@ +use interval_map::{Interval, IntervalMap}; + +trait Point { + fn new_point(x: T) -> Interval; +} + +impl Point for Interval { + fn new_point(x: u32) -> Self { + Interval::new(x, x + 1) + } +} + +fn main() { + let mut interval_map = IntervalMap::::new(); + interval_map.insert(Interval::new(3, 7), 20); + interval_map.insert(Interval::new(2, 6), 15); + + let tmp_point = Interval::new_point(5); + assert_eq!(tmp_point, Interval::new(5, 6)); + + interval_map.insert(tmp_point.clone(), 10); + assert_eq!(interval_map.get(&tmp_point).unwrap(), &10); + assert_eq!( + interval_map.find_all_overlap(&Interval::new_point(5)).len(), + 3 + ); +} diff --git a/examples/string_affine.rs b/examples/string_affine.rs new file mode 100644 index 0000000..5595b18 --- /dev/null +++ b/examples/string_affine.rs @@ -0,0 +1,68 @@ +use std::cmp; + +use interval_map::{Interval, IntervalMap}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StringAffine { + /// String + String(String), + /// Unbounded + Unbounded, +} + +impl StringAffine { + pub fn new_key(s: &str) -> Self { + Self::String(s.to_string()) + } + + pub fn new_unbounded() -> Self { + Self::Unbounded + } +} + +impl PartialOrd for StringAffine { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for StringAffine { + fn cmp(&self, other: &Self) -> cmp::Ordering { + match (self, other) { + (StringAffine::String(x), StringAffine::String(y)) => x.cmp(y), + (StringAffine::String(_), StringAffine::Unbounded) => cmp::Ordering::Less, + (StringAffine::Unbounded, StringAffine::String(_)) => cmp::Ordering::Greater, + (StringAffine::Unbounded, StringAffine::Unbounded) => cmp::Ordering::Equal, + } + } +} + +trait Point { + fn new_point(x: T) -> Interval; +} + +impl Point for Interval { + fn new_point(x: StringAffine) -> Interval { + match x { + StringAffine::String(mut x_string) => { + let low = x_string.clone(); + x_string.push('\0'); + Interval::new( + StringAffine::new_key(&low), + StringAffine::new_key(&x_string), + ) + } + _ => panic!("new_point only receive StringAffine::String!"), + } + } +} + +fn main() { + let mut interval_map = IntervalMap::::new(); + interval_map.insert( + Interval::new(StringAffine::new_key("8"), StringAffine::Unbounded), + 123, + ); + assert!(interval_map.overlaps(&Interval::new_point(StringAffine::new_key("9")))); + assert!(!interval_map.overlaps(&Interval::new_point(StringAffine::new_key("7")))); +} diff --git a/src/entry.rs b/src/entry.rs index 731a55a..59139aa 100644 --- a/src/entry.rs +++ b/src/entry.rs @@ -5,7 +5,10 @@ use crate::node::Node; /// A view into a single entry in a map, which may either be vacant or occupied. #[derive(Debug)] -pub enum Entry<'a, T, V, Ix> { +pub enum Entry<'a, T, V, Ix> +where + T: Ord, +{ /// An occupied entry. Occupied(OccupiedEntry<'a, T, V, Ix>), /// A vacant entry. @@ -15,17 +18,23 @@ pub enum Entry<'a, T, V, Ix> { /// A view into an occupied entry in a `IntervalMap`. /// It is part of the [`Entry`] enum. #[derive(Debug)] -pub struct OccupiedEntry<'a, T, V, Ix> { +pub struct OccupiedEntry<'a, T, V, Ix> +where + T: Ord, +{ /// Reference to the map pub map_ref: &'a mut IntervalMap, /// The entry node - pub node: NodeIndex, + pub node_idx: NodeIndex, } /// A view into a vacant entry in a `IntervalMap`. /// It is part of the [`Entry`] enum. #[derive(Debug)] -pub struct VacantEntry<'a, T, V, Ix> { +pub struct VacantEntry<'a, T, V, Ix> +where + T: Ord, +{ /// Mutable reference to the map pub map_ref: &'a mut IntervalMap, /// The interval of this entry @@ -53,7 +62,7 @@ where #[inline] pub fn or_insert(self, default: V) -> &'a mut V { match self { - Entry::Occupied(entry) => entry.map_ref.node_mut(entry.node, Node::value_mut), + Entry::Occupied(entry) => entry.map_ref.node_mut(entry.node_idx, Node::value_mut), Entry::Vacant(entry) => { let entry_idx = NodeIndex::new(entry.map_ref.nodes.len()); let _ignore = entry.map_ref.insert(entry.interval, default); @@ -88,7 +97,7 @@ where { match self { Entry::Occupied(entry) => { - f(entry.map_ref.node_mut(entry.node, Node::value_mut)); + f(entry.map_ref.node_mut(entry.node_idx, Node::value_mut)); Self::Occupied(entry) } Entry::Vacant(entry) => Self::Vacant(entry), diff --git a/src/index.rs b/src/index.rs index 657f955..c5d156c 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,30 +1,49 @@ use std::fmt; use std::hash::Hash; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + pub type DefaultIx = u32; -pub unsafe trait IndexType: Copy + Default + Hash + Ord + fmt::Debug + 'static { +pub trait IndexType: Copy + Default + Hash + Ord + fmt::Debug + 'static { + const SENTINEL: Self; fn new(x: usize) -> Self; fn index(&self) -> usize; fn max() -> Self; + fn is_sentinel(&self) -> bool { + *self == Self::SENTINEL + } } -unsafe impl IndexType for u32 { - #[inline(always)] - fn new(x: usize) -> Self { - x as u32 - } - #[inline(always)] - fn index(&self) -> usize { - *self as usize - } - #[inline(always)] - fn max() -> Self { - ::std::u32::MAX - } +macro_rules! impl_index { + ($type:ident) => { + impl IndexType for $type { + const SENTINEL: Self = 0; + + #[inline(always)] + fn new(x: usize) -> Self { + x as $type + } + #[inline(always)] + fn index(&self) -> usize { + *self as usize + } + #[inline(always)] + fn max() -> Self { + ::std::$type::MAX + } + } + }; } +impl_index!(u8); +impl_index!(u16); +impl_index!(u32); +impl_index!(u64); + /// Node identifier. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct NodeIndex(Ix); @@ -34,24 +53,30 @@ impl NodeIndex { NodeIndex(IndexType::new(x)) } - #[inline] - pub fn index(self) -> usize { - self.0.index() - } - #[inline] pub fn end() -> Self { NodeIndex(IndexType::max()) } + + pub fn incre(&self) -> Self { + NodeIndex::new(self.index().wrapping_add(1)) + } } -unsafe impl IndexType for NodeIndex { +impl IndexType for NodeIndex { + const SENTINEL: Self = NodeIndex(Ix::SENTINEL); + + #[inline] fn index(&self) -> usize { self.0.index() } + + #[inline] fn new(x: usize) -> Self { NodeIndex::new(x) } + + #[inline] fn max() -> Self { NodeIndex(::max()) } diff --git a/src/interval.rs b/src/interval.rs index e7d78bc..f775c7f 100644 --- a/src/interval.rs +++ b/src/interval.rs @@ -8,9 +8,14 @@ //! //! Currently, `interval_map` only supports half-open intervals, i.e., [...,...). +use std::fmt::{Display, Formatter}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + /// The interval stored in `IntervalMap` represents [low, high) -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[non_exhaustive] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Interval { /// Low value pub low: T, @@ -35,6 +40,24 @@ impl Interval { pub fn overlap(&self, other: &Self) -> bool { self.high > other.low && other.high > self.low } + + /// Checks if self contains other interval + /// e.g. [1,10) contains [1,8) + #[inline] + pub fn contain(&self, other: &Self) -> bool { + self.low <= other.low && self.high > other.high + } + + /// Checks if self contains a point + pub fn contain_point(&self, p: T) -> bool { + self.low <= p && self.high > p + } +} + +impl Display for Interval { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "[{},{})", self.low, self.high) + } } /// Reference type of `Interval` @@ -64,6 +87,21 @@ impl<'a, T: Ord> IntervalRef<'a, T> { } } +#[cfg(feature = "serde")] +impl Serialize for Interval { + fn serialize(&self, serializer: S) -> Result { + (&self.low, &self.high).serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'de, T: Deserialize<'de> + Ord> Deserialize<'de> for Interval { + fn deserialize>(deserializer: D) -> Result { + let (low, high) = <(T, T)>::deserialize(deserializer)?; + Ok(Interval::new(low, high)) + } +} + #[cfg(test)] mod test { use super::*; @@ -73,4 +111,42 @@ mod test { fn invalid_range_should_panic() { let _interval = Interval::new(3, 1); } + + #[test] + fn test_interval_clone() { + let interval1 = Interval::new(1, 10); + let interval2 = interval1.clone(); + assert_eq!(interval1, interval2); + } + + #[test] + fn test_interval_compare() { + let interval1 = Interval::new(1, 10); + let interval2 = Interval::new(5, 15); + assert!(interval1 < interval2); + assert!(interval2 > interval1); + assert_eq!(interval1, Interval::new(1, 10)); + assert_ne!(interval1, interval2); + } + + #[test] + fn test_interval_hash() { + let interval1 = Interval::new(1, 10); + let interval2 = Interval::new(1, 10); + let interval3 = Interval::new(5, 15); + let mut hashset = std::collections::HashSet::new(); + hashset.insert(interval1); + hashset.insert(interval2); + hashset.insert(interval3); + assert_eq!(hashset.len(), 2); + } + + #[cfg(feature = "serde")] + #[test] + fn test_interval_serialize_deserialize() { + let interval = Interval::new(1, 10); + let serialized = serde_json::to_string(&interval).unwrap(); + let deserialized: Interval = serde_json::from_str(&serialized).unwrap(); + assert_eq!(interval, deserialized); + } } diff --git a/src/intervalmap.rs b/src/intervalmap.rs index 51dcb24..7c2a13e 100644 --- a/src/intervalmap.rs +++ b/src/intervalmap.rs @@ -1,12 +1,29 @@ use crate::entry::{Entry, OccupiedEntry, VacantEntry}; -use crate::index::{DefaultIx, IndexType, NodeIndex}; +use crate::index::{self, DefaultIx, IndexType, NodeIndex}; use crate::interval::{Interval, IntervalRef}; +use crate::iter::{FilterIter, IntoIter, Iter, UnsortedIter}; use crate::node::{Color, Node}; + use std::collections::VecDeque; +use std::fmt::Debug; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "graphviz")] +use std::fmt::Display; +#[cfg(feature = "graphviz")] +use std::fs::OpenOptions; +#[cfg(feature = "graphviz")] +use std::io::Write; /// An interval-value map, which support operations on dynamic sets of intervals. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] -pub struct IntervalMap { +pub struct IntervalMap +where + T: Ord, +{ /// Vector that stores nodes pub(crate) nodes: Vec>, /// Root of the interval tree @@ -24,16 +41,16 @@ where #[inline] #[must_use] pub fn with_capacity(capacity: usize) -> Self { - let mut nodes = vec![Self::new_sentinel()]; + let mut nodes = vec![Node::new_sentinel()]; nodes.reserve(capacity); IntervalMap { nodes, - root: Self::sentinel(), + root: NodeIndex::SENTINEL, len: 0, } } - /// Insert an interval-value pair into the map. + /// insert an interval-value pair into the map. /// If the interval exists, overwrite and return the previous value. /// /// # Panics @@ -52,10 +69,11 @@ where #[inline] pub fn insert(&mut self, interval: Interval, value: V) -> Option { let node_idx = NodeIndex::new(self.nodes.len()); - let node = Self::new_node(interval, value, node_idx); + let node = Node::new(interval, value, node_idx); // check for max capacity, except if we use usize assert!( - ::max().index() == !0 || NodeIndex::end() != node_idx, + ::max().index() == !0 + || as index::IndexType>::max() != node_idx, "Reached maximum number of nodes" ); self.nodes.push(node); @@ -101,19 +119,22 @@ where /// map.insert(Interval::new(1, 3), ()); /// map.insert(Interval::new(6, 7), ()); /// map.insert(Interval::new(9, 11), ()); - /// assert!(map.overlap(&Interval::new(2, 5))); - /// assert!(map.overlap(&Interval::new(1, 17))); - /// assert!(!map.overlap(&Interval::new(3, 6))); - /// assert!(!map.overlap(&Interval::new(11, 23))); + /// assert!(map.overlaps(&Interval::new(2, 5))); + /// assert!(map.overlaps(&Interval::new(1, 17))); + /// assert!(!map.overlaps(&Interval::new(3, 6))); + /// assert!(!map.overlaps(&Interval::new(11, 23))); /// ``` #[inline] - pub fn overlap(&self, interval: &Interval) -> bool { + pub fn overlaps(&self, interval: &Interval) -> bool { let node_idx = self.search(interval); !self.node_ref(node_idx, Node::is_sentinel) } /// Find all intervals in the map that overlaps with the given interval. /// + /// # Note + /// This method's returned data is unordered. To get ordered data, please use `find_all_overlap_ordered`. + /// /// # Example /// ```rust /// use interval_map::{Interval, IntervalMap}; @@ -129,10 +150,45 @@ where /// ``` #[inline] pub fn find_all_overlap(&self, interval: &Interval) -> Vec<(&Interval, &V)> { + if self.node_ref(self.root, Node::is_sentinel) { + return Vec::new(); + } + if self.len() > 20 { + self.find_all_overlap_inner(self.root, interval) + } else { + self.unsorted_iter() + .filter(|v| v.0.overlap(interval)) + .collect() + } + } + + /// Find all intervals in the map that overlaps with the given interval. + /// + /// # Note + /// This method's returned data is ordered. Generally, it's much slower than `find_all_overlap`. + /// + /// # Example + /// ```rust + /// use interval_map::{Interval, IntervalMap}; + /// + /// let mut map = IntervalMap::new(); + /// map.insert(Interval::new(1, 3), ()); + /// map.insert(Interval::new(2, 4), ()); + /// map.insert(Interval::new(6, 7), ()); + /// map.insert(Interval::new(7, 11), ()); + /// assert_eq!(map.find_all_overlap(&Interval::new(2, 7)).len(), 3); + /// map.remove(&Interval::new(1, 3)); + /// assert_eq!(map.find_all_overlap(&Interval::new(2, 7)).len(), 2); + /// ``` + #[inline] + pub fn find_all_overlap_ordered<'a>( + &'a self, + interval: &'a Interval, + ) -> Vec<(&Interval, &V)> { if self.node_ref(self.root, Node::is_sentinel) { Vec::new() } else { - self.find_all_overlap_inner_unordered(self.root, interval) + self.filter_iter(interval).collect() } } @@ -176,10 +232,71 @@ where #[inline] #[must_use] pub fn iter(&self) -> Iter<'_, T, V, Ix> { - Iter { - map_ref: self, - stack: None, - } + Iter::new(self) + } + + /// Get an into iterator over the entries of the map, sorted by key. + // #[inline] + // #[must_use] + // pub fn into_iter(self) -> IntoIter { + // IntoIter::new(self) + // } + + /// Get an iterator over the entries of the map, unsorted. + #[inline] + pub fn unsorted_iter(&self) -> UnsortedIter { + UnsortedIter::new(self) + } + + /// Get an iterator over the entries that overlap the `query`, sorted by key. + /// + /// # Panics + /// + /// The method panics when `query` contains a value that cannot be compared. + #[inline] + pub fn filter_iter<'a, 'b: 'a>(&'a self, query: &'b Interval) -> FilterIter { + FilterIter::new(self, query) + } + + /// Return true if the interval tree's key cover the entire given interval. + /// + /// # Example + /// ```rust + /// use interval_map::{Interval, IntervalMap}; + /// + /// let mut map = IntervalMap::new(); + /// map.insert(Interval::new(3, 5), 0); + /// map.insert(Interval::new(5, 8), 1); + /// map.insert(Interval::new(9, 12), 1); + /// assert!(map.contains(&Interval::new(4, 6))); + /// assert!(!map.contains(&Interval::new(7, 10))); + /// ``` + #[inline] + pub fn contains(&self, interval: &Interval) -> bool { + let mut max_end: Option<&T> = None; + let mut min_begin: Option<&T> = None; + + let mut continuous = true; + self.filter_iter(interval).find(|v| { + if min_begin.is_none() { + min_begin = Some(&v.0.low); + max_end = Some(&v.0.high); + return false; + } + if max_end.map(|mv| mv < &v.0.low).unwrap() { + continuous = false; + return true; + } + if max_end.map(|mv| mv < &v.0.high).unwrap() { + max_end = Some(&v.0.high); + } + false + }); + + continuous + && min_begin.is_some() + && max_end.map(|mv| mv >= &interval.high).unwrap() + && min_begin.map(|mv| mv <= &interval.low).unwrap() } /// Get the given key's corresponding entry in the map for in-place manipulation. @@ -199,9 +316,9 @@ where #[inline] pub fn entry(&mut self, interval: Interval) -> Entry<'_, T, V, Ix> { match self.search_exact(&interval) { - Some(node) => Entry::Occupied(OccupiedEntry { + Some(node_idx) => Entry::Occupied(OccupiedEntry { map_ref: self, - node, + node_idx, }), None => Entry::Vacant(VacantEntry { map_ref: self, @@ -214,8 +331,8 @@ where #[inline] pub fn clear(&mut self) { self.nodes.clear(); - self.nodes.push(Self::new_sentinel()); - self.root = Self::sentinel(); + self.nodes.push(Node::new_sentinel()); + self.root = NodeIndex::SENTINEL; self.len = 0; } @@ -234,6 +351,20 @@ where } } +impl IntoIterator for IntervalMap +where + T: Ord, + Ix: IndexType, +{ + type Item = (Interval, V); + + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + impl IntervalMap where T: Ord, @@ -242,11 +373,7 @@ where #[inline] #[must_use] pub fn new() -> Self { - Self { - nodes: vec![Self::new_sentinel()], - root: Self::sentinel(), - len: 0, - } + Self::with_capacity(0) } } @@ -265,46 +392,9 @@ where T: Ord, Ix: IndexType, { - /// Create a new sentinel node - fn new_sentinel() -> Node { - Node { - interval: None, - value: None, - max_index: None, - left: None, - right: None, - parent: None, - color: Color::Black, - } - } - - /// Create a new tree node - fn new_node(interval: Interval, value: V, index: NodeIndex) -> Node { - Node { - max_index: Some(index), - interval: Some(interval), - value: Some(value), - left: Some(Self::sentinel()), - right: Some(Self::sentinel()), - parent: Some(Self::sentinel()), - color: Color::Red, - } - } - - /// Get the sentinel node index - fn sentinel() -> NodeIndex { - NodeIndex::new(0) - } -} - -impl IntervalMap -where - T: Ord, - Ix: IndexType, -{ - /// Insert a node into the tree. + /// insert a node into the tree. fn insert_inner(&mut self, z: NodeIndex) -> Option { - let mut y = Self::sentinel(); + let mut y = NodeIndex::SENTINEL; let mut x = self.root; while !self.node_ref(x, Node::is_sentinel) { @@ -380,34 +470,10 @@ where self.len = self.len.wrapping_sub(1); } - /// Find all intervals in the map that overlaps with the given interval. - #[cfg(interval_tree_find_overlap_ordered)] - fn find_all_overlap_inner( - &self, - x: NodeIndex, - interval: &Interval, - ) -> Vec<(&Interval, &V)> { - let mut list = vec![]; - if self.node_ref(x, Node::interval).overlap(interval) { - list.push(self.node_ref(x, |nx| (nx.interval(), nx.value()))); - } - if self.max(self.node_ref(x, Node::left)) >= Some(&interval.low) { - list.extend(self.find_all_overlap_inner(self.node_ref(x, Node::left), interval)); - } - if self - .max(self.node_ref(x, Node::right)) - .map(|rmax| IntervalRef::new(&self.node_ref(x, Node::interval).low, rmax)) - .is_some_and(|i| i.overlap(interval)) - { - list.extend(self.find_all_overlap_inner(self.node_ref(x, Node::right), interval)); - } - list - } - /// Find all intervals in the map that overlaps with the given interval. /// /// The result is unordered because of breadth-first search to save stack size - fn find_all_overlap_inner_unordered( + fn find_all_overlap_inner( &self, x: NodeIndex, interval: &Interval, @@ -680,13 +746,13 @@ where } /// Check if a node is a left child of its parent. - fn is_left_child(&self, node: NodeIndex) -> bool { - self.parent_ref(node, Node::left) == node + fn is_left_child(&self, node_idx: NodeIndex) -> bool { + self.parent_ref(node_idx, Node::left) == node_idx } /// Check if a node is a right child of its parent. - fn is_right_child(&self, node: NodeIndex) -> bool { - self.parent_ref(node, Node::right) == node + fn is_right_child(&self, node_idx: NodeIndex) -> bool { + self.parent_ref(node_idx, Node::right) == node_idx } /// Update nodes indices after remove @@ -719,423 +785,792 @@ where } } +#[cfg(feature = "graphviz")] +impl IntervalMap +where + T: Ord + Copy + Display, + V: Display, + Ix: IndexType, +{ + /// writes dot file to `filename`. `T` and `V` should implement `Display`. + pub fn draw(&self, filename: &str) -> std::io::Result<()> { + let mut file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(filename)?; + writeln!(file, "digraph {{")?; + // begin at 1, because 0 is sentinel node + for i in 1..self.nodes.len() { + self.nodes[i].draw(i, &mut file)?; + } + writeln!(file, "}}") + } +} + +#[cfg(feature = "graphviz")] +impl IntervalMap +where + T: Ord + Copy + Display, + Ix: IndexType, +{ + /// Writes dot file to `filename` without values. `T` should implement `Display`. + pub fn draw_without_value(&self, filename: &str) -> std::io::Result<()> { + let mut file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(filename)?; + writeln!(file, "digraph {{")?; + // begin at 1, because 0 is sentinel node + for i in 1..self.nodes.len() { + self.nodes[i].draw_without_value(i, &mut file)?; + } + writeln!(file, "}}") + } +} + // Convenient methods for reference or mutate current/parent/left/right node impl<'a, T, V, Ix> IntervalMap where + T: Ord, Ix: IndexType, { - fn node_ref(&'a self, node: NodeIndex, op: F) -> R + pub(crate) fn node_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - op(&self.nodes[node.index()]) + op(&self.nodes[node_idx.index()]) } - pub(crate) fn node_mut(&'a mut self, node: NodeIndex, op: F) -> R + pub(crate) fn node_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - op(&mut self.nodes[node.index()]) + op(&mut self.nodes[node_idx.index()]) } - fn left_ref(&'a self, node: NodeIndex, op: F) -> R + pub(crate) fn left_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - let idx = self.nodes[node.index()].left().index(); + let idx = self.nodes[node_idx.index()].left().index(); op(&self.nodes[idx]) } - fn right_ref(&'a self, node: NodeIndex, op: F) -> R + pub(crate) fn right_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - let idx = self.nodes[node.index()].right().index(); + let idx = self.nodes[node_idx.index()].right().index(); op(&self.nodes[idx]) } - fn parent_ref(&'a self, node: NodeIndex, op: F) -> R + fn parent_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - let idx = self.nodes[node.index()].parent().index(); + let idx = self.nodes[node_idx.index()].parent().index(); op(&self.nodes[idx]) } - fn grand_parent_ref(&'a self, node: NodeIndex, op: F) -> R + fn grand_parent_ref(&'a self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a Node) -> R, { - let parent_idx = self.nodes[node.index()].parent().index(); + let parent_idx = self.nodes[node_idx.index()].parent().index(); let grand_parent_idx = self.nodes[parent_idx].parent().index(); op(&self.nodes[grand_parent_idx]) } - fn left_mut(&'a mut self, node: NodeIndex, op: F) -> R + fn left_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - let idx = self.nodes[node.index()].left().index(); + let idx = self.nodes[node_idx.index()].left().index(); op(&mut self.nodes[idx]) } - fn right_mut(&'a mut self, node: NodeIndex, op: F) -> R + fn right_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - let idx = self.nodes[node.index()].right().index(); + let idx = self.nodes[node_idx.index()].right().index(); op(&mut self.nodes[idx]) } - fn parent_mut(&'a mut self, node: NodeIndex, op: F) -> R + fn parent_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - let idx = self.nodes[node.index()].parent().index(); + let idx = self.nodes[node_idx.index()].parent().index(); op(&mut self.nodes[idx]) } - fn grand_parent_mut(&'a mut self, node: NodeIndex, op: F) -> R + fn grand_parent_mut(&'a mut self, node_idx: NodeIndex, op: F) -> R where R: 'a, F: FnOnce(&'a mut Node) -> R, { - let parent_idx = self.nodes[node.index()].parent().index(); + let parent_idx = self.nodes[node_idx.index()].parent().index(); let grand_parent_idx = self.nodes[parent_idx].parent().index(); op(&mut self.nodes[grand_parent_idx]) } - fn max(&self, node: NodeIndex) -> Option<&T> { - let max_index = self.nodes[node.index()].max_index?.index(); + pub(crate) fn max(&self, node_idx: NodeIndex) -> Option<&T> { + let max_index = self.nodes[node_idx.index()].max_index?.index(); self.nodes[max_index].interval.as_ref().map(|i| &i.high) } } -/// An iterator over the entries of a `IntervalMap`. -#[derive(Debug)] -pub struct Iter<'a, T, V, Ix> { - /// Reference to the map - map_ref: &'a IntervalMap, - /// Stack for iteration - stack: Option>>, +#[cfg(test)] +#[derive(Debug, PartialEq, Eq)] +pub struct VisitedInterval { + key: Interval, + left: Option>, + right: Option>, + color: Color, + depth: i32, } -impl Iter<'_, T, V, Ix> -where - Ix: IndexType, -{ - /// Initializes the stack - fn init_stack(&mut self) { - self.stack = Some(Self::left_link(self.map_ref, self.map_ref.root)); - } - - /// Pushes a link of nodes on the left to stack. - fn left_link(map_ref: &IntervalMap, mut x: NodeIndex) -> Vec> { - let mut nodes = vec![]; - while !map_ref.node_ref(x, Node::is_sentinel) { - nodes.push(x); - x = map_ref.node_ref(x, Node::left); +#[cfg(test)] +impl VisitedInterval { + pub fn new( + key: Interval, + left: Option>, + right: Option>, + color: Color, + depth: i32, + ) -> Self { + Self { + key, + left, + right, + color, + depth, } - nodes } } -impl<'a, T, V, Ix> Iterator for Iter<'a, T, V, Ix> +#[cfg(test)] +impl IntervalMap where + T: Ord + Clone, Ix: IndexType, { - type Item = (&'a Interval, &'a V); - - #[inline] - fn next(&mut self) -> Option { - if self.stack.is_none() { - self.init_stack(); - } - let stack = self.stack.as_mut().unwrap(); - if stack.is_empty() { - return None; + fn visit_level(&self) -> Vec> { + let mut res: Vec> = Vec::new(); + let mut queue = VecDeque::new(); + queue.push_back(self.root); + let mut depth = 0; + while !queue.is_empty() { + for _ in 0..queue.len() { + let p = queue.pop_front().unwrap(); + let node = &self.nodes[p.index()]; + let p_left_node = &self.nodes[node.left().index()]; + let p_right_node = &self.nodes[node.right().index()]; + + res.push(VisitedInterval { + key: node.interval.clone().unwrap(), + left: p_left_node.interval.clone(), + right: p_right_node.interval.clone(), + color: node.color(), + depth, + }); + if !p_left_node.is_sentinel() { + queue.push_back(node.left()) + } + if !p_right_node.is_sentinel() { + queue.push_back(node.right()) + } + } + depth += 1; } - let x = stack.pop().unwrap(); - stack.extend(Self::left_link( - self.map_ref, - self.map_ref.node_ref(x, Node::right), - )); - Some(self.map_ref.node_ref(x, |xn| (xn.interval(), xn.value()))) + res } } #[cfg(test)] mod test { - use std::collections::HashSet; - - use rand::{rngs::StdRng, Rng, SeedableRng}; - use super::*; - struct IntervalGenerator { - rng: StdRng, - unique: HashSet>, - limit: i32, - } - - impl IntervalGenerator { - fn new(seed: [u8; 32]) -> Self { - const LIMIT: i32 = 1000; - Self { - rng: SeedableRng::from_seed(seed), - unique: HashSet::new(), - limit: LIMIT, - } - } - - fn next(&mut self) -> Interval { - let low = self.rng.gen_range(0..self.limit - 1); - let high = self.rng.gen_range((low + 1)..self.limit); - Interval::new(low, high) - } - - fn next_unique(&mut self) -> Interval { - let mut interval = self.next(); - while self.unique.contains(&interval) { - interval = self.next(); - } - self.unique.insert(interval.clone()); - interval - } - - fn next_with_range(&mut self, range: i32) -> Interval { - let low = self.rng.gen_range(0..self.limit - 1); - let high = self - .rng - .gen_range((low + 1)..self.limit.min(low + 1 + range)); - Interval::new(low, high) - } - } - - impl IntervalMap { - fn check_max(&self) { - let _ignore = self.check_max_inner(self.root); - } - - fn check_max_inner(&self, x: NodeIndex) -> i32 { - if self.node_ref(x, Node::is_sentinel) { - return 0; - } - let l_max = self.check_max_inner(self.node_ref(x, Node::left)); - let r_max = self.check_max_inner(self.node_ref(x, Node::right)); - let max = self.node_ref(x, |x| x.interval().high.max(l_max).max(r_max)); - assert_eq!(self.max(x), Some(&max)); - max - } - - /// 1. Every node is either red or black. - /// 2. The root is black. - /// 3. Every leaf (NIL) is black. - /// 4. If a node is red, then both its children are black. - /// 5. For each node, all simple paths from the node to descendant leaves contain the - /// same number of black nodes. - fn check_rb_properties(&self) { - assert!(matches!( - self.node_ref(self.root, Node::color), - Color::Black - )); - self.check_children_color(self.root); - self.check_black_height(self.root); - } - - fn check_children_color(&self, x: NodeIndex) { - if self.node_ref(x, Node::is_sentinel) { - return; - } - self.check_children_color(self.node_ref(x, Node::left)); - self.check_children_color(self.node_ref(x, Node::right)); - if self.node_ref(x, Node::is_red) { - assert!(matches!(self.left_ref(x, Node::color), Color::Black)); - assert!(matches!(self.right_ref(x, Node::color), Color::Black)); - } - } - - fn check_black_height(&self, x: NodeIndex) -> usize { - if self.node_ref(x, Node::is_sentinel) { - return 0; - } - let lefth = self.check_black_height(self.node_ref(x, Node::left)); - let righth = self.check_black_height(self.node_ref(x, Node::right)); - assert_eq!(lefth, righth); - if self.node_ref(x, Node::is_black) { - return lefth + 1; - } - lefth - } + #[test] + fn test_interval_tree_insert() { + let mut map = IntervalMap::new(); + map.insert(Interval::new(16, 21), 30); + map.insert(Interval::new(8, 9), 23); + map.insert(Interval::new(0, 3), 3); + map.insert(Interval::new(5, 8), 10); + map.insert(Interval::new(6, 10), 10); + map.insert(Interval::new(15, 23), 23); + map.insert(Interval::new(17, 19), 20); + map.insert(Interval::new(25, 30), 30); + map.insert(Interval::new(26, 27), 26); + map.insert(Interval::new(19, 20), 20); + + let expected = vec![ + VisitedInterval::new( + Interval::new(16, 21), + Some(Interval::new(8, 9)), + Some(Interval::new(25, 30)), + Color::Black, + 0, + ), + VisitedInterval::new( + Interval::new(8, 9), + Some(Interval::new(5, 8)), + Some(Interval::new(15, 23)), + Color::Red, + 1, + ), + VisitedInterval::new( + Interval::new(25, 30), + Some(Interval::new(17, 19)), + Some(Interval::new(26, 27)), + Color::Red, + 1, + ), + VisitedInterval::new( + Interval::new(5, 8), + Some(Interval::new(0, 3)), + Some(Interval::new(6, 10)), + Color::Black, + 2, + ), + VisitedInterval::new(Interval::new(15, 23), None, None, Color::Black, 2), + VisitedInterval::new( + Interval::new(17, 19), + None, + Some(Interval::new(19, 20)), + Color::Black, + 2, + ), + VisitedInterval::new(Interval::new(26, 27), None, None, Color::Black, 2), + VisitedInterval::new(Interval::new(0, 3), None, None, Color::Red, 3), + VisitedInterval::new(Interval::new(6, 10), None, None, Color::Red, 3), + VisitedInterval::new(Interval::new(19, 20), None, None, Color::Red, 3), + ]; + + let res = map.visit_level(); + assert_eq!(res, expected); } - fn with_map_and_generator(test_fn: impl Fn(IntervalMap, IntervalGenerator)) { - let seeds = vec![[0; 32], [1; 32], [2; 32]]; - for seed in seeds { - let gen = IntervalGenerator::new(seed); - let map = IntervalMap::new(); - test_fn(map, gen); - } + #[test] + fn test_interval_tree_self_balanced() { + let mut map = IntervalMap::new(); + map.insert(Interval::new(0, 1), 0); + map.insert(Interval::new(1, 2), 0); + map.insert(Interval::new(3, 4), 0); + map.insert(Interval::new(5, 6), 0); + map.insert(Interval::new(7, 8), 0); + map.insert(Interval::new(8, 9), 0); + + let expected = vec![ + VisitedInterval::new( + Interval::new(1, 2), + Some(Interval::new(0, 1)), + Some(Interval::new(5, 6)), + Color::Black, + 0, + ), + VisitedInterval::new(Interval::new(0, 1), None, None, Color::Black, 1), + VisitedInterval::new( + Interval::new(5, 6), + Some(Interval::new(3, 4)), + Some(Interval::new(7, 8)), + Color::Red, + 1, + ), + VisitedInterval::new(Interval::new(3, 4), None, None, Color::Black, 2), + VisitedInterval::new( + Interval::new(7, 8), + None, + Some(Interval::new(8, 9)), + Color::Black, + 2, + ), + VisitedInterval::new(Interval::new(8, 9), None, None, Color::Red, 3), + ]; + + let res = map.visit_level(); + assert_eq!(res, expected); } #[test] - fn red_black_tree_properties_is_satisfied() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - } - map.check_rb_properties(); - }); + fn test_interval_tree_delete() { + let mut map = IntervalMap::new(); + map.insert(Interval::new(510, 511), 0); + map.insert(Interval::new(82, 83), 0); + map.insert(Interval::new(830, 831), 0); + map.insert(Interval::new(11, 12), 0); + map.insert(Interval::new(383, 384), 0); + map.insert(Interval::new(647, 648), 0); + map.insert(Interval::new(899, 900), 0); + map.insert(Interval::new(261, 262), 0); + map.insert(Interval::new(410, 411), 0); + map.insert(Interval::new(514, 515), 0); + map.insert(Interval::new(815, 816), 0); + map.insert(Interval::new(888, 889), 0); + map.insert(Interval::new(972, 973), 0); + map.insert(Interval::new(238, 239), 0); + map.insert(Interval::new(292, 293), 0); + map.insert(Interval::new(953, 954), 0); + + let expected_before_delete = vec![ + VisitedInterval::new( + Interval::new(510, 511), + Some(Interval::new(82, 83)), + Some(Interval::new(830, 831)), + Color::Black, + 0, + ), + VisitedInterval::new( + Interval::new(82, 83), + Some(Interval::new(11, 12)), + Some(Interval::new(383, 384)), + Color::Black, + 1, + ), + VisitedInterval::new( + Interval::new(830, 831), + Some(Interval::new(647, 648)), + Some(Interval::new(899, 900)), + Color::Black, + 1, + ), + VisitedInterval::new(Interval::new(11, 12), None, None, Color::Black, 2), + VisitedInterval::new( + Interval::new(383, 384), + Some(Interval::new(261, 262)), + Some(Interval::new(410, 411)), + Color::Red, + 2, + ), + VisitedInterval::new( + Interval::new(647, 648), + Some(Interval::new(514, 515)), + Some(Interval::new(815, 816)), + Color::Black, + 2, + ), + VisitedInterval::new( + Interval::new(899, 900), + Some(Interval::new(888, 889)), + Some(Interval::new(972, 973)), + Color::Red, + 2, + ), + VisitedInterval::new( + Interval::new(261, 262), + Some(Interval::new(238, 239)), + Some(Interval::new(292, 293)), + Color::Black, + 3, + ), + VisitedInterval::new(Interval::new(410, 411), None, None, Color::Black, 3), + VisitedInterval::new(Interval::new(514, 515), None, None, Color::Red, 3), + VisitedInterval::new(Interval::new(815, 816), None, None, Color::Red, 3), + VisitedInterval::new(Interval::new(888, 889), None, None, Color::Black, 3), + VisitedInterval::new( + Interval::new(972, 973), + Some(Interval::new(953, 954)), + None, + Color::Black, + 3, + ), + VisitedInterval::new(Interval::new(238, 239), None, None, Color::Red, 4), + VisitedInterval::new(Interval::new(292, 293), None, None, Color::Red, 4), + VisitedInterval::new(Interval::new(953, 954), None, None, Color::Red, 4), + ]; + + let res = map.visit_level(); + assert_eq!(res, expected_before_delete); + + // delete the node "514" + let range514 = Interval::new(514, 515); + let deleted = map.remove(&range514); + assert!(deleted.is_some()); + + let expected_after_delete514 = vec![ + VisitedInterval::new( + Interval::new(510, 511), + Some(Interval::new(82, 83)), + Some(Interval::new(830, 831)), + Color::Black, + 0, + ), + VisitedInterval::new( + Interval::new(82, 83), + Some(Interval::new(11, 12)), + Some(Interval::new(383, 384)), + Color::Black, + 1, + ), + VisitedInterval::new( + Interval::new(830, 831), + Some(Interval::new(647, 648)), + Some(Interval::new(899, 900)), + Color::Black, + 1, + ), + VisitedInterval::new(Interval::new(11, 12), None, None, Color::Black, 2), + VisitedInterval::new( + Interval::new(383, 384), + Some(Interval::new(261, 262)), + Some(Interval::new(410, 411)), + Color::Red, + 2, + ), + VisitedInterval::new( + Interval::new(647, 648), + None, + Some(Interval::new(815, 816)), + Color::Black, + 2, + ), + VisitedInterval::new( + Interval::new(899, 900), + Some(Interval::new(888, 889)), + Some(Interval::new(972, 973)), + Color::Red, + 2, + ), + VisitedInterval::new( + Interval::new(261, 262), + Some(Interval::new(238, 239)), + Some(Interval::new(292, 293)), + Color::Black, + 3, + ), + VisitedInterval::new(Interval::new(410, 411), None, None, Color::Black, 3), + VisitedInterval::new(Interval::new(815, 816), None, None, Color::Red, 3), + VisitedInterval::new(Interval::new(888, 889), None, None, Color::Black, 3), + VisitedInterval::new( + Interval::new(972, 973), + Some(Interval::new(953, 954)), + None, + Color::Black, + 3, + ), + VisitedInterval::new(Interval::new(238, 239), None, None, Color::Red, 4), + VisitedInterval::new(Interval::new(292, 293), None, None, Color::Red, 4), + VisitedInterval::new(Interval::new(953, 954), None, None, Color::Red, 4), + ]; + + let res = map.visit_level(); + assert_eq!(res, expected_after_delete514); + + // delete the node "11" + let range11 = Interval::new(11, 12); + let deleted = map.remove(&range11); + assert!(deleted.is_some()); + + let expected_after_delete11 = vec![ + VisitedInterval::new( + Interval::new(510, 511), + Some(Interval::new(383, 384)), + Some(Interval::new(830, 831)), + Color::Black, + 0, + ), + VisitedInterval::new( + Interval::new(383, 384), + Some(Interval::new(261, 262)), + Some(Interval::new(410, 411)), + Color::Black, + 1, + ), + VisitedInterval::new( + Interval::new(830, 831), + Some(Interval::new(647, 648)), + Some(Interval::new(899, 900)), + Color::Black, + 1, + ), + VisitedInterval::new( + Interval::new(261, 262), + Some(Interval::new(82, 83)), + Some(Interval::new(292, 293)), + Color::Red, + 2, + ), + VisitedInterval::new(Interval::new(410, 411), None, None, Color::Black, 2), + VisitedInterval::new( + Interval::new(647, 648), + None, + Some(Interval::new(815, 816)), + Color::Black, + 2, + ), + VisitedInterval::new( + Interval::new(899, 900), + Some(Interval::new(888, 889)), + Some(Interval::new(972, 973)), + Color::Red, + 2, + ), + VisitedInterval::new( + Interval::new(82, 83), + None, + Some(Interval::new(238, 239)), + Color::Black, + 3, + ), + VisitedInterval::new(Interval::new(292, 293), None, None, Color::Black, 3), + VisitedInterval::new(Interval::new(815, 816), None, None, Color::Red, 3), + VisitedInterval::new(Interval::new(888, 889), None, None, Color::Black, 3), + VisitedInterval::new( + Interval::new(972, 973), + Some(Interval::new(953, 954)), + None, + Color::Black, + 3, + ), + VisitedInterval::new(Interval::new(238, 239), None, None, Color::Red, 4), + VisitedInterval::new(Interval::new(953, 954), None, None, Color::Red, 4), + ]; + + let res = map.visit_level(); + assert_eq!(res, expected_after_delete11); } - #[test] - fn map_len_will_update() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(100) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - } - assert_eq!(map.len(), 100); - for i in intervals { - let _ignore = map.remove(&i); + impl Interval { + fn new_point(x: &str) -> Interval { + let mut hx = x.to_owned(); + hx.push('\0'); + Interval { + low: x.to_owned(), + high: hx, } - assert_eq!(map.len(), 0); - }); + } } #[test] - fn check_overlap_is_ok() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_with_range(10)) - .take(100) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - } - let to_check: Vec<_> = std::iter::repeat_with(|| gen.next_with_range(10)) - .take(1000) - .collect(); - let expects: Vec<_> = to_check - .iter() - .map(|ci| intervals.iter().any(|i| ci.overlap(i))) - .collect(); - - for (ci, expect) in to_check.into_iter().zip(expects.into_iter()) { - assert_eq!(map.overlap(&ci), expect); - } - }); + fn test_interval_tree_intersects() { + let mut map = IntervalMap::::new(); + map.insert(Interval::new(String::from("1"), String::from("3")), 123); + + assert!(!map.overlaps(&Interval::new_point("0")), "contains 0"); + assert!(map.overlaps(&Interval::new_point("1")), "missing 1"); + assert!(map.overlaps(&Interval::new_point("11")), "missing 11"); + assert!(map.overlaps(&Interval::new_point("2")), "missing 2"); + assert!(!map.overlaps(&Interval::new_point("3")), "contains 3"); } #[test] - fn check_max_is_ok() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - map.check_max(); - } - assert_eq!(map.len(), 1000); - for i in intervals { - let _ignore = map.remove(&i); - map.check_max(); - } - }); + fn test_interval_tree_find_all_overlap() { + let mut map = IntervalMap::::new(); + map.insert(Interval::new(String::from("0"), String::from("1")), 123); + map.insert(Interval::new(String::from("0"), String::from("2")), 456); + map.insert(Interval::new(String::from("5"), String::from("6")), 789); + map.insert(Interval::new(String::from("6"), String::from("8")), 999); + map.insert(Interval::new(String::from("0"), String::from("3")), 0); + + let tmp = map.node_ref(map.node_ref(map.root, Node::max_index), Node::interval); + assert_eq!(tmp, &Interval::new(String::from("6"), String::from("8"))); + + assert_eq!(map.find_all_overlap(&Interval::new_point("0")).len(), 3); + assert_eq!(map.find_all_overlap(&Interval::new_point("1")).len(), 2); + assert_eq!(map.find_all_overlap(&Interval::new_point("2")).len(), 1); + assert_eq!(map.find_all_overlap(&Interval::new_point("3")).len(), 0); + assert_eq!(map.find_all_overlap(&Interval::new_point("5")).len(), 1); + assert_eq!(map.find_all_overlap(&Interval::new_point("55")).len(), 1); + assert_eq!(map.find_all_overlap(&Interval::new_point("6")).len(), 1); } + type TestCaseBFn = dyn Fn(&(&Interval, &())) -> bool; + struct TestCaseB { + f: Box, + wcount: i32, + } #[test] - fn remove_non_exist_interval_will_do_nothing() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in intervals { - let _ignore = map.insert(i, ()); - } - assert_eq!(map.len(), 1000); - let to_remove: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in to_remove { - let _ignore = map.remove(&i); - } - assert_eq!(map.len(), 1000); - }); + fn test_interval_tree_visit_exit() { + let ivls = vec![ + Interval::new(1, 10), + Interval::new(2, 5), + Interval::new(3, 6), + Interval::new(4, 8), + ]; + let ivl_range = Interval::new(0, 100); + + let tests = [ + TestCaseB { + f: Box::new(|_| false), + wcount: 1, + }, + TestCaseB { + f: Box::new({ + let ivls = ivls.clone(); + move |v| v.0.low <= ivls[0].low + }), + wcount: 2, + }, + TestCaseB { + f: Box::new({ + let ivls = ivls.clone(); + move |v| v.0.low < ivls[2].low + }), + wcount: 3, + }, + TestCaseB { + f: Box::new(|_| true), + wcount: 4, + }, + ]; + + for (i, tt) in tests.iter().enumerate() { + let mut map = IntervalMap::new(); + ivls.iter().for_each(|v| { + map.insert(v.clone(), ()); + }); + let mut count = 0; + map.filter_iter(&ivl_range).find(|v| { + count += 1; + !(tt.f)(v) + }); + assert_eq!(count, tt.wcount, "#{}: error", i); + } } - #[test] - fn find_all_overlap_is_ok() { - with_map_and_generator(|mut map, mut gen| { - let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .take(1000) - .collect(); - for i in intervals.clone() { - let _ignore = map.insert(i, ()); - } - let to_find: Vec<_> = std::iter::repeat_with(|| gen.next()).take(1000).collect(); + struct TestCaseC { + ivls: Vec>, + chk_ivl: Interval, - let expects: Vec> = to_find - .iter() - .map(|ti| intervals.iter().filter(|i| ti.overlap(i)).collect()) - .collect(); - - for (ti, mut expect) in to_find.into_iter().zip(expects.into_iter()) { - let mut result = map.find_all_overlap(&ti); - expect.sort_unstable(); - result.sort_unstable(); - assert_eq!(expect.len(), result.len()); - for (e, r) in expect.into_iter().zip(result.into_iter()) { - assert_eq!(e, r.0); - } - } - }); + w_contains: bool, } - #[test] - fn iterate_through_map_is_sorted() { - with_map_and_generator(|mut map, mut gen| { - let mut intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) - .enumerate() - .take(1000) - .collect(); - for (v, i) in intervals.clone() { - let _ignore = map.insert(i, v); - } - intervals.sort_unstable_by(|a, b| a.1.cmp(&b.1)); - - for ((ei, ev), (v, i)) in map.iter().zip(intervals.iter()) { - assert_eq!(ei, i); - assert_eq!(ev, v); - } - }); + fn test_interval_tree_contains() { + let tests = [ + TestCaseC { + ivls: vec![Interval::new(1, 10)], + chk_ivl: Interval::new(0, 100), + + w_contains: false, + }, + TestCaseC { + ivls: vec![Interval::new(1, 10)], + chk_ivl: Interval::new(1, 10), + + w_contains: true, + }, + TestCaseC { + ivls: vec![Interval::new(1, 10)], + chk_ivl: Interval::new(2, 8), + + w_contains: true, + }, + TestCaseC { + ivls: vec![Interval::new(1, 5), Interval::new(6, 10)], + chk_ivl: Interval::new(1, 10), + + w_contains: false, + }, + TestCaseC { + ivls: vec![Interval::new(1, 5), Interval::new(3, 10)], + chk_ivl: Interval::new(1, 10), + + w_contains: true, + }, + TestCaseC { + ivls: vec![ + Interval::new(1, 4), + Interval::new(4, 7), + Interval::new(3, 10), + ], + chk_ivl: Interval::new(1, 10), + + w_contains: true, + }, + TestCaseC { + ivls: vec![], + chk_ivl: Interval::new(1, 10), + + w_contains: false, + }, + ]; + for (i, tt) in tests.iter().enumerate() { + let mut map = IntervalMap::new(); + tt.ivls.iter().for_each(|v| { + map.insert(v.clone(), ()); + }); + assert_eq!(map.contains(&tt.chk_ivl), tt.w_contains, "#{}: error", i); + } } + struct TestCaseA { + ivls: Vec>, + visit_range: Interval, + } #[test] - fn interval_map_clear_is_ok() { - let mut map = IntervalMap::new(); - map.insert(Interval::new(1, 3), 1); - map.insert(Interval::new(2, 4), 2); - map.insert(Interval::new(6, 7), 3); - assert_eq!(map.len(), 3); - map.clear(); - assert_eq!(map.len(), 0); - assert!(map.is_empty()); - assert_eq!(map.nodes.len(), 1); - assert!(map.nodes[0].is_sentinel()); + fn test_interval_tree_sorted_visit() { + let tests = [ + TestCaseA { + ivls: vec![ + Interval::new(1, 10), + Interval::new(2, 5), + Interval::new(3, 6), + ], + visit_range: Interval::new(0, 100), + }, + TestCaseA { + ivls: vec![ + Interval::new(1, 10), + Interval::new(10, 12), + Interval::new(3, 6), + ], + visit_range: Interval::new(0, 100), + }, + TestCaseA { + ivls: vec![ + Interval::new(2, 3), + Interval::new(3, 4), + Interval::new(6, 7), + Interval::new(5, 6), + ], + visit_range: Interval::new(0, 100), + }, + TestCaseA { + ivls: vec![ + Interval::new(2, 3), + Interval::new(2, 4), + Interval::new(3, 7), + Interval::new(2, 5), + Interval::new(3, 8), + Interval::new(3, 5), + ], + visit_range: Interval::new(0, 100), + }, + ]; + for (i, tt) in tests.iter().enumerate() { + let mut map = IntervalMap::new(); + tt.ivls.iter().for_each(|v| { + map.insert(v.clone(), ()); + }); + let mut last = tt.ivls[0].low; + let count = map + .iter() + .filter(|v| v.0.overlap(&tt.visit_range)) + .fold(0, |acc, v| { + assert!( + last <= v.0.low, + "#{}: expected less than {}, got interval {:?}", + i, + last, + v.0 + ); + last = v.0.low; + acc + 1 + }); + assert_eq!(count, tt.ivls.len(), "#{}: did not cover all intervals.", i); + } } } diff --git a/src/iter.rs b/src/iter.rs new file mode 100644 index 0000000..19b8915 --- /dev/null +++ b/src/iter.rs @@ -0,0 +1,248 @@ +use std::fmt::Debug; + +use crate::index::{IndexType, NodeIndex}; +use crate::interval::Interval; +use crate::intervalmap::IntervalMap; +use crate::node::Node; + +/// Pushes a link of nodes on the left to stack. +fn left_link(map_ref: &IntervalMap, mut x: NodeIndex) -> Vec> +where + T: Ord, + Ix: IndexType, +{ + let mut nodes = vec![]; + while !map_ref.node_ref(x, Node::is_sentinel) { + nodes.push(x); + x = map_ref.node_ref(x, Node::left); + } + nodes +} + +/// An iterator over the entries of a `IntervalMap`. +#[derive(Debug)] +pub struct Iter<'a, T, V, Ix> +where + T: Ord, +{ + /// Reference to the map + pub(crate) map_ref: &'a IntervalMap, + /// Stack for iteration + pub(crate) stack: Vec>, +} + +impl<'a, T, V, Ix> Iter<'a, T, V, Ix> +where + T: Ord, + Ix: IndexType, +{ + pub fn new(map_ref: &'a IntervalMap) -> Self { + Iter { + map_ref, + stack: left_link(map_ref, map_ref.root), + } + } +} + +impl<'a, T, V, Ix> Iterator for Iter<'a, T, V, Ix> +where + T: Ord, + Ix: IndexType, +{ + type Item = (&'a Interval, &'a V); + + #[inline] + fn next(&mut self) -> Option { + if self.stack.is_empty() { + return None; + } + let x = self.stack.pop().unwrap(); + self.stack.extend(left_link( + self.map_ref, + self.map_ref.node_ref(x, Node::right), + )); + Some(self.map_ref.node_ref(x, |xn| (xn.interval(), xn.value()))) + } +} + +/// An into iterator over the entries of a `IntervalMap`. +#[derive(Debug)] +pub struct IntoIter +where + T: Ord, +{ + interval_map: IntervalMap, + /// Stack for iteration + pub(crate) stack: Vec>, +} + +impl IntoIter +where + T: Ord, + Ix: IndexType, +{ + pub fn new(interval_map: IntervalMap) -> Self { + let mut temp = IntoIter { + interval_map, + stack: vec![], + }; + temp.stack = left_link(&temp.interval_map, temp.interval_map.root); + temp + } +} + +impl Iterator for IntoIter +where + T: Ord, + Ix: IndexType, +{ + type Item = (Interval, V); + + #[inline] + fn next(&mut self) -> Option { + if self.stack.is_empty() { + return None; + } + let x = self.stack.pop().unwrap(); + self.stack.extend(left_link( + &self.interval_map, + self.interval_map.node_ref(x, Node::right), + )); + let res = &mut self.interval_map.nodes[x.index()]; + Some((res.interval.take().unwrap(), res.value.take().unwrap())) + } +} + +/// An unsorted iterator over the entries of a `IntervalMap`. +#[derive(Debug)] +pub struct UnsortedIter<'a, T, V, Ix> +where + T: Ord, +{ + map_ref: &'a IntervalMap, + /// Stack for iteration + pub(crate) cur: NodeIndex, +} + +impl<'a, T, V, Ix> UnsortedIter<'a, T, V, Ix> +where + T: Ord, + Ix: IndexType, +{ + pub fn new(map_ref: &'a IntervalMap) -> Self { + UnsortedIter { + map_ref, + cur: NodeIndex::SENTINEL, + } + } +} + +impl<'a, T, V, Ix> Iterator for UnsortedIter<'a, T, V, Ix> +where + T: Ord, + Ix: IndexType, +{ + type Item = (&'a Interval, &'a V); + + #[inline] + fn next(&mut self) -> Option { + if self.map_ref.is_empty() + || self.cur.index() >= self.map_ref.len() + || self.cur.index() == ::max().index() + { + return None; + } + self.cur = self.cur.incre(); + Some( + self.map_ref + .node_ref(self.cur, |xn| (xn.interval(), xn.value())), + ) + } +} + +/// A filter iterator over the entries of a `IntervalMap`.It's equal to `iter().filter()` +/// but faster than the latter. +#[derive(Debug)] +pub struct FilterIter<'a, T, V, Ix> +where + T: Ord, +{ + /// Reference to the map + pub(crate) map_ref: &'a IntervalMap, + /// Stack for iteration + pub(crate) stack: Vec>, + /// Filter criteria + pub(crate) query: &'a Interval, +} + +fn left_link_with_query( + map_ref: &IntervalMap, + mut x: NodeIndex, + query: &Interval, +) -> Vec> +where + T: Ord, + Ix: IndexType, +{ + let mut stack: Vec> = vec![]; + if map_ref.max(x).is_some_and(|v| v <= &query.low) { + return stack; + } + while map_ref.node_ref(x, Node::sentinel).is_some() { + if map_ref.node_ref(x, Node::interval).low < query.high { + stack.push(x); + } + if map_ref.max(map_ref.node_ref(x, Node::left)) <= Some(&query.low) { + break; + } + x = map_ref.node_ref(x, Node::left); + } + stack +} + +impl<'a, T, V, Ix> FilterIter<'a, T, V, Ix> +where + T: Ord, + Ix: IndexType, +{ + pub fn new(map_ref: &'a IntervalMap, query: &'a Interval) -> Self { + FilterIter { + map_ref, + stack: left_link_with_query(map_ref, map_ref.root, query), + query, + } + } +} + +impl<'a, T, V, Ix> Iterator for FilterIter<'a, T, V, Ix> +where + T: Ord, + Ix: IndexType, +{ + type Item = (&'a Interval, &'a V); + + #[inline] + fn next(&mut self) -> Option { + if self.stack.is_empty() { + return None; + } + let mut x = self.stack.pop().unwrap(); + while !self.map_ref.node_ref(x, Node::interval).overlap(self.query) { + self.stack.extend(left_link_with_query( + self.map_ref, + self.map_ref.node_ref(x, Node::right), + self.query, + )); + if self.stack.is_empty() { + return None; + } + x = self.stack.pop().unwrap(); + } + self.stack.extend(left_link_with_query( + self.map_ref, + self.map_ref.node_ref(x, Node::right), + self.query, + )); + Some(self.map_ref.node_ref(x, |xn| (xn.interval(), xn.value()))) + } +} diff --git a/src/lib.rs b/src/lib.rs index 13e5141..b0235a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,8 +25,13 @@ mod entry; mod index; mod interval; mod intervalmap; +mod iter; mod node; +#[cfg(test)] +mod tests; + pub use entry::{Entry, OccupiedEntry, VacantEntry}; pub use interval::Interval; -pub use intervalmap::{IntervalMap, Iter}; +pub use intervalmap::IntervalMap; +pub use iter::Iter; diff --git a/src/node.rs b/src/node.rs index 63cbb03..38f1d51 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,10 +1,19 @@ +use crate::index::{IndexType, NodeIndex}; use crate::interval::Interval; -use crate::index::{IndexType, NodeIndex}; +#[cfg(feature = "graphviz")] +use std::fmt::Display; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] /// Node of the interval tree #[derive(Debug)] -pub struct Node { +pub struct Node +where + T: Ord, +{ /// Left children pub left: Option>, /// Right children @@ -22,103 +31,261 @@ pub struct Node { pub value: Option, } -// Convenient getter/setter methods impl Node where - Ix: IndexType, + T: Ord, { - pub fn color(&self) -> Color { - self.color - } - pub fn interval(&self) -> &Interval { self.interval.as_ref().unwrap() } - pub fn max_index(&self) -> NodeIndex { + pub fn value(&self) -> &V { + self.value.as_ref().unwrap() + } + pub fn value_mut(&mut self) -> &mut V { + self.value.as_mut().unwrap() + } + pub fn take_value(&mut self) -> V { + self.value.take().unwrap() + } + + pub fn set_value(value: V) -> impl FnOnce(&mut Node) -> V { + move |node: &mut Node| node.value.replace(value).unwrap() + } +} + +// Convenient getter/setter methods +impl Node +where + T: Ord, + Ix: IndexType, +{ + pub(crate) fn max_index(&self) -> NodeIndex { self.max_index.unwrap() } - pub fn left(&self) -> NodeIndex { - self.left.unwrap() + pub(crate) fn set_max_index(max_index: NodeIndex) -> impl FnOnce(&mut Node) { + move |node: &mut Node| { + let _ignore = node.max_index.replace(max_index); + } } - pub fn right(&self) -> NodeIndex { - self.right.unwrap() + pub(crate) fn left(&self) -> NodeIndex { + self.left.unwrap() } - pub fn parent(&self) -> NodeIndex { - self.parent.unwrap() + pub(crate) fn set_left(left: NodeIndex) -> impl FnOnce(&mut Node) { + move |node: &mut Node| { + let _ignore = node.left.replace(left); + } } - pub fn is_sentinel(&self) -> bool { - self.interval.is_none() + pub(crate) fn right(&self) -> NodeIndex { + self.right.unwrap() } - pub fn sentinel(&self) -> Option<&Self> { - self.interval.is_some().then_some(self) + pub(crate) fn set_right(right: NodeIndex) -> impl FnOnce(&mut Node) { + move |node: &mut Node| { + let _ignore = node.right.replace(right); + } } - pub fn is_black(&self) -> bool { - matches!(self.color, Color::Black) + pub(crate) fn parent(&self) -> NodeIndex { + self.parent.unwrap() } - pub fn is_red(&self) -> bool { - matches!(self.color, Color::Red) + pub(crate) fn set_parent(parent: NodeIndex) -> impl FnOnce(&mut Node) { + move |node: &mut Node| { + let _ignore = node.parent.replace(parent); + } + } + pub(crate) fn is_sentinel(&self) -> bool { + self.interval.is_none() } - pub fn value(&self) -> &V { - self.value.as_ref().unwrap() + pub(crate) fn sentinel(&self) -> Option<&Self> { + self.interval.is_some().then_some(self) } - pub fn value_mut(&mut self) -> &mut V { - self.value.as_mut().unwrap() + pub(crate) fn color(&self) -> Color { + self.color } - pub fn take_value(&mut self) -> V { - self.value.take().unwrap() + pub(crate) fn is_black(&self) -> bool { + matches!(self.color, Color::Black) } - pub fn set_value(value: V) -> impl FnOnce(&mut Node) -> V { - move |node: &mut Node| node.value.replace(value).unwrap() + pub(crate) fn is_red(&self) -> bool { + matches!(self.color, Color::Red) } - pub fn set_color(color: Color) -> impl FnOnce(&mut Node) { + pub(crate) fn set_color(color: Color) -> impl FnOnce(&mut Node) { move |node: &mut Node| { node.color = color; } } +} - pub fn set_max_index(max_index: NodeIndex) -> impl FnOnce(&mut Node) { - move |node: &mut Node| { - let _ignore = node.max_index.replace(max_index); +#[cfg(feature = "graphviz")] +impl Node +where + T: Ord + Display, + V: Display, + Ix: IndexType, +{ + pub(crate) fn draw( + &self, + index: usize, + mut writer: W, + ) -> std::io::Result<()> { + writeln!( + writer, + " {} [label=\"i={}\\n{}: {}\\n\", fillcolor={}, style=filled]", + index, + index, + self.interval.as_ref().unwrap(), + self.value.as_ref().unwrap(), + if self.is_red() { "salmon" } else { "grey65" } + )?; + if !self.left.unwrap().is_sentinel() { + writeln!( + writer, + " {} -> {} [label=\"L\"]", + index, + self.left.unwrap().index() + )?; } + if !self.right.unwrap().is_sentinel() { + writeln!( + writer, + " {} -> {} [label=\"R\"]", + index, + self.right.unwrap().index() + )?; + } + Ok(()) } +} - pub fn set_left(left: NodeIndex) -> impl FnOnce(&mut Node) { - move |node: &mut Node| { - let _ignore = node.left.replace(left); +#[cfg(feature = "graphviz")] +impl Node +where + T: Display + Ord, + Ix: IndexType, +{ + pub(crate) fn draw_without_value( + &self, + index: usize, + mut writer: W, + ) -> std::io::Result<()> { + writeln!( + writer, + " {} [label=\"i={}: {}\", fillcolor={}, style=filled]", + index, + index, + self.interval.as_ref().unwrap(), + if self.is_red() { "salmon" } else { "grey65" } + )?; + if !self.left.unwrap().is_sentinel() { + writeln!( + writer, + " {} -> {} [label=\"L\"]", + index, + self.left.unwrap().index() + )?; + } + if !self.right.unwrap().is_sentinel() { + writeln!( + writer, + " {} -> {} [label=\"R\"]", + index, + self.right.unwrap().index() + )?; } + Ok(()) } +} - pub fn set_right(right: NodeIndex) -> impl FnOnce(&mut Node) { - move |node: &mut Node| { - let _ignore = node.right.replace(right); +impl Node +where + T: Ord, + Ix: IndexType, +{ + pub fn new(interval: Interval, value: V, index: NodeIndex) -> Self { + Node { + interval: Some(interval), + value: Some(value), + max_index: Some(index), + left: Some(NodeIndex::SENTINEL), + right: Some(NodeIndex::SENTINEL), + parent: Some(NodeIndex::SENTINEL), + color: Color::Red, } } - pub fn set_parent(parent: NodeIndex) -> impl FnOnce(&mut Node) { - move |node: &mut Node| { - let _ignore = node.parent.replace(parent); + pub fn new_sentinel() -> Self { + Node { + interval: None, + value: None, + max_index: None, + left: None, + right: None, + parent: None, + color: Color::Black, } } } /// The color of the node -#[derive(Debug, Clone, Copy)] -pub enum Color { +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum Color { /// Red node Red, /// Black node Black, } + +#[cfg(feature = "serde")] +#[cfg(test)] +mod tests { + use super::*; + use serde_json::{json, Value}; + + #[test] + fn test_node_serialize_deserialize() { + let node = Node:: { + left: Some(NodeIndex::new(0)), + right: Some(NodeIndex::new(1)), + parent: Some(NodeIndex::new(2)), + color: Color::Red, + interval: Some(Interval::new(10, 20)), + max_index: Some(NodeIndex::new(3)), + value: Some(42), + }; + + // Serialize the node to JSON + let serialized = serde_json::to_string(&node).unwrap(); + let expected = json!({ + "left": 0, + "right": 1, + "parent": 2, + "color": "Red", + "interval": [10,20], + "max_index": 3, + "value": 42 + }); + let actual: Value = serde_json::from_str(&serialized).unwrap(); + assert_eq!(expected, actual); + + // Deserialize the node from JSON + let deserialized: Node = serde_json::from_str(&serialized).unwrap(); + assert_eq!(node.left, deserialized.left); + assert_eq!(node.right, deserialized.right); + assert_eq!(node.parent, deserialized.parent); + assert_eq!(node.color, deserialized.color); + assert_eq!(node.interval, deserialized.interval); + assert_eq!(node.max_index, deserialized.max_index); + assert_eq!(node.value, deserialized.value); + } +} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..50e6013 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,451 @@ +use std::collections::HashSet; + +use index::NodeIndex; +use node::{Color, Node}; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +use super::*; + +struct IntervalGenerator { + rng: StdRng, + unique: HashSet>, + limit: i32, +} + +impl IntervalGenerator { + fn new(seed: [u8; 32]) -> Self { + const LIMIT: i32 = 1000; + Self { + rng: SeedableRng::from_seed(seed), + unique: HashSet::new(), + limit: LIMIT, + } + } + + fn next(&mut self) -> Interval { + let low = self.rng.gen_range(0..self.limit - 1); + let high = self.rng.gen_range((low + 1)..self.limit); + Interval::new(low, high) + } + + fn next_unique(&mut self) -> Interval { + let mut interval = self.next(); + while self.unique.contains(&interval) { + interval = self.next(); + } + self.unique.insert(interval.clone()); + interval + } + + fn next_with_range(&mut self, range: i32) -> Interval { + let low = self.rng.gen_range(0..self.limit - 1); + let high = self + .rng + .gen_range((low + 1)..self.limit.min(low + 1 + range)); + Interval::new(low, high) + } +} + +impl IntervalMap { + fn check_max(&self) { + let _ignore = self.check_max_inner(self.root); + } + + fn check_max_inner(&self, x: NodeIndex) -> i32 { + if self.node_ref(x, Node::is_sentinel) { + return 0; + } + let l_max = self.check_max_inner(self.node_ref(x, Node::left)); + let r_max = self.check_max_inner(self.node_ref(x, Node::right)); + let max = self.node_ref(x, |x| x.interval().high.max(l_max).max(r_max)); + assert_eq!(self.max(x), Some(&max)); + max + } + + /// 1. Every node is either red or black. + /// 2. The root is black. + /// 3. Every leaf (NIL) is black. + /// 4. If a node is red, then both its children are black. + /// 5. For each node, all simple paths from the node to descendant leaves contain the + /// same number of black nodes. + fn check_rb_properties(&self) { + assert!(matches!( + self.node_ref(self.root, Node::color), + Color::Black + )); + self.check_children_color(self.root); + self.check_black_height(self.root); + } + + fn check_children_color(&self, x: NodeIndex) { + if self.node_ref(x, Node::is_sentinel) { + return; + } + self.check_children_color(self.node_ref(x, Node::left)); + self.check_children_color(self.node_ref(x, Node::right)); + if self.node_ref(x, Node::is_red) { + assert!(matches!(self.left_ref(x, Node::color), Color::Black)); + assert!(matches!(self.right_ref(x, Node::color), Color::Black)); + } + } + + fn check_black_height(&self, x: NodeIndex) -> usize { + if self.node_ref(x, Node::is_sentinel) { + return 0; + } + let lefth = self.check_black_height(self.node_ref(x, Node::left)); + let righth = self.check_black_height(self.node_ref(x, Node::right)); + assert_eq!(lefth, righth); + if self.node_ref(x, Node::is_black) { + return lefth + 1; + } + lefth + } +} + +fn with_map_and_generator(test_fn: impl Fn(IntervalMap, IntervalGenerator)) { + let seeds = vec![[0; 32], [1; 32], [2; 32]]; + for seed in seeds { + let gen = IntervalGenerator::new(seed); + let map = IntervalMap::new(); + test_fn(map, gen); + } +} + +#[test] +fn red_black_tree_properties_is_satisfied() { + with_map_and_generator(|mut map, mut gen| { + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) + .take(1000) + .collect(); + for i in intervals.clone() { + let _ignore = map.insert(i, ()); + } + map.check_rb_properties(); + }); +} + +#[test] +fn map_len_will_update() { + with_map_and_generator(|mut map, mut gen| { + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) + .take(100) + .collect(); + for i in intervals.clone() { + let _ignore = map.insert(i, ()); + } + assert_eq!(map.len(), 100); + for i in intervals { + let _ignore = map.remove(&i); + } + assert_eq!(map.len(), 0); + }); +} + +#[test] +fn check_overlap_is_ok() { + with_map_and_generator(|mut map, mut gen| { + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_with_range(10)) + .take(100) + .collect(); + for i in intervals.clone() { + let _ignore = map.insert(i, ()); + } + let to_check: Vec<_> = std::iter::repeat_with(|| gen.next_with_range(10)) + .take(1000) + .collect(); + let expects: Vec<_> = to_check + .iter() + .map(|ci| intervals.iter().any(|i| ci.overlap(i))) + .collect(); + + for (ci, expect) in to_check.into_iter().zip(expects.into_iter()) { + assert_eq!(map.overlaps(&ci), expect); + } + }); +} + +#[test] +fn check_max_is_ok() { + with_map_and_generator(|mut map, mut gen| { + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) + .take(1000) + .collect(); + for i in intervals.clone() { + let _ignore = map.insert(i, ()); + map.check_max(); + } + assert_eq!(map.len(), 1000); + for i in intervals { + let _ignore = map.remove(&i); + map.check_max(); + } + }); +} + +#[test] +fn remove_non_exist_interval_will_do_nothing() { + with_map_and_generator(|mut map, mut gen| { + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) + .take(1000) + .collect(); + for i in intervals { + let _ignore = map.insert(i, ()); + } + assert_eq!(map.len(), 1000); + let to_remove: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) + .take(1000) + .collect(); + for i in to_remove { + let _ignore = map.remove(&i); + } + assert_eq!(map.len(), 1000); + }); +} + +#[test] +fn find_all_overlap_is_ok() { + with_map_and_generator(|mut map, mut gen| { + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) + .take(1000) + .collect(); + for i in intervals.clone() { + let _ignore = map.insert(i, ()); + } + let to_find: Vec<_> = std::iter::repeat_with(|| gen.next()).take(1000).collect(); + + let expects: Vec> = to_find + .iter() + .map(|ti| intervals.iter().filter(|i| ti.overlap(i)).collect()) + .collect(); + + for (ti, mut expect) in to_find.into_iter().zip(expects.into_iter()) { + let mut result = map.find_all_overlap(&ti); + expect.sort_unstable(); + result.sort_unstable(); + assert_eq!(expect.len(), result.len()); + for (e, r) in expect.into_iter().zip(result.into_iter()) { + assert_eq!(e, r.0); + } + } + }); +} + +#[test] +fn iterate_through_map_is_sorted() { + with_map_and_generator(|mut map, mut gen| { + let mut intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) + .enumerate() + .take(1000) + .collect(); + for (v, i) in intervals.clone() { + let _ignore = map.insert(i, v); + } + intervals.sort_unstable_by(|a, b| a.1.cmp(&b.1)); + + for ((ei, ev), (v, i)) in map.iter().zip(intervals.iter()) { + assert_eq!(ei, i); + assert_eq!(ev, v); + } + }); +} + +#[test] +fn interval_map_clear_is_ok() { + let mut map = IntervalMap::new(); + map.insert(Interval::new(1, 3), 1); + map.insert(Interval::new(2, 4), 2); + map.insert(Interval::new(6, 7), 3); + assert_eq!(map.len(), 3); + map.clear(); + assert_eq!(map.len(), 0); + assert!(map.is_empty()); + assert_eq!(map.nodes.len(), 1); + assert!(map.nodes[0].is_sentinel()); +} + +#[cfg(test)] +struct TestCaseFilterIter { + query: Interval, + expected: Vec>, +} + +#[test] +fn interval_map_filter_iter_is_ok() { + let tests = [ + TestCaseFilterIter { + query: Interval::new(50, 51), + expected: vec![Interval::new(6, 99)], + }, + TestCaseFilterIter { + query: Interval::new(23, 26), + expected: vec![Interval::new(6, 99), Interval::new(25, 30)], + }, + TestCaseFilterIter { + query: Interval::new(23, 30), + expected: vec![ + Interval::new(6, 99), + Interval::new(25, 30), + Interval::new(26, 27), + ], + }, + TestCaseFilterIter { + query: Interval::new(6, 17), + expected: vec![ + Interval::new(0, 23), + Interval::new(6, 99), + Interval::new(8, 9), + Interval::new(15, 23), + Interval::new(16, 21), + ], + }, + ]; + + let mut map = IntervalMap::new(); + map.insert(Interval::new(16, 21), 30); + map.insert(Interval::new(8, 9), 23); + map.insert(Interval::new(0, 23), 3); + map.insert(Interval::new(5, 6), 10); + map.insert(Interval::new(6, 99), 10); + map.insert(Interval::new(15, 23), 23); + map.insert(Interval::new(17, 19), 20); + map.insert(Interval::new(25, 30), 30); + map.insert(Interval::new(26, 27), 26); + map.insert(Interval::new(19, 20), 20); + + for (i, tt) in tests.iter().enumerate() { + let v: Vec<_> = map.filter_iter(&tt.query).map(|v| v.0.clone()).collect(); + assert_eq!(v, tt.expected, "#{}: error", i); + } +} + +#[cfg(feature = "graphviz")] +#[test] +fn interval_map_draw_is_ok() { + let mut map = IntervalMap::new(); + map.insert(Interval::new(16, 21), 30); + map.insert(Interval::new(8, 9), 23); + map.insert(Interval::new(0, 23), 3); + map.insert(Interval::new(5, 6), 10); + map.insert(Interval::new(6, 99), 10); + map.insert(Interval::new(15, 23), 23); + map.insert(Interval::new(17, 19), 20); + map.insert(Interval::new(25, 30), 30); + map.insert(Interval::new(26, 27), 26); + map.insert(Interval::new(19, 20), 20); + + let _ = map.draw("./test.dot"); + + let _ = map.draw_without_value("./test.dot"); +} + +#[cfg(feature = "serde")] +#[test] +fn test_serde_interval_map() { + use serde_json::{json, Value}; + + let mut interval_map = IntervalMap::::new(); + interval_map.insert(Interval::new(1, 5), 10); + interval_map.insert(Interval::new(3, 7), 20); + interval_map.insert(Interval::new(2, 6), 15); + + // Serialize the interval map to JSON + let serialized = serde_json::to_string(&interval_map).unwrap(); + let expected = json!({ + "nodes": [ + // sentinel node + { + "left": null, + "right": null, + "parent": null, + "color": "Black", + "interval": null, + "max_index": null, + "value": null + }, + { + "left": 0, + "right": 0, + "parent": 3, + "color": "Red", + "interval": [1,5], + "max_index": 1, + "value": 10 + }, + { + "left": 0, + "right": 0, + "parent": 3, + "color": "Red", + "interval": [3,7], + "max_index": 2, + "value": 20 + }, + { + "left": 1, + "right": 2, + "parent": 0, + "color": "Black", + "interval": [2,6], + "max_index": 2, + "value": 15 + } + ], + "root": 3, + "len": 3 + }); + let actual: Value = serde_json::from_str(&serialized).unwrap(); + assert_eq!(expected, actual); + + // Deserialize the interval map from JSON + let deserialized: IntervalMap = serde_json::from_str(&serialized).unwrap(); + let dv: Vec<_> = deserialized.iter().collect(); + let ev: Vec<_> = interval_map.iter().collect(); + + assert_eq!(ev, dv); +} + +impl Interval { + fn new_point(x: u32) -> Self { + Interval { + low: x, + high: x + 1, + } + } +} + +#[test] +fn test_insert_point() { + let mut interval_map = IntervalMap::::new(); + interval_map.insert(Interval::new_point(5), 10); + interval_map.insert(Interval::new(3, 7), 20); + interval_map.insert(Interval::new(2, 6), 15); + + assert_eq!(interval_map.get(&Interval::new_point(5)).unwrap(), &10); + assert_eq!( + interval_map.find_all_overlap(&Interval::new_point(5)).len(), + 3 + ); +} + +#[test] +fn check_filter_iter_equal_to_iter_filter() { + with_map_and_generator(|mut map, mut gen| { + let intervals: Vec<_> = std::iter::repeat_with(|| gen.next_unique()) + .take(1000) + .collect(); + for i in intervals.clone() { + let _ignore = map.insert(i, ()); + } + let mut map = IntervalMap::new(); + for i in intervals.clone() { + map.insert(i, ()); + } + + for i in intervals { + let filter_iter_res: Vec<_> = map.filter_iter(&i).collect(); + let iter_filter_res: Vec<_> = map.iter().filter(|v| v.0.overlap(&i)).collect(); + assert_eq!(filter_iter_res, iter_filter_res); + } + }); +}