Skip to content

Commit

Permalink
add back old path of fetching vector from replicas and re-scoring on …
Browse files Browse the repository at this point in the history
…coordinator
  • Loading branch information
jbellis committed Dec 6, 2024
1 parent 4544422 commit 395711e
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 12 deletions.
3 changes: 2 additions & 1 deletion src/java/org/apache/cassandra/cql3/Ordering.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.cassandra.cql3.restrictions.SingleColumnRestriction;
import org.apache.cassandra.cql3.restrictions.SingleRestriction;
import org.apache.cassandra.cql3.statements.SelectStatement;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.TableMetadata;

Expand Down Expand Up @@ -127,7 +128,7 @@ public ColumnMetadata getColumn()
@Override
public boolean isScored()
{
return true;
return SelectStatement.ANN_USE_SYNTHETIC_SCORE;
}
}

Expand Down
54 changes: 51 additions & 3 deletions src/java/org/apache/cassandra/cql3/statements/SelectStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.cassandra.cql3.selection.SortedRowsBuilder;
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.guardrails.Guardrails;
import org.apache.cassandra.index.Index;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.Schema;
import org.apache.cassandra.schema.TableMetadata;
Expand Down Expand Up @@ -100,6 +101,10 @@
*/
public class SelectStatement implements CQLStatement.SingleKeyspaceCqlStatement
{
// TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns
// (And don't forget to remove the related hacks in Columns.Serializer.encodeBitmap and UnfilteredSerializer.serializeRowBody)
public static final boolean ANN_USE_SYNTHETIC_SCORE = Boolean.parseBoolean(System.getProperty("cassandra.sai.ann_use_synthetic_score", "false"));

private static final Logger logger = LoggerFactory.getLogger(SelectStatement.class);
private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(SelectStatement.logger, 1, TimeUnit.MINUTES);
public static final String TOPK_CONSISTENCY_LEVEL_ERROR = "Top-K queries can only be run with consistency level ONE/LOCAL_ONE. Consistency level %s was used.";
Expand Down Expand Up @@ -993,7 +998,7 @@ private ResultSet process(PartitionIterator partitions,
{
ProtocolVersion protocolVersion = options.getProtocolVersion();
GroupMaker groupMaker = aggregationSpec == null ? null : aggregationSpec.newGroupMaker();
SortedRowsBuilder rows = sortedRowsBuilder(userLimit, userOffset == NO_OFFSET ? 0 : userOffset);
SortedRowsBuilder rows = sortedRowsBuilder(userLimit, userOffset == NO_OFFSET ? 0 : userOffset, options);
ResultSetBuilder result = new ResultSetBuilder(protocolVersion, getResultMetadata(), selectors, groupMaker, rows);

while (partitions.hasNext())
Expand Down Expand Up @@ -1112,11 +1117,24 @@ private boolean needsPostQueryOrdering()
/**
* Orders results when multiple keys are selected (using IN)
*/
public SortedRowsBuilder sortedRowsBuilder(int limit, int offset)
public SortedRowsBuilder sortedRowsBuilder(int limit, int offset, QueryOptions options)
{
if (orderingComparator == null)
return SortedRowsBuilder.create(limit, offset);

if (orderingComparator instanceof IndexColumnComparator)
{
SingleRestriction restriction = ((IndexColumnComparator) orderingComparator).restriction;
int columnIndex = ((IndexColumnComparator) orderingComparator).columnIndex;

Index index = restriction.findSupportingIndex(IndexRegistry.obtain(table));
assert index != null;

Index.Scorer scorer = index.postQueryScorer(restriction, columnIndex, options);
return SortedRowsBuilder.create(limit, offset, scorer);
}

// else
return SortedRowsBuilder.create(limit, offset, orderingComparator);
}

Expand Down Expand Up @@ -1474,7 +1492,11 @@ private ColumnComparator<List<ByteBuffer>> getOrderingComparator(Selection selec
assert orderingColumns.size() == 1 : orderingColumns.keySet();
var e = orderingColumns.entrySet().iterator().next();
var column = e.getKey();
return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false);
var ordering = e.getValue();
if (ordering.expression instanceof Ordering.Ann && !ANN_USE_SYNTHETIC_SCORE)
return new IndexColumnComparator(ordering.expression.toRestriction(), selection.getOrderingIndex(column));
else
return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false);
}

if (!restrictions.keyIsInRelation())
Expand Down Expand Up @@ -1694,6 +1716,32 @@ public boolean isClustered()
}
}

