Skip to content

Commit

Permalink
use +score pseudo-column to order ANN results with instead of recompu…
Browse files Browse the repository at this point in the history
…ting scores on the coordinator
  • Loading branch information
jbellis committed Nov 25, 2024
1 parent 36bdaaf commit 208f88f
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 92 deletions.
11 changes: 11 additions & 0 deletions src/java/org/apache/cassandra/cql3/Ordering.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public interface Expression
SingleRestriction toRestriction();

ColumnMetadata getColumn();

default boolean isScored()
{
return false;
}
}

/**
Expand Down Expand Up @@ -118,6 +123,12 @@ public ColumnMetadata getColumn()
{
return column;
}

@Override
public boolean isScored()
{
return true;
}
}

public enum Direction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ else if (indexOrderings.size() == 1)
if (orderings.size() > 1)
throw new InvalidRequestException("Cannot combine clustering column ordering with non-clustering column ordering");
Ordering ordering = indexOrderings.get(0);
if (ordering.direction != Ordering.Direction.ASC && ordering.expression instanceof Ordering.Ann)
if (ordering.direction != Ordering.Direction.ASC && ordering.expression.isScored())
throw new InvalidRequestException("Descending ANN ordering is not supported");
if (!ENABLE_SAI_GENERAL_ORDER_BY && ordering.expression instanceof Ordering.SingleColumn)
throw new InvalidRequestException("SAI based ORDER BY on non-vector column is not supported");
Expand Down
117 changes: 53 additions & 64 deletions src/java/org/apache/cassandra/cql3/statements/SelectStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.cql3.restrictions.ExternalRestriction;
import org.apache.cassandra.cql3.restrictions.Restrictions;
import org.apache.cassandra.cql3.restrictions.SingleRestriction;
import org.apache.cassandra.cql3.selection.SortedRowsBuilder;
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.guardrails.Guardrails;
import org.apache.cassandra.index.Index;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.Schema;
import org.apache.cassandra.schema.TableMetadata;
Expand Down Expand Up @@ -994,7 +993,7 @@ private ResultSet process(PartitionIterator partitions,
{
ProtocolVersion protocolVersion = options.getProtocolVersion();
GroupMaker groupMaker = aggregationSpec == null ? null : aggregationSpec.newGroupMaker();
SortedRowsBuilder rows = sortedRowsBuilder(userLimit, userOffset == NO_OFFSET ? 0 : userOffset, options);
SortedRowsBuilder rows = sortedRowsBuilder(userLimit, userOffset == NO_OFFSET ? 0 : userOffset);
ResultSetBuilder result = new ResultSetBuilder(protocolVersion, getResultMetadata(), selectors, groupMaker, rows);

while (partitions.hasNext())
Expand Down Expand Up @@ -1094,7 +1093,8 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil

private boolean needsToSkipUserLimit()
{
// if post query ordering is required, and it's not ordered by an index
// if we're querying by `pk IN (...)` and ordering by clustered columns, replicas don't sort
// before applying LIMIT
return needsPostQueryOrdering() && (orderingComparator != null && orderingComparator.isClustered());
}

Expand All @@ -1108,27 +1108,12 @@ private boolean needsPostQueryOrdering()
/**
* Orders results when multiple keys are selected (using IN)
*/
public SortedRowsBuilder sortedRowsBuilder(int limit, int offset, QueryOptions options)
public SortedRowsBuilder sortedRowsBuilder(int limit, int offset)
{
if (orderingComparator == null)
{
return SortedRowsBuilder.create(limit, offset);
}
else if (orderingComparator instanceof IndexColumnComparator)
{
SingleRestriction restriction = ((IndexColumnComparator) orderingComparator).restriction;
int columnIndex = ((IndexColumnComparator) orderingComparator).columnIndex;

Index index = restriction.findSupportingIndex(IndexRegistry.obtain(table));
assert index != null;

Index.Scorer scorer = index.postQueryScorer(restriction, columnIndex, options);
return SortedRowsBuilder.create(limit, offset, scorer);
}
else
{
return SortedRowsBuilder.create(limit, offset, orderingComparator);
}
return SortedRowsBuilder.create(limit, offset, orderingComparator);
}

