Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove reranking and force rerankK=limit in diskann (sstable) searches. memtables unaffected #1472

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion build.xml
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@
<dependency groupId="org.apache.lucene" artifactId="lucene-core" version="9.8.0-5ea8bb4f21" />
<dependency groupId="org.apache.lucene" artifactId="lucene-analysis-common" version="9.8.0-5ea8bb4f21" />
<dependency groupId="org.apache.lucene" artifactId="lucene-backward-codecs" version="9.8.0-5ea8bb4f21" />
<dependency groupId="io.github.jbellis" artifactId="jvector" version="3.0.6" />
<dependency groupId="io.github.jbellis" artifactId="jvector" version="3.0.7-13de754a" />
<dependency groupId="com.bpodgursky" artifactId="jbool_expressions" version="1.14" scope="test"/>

<dependency groupId="com.carrotsearch.randomizedtesting" artifactId="randomizedtesting-runner" version="2.1.2" scope="test">
Expand Down
6 changes: 6 additions & 0 deletions src/java/org/apache/cassandra/cache/ChunkCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ public CachingRebufferer(ChunkReader file)
alignmentMask = -chunkSize;
}

@Override
public boolean supportsConcurrentRebuffer()
{
return source.supportsReadingChunksConcurrently();
}

@Override
public BufferHolder rebuffer(long position)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,13 @@ public enum CassandraRelevantProperties
* Do not wait for gossip to be enabled before starting stabilisation period. This is required especially for tests
* which do not enable gossip at all.
*/
CLUSTER_VERSION_PROVIDER_SKIP_WAIT_FOR_GOSSIP("cassandra.test.cluster_version_provider.skip_wait_for_gossip");
CLUSTER_VERSION_PROVIDER_SKIP_WAIT_FOR_GOSSIP("cassandra.test.cluster_version_provider.skip_wait_for_gossip"),


/**
* (Experimental) Thread pool size for RandomAccessReader "vectored" reads.
*/
RAR_VECTORED_READS_THREAD_POOL_SIZE("cassandra.rar.vectored_reads_thread_pool_size");

CassandraRelevantProperties(String key, String defaultVal)
{
Expand Down
28 changes: 28 additions & 0 deletions src/java/org/apache/cassandra/db/ColumnFamilyStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -3242,6 +3242,34 @@ public ViewFragment(List<SSTableReader> sstables, Iterable<Memtable> memtables)
this.sstables = sstables;
this.memtables = memtables;
}

public void sortSSTablesByMaxTimestampDescending()
{
sstables.sort(SSTableReader.maxTimestampDescending);
}
}

// A RefViewFragment that is sorted by max timestamp descending, which makes this object safe to reuse and to
// share across threads.
public static class SortedRefViewFragment extends RefViewFragment
{
private SortedRefViewFragment(RefViewFragment view)
{
// Copy sstable list to an immutable list to ensure there is not a future change that modifies the list
super(List.copyOf(view.sstables), view.memtables, view.refs);
}

// sstables are expected to be pre-sorted
@Override
public void sortSSTablesByMaxTimestampDescending()
{
}

public static SortedRefViewFragment sortThenCreateFrom(RefViewFragment view)
{
view.sortSSTablesByMaxTimestampDescending();
return new SortedRefViewFragment(view);
}
}

public static class RefViewFragment extends ViewFragment implements AutoCloseable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs
if (Tracing.traceSinglePartitions())
Tracing.trace("Acquiring sstable references");

view.sstables.sort(SSTableReader.maxTimestampDescending);
view.sortSSTablesByMaxTimestampDescending();
ClusteringIndexFilter filter = clusteringIndexFilter();
long minTimestamp = Long.MAX_VALUE;
long mostRecentPartitionTombstone = Long.MIN_VALUE;
Expand Down Expand Up @@ -673,7 +673,7 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs
* In other words, iterating in descending maxTimestamp order allow to do our mostRecentPartitionTombstone
* elimination in one pass, and minimize the number of sstables for which we read a partition tombstone.
*/
view.sstables.sort(SSTableReader.maxTimestampDescending);
view.sortSSTablesByMaxTimestampDescending();
int nonIntersectingSSTables = 0;
int includedDueToTombstones = 0;

Expand Down Expand Up @@ -903,7 +903,7 @@ private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFam
}

/* add the SSTables on disk */
view.sstables.sort(SSTableReader.maxTimestampDescending);
view.sortSSTablesByMaxTimestampDescending();
// read sorted sstables
SSTableReadMetricsCollector metricsCollector = new SSTableReadMetricsCollector();
for (SSTableReader sstable : view.sstables)
Expand Down
7 changes: 6 additions & 1 deletion src/java/org/apache/cassandra/index/sai/QueryContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,14 @@ public FilterSortOrder filterSortOrder()
return filterSortOrder;
}

public long approximateRemainingTimeNs()
{
return executionQuotaNano - totalQueryTimeNs();
}

