From 4544422c360175e42919fc7ffec8660d4a252c50 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Mon, 2 Dec 2024 12:00:26 -0600 Subject: [PATCH] implement BM25 --- .../index/sai/disk/v1/IndexSearcher.java | 23 +- .../sai/disk/v1/InvertedIndexSearcher.java | 177 ++++++------ .../v1/postings/IntersectingPostingList.java | 134 ++++++++++ .../index/sai/memory/TrieMemoryIndex.java | 7 + .../index/sai/memory/TrieMemtableIndex.java | 163 ++++++++++-- .../cassandra/index/sai/plan/Orderer.java | 20 +- .../apache/cassandra/index/sai/plan/Plan.java | 110 ++++++-- .../index/sai/plan/QueryController.java | 1 + .../plan/StorageAttachedIndexSearcher.java | 57 ++-- .../index/sai/plan/TopKProcessor.java | 27 +- .../cassandra/index/sai/utils/BM25Utils.java | 178 +++++++++++++ .../utils/PrimaryKeyWithByteComparable.java | 10 +- .../index/sai/utils/PrimaryKeyWithScore.java | 10 +- .../sai/utils/PrimaryKeyWithSortKey.java | 11 +- .../test/sai/BM25DistributedTest.java | 121 +++++++++ .../cassandra/index/sai/cql/BM25Test.java | 251 ++++++++++++++++-- .../postings/IntersectingPostingListTest.java | 213 +++++++++++++++ 17 files changed, 1313 insertions(+), 200 deletions(-) create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java create mode 100644 src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java create mode 100644 test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java create mode 100644 test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java index 52988a685cd5..5fb246183921 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java @@ -68,9 +68,7 @@ public abstract class IndexSearcher implements Closeable, SegmentOrdering protected final SegmentMetadata metadata; protected final IndexContext indexContext; - private static final SSTableReadsListener NOOP_LISTENER = new SSTableReadsListener() {}; - - private final ColumnFilter columnFilter; + protected final ColumnFilter columnFilter; protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, PerIndexFiles perIndexFiles, @@ -90,30 +88,37 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract long indexFileCacheSize(); /** - * Search on-disk index synchronously + * Search on-disk index synchronously. Used for WHERE clause predicates, including BOUNDED_ANN. * * @param expression to filter on disk index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics * @param defer create the iterator in a deferred state - * @param limit the num of rows to returned, used by ANN index + * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected! * @return {@link KeyRangeIterator} that matches given expression */ public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException; /** - * Order the on-disk index synchronously and produce an iterator in score order + * Order the rows by the giving Orderer. Used for ORDER BY clause when + * (1) the WHERE predicate is either a partition restriction or a range restriction on the index, + * (2) there is no WHERE predicate, or + * (3) the planner determines it is better to post-filter the ordered results by the predicate. * * @param orderer the object containing the ordering logic * @param slice optional predicate to get a slice of the index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics - * @param limit the num of rows to returned, used by ANN index + * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected! * @return an iterator of {@link PrimaryKeyWithSortKey} in score order */ public abstract CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException; - + /** + * Order the rows by the giving Orderer. Used for ORDER BY clause when the WHERE predicates + * have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine + * the initial number of results returned, not a maximum. + */ @Override public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext context, List keys, Orderer orderer, int limit) throws IOException { @@ -124,7 +129,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea { var slices = Slices.with(indexContext.comparator(), Slice.make(key.clustering())); // TODO if we end up needing to read the row still, is it better to store offset and use reader.unfilteredAt? - try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, NOOP_LISTENER)) + try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) { if (iter.hasNext()) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 96dc54b63130..1bcccd7c5cb5 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -19,42 +19,49 @@ package org.apache.cassandra.index.sai.disk.v1; import java.io.IOException; +import java.io.UncheckedIOException; import java.lang.invoke.MethodHandles; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; import java.util.HashMap; +import java.util.Iterator; +import java.util.List; import java.util.Map; -import java.util.PriorityQueue; +import java.util.function.Function; import java.util.stream.Collectors; import com.google.common.base.MoreObjects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.PartitionPosition; -import org.apache.cassandra.db.RegularAndStaticColumns; +import org.apache.cassandra.db.Slice; import org.apache.cassandra.db.Slices; -import org.apache.cassandra.db.filter.ColumnFilter; import org.apache.cassandra.db.rows.Cell; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.SSTableContext; +import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.disk.v1.postings.MergePostingList; +import org.apache.cassandra.index.sai.disk.v1.postings.IntersectingPostingList; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.MulticastQueryEventListeners; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable; -import org.apache.cassandra.index.sai.utils.RowIdWithScore; import org.apache.cassandra.index.sai.utils.SAICodecUtils; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.sstable.format.SSTableReadsListener; @@ -75,23 +82,12 @@ public class InvertedIndexSearcher extends IndexSearcher private final TermsReader reader; private final QueryEventListener.TrieIndexEventListener perColumnEventListener; private final Version version; - private static final int ROW_COUNT = 1_000_000; // TODO: Replace with actual count private static final float K1 = 1.2f; // BM25 term frequency saturation parameter private static final float B = 0.75f; // BM25 length normalization parameter private final boolean filterRangeResults; private final SSTableReader sstable; - private static class DocumentStats { - final Map termFrequencies = new HashMap<>(); - float length; - final int rowId; - - DocumentStats(int rowId) { - this.rowId = rowId; - } - } - protected InvertedIndexSearcher(SSTableContext sstableContext, PerIndexFiles perIndexFiles, SegmentMetadata segmentMetadata, @@ -161,19 +157,16 @@ else if (exp.getOp() == Expression.Op.RANGE) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression: " + exp)); } - private Cell readColumn(SSTableReader sstable, PrimaryKey primaryKey) { - var decoratedKey = primaryKey.partitionKey(); - - // Create a ColumnFilter to select only the specific column - var column = indexContext.getDefinition(); - var columnFilter = ColumnFilter.selection(RegularAndStaticColumns.of(column)); - - // TODO Slices.ALL is not correct when there are multiple rows per partition - try (var rowIterator = sstable.iterator(decoratedKey, Slices.ALL, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) { + private Cell readColumn(SSTableReader sstable, PrimaryKey primaryKey) + { + var dk = primaryKey.partitionKey(); + var slices = Slices.with(indexContext.comparator(), Slice.make(primaryKey.clustering())); + try (var rowIterator = sstable.iterator(dk, slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) + { var unfiltered = rowIterator.next(); assert unfiltered.isRow() : unfiltered; Row row = (Row) unfiltered; - return row.getCell(column); + return row.getCell(indexContext.getDefinition()); } } @@ -186,83 +179,83 @@ public CloseableIterator orderBy(Orderer orderer, Express return toMetaSortedIterator(iter, queryContext); } - var queryAnalyzer = indexContext.getQueryAnalyzerFactory().create(); - var docAnalyzer = indexContext.getAnalyzerFactory().create(); + // find documents that match each term + var queryTerms = orderer.extractQueryTerms(); + var postingLists = queryTerms.stream() + .collect(Collectors.toMap(Function.identity(), term -> + { + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + return reader.exactMatch(encodedTerm, listener, queryContext); + })); + // extract the match count for each + var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size())); - // Split query into terms - var queryTerms = new ArrayList(); - queryAnalyzer.reset(orderer.term); - try - { - queryAnalyzer.forEachRemaining(queryTerms::add); - } - finally - { - queryAnalyzer.end(); - } - - // find documents that match all terms - var postingLists = queryTerms.stream().map(term -> - { - var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); - var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); - return reader.exactMatch(encodedTerm, listener, queryContext); - }).collect(Collectors.toList()); - - Map documents = new HashMap<>(); try (var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap(); - var merged = MergePostingList.merge(postingLists)) + var merged = IntersectingPostingList.intersect(List.copyOf(postingLists.values()))) { - // First pass: collect term frequencies and document lengths - int rowId; - while ((rowId = merged.nextPosting()) != END_OF_STREAM) { - var pk = pkm.primaryKeyFromRowId(rowId); - var cell = readColumn(sstable, pk); - - var stats = new DocumentStats(rowId); - docAnalyzer.reset(cell.buffer()); - try + var it = new AbstractIterator() { + @Override + protected PrimaryKey computeNext() { - while (docAnalyzer.hasNext()) { - var docTerm = docAnalyzer.next(); - if (queryTerms.stream().anyMatch(q -> q.equals(docTerm))) { - stats.termFrequencies.merge(docTerm, 1, Integer::sum); - } - stats.length++; + try + { + int rowId = merged.nextPosting(); + if (rowId == PostingList.END_OF_STREAM) + return endOfData(); + return pkm.primaryKeyFromRowId(rowId); + } + catch (IOException e) + { + throw new UncheckedIOException(e); } } - finally - { - docAnalyzer.end(); - } - documents.put(rowId, stats); - } - - // Ccompute average document length - double avgDocLength = documents.values().stream().mapToDouble(d -> d.length).average().orElse(0.0); + }; + return bm25Internal(it, queryTerms, documentFrequencies); + } + } - // Second pass: calculate BM25 scores - var scoredDocs = new PriorityQueue((a, b) -> Double.compare(b.score, a.score)); + private CloseableIterator bm25Internal(Iterator keyIterator, + List queryTerms, + Map documentFrequencies) + { + var docStats = new BM25Utils.DocStats(documentFrequencies, sstable.getTotalRows()); + return BM25Utils.computeScores(keyIterator, + queryTerms, + docStats, + indexContext, + sstable.descriptor.id, + pk -> readColumn(sstable, pk)); + } - for (var doc : documents.values()) { - double score = 0.0f; - for (var queryTerm : queryTerms) { - int tf = doc.termFrequencies.get(queryTerm); - double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.length / avgDocLength)); - // TODO this is broken, need counts for bare term matches for each term - double idf = Math.log(1 + (ROW_COUNT - documents.size() + 0.5) / (documents.size() + 0.5)); - score += normalizedTf * idf; - } + @Override + public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext queryContext, List keys, Orderer orderer, int limit) throws IOException + { + if (!orderer.isBM25()) + return super.orderResultsBy(reader, queryContext, keys, orderer, limit); - scoredDocs.add(new RowIdWithScore(doc.rowId, (float) score)); - if (scoredDocs.size() > limit) { - scoredDocs.poll(); // Remove lowest scoring doc - } + var queryTerms = orderer.extractQueryTerms(); + var documentFrequencies = new HashMap(); + boolean hasHistograms = metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram(); + for (ByteBuffer term : queryTerms) + { + long matches; + if (hasHistograms) + { + matches = metadata.estimateNumRowsMatching(new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term)); } - - // Convert scored results back to posting list - return toMetaSortedIterator(CloseableIterator.wrap(scoredDocs.iterator()), queryContext); + else + { + // Without histograms, need to do an actual index scan + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + var postingList = this.reader.exactMatch(encodedTerm, listener, queryContext); + matches = postingList.size(); + FileUtils.closeQuietly(postingList); + } + documentFrequencies.put(term, matches); } + return bm25Internal(keys.iterator(), queryTerms, documentFrequencies); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java new file mode 100644 index 000000000000..61fa01ff848b --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.util.List; +import javax.annotation.concurrent.NotThreadSafe; + +import org.apache.cassandra.index.sai.disk.PostingList; + +import static java.lang.Math.max; + +/** + * Performs intersection operations on multiple PostingLists, returning only postings + * that appear in all inputs. + */ +@NotThreadSafe +public class IntersectingPostingList implements PostingList +{ + private final List postingLists; + private final int size; + + // currentPostings state is effectively local to findNextIntersection, but we keep it + // around as a field to avoid repeated allocations there + private final int[] currentRowIds; + + private IntersectingPostingList(List postingLists) + { + assert !postingLists.isEmpty(); + this.postingLists = postingLists; + this.size = postingLists.stream() + .mapToInt(PostingList::size) + .min() + .orElse(0); + this.currentRowIds = new int[postingLists.size()]; + } + + public static PostingList intersect(List postingLists) + { + if (postingLists.size() == 1) + return postingLists.get(0); + + if (postingLists.stream().anyMatch(PostingList::isEmpty)) + return PostingList.EMPTY; + + return new IntersectingPostingList(postingLists); + } + + @Override + public int nextPosting() throws IOException + { + return findNextIntersection(Integer.MIN_VALUE, false); + } + + @Override + public int advance(int targetRowID) throws IOException + { + assert targetRowID >= 0 : targetRowID; + return findNextIntersection(targetRowID, true); + } + + private int findNextIntersection(int targetRowID, boolean isAdvance) throws IOException + { + // Initialize currentRowIds from the underlying posting lists + for (int i = 0; i < postingLists.size(); i++) + { + currentRowIds[i] = isAdvance + ? postingLists.get(i).advance(targetRowID) + : postingLists.get(i).nextPosting(); + + if (currentRowIds[i] == END_OF_STREAM) + return END_OF_STREAM; + } + + while (true) + { + // Find the maximum row ID among all posting lists + int maxRowId = targetRowID; + for (int rowId : currentRowIds) + maxRowId = max(maxRowId, rowId); + + // Advance any posting list that's behind the maximum + boolean allMatch = true; + for (int i = 0; i < postingLists.size(); i++) + { + if (currentRowIds[i] < maxRowId) + { + currentRowIds[i] = postingLists.get(i).advance(maxRowId); + if (currentRowIds[i] == END_OF_STREAM) + return END_OF_STREAM; + allMatch = false; + } + } + + // If all posting lists have the same row ID, we've found an intersection + if (allMatch) + return maxRowId; + + // Otherwise, continue searching with the new maximum as target + targetRowID = maxRowId; + } + } + + @Override + public int size() + { + return size; + } + + @Override + public void close() throws IOException + { + for (PostingList list : postingLists) + list.close(); + } +} + + diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java index 57517d9192d0..3c7f945a4194 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java @@ -662,6 +662,13 @@ public void close() throws IOException } } + /** + * Iterator that provides ordered access to all indexed terms and their associated primary keys + * in the TrieMemoryIndex. For each term in the index, yields PrimaryKeyWithSortKey objects that + * combine a primary key with its associated term. + *

+ * A more verbose name could be KeysMatchingTermsByTermIterator. + */ private class AllTermsIterator extends AbstractIterator { private final Iterator> iterator; diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java index d01715023dbc..63710a28c290 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Objects; @@ -32,22 +33,30 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.Runnables; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DataRange; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.RegularAndStaticColumns; +import org.apache.cassandra.db.filter.ColumnFilter; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.memtable.ShardBoundaries; import org.apache.cassandra.db.memtable.TrieMemtable; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Bounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeConcatIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithByteComparable; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; @@ -237,15 +246,45 @@ public List> orderBy(QueryContext query int startShard = boundaries.getShardForToken(keyRange.left.getToken()); int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken()); - var iterators = new ArrayList>(endShard - startShard + 1); - - for (int shard = startShard; shard <= endShard; ++shard) + if (!orderer.isBM25()) { - assert rangeIndexes[shard] != null; - iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + var iterators = new ArrayList>(endShard - startShard + 1); + for (int shard = startShard; shard <= endShard; ++shard) + { + assert rangeIndexes[shard] != null; + iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + } + return iterators; } - return iterators; + // BM25 + var queryTerms = orderer.extractQueryTerms(); + + // Intersect iterators to find documents containing all terms + var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms); + var intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build(); + + // Compute BM25 scores + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + return List.of(BM25Utils.computeScores(intersectedIterator, + queryTerms, + docStats, + indexContext, + memtable, + this::getCellForKey)); + } + + private List keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds keyRange, List queryTerms) + { + List termIterators = new ArrayList<>(queryTerms.size()); + for (ByteBuffer term : queryTerms) + { + Expression expr = new Expression(indexContext); + expr.add(Operator.ANALYZER_MATCHES, term); + KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE); + termIterators.add(iterator); + } + return termIterators; } @Override @@ -257,32 +296,98 @@ public long estimateMatchingRowsCount(Expression expression, AbstractBounds orderResultsBy(QueryContext context, List keys, Orderer orderer, int limit) + public CloseableIterator orderResultsBy(QueryContext queryContext, List keys, Orderer orderer, int limit) { if (keys.isEmpty()) return CloseableIterator.emptyIterator(); - return SortingIterator.createCloseable( - orderer.getComparator(), - keys, - key -> + + if (!orderer.isBM25()) + { + return SortingIterator.createCloseable( + orderer.getComparator(), + keys, + key -> + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + return null; + + // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what + // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. + var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); + return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); + }, + Runnables.doNothing() + ); + } + + // BM25 + var queryTerms = orderer.extractQueryTerms(); + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + return BM25Utils.computeScores(keys.iterator(), + queryTerms, + docStats, + indexContext, + memtable, + this::getCellForKey); + } + + /** + * Count document frequencies for each term using brute force + */ + private BM25Utils.DocStats computeDocumentFrequencies(QueryContext queryContext, List queryTerms) + { + var termIterators = keyIteratorsPerTerm(queryContext, Bounds.unbounded(indexContext.getPartitioner()), queryTerms); + var documentFrequencies = new HashMap(); + for (int i = 0; i < queryTerms.size(); i++) + { + // KeyRangeIterator.getMaxKeys is not accurate enough, we have to count them + long keys = 0; + for (var it = termIterators.get(i); it.hasNext(); it.next()) + keys++; + documentFrequencies.put(queryTerms.get(i), keys); + } + long docCount = 0; + + try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())), + DataRange.allData(memtable.metadata().partitioner))) + { + while (it.hasNext()) { - var partition = memtable.getPartition(key.partitionKey()); - if (partition == null) - return null; - var row = partition.getRow(key.clustering()); - if (row == null) - return null; - var cell = row.getCell(indexContext.getDefinition()); - if (cell == null) - return null; - - // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what - // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. - var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); - return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); - }, - Runnables.doNothing() - ); + var partitions = it.next(); + while (partitions.hasNext()) + { + var unfiltered = partitions.next(); + if (!unfiltered.isRow()) + continue; + var row = (Row) unfiltered; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + continue; + + docCount++; + } + } + } + return new BM25Utils.DocStats(documentFrequencies, docCount); + } + + @Nullable + private org.apache.cassandra.db.rows.Cell getCellForKey(PrimaryKey key) + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + return row.getCell(indexContext.getDefinition()); } private ByteComparable encode(ByteBuffer input) diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index 1686a0bceffd..6193e0a37110 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -19,6 +19,7 @@ package org.apache.cassandra.index.sai.plan; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.EnumSet; @@ -77,7 +78,7 @@ public boolean isAscending() public Comparator getComparator() { - // ANN's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue + // ANN/BM25's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue return (isAscending() || isANN() || isBM25()) ? Comparator.naturalOrder() : Comparator.reverseOrder(); } @@ -130,4 +131,21 @@ public float[] getVectorTerm() vector = TypeUtil.decomposeVector(context.getValidator(), term); return vector; } + + public ArrayList extractQueryTerms() + { + var queryAnalyzer = context.getQueryAnalyzerFactory().create(); + // Split query into terms + var queryTerms = new ArrayList(); + queryAnalyzer.reset(term); + try + { + queryAnalyzer.forEachRemaining(queryTerms::add); + } + finally + { + queryAnalyzer.end(); + } + return queryTerms; + } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java index 2bb379a16f14..46baa4f59a87 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java @@ -1243,10 +1243,12 @@ protected double estimateSelectivity() @Override protected KeysIterationCost estimateCost() { - return ordering.isANN() - ? estimateAnnSortCost() - : estimateGlobalSortCost(); - + if (ordering.isANN()) + return estimateAnnSortCost(); + else if (ordering.isBM25()) + return estimateBm25SortCost(); + else + return estimateGlobalSortCost(); } private KeysIterationCost estimateAnnSortCost() @@ -1263,6 +1265,21 @@ private KeysIterationCost estimateAnnSortCost() return new KeysIterationCost(expectedKeys, initCost, searchCost); } + private KeysIterationCost estimateBm25SortCost() + { + double expectedKeys = access.expectedAccessCount(source.expectedKeys()); + + int termCount = ordering.extractQueryTerms().size(); + // all of the cost for BM25 is up front since the index doesn't give us the information we need + // to return results in order, in isolation. The big cost is reading the indexed cells out of + // the sstables. + // VSTODO if we had stats on cell size _per column_ we could usefully include ROW_BYTE_COST + double initCost = source.fullCost() + + source.expectedKeys() * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + return new KeysIterationCost(expectedKeys, initCost, 0); + } + private KeysIterationCost estimateGlobalSortCost() { return new KeysIterationCost(source.expectedKeys(), @@ -1296,36 +1313,30 @@ protected KeysSort withAccess(Access access) } /** - * Returns all keys in ANN order. - * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + * Base class for index scans that return results in a computed order (ANN, BM25) + * rather than the natural index order. */ - final static class AnnIndexScan extends Leaf + abstract static class ComputedOrderIndexScan extends Leaf { final Orderer ordering; - protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + protected ComputedOrderIndexScan(Factory factory, int id, Access access, Orderer ordering) { super(factory, id, access); this.ordering = ordering; } + @Nullable @Override - protected KeysIterationCost estimateCost() + protected Orderer ordering() { - double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); - int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); - double searchCost = factory.costEstimator.estimateAnnSearchCost(ordering, - expectedKeysInt, - factory.tableMetrics.rows); - double initCost = 0; // negligible - return new KeysIterationCost(expectedKeys, initCost, searchCost); + return ordering; } - @Nullable @Override - protected Orderer ordering() + protected double estimateSelectivity() { - return ordering; + return 1.0; } @Override @@ -1334,6 +1345,30 @@ protected Iterator execute(Executor executor) int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows))); return executor.getTopKRows((Expression) null, softLimit); } + } + + /** + * Returns all keys in ANN order. + * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + */ + final static class AnnIndexScan extends ComputedOrderIndexScan + { + protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + + @Override + protected KeysIterationCost estimateCost() + { + double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); + int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); + double searchCost = factory.costEstimator.estimateAnnSearchCost(ordering, + expectedKeysInt, + factory.tableMetrics.rows); + double initCost = 0; // negligible + return new KeysIterationCost(expectedKeys, initCost, searchCost); + } @Override protected KeysIteration withAccess(Access access) @@ -1342,11 +1377,39 @@ protected KeysIteration withAccess(Access access) ? this : new AnnIndexScan(factory, id, access, ordering); } + } + /** + * Returns all keys in BM25 order. + * Like AnnIndexScan, this generates results lazily without an input node. + */ + final static class Bm25IndexScan extends ComputedOrderIndexScan + { + protected Bm25IndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + + @Nonnull @Override - protected double estimateSelectivity() + protected KeysIterationCost estimateCost() { - return 1.0; + double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); + int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); + + int termCount = ordering.extractQueryTerms().size(); + double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + + return new KeysIterationCost(expectedKeys, initCost, 0); + } + + @Override + protected KeysIteration withAccess(Access access) + { + return Objects.equals(access, this.access) + ? this + : new Bm25IndexScan(factory, id, access, ordering); } } @@ -1664,6 +1727,8 @@ private KeysIteration indexScan(Expression predicate, long matchingKeysCount, Or if (ordering != null) if (ordering.isANN()) return new AnnIndexScan(this, id, defaultAccess, ordering); + else if (ordering.isBM25()) + return new Bm25IndexScan(this, id, defaultAccess, ordering); else if (ordering.isLiteral()) return new LiteralIndexScan(this, id, predicate, matchingKeysCount, defaultAccess, ordering); else @@ -1911,6 +1976,9 @@ public static class CostCoefficients /** Additional cost added to row fetch cost per each serialized byte of the row */ public final static double ROW_BYTE_COST = 0.005; + + /** Cost to perform BM25 scoring, per query term */ + public final static double BM25_SCORE_COST = 0.5; } /** Convenience builder for building intersection and union nodes */ diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 9a8ee4ecff00..fbe4430efc6c 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -809,6 +809,7 @@ private long estimateMatchingRowCount(Expression predicate) switch (predicate.getOp()) { case EQ: + case MATCH: case CONTAINS_KEY: case CONTAINS_VALUE: case NOT_EQ: diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index 45015341d5dd..bfde594fadfc 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -499,26 +499,27 @@ public UnfilteredRowIterator computeNext() */ private void fillPendingRows() { + // Group PKs by source sstable/memtable + var groupedKeys = new HashMap>(); // We always want to get at least 1. int rowsToRetrieve = Math.max(1, softLimit - returnedRowCount); - var keys = new HashMap>(); // We want to get the first unique `rowsToRetrieve` keys to materialize // Don't pass the priority queue here because it is more efficient to add keys in bulk - fillKeys(keys, rowsToRetrieve, null); + fillKeys(groupedKeys, rowsToRetrieve, null); // Sort the primary keys by PrK order, just in case that helps with cache and disk efficiency - var primaryKeyPriorityQueue = new PriorityQueue<>(keys.keySet()); + var primaryKeyPriorityQueue = new PriorityQueue<>(groupedKeys.keySet()); - while (!keys.isEmpty()) + while (!groupedKeys.isEmpty()) { - var primaryKey = primaryKeyPriorityQueue.poll(); - var primaryKeyWithSortKeys = keys.remove(primaryKey); - var partitionIterator = readAndValidatePartition(primaryKey, primaryKeyWithSortKeys); + var pk = primaryKeyPriorityQueue.poll(); + var sourceKeys = groupedKeys.remove(pk); + var partitionIterator = readAndValidatePartition(pk, sourceKeys); if (partitionIterator != null) pendingRows.add(partitionIterator); else // The current primaryKey did not produce a partition iterator. We know the caller will need // `rowsToRetrieve` rows, so we get the next unique key and add it to the queue. - fillKeys(keys, 1, primaryKeyPriorityQueue); + fillKeys(groupedKeys, 1, primaryKeyPriorityQueue); } } @@ -526,21 +527,21 @@ private void fillPendingRows() * Fills the keys map with the next `count` unique primary keys that are in the keys produced by calling * {@link #nextSelectedKeyInRange()}. We map PrimaryKey to List because the same * primary key can be in the result set multiple times, but with different source tables. - * @param keys the map to fill + * @param groupedKeys the map to fill * @param count the number of unique PrimaryKeys to consume from the iterator * @param primaryKeyPriorityQueue the priority queue to add new keys to. If the queue is null, we do not add * keys to the queue. */ - private void fillKeys(Map> keys, int count, PriorityQueue primaryKeyPriorityQueue) + private void fillKeys(Map> groupedKeys, int count, PriorityQueue primaryKeyPriorityQueue) { - int initialSize = keys.size(); - while (keys.size() - initialSize < count) + int initialSize = groupedKeys.size(); + while (groupedKeys.size() - initialSize < count) { var primaryKeyWithSortKey = nextSelectedKeyInRange(); if (primaryKeyWithSortKey == null) return; var nextPrimaryKey = primaryKeyWithSortKey.primaryKey(); - var accumulator = keys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); + var accumulator = groupedKeys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); if (primaryKeyPriorityQueue != null && accumulator.isEmpty()) primaryKeyPriorityQueue.add(nextPrimaryKey); accumulator.add(primaryKeyWithSortKey); @@ -580,15 +581,29 @@ private boolean isInRange(DecoratedKey key) return null; } - public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeys) + /** + * Reads and validates a partition for a given primary key against its sources. + *

+ * @param pk The primary key of the partition to read and validate + * @param sourceKeys A list of PrimaryKeyWithSortKey objects associated with the primary key. + * Multiple sort keys can exist for the same primary key when data comes from different + * sstables or memtables. + * + * @return An UnfilteredRowIterator containing the validated partition data, or null if: + * - The key has already been processed + * - The partition does not pass index filters + * - The partition contains no valid rows + * - The row data does not match the index metadata for any of the provided primary keys + */ + public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List sourceKeys) { // If we've already processed the key, we can skip it. Because the score ordered iterator does not // deduplicate rows, we could see dupes if a row is in the ordering index multiple times. This happens // in the case of dupes and of overwrites. - if (processedKeys.contains(key)) + if (processedKeys.contains(pk)) return null; - try (UnfilteredRowIterator partition = controller.getPartition(key, view, executionController)) + try (UnfilteredRowIterator partition = controller.getPartition(pk, view, executionController)) { queryContext.addPartitionsRead(1); queryContext.checkpoint(); @@ -597,7 +612,7 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List queryVector; + private final ColumnMetadata scoreColumn; private final int limit; @@ -113,6 +117,7 @@ public TopKProcessor(ReadCommand command) else this.queryVector = null; this.limit = command.limits().count(); + this.scoreColumn = ColumnMetadata.syntheticColumn(indexContext.getKeyspace(), indexContext.getTable(), ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); } /** @@ -158,8 +163,11 @@ private , P extends BaseParti { // priority queue ordered by score in descending order Comparator> comparator; - if (queryVector != null) + // TODO does this work for complex expressions? + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) + { comparator = Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed(); + } else { comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator()); @@ -187,7 +195,7 @@ private , P extends BaseParti executor.maybeExecuteImmediately(() -> { try (var partitionRowIterator = pIter.commandToIterator(command.left(), command.right())) { - future.complete(partitionRowIterator == null ? null : processPartition(partitionRowIterator)); + future.complete(partitionRowIterator == null ? null : processScoredPartition(partitionRowIterator)); } catch (Throwable t) { @@ -235,9 +243,9 @@ private , P extends BaseParti // have to close to move to the next partition, otherwise hasNext() fails try (var partitionRowIterator = partitions.next()) { - if (queryVector != null) + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) { - PartitionResults pr = processPartition(partitionRowIterator); + PartitionResults pr = processScoredPartition(partitionRowIterator); topK.addAll(pr.rows); for (var uf: pr.tombstones) addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf); @@ -250,7 +258,6 @@ private , P extends BaseParti topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, row.getCell(expression.column()).buffer())); } } - } } } @@ -286,7 +293,7 @@ void addRow(Triple triple) { /** * Processes a single partition, calculating scores for rows and extracting tombstones. */ - private PartitionResults processPartition(BaseRowIterator partitionRowIterator) { + private PartitionResults processScoredPartition(BaseRowIterator partitionRowIterator) { // Compute key and static row score once per partition DecoratedKey key = partitionRowIterator.partitionKey(); Row staticRow = partitionRowIterator.staticRow(); @@ -352,6 +359,14 @@ private float getScoreForRow(DecoratedKey key, Row row) if ((column.isClusteringColumn() || column.isRegular()) && row.isStatic()) return 0; + var scoreData = row.getColumnData(scoreColumn); + if (scoreData != null) + { + var cell = (Cell) scoreData; + return FloatType.instance.compose(cell.buffer()); + } + + // TODO remove this once we enable the scored path for vector queries ByteBuffer value = indexContext.getValueOf(key, row, FBUtilities.nowInSeconds()); if (value != null) { diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java new file mode 100644 index 000000000000..044997e6508b --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.utils.CloseableIterator; + +public class BM25Utils +{ + private static final float K1 = 1.2f; // BM25 term frequency saturation parameter + private static final float B = 0.75f; // BM25 length normalization parameter + + /** + * Term frequencies across all documents. Each document is only counted once. + */ + public static class DocStats + { + // Map of term -> count of docs containing that term + private final Map frequencies; + // total number of docs in the index + private final long docCount; + + public DocStats(Map frequencies, long docCount) + { + this.frequencies = frequencies; + this.docCount = docCount; + } + } + + /** + * Term frequencies within a single document. All instances of a term are counted. + */ + public static class DocTF + { + private final PrimaryKey pk; + private final Map frequencies; + private final float termCount; + + private DocTF(PrimaryKey pk, float termCount, Map frequencies) + { + this.pk = pk; + this.frequencies = frequencies; + this.termCount = termCount; + } + + public int getTermFrequency(ByteBuffer term) + { + return frequencies.getOrDefault(term, 0); + } + + public static DocTF createFromDocument(PrimaryKey pk, + Cell cell, + AbstractAnalyzer docAnalyzer, + Collection queryTerms) + { + float count = 0; + Map frequencies = new HashMap<>(); + + docAnalyzer.reset(cell.buffer()); + try + { + while (docAnalyzer.hasNext()) + { + ByteBuffer term = docAnalyzer.next(); + count++; + if (queryTerms.contains(term)) + frequencies.merge(term, 1, Integer::sum); + } + } + finally + { + docAnalyzer.end(); + } + + return new DocTF(pk, count, frequencies); + } + } + + @FunctionalInterface + public interface CellReader + { + Cell readCell(PrimaryKey pk); + } + + public static CloseableIterator computeScores(Iterator keyIterator, + List queryTerms, + DocStats docStats, + IndexContext indexContext, + Object source, + CellReader cellReader) + { + var docAnalyzer = indexContext.getAnalyzerFactory().create(); + + // data structures for document stats and frequencies + ArrayList documents = new ArrayList<>(); + double totalTermCount = 0; + + // Compute TF within each document + while (keyIterator.hasNext()) + { + var pk = keyIterator.next(); + var cell = cellReader.readCell(pk); + if (cell == null) + continue; + var tf = DocTF.createFromDocument(pk, cell, docAnalyzer, queryTerms); + + // sstable index will only send documents that contain all query terms to this method, + // but memtable is not indexed and will send all documents, so we have to skip documents + // that don't contain all query terms here to preserve consistency with sstable behavior + if (tf.frequencies.size() != queryTerms.size()) + continue; + + documents.add(tf); + + totalTermCount += tf.termCount; + } + + // Calculate average document length + double avgDocLength = documents.size() > 0 ? totalTermCount / documents.size() : 0.0; + + // Calculate BM25 scores + var scoredDocs = new ArrayList(documents.size()); + for (var doc : documents) + { + double score = 0.0; + for (var queryTerm : queryTerms) + { + int tf = doc.getTermFrequency(queryTerm); + Long df = docStats.frequencies.get(queryTerm); + double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength)); + double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); + double deltaScore = normalizedTf * idf; + assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, totalDocs=%d is %f", tf, df, docStats.docCount, deltaScore); + score += deltaScore; + } + if (source instanceof Memtable) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (Memtable) source, doc.pk, (float) score)); + else if (source instanceof SSTableId) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (SSTableId) source, doc.pk, (float) score)); + else + throw new IllegalArgumentException("Invalid source " + source.getClass()); + } + + // sort by score + scoredDocs.sort(Comparator.comparingDouble((PrimaryKeyWithScore pkws) -> pkws.indexScore).reversed()); + + return (CloseableIterator) (CloseableIterator) CloseableIterator.wrap(scoredDocs.iterator()); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java index 837c032b7952..949eb21d282c 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java @@ -21,7 +21,9 @@ import java.nio.ByteBuffer; import java.util.Arrays; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse; @@ -34,7 +36,13 @@ public class PrimaryKeyWithByteComparable extends PrimaryKeyWithSortKey { private final ByteComparable byteComparable; - public PrimaryKeyWithByteComparable(IndexContext context, Object sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + public PrimaryKeyWithByteComparable(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + { + super(context, sourceTable, primaryKey); + this.byteComparable = byteComparable; + } + + public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) { super(context, sourceTable, primaryKey); this.byteComparable = byteComparable; diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java index 61e63d43c7fa..a10d6e82549a 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java @@ -20,7 +20,9 @@ import java.nio.ByteBuffer; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; /** * A {@link PrimaryKey} that includes a score from a source index. @@ -30,7 +32,13 @@ public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey { public final float indexScore; - public PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore) + public PrimaryKeyWithScore(IndexContext context, Memtable source, PrimaryKey primaryKey, float indexScore) + { + super(context, source, primaryKey); + this.indexScore = indexScore; + } + + public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore) { super(context, source, primaryKey); this.indexScore = indexScore; diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java index 2e79b0402124..8a171fca4dbe 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java @@ -22,9 +22,11 @@ import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.Token; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; @@ -41,7 +43,14 @@ public abstract class PrimaryKeyWithSortKey implements PrimaryKey // Either a Memtable reference or an SSTableId reference private final Object sourceTable; - protected PrimaryKeyWithSortKey(IndexContext context, Object sourceTable, PrimaryKey primaryKey) + protected PrimaryKeyWithSortKey(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey) + { + this.context = context; + this.sourceTable = sourceTable; + this.primaryKey = primaryKey; + } + + protected PrimaryKeyWithSortKey(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey) { this.context = context; this.sourceTable = sourceTable; diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java new file mode 100644 index 000000000000..8feb61f36b83 --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test.sai; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.test.TestBaseImpl; + +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; +import static org.assertj.core.api.Assertions.assertThat; + +public class BM25DistributedTest extends TestBaseImpl +{ + private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}"; + private static final String CREATE_TABLE = "CREATE TABLE %s (k int PRIMARY KEY, v text)"; + private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex' WITH OPTIONS = {'index_analyzer': '{\"tokenizer\" : {\"name\" : \"standard\"}, \"filters\" : [{\"name\" : \"porterstem\"}]}'}"; + + // To get consistent results from BM25 we need to know which docs are evaluated, the easiest way + // to do that is to put all the docs on every replica + private static final int NUM_NODES = 3; + private static final int RF = 3; + + private static Cluster cluster; + private static String table; + + private static final AtomicInteger seq = new AtomicInteger(); + + @BeforeClass + public static void setupCluster() throws Exception + { + cluster = Cluster.build(NUM_NODES) + .withTokenCount(1) + .withDataDirCount(1) + .withConfig(config -> config.with(GOSSIP).with(NETWORK)) + .start(); + + cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF))); + } + + @AfterClass + public static void closeCluster() + { + if (cluster != null) + cluster.close(); + } + + @Before + public void before() + { + table = "table_" + seq.getAndIncrement(); + cluster.schemaChange(formatQuery(CREATE_TABLE)); + cluster.schemaChange(formatQuery(CREATE_INDEX)); + SAIUtil.waitForIndexQueryable(cluster, KEYSPACE); + } + + @Test + public void testTermFrequencyOrdering() + { + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + // Query memtable index + assertBM25Ordering(); + + // Flush and query on-disk index + cluster.forEach(n -> n.flush(KEYSPACE)); + assertBM25Ordering(); + } + + private void assertBM25Ordering() + { + Object[][] result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertThat(result).hasNumberOfRows(3); + + // Results should be ordered by term frequency (highest to lowest) + assertThat((Integer) result[0][0]).isEqualTo(3); // 3 occurrences + assertThat((Integer) result[1][0]).isEqualTo(2); // 2 occurrences + assertThat((Integer) result[2][0]).isEqualTo(1); // 1 occurrence + } + + private static Object[][] execute(String query) + { + return execute(query, ConsistencyLevel.QUORUM); + } + + private static Object[][] execute(String query, ConsistencyLevel consistencyLevel) + { + return cluster.coordinator(1).execute(formatQuery(query), consistencyLevel); + } + + private static String formatQuery(String query) + { + return String.format(query, KEYSPACE + '.' + table); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index 1023dcbefa79..c8dbb2ae2528 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -18,42 +18,257 @@ package org.apache.cassandra.index.sai.cql; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + import org.junit.Test; -import org.apache.cassandra.cql3.UntypedResultSet; import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.disk.v1.InvertedIndexSearcher; +import org.apache.cassandra.index.sai.plan.QueryController; +import org.apache.cassandra.inject.ActionBuilder; +import org.apache.cassandra.inject.Expression; +import org.apache.cassandra.inject.Injections; +import org.apache.cassandra.inject.InvokePointBuilder; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.junit.Assert.assertEquals; public class BM25Test extends SAITester { @Test - public void testTermFrequencyOrdering() + public void testTermFrequencyOrdering() throws Throwable { - createAnalyzedTable(); + createSimpleTable(); // Insert documents with varying frequencies of the term "apple" execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); - flush(); - // Results should be ordered by term frequency (highest to lowest) - var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); - assertRows(result, - row(3), // 3 occurrences - row(2), // 2 occurrences - row(1)); // 1 occurrence + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testDocumentLength() throws Throwable + { + createSimpleTable(); + // Create documents with same term frequency but different lengths + execute("INSERT INTO %s (k, v) VALUES (1, 'test test')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'test test other words here to make it longer')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'test test extremely long document with many additional words to significantly increase the document length while maintaining the same term frequency for our target term')"); + + beforeAndAfterFlush(() -> + { + // Documents with same term frequency should be ordered by length (shorter first) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 3"); + assertRows(result, + row(1), + row(2), + row(3)); + }); + } + + @Test + public void testMultiTermQueryScoring() throws Throwable + { + createSimpleTable(); + // Two terms, but "apple" appears in fewer documents + execute("INSERT INTO %s (k, v) VALUES (1, 'apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (4, 'apple apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (5, 'banana banana')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple banana' LIMIT 4"); + assertRows(result, + row(2), // Highest frequency of most important term + row(4), // More mentions of both terms + row(1), // One of each term + row(3)); // Low frequency of most important term + }); + } + + @Test + public void testIrrelevantRowsScoring() throws Throwable + { + createSimpleTable(); + // Insert pizza reviews with varying relevance to "crispy crust" + execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention + execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy + execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions + execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both + execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'crispy crust' LIMIT 5"); + assertRows(result, + row(4), // Highest frequency of both terms + row(2), // High frequency of 'crispy', one 'crust' + row(1)); // One mention of each term + // Rows 4 and 5 do not contain all terms + }); } - private void createAnalyzedTable() + private void createSimpleTable() { createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); - createIndex("CREATE CUSTOM INDEX ON %s(v) " + - "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + - "WITH OPTIONS = {" + - "'index_analyzer': '{" + - "\"tokenizer\" : {\"name\" : \"standard\"}, " + - "\"filters\" : [{\"name\" : \"porterstem\"}]" + - "}'}" + analyzeIndex(); + } + + private String analyzeIndex() + { + return createIndex("CREATE CUSTOM INDEX ON %s(v) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" ); } + + @Test + public void testWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, p int, v text)"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, p, v) VALUES (1, 5, 'apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (2, 5, 'apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartition() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPkPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (1, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (2, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE k1 = 0 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, p int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 1, 5, 'apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 2, 5, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWithPredicateSearchThenOrder() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 0; + testWithPredicate(); + } + + @Test + public void testWidePartitionWithPredicateOrderThenSearch() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 1; + testWidePartitionWithPredicate(); + } + + @Test + public void testQueryWithNulls() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (0, null)"); + execute("INSERT INTO %s (k, v) VALUES (1, 'test document')"); + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testQueryEmptyTable() + { + createSimpleTable(); + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertThat(result).hasSize(0); + } } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java new file mode 100644 index 000000000000..40f6f16ef970 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; + +import org.junit.Test; + +import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.index.sai.postings.IntArrayPostingList; +import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; + +public class IntersectingPostingListTest extends SaiRandomizedTest +{ + @Test + public void shouldIntersectOverlappingPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 4, 6, 8 }), + new IntArrayPostingList(new int[]{ 2, 4, 6, 9 }), + new IntArrayPostingList(new int[]{ 4, 6, 7 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 4, 6 }), intersected); + } + + @Test + public void shouldIntersectDisjointPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5 }), + new IntArrayPostingList(new int[]{ 2, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{}), intersected); + } + + @Test + public void shouldIntersectSinglePostingList() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 4, 6 }), intersected); + } + + @Test + public void shouldIntersectIdenticalPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 2, 3 }), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 2, 3 }), intersected); + } + + @Test + public void shouldAdvanceAllIntersectedLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 2, 3, 5, 7, 8 }), + new IntArrayPostingList(new int[]{ 3, 5, 7, 10 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + final PostingList expected = new IntArrayPostingList(new int[]{ 3, 5, 7 }); + + assertEquals(expected.advance(5), + intersected.advance(5)); + + assertPostingListEquals(expected, intersected); + } + + @Test + public void shouldHandleEmptyList() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{}), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertEquals(PostingList.END_OF_STREAM, intersected.advance(1)); + } + + @Test + public void shouldInterleaveNextAndAdvance() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertEquals(1, intersected.nextPosting()); + assertEquals(5, intersected.advance(5)); + assertEquals(7, intersected.nextPosting()); + assertEquals(9, intersected.advance(9)); + } + + @Test + public void shouldInterleaveNextAndAdvanceOnRandom() throws IOException + { + for (int i = 0; i < 1000; ++i) + { + testAdvancingOnRandom(); + } + } + + private void testAdvancingOnRandom() throws IOException + { + final int postingsCount = nextInt(1, 50_000); + final int postingListCount = nextInt(2, 10); + + // Generate base postings that will be present in all lists + final AtomicInteger rowId = new AtomicInteger(); + final int[] commonPostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount / 4) // Fewer common elements + .toArray(); + + var splitPostingLists = new ArrayList(); + for (int i = 0; i < postingListCount; i++) + { + // Combine common postings with some unique ones for each list + final int[] uniquePostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount) + .toArray(); + int[] combined = IntStream.concat(IntStream.of(commonPostings), + IntStream.of(uniquePostings)) + .distinct() + .sorted() + .toArray(); + splitPostingLists.add(new IntArrayPostingList(combined)); + } + + final PostingList intersected = IntersectingPostingList.intersect(splitPostingLists); + final PostingList expected = new IntArrayPostingList(commonPostings); + + final List actions = new ArrayList<>(); + for (int idx = 0; idx < commonPostings.length; idx++) + { + if (nextInt(0, 8) == 0) + { + actions.add((postingList) -> { + try + { + return postingList.nextPosting(); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + else + { + final int skips = nextInt(0, 5); + idx = Math.min(idx + skips, commonPostings.length - 1); + final int rowID = commonPostings[idx]; + actions.add((postingList) -> { + try + { + return postingList.advance(rowID); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + } + + for (PostingListAdvance action : actions) + { + assertEquals(action.advance(expected), action.advance(intersected)); + } + } + + private ArrayList listOfLists(PostingList... postingLists) + { + var L = new ArrayList(); + Collections.addAll(L, postingLists); + return L; + } + + private interface PostingListAdvance + { + long advance(PostingList list) throws IOException; + } +}