From 83de399bef99361a68ab315f714ac16334c3d8e0 Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Tue, 19 Dec 2023 16:14:37 +0800 Subject: [PATCH] feat(inverted_index.create): add external sorter (#2950) * feat(inverted_index.create): add read/write for external intermediate files Signed-off-by: Zhenchi * chore: MAGIC_CODEC_V1 -> CODEC_V1_MAGIC Signed-off-by: Zhenchi * chore: polish comments Signed-off-by: Zhenchi * chore: fix typos intermedia -> intermediate Signed-off-by: Zhenchi * fix: typos Signed-off-by: Zhenchi * feat(inverted_index.create): add external sorter Signed-off-by: Zhenchi * chore: fix typos intermedia -> intermediate Signed-off-by: Zhenchi * chore: polish comments Signed-off-by: Zhenchi * chore: polish comments Signed-off-by: Zhenchi * refactor: drop the stream as early as possible to avoid recursive calls to poll Signed-off-by: Zhenchi * refactor: project merge sorted stream Signed-off-by: Zhenchi * feat: add total_row_count to SortOutput Signed-off-by: Zhenchi * feat: remove change of format Signed-off-by: Zhenchi * refactor: rename segment null bitmap Signed-off-by: Zhenchi * refactor: test type alias Signed-off-by: Zhenchi * feat: allow `memory_usage_threshold` to be None to turn off dumping Signed-off-by: Zhenchi * feat: change segment_row_count type to NonZeroUsize Signed-off-by: Zhenchi * refactor: accept BytesRef instead Signed-off-by: Zhenchi * feat: add `push_n` to adapt mito2 Signed-off-by: Zhenchi * chore: add k-way merge TODO Signed-off-by: Zhenchi * refactor: more sorter cases Signed-off-by: Zhenchi * refactor: make the merge tree balance Signed-off-by: Zhenchi * Update src/index/src/inverted_index/create/sort/external_sort.rs Co-authored-by: Yingwen * chore: address comments Signed-off-by: Zhenchi * chore: stable feature Signed-off-by: Zhenchi --------- Signed-off-by: Zhenchi Co-authored-by: Yingwen --- Cargo.lock | 3 + src/index/Cargo.toml | 3 + src/index/src/inverted_index.rs | 1 + src/index/src/inverted_index/create/sort.rs | 39 +- .../create/sort/external_provider.rs | 39 ++ .../create/sort/external_sort.rs | 437 ++++++++++++++++++ .../create/sort/merge_stream.rs | 174 +++++++ 7 files changed, 693 insertions(+), 3 deletions(-) create mode 100644 src/index/src/inverted_index/create/sort/external_provider.rs create mode 100644 src/index/src/inverted_index/create/sort/external_sort.rs create mode 100644 src/index/src/inverted_index/create/sort/merge_stream.rs diff --git a/Cargo.lock b/Cargo.lock index 5c4a03af734a..97a509658e09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3980,11 +3980,14 @@ dependencies = [ "common-base", "common-error", "common-macro", + "common-telemetry", "fst", "futures", "greptime-proto", "mockall", + "pin-project", "prost 0.12.2", + "rand", "regex", "regex-automata 0.1.10", "snafu", diff --git a/src/index/Cargo.toml b/src/index/Cargo.toml index bd5b560ce854..0835da45d003 100644 --- a/src/index/Cargo.toml +++ b/src/index/Cargo.toml @@ -12,15 +12,18 @@ bytes.workspace = true common-base.workspace = true common-error.workspace = true common-macro.workspace = true +common-telemetry.workspace = true fst.workspace = true futures.workspace = true greptime-proto.workspace = true mockall.workspace = true +pin-project.workspace = true prost.workspace = true regex-automata.workspace = true regex.workspace = true snafu.workspace = true [dev-dependencies] +rand.workspace = true tokio-util.workspace = true tokio.workspace = true diff --git a/src/index/src/inverted_index.rs b/src/index/src/inverted_index.rs index a793d1a25238..7a34bae21381 100644 --- a/src/index/src/inverted_index.rs +++ b/src/index/src/inverted_index.rs @@ -19,3 +19,4 @@ pub mod search; pub type FstMap = fst::Map>; pub type Bytes = Vec; +pub type BytesRef<'a> = &'a [u8]; diff --git a/src/index/src/inverted_index/create/sort.rs b/src/index/src/inverted_index/create/sort.rs index 2331ed8dcb81..53a70fc7b5c0 100644 --- a/src/index/src/inverted_index/create/sort.rs +++ b/src/index/src/inverted_index/create/sort.rs @@ -12,13 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod external_provider; +mod external_sort; +mod intermediate_rw; +mod merge_stream; + +use async_trait::async_trait; use common_base::BitVec; use futures::Stream; use crate::inverted_index::error::Result; -use crate::inverted_index::Bytes; - -mod intermediate_rw; +use crate::inverted_index::{Bytes, BytesRef}; /// A stream of sorted values along with their associated bitmap pub type SortedStream = Box> + Send + Unpin>; + +/// Output of a sorting operation, encapsulating a bitmap for null values and a stream of sorted items +pub struct SortOutput { + /// Bitmap indicating which segments have null values + pub segment_null_bitmap: BitVec, + + /// Stream of sorted items + pub sorted_stream: SortedStream, + + /// Total number of rows in the sorted data + pub total_row_count: usize, +} + +/// Handles data sorting, supporting incremental input and retrieval of sorted output +#[async_trait] +pub trait Sorter: Send { + /// Inputs a non-null or null value into the sorter. + /// Should be equivalent to calling `push_n` with n = 1 + async fn push(&mut self, value: Option>) -> Result<()> { + self.push_n(value, 1).await + } + + /// Pushing n identical non-null or null values into the sorter. + /// Should be equivalent to calling `push` n times + async fn push_n(&mut self, value: Option>, n: usize) -> Result<()>; + + /// Completes the sorting process and returns the sorted data + async fn output(&mut self) -> Result; +} diff --git a/src/index/src/inverted_index/create/sort/external_provider.rs b/src/index/src/inverted_index/create/sort/external_provider.rs new file mode 100644 index 000000000000..a86f3e06aad4 --- /dev/null +++ b/src/index/src/inverted_index/create/sort/external_provider.rs @@ -0,0 +1,39 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use futures::{AsyncRead, AsyncWrite}; + +use crate::inverted_index::error::Result; + +/// Trait for managing intermediate files during external sorting for a particular index. +#[mockall::automock] +#[async_trait] +pub trait ExternalTempFileProvider: Send + Sync { + /// Creates and opens a new intermediate file associated with a specific index for writing. + /// The implementation should ensure that the file does not already exist. + /// + /// - `index_name`: the name of the index for which the file will be associated + /// - `file_id`: a unique identifier for the new file + async fn create( + &self, + index_name: &str, + file_id: &str, + ) -> Result>; + + /// Retrieves all intermediate files associated with a specific index for an external sorting operation. + /// + /// `index_name`: the name of the index to retrieve intermediate files for + async fn read_all(&self, index_name: &str) -> Result>>; +} diff --git a/src/index/src/inverted_index/create/sort/external_sort.rs b/src/index/src/inverted_index/create/sort/external_sort.rs new file mode 100644 index 000000000000..2e530f3e45e4 --- /dev/null +++ b/src/index/src/inverted_index/create/sort/external_sort.rs @@ -0,0 +1,437 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{BTreeMap, VecDeque}; +use std::mem; +use std::num::NonZeroUsize; +use std::ops::RangeInclusive; +use std::sync::Arc; + +use async_trait::async_trait; +use common_base::BitVec; +use common_telemetry::logging; +use futures::stream; + +use crate::inverted_index::create::sort::external_provider::ExternalTempFileProvider; +use crate::inverted_index::create::sort::intermediate_rw::{ + IntermediateReader, IntermediateWriter, +}; +use crate::inverted_index::create::sort::merge_stream::MergeSortedStream; +use crate::inverted_index::create::sort::{SortOutput, SortedStream, Sorter}; +use crate::inverted_index::error::Result; +use crate::inverted_index::{Bytes, BytesRef}; + +/// `ExternalSorter` manages the sorting of data using both in-memory structures and external files. +/// It dumps data to external files when the in-memory buffer crosses a certain memory threshold. +pub struct ExternalSorter { + /// The index name associated with the sorting operation + index_name: String, + + /// Manages creation and access to external temporary files + temp_file_provider: Arc, + + /// Bitmap indicating which segments have null values + segment_null_bitmap: BitVec, + + /// In-memory buffer to hold values and their corresponding bitmaps until memory threshold is exceeded + values_buffer: BTreeMap, + + /// Count of all rows ingested so far + total_row_count: usize, + + /// The number of rows per group for bitmap indexing which determines how rows are + /// batched for indexing. It is used to determine which segment a row belongs to. + segment_row_count: NonZeroUsize, + + /// Tracks memory usage of the buffer + current_memory_usage: usize, + + /// The memory usage threshold at which the buffer should be dumped to an external file. + /// `None` indicates that the buffer should never be dumped. + memory_usage_threshold: Option, +} + +#[async_trait] +impl Sorter for ExternalSorter { + /// Pushes n identical values into the sorter, adding them to the in-memory buffer and dumping + /// the buffer to an external file if necessary + async fn push_n(&mut self, value: Option>, n: usize) -> Result<()> { + if n == 0 { + return Ok(()); + } + + let segment_index_range = self.segment_index_range(n); + self.total_row_count += n; + + if let Some(value) = value { + let memory_diff = self.push_not_null(value, segment_index_range); + self.may_dump_buffer(memory_diff).await + } else { + set_bits(&mut self.segment_null_bitmap, segment_index_range); + Ok(()) + } + } + + /// Finalizes the sorting operation, merging data from both in-memory buffer and external files + /// into a sorted stream + async fn output(&mut self) -> Result { + let readers = self.temp_file_provider.read_all(&self.index_name).await?; + + // TODO(zhongzc): k-way merge instead of 2-way merge + + let mut tree_nodes: VecDeque = VecDeque::with_capacity(readers.len() + 1); + tree_nodes.push_back(Box::new(stream::iter( + mem::take(&mut self.values_buffer).into_iter().map(Ok), + ))); + for reader in readers { + tree_nodes.push_back(IntermediateReader::new(reader).into_stream().await?); + } + + while tree_nodes.len() >= 2 { + // every turn, the length of tree_nodes will be reduced by 1 until only one stream left + let stream1 = tree_nodes.pop_front().unwrap(); + let stream2 = tree_nodes.pop_front().unwrap(); + let merged_stream = MergeSortedStream::merge(stream1, stream2); + tree_nodes.push_back(merged_stream); + } + + Ok(SortOutput { + segment_null_bitmap: mem::take(&mut self.segment_null_bitmap), + sorted_stream: tree_nodes.pop_front().unwrap(), + total_row_count: self.total_row_count, + }) + } +} + +impl ExternalSorter { + /// Constructs a new `ExternalSorter` + pub fn new( + index_name: String, + temp_file_provider: Arc, + segment_row_count: NonZeroUsize, + memory_usage_threshold: Option, + ) -> Self { + Self { + index_name, + temp_file_provider, + + segment_null_bitmap: BitVec::new(), + values_buffer: BTreeMap::new(), + + total_row_count: 0, + segment_row_count, + + current_memory_usage: 0, + memory_usage_threshold, + } + } + + /// Pushes the non-null values to the values buffer and sets the bits within + /// the specified range in the given BitVec to true. + /// Returns the memory usage difference of the buffer after the operation. + fn push_not_null( + &mut self, + value: BytesRef<'_>, + segment_index_range: RangeInclusive, + ) -> usize { + match self.values_buffer.get_mut(value) { + Some(bitmap) => { + let old_len = bitmap.as_raw_slice().len(); + set_bits(bitmap, segment_index_range); + + bitmap.as_raw_slice().len() - old_len + } + None => { + let mut bitmap = BitVec::default(); + set_bits(&mut bitmap, segment_index_range); + + let mem_diff = bitmap.as_raw_slice().len() + value.len(); + self.values_buffer.insert(value.to_vec(), bitmap); + + mem_diff + } + } + } + + /// Checks if the in-memory buffer exceeds the threshold and offloads it to external storage if necessary + async fn may_dump_buffer(&mut self, memory_diff: usize) -> Result<()> { + self.current_memory_usage += memory_diff; + if self.memory_usage_threshold.is_none() + || self.current_memory_usage < self.memory_usage_threshold.unwrap() + { + return Ok(()); + } + + let file_id = &format!("{:012}", self.total_row_count); + let index_name = &self.index_name; + let writer = self.temp_file_provider.create(index_name, file_id).await?; + + let memory_usage = self.current_memory_usage; + let values = mem::take(&mut self.values_buffer); + self.current_memory_usage = 0; + + let entries = values.len(); + IntermediateWriter::new(writer).write_all(values).await.inspect(|_| + logging::debug!("Dumped {entries} entries ({memory_usage} bytes) to intermediate file {file_id} for index {index_name}") + ).inspect_err(|e| + logging::error!("Failed to dump {entries} entries to intermediate file {file_id} for index {index_name}. Error: {e}") + ) + } + + /// Determines the segment index range for the row index range + /// `[self.total_row_count, self.total_row_count + n - 1]` + fn segment_index_range(&self, n: usize) -> RangeInclusive { + let start = self.segment_index(self.total_row_count); + let end = self.segment_index(self.total_row_count + n - 1); + start..=end + } + + /// Determines the segment index for the given row index + fn segment_index(&self, row_index: usize) -> usize { + row_index / self.segment_row_count + } +} + +/// Sets the bits within the specified range in the given `BitVec` to true +fn set_bits(bitmap: &mut BitVec, index_range: RangeInclusive) { + if *index_range.end() >= bitmap.len() { + bitmap.resize(index_range.end() + 1, false); + } + for index in index_range { + bitmap.set(index, true); + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::iter; + use std::sync::Mutex; + + use futures::{AsyncRead, StreamExt}; + use rand::Rng; + use tokio::io::duplex; + use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + + use super::*; + use crate::inverted_index::create::sort::external_provider::MockExternalTempFileProvider; + + async fn test_external_sorter( + memory_usage_threshold: Option, + segment_row_count: usize, + row_count: usize, + batch_push: bool, + ) { + let mut mock_provider = MockExternalTempFileProvider::new(); + + let mock_files: Arc>>> = + Arc::new(Mutex::new(HashMap::new())); + + mock_provider.expect_create().returning({ + let files = Arc::clone(&mock_files); + move |index_name, file_id| { + assert_eq!(index_name, "test"); + let mut files = files.lock().unwrap(); + let (writer, reader) = duplex(8 * 1024); + files.insert(file_id.to_string(), Box::new(reader.compat())); + Ok(Box::new(writer.compat_write())) + } + }); + + mock_provider.expect_read_all().returning({ + let files = Arc::clone(&mock_files); + move |index_name| { + assert_eq!(index_name, "test"); + let mut files = files.lock().unwrap(); + Ok(files.drain().map(|f| f.1).collect::>()) + } + }); + + let mut sorter = ExternalSorter::new( + "test".to_owned(), + Arc::new(mock_provider), + NonZeroUsize::new(segment_row_count).unwrap(), + memory_usage_threshold, + ); + + let mut sorted_result = if batch_push { + let (dic_values, sorted_result) = + dictionary_values_and_sorted_result(row_count, segment_row_count); + + for (value, n) in dic_values { + sorter.push_n(value.as_deref(), n).await.unwrap(); + } + + sorted_result + } else { + let (mock_values, sorted_result) = + shuffle_values_and_sorted_result(row_count, segment_row_count); + + for value in mock_values { + sorter.push(value.as_deref()).await.unwrap(); + } + + sorted_result + }; + + let SortOutput { + segment_null_bitmap, + mut sorted_stream, + total_row_count, + } = sorter.output().await.unwrap(); + assert_eq!(total_row_count, row_count); + let n = sorted_result.remove(&None); + assert_eq!( + segment_null_bitmap.iter_ones().collect::>(), + n.unwrap_or_default() + ); + for (value, offsets) in sorted_result { + let item = sorted_stream.next().await.unwrap().unwrap(); + assert_eq!(item.0, value.unwrap()); + assert_eq!(item.1.iter_ones().collect::>(), offsets); + } + } + + #[tokio::test] + async fn test_external_sorter_pure_in_memory() { + let memory_usage_threshold = None; + let total_row_count_cases = vec![0, 100, 1000, 10000]; + let segment_row_count_cases = vec![1, 10, 100, 1000]; + let batch_push_cases = vec![false, true]; + + for total_row_count in total_row_count_cases { + for segment_row_count in &segment_row_count_cases { + for batch_push in &batch_push_cases { + test_external_sorter( + memory_usage_threshold, + *segment_row_count, + total_row_count, + *batch_push, + ) + .await; + } + } + } + } + + #[tokio::test] + async fn test_external_sorter_pure_external() { + let memory_usage_threshold = Some(0); + let total_row_count_cases = vec![0, 100, 1000, 10000]; + let segment_row_count_cases = vec![1, 10, 100, 1000]; + let batch_push_cases = vec![false, true]; + + for total_row_count in total_row_count_cases { + for segment_row_count in &segment_row_count_cases { + for batch_push in &batch_push_cases { + test_external_sorter( + memory_usage_threshold, + *segment_row_count, + total_row_count, + *batch_push, + ) + .await; + } + } + } + } + + #[tokio::test] + async fn test_external_sorter_mixed() { + let memory_usage_threshold = Some(1024); + let total_row_count_cases = vec![0, 100, 1000, 10000]; + let segment_row_count_cases = vec![1, 10, 100, 1000]; + let batch_push_cases = vec![false, true]; + + for total_row_count in total_row_count_cases { + for segment_row_count in &segment_row_count_cases { + for batch_push in &batch_push_cases { + test_external_sorter( + memory_usage_threshold, + *segment_row_count, + total_row_count, + *batch_push, + ) + .await; + } + } + } + } + + fn random_option_bytes(size: usize) -> Option> { + let mut rng = rand::thread_rng(); + + if rng.gen() { + let mut buffer = vec![0u8; size]; + rng.fill(&mut buffer[..]); + Some(buffer) + } else { + None + } + } + + type Values = Vec>; + type DictionaryValues = Vec<(Option, usize)>; + type ValueSegIds = BTreeMap, Vec>; + + fn shuffle_values_and_sorted_result( + row_count: usize, + segment_row_count: usize, + ) -> (Values, ValueSegIds) { + let mock_values = iter::repeat_with(|| random_option_bytes(100)) + .take(row_count) + .collect::>(); + + let sorted_result = sorted_result(&mock_values, segment_row_count); + (mock_values, sorted_result) + } + + fn dictionary_values_and_sorted_result( + row_count: usize, + segment_row_count: usize, + ) -> (DictionaryValues, ValueSegIds) { + let mut n = row_count; + let mut rng = rand::thread_rng(); + let mut dic_values = Vec::new(); + + while n > 0 { + let size = rng.gen_range(1..=n); + let value = random_option_bytes(100); + dic_values.push((value, size)); + n -= size; + } + + let mock_values = dic_values + .iter() + .flat_map(|(value, size)| iter::repeat(value.clone()).take(*size)) + .collect::>(); + + let sorted_result = sorted_result(&mock_values, segment_row_count); + (dic_values, sorted_result) + } + + fn sorted_result(values: &Values, segment_row_count: usize) -> ValueSegIds { + let mut sorted_result = BTreeMap::new(); + for (row_index, value) in values.iter().enumerate() { + let to_add_segment_index = row_index / segment_row_count; + let indices = sorted_result.entry(value.clone()).or_insert_with(Vec::new); + + if indices.last() != Some(&to_add_segment_index) { + indices.push(to_add_segment_index); + } + } + + sorted_result + } +} diff --git a/src/index/src/inverted_index/create/sort/merge_stream.rs b/src/index/src/inverted_index/create/sort/merge_stream.rs new file mode 100644 index 000000000000..84debecb8ada --- /dev/null +++ b/src/index/src/inverted_index/create/sort/merge_stream.rs @@ -0,0 +1,174 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ordering; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use common_base::BitVec; +use futures::{ready, Stream, StreamExt}; +use pin_project::pin_project; + +use crate::inverted_index::create::sort::SortedStream; +use crate::inverted_index::error::Result; +use crate::inverted_index::Bytes; + +/// A [`Stream`] implementation that merges two sorted streams into a single sorted stream +#[pin_project] +pub struct MergeSortedStream { + stream1: Option, + peek1: Option<(Bytes, BitVec)>, + + stream2: Option, + peek2: Option<(Bytes, BitVec)>, +} + +impl MergeSortedStream { + /// Creates a new `MergeSortedStream` that will return elements from `stream1` and `stream2` + /// in sorted order, merging duplicate items by unioning their bitmaps + pub fn merge(stream1: SortedStream, stream2: SortedStream) -> SortedStream { + Box::new(MergeSortedStream { + stream1: Some(stream1), + peek1: None, + + stream2: Some(stream2), + peek2: None, + }) + } +} + +impl Stream for MergeSortedStream { + type Item = Result<(Bytes, BitVec)>; + + /// Polls both streams and returns the next item from the stream that has the smaller next item. + /// If both streams have the same next item, the bitmaps are unioned together. + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + if let (None, Some(stream1)) = (&this.peek1, this.stream1.as_mut()) { + match ready!(stream1.poll_next_unpin(cx)) { + Some(item) => *this.peek1 = Some(item?), + None => *this.stream1 = None, // `stream1` is exhausted, don't poll it next time + } + } + + if let (None, Some(stream2)) = (&this.peek2, this.stream2.as_mut()) { + match ready!(stream2.poll_next_unpin(cx)) { + Some(item) => *this.peek2 = Some(item?), + None => *this.stream2 = None, // `stream2` is exhausted, don't poll it next time + } + } + + Poll::Ready(match (this.peek1.take(), this.peek2.take()) { + (Some((v1, b1)), Some((v2, b2))) => match v1.cmp(&v2) { + Ordering::Less => { + *this.peek2 = Some((v2, b2)); // Preserve the rest of `stream2` + Some(Ok((v1, b1))) + } + Ordering::Greater => { + *this.peek1 = Some((v1, b1)); // Preserve the rest of `stream1` + Some(Ok((v2, b2))) + } + Ordering::Equal => Some(Ok((v1, merge_bitmaps(b1, b2)))), + }, + (None, Some(item)) | (Some(item), None) => Some(Ok(item)), + (None, None) => None, + }) + } +} + +/// Merges two bitmaps by bit-wise OR'ing them together, preserving all bits from both +fn merge_bitmaps(bitmap1: BitVec, bitmap2: BitVec) -> BitVec { + // make sure longer bitmap is on the left to avoid truncation + #[allow(clippy::if_same_then_else)] + if bitmap1.len() > bitmap2.len() { + bitmap1 | bitmap2 + } else { + bitmap2 | bitmap1 + } +} + +#[cfg(test)] +mod tests { + use futures::stream; + + use super::*; + use crate::inverted_index::error::Error; + + fn sorted_stream_from_vec(vec: Vec<(Bytes, BitVec)>) -> SortedStream { + Box::new(stream::iter(vec.into_iter().map(Ok::<_, Error>))) + } + + #[tokio::test] + async fn test_merge_sorted_stream_non_overlapping() { + let stream1 = sorted_stream_from_vec(vec![ + (Bytes::from("apple"), BitVec::from_slice(&[0b10101010])), + (Bytes::from("orange"), BitVec::from_slice(&[0b01010101])), + ]); + let stream2 = sorted_stream_from_vec(vec![ + (Bytes::from("banana"), BitVec::from_slice(&[0b10101010])), + (Bytes::from("peach"), BitVec::from_slice(&[0b01010101])), + ]); + + let mut merged_stream = MergeSortedStream::merge(stream1, stream2); + + let item = merged_stream.next().await.unwrap().unwrap(); + assert_eq!(item.0, Bytes::from("apple")); + assert_eq!(item.1, BitVec::from_slice(&[0b10101010])); + let item = merged_stream.next().await.unwrap().unwrap(); + assert_eq!(item.0, Bytes::from("banana")); + assert_eq!(item.1, BitVec::from_slice(&[0b10101010])); + let item = merged_stream.next().await.unwrap().unwrap(); + assert_eq!(item.0, Bytes::from("orange")); + assert_eq!(item.1, BitVec::from_slice(&[0b01010101])); + let item = merged_stream.next().await.unwrap().unwrap(); + assert_eq!(item.0, Bytes::from("peach")); + assert_eq!(item.1, BitVec::from_slice(&[0b01010101])); + assert!(merged_stream.next().await.is_none()); + } + + #[tokio::test] + async fn test_merge_sorted_stream_overlapping() { + let stream1 = sorted_stream_from_vec(vec![ + (Bytes::from("apple"), BitVec::from_slice(&[0b10101010])), + (Bytes::from("orange"), BitVec::from_slice(&[0b10101010])), + ]); + let stream2 = sorted_stream_from_vec(vec![ + (Bytes::from("apple"), BitVec::from_slice(&[0b01010101])), + (Bytes::from("peach"), BitVec::from_slice(&[0b01010101])), + ]); + + let mut merged_stream = MergeSortedStream::merge(stream1, stream2); + + let item = merged_stream.next().await.unwrap().unwrap(); + assert_eq!(item.0, Bytes::from("apple")); + assert_eq!(item.1, BitVec::from_slice(&[0b11111111])); + let item = merged_stream.next().await.unwrap().unwrap(); + assert_eq!(item.0, Bytes::from("orange")); + assert_eq!(item.1, BitVec::from_slice(&[0b10101010])); + let item = merged_stream.next().await.unwrap().unwrap(); + assert_eq!(item.0, Bytes::from("peach")); + assert_eq!(item.1, BitVec::from_slice(&[0b01010101])); + assert!(merged_stream.next().await.is_none()); + } + + #[tokio::test] + async fn test_merge_sorted_stream_empty_streams() { + let stream1 = sorted_stream_from_vec(vec![]); + let stream2 = sorted_stream_from_vec(vec![]); + + let mut merged_stream = MergeSortedStream::merge(stream1, stream2); + assert!(merged_stream.next().await.is_none()); + } +}