Skip to content

Commit

Permalink
address review notes
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellis committed Dec 11, 2024
1 parent 30b6545 commit 6c9a0e6
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 29 deletions.
29 changes: 15 additions & 14 deletions src/java/org/apache/cassandra/cql3/Operator.java
Original file line number Diff line number Diff line change
Expand Up @@ -373,20 +373,6 @@ public boolean isSatisfiedBy(AbstractType<?> type, ByteBuffer leftOperand, ByteB
return !LIKE.isSatisfiedBy(type, leftOperand, rightOperand, analyzer);
}
},
BM25(25)
{
@Override
public String toString()
{
return "BM25";
}

@Override
public boolean isSatisfiedBy(AbstractType<?> type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer)
{
throw new UnsupportedOperationException();
}
},

/**
* An operator that only performs matching against analyzed columns.
Expand Down Expand Up @@ -429,6 +415,7 @@ private boolean hasToken(AbstractType<?> type, List<ByteBuffer> tokens, ByteBuff
return false;
}
},

/**
* An operator that performs a distance bounded approximate nearest neighbor search against a vector column such
* that all result vectors are within a given distance of the query vector. The notable difference between this
Expand Down Expand Up @@ -473,6 +460,20 @@ public String toString()
return "DESC";
}

@Override
public boolean isSatisfiedBy(AbstractType<?> type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer)
{
throw new UnsupportedOperationException();
}
},
BM25(104)
{
@Override
public String toString()
{
return "BM25";
}

@Override
public boolean isSatisfiedBy(AbstractType<?> type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
public abstract KeyRangeIterator search(Expression expression, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException;

/**
* Order the rows by the giving Orderer. Used for ORDER BY clause when
* Order the rows by the given Orderer. Used for ORDER BY clause when
* (1) the WHERE predicate is either a partition restriction or a range restriction on the index,
* (2) there is no WHERE predicate, or
* (3) the planner determines it is better to post-filter the ordered results by the predicate.
Expand All @@ -115,7 +115,7 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
public abstract CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Expression slice, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext, int limit) throws IOException;

/**
* Order the rows by the giving Orderer. Used for ORDER BY clause when the WHERE predicates
* Order the rows by the given Orderer. Used for ORDER BY clause when the WHERE predicates
* have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine
* the initial number of results returned, not a maximum.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
}

// find documents that match each term
var queryTerms = orderer.extractQueryTerms();
var queryTerms = orderer.getQueryTerms();
var postingLists = queryTerms.stream()
.collect(Collectors.toMap(Function.identity(), term ->
{
Expand Down Expand Up @@ -227,7 +227,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
if (!orderer.isBM25())
return super.orderResultsBy(reader, queryContext, keys, orderer, limit);

var queryTerms = orderer.extractQueryTerms();
var queryTerms = orderer.getQueryTerms();
// compute documentFrequencies from either histogram or an index search
var documentFrequencies = new HashMap<ByteBuffer, Long>();
boolean hasHistograms = metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import javax.annotation.concurrent.NotThreadSafe;

import org.apache.cassandra.index.sai.disk.PostingList;
import org.apache.cassandra.io.util.FileUtils;

import static java.lang.Math.max;

Expand Down Expand Up @@ -51,13 +52,16 @@ private IntersectingPostingList(List<PostingList> postingLists)
this.currentRowIds = new int[postingLists.size()];
}

/**
* @return the intersection of the provided posting lists
*/
public static PostingList intersect(List<PostingList> postingLists)
{
if (postingLists.size() == 1)
return postingLists.get(0);

if (postingLists.stream().anyMatch(PostingList::isEmpty))
return PostingList.EMPTY;
return new EmptyIntersectingList(postingLists);

return new IntersectingPostingList(postingLists);
}
Expand Down Expand Up @@ -124,10 +128,27 @@ public int size()
}

@Override
public void close() throws IOException
public void close()
{
for (PostingList list : postingLists)
list.close();
FileUtils.closeQuietly(list);
}

private static class EmptyIntersectingList extends EmptyPostingList
{
private final List<PostingList> lists;

public EmptyIntersectingList(List<PostingList> postingLists)
{
this.lists = postingLists;
}

@Override
public void close()
{
for (PostingList list : lists)
FileUtils.closeQuietly(list);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
}

// BM25
var queryTerms = orderer.extractQueryTerms();
var queryTerms = orderer.getQueryTerms();

// Intersect iterators to find documents containing all terms
var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms);
Expand Down Expand Up @@ -328,7 +328,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext quer
}

// BM25
var queryTerms = orderer.extractQueryTerms();
var queryTerms = orderer.getQueryTerms();
var docStats = computeDocumentFrequencies(queryContext, queryTerms);
return BM25Utils.computeScores(keys.iterator(),
queryTerms,
Expand Down
8 changes: 6 additions & 2 deletions src/java/org/apache/cassandra/index/sai/plan/Orderer.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class Orderer
public final Operator operator;
public final ByteBuffer term;
private float[] vector;
private ArrayList<ByteBuffer> queryTerms;

/**
* Create an orderer for the given index context, operator, and term.
Expand Down Expand Up @@ -132,11 +133,14 @@ public float[] getVectorTerm()
return vector;
}

public ArrayList<ByteBuffer> extractQueryTerms()
public ArrayList<ByteBuffer> getQueryTerms()
{
if (queryTerms != null)
return queryTerms;

var queryAnalyzer = context.getQueryAnalyzerFactory().create();
// Split query into terms
var queryTerms = new ArrayList<ByteBuffer>();
queryTerms = new ArrayList<ByteBuffer>();
queryAnalyzer.reset(term);
try
{
Expand Down
4 changes: 2 additions & 2 deletions src/java/org/apache/cassandra/index/sai/plan/Plan.java
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ private KeysIterationCost estimateBm25SortCost()
{
double expectedKeys = access.expectedAccessCount(source.expectedKeys());

int termCount = ordering.extractQueryTerms().size();
int termCount = ordering.getQueryTerms().size();
// all of the cost for BM25 is up front since the index doesn't give us the information we need
// to return results in order, in isolation. The big cost is reading the indexed cells out of
// the sstables.
Expand Down Expand Up @@ -1397,7 +1397,7 @@ protected KeysIterationCost estimateCost()
double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows);
int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys));

int termCount = ordering.extractQueryTerms().size();
int termCount = ordering.getQueryTerms().size();
double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST)
+ termCount * BM25_SCORE_COST;

Expand Down
5 changes: 3 additions & 2 deletions src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
Expand Down Expand Up @@ -170,8 +171,8 @@ else if (source instanceof SSTableId)
throw new IllegalArgumentException("Invalid source " + source.getClass());
}

// sort by score
scoredDocs.sort(Comparator.comparingDouble((PrimaryKeyWithScore pkws) -> pkws.indexScore).reversed());
// sort by score (PKWS implements Comparator correctly for us)
Collections.sort(scoredDocs);

return (CloseableIterator<PrimaryKeyWithSortKey>) (CloseableIterator) CloseableIterator.wrap(scoredDocs.iterator());
}
Expand Down

0 comments on commit 6c9a0e6

Please sign in to comment.