public static class RawStatement extends QualifiedStatement<SelectStatement>
Expand Down Expand Up @@ -1172,13 +1157,21 @@ public SelectStatement prepare(boolean forView, UnaryOperator<String> keyspaceMa
List<Selectable> selectables = RawSelector.toSelectables(selectClause, table);
boolean containsOnlyStaticColumns = selectOnlyStaticColumns(table, selectables);

// Besides actual restrictions (where clauses), prepareRestrictions will include pseudo-restrictions
// on indexed columns to allow pushing ORDER BY into the index; see StatementRestrictions::addOrderingRestrictions.
// Therefore, we don't want to convert the Ordering column into a +score column until after that.
List<Ordering> orderings = getOrderings(table);
StatementRestrictions restrictions = prepareRestrictions(
table, bindVariables, orderings, containsOnlyStaticColumns, forView);

// If we order post-query, the sorted column needs to be in the ResultSet for sorting,
// even if we don't ultimately ship them to the client (CASSANDRA-4911).
Map<ColumnMetadata, Ordering> orderingColumns = getOrderingColumns(orderings);
// +score column for ANN/BM25
var scoreOrdering = getScoreOrdering(orderings);
assert scoreOrdering == null || orderingColumns.isEmpty() : "can't have both scored ordering and column ordering";
if (scoreOrdering != null)
orderingColumns = scoreOrdering;
Set<ColumnMetadata> resultSetOrderingColumns = getResultSetOrdering(restrictions, orderingColumns);

Selection selection = prepareSelection(table,
Expand Down Expand Up @@ -1209,7 +1202,7 @@ public SelectStatement prepare(boolean forView, UnaryOperator<String> keyspaceMa
assert !forView;
verifyOrderingIsAllowed(table, restrictions, orderingColumns);
orderingComparator = getOrderingComparator(selection, restrictions, orderingColumns);
isReversed = isReversed(table, orderingColumns, restrictions);
isReversed = isReversed(table, orderingColumns);
if (isReversed && orderingComparator != null)
orderingComparator = orderingComparator.reverse();
}
Expand All @@ -1232,6 +1225,27 @@ public SelectStatement prepare(boolean forView, UnaryOperator<String> keyspaceMa
prepareLimit(bindVariables, offset, ks, offsetReceiver()));
}

private Map<ColumnMetadata, Ordering> getScoreOrdering(List<Ordering> orderings)
{
if (orderings.isEmpty())
return null;

var expr = orderings.get(0).expression;
if (!expr.isScored())
return null;

// Create synthetic score column
// Use the original column's table metadata but create new identifier and type
ColumnMetadata sourceColumn = expr.getColumn();
var cm = new ColumnMetadata(sourceColumn.ksName,
sourceColumn.cfName,
new ColumnIdentifier("+score", true),
FloatType.instance,
ColumnMetadata.NO_POSITION,
ColumnMetadata.Kind.REGULAR);
return Map.of(cm, orderings.get(0));
}

private Set<ColumnMetadata> getResultSetOrdering(StatementRestrictions restrictions, Map<ColumnMetadata, Ordering> orderingColumns)
{
if (restrictions.keyIsInRelation() || orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering()))
Expand Down Expand Up @@ -1292,13 +1306,14 @@ private Map<ColumnMetadata, Ordering> getOrderingColumns(List<Ordering> ordering
if (orderings.isEmpty())
return Collections.emptyMap();

Map<ColumnMetadata, Ordering> orderingColumns = new LinkedHashMap<>();
for (Ordering ordering : orderings)
{
ColumnMetadata column = ordering.expression.getColumn();
orderingColumns.put(column, ordering);
}
return orderingColumns;
return orderings.stream()
.filter(ordering -> !ordering.expression.isScored())
.collect(Collectors.toMap(ordering -> ordering.expression.getColumn(),
ordering -> ordering,
(a, b) -> {
throw new IllegalStateException("Duplicate keys");
},
LinkedHashMap::new));
}

private List<Ordering> getOrderings(TableMetadata table)
Expand Down Expand Up @@ -1461,13 +1476,9 @@ private ColumnComparator<List<ByteBuffer>> getOrderingComparator(Selection selec
assert orderingColumns.size() == 1 : orderingColumns.keySet();
var e = orderingColumns.entrySet().iterator().next();
var column = e.getKey();
var ordering = e.getValue();
if (ordering.expression instanceof Ordering.Ann)
return new IndexColumnComparator(ordering.expression.toRestriction(), selection.getOrderingIndex(column));
else
return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false);
return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false);
}

if (!restrictions.keyIsInRelation())
return null;

Expand All @@ -1484,18 +1495,21 @@ private ColumnComparator<List<ByteBuffer>> getOrderingComparator(Selection selec
: new CompositeComparator(sorters, idToSort);
}

private boolean isReversed(TableMetadata table, Map<ColumnMetadata, Ordering> orderingColumns, StatementRestrictions restrictions) throws InvalidRequestException
private boolean isReversed(TableMetadata table, Map<ColumnMetadata, Ordering> orderingColumns) throws InvalidRequestException
{
Boolean[] clusteredMap = new Boolean[table.clusteringColumns().size()];
for (var entry : orderingColumns.entrySet())
{
ColumnMetadata def = entry.getKey();
Ordering ordering = entry.getValue();
boolean reversed = ordering.direction == Ordering.Direction.DESC;
// We defined ANN OF to be ASC ordering, as in, "order by near-ness". But since score goves from
// 0 (worst) to 1 (closest), we need to reverse the ordering for the comparator when we're sorting
// by synthetic +score column.
boolean cqlReversed = ordering.direction == Ordering.Direction.DESC;
if (def.position() == ColumnMetadata.NO_POSITION)
return reversed;
return ordering.expression.isScored() || cqlReversed;
else
clusteredMap[def.position()] = (reversed != def.isReversedType());
clusteredMap[def.position()] = (cqlReversed != def.isReversedType());
}

// Check that all boolean in clusteredMap, if set, agrees
Expand Down Expand Up @@ -1682,31 +1696,6 @@ public boolean isClustered()
}
}

private static class IndexColumnComparator extends ColumnComparator<List<ByteBuffer>>
{
private final SingleRestriction restriction;
private final int columnIndex;

// VSTODO maybe cache in prepared statement
public IndexColumnComparator(SingleRestriction restriction, int columnIndex)
{
this.restriction = restriction;
this.columnIndex = columnIndex;
}

@Override
public boolean isClustered()
{
return false;
}

@Override
public int compare(List<ByteBuffer> o1, List<ByteBuffer> o2)
{
throw new UnsupportedOperationException();
}
}

/**
* Used in orderResults(...) method when multiple 'ORDER BY' conditions where given
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public Builder addAll(RegularAndStaticColumns columns)

public RegularAndStaticColumns build()
{
return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns),
return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns),
regularColumns == null ? Columns.NONE : Columns.from(regularColumns));
}
}
Expand Down
Loading

0 comments on commit 208f88f

Please sign in to comment.