Skip to content

Commit

Permalink
implement BM25
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellis committed Dec 6, 2024
1 parent 087b14f commit 4544422
Show file tree
Hide file tree
Showing 17 changed files with 1,313 additions and 200 deletions.
23 changes: 14 additions & 9 deletions src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ public abstract class IndexSearcher implements Closeable, SegmentOrdering
protected final SegmentMetadata metadata;
protected final IndexContext indexContext;

private static final SSTableReadsListener NOOP_LISTENER = new SSTableReadsListener() {};

private final ColumnFilter columnFilter;
protected final ColumnFilter columnFilter;

protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
PerIndexFiles perIndexFiles,
Expand All @@ -90,30 +88,37 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
public abstract long indexFileCacheSize();

/**
* Search on-disk index synchronously
* Search on-disk index synchronously. Used for WHERE clause predicates, including BOUNDED_ANN.
*
* @param expression to filter on disk index
* @param keyRange key range specific in read command, used by ANN index
* @param queryContext to track per sstable cache and per query metrics
* @param defer create the iterator in a deferred state
* @param limit the num of rows to returned, used by ANN index
* @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected!
* @return {@link KeyRangeIterator} that matches given expression
*/
public abstract KeyRangeIterator search(Expression expression, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException;

/**
* Order the on-disk index synchronously and produce an iterator in score order
* Order the rows by the giving Orderer. Used for ORDER BY clause when
* (1) the WHERE predicate is either a partition restriction or a range restriction on the index,
* (2) there is no WHERE predicate, or
* (3) the planner determines it is better to post-filter the ordered results by the predicate.
*
* @param orderer the object containing the ordering logic
* @param slice optional predicate to get a slice of the index
* @param keyRange key range specific in read command, used by ANN index
* @param queryContext to track per sstable cache and per query metrics
* @param limit the num of rows to returned, used by ANN index
* @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected!
* @return an iterator of {@link PrimaryKeyWithSortKey} in score order
*/
public abstract CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Expression slice, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext, int limit) throws IOException;


/**
* Order the rows by the giving Orderer. Used for ORDER BY clause when the WHERE predicates
* have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine
* the initial number of results returned, not a maximum.
*/
@Override
public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader reader, QueryContext context, List<PrimaryKey> keys, Orderer orderer, int limit) throws IOException
{
Expand All @@ -124,7 +129,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
{
var slices = Slices.with(indexContext.comparator(), Slice.make(key.clustering()));
// TODO if we end up needing to read the row still, is it better to store offset and use reader.unfilteredAt?
try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, NOOP_LISTENER))
try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER))
{
if (iter.hasNext())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,49 @@
package org.apache.cassandra.index.sai.disk.v1;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.common.base.MoreObjects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.cassandra.cql3.Operator;
import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.db.RegularAndStaticColumns;
import org.apache.cassandra.db.Slice;
import org.apache.cassandra.db.Slices;
import org.apache.cassandra.db.filter.ColumnFilter;
import org.apache.cassandra.db.rows.Cell;
import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.SSTableContext;
import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer;
import org.apache.cassandra.index.sai.disk.PostingList;
import org.apache.cassandra.index.sai.disk.TermsIterator;
import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
import org.apache.cassandra.index.sai.disk.format.Version;
import org.apache.cassandra.index.sai.disk.v1.postings.MergePostingList;
import org.apache.cassandra.index.sai.disk.v1.postings.IntersectingPostingList;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.metrics.MulticastQueryEventListeners;
import org.apache.cassandra.index.sai.metrics.QueryEventListener;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.plan.Orderer;
import org.apache.cassandra.index.sai.utils.BM25Utils;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable;
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
import org.apache.cassandra.io.sstable.format.SSTableReader;
import org.apache.cassandra.io.sstable.format.SSTableReadsListener;
Expand All @@ -75,23 +82,12 @@ public class InvertedIndexSearcher extends IndexSearcher
private final TermsReader reader;
private final QueryEventListener.TrieIndexEventListener perColumnEventListener;
private final Version version;
private static final int ROW_COUNT = 1_000_000; // TODO: Replace with actual count
private static final float K1 = 1.2f; // BM25 term frequency saturation parameter
private static final float B = 0.75f; // BM25 length normalization parameter

private final boolean filterRangeResults;
private final SSTableReader sstable;

private static class DocumentStats {
final Map<ByteBuffer, Integer> termFrequencies = new HashMap<>();
float length;
final int rowId;

DocumentStats(int rowId) {
this.rowId = rowId;
}
}

protected InvertedIndexSearcher(SSTableContext sstableContext,
PerIndexFiles perIndexFiles,
SegmentMetadata segmentMetadata,
Expand Down Expand Up @@ -161,19 +157,16 @@ else if (exp.getOp() == Expression.Op.RANGE)
throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression: " + exp));
}

private Cell<?> readColumn(SSTableReader sstable, PrimaryKey primaryKey) {
var decoratedKey = primaryKey.partitionKey();

// Create a ColumnFilter to select only the specific column
var column = indexContext.getDefinition();
var columnFilter = ColumnFilter.selection(RegularAndStaticColumns.of(column));

// TODO Slices.ALL is not correct when there are multiple rows per partition
try (var rowIterator = sstable.iterator(decoratedKey, Slices.ALL, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) {
private Cell<?> readColumn(SSTableReader sstable, PrimaryKey primaryKey)
{
var dk = primaryKey.partitionKey();
var slices = Slices.with(indexContext.comparator(), Slice.make(primaryKey.clustering()));
try (var rowIterator = sstable.iterator(dk, slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER))
{
var unfiltered = rowIterator.next();
assert unfiltered.isRow() : unfiltered;
Row row = (Row) unfiltered;
return row.getCell(column);
return row.getCell(indexContext.getDefinition());
}
}

Expand All @@ -186,83 +179,83 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
return toMetaSortedIterator(iter, queryContext);
}

var queryAnalyzer = indexContext.getQueryAnalyzerFactory().create();
var docAnalyzer = indexContext.getAnalyzerFactory().create();
// find documents that match each term
var queryTerms = orderer.extractQueryTerms();
var postingLists = queryTerms.stream()
.collect(Collectors.toMap(Function.identity(), term ->
{
var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator());
var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener);
return reader.exactMatch(encodedTerm, listener, queryContext);
}));
// extract the match count for each
var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()));