public void checkpoint()
{
if (totalQueryTimeNs() >= executionQuotaNano && !DISABLE_TIMEOUT)
if (approximateRemainingTimeNs() < 0 && !DISABLE_TIMEOUT)
{
addQueryTimeouts(1);
throw new AbortedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.function.Consumer;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -291,7 +293,7 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv
VectorFloat<?> queryVector,
IntIntPairArray segmentOrdinalPairs,
int limit,
int rerankK) throws IOException
int rerankK)
{
var approximateScores = new SortingIterator.Builder<BruteForceRowIdIterator.RowWithApproximateScore>(segmentOrdinalPairs.size());
var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction();
Expand All @@ -302,8 +304,28 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv
approximateScores.add(new BruteForceRowIdIterator.RowWithApproximateScore(segmentRowId, ordinal, score));
});
var approximateScoresQueue = approximateScores.build(BruteForceRowIdIterator.RowWithApproximateScore::compare);
var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView());
return new BruteForceRowIdIterator(approximateScoresQueue, reranker, limit, rerankK);
var transformed = new Iterator<RowIdWithScore>() {
int consumed = 0;

@Override
public boolean hasNext()
{
if (consumed >= limit)
return false;
return approximateScoresQueue.hasNext();
}

@Override
public RowIdWithScore next()
{
if (!hasNext())
throw new NoSuchElementException();
consumed++;
var approximated = approximateScoresQueue.next();
return new RowIdWithScore(approximated.rowId, approximated.appoximateScore);
}
};
return CloseableIterator.wrap(transformed);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ public class BruteForceRowIdIterator extends AbstractIterator<RowIdWithScore>
{
public static class RowWithApproximateScore
{
private final int rowId;
private final int ordinal;
private final float appoximateScore;
public final int rowId;
public final int ordinal;
public final float appoximateScore;

public RowWithApproximateScore(int rowId, int ordinal, float appoximateScore)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ public CloseableIterator<RowIdWithScore> search(VectorFloat<?> queryVector,
{
VectorValidation.validateIndexable(queryVector, similarityFunction);

rerankK = limit;
var graphAccessManager = searchers.get();
var searcher = graphAccessManager.get();
try
Expand All @@ -236,11 +237,11 @@ public CloseableIterator<RowIdWithScore> search(VectorFloat<?> queryVector,
if (features.contains(FeatureId.FUSED_ADC))
{
var asf = view.approximateScoreFunctionFor(queryVector, similarityFunction);
var rr = view.rerankerFor(queryVector, similarityFunction);
ssp = new SearchScoreProvider(asf, rr);
ssp = new SearchScoreProvider(asf, null);
}
else if (compressedVectors == null)
{
// no PQ, search with full-res vectors from disk
ssp = new SearchScoreProvider(view.rerankerFor(queryVector, similarityFunction));
}
else
Expand All @@ -251,8 +252,7 @@ else if (compressedVectors == null)
? VectorSimilarityFunction.COSINE
: similarityFunction;
var asf = compressedVectors.precomputedScoreFunctionFor(queryVector, sf);
var rr = view.rerankerFor(queryVector, similarityFunction);
ssp = new SearchScoreProvider(asf, rr);
ssp = new SearchScoreProvider(asf, null);
}
var result = searcher.search(ssp, limit, rerankK, threshold, context.getAnnRerankFloor(), ordinalsMap.ignoringDeleted(acceptBits));
if (V3OnDiskFormat.ENABLE_RERANK_FLOOR)
Expand Down
7 changes: 4 additions & 3 deletions src/java/org/apache/cassandra/index/sai/plan/QueryView.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@

public class QueryView implements AutoCloseable
{
final ColumnFamilyStore.RefViewFragment view;
// We use a SortedRefViewFragment because it can be safely shared across multiple threads when materializing rows.
final ColumnFamilyStore.SortedRefViewFragment view;
final Set<SSTableIndex> referencedIndexes;
final Set<MemtableIndex> memtableIndexes;
final IndexContext indexContext;

public QueryView(ColumnFamilyStore.RefViewFragment view,
public QueryView(ColumnFamilyStore.SortedRefViewFragment view,
Set<SSTableIndex> referencedIndexes,
Set<MemtableIndex> memtableIndexes,
IndexContext indexContext)
Expand Down Expand Up @@ -180,7 +181,7 @@ else if (MonotonicClock.approxTime.now() - failingSince > TimeUnit.MILLISECONDS.
// freeze referencedIndexes and memtableIndexes, so we can safely give access to them
// without risking something messes them up
// (this was added after KeyRangeTermIterator messed them up which led to a bug)
return new QueryView(refViewFragment,
return new QueryView(ColumnFamilyStore.SortedRefViewFragment.sortThenCreateFrom(refViewFragment),
Collections.unmodifiableSet(referencedIndexes),
Collections.unmodifiableSet(memtableIndexes),
indexContext);
Expand Down
Loading