From 6c9a0e6cbfdae0c40d1e90d655d4d0f1a671cb9b Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Wed, 11 Dec 2024 09:00:37 -0600 Subject: [PATCH] address review notes --- .../org/apache/cassandra/cql3/Operator.java | 29 ++++++++++--------- .../index/sai/disk/v1/IndexSearcher.java | 4 +-- .../sai/disk/v1/InvertedIndexSearcher.java | 4 +-- .../v1/postings/IntersectingPostingList.java | 27 +++++++++++++++-- .../index/sai/memory/TrieMemtableIndex.java | 4 +-- .../cassandra/index/sai/plan/Orderer.java | 8 +++-- .../apache/cassandra/index/sai/plan/Plan.java | 4 +-- .../cassandra/index/sai/utils/BM25Utils.java | 5 ++-- 8 files changed, 56 insertions(+), 29 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/Operator.java b/src/java/org/apache/cassandra/cql3/Operator.java index 8a8494661cb4..cd80020c1257 100644 --- a/src/java/org/apache/cassandra/cql3/Operator.java +++ b/src/java/org/apache/cassandra/cql3/Operator.java @@ -373,20 +373,6 @@ public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteB return !LIKE.isSatisfiedBy(type, leftOperand, rightOperand, analyzer); } }, - BM25(25) - { - @Override - public String toString() - { - return "BM25"; - } - - @Override - public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) - { - throw new UnsupportedOperationException(); - } - }, /** * An operator that only performs matching against analyzed columns. @@ -429,6 +415,7 @@ private boolean hasToken(AbstractType type, List tokens, ByteBuff return false; } }, + /** * An operator that performs a distance bounded approximate nearest neighbor search against a vector column such * that all result vectors are within a given distance of the query vector. The notable difference between this @@ -473,6 +460,20 @@ public String toString() return "DESC"; } + @Override + public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) + { + throw new UnsupportedOperationException(); + } + }, + BM25(104) + { + @Override + public String toString() + { + return "BM25"; + } + @Override public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java index 5fb246183921..14b182c7ee56 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java @@ -100,7 +100,7 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException; /** - * Order the rows by the giving Orderer. Used for ORDER BY clause when + * Order the rows by the given Orderer. Used for ORDER BY clause when * (1) the WHERE predicate is either a partition restriction or a range restriction on the index, * (2) there is no WHERE predicate, or * (3) the planner determines it is better to post-filter the ordered results by the predicate. @@ -115,7 +115,7 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException; /** - * Order the rows by the giving Orderer. Used for ORDER BY clause when the WHERE predicates + * Order the rows by the given Orderer. Used for ORDER BY clause when the WHERE predicates * have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine * the initial number of results returned, not a maximum. */ diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index b8353a3b5d80..8b7d14c42755 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -172,7 +172,7 @@ public CloseableIterator orderBy(Orderer orderer, Express } // find documents that match each term - var queryTerms = orderer.extractQueryTerms(); + var queryTerms = orderer.getQueryTerms(); var postingLists = queryTerms.stream() .collect(Collectors.toMap(Function.identity(), term -> { @@ -227,7 +227,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea if (!orderer.isBM25()) return super.orderResultsBy(reader, queryContext, keys, orderer, limit); - var queryTerms = orderer.extractQueryTerms(); + var queryTerms = orderer.getQueryTerms(); // compute documentFrequencies from either histogram or an index search var documentFrequencies = new HashMap(); boolean hasHistograms = metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram(); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java index b6240bfce7b1..4cd762fd23c6 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -23,6 +23,7 @@ import javax.annotation.concurrent.NotThreadSafe; import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.io.util.FileUtils; import static java.lang.Math.max; @@ -51,13 +52,16 @@ private IntersectingPostingList(List postingLists) this.currentRowIds = new int[postingLists.size()]; } + /** + * @return the intersection of the provided posting lists + */ public static PostingList intersect(List postingLists) { if (postingLists.size() == 1) return postingLists.get(0); if (postingLists.stream().anyMatch(PostingList::isEmpty)) - return PostingList.EMPTY; + return new EmptyIntersectingList(postingLists); return new IntersectingPostingList(postingLists); } @@ -124,10 +128,27 @@ public int size() } @Override - public void close() throws IOException + public void close() { for (PostingList list : postingLists) - list.close(); + FileUtils.closeQuietly(list); + } + + private static class EmptyIntersectingList extends EmptyPostingList + { + private final List lists; + + public EmptyIntersectingList(List postingLists) + { + this.lists = postingLists; + } + + @Override + public void close() + { + for (PostingList list : lists) + FileUtils.closeQuietly(list); + } } } diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java index 22b392350860..4fa839532d28 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java @@ -258,7 +258,7 @@ public List> orderBy(QueryContext query } // BM25 - var queryTerms = orderer.extractQueryTerms(); + var queryTerms = orderer.getQueryTerms(); // Intersect iterators to find documents containing all terms var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms); @@ -328,7 +328,7 @@ public CloseableIterator orderResultsBy(QueryContext quer } // BM25 - var queryTerms = orderer.extractQueryTerms(); + var queryTerms = orderer.getQueryTerms(); var docStats = computeDocumentFrequencies(queryContext, queryTerms); return BM25Utils.computeScores(keys.iterator(), queryTerms, diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index 6193e0a37110..6cb62ca06607 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -50,6 +50,7 @@ public class Orderer public final Operator operator; public final ByteBuffer term; private float[] vector; + private ArrayList queryTerms; /** * Create an orderer for the given index context, operator, and term. @@ -132,11 +133,14 @@ public float[] getVectorTerm() return vector; } - public ArrayList extractQueryTerms() + public ArrayList getQueryTerms() { + if (queryTerms != null) + return queryTerms; + var queryAnalyzer = context.getQueryAnalyzerFactory().create(); // Split query into terms - var queryTerms = new ArrayList(); + queryTerms = new ArrayList(); queryAnalyzer.reset(term); try { diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java index 4cf730bd1d53..941608328779 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java @@ -1269,7 +1269,7 @@ private KeysIterationCost estimateBm25SortCost() { double expectedKeys = access.expectedAccessCount(source.expectedKeys()); - int termCount = ordering.extractQueryTerms().size(); + int termCount = ordering.getQueryTerms().size(); // all of the cost for BM25 is up front since the index doesn't give us the information we need // to return results in order, in isolation. The big cost is reading the indexed cells out of // the sstables. @@ -1397,7 +1397,7 @@ protected KeysIterationCost estimateCost() double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); - int termCount = ordering.extractQueryTerms().size(); + int termCount = ordering.getQueryTerms().size(); double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + termCount * BM25_SCORE_COST; diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java index 044997e6508b..f1c51e52923b 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; @@ -170,8 +171,8 @@ else if (source instanceof SSTableId) throw new IllegalArgumentException("Invalid source " + source.getClass()); } - // sort by score - scoredDocs.sort(Comparator.comparingDouble((PrimaryKeyWithScore pkws) -> pkws.indexScore).reversed()); + // sort by score (PKWS implements Comparator correctly for us) + Collections.sort(scoredDocs); return (CloseableIterator) (CloseableIterator) CloseableIterator.wrap(scoredDocs.iterator()); }