// see usage in sortedRowsBuilder
private static class IndexColumnComparator extends ColumnComparator<List<ByteBuffer>>
{
private final SingleRestriction restriction;
private final int columnIndex;

// VSTODO maybe cache in prepared statement
public IndexColumnComparator(SingleRestriction restriction, int columnIndex)
{
this.restriction = restriction;
this.columnIndex = columnIndex;
}

@Override
public boolean isClustered()
{
return false;
}

@Override
public int compare(List<ByteBuffer> o1, List<ByteBuffer> o2)
{
throw new UnsupportedOperationException();
}
}

/**
* Used in orderResults(...) method when multiple 'ORDER BY' conditions where given
*/
Expand Down
8 changes: 8 additions & 0 deletions src/java/org/apache/cassandra/db/Columns.java
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,15 @@ private static long encodeBitmap(Collection<ColumnMetadata> columns, Columns sup
for (ColumnMetadata column : columns)
{
if (iter.next(column) == null)
{
// We can't know for sure whether to add the synthetic score column because WildcardColumnFilter
// just says "yes" to everything; instead, we just skip it here.
// TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE.
if (column.isSynthetic())
continue;

throw new IllegalStateException(columns + " is not a subset of " + superset);
}

int currentIndex = iter.indexOfCurrent();
int count = currentIndex - expectIndex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ private void serializeRowBody(Row row, int flags, SerializationHelper helper, Da
// with. So we use the ColumnMetadata from the "header" which is "current". Also see #11810 for what
// happens if we don't do that.
ColumnMetadata column = si.next(cd.column());
// We can't know for sure whether to add the synthetic score column because WildcardColumnFilter
// just says "yes" to everything; instead, we just skip it here.
// TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE.
if (column == null)
return;
assert column != null : cd.column.toString();

try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,8 @@ public RowFilter getPostIndexQueryFilter(RowFilter filter)
@Override
public Scorer postQueryScorer(Restriction restriction, int columnIndex, QueryOptions options)
{
// For now, only support ANN
assert restriction instanceof SingleColumnRestriction.AnnRestriction || restriction instanceof SingleColumnRestriction.Bm25Restriction;
// TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE.
assert restriction instanceof SingleColumnRestriction.AnnRestriction;

Preconditions.checkState(indexContext.isVector());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
Expand All @@ -46,7 +43,6 @@
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.SSTableContext;
import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer;
import org.apache.cassandra.index.sai.disk.PostingList;
import org.apache.cassandra.index.sai.disk.TermsIterator;
import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
Expand All @@ -59,7 +55,6 @@
import org.apache.cassandra.index.sai.plan.Orderer;
import org.apache.cassandra.index.sai.utils.BM25Utils;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable;
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,20 @@ private static ByteBuffer floatVectorToByteBuffer(CQLTester.Vector<Float> vector

@Test
public void testOrderResults() {
QueryOptions queryOptions = QueryOptions.create(ConsistencyLevel.ONE,
byteBufferList,
false,
PageSize.inRows(1),
null,
null,
ProtocolVersion.CURRENT,
KEYSPACE);
List<List<ByteBuffer>> rows = new ArrayList<>();
rows.add(byteBufferList);

SelectStatement selectStatementInstance = (SelectStatement) QueryProcessor.prepareInternal("SELECT key, value FROM " + KEYSPACE + '.' + TABLE).statement;

SortedRowsBuilder builder = selectStatementInstance.sortedRowsBuilder(Integer.MAX_VALUE, 0);
SortedRowsBuilder builder = selectStatementInstance.sortedRowsBuilder(Integer.MAX_VALUE, 0, queryOptions);
rows.forEach(builder::add);
List<List<ByteBuffer>> sortedRows = builder.build();

Expand Down

0 comments on commit 395711e

Please sign in to comment.