From aebae9965d07221b7ee2b57d2a217c48ab0137ed Mon Sep 17 00:00:00 2001 From: PSeitz Date: Mon, 21 Oct 2024 18:29:17 +0800 Subject: [PATCH] add RegexPhraseQuery (#2516) * add RegexPhraseQuery RegexPhraseQuery supports phrase queries with regex. It supports regex and wildcards. E.g. a query with wildcards: "b* b* wolf" matches "big bad wolf" Slop is supported as well: "b* wolf"~2 matches "big bad wolf" Regex queries may match a lot of terms where we still need to keep track which term hit to load the positions. The phrase query algorithm groups terms by their frequency together in the union to prefilter groups early. This PR comes with some new datastructures: SimpleUnion - A union docset for a list of docsets. It doesn't do any caching and is therefore well suited for datasets with lots of skipping. (phrase search, but intersections in general) LoadedPostings - Like SegmentPostings, but all docs and positions are loaded in memory. SegmentPostings uses 1840 bytes per instance with its caches, which is equivalent to 460 docids. LoadedPostings is used for terms which have less than 100 docs. LoadedPostings is only used to reduce memory consumption. BitSetPostingUnion - Creates a `Posting` that uses the bitset for docid hits and the docsets for positions. The BitSet is the precalculated union of the docsets In the RegexPhraseQuery there is a size limit of 512 docsets per PreAggregatedUnion, before creating a new one. Renamed Union to BufferedUnionScorer Added proptests to test different union types. * cleanup * use Box instead of Vec * use RefCell instead of term_freq(&mut) * remove wildcard mode * move RefCell to outer * clippy --- src/postings/loaded_postings.rs | 155 ++++++ src/postings/mod.rs | 2 + src/postings/postings.rs | 19 +- src/postings/segment_postings.rs | 12 +- src/query/automaton_weight.rs | 13 + src/query/boolean_query/block_wand.rs | 4 +- src/query/boolean_query/boolean_weight.rs | 19 +- src/query/mod.rs | 3 +- src/query/phrase_query/mod.rs | 8 +- src/query/phrase_query/phrase_weight.rs | 29 +- src/query/phrase_query/regex_phrase_query.rs | 172 +++++++ src/query/phrase_query/regex_phrase_weight.rs | 475 ++++++++++++++++++ src/query/union/bitset_union.rs | 89 ++++ .../{union.rs => union/buffered_union.rs} | 214 +------- src/query/union/mod.rs | 303 +++++++++++ src/query/union/simple_union.rs | 112 +++++ 16 files changed, 1380 insertions(+), 249 deletions(-) create mode 100644 src/postings/loaded_postings.rs create mode 100644 src/query/phrase_query/regex_phrase_query.rs create mode 100644 src/query/phrase_query/regex_phrase_weight.rs create mode 100644 src/query/union/bitset_union.rs rename src/query/{union.rs => union/buffered_union.rs} (50%) create mode 100644 src/query/union/mod.rs create mode 100644 src/query/union/simple_union.rs diff --git a/src/postings/loaded_postings.rs b/src/postings/loaded_postings.rs new file mode 100644 index 0000000000..7212be4ac8 --- /dev/null +++ b/src/postings/loaded_postings.rs @@ -0,0 +1,155 @@ +use crate::docset::{DocSet, TERMINATED}; +use crate::postings::{Postings, SegmentPostings}; +use crate::DocId; + +/// `LoadedPostings` is a `DocSet` and `Postings` implementation. +/// It is used to represent the postings of a term in memory. +/// It is suitable if there are few documents for a term. +/// +/// It exists mainly to reduce memory usage. +/// `SegmentPostings` uses 1840 bytes per instance due to its caches. +/// If you need to keep many terms around with few docs, it's cheaper to load all the +/// postings in memory. +/// +/// This is relevant for `RegexPhraseQuery`, which may have a lot of +/// terms. +/// E.g. 100_000 terms would need 184MB due to SegmentPostings. +pub struct LoadedPostings { + doc_ids: Box<[DocId]>, + position_offsets: Box<[u32]>, + positions: Box<[u32]>, + cursor: usize, +} + +impl LoadedPostings { + /// Creates a new `LoadedPostings` from a `SegmentPostings`. + /// + /// It will also preload positions, if positions are available in the SegmentPostings. + pub fn load(segment_postings: &mut SegmentPostings) -> LoadedPostings { + let num_docs = segment_postings.doc_freq() as usize; + let mut doc_ids = Vec::with_capacity(num_docs); + let mut positions = Vec::with_capacity(num_docs); + let mut position_offsets = Vec::with_capacity(num_docs); + while segment_postings.doc() != TERMINATED { + position_offsets.push(positions.len() as u32); + doc_ids.push(segment_postings.doc()); + segment_postings.append_positions_with_offset(0, &mut positions); + segment_postings.advance(); + } + position_offsets.push(positions.len() as u32); + LoadedPostings { + doc_ids: doc_ids.into_boxed_slice(), + positions: positions.into_boxed_slice(), + position_offsets: position_offsets.into_boxed_slice(), + cursor: 0, + } + } +} + +#[cfg(test)] +impl From<(Vec, Vec>)> for LoadedPostings { + fn from(doc_ids_and_positions: (Vec, Vec>)) -> LoadedPostings { + let mut position_offsets = Vec::new(); + let mut all_positions = Vec::new(); + let (doc_ids, docid_positions) = doc_ids_and_positions; + for positions in docid_positions { + position_offsets.push(all_positions.len() as u32); + all_positions.extend_from_slice(&positions); + } + position_offsets.push(all_positions.len() as u32); + LoadedPostings { + doc_ids: doc_ids.into_boxed_slice(), + positions: all_positions.into_boxed_slice(), + position_offsets: position_offsets.into_boxed_slice(), + cursor: 0, + } + } +} + +impl DocSet for LoadedPostings { + fn advance(&mut self) -> DocId { + self.cursor += 1; + if self.cursor >= self.doc_ids.len() { + self.cursor = self.doc_ids.len(); + return TERMINATED; + } + self.doc() + } + + fn doc(&self) -> DocId { + if self.cursor >= self.doc_ids.len() { + return TERMINATED; + } + self.doc_ids[self.cursor] + } + + fn size_hint(&self) -> u32 { + self.doc_ids.len() as u32 + } +} +impl Postings for LoadedPostings { + fn term_freq(&self) -> u32 { + let start = self.position_offsets[self.cursor] as usize; + let end = self.position_offsets[self.cursor + 1] as usize; + (end - start) as u32 + } + + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + let start = self.position_offsets[self.cursor] as usize; + let end = self.position_offsets[self.cursor + 1] as usize; + for pos in &self.positions[start..end] { + output.push(*pos + offset); + } + } +} + +#[cfg(test)] +pub mod tests { + + use super::*; + + #[test] + pub fn test_vec_postings() { + let doc_ids: Vec = (0u32..1024u32).map(|e| e * 3).collect(); + let mut postings = LoadedPostings::from((doc_ids, vec![])); + assert_eq!(postings.doc(), 0u32); + assert_eq!(postings.advance(), 3u32); + assert_eq!(postings.doc(), 3u32); + assert_eq!(postings.seek(14u32), 15u32); + assert_eq!(postings.doc(), 15u32); + assert_eq!(postings.seek(300u32), 300u32); + assert_eq!(postings.doc(), 300u32); + assert_eq!(postings.seek(6000u32), TERMINATED); + } + + #[test] + pub fn test_vec_postings2() { + let doc_ids: Vec = (0u32..1024u32).map(|e| e * 3).collect(); + let mut positions = Vec::new(); + positions.resize(1024, Vec::new()); + positions[0] = vec![1u32, 2u32, 3u32]; + positions[1] = vec![30u32]; + positions[2] = vec![10u32]; + positions[4] = vec![50u32]; + let mut postings = LoadedPostings::from((doc_ids, positions)); + + let load = |postings: &mut LoadedPostings| { + let mut loaded_positions = Vec::new(); + postings.positions(loaded_positions.as_mut()); + loaded_positions + }; + assert_eq!(postings.doc(), 0u32); + assert_eq!(load(&mut postings), vec![1u32, 2u32, 3u32]); + + assert_eq!(postings.advance(), 3u32); + assert_eq!(postings.doc(), 3u32); + + assert_eq!(load(&mut postings), vec![30u32]); + + assert_eq!(postings.seek(14u32), 15u32); + assert_eq!(postings.doc(), 15u32); + assert_eq!(postings.seek(300u32), 300u32); + assert_eq!(postings.doc(), 300u32); + assert_eq!(postings.seek(6000u32), TERMINATED); + } +} diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 5fd90032df..7060916bd1 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -8,6 +8,7 @@ mod block_segment_postings; pub(crate) mod compression; mod indexing_context; mod json_postings_writer; +mod loaded_postings; mod per_field_postings_writer; mod postings; mod postings_writer; @@ -17,6 +18,7 @@ mod serializer; mod skip; mod term_info; +pub(crate) use loaded_postings::LoadedPostings; pub(crate) use stacker::compute_table_memory_size; pub use self::block_segment_postings::BlockSegmentPostings; diff --git a/src/postings/postings.rs b/src/postings/postings.rs index 682e61393f..8606f00a99 100644 --- a/src/postings/postings.rs +++ b/src/postings/postings.rs @@ -17,7 +17,14 @@ pub trait Postings: DocSet + 'static { /// Returns the positions offsetted with a given value. /// It is not necessary to clear the `output` before calling this method. /// The output vector will be resized to the `term_freq`. - fn positions_with_offset(&mut self, offset: u32, output: &mut Vec); + fn positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + output.clear(); + self.append_positions_with_offset(offset, output); + } + + /// Returns the positions offsetted with a given value. + /// Data will be appended to the output. + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec); /// Returns the positions of the term in the given document. /// The output vector will be resized to the `term_freq`. @@ -25,3 +32,13 @@ pub trait Postings: DocSet + 'static { self.positions_with_offset(0u32, output); } } + +impl Postings for Box { + fn term_freq(&self) -> u32 { + (**self).term_freq() + } + + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + (**self).append_positions_with_offset(offset, output); + } +} diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 3d91cf2ee2..51194a3569 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -237,8 +237,9 @@ impl Postings for SegmentPostings { self.block_cursor.freq(self.cur) } - fn positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { let term_freq = self.term_freq(); + let prev_len = output.len(); if let Some(position_reader) = self.position_reader.as_mut() { debug_assert!( !self.block_cursor.freqs().is_empty(), @@ -249,15 +250,14 @@ impl Postings for SegmentPostings { .iter() .cloned() .sum::() as u64); - output.resize(term_freq as usize, 0u32); - position_reader.read(read_offset, &mut output[..]); + // TODO: instead of zeroing the output, we could use MaybeUninit or similar. + output.resize(prev_len + term_freq as usize, 0u32); + position_reader.read(read_offset, &mut output[prev_len..]); let mut cum = offset; - for output_mut in output.iter_mut() { + for output_mut in output[prev_len..].iter_mut() { cum += *output_mut; *output_mut = cum; } - } else { - output.clear(); } } } diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index ef675864bd..5f1053fb67 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -6,6 +6,7 @@ use tantivy_fst::Automaton; use super::phrase_prefix_query::prefix_end; use crate::index::SegmentReader; +use crate::postings::TermInfo; use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight}; use crate::schema::{Field, IndexRecordOption}; use crate::termdict::{TermDictionary, TermStreamer}; @@ -64,6 +65,18 @@ where term_stream_builder.into_stream() } + + /// Returns the term infos that match the automaton + pub fn get_match_term_infos(&self, reader: &SegmentReader) -> crate::Result> { + let inverted_index = reader.inverted_index(self.field)?; + let term_dict = inverted_index.terms(); + let mut term_stream = self.automaton_stream(term_dict)?; + let mut term_infos = Vec::new(); + while term_stream.advance() { + term_infos.push(term_stream.value().clone()); + } + Ok(term_infos) + } } impl Weight for AutomatonWeight diff --git a/src/query/boolean_query/block_wand.rs b/src/query/boolean_query/block_wand.rs index ad9a8b2ba5..59e22caa14 100644 --- a/src/query/boolean_query/block_wand.rs +++ b/src/query/boolean_query/block_wand.rs @@ -308,7 +308,7 @@ mod tests { use crate::query::score_combiner::SumCombiner; use crate::query::term_query::TermScorer; - use crate::query::{Bm25Weight, Scorer, Union}; + use crate::query::{Bm25Weight, BufferedUnionScorer, Scorer}; use crate::{DocId, DocSet, Score, TERMINATED}; struct Float(Score); @@ -371,7 +371,7 @@ mod tests { fn compute_checkpoints_manual(term_scorers: Vec, n: usize) -> Vec<(DocId, Score)> { let mut heap: BinaryHeap = BinaryHeap::with_capacity(n); let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); - let mut scorer = Union::build(term_scorers, SumCombiner::default); + let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default); let mut limit = Score::MIN; loop { diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index c0a5e2c37f..7b617866fe 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -9,8 +9,8 @@ use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::term_query::TermScorer; use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer}; use crate::query::{ - intersect_scorers, EmptyScorer, Exclude, Explanation, Occur, RequiredOptionalScorer, Scorer, - Union, Weight, + intersect_scorers, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur, + RequiredOptionalScorer, Scorer, Weight, }; use crate::{DocId, Score}; @@ -65,14 +65,17 @@ where // Block wand is only available if we read frequencies. return SpecializedScorer::TermUnion(scorers); } else { - return SpecializedScorer::Other(Box::new(Union::build( + return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( scorers, score_combiner_fn, ))); } } } - SpecializedScorer::Other(Box::new(Union::build(scorers, score_combiner_fn))) + SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( + scorers, + score_combiner_fn, + ))) } fn into_box_scorer( @@ -81,7 +84,7 @@ fn into_box_scorer( ) -> Box { match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let union_scorer = Union::build(term_scorers, score_combiner_fn); + let union_scorer = BufferedUnionScorer::build(term_scorers, score_combiner_fn); Box::new(union_scorer) } SpecializedScorer::Other(scorer) => scorer, @@ -296,7 +299,8 @@ impl Weight for BooleanWeight { - let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn); + let mut union_scorer = + BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn); for_each_scorer(&mut union_scorer, callback); } SpecializedScorer::Other(mut scorer) => { @@ -316,7 +320,8 @@ impl Weight for BooleanWeight { - let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn); + let mut union_scorer = + BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn); for_each_docset_buffered(&mut union_scorer, &mut buffer, callback); } SpecializedScorer::Other(mut scorer) => { diff --git a/src/query/mod.rs b/src/query/mod.rs index 5e99354ffe..23e64f1894 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -51,6 +51,7 @@ pub use self::fuzzy_query::FuzzyTermQuery; pub use self::intersection::{intersect_scorers, Intersection}; pub use self::more_like_this::{MoreLikeThisQuery, MoreLikeThisQueryBuilder}; pub use self::phrase_prefix_query::PhrasePrefixQuery; +pub use self::phrase_query::regex_phrase_query::{wildcard_query_to_regex_str, RegexPhraseQuery}; pub use self::phrase_query::PhraseQuery; pub use self::query::{EnableScoring, Query, QueryClone}; pub use self::query_parser::{QueryParser, QueryParserError}; @@ -61,7 +62,7 @@ pub use self::score_combiner::{DisjunctionMaxCombiner, ScoreCombiner, SumCombine pub use self::scorer::Scorer; pub use self::set_query::TermSetQuery; pub use self::term_query::TermQuery; -pub use self::union::Union; +pub use self::union::BufferedUnionScorer; #[cfg(test)] pub use self::vec_docset::VecDocSet; pub use self::weight::Weight; diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 7b8d3e0074..f37c39b153 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -1,6 +1,8 @@ mod phrase_query; mod phrase_scorer; mod phrase_weight; +pub mod regex_phrase_query; +mod regex_phrase_weight; pub use self::phrase_query::PhraseQuery; pub(crate) use self::phrase_scorer::intersection_count; @@ -19,15 +21,15 @@ pub mod tests { use crate::schema::{Schema, Term, TEXT}; use crate::{assert_nearly_equals, DocAddress, DocId, IndexWriter, TERMINATED}; - pub fn create_index(texts: &[&'static str]) -> crate::Result { + pub fn create_index>(texts: &[S]) -> crate::Result { let mut schema_builder = Schema::builder(); let text_field = schema_builder.add_text_field("text", TEXT); let schema = schema_builder.build(); let index = Index::create_in_ram(schema); { let mut index_writer: IndexWriter = index.writer_for_tests()?; - for &text in texts { - let doc = doc!(text_field=>text); + for text in texts { + let doc = doc!(text_field=>text.as_ref()); index_writer.add_document(doc)?; } index_writer.commit()?; diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index 6e97bca7f3..4118f79f6c 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -50,27 +50,14 @@ impl PhraseWeight { .map(|similarity_weight| similarity_weight.boost_by(boost)); let fieldnorm_reader = self.fieldnorm_reader(reader)?; let mut term_postings_list = Vec::new(); - if reader.has_deletes() { - for &(offset, ref term) in &self.phrase_terms { - if let Some(postings) = reader - .inverted_index(term.field())? - .read_postings(term, IndexRecordOption::WithFreqsAndPositions)? - { - term_postings_list.push((offset, postings)); - } else { - return Ok(None); - } - } - } else { - for &(offset, ref term) in &self.phrase_terms { - if let Some(postings) = reader - .inverted_index(term.field())? - .read_postings_no_deletes(term, IndexRecordOption::WithFreqsAndPositions)? - { - term_postings_list.push((offset, postings)); - } else { - return Ok(None); - } + for &(offset, ref term) in &self.phrase_terms { + if let Some(postings) = reader + .inverted_index(term.field())? + .read_postings(term, IndexRecordOption::WithFreqsAndPositions)? + { + term_postings_list.push((offset, postings)); + } else { + return Ok(None); } } Ok(Some(PhraseScorer::new( diff --git a/src/query/phrase_query/regex_phrase_query.rs b/src/query/phrase_query/regex_phrase_query.rs new file mode 100644 index 0000000000..27096fcf16 --- /dev/null +++ b/src/query/phrase_query/regex_phrase_query.rs @@ -0,0 +1,172 @@ +use super::regex_phrase_weight::RegexPhraseWeight; +use crate::query::bm25::Bm25Weight; +use crate::query::{EnableScoring, Query, Weight}; +use crate::schema::{Field, IndexRecordOption, Term, Type}; + +/// `RegexPhraseQuery` matches a specific sequence of regex queries. +/// +/// For instance, the phrase query for `"pa.* time"` will match +/// the sentence: +/// +/// **Alan just got a part time job.** +/// +/// On the other hand it will not match the sentence. +/// +/// **This is my favorite part of the job.** +/// +/// [Slop](RegexPhraseQuery::set_slop) allows leniency in term proximity +/// for some performance trade-off. +/// +/// Using a `RegexPhraseQuery` on a field requires positions +/// to be indexed for this field. +#[derive(Clone, Debug)] +pub struct RegexPhraseQuery { + field: Field, + phrase_terms: Vec<(usize, String)>, + slop: u32, + max_expansions: u32, +} + +/// Transform a wildcard query to a regex string. +/// +/// `AB*CD` for example is converted to `AB.*CD` +/// +/// All other chars are regex escaped. +pub fn wildcard_query_to_regex_str(term: &str) -> String { + regex::escape(term).replace(r"\*", ".*") +} + +impl RegexPhraseQuery { + /// Creates a new `RegexPhraseQuery` given a list of terms. + /// + /// There must be at least two terms, and all terms + /// must belong to the same field. + /// + /// Offset for each term will be same as index in the Vector + pub fn new(field: Field, terms: Vec) -> RegexPhraseQuery { + let terms_with_offset = terms.into_iter().enumerate().collect(); + RegexPhraseQuery::new_with_offset(field, terms_with_offset) + } + + /// Creates a new `RegexPhraseQuery` given a list of terms and their offsets. + /// + /// Can be used to provide custom offset for each term. + pub fn new_with_offset(field: Field, terms: Vec<(usize, String)>) -> RegexPhraseQuery { + RegexPhraseQuery::new_with_offset_and_slop(field, terms, 0) + } + + /// Creates a new `RegexPhraseQuery` given a list of terms, their offsets and a slop + pub fn new_with_offset_and_slop( + field: Field, + mut terms: Vec<(usize, String)>, + slop: u32, + ) -> RegexPhraseQuery { + assert!( + terms.len() > 1, + "A phrase query is required to have strictly more than one term." + ); + terms.sort_by_key(|&(offset, _)| offset); + RegexPhraseQuery { + field, + phrase_terms: terms, + slop, + max_expansions: 1 << 14, + } + } + + /// Slop allowed for the phrase. + /// + /// The query will match if its terms are separated by `slop` terms at most. + /// The slop can be considered a budget between all terms. + /// E.g. "A B C" with slop 1 allows "A X B C", "A B X C", but not "A X B X C". + /// + /// Transposition costs 2, e.g. "A B" with slop 1 will not match "B A" but it would with slop 2 + /// Transposition is not a special case, in the example above A is moved 1 position and B is + /// moved 1 position, so the slop is 2. + /// + /// As a result slop works in both directions, so the order of the terms may changed as long as + /// they respect the slop. + /// + /// By default the slop is 0 meaning query terms need to be adjacent. + pub fn set_slop(&mut self, value: u32) { + self.slop = value; + } + + /// Sets the max expansions a regex term can match. The limit will be over all terms. + /// After the limit is hit an error will be returned. + pub fn set_max_expansions(&mut self, value: u32) { + self.max_expansions = value; + } + + /// The [`Field`] this `RegexPhraseQuery` is targeting. + pub fn field(&self) -> Field { + self.field + } + + /// `Term`s in the phrase without the associated offsets. + pub fn phrase_terms(&self) -> Vec { + self.phrase_terms + .iter() + .map(|(_, term)| Term::from_field_text(self.field, term)) + .collect::>() + } + + /// Returns the [`RegexPhraseWeight`] for the given phrase query given a specific `searcher`. + /// + /// This function is the same as [`Query::weight()`] except it returns + /// a specialized type [`RegexPhraseWeight`] instead of a Boxed trait. + pub(crate) fn regex_phrase_weight( + &self, + enable_scoring: EnableScoring<'_>, + ) -> crate::Result { + let schema = enable_scoring.schema(); + let field_type = schema.get_field_entry(self.field).field_type().value_type(); + if field_type != Type::Str { + return Err(crate::TantivyError::SchemaError(format!( + "RegexPhraseQuery can only be used with a field of type text currently, but got \ + {:?}", + field_type + ))); + } + + let field_entry = schema.get_field_entry(self.field); + let has_positions = field_entry + .field_type() + .get_index_record_option() + .map(IndexRecordOption::has_positions) + .unwrap_or(false); + if !has_positions { + let field_name = field_entry.name(); + return Err(crate::TantivyError::SchemaError(format!( + "Applied phrase query on field {field_name:?}, which does not have positions \ + indexed" + ))); + } + let terms = self.phrase_terms(); + let bm25_weight_opt = match enable_scoring { + EnableScoring::Enabled { + statistics_provider, + .. + } => Some(Bm25Weight::for_terms(statistics_provider, &terms)?), + EnableScoring::Disabled { .. } => None, + }; + let weight = RegexPhraseWeight::new( + self.field, + self.phrase_terms.clone(), + bm25_weight_opt, + self.max_expansions, + self.slop, + ); + Ok(weight) + } +} + +impl Query for RegexPhraseQuery { + /// Create the weight associated with a query. + /// + /// See [`Weight`]. + fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result> { + let phrase_weight = self.regex_phrase_weight(enable_scoring)?; + Ok(Box::new(phrase_weight)) + } +} diff --git a/src/query/phrase_query/regex_phrase_weight.rs b/src/query/phrase_query/regex_phrase_weight.rs new file mode 100644 index 0000000000..53959c6440 --- /dev/null +++ b/src/query/phrase_query/regex_phrase_weight.rs @@ -0,0 +1,475 @@ +use std::sync::Arc; + +use common::BitSet; +use tantivy_fst::Regex; + +use super::PhraseScorer; +use crate::fieldnorm::FieldNormReader; +use crate::index::SegmentReader; +use crate::postings::{LoadedPostings, Postings, SegmentPostings, TermInfo}; +use crate::query::bm25::Bm25Weight; +use crate::query::explanation::does_not_match; +use crate::query::union::{BitSetPostingUnion, SimpleUnion}; +use crate::query::{AutomatonWeight, BitSetDocSet, EmptyScorer, Explanation, Scorer, Weight}; +use crate::schema::{Field, IndexRecordOption}; +use crate::{DocId, DocSet, InvertedIndexReader, Score}; + +type UnionType = SimpleUnion>; + +/// The `RegexPhraseWeight` is the weight associated to a regex phrase query. +/// See RegexPhraseWeight::get_union_from_term_infos for some design decisions. +pub struct RegexPhraseWeight { + field: Field, + phrase_terms: Vec<(usize, String)>, + similarity_weight_opt: Option, + slop: u32, + max_expansions: u32, +} + +impl RegexPhraseWeight { + /// Creates a new phrase weight. + /// If `similarity_weight_opt` is None, then scoring is disabled + pub fn new( + field: Field, + phrase_terms: Vec<(usize, String)>, + similarity_weight_opt: Option, + max_expansions: u32, + slop: u32, + ) -> RegexPhraseWeight { + RegexPhraseWeight { + field, + phrase_terms, + similarity_weight_opt, + slop, + max_expansions, + } + } + + fn fieldnorm_reader(&self, reader: &SegmentReader) -> crate::Result { + if self.similarity_weight_opt.is_some() { + if let Some(fieldnorm_reader) = reader.fieldnorms_readers().get_field(self.field)? { + return Ok(fieldnorm_reader); + } + } + Ok(FieldNormReader::constant(reader.max_doc(), 1)) + } + + pub(crate) fn phrase_scorer( + &self, + reader: &SegmentReader, + boost: Score, + ) -> crate::Result>> { + let similarity_weight_opt = self + .similarity_weight_opt + .as_ref() + .map(|similarity_weight| similarity_weight.boost_by(boost)); + let fieldnorm_reader = self.fieldnorm_reader(reader)?; + let mut posting_lists = Vec::new(); + let inverted_index = reader.inverted_index(self.field)?; + let mut num_terms = 0; + for &(offset, ref term) in &self.phrase_terms { + let regex = Regex::new(term) + .map_err(|e| crate::TantivyError::InvalidArgument(format!("Invalid regex: {e}")))?; + + let automaton: AutomatonWeight = + AutomatonWeight::new(self.field, Arc::new(regex)); + let term_infos = automaton.get_match_term_infos(reader)?; + // If term_infos is empty, the phrase can not match any documents. + if term_infos.is_empty() { + return Ok(None); + } + num_terms += term_infos.len(); + if num_terms > self.max_expansions as usize { + return Err(crate::TantivyError::InvalidArgument(format!( + "Phrase query exceeded max expansions {}", + num_terms + ))); + } + let union = Self::get_union_from_term_infos(&term_infos, reader, &inverted_index)?; + + posting_lists.push((offset, union)); + } + + Ok(Some(PhraseScorer::new( + posting_lists, + similarity_weight_opt, + fieldnorm_reader, + self.slop, + ))) + } + + /// Add all docs of the term to the docset + fn add_to_bitset( + inverted_index: &InvertedIndexReader, + term_info: &TermInfo, + doc_bitset: &mut BitSet, + ) -> crate::Result<()> { + let mut block_segment_postings = inverted_index + .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?; + loop { + let docs = block_segment_postings.docs(); + if docs.is_empty() { + break; + } + for &doc in docs { + doc_bitset.insert(doc); + } + block_segment_postings.advance(); + } + Ok(()) + } + + /// This function generates a union of document sets from multiple term information + /// (`TermInfo`). + /// + /// It uses bucketing based on term frequency to optimize query performance and memory usage. + /// The terms are divided into buckets based on their document frequency (the number of + /// documents they appear in). + /// + /// ### Bucketing Strategy: + /// Once a bucket contains more than 512 terms, it is moved to the end of the list and replaced + /// with a new empty bucket. + /// + /// - **Sparse Term Buckets**: Terms with document frequency `< 100`. + /// + /// Each sparse bucket contains: + /// - A `BitSet` to efficiently track which document IDs are present in the bucket, which is + /// used to drive the `DocSet`. + /// - A `Vec` to store the postings for each term in that bucket. + /// + /// - **Other Term Buckets**: + /// - **Bucket 0**: Terms appearing in less than `0.1%` of documents. + /// - **Bucket 1**: Terms appearing in `0.1%` to `1%` of documents. + /// - **Bucket 2**: Terms appearing in `1%` to `10%` of documents. + /// - **Bucket 3**: Terms appearing in more than `10%` of documents. + /// + /// Each bucket contains: + /// - A `BitSet` to efficiently track which document IDs are present in the bucket. + /// - A `Vec` to store the postings for each term in that bucket. + /// + /// ### Design Choices: + /// The main cost for a _unbucketed_ regex phrase query with a medium/high amount of terms is + /// the `append_positions_with_offset` from `Postings`. + /// We don't know which docsets hit, so we need to scan all of them to check if they contain the + /// docid. + /// The bucketing strategy groups less common DocSets together, so we can rule out the + /// whole docset group in many cases. + /// + /// E.g. consider the phrase "th* world" + /// It contains the term "the", which may occur in almost all documents. + /// It may also contain 10_000s very rare terms like "theologian". + /// + /// For very low-frequency terms (sparse terms), we use `LoadedPostings` and aggregate + /// their document IDs into a `BitSet`, which is more memory-efficient than using + /// `SegmentPostings`. E.g. 100_000 terms with SegmentPostings would consume 184MB. + /// `SegmentPostings` uses memory equivalent to 460 docids. The 100 docs limit should be + /// fine as long as a term doesn't have too many positions per doc. + /// + /// ### Future Optimization: + /// A larger performance improvement would be an additional partitioning of the space + /// vertically of u16::MAX blocks, where we mark which docset ord has values in each block. + /// E.g. partitioning in a index with 5 million documents this would reduce the number of + /// docsets to scan to around 1/20 in the sparse term bucket where the terms only have a few + /// docs. For higher cardinality buckets this is irrelevant as they are in most blocks. + /// + /// Use Roaring Bitmaps for sparse terms. The full bitvec is main memory consumer currently. + pub(crate) fn get_union_from_term_infos( + term_infos: &[TermInfo], + reader: &SegmentReader, + inverted_index: &InvertedIndexReader, + ) -> crate::Result { + let max_doc = reader.max_doc(); + + // Buckets for sparse terms + let mut sparse_buckets: Vec<(BitSet, Vec)> = + vec![(BitSet::with_max_value(max_doc), Vec::new())]; + + // Buckets for other terms based on document frequency percentages: + // - Bucket 0: Terms appearing in less than 0.1% of documents + // - Bucket 1: Terms appearing in 0.1% to 1% of documents + // - Bucket 2: Terms appearing in 1% to 10% of documents + // - Bucket 3: Terms appearing in more than 10% of documents + let mut buckets: Vec<(BitSet, Vec)> = (0..4) + .map(|_| (BitSet::with_max_value(max_doc), Vec::new())) + .collect(); + + const SPARSE_TERM_DOC_THRESHOLD: u32 = 100; + + for term_info in term_infos { + let mut term_posting = inverted_index + .read_postings_from_terminfo(term_info, IndexRecordOption::WithFreqsAndPositions)?; + let num_docs = term_posting.doc_freq(); + + if num_docs < SPARSE_TERM_DOC_THRESHOLD { + let current_bucket = &mut sparse_buckets[0]; + Self::add_to_bitset(inverted_index, term_info, &mut current_bucket.0)?; + let docset = LoadedPostings::load(&mut term_posting); + current_bucket.1.push(docset); + + // Move the bucket to the end if the term limit is reached + if current_bucket.1.len() == 512 { + sparse_buckets.push((BitSet::with_max_value(max_doc), Vec::new())); + let end_index = sparse_buckets.len() - 1; + sparse_buckets.swap(0, end_index); + } + } else { + // Calculate the percentage of documents the term appears in + let doc_freq_percentage = (num_docs as f32) / (max_doc as f32) * 100.0; + + // Determine the appropriate bucket based on percentage thresholds + let bucket_index = if doc_freq_percentage < 0.1 { + 0 + } else if doc_freq_percentage < 1.0 { + 1 + } else if doc_freq_percentage < 10.0 { + 2 + } else { + 3 + }; + let bucket = &mut buckets[bucket_index]; + + // Add term postings to the appropriate bucket + Self::add_to_bitset(inverted_index, term_info, &mut bucket.0)?; + bucket.1.push(term_posting); + + // Move the bucket to the end if the term limit is reached + if bucket.1.len() == 512 { + buckets.push((BitSet::with_max_value(max_doc), Vec::new())); + let end_index = buckets.len() - 1; + buckets.swap(bucket_index, end_index); + } + } + } + + // Build unions for sparse term buckets + let sparse_term_docsets: Vec<_> = sparse_buckets + .into_iter() + .filter(|(_, postings)| !postings.is_empty()) + .map(|(bitset, postings)| { + BitSetPostingUnion::build(postings, BitSetDocSet::from(bitset)) + }) + .collect(); + let sparse_term_unions = SimpleUnion::build(sparse_term_docsets); + + // Build unions for other term buckets + let bitset_unions_per_bucket: Vec<_> = buckets + .into_iter() + .filter(|(_, postings)| !postings.is_empty()) + .map(|(bitset, postings)| { + BitSetPostingUnion::build(postings, BitSetDocSet::from(bitset)) + }) + .collect(); + let other_union = SimpleUnion::build(bitset_unions_per_bucket); + + let union: SimpleUnion> = + SimpleUnion::build(vec![Box::new(sparse_term_unions), Box::new(other_union)]); + + // Return a union of sparse term unions and other term unions + Ok(union) + } +} + +impl Weight for RegexPhraseWeight { + fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { + if let Some(scorer) = self.phrase_scorer(reader, boost)? { + Ok(Box::new(scorer)) + } else { + Ok(Box::new(EmptyScorer)) + } + } + + fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { + let scorer_opt = self.phrase_scorer(reader, 1.0)?; + if scorer_opt.is_none() { + return Err(does_not_match(doc)); + } + let mut scorer = scorer_opt.unwrap(); + if scorer.seek(doc) != doc { + return Err(does_not_match(doc)); + } + let fieldnorm_reader = self.fieldnorm_reader(reader)?; + let fieldnorm_id = fieldnorm_reader.fieldnorm_id(doc); + let phrase_count = scorer.phrase_count(); + let mut explanation = Explanation::new("Phrase Scorer", scorer.score()); + if let Some(similarity_weight) = self.similarity_weight_opt.as_ref() { + explanation.add_detail(similarity_weight.explain(fieldnorm_id, phrase_count)); + } + Ok(explanation) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + use rand::seq::SliceRandom; + + use super::super::tests::create_index; + use crate::docset::TERMINATED; + use crate::query::{wildcard_query_to_regex_str, EnableScoring, RegexPhraseQuery}; + use crate::DocSet; + + proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + #[test] + fn test_phrase_regex_with_random_strings(mut random_strings in proptest::collection::vec("[c-z ]{0,10}", 1..100), num_occurrences in 1..150_usize) { + let mut rng = rand::thread_rng(); + + // Insert "aaa ccc" the specified number of times into the list + for _ in 0..num_occurrences { + random_strings.push("aaa ccc".to_string()); + } + // Shuffle the list, which now contains random strings and the inserted "aaa ccc" + random_strings.shuffle(&mut rng); + + // Compute the positions of "aaa ccc" after the shuffle + let aaa_ccc_positions: Vec = random_strings + .iter() + .enumerate() + .filter_map(|(idx, s)| if s == "aaa ccc" { Some(idx) } else { None }) + .collect(); + + // Create the index with random strings and the fixed string "aaa ccc" + let index = create_index(&random_strings.iter().map(AsRef::as_ref).collect::>())?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + + let phrase_query = RegexPhraseQuery::new(text_field, vec![wildcard_query_to_regex_str("a*"), wildcard_query_to_regex_str("c*")]); + + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + + // Check if the scorer returns the correct document positions for "aaa ccc" + for expected_doc in aaa_ccc_positions { + prop_assert_eq!(phrase_scorer.doc(), expected_doc as u32); + prop_assert_eq!(phrase_scorer.phrase_count(), 1); + phrase_scorer.advance(); + } + prop_assert_eq!(phrase_scorer.advance(), TERMINATED); + } + } + + #[test] + pub fn test_phrase_count() -> crate::Result<()> { + let index = create_index(&["a c", "a a b d a b c", " a b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = RegexPhraseQuery::new(text_field, vec!["a".into(), "b".into()]); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 1); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert_eq!(phrase_scorer.advance(), 2); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + Ok(()) + } + + #[test] + pub fn test_phrase_wildcard() -> crate::Result<()> { + let index = create_index(&["a c", "a aa b d ad b c", " ac b", "bac b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = RegexPhraseQuery::new(text_field, vec!["a.*".into(), "b".into()]); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 1); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert_eq!(phrase_scorer.advance(), 2); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + + Ok(()) + } + + #[test] + pub fn test_phrase_regex() -> crate::Result<()> { + let index = create_index(&["ba b", "a aa b d ad b c", "bac b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = RegexPhraseQuery::new(text_field, vec!["b?a.*".into(), "b".into()]); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 0); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), 1); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert_eq!(phrase_scorer.advance(), 2); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + + Ok(()) + } + + #[test] + pub fn test_phrase_regex_with_slop() -> crate::Result<()> { + let index = create_index(&["aaa bbb ccc ___ abc ddd bbb ccc"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let mut phrase_query = RegexPhraseQuery::new(text_field, vec!["a.*".into(), "c.*".into()]); + phrase_query.set_slop(1); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 0); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + + phrase_query.set_slop(2); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 0); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert_eq!(phrase_scorer.advance(), TERMINATED); + + Ok(()) + } + + #[test] + pub fn test_phrase_regex_double_wildcard() -> crate::Result<()> { + let index = create_index(&["baaab bccccb"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = RegexPhraseQuery::new( + text_field, + vec![ + wildcard_query_to_regex_str("*a*"), + wildcard_query_to_regex_str("*c*"), + ], + ); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 0); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + Ok(()) + } +} diff --git a/src/query/union/bitset_union.rs b/src/query/union/bitset_union.rs new file mode 100644 index 0000000000..8af1703ee0 --- /dev/null +++ b/src/query/union/bitset_union.rs @@ -0,0 +1,89 @@ +use std::cell::RefCell; + +use crate::docset::DocSet; +use crate::postings::Postings; +use crate::query::BitSetDocSet; +use crate::DocId; + +/// Creates a `Posting` that uses the bitset for hits and the docsets for PostingLists. +/// +/// It is used for the regex phrase query, where we need the union of a large amount of +/// terms, but need to keep the docsets for the postings. +pub struct BitSetPostingUnion { + /// The docsets are required to load positions + /// + /// RefCell because we mutate in term_freq + docsets: RefCell>, + /// The already unionized BitSet of the docsets + bitset: BitSetDocSet, +} + +impl BitSetPostingUnion { + pub(crate) fn build( + docsets: Vec, + bitset: BitSetDocSet, + ) -> BitSetPostingUnion { + BitSetPostingUnion { + docsets: RefCell::new(docsets), + bitset, + } + } +} + +impl Postings for BitSetPostingUnion { + fn term_freq(&self) -> u32 { + let curr_doc = self.bitset.doc(); + let mut term_freq = 0; + let mut docsets = self.docsets.borrow_mut(); + for docset in docsets.iter_mut() { + if docset.doc() < curr_doc { + docset.seek(curr_doc); + } + if docset.doc() == curr_doc { + term_freq += docset.term_freq(); + } + } + term_freq + } + + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + let curr_doc = self.bitset.doc(); + let mut docsets = self.docsets.borrow_mut(); + for docset in docsets.iter_mut() { + if docset.doc() < curr_doc { + docset.seek(curr_doc); + } + if docset.doc() == curr_doc { + docset.append_positions_with_offset(offset, output); + } + } + debug_assert!( + !output.is_empty(), + "this method should only be called if positions are available" + ); + output.sort_unstable(); + output.dedup(); + } +} + +impl DocSet for BitSetPostingUnion { + fn advance(&mut self) -> DocId { + self.bitset.advance() + } + + fn seek(&mut self, target: DocId) -> DocId { + self.bitset.seek(target) + } + + fn doc(&self) -> DocId { + self.bitset.doc() + } + + fn size_hint(&self) -> u32 { + self.bitset.size_hint() + } + + fn count_including_deleted(&mut self) -> u32 { + self.bitset.count_including_deleted() + } +} diff --git a/src/query/union.rs b/src/query/union/buffered_union.rs similarity index 50% rename from src/query/union.rs rename to src/query/union/buffered_union.rs index b1f23156a2..5fc946ee11 100644 --- a/src/query/union.rs +++ b/src/query/union/buffered_union.rs @@ -26,7 +26,7 @@ where P: FnMut(&mut T) -> bool { } /// Creates a `DocSet` that iterate through the union of two or more `DocSet`s. -pub struct Union { +pub struct BufferedUnionScorer { docsets: Vec, bitsets: Box<[TinySet; HORIZON_NUM_TINYBITSETS]>, scores: Box<[TScoreCombiner; HORIZON as usize]>, @@ -61,16 +61,16 @@ fn refill( }); } -impl Union { +impl BufferedUnionScorer { pub(crate) fn build( docsets: Vec, score_combiner_fn: impl FnOnce() -> TScoreCombiner, - ) -> Union { + ) -> BufferedUnionScorer { let non_empty_docsets: Vec = docsets .into_iter() .filter(|docset| docset.doc() != TERMINATED) .collect(); - let mut union = Union { + let mut union = BufferedUnionScorer { docsets: non_empty_docsets, bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]), scores: Box::new([score_combiner_fn(); HORIZON as usize]), @@ -121,7 +121,7 @@ impl Union DocSet for Union +impl DocSet for BufferedUnionScorer where TScorer: Scorer, TScoreCombiner: ScoreCombiner, @@ -230,7 +230,7 @@ where } } -impl Scorer for Union +impl Scorer for BufferedUnionScorer where TScoreCombiner: ScoreCombiner, TScorer: Scorer, @@ -239,205 +239,3 @@ where self.score } } - -#[cfg(test)] -mod tests { - - use std::collections::BTreeSet; - - use super::{Union, HORIZON}; - use crate::docset::{DocSet, TERMINATED}; - use crate::postings::tests::test_skip_against_unoptimized; - use crate::query::score_combiner::DoNothingCombiner; - use crate::query::{ConstScorer, VecDocSet}; - use crate::{tests, DocId}; - - fn aux_test_union(vals: Vec>) { - let mut val_set: BTreeSet = BTreeSet::new(); - for vs in &vals { - for &v in vs { - val_set.insert(v); - } - } - let union_vals: Vec = val_set.into_iter().collect(); - let mut union_expected = VecDocSet::from(union_vals); - let make_union = || { - Union::build( - vals.iter() - .cloned() - .map(VecDocSet::from) - .map(|docset| ConstScorer::new(docset, 1.0)) - .collect::>>(), - DoNothingCombiner::default, - ) - }; - let mut union: Union<_, DoNothingCombiner> = make_union(); - let mut count = 0; - while union.doc() != TERMINATED { - assert_eq!(union_expected.doc(), union.doc()); - assert_eq!(union_expected.advance(), union.advance()); - count += 1; - } - assert_eq!(union_expected.advance(), TERMINATED); - assert_eq!(count, make_union().count_including_deleted()); - } - - #[test] - fn test_union() { - aux_test_union(vec![ - vec![1, 3333, 100000000u32], - vec![1, 2, 100000000u32], - vec![1, 2, 100000000u32], - vec![], - ]); - aux_test_union(vec![ - vec![1, 3333, 100000000u32], - vec![1, 2, 100000000u32], - vec![1, 2, 100000000u32], - vec![], - ]); - aux_test_union(vec![ - tests::sample_with_seed(100_000, 0.01, 1), - tests::sample_with_seed(100_000, 0.05, 2), - tests::sample_with_seed(100_000, 0.001, 3), - ]); - } - - fn test_aux_union_skip(docs_list: &[Vec], skip_targets: Vec) { - let mut btree_set = BTreeSet::new(); - for docs in docs_list { - btree_set.extend(docs.iter().cloned()); - } - let docset_factory = || { - let res: Box = Box::new(Union::build( - docs_list - .iter() - .cloned() - .map(VecDocSet::from) - .map(|docset| ConstScorer::new(docset, 1.0)) - .collect::>(), - DoNothingCombiner::default, - )); - res - }; - let mut docset = docset_factory(); - for el in btree_set { - assert_eq!(el, docset.doc()); - docset.advance(); - } - assert_eq!(docset.doc(), TERMINATED); - test_skip_against_unoptimized(docset_factory, skip_targets); - } - - #[test] - fn test_union_skip_corner_case() { - test_aux_union_skip(&[vec![165132, 167382], vec![25029, 25091]], vec![25029]); - } - - #[test] - fn test_union_skip_corner_case2() { - test_aux_union_skip( - &[vec![1u32, 1u32 + HORIZON], vec![2u32, 1000u32, 10_000u32]], - vec![0u32, 1u32, 2u32, 3u32, 1u32 + HORIZON, 2u32 + HORIZON], - ); - } - - #[test] - fn test_union_skip_corner_case3() { - let mut docset = Union::build( - vec![ - ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])), - ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])), - ], - DoNothingCombiner::default, - ); - assert_eq!(docset.doc(), 0u32); - assert_eq!(docset.seek(0u32), 0u32); - assert_eq!(docset.seek(0u32), 0u32); - assert_eq!(docset.doc(), 0u32) - } - - #[test] - fn test_union_skip_random() { - test_aux_union_skip( - &[ - vec![1, 2, 3, 7], - vec![1, 3, 9, 10000], - vec![1, 3, 8, 9, 100], - ], - vec![1, 2, 3, 5, 6, 7, 8, 100], - ); - test_aux_union_skip( - &[ - tests::sample_with_seed(100_000, 0.001, 1), - tests::sample_with_seed(100_000, 0.002, 2), - tests::sample_with_seed(100_000, 0.005, 3), - ], - tests::sample_with_seed(100_000, 0.01, 4), - ); - } - - #[test] - fn test_union_skip_specific() { - test_aux_union_skip( - &[ - vec![1, 2, 3, 7], - vec![1, 3, 9, 10000], - vec![1, 3, 8, 9, 100], - ], - vec![1, 2, 3, 7, 8, 9, 99, 100, 101, 500, 20000], - ); - } -} - -#[cfg(all(test, feature = "unstable"))] -mod bench { - - use test::Bencher; - - use crate::query::score_combiner::DoNothingCombiner; - use crate::query::{ConstScorer, Union, VecDocSet}; - use crate::{tests, DocId, DocSet, TERMINATED}; - - #[bench] - fn bench_union_3_high(bench: &mut Bencher) { - let union_docset: Vec> = vec![ - tests::sample_with_seed(100_000, 0.1, 0), - tests::sample_with_seed(100_000, 0.2, 1), - ]; - bench.iter(|| { - let mut v = Union::build( - union_docset - .iter() - .map(|doc_ids| VecDocSet::from(doc_ids.clone())) - .map(|docset| ConstScorer::new(docset, 1.0)) - .collect::>(), - DoNothingCombiner::default, - ); - while v.doc() != TERMINATED { - v.advance(); - } - }); - } - #[bench] - fn bench_union_3_low(bench: &mut Bencher) { - let union_docset: Vec> = vec![ - tests::sample_with_seed(100_000, 0.01, 0), - tests::sample_with_seed(100_000, 0.05, 1), - tests::sample_with_seed(100_000, 0.001, 2), - ]; - bench.iter(|| { - let mut v = Union::build( - union_docset - .iter() - .map(|doc_ids| VecDocSet::from(doc_ids.clone())) - .map(|docset| ConstScorer::new(docset, 1.0)) - .collect::>(), - DoNothingCombiner::default, - ); - while v.doc() != TERMINATED { - v.advance(); - } - }); - } -} diff --git a/src/query/union/mod.rs b/src/query/union/mod.rs new file mode 100644 index 0000000000..84153e272f --- /dev/null +++ b/src/query/union/mod.rs @@ -0,0 +1,303 @@ +mod bitset_union; +mod buffered_union; +mod simple_union; + +pub use bitset_union::BitSetPostingUnion; +pub use buffered_union::BufferedUnionScorer; +pub use simple_union::SimpleUnion; + +#[cfg(test)] +mod tests { + + use std::collections::BTreeSet; + + use common::BitSet; + + use super::{SimpleUnion, *}; + use crate::docset::{DocSet, TERMINATED}; + use crate::postings::tests::test_skip_against_unoptimized; + use crate::query::score_combiner::DoNothingCombiner; + use crate::query::union::bitset_union::BitSetPostingUnion; + use crate::query::{BitSetDocSet, ConstScorer, VecDocSet}; + use crate::{tests, DocId}; + + fn vec_doc_set_from_docs_list( + docs_list: &[Vec], + ) -> impl Iterator + '_ { + docs_list.iter().cloned().map(VecDocSet::from) + } + fn union_from_docs_list(docs_list: &[Vec]) -> Box { + Box::new(BufferedUnionScorer::build( + vec_doc_set_from_docs_list(docs_list) + .map(|docset| ConstScorer::new(docset, 1.0)) + .collect::>>(), + DoNothingCombiner::default, + )) + } + + fn posting_list_union_from_docs_list(docs_list: &[Vec]) -> Box { + Box::new(BitSetPostingUnion::build( + vec_doc_set_from_docs_list(docs_list).collect::>(), + bitset_from_docs_list(docs_list), + )) + } + fn simple_union_from_docs_list(docs_list: &[Vec]) -> Box { + Box::new(SimpleUnion::build( + vec_doc_set_from_docs_list(docs_list).collect::>(), + )) + } + fn bitset_from_docs_list(docs_list: &[Vec]) -> BitSetDocSet { + let max_doc = docs_list + .iter() + .flat_map(|docs| docs.iter().copied()) + .max() + .unwrap_or(0); + let mut doc_bitset = BitSet::with_max_value(max_doc + 1); + for docs in docs_list { + for &doc in docs { + doc_bitset.insert(doc); + } + } + BitSetDocSet::from(doc_bitset) + } + fn aux_test_union(docs_list: &[Vec]) { + for constructor in [ + posting_list_union_from_docs_list, + simple_union_from_docs_list, + union_from_docs_list, + ] { + aux_test_union_with_constructor(constructor, docs_list); + } + } + fn aux_test_union_with_constructor(constructor: F, docs_list: &[Vec]) + where F: Fn(&[Vec]) -> Box { + let mut val_set: BTreeSet = BTreeSet::new(); + for vs in docs_list { + for &v in vs { + val_set.insert(v); + } + } + let union_vals: Vec = val_set.into_iter().collect(); + let mut union_expected = VecDocSet::from(union_vals); + let make_union = || constructor(docs_list); + let mut union = make_union(); + let mut count = 0; + while union.doc() != TERMINATED { + assert_eq!(union_expected.doc(), union.doc()); + assert_eq!(union_expected.advance(), union.advance()); + count += 1; + } + assert_eq!(union_expected.advance(), TERMINATED); + assert_eq!(count, make_union().count_including_deleted()); + } + + use proptest::prelude::*; + + proptest! { + #[test] + fn test_union_is_same(vecs in prop::collection::vec( + prop::collection::vec(0u32..100, 1..10) + .prop_map(|mut inner| { + inner.sort_unstable(); + inner.dedup(); + inner + }), + 1..10 + ), + seek_docids in prop::collection::vec(0u32..100, 0..10).prop_map(|mut inner| { + inner.sort_unstable(); + inner + })) { + test_docid_with_skip(&vecs, &seek_docids); + } + } + + fn test_docid_with_skip(vecs: &[Vec], skip_targets: &[DocId]) { + let mut union1 = posting_list_union_from_docs_list(vecs); + let mut union2 = simple_union_from_docs_list(vecs); + let mut union3 = union_from_docs_list(vecs); + + // Check initial sequential advance + while union1.doc() != TERMINATED { + assert_eq!(union1.doc(), union2.doc()); + assert_eq!(union1.doc(), union3.doc()); + assert_eq!(union1.advance(), union2.advance()); + assert_eq!(union1.doc(), union3.advance()); + } + + // Reset and test seek functionality + let mut union1 = posting_list_union_from_docs_list(vecs); + let mut union2 = simple_union_from_docs_list(vecs); + let mut union3 = union_from_docs_list(vecs); + + for &seek_docid in skip_targets { + union1.seek(seek_docid); + union2.seek(seek_docid); + union3.seek(seek_docid); + + // Verify that all unions have the same document after seeking + assert_eq!(union3.doc(), union1.doc()); + assert_eq!(union3.doc(), union2.doc()); + } + } + + #[test] + fn test_union() { + aux_test_union(&[ + vec![1, 3333, 100000000u32], + vec![1, 2, 100000000u32], + vec![1, 2, 100000000u32], + vec![], + ]); + aux_test_union(&[ + vec![1, 3333, 100000000u32], + vec![1, 2, 100000000u32], + vec![1, 2, 100000000u32], + vec![], + ]); + aux_test_union(&[ + tests::sample_with_seed(100_000, 0.01, 1), + tests::sample_with_seed(100_000, 0.05, 2), + tests::sample_with_seed(100_000, 0.001, 3), + ]); + } + + fn test_aux_union_skip(docs_list: &[Vec], skip_targets: Vec) { + for constructor in [ + posting_list_union_from_docs_list, + simple_union_from_docs_list, + union_from_docs_list, + ] { + test_aux_union_skip_with_constructor(constructor, docs_list, skip_targets.clone()); + } + } + fn test_aux_union_skip_with_constructor( + constructor: F, + docs_list: &[Vec], + skip_targets: Vec, + ) where + F: Fn(&[Vec]) -> Box, + { + let mut btree_set = BTreeSet::new(); + for docs in docs_list { + btree_set.extend(docs.iter().cloned()); + } + let docset_factory = || { + let res: Box = constructor(docs_list); + res + }; + let mut docset = constructor(docs_list); + for el in btree_set { + assert_eq!(el, docset.doc()); + docset.advance(); + } + assert_eq!(docset.doc(), TERMINATED); + test_skip_against_unoptimized(docset_factory, skip_targets); + } + + #[test] + fn test_union_skip_corner_case() { + test_aux_union_skip(&[vec![165132, 167382], vec![25029, 25091]], vec![25029]); + } + + #[test] + fn test_union_skip_corner_case2() { + test_aux_union_skip( + &[vec![1u32, 1u32 + 100], vec![2u32, 1000u32, 10_000u32]], + vec![0u32, 1u32, 2u32, 3u32, 1u32 + 100, 2u32 + 100], + ); + } + + #[test] + fn test_union_skip_corner_case3() { + let mut docset = posting_list_union_from_docs_list(&[vec![0u32, 5u32], vec![1u32, 4u32]]); + assert_eq!(docset.doc(), 0u32); + assert_eq!(docset.seek(0u32), 0u32); + assert_eq!(docset.seek(0u32), 0u32); + assert_eq!(docset.doc(), 0u32) + } + + #[test] + fn test_union_skip_random() { + test_aux_union_skip( + &[ + vec![1, 2, 3, 7], + vec![1, 3, 9, 10000], + vec![1, 3, 8, 9, 100], + ], + vec![1, 2, 3, 5, 6, 7, 8, 100], + ); + test_aux_union_skip( + &[ + tests::sample_with_seed(100_000, 0.001, 1), + tests::sample_with_seed(100_000, 0.002, 2), + tests::sample_with_seed(100_000, 0.005, 3), + ], + tests::sample_with_seed(100_000, 0.01, 4), + ); + } + + #[test] + fn test_union_skip_specific() { + test_aux_union_skip( + &[ + vec![1, 2, 3, 7], + vec![1, 3, 9, 10000], + vec![1, 3, 8, 9, 100], + ], + vec![1, 2, 3, 7, 8, 9, 99, 100, 101, 500, 20000], + ); + } +} + +#[cfg(all(test, feature = "unstable"))] +mod bench { + + use test::Bencher; + + use crate::query::score_combiner::DoNothingCombiner; + use crate::query::{BufferedUnionScorer, ConstScorer, VecDocSet}; + use crate::{tests, DocId, DocSet, TERMINATED}; + + #[bench] + fn bench_union_3_high(bench: &mut Bencher) { + let union_docset: Vec> = vec![ + tests::sample_with_seed(100_000, 0.1, 0), + tests::sample_with_seed(100_000, 0.2, 1), + ]; + bench.iter(|| { + let mut v = BufferedUnionScorer::build( + union_docset + .iter() + .map(|doc_ids| VecDocSet::from(doc_ids.clone())) + .map(|docset| ConstScorer::new(docset, 1.0)) + .collect::>(), + DoNothingCombiner::default, + ); + while v.doc() != TERMINATED { + v.advance(); + } + }); + } + #[bench] + fn bench_union_3_low(bench: &mut Bencher) { + let union_docset: Vec> = vec![ + tests::sample_with_seed(100_000, 0.01, 0), + tests::sample_with_seed(100_000, 0.05, 1), + tests::sample_with_seed(100_000, 0.001, 2), + ]; + bench.iter(|| { + let mut v = BufferedUnionScorer::build( + union_docset + .iter() + .map(|doc_ids| VecDocSet::from(doc_ids.clone())) + .map(|docset| ConstScorer::new(docset, 1.0)) + .collect::>(), + DoNothingCombiner::default, + ); + while v.doc() != TERMINATED { + v.advance(); + } + }); + } +} diff --git a/src/query/union/simple_union.rs b/src/query/union/simple_union.rs new file mode 100644 index 0000000000..041d4c90e1 --- /dev/null +++ b/src/query/union/simple_union.rs @@ -0,0 +1,112 @@ +use crate::docset::{DocSet, TERMINATED}; +use crate::postings::Postings; +use crate::DocId; + +/// A `SimpleUnion` is a `DocSet` that is the union of multiple `DocSet`. +/// Unlike `BufferedUnion`, it doesn't do any horizon precomputation. +/// +/// For that reason SimpleUnion is a good choice for queries that skip a lot. +pub struct SimpleUnion { + docsets: Vec, + doc: DocId, +} + +impl SimpleUnion { + pub(crate) fn build(mut docsets: Vec) -> SimpleUnion { + docsets.retain(|docset| docset.doc() != TERMINATED); + let mut docset = SimpleUnion { docsets, doc: 0 }; + + docset.initialize_first_doc_id(); + + docset + } + + fn initialize_first_doc_id(&mut self) { + let mut next_doc = TERMINATED; + + for docset in &self.docsets { + next_doc = next_doc.min(docset.doc()); + } + self.doc = next_doc; + } + + fn advance_to_next(&mut self) -> DocId { + let mut next_doc = TERMINATED; + + for docset in &mut self.docsets { + if docset.doc() <= self.doc { + docset.advance(); + } + next_doc = next_doc.min(docset.doc()); + } + self.doc = next_doc; + self.doc + } +} + +impl Postings for SimpleUnion { + fn term_freq(&self) -> u32 { + let mut term_freq = 0; + for docset in &self.docsets { + let doc = docset.doc(); + if doc == self.doc { + term_freq += docset.term_freq(); + } + } + term_freq + } + + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + for docset in &mut self.docsets { + let doc = docset.doc(); + if doc == self.doc { + docset.append_positions_with_offset(offset, output); + } + } + output.sort_unstable(); + output.dedup(); + } +} + +impl DocSet for SimpleUnion { + fn advance(&mut self) -> DocId { + self.advance_to_next(); + self.doc + } + + fn seek(&mut self, target: DocId) -> DocId { + self.doc = TERMINATED; + for docset in &mut self.docsets { + if docset.doc() < target { + docset.seek(target); + } + if docset.doc() < self.doc { + self.doc = docset.doc(); + } + } + self.doc + } + + fn doc(&self) -> DocId { + self.doc + } + + fn size_hint(&self) -> u32 { + self.docsets + .iter() + .map(|docset| docset.size_hint()) + .max() + .unwrap_or(0u32) + } + + fn count_including_deleted(&mut self) -> u32 { + if self.doc == TERMINATED { + return 0u32; + } + let mut count = 1u32; + while self.advance_to_next() != TERMINATED { + count += 1; + } + count + } +}