// Split query into terms
var queryTerms = new ArrayList<ByteBuffer>();
queryAnalyzer.reset(orderer.term);
try
{
queryAnalyzer.forEachRemaining(queryTerms::add);
}
finally
{
queryAnalyzer.end();
}

// find documents that match all terms
var postingLists = queryTerms.stream().map(term ->
{
var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator());
var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener);
return reader.exactMatch(encodedTerm, listener, queryContext);
}).collect(Collectors.toList());

Map<Integer, DocumentStats> documents = new HashMap<>();
try (var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
var merged = MergePostingList.merge(postingLists))
var merged = IntersectingPostingList.intersect(List.copyOf(postingLists.values())))
{
// First pass: collect term frequencies and document lengths
int rowId;
while ((rowId = merged.nextPosting()) != END_OF_STREAM) {
var pk = pkm.primaryKeyFromRowId(rowId);
var cell = readColumn(sstable, pk);

var stats = new DocumentStats(rowId);
docAnalyzer.reset(cell.buffer());
try
var it = new AbstractIterator<PrimaryKey>() {
@Override
protected PrimaryKey computeNext()
{
while (docAnalyzer.hasNext()) {
var docTerm = docAnalyzer.next();
if (queryTerms.stream().anyMatch(q -> q.equals(docTerm))) {
stats.termFrequencies.merge(docTerm, 1, Integer::sum);
}
stats.length++;
try
{
int rowId = merged.nextPosting();
if (rowId == PostingList.END_OF_STREAM)
return endOfData();
return pkm.primaryKeyFromRowId(rowId);
}
catch (IOException e)
{
throw new UncheckedIOException(e);
}
}
finally
{
docAnalyzer.end();
}
documents.put(rowId, stats);
}

// Ccompute average document length
double avgDocLength = documents.values().stream().mapToDouble(d -> d.length).average().orElse(0.0);
};
return bm25Internal(it, queryTerms, documentFrequencies);
}
}

// Second pass: calculate BM25 scores
var scoredDocs = new PriorityQueue<RowIdWithScore>((a, b) -> Double.compare(b.score, a.score));
private CloseableIterator<PrimaryKeyWithSortKey> bm25Internal(Iterator<PrimaryKey> keyIterator,
List<ByteBuffer> queryTerms,
Map<ByteBuffer, Long> documentFrequencies)
{
var docStats = new BM25Utils.DocStats(documentFrequencies, sstable.getTotalRows());
return BM25Utils.computeScores(keyIterator,
queryTerms,
docStats,
indexContext,
sstable.descriptor.id,
pk -> readColumn(sstable, pk));
}

for (var doc : documents.values()) {
double score = 0.0f;
for (var queryTerm : queryTerms) {
int tf = doc.termFrequencies.get(queryTerm);
double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.length / avgDocLength));
// TODO this is broken, need counts for bare term matches for each term
double idf = Math.log(1 + (ROW_COUNT - documents.size() + 0.5) / (documents.size() + 0.5));
score += normalizedTf * idf;
}
@Override
public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader reader, QueryContext queryContext, List<PrimaryKey> keys, Orderer orderer, int limit) throws IOException
{
if (!orderer.isBM25())
return super.orderResultsBy(reader, queryContext, keys, orderer, limit);

scoredDocs.add(new RowIdWithScore(doc.rowId, (float) score));
if (scoredDocs.size() > limit) {
scoredDocs.poll(); // Remove lowest scoring doc
}
var queryTerms = orderer.extractQueryTerms();
var documentFrequencies = new HashMap<ByteBuffer, Long>();
boolean hasHistograms = metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram();
for (ByteBuffer term : queryTerms)
{
long matches;
if (hasHistograms)
{
matches = metadata.estimateNumRowsMatching(new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term));
}

// Convert scored results back to posting list
return toMetaSortedIterator(CloseableIterator.wrap(scoredDocs.iterator()), queryContext);
else
{
// Without histograms, need to do an actual index scan
var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator());
var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener);
var postingList = this.reader.exactMatch(encodedTerm, listener, queryContext);
matches = postingList.size();
FileUtils.closeQuietly(postingList);
}
documentFrequencies.put(term, matches);
}
return bm25Internal(keys.iterator(), queryTerms, documentFrequencies);
}

@Override
Expand Down
Loading

0 comments on commit 4544422

Please sign in to comment.