From ddd9b691af164681f371f3e1f6a5e84c5a7bca2c Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 19 Nov 2024 10:56:32 -0600 Subject: [PATCH 01/29] remove unnecessary generification of IndexColumnComparator --- .../cql3/statements/SelectStatement.java | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 3f4a2e9e55e6..542116950a3a 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -25,7 +25,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import com.google.common.base.Preconditions; import com.google.common.math.IntMath; import org.apache.cassandra.cql3.Ordering; @@ -1102,7 +1101,8 @@ private boolean needsToSkipUserLimit() private boolean needsPostQueryOrdering() { // We need post-query ordering only for queries with IN on the partition key and an ORDER BY or index restriction reordering - return restrictions.keyIsInRelation() && !parameters.orderings.isEmpty() || needIndexOrdering(); + return (restrictions.keyIsInRelation() && !parameters.orderings.isEmpty()) + || needIndexOrdering(); } private boolean needIndexOrdering() @@ -1125,8 +1125,8 @@ public SortedRowsBuilder sortedRowsBuilder(int limit, int offset, QueryOptions o } else if (orderingComparator instanceof IndexColumnComparator) { - SingleRestriction restriction = ((IndexColumnComparator) orderingComparator).restriction; - int columnIndex = ((IndexColumnComparator) orderingComparator).columnIndex; + SingleRestriction restriction = ((IndexColumnComparator) orderingComparator).restriction; + int columnIndex = ((IndexColumnComparator) orderingComparator).columnIndex; Index index = restriction.findSupportingIndex(IndexRegistry.obtain(table)); assert index != null; @@ -1455,20 +1455,18 @@ private ColumnComparator> getOrderingComparator(Selection selec Map orderingColumns) throws InvalidRequestException { - for (Map.Entry e : orderingColumns.entrySet()) + if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) { - if (e.getValue().expression.hasNonClusteredOrdering()) - { - Preconditions.checkState(orderingColumns.size() == 1); - return new IndexColumnComparator<>(e.getValue().expression.toRestriction(), selection.getOrderingIndex(e.getKey())); - } + assert orderingColumns.size() == 1 : orderingColumns.keySet(); + var e = orderingColumns.entrySet().iterator().next(); + return new IndexColumnComparator(e.getValue().expression.toRestriction(), selection.getOrderingIndex(e.getKey())); } - + if (!restrictions.keyIsInRelation()) return null; - List idToSort = new ArrayList(orderingColumns.size()); - List> sorters = new ArrayList>(orderingColumns.size()); + List idToSort = new ArrayList<>(orderingColumns.size()); + List> sorters = new ArrayList<>(orderingColumns.size()); for (ColumnMetadata orderingColumn : orderingColumns.keySet()) { @@ -1671,7 +1669,7 @@ public int compare(List a, List b) } } - private static class IndexColumnComparator extends ColumnComparator> + private static class IndexColumnComparator extends ColumnComparator> { private final SingleRestriction restriction; private final int columnIndex; From 9f1b794d031eea679012cebd2002e2c3cae15446 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 19 Nov 2024 11:42:06 -0600 Subject: [PATCH 02/29] Simplify the ordering logic by making IndexColumnComparator only responsible for ANN index queries. Other global orderings will be represented by a SingleColumnComparator with clustered=true instead. --- .../cql3/statements/SelectStatement.java | 103 ++++++++++-------- .../org/apache/cassandra/index/Index.java | 15 --- .../index/sai/StorageAttachedIndex.java | 12 -- .../operations/SelectOrderByTest.java | 2 +- 4 files changed, 59 insertions(+), 73 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 542116950a3a..51ae29cb48cd 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -1095,19 +1095,14 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil private boolean needsToSkipUserLimit() { // if post query ordering is required, and it's not ordered by an index - return needsPostQueryOrdering() && !needIndexOrdering(); + return needsPostQueryOrdering() && (orderingComparator != null && orderingComparator.isClustered()); } private boolean needsPostQueryOrdering() { // We need post-query ordering only for queries with IN on the partition key and an ORDER BY or index restriction reordering return (restrictions.keyIsInRelation() && !parameters.orderings.isEmpty()) - || needIndexOrdering(); - } - - private boolean needIndexOrdering() - { - return orderingComparator != null && orderingComparator.indexOrdering(); + || orderingComparator != null; } /** @@ -1115,10 +1110,6 @@ private boolean needIndexOrdering() */ public SortedRowsBuilder sortedRowsBuilder(int limit, int offset, QueryOptions options) { - assert (orderingComparator != null) == needsPostQueryOrdering() - : String.format("orderingComparator: %s, needsPostQueryOrdering: %s", - orderingComparator, needsPostQueryOrdering()); - if (orderingComparator == null) { return SortedRowsBuilder.create(limit, offset); @@ -1131,12 +1122,6 @@ else if (orderingComparator instanceof IndexColumnComparator) Index index = restriction.findSupportingIndex(IndexRegistry.obtain(table)); assert index != null; - if (restriction instanceof SingleColumnRestriction.OrderRestriction) - { - var comparator = index.postQueryComparator(restriction, columnIndex, options); - return SortedRowsBuilder.create(limit, offset, comparator); - } - Index.Scorer scorer = index.postQueryScorer(restriction, columnIndex, options); return SortedRowsBuilder.create(limit, offset, scorer); } @@ -1222,7 +1207,7 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa if (!orderingColumns.isEmpty()) { assert !forView; - verifyOrderingIsAllowed(restrictions, orderingColumns); + verifyOrderingIsAllowed(table, restrictions, orderingColumns); orderingComparator = getOrderingComparator(selection, restrictions, orderingColumns); isReversed = isReversed(table, orderingColumns, restrictions); if (isReversed && orderingComparator != null) @@ -1360,12 +1345,28 @@ private Term prepareLimit(VariableSpecifications boundNames, Term.Raw limit, return prepLimit; } - private static void verifyOrderingIsAllowed(StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException + private static void verifyOrderingIsAllowed(TableMetadata table, StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException { if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) return; + checkFalse(restrictions.usesSecondaryIndexing(), "ORDER BY with 2ndary indexes is not supported."); checkFalse(restrictions.isKeyRange(), "ORDER BY is only supported when the partition key is restricted by an EQ or an IN."); + + // check that clustering columns are valid + int i = 0; + for (var entry : orderingColumns.entrySet()) + { + ColumnMetadata def = entry.getKey(); + checkTrue(def.isClusteringColumn(), + "Order by is currently only supported on indexed columns and the clustered columns of the PRIMARY KEY, got %s", def.name); + while (i != def.position()) + { + checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)), + "Ordering by clustered columns must follow the declared order in the PRIMARY KEY"); + } + i++; + } } private static void validateDistinctSelection(TableMetadata metadata, @@ -1459,7 +1460,12 @@ private ColumnComparator> getOrderingComparator(Selection selec { assert orderingColumns.size() == 1 : orderingColumns.keySet(); var e = orderingColumns.entrySet().iterator().next(); - return new IndexColumnComparator(e.getValue().expression.toRestriction(), selection.getOrderingIndex(e.getKey())); + var column = e.getKey(); + var ordering = e.getValue(); + if (ordering.expression instanceof Ordering.Ann) + return new IndexColumnComparator(ordering.expression.toRestriction(), selection.getOrderingIndex(column)); + else + return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false); } if (!restrictions.keyIsInRelation()) @@ -1480,33 +1486,21 @@ private ColumnComparator> getOrderingComparator(Selection selec private boolean isReversed(TableMetadata table, Map orderingColumns, StatementRestrictions restrictions) throws InvalidRequestException { - // Nonclustered ordering handles descending logic in a different way - if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) - return false; - - Boolean[] reversedMap = new Boolean[table.clusteringColumns().size()]; - int i = 0; + Boolean[] clusteredMap = new Boolean[table.clusteringColumns().size()]; for (var entry : orderingColumns.entrySet()) { ColumnMetadata def = entry.getKey(); Ordering ordering = entry.getValue(); boolean reversed = ordering.direction == Ordering.Direction.DESC; - - // VSTODO move this to verifyOrderingIsAllowed? - checkTrue(def.isClusteringColumn(), - "Order by is currently only supported on the clustered columns of the PRIMARY KEY, got %s", def.name); - while (i != def.position()) - { - checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)), - "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY"); - } - i++; - reversedMap[def.position()] = (reversed != def.isReversedType()); + if (def.position() == ColumnMetadata.NO_POSITION) + return reversed; + else + clusteredMap[def.position()] = (reversed != def.isReversedType()); } - // Check that all boolean in reversedMap, if set, agrees + // Check that all boolean in clusteredMap, if set, agrees Boolean isReversed = null; - for (Boolean b : reversedMap) + for (Boolean b : clusteredMap) { // Column on which order is specified can be in any order if (b == null) @@ -1626,11 +1620,11 @@ public ColumnComparator reverse() } /** - * @return true if ordering is performed by index + * @return true if ordering is performed by classic collation columns */ - public boolean indexOrdering() + public boolean isClustered() { - return false; + return true; } } @@ -1648,6 +1642,12 @@ public int compare(T o1, T o2) { return wrapped.compare(o2, o1); } + + @Override + public boolean isClustered() + { + return wrapped.isClustered(); + } } /** * Used in orderResults(...) method when single 'ORDER BY' condition where given @@ -1656,17 +1656,30 @@ private static class SingleColumnComparator extends ColumnComparator comparator; + private final boolean clustered; - public SingleColumnComparator(int columnIndex, Comparator orderer) + public SingleColumnComparator(int columnIndex, Comparator orderer, boolean clustered) { index = columnIndex; comparator = orderer; + this.clustered = clustered; + } + + public SingleColumnComparator(int columnIndex, Comparator orderer) + { + this(columnIndex, orderer, true); } public int compare(List a, List b) { return compare(comparator, a.get(index), b.get(index)); } + + @Override + public boolean isClustered() + { + return clustered; + } } private static class IndexColumnComparator extends ColumnComparator> @@ -1682,9 +1695,9 @@ public IndexColumnComparator(SingleRestriction restriction, int columnIndex) } @Override - public boolean indexOrdering() + public boolean isClustered() { - return true; + return false; } @Override diff --git a/src/java/org/apache/cassandra/index/Index.java b/src/java/org/apache/cassandra/index/Index.java index 7839df669495..46024d2e6adf 100644 --- a/src/java/org/apache/cassandra/index/Index.java +++ b/src/java/org/apache/cassandra/index/Index.java @@ -23,7 +23,6 @@ import java.nio.ByteBuffer; import java.util.Collection; import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -471,20 +470,6 @@ interface Analyzer */ public RowFilter getPostIndexQueryFilter(RowFilter filter); - /** - * Returns a {@link Comparator} of CQL result rows, so they can be ordered by the - * coordinator before sending them to client. - * - * @param restriction restriction that requires current index - * @param columnIndex idx of the indexed column in returned row - * @param options query options - * @return a comparator of rows - */ - default Comparator> postQueryComparator(Restriction restriction, int columnIndex, QueryOptions options) - { - throw new NotImplementedException(); - } - /** * Returns a {@link Scorer} to give a similarity/proximity score to CQL result rows, so they can be ordered by the * coordinator before sending them to client. diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index d974da34ccde..6b9e2582d3ef 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -680,18 +680,6 @@ public RowFilter getPostIndexQueryFilter(RowFilter filter) throw new UnsupportedOperationException(); } - @Override - public Comparator> postQueryComparator(Restriction restriction, int columnIndex, QueryOptions options) - { - assert restriction instanceof SingleColumnRestriction.OrderRestriction; - - SingleColumnRestriction.OrderRestriction orderRestriction = (SingleColumnRestriction.OrderRestriction) restriction; - var typeComparator = orderRestriction.getDirection() == Operator.ORDER_BY_DESC - ? indexContext.getValidator().reversed() - : indexContext.getValidator(); - return (a, b) -> typeComparator.compare(a.get(columnIndex), b.get(columnIndex)); - } - @Override public Scorer postQueryScorer(Restriction restriction, int columnIndex, QueryOptions options) { diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java index dac86c76e5d8..34f7d606ce55 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java @@ -659,7 +659,7 @@ public void testAllowSkippingEqualityAndSingleValueInRestrictedClusteringColumns assertInvalidMessage("Cannot combine clustering column ordering with non-clustering column ordering", "SELECT * FROM %s WHERE a=? ORDER BY b ASC, c ASC, d ASC", 0); - String errorMsg = "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY"; + String errorMsg = "Ordering by clustered columns must follow the declared order in the PRIMARY KEY"; assertRows(execute("SELECT * FROM %s WHERE a=? AND b=? ORDER BY c", 0, 0), row(0, 0, 0, 0), From 99861083af9975e2a431e98a61def7ddc1f050e1 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 19 Nov 2024 12:48:05 -0600 Subject: [PATCH 03/29] CNDB-11725 use +score pseudo-column to order ANN results with instead of recomputing scores on the coordinator --- .../org/apache/cassandra/cql3/Ordering.java | 11 ++ .../restrictions/StatementRestrictions.java | 2 +- .../cql3/statements/SelectStatement.java | 117 ++++++++---------- .../cassandra/db/RegularAndStaticColumns.java | 2 +- .../plan/StorageAttachedIndexSearcher.java | 80 ++++++++---- .../index/sai/utils/PrimaryKeyWithScore.java | 2 +- .../cassandra/schema/TableMetadata.java | 5 +- .../index/sai/StorageAttachedIndexTest.java | 2 +- 8 files changed, 129 insertions(+), 92 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/Ordering.java b/src/java/org/apache/cassandra/cql3/Ordering.java index 81aa94a076cd..f5760f38510e 100644 --- a/src/java/org/apache/cassandra/cql3/Ordering.java +++ b/src/java/org/apache/cassandra/cql3/Ordering.java @@ -48,6 +48,11 @@ public interface Expression SingleRestriction toRestriction(); ColumnMetadata getColumn(); + + default boolean isScored() + { + return false; + } } /** @@ -118,6 +123,12 @@ public ColumnMetadata getColumn() { return column; } + + @Override + public boolean isScored() + { + return true; + } } public enum Direction diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index a9dbf41e8559..d28e16fa399f 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -682,7 +682,7 @@ else if (indexOrderings.size() == 1) if (orderings.size() > 1) throw new InvalidRequestException("Cannot combine clustering column ordering with non-clustering column ordering"); Ordering ordering = indexOrderings.get(0); - if (ordering.direction != Ordering.Direction.ASC && ordering.expression instanceof Ordering.Ann) + if (ordering.direction != Ordering.Direction.ASC && ordering.expression.isScored()) throw new InvalidRequestException("Descending ANN ordering is not supported"); if (!ENABLE_SAI_GENERAL_ORDER_BY && ordering.expression instanceof Ordering.SingleColumn) throw new InvalidRequestException("SAI based ORDER BY on non-vector column is not supported"); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 51ae29cb48cd..3ab85e5adf6d 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -38,10 +38,9 @@ import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.cql3.restrictions.ExternalRestriction; import org.apache.cassandra.cql3.restrictions.Restrictions; -import org.apache.cassandra.cql3.restrictions.SingleRestriction; import org.apache.cassandra.cql3.selection.SortedRowsBuilder; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.guardrails.Guardrails; -import org.apache.cassandra.index.Index; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.TableMetadata; @@ -994,7 +993,7 @@ private ResultSet process(PartitionIterator partitions, { ProtocolVersion protocolVersion = options.getProtocolVersion(); GroupMaker groupMaker = aggregationSpec == null ? null : aggregationSpec.newGroupMaker(); - SortedRowsBuilder rows = sortedRowsBuilder(userLimit, userOffset == NO_OFFSET ? 0 : userOffset, options); + SortedRowsBuilder rows = sortedRowsBuilder(userLimit, userOffset == NO_OFFSET ? 0 : userOffset); ResultSetBuilder result = new ResultSetBuilder(protocolVersion, getResultMetadata(), selectors, groupMaker, rows); while (partitions.hasNext()) @@ -1094,7 +1093,8 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil private boolean needsToSkipUserLimit() { - // if post query ordering is required, and it's not ordered by an index + // if we're querying by `pk IN (...)` and ordering by clustered columns, replicas don't sort + // before applying LIMIT return needsPostQueryOrdering() && (orderingComparator != null && orderingComparator.isClustered()); } @@ -1108,27 +1108,12 @@ private boolean needsPostQueryOrdering() /** * Orders results when multiple keys are selected (using IN) */ - public SortedRowsBuilder sortedRowsBuilder(int limit, int offset, QueryOptions options) + public SortedRowsBuilder sortedRowsBuilder(int limit, int offset) { if (orderingComparator == null) - { return SortedRowsBuilder.create(limit, offset); - } - else if (orderingComparator instanceof IndexColumnComparator) - { - SingleRestriction restriction = ((IndexColumnComparator) orderingComparator).restriction; - int columnIndex = ((IndexColumnComparator) orderingComparator).columnIndex; - Index index = restriction.findSupportingIndex(IndexRegistry.obtain(table)); - assert index != null; - - Index.Scorer scorer = index.postQueryScorer(restriction, columnIndex, options); - return SortedRowsBuilder.create(limit, offset, scorer); - } - else - { - return SortedRowsBuilder.create(limit, offset, orderingComparator); - } + return SortedRowsBuilder.create(limit, offset, orderingComparator); } public static class RawStatement extends QualifiedStatement @@ -1172,6 +1157,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa List selectables = RawSelector.toSelectables(selectClause, table); boolean containsOnlyStaticColumns = selectOnlyStaticColumns(table, selectables); + // Besides actual restrictions (where clauses), prepareRestrictions will include pseudo-restrictions + // on indexed columns to allow pushing ORDER BY into the index; see StatementRestrictions::addOrderingRestrictions. + // Therefore, we don't want to convert the Ordering column into a +score column until after that. List orderings = getOrderings(table); StatementRestrictions restrictions = prepareRestrictions( table, bindVariables, orderings, containsOnlyStaticColumns, forView); @@ -1179,6 +1167,11 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa // If we order post-query, the sorted column needs to be in the ResultSet for sorting, // even if we don't ultimately ship them to the client (CASSANDRA-4911). Map orderingColumns = getOrderingColumns(orderings); + // +score column for ANN/BM25 + var scoreOrdering = getScoreOrdering(orderings); + assert scoreOrdering == null || orderingColumns.isEmpty() : "can't have both scored ordering and column ordering"; + if (scoreOrdering != null) + orderingColumns = scoreOrdering; Set resultSetOrderingColumns = getResultSetOrdering(restrictions, orderingColumns); Selection selection = prepareSelection(table, @@ -1209,7 +1202,7 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa assert !forView; verifyOrderingIsAllowed(table, restrictions, orderingColumns); orderingComparator = getOrderingComparator(selection, restrictions, orderingColumns); - isReversed = isReversed(table, orderingColumns, restrictions); + isReversed = isReversed(table, orderingColumns); if (isReversed && orderingComparator != null) orderingComparator = orderingComparator.reverse(); } @@ -1232,6 +1225,27 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa prepareLimit(bindVariables, offset, ks, offsetReceiver())); } + private Map getScoreOrdering(List orderings) + { + if (orderings.isEmpty()) + return null; + + var expr = orderings.get(0).expression; + if (!expr.isScored()) + return null; + + // Create synthetic score column + // Use the original column's table metadata but create new identifier and type + ColumnMetadata sourceColumn = expr.getColumn(); + var cm = new ColumnMetadata(sourceColumn.ksName, + sourceColumn.cfName, + new ColumnIdentifier("+score", true), + FloatType.instance, + ColumnMetadata.NO_POSITION, + ColumnMetadata.Kind.REGULAR); + return Map.of(cm, orderings.get(0)); + } + private Set getResultSetOrdering(StatementRestrictions restrictions, Map orderingColumns) { if (restrictions.keyIsInRelation() || orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) @@ -1292,13 +1306,14 @@ private Map getOrderingColumns(List ordering if (orderings.isEmpty()) return Collections.emptyMap(); - Map orderingColumns = new LinkedHashMap<>(); - for (Ordering ordering : orderings) - { - ColumnMetadata column = ordering.expression.getColumn(); - orderingColumns.put(column, ordering); - } - return orderingColumns; + return orderings.stream() + .filter(ordering -> !ordering.expression.isScored()) + .collect(Collectors.toMap(ordering -> ordering.expression.getColumn(), + ordering -> ordering, + (a, b) -> { + throw new IllegalStateException("Duplicate keys"); + }, + LinkedHashMap::new)); } private List getOrderings(TableMetadata table) @@ -1461,13 +1476,9 @@ private ColumnComparator> getOrderingComparator(Selection selec assert orderingColumns.size() == 1 : orderingColumns.keySet(); var e = orderingColumns.entrySet().iterator().next(); var column = e.getKey(); - var ordering = e.getValue(); - if (ordering.expression instanceof Ordering.Ann) - return new IndexColumnComparator(ordering.expression.toRestriction(), selection.getOrderingIndex(column)); - else - return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false); + return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false); } - + if (!restrictions.keyIsInRelation()) return null; @@ -1484,18 +1495,21 @@ private ColumnComparator> getOrderingComparator(Selection selec : new CompositeComparator(sorters, idToSort); } - private boolean isReversed(TableMetadata table, Map orderingColumns, StatementRestrictions restrictions) throws InvalidRequestException + private boolean isReversed(TableMetadata table, Map orderingColumns) throws InvalidRequestException { Boolean[] clusteredMap = new Boolean[table.clusteringColumns().size()]; for (var entry : orderingColumns.entrySet()) { ColumnMetadata def = entry.getKey(); Ordering ordering = entry.getValue(); - boolean reversed = ordering.direction == Ordering.Direction.DESC; + // We defined ANN OF to be ASC ordering, as in, "order by near-ness". But since score goves from + // 0 (worst) to 1 (closest), we need to reverse the ordering for the comparator when we're sorting + // by synthetic +score column. + boolean cqlReversed = ordering.direction == Ordering.Direction.DESC; if (def.position() == ColumnMetadata.NO_POSITION) - return reversed; + return ordering.expression.isScored() || cqlReversed; else - clusteredMap[def.position()] = (reversed != def.isReversedType()); + clusteredMap[def.position()] = (cqlReversed != def.isReversedType()); } // Check that all boolean in clusteredMap, if set, agrees @@ -1682,31 +1696,6 @@ public boolean isClustered() } } - private static class IndexColumnComparator extends ColumnComparator> - { - private final SingleRestriction restriction; - private final int columnIndex; - - // VSTODO maybe cache in prepared statement - public IndexColumnComparator(SingleRestriction restriction, int columnIndex) - { - this.restriction = restriction; - this.columnIndex = columnIndex; - } - - @Override - public boolean isClustered() - { - return false; - } - - @Override - public int compare(List o1, List o2) - { - throw new UnsupportedOperationException(); - } - } - /** * Used in orderResults(...) method when multiple 'ORDER BY' conditions where given */ diff --git a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java index b6da183d013f..05c5251c34e8 100644 --- a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java +++ b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java @@ -197,7 +197,7 @@ public Builder addAll(RegularAndStaticColumns columns) public RegularAndStaticColumns build() { - return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns), + return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns), regularColumns == null ? Columns.NONE : Columns.from(regularColumns)); } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index ad0ab70f2e9f..774f06745110 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -18,17 +18,16 @@ package org.apache.cassandra.index.sai.plan; +import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.function.Supplier; import java.util.stream.Collectors; - import javax.annotation.Nonnull; import javax.annotation.Nullable; import com.google.common.base.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,8 +37,12 @@ import org.apache.cassandra.db.PartitionPosition; import org.apache.cassandra.db.ReadCommand; import org.apache.cassandra.db.ReadExecutionController; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.db.rows.AbstractUnfilteredRowIterator; +import org.apache.cassandra.db.rows.BTreeRow; +import org.apache.cassandra.db.rows.BufferCell; +import org.apache.cassandra.db.rows.ColumnData; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.Unfiltered; import org.apache.cassandra.db.rows.UnfilteredRowIterator; @@ -53,13 +56,16 @@ import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.TableQueryMetrics; import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; +import org.apache.cassandra.index.sai.utils.RangeUtil; import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.btree.BTree; public class StorageAttachedIndexSearcher implements Index.Searcher { @@ -103,8 +109,11 @@ public UnfilteredPartitionIterator search(ReadExecutionController executionContr { assert !(keysIterator instanceof KeyRangeIterator); var scoredKeysIterator = (CloseableIterator) keysIterator; - var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, - executionController, queryContext); + var result = new ScoreOrderedResultRetriever(scoredKeysIterator, + filterTree, + controller, + executionController, + queryContext); return (UnfilteredPartitionIterator) new TopKProcessor(command).filter(result); } else @@ -558,14 +567,25 @@ public boolean shouldInclude(PrimaryKeyWithSortKey key, Row row) return true; } + @Override + public TableMetadata metadata() + { + return controller.metadata(); + } + + public void close() + { + FileUtils.closeQuietly(scoredPrimaryKeyIterator); + controller.finish(); + } + public static class PrimaryKeyIterator extends AbstractUnfilteredRowIterator { private boolean consumed = false; private final Unfiltered row; public final PrimaryKeyWithSortKey primaryKeyWithSortKey; - public PrimaryKeyIterator(PrimaryKeyWithSortKey key, UnfilteredRowIterator partition, Row staticRow, Unfiltered content) - { + public PrimaryKeyIterator(PrimaryKeyWithSortKey key, UnfilteredRowIterator partition, Row staticRow, Unfiltered content) { super(partition.metadata(), partition.partitionKey(), partition.partitionLevelDeletion(), @@ -574,31 +594,45 @@ public PrimaryKeyIterator(PrimaryKeyWithSortKey key, UnfilteredRowIterator parti partition.isReverseOrder(), partition.stats()); - row = content; - primaryKeyWithSortKey = key; + this.primaryKeyWithSortKey = key; + + if (!content.isRow() || !(key instanceof PrimaryKeyWithScore)) + { + this.row = content; + return; + } + + var tm = metadata(); + var scoreColumn = ColumnMetadata.regularColumn(tm.keyspace, tm.name, "+score", FloatType.instance); + + // clone the original Row + Row originalRow = (Row) content; + ArrayList columnData = new ArrayList<>(originalRow.columnCount() + 1); + columnData.addAll(originalRow.columnData()); + + // inject +score as a new column + var pkWithScore = (PrimaryKeyWithScore) key; + columnData.add(BufferCell.live(scoreColumn, + FBUtilities.nowInSeconds(), + FloatType.instance.decompose(pkWithScore.indexScore))); + + this.row = BTreeRow.create(originalRow.clustering(), + originalRow.primaryKeyLivenessInfo(), + originalRow.deletion(), + BTree.builder(ColumnData.comparator) + .auto(true) + .addAll(columnData) + .build()); } @Override - protected Unfiltered computeNext() - { + protected Unfiltered computeNext() { if (consumed) return endOfData(); consumed = true; return row; } } - - @Override - public TableMetadata metadata() - { - return controller.metadata(); - } - - public void close() - { - FileUtils.closeQuietly(scoredPrimaryKeyIterator); - controller.finish(); - } } private static UnfilteredRowIterator applyIndexFilter(UnfilteredRowIterator partition, FilterTree tree, QueryContext queryContext) diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java index 8b7b4acba9c6..61e63d43c7fa 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java @@ -28,7 +28,7 @@ */ public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey { - private final float indexScore; + public final float indexScore; public PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore) { diff --git a/src/java/org/apache/cassandra/schema/TableMetadata.java b/src/java/org/apache/cassandra/schema/TableMetadata.java index ba5e1db8d84d..6442326c9a2d 100644 --- a/src/java/org/apache/cassandra/schema/TableMetadata.java +++ b/src/java/org/apache/cassandra/schema/TableMetadata.java @@ -167,6 +167,9 @@ protected TableMetadata(Builder builder) name = builder.name; id = builder.id; + // FIXME + builder.addColumn(new ColumnMetadata(keyspace, name, new ColumnIdentifier("+score", true), FloatType.instance, ColumnMetadata.NO_POSITION, ColumnMetadata.Kind.REGULAR)); + partitioner = builder.partitioner; kind = builder.kind; params = builder.params.build(); @@ -1123,7 +1126,7 @@ public Builder addStaticColumn(ColumnIdentifier name, AbstractType type) public Builder addColumn(ColumnMetadata column) { if (columns.containsKey(column.name.bytes)) - throw new IllegalArgumentException(); + return this; // FIXME switch (column.kind) { diff --git a/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java b/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java index a32b6347f6b2..f475e466e28c 100644 --- a/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java @@ -148,7 +148,7 @@ public void testOrderResults() { SelectStatement selectStatementInstance = (SelectStatement) QueryProcessor.prepareInternal("SELECT key, value FROM " + KEYSPACE + '.' + TABLE).statement; - SortedRowsBuilder builder = selectStatementInstance.sortedRowsBuilder(Integer.MAX_VALUE, 0, queryOptions); + SortedRowsBuilder builder = selectStatementInstance.sortedRowsBuilder(Integer.MAX_VALUE, 0); rows.forEach(builder::add); List> sortedRows = builder.build(); From e0ea87253a6911542295569741fd9f36bd3c67de Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 19 Nov 2024 15:39:18 -0600 Subject: [PATCH 04/29] CNDB-11725 add SYNTHETIC ColumnMetadata.Kind to represent the score column --- .../cql3/selection/ColumnFilterFactory.java | 16 ++- .../cassandra/cql3/selection/Selection.java | 23 ++++- .../cql3/statements/SelectStatement.java | 18 ++-- src/java/org/apache/cassandra/db/Columns.java | 99 ++++++++++++++++--- .../org/apache/cassandra/db/ReadCommand.java | 20 +++- .../cassandra/db/RegularAndStaticColumns.java | 2 +- .../cassandra/db/filter/ColumnFilter.java | 25 +++-- .../plan/StorageAttachedIndexSearcher.java | 11 ++- .../cassandra/schema/ColumnMetadata.java | 35 +++++-- .../cassandra/schema/TableMetadata.java | 6 +- .../index/sai/StorageAttachedIndexTest.java | 8 -- 11 files changed, 196 insertions(+), 67 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java index 63fa0520101e..00225cca4108 100644 --- a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java +++ b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java @@ -38,9 +38,21 @@ abstract class ColumnFilterFactory */ abstract ColumnFilter newInstance(List selectors); - public static ColumnFilterFactory wildcard(TableMetadata table) + public static ColumnFilterFactory wildcard(TableMetadata table, Set orderingColumns) { - return new PrecomputedColumnFilter(ColumnFilter.all(table)); + ColumnFilter cf; + if (orderingColumns.isEmpty()) + { + cf = ColumnFilter.all(table); + } + else + { + ColumnFilter.Builder builder = ColumnFilter.selectionBuilder(); + builder.addAll(table.regularAndStaticColumns()); + builder.addAll(orderingColumns); + cf = builder.build(); + } + return new PrecomputedColumnFilter(cf); } public static ColumnFilterFactory fromColumns(TableMetadata table, diff --git a/src/java/org/apache/cassandra/cql3/selection/Selection.java b/src/java/org/apache/cassandra/cql3/selection/Selection.java index 02aae61dd5ff..12d8aa014e19 100644 --- a/src/java/org/apache/cassandra/cql3/selection/Selection.java +++ b/src/java/org/apache/cassandra/cql3/selection/Selection.java @@ -43,10 +43,24 @@ public abstract class Selection private static final Predicate STATIC_COLUMN_FILTER = (column) -> column.isStatic(); private final TableMetadata table; + + // Full list of columns needed for processing the query, including selected columns, ordering columns, + // and columns needed for restrictions. Wildcard columns are fully materialized here. + // + // This also includes synthetic columns, because unlike all the other not-physical-columns selectables, they are + // computed on the replica instead of the coordinator and so, like physical columns, they need to be sent back + // as part of the result. private final List columns; + + // maps ColumnSpecifications (columns, function calls, aliases) to the columns backing them private final SelectionColumnMapping columnMapping; + + // metadata matching the ColumnSpcifications protected final ResultSet.ResultMetadata metadata; + + // creates a ColumnFilter that breaks columns into `queried` and `fetched` protected final ColumnFilterFactory columnFilterFactory; + protected final boolean isJson; // Columns used to order the result set for JSON queries with post ordering. @@ -126,10 +140,15 @@ public ResultSet.ResultMetadata getResultMetadata() } public static Selection wildcard(TableMetadata table, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows) + { + return wildcard(table, Collections.emptySet(), isJson, returnStaticContentOnPartitionWithNoRows); + } + + public static Selection wildcard(TableMetadata table, Set orderingColumns, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows) { List all = new ArrayList<>(table.columns().size()); Iterators.addAll(all, table.allColumnsInSelectOrder()); - return new SimpleSelection(table, all, Collections.emptySet(), true, isJson, returnStaticContentOnPartitionWithNoRows); + return new SimpleSelection(table, all, orderingColumns, true, isJson, returnStaticContentOnPartitionWithNoRows); } public static Selection wildcardWithGroupBy(TableMetadata table, @@ -400,7 +419,7 @@ public SimpleSelection(TableMetadata table, selectedColumns, orderingColumns, SelectionColumnMapping.simpleMapping(selectedColumns), - isWildcard ? ColumnFilterFactory.wildcard(table) + isWildcard ? ColumnFilterFactory.wildcard(table, orderingColumns) : ColumnFilterFactory.fromColumns(table, selectedColumns, orderingColumns, Collections.emptySet(), returnStaticContentOnPartitionWithNoRows), isWildcard, isJson); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 3ab85e5adf6d..eb1b5f9fa0ba 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -1080,12 +1080,16 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil case CLUSTERING: result.add(row.clustering().bufferAt(def.position())); break; + case SYNTHETIC: + // treat as REGULAR case REGULAR: result.add(row.getColumnData(def), nowInSec); break; case STATIC: result.add(staticRow.getColumnData(def), nowInSec); break; + default: + throw new AssertionError(); } } } @@ -1159,7 +1163,7 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa // Besides actual restrictions (where clauses), prepareRestrictions will include pseudo-restrictions // on indexed columns to allow pushing ORDER BY into the index; see StatementRestrictions::addOrderingRestrictions. - // Therefore, we don't want to convert the Ordering column into a +score column until after that. + // Therefore, we don't want to convert an ANN Ordering column into a +score column until after that. List orderings = getOrderings(table); StatementRestrictions restrictions = prepareRestrictions( table, bindVariables, orderings, containsOnlyStaticColumns, forView); @@ -1235,14 +1239,8 @@ private Map getScoreOrdering(List orderings) return null; // Create synthetic score column - // Use the original column's table metadata but create new identifier and type ColumnMetadata sourceColumn = expr.getColumn(); - var cm = new ColumnMetadata(sourceColumn.ksName, - sourceColumn.cfName, - new ColumnIdentifier("+score", true), - FloatType.instance, - ColumnMetadata.NO_POSITION, - ColumnMetadata.Kind.REGULAR); + var cm = ColumnMetadata.syntheticColumn(sourceColumn.ksName, sourceColumn.cfName, ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); return Map.of(cm, orderings.get(0)); } @@ -1264,7 +1262,7 @@ private Selection prepareSelection(TableMetadata table, if (selectables.isEmpty()) // wildcard query { return hasGroupBy ? Selection.wildcardWithGroupBy(table, boundNames, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()) - : Selection.wildcard(table, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()); + : Selection.wildcard(table, resultSetOrderingColumns, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()); } return Selection.fromSelectors(table, @@ -1502,7 +1500,7 @@ private boolean isReversed(TableMetadata table, Map or { ColumnMetadata def = entry.getKey(); Ordering ordering = entry.getValue(); - // We defined ANN OF to be ASC ordering, as in, "order by near-ness". But since score goves from + // We defined ANN OF to be ASC ordering, as in, "order by near-ness". But since score goes from // 0 (worst) to 1 (closest), we need to reverse the ordering for the comparator when we're sorting // by synthetic +score column. boolean cqlReversed = ordering.direction == Ordering.Direction.DESC; diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java index 7ce9bd68cf15..cef03393a1db 100644 --- a/src/java/org/apache/cassandra/db/Columns.java +++ b/src/java/org/apache/cassandra/db/Columns.java @@ -28,6 +28,7 @@ import net.nicoulaj.compilecommand.annotations.DontInline; import org.apache.cassandra.cql3.ColumnIdentifier; +import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.SetType; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.db.rows.ColumnData; @@ -36,6 +37,7 @@ import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.serializers.AbstractTypeSerializer; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.ObjectSizes; import org.apache.cassandra.utils.SearchIterator; @@ -334,6 +336,11 @@ public Iterator simpleColumns() return BTree.iterator(columns, 0, complexIdx - 1, BTree.Dir.ASC); } + public Iterator simpleColumnsDesc() + { + return BTree.iterator(columns, 0, complexIdx - 1, BTree.Dir.DESC); + } + /** * Iterator over the complex columns of this object. * @@ -459,42 +466,112 @@ public String toString() public static class Serializer { + AbstractTypeSerializer typeSerializer = new AbstractTypeSerializer(); + public void serialize(Columns columns, DataOutputPlus out) throws IOException { - out.writeUnsignedVInt(columns.size()); + int regularCount = 0; + int syntheticCount = 0; + + // Count regular and synthetic columns + for (ColumnMetadata column : columns) + { + if (column.isSynthetic()) + syntheticCount++; + else + regularCount++; + } + + // Jam the two counts into a single value to avoid massive backwards compatibility issues + long packedCount = getPackedCount(syntheticCount, regularCount); + out.writeUnsignedVInt(packedCount); + + // First pass - write regular columns for (ColumnMetadata column : columns) - ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + { + if (!column.isSynthetic()) + ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + } + + // Second pass - write synthetic columns with their full metadata + for (ColumnMetadata column : columns) + { + if (column.isSynthetic()) + { + ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + typeSerializer.serialize(column.type, out); + } + } + } + + private static long getPackedCount(int syntheticCount, int regularCount) + { + // Left shift of 20 gives us over 1M regular columns, and up to 4 synthetic columns + // before overflowing to a 4th byte. + return ((long) syntheticCount << 20) | regularCount; } public long serializedSize(Columns columns) { - long size = TypeSizes.sizeofUnsignedVInt(columns.size()); + int regularCount = 0; + int syntheticCount = 0; + long size = 0; + + // Count and calculate sizes for (ColumnMetadata column : columns) - size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); - return size; + { + if (column.isSynthetic()) + { + syntheticCount++; + size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); + size += typeSerializer.serializedSize(column.type); + } + else + { + regularCount++; + size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); + } + } + + return TypeSizes.sizeofUnsignedVInt(getPackedCount(syntheticCount, regularCount)) + + size; } public Columns deserialize(DataInputPlus in, TableMetadata metadata) throws IOException { - int length = (int)in.readUnsignedVInt(); try (BTree.FastBuilder builder = BTree.fastBuilder()) { - for (int i = 0; i < length; i++) + long packedCount = in.readUnsignedVInt() ; + int regularCount = (int) (packedCount & 0xFFFFF); + int syntheticCount = (int) (packedCount >> 20); + + // First pass - regular columns + for (int i = 0; i < regularCount; i++) { ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); ColumnMetadata column = metadata.getColumn(name); if (column == null) { - // If we don't find the definition, it could be we have data for a dropped column, and we shouldn't - // fail deserialization because of that. So we grab a "fake" ColumnMetadata that ensure proper - // deserialization. The column will be ignore later on anyway. + // If we don't find the definition, it could be we have data for a dropped column column = metadata.getDroppedColumn(name); - if (column == null) throw new RuntimeException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization"); } builder.add(column); } + + // Second pass - synthetic columns + for (int i = 0; i < syntheticCount; i++) + { + ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); + AbstractType type = typeSerializer.deserialize(in); + + if (!name.equals(ColumnMetadata.SYNTHETIC_SCORE_ID.bytes)) + throw new IllegalStateException("Unknown synthetic column " + UTF8Type.instance.getString(name)); + + ColumnMetadata column = ColumnMetadata.syntheticColumn(metadata.keyspace, metadata.name, ColumnMetadata.SYNTHETIC_SCORE_ID, type); + builder.add(column); + } return new Columns(builder.build()); } } diff --git a/src/java/org/apache/cassandra/db/ReadCommand.java b/src/java/org/apache/cassandra/db/ReadCommand.java index d872147c0d0f..a289987a1adc 100644 --- a/src/java/org/apache/cassandra/db/ReadCommand.java +++ b/src/java/org/apache/cassandra/db/ReadCommand.java @@ -72,6 +72,7 @@ import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessageFlag; import org.apache.cassandra.net.Verb; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.SchemaConstants; @@ -413,9 +414,9 @@ public UnfilteredPartitionIterator executeLocally(ReadExecutionController execut } Context context = Context.from(this); - UnfilteredPartitionIterator iterator = (null == searcher) ? Transformation.apply(queryStorage(cfs, executionController), new TrackingRowIterator(context)) - : Transformation.apply(searchStorage(searcher, executionController), new TrackingRowIterator(context)); - + var storageTarget = (null == searcher) ? queryStorage(cfs, executionController) + : searchStorage(searcher, executionController); + UnfilteredPartitionIterator iterator = Transformation.apply(storageTarget, new TrackingRowIterator(context)); iterator = RTBoundValidator.validate(iterator, Stage.MERGED, false); try @@ -1047,6 +1048,19 @@ public ReadCommand deserialize(DataInputPlus in, int version) throws IOException TableMetadata metadata = schema.getExistingTableMetadata(TableId.deserialize(in)); int nowInSec = in.readInt(); ColumnFilter columnFilter = ColumnFilter.serializer.deserialize(in, version, metadata); + + // add synthetic columns to the tablemetadata so we can serialize them in our response + var tmb = metadata.unbuild(); + for (var it = columnFilter.fetchedColumns().regulars.simpleColumnsDesc(); it.hasNext(); ) + { + var c = it.next(); + // synthetic columns sort last, so when we hit the first non-synthetic, we're done + if (!c.isSynthetic()) + break; + tmb.addColumn(ColumnMetadata.syntheticColumn(c.ksName, c.cfName, c.name, c.type)); + } + metadata = tmb.build(); + RowFilter rowFilter = RowFilter.serializer.deserialize(in, version, metadata); DataLimits limits = DataLimits.serializer.deserialize(in, version, metadata.comparator); Index.QueryPlan indexQueryPlan = null; diff --git a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java index 05c5251c34e8..55533eda0e97 100644 --- a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java +++ b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java @@ -163,7 +163,7 @@ public Builder add(ColumnMetadata c) } else { - assert c.isRegular(); + assert c.isRegular() || c.isSynthetic(); if (regularColumns == null) regularColumns = BTree.builder(naturalOrder()); regularColumns.add(c); diff --git a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java index d9a1b9d4e51a..2d4b240270a8 100644 --- a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java +++ b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java @@ -103,7 +103,8 @@ boolean fetchesAllColumns(boolean isStatic) @Override RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried) { - return metadata.regularAndStaticColumns(); + var merged = queried.regulars.mergeTo(metadata.regularColumns()); + return new RegularAndStaticColumns(metadata.staticColumns(), merged); } }, @@ -124,7 +125,8 @@ boolean fetchesAllColumns(boolean isStatic) @Override RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried) { - return new RegularAndStaticColumns(queried.statics, metadata.regularColumns()); + var merged = queried.regulars.mergeTo(metadata.regularColumns()); + return new RegularAndStaticColumns(queried.statics, merged); } }, @@ -295,14 +297,16 @@ public static ColumnFilter selection(TableMetadata metadata, } /** - * The columns that needs to be fetched internally for this filter. + * The columns that needs to be fetched internally. See FetchingStrategy for why this is + * always a superset of the queried columns. * * @return the columns to fetch for this filter. */ public abstract RegularAndStaticColumns fetchedColumns(); /** - * The columns actually queried by the user. + * The columns needed to process the query, including selected columns, ordering columns, + * restriction (predicate) columns, and synthetic columns. *

* Note that this is in general not all the columns that are fetched internally (see {@link #fetchedColumns}). */ @@ -619,9 +623,7 @@ private SortedSetMultimap buildSubSelectio */ public static class WildCardColumnFilter extends ColumnFilter { - /** - * The queried and fetched columns. - */ + // for wildcards, there is no distinction between fetched and queried because queried is already "everything" private final RegularAndStaticColumns fetchedAndQueried; /** @@ -739,14 +741,9 @@ public static class SelectionColumnFilter extends ColumnFilter { public final FetchingStrategy fetchingStrategy; - /** - * The selected columns - */ + // Materializes the columns required to implement queriedColumns() and fetchedColumns(), + // see the comments to superclass's methods private final RegularAndStaticColumns queried; - - /** - * The columns that need to be fetched to be able - */ private final RegularAndStaticColumns fetched; private final SortedSetMultimap subSelections; // can be null diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index 774f06745110..7dd620be9ebe 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -585,7 +585,8 @@ public static class PrimaryKeyIterator extends AbstractUnfilteredRowIterator private final Unfiltered row; public final PrimaryKeyWithSortKey primaryKeyWithSortKey; - public PrimaryKeyIterator(PrimaryKeyWithSortKey key, UnfilteredRowIterator partition, Row staticRow, Unfiltered content) { + public PrimaryKeyIterator(PrimaryKeyWithSortKey key, UnfilteredRowIterator partition, Row staticRow, Unfiltered content) + { super(partition.metadata(), partition.partitionKey(), partition.partitionLevelDeletion(), @@ -602,15 +603,14 @@ public PrimaryKeyIterator(PrimaryKeyWithSortKey key, UnfilteredRowIterator parti return; } - var tm = metadata(); - var scoreColumn = ColumnMetadata.regularColumn(tm.keyspace, tm.name, "+score", FloatType.instance); - // clone the original Row Row originalRow = (Row) content; ArrayList columnData = new ArrayList<>(originalRow.columnCount() + 1); columnData.addAll(originalRow.columnData()); // inject +score as a new column + var tm = metadata(); + var scoreColumn = ColumnMetadata.syntheticColumn(tm.keyspace, tm.name, ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); var pkWithScore = (PrimaryKeyWithScore) key; columnData.add(BufferCell.live(scoreColumn, FBUtilities.nowInSeconds(), @@ -626,7 +626,8 @@ public PrimaryKeyIterator(PrimaryKeyWithSortKey key, UnfilteredRowIterator parti } @Override - protected Unfiltered computeNext() { + protected Unfiltered computeNext() + { if (consumed) return endOfData(); consumed = true; diff --git a/src/java/org/apache/cassandra/schema/ColumnMetadata.java b/src/java/org/apache/cassandra/schema/ColumnMetadata.java index 38784e72acb5..2a59ee3ebf58 100644 --- a/src/java/org/apache/cassandra/schema/ColumnMetadata.java +++ b/src/java/org/apache/cassandra/schema/ColumnMetadata.java @@ -71,9 +71,9 @@ public enum ClusteringOrder /** * The type of CQL3 column this definition represents. - * There is 4 main type of CQL3 columns: those parts of the partition key, - * those parts of the clustering columns and amongst the others, regular and - * static ones. + * There are 5 types of columns: those parts of the partition key, + * those parts of the clustering columns and amongst the others, regular, + * static, and synthetic ones. * * IMPORTANT: this enum is serialized as toString() and deserialized by calling * Kind.valueOf(), so do not override toString() or rename existing values. @@ -84,15 +84,17 @@ public enum Kind PARTITION_KEY, CLUSTERING, REGULAR, - STATIC; + STATIC, + SYNTHETIC; public boolean isPrimaryKeyKind() { return this == PARTITION_KEY || this == CLUSTERING; } - } + public static final ColumnIdentifier SYNTHETIC_SCORE_ID = ColumnIdentifier.getInterned("+:!score", true); + /** * Whether this is a dropped column. */ @@ -121,10 +123,17 @@ public boolean isPrimaryKeyKind() */ private final long comparisonOrder; + /** + * Bit layout (from most to least significant): + * - Bits 61-63: Kind ordinal (3 bits, supporting up to 8 Kind values) + * - Bit 60: isComplex flag + * - Bits 48-59: position (12 bits, see assert) + * - Bits 0-47: name.prefixComparison (shifted right by 16) + */ private static long comparisonOrder(Kind kind, boolean isComplex, long position, ColumnIdentifier name) { assert position >= 0 && position < 1 << 12; - return (((long) kind.ordinal()) << 61) + return (((long) kind.ordinal()) << 61) | (isComplex ? 1L << 60 : 0) | (position << 48) | (name.prefixComparison >>> 16); @@ -170,6 +179,14 @@ public static ColumnMetadata staticColumn(String keyspace, String table, String return new ColumnMetadata(keyspace, table, ColumnIdentifier.getInterned(name, true), type, NO_POSITION, Kind.STATIC); } + /** + * Creates a new synthetic column metadata instance. + */ + public static ColumnMetadata syntheticColumn(String keyspace, String table, ColumnIdentifier id, AbstractType type) + { + return new ColumnMetadata(keyspace, table, id, type, NO_POSITION, Kind.SYNTHETIC); + } + /** * Rebuild the metadata for a dropped column from its recorded data. * @@ -225,6 +242,7 @@ public ColumnMetadata(String ksName, this.kind = kind; this.position = position; this.cellPathComparator = makeCellPathComparator(kind, type); + assert kind != Kind.SYNTHETIC || cellPathComparator == null; this.cellComparator = cellPathComparator == null ? ColumnData.comparator : new Comparator>() { @Override @@ -593,6 +611,11 @@ public boolean isCounterColumn() return type.isCounter(); } + public boolean isSynthetic() + { + return kind == Kind.SYNTHETIC; + } + public Selector.Factory newSelectorFactory(TableMetadata table, AbstractType expectedType, List defs, VariableSpecifications boundNames) throws InvalidRequestException { return SimpleSelector.newFactory(this, addAndGetIndex(this, defs)); diff --git a/src/java/org/apache/cassandra/schema/TableMetadata.java b/src/java/org/apache/cassandra/schema/TableMetadata.java index 6442326c9a2d..8885b85fdd43 100644 --- a/src/java/org/apache/cassandra/schema/TableMetadata.java +++ b/src/java/org/apache/cassandra/schema/TableMetadata.java @@ -167,9 +167,6 @@ protected TableMetadata(Builder builder) name = builder.name; id = builder.id; - // FIXME - builder.addColumn(new ColumnMetadata(keyspace, name, new ColumnIdentifier("+score", true), FloatType.instance, ColumnMetadata.NO_POSITION, ColumnMetadata.Kind.REGULAR)); - partitioner = builder.partitioner; kind = builder.kind; params = builder.params.build(); @@ -1125,8 +1122,7 @@ public Builder addStaticColumn(ColumnIdentifier name, AbstractType type) public Builder addColumn(ColumnMetadata column) { - if (columns.containsKey(column.name.bytes)) - return this; // FIXME + assert !columns.containsKey(column.name.bytes) : column.name + " is already present"; switch (column.kind) { diff --git a/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java b/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java index f475e466e28c..7538d65a4d88 100644 --- a/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java @@ -135,14 +135,6 @@ private static ByteBuffer floatVectorToByteBuffer(CQLTester.Vector vector @Test public void testOrderResults() { - QueryOptions queryOptions = QueryOptions.create(ConsistencyLevel.ONE, - byteBufferList, - false, - PageSize.inRows(1), - null, - null, - ProtocolVersion.CURRENT, - KEYSPACE); List> rows = new ArrayList<>(); rows.add(byteBufferList); From 237cec4d1768efe5ea10c671fd779d47ca3d027a Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 6 Dec 2024 16:30:00 -0600 Subject: [PATCH 05/29] implement BM25 --- src/antlr/Lexer.g | 1 + src/antlr/Parser.g | 15 +- .../cassandra/cql3/GeoDistanceRelation.java | 6 + .../cassandra/cql3/MultiColumnRelation.java | 6 + .../org/apache/cassandra/cql3/Operator.java | 14 + .../org/apache/cassandra/cql3/Ordering.java | 64 ++++ .../org/apache/cassandra/cql3/Relation.java | 7 + .../cassandra/cql3/SingleColumnRelation.java | 8 + .../apache/cassandra/cql3/TokenRelation.java | 8 +- .../restrictions/SingleColumnRestriction.java | 68 +++++ .../cql3/statements/SelectStatement.java | 54 +++- src/java/org/apache/cassandra/db/Columns.java | 8 + .../apache/cassandra/db/filter/RowFilter.java | 1 + .../db/rows/UnfilteredSerializer.java | 5 + .../cassandra/index/sai/IndexContext.java | 10 +- .../cassandra/index/sai/QueryContext.java | 2 +- .../index/sai/StorageAttachedIndex.java | 2 +- .../index/sai/disk/v1/IndexSearcher.java | 23 +- .../sai/disk/v1/InvertedIndexSearcher.java | 127 +++++++- .../v1/postings/IntersectingPostingList.java | 134 +++++++++ ...ngList.java => ReorderingPostingList.java} | 10 +- .../sai/disk/v2/V2VectorIndexSearcher.java | 16 +- .../sai/disk/vector/VectorMemtableIndex.java | 4 +- .../index/sai/memory/TrieMemoryIndex.java | 7 + .../index/sai/memory/TrieMemtableIndex.java | 163 +++++++++-- .../cassandra/index/sai/plan/Expression.java | 2 + .../cassandra/index/sai/plan/Operation.java | 3 +- .../cassandra/index/sai/plan/Orderer.java | 48 ++- .../apache/cassandra/index/sai/plan/Plan.java | 110 +++++-- .../index/sai/plan/QueryController.java | 1 + .../plan/StorageAttachedIndexSearcher.java | 73 +++-- .../index/sai/plan/TopKProcessor.java | 48 ++- .../cassandra/index/sai/utils/BM25Utils.java | 178 ++++++++++++ .../utils/PrimaryKeyWithByteComparable.java | 10 +- .../index/sai/utils/PrimaryKeyWithScore.java | 10 +- .../sai/utils/PrimaryKeyWithSortKey.java | 11 +- .../index/sai/utils/RowIdWithScore.java | 2 +- .../test/sai/BM25DistributedTest.java | 121 ++++++++ .../index/sai/StorageAttachedIndexTest.java | 10 +- .../cassandra/index/sai/cql/BM25Test.java | 274 ++++++++++++++++++ .../postings/IntersectingPostingListTest.java | 213 ++++++++++++++ ...st.java => ReorderingPostingListTest.java} | 8 +- .../sai/memory/VectorMemtableIndexTest.java | 4 +- 43 files changed, 1726 insertions(+), 163 deletions(-) create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java rename src/java/org/apache/cassandra/index/sai/disk/v1/postings/{VectorPostingList.java => ReorderingPostingList.java} (83%) create mode 100644 src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java create mode 100644 test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java create mode 100644 test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java create mode 100644 test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java rename test/unit/org/apache/cassandra/index/sai/disk/v1/postings/{VectorPostingListTest.java => ReorderingPostingListTest.java} (91%) diff --git a/src/antlr/Lexer.g b/src/antlr/Lexer.g index 1e30ad18edd2..100541caf913 100644 --- a/src/antlr/Lexer.g +++ b/src/antlr/Lexer.g @@ -227,6 +227,7 @@ K_DROPPED: D R O P P E D; K_COLUMN: C O L U M N; K_RECORD: R E C O R D; K_ANN_OF: A N N WS+ O F; +K_BM25_OF: 'BM25' WS+ 'OF'; // Case-insensitive alpha characters fragment A: ('a'|'A'); diff --git a/src/antlr/Parser.g b/src/antlr/Parser.g index 7fb85a9aeeea..43c4fe67c96f 100644 --- a/src/antlr/Parser.g +++ b/src/antlr/Parser.g @@ -457,14 +457,18 @@ customIndexExpression [WhereClause.Builder clause] ; orderByClause[List orderings] - @init{ + @init { Ordering.Direction direction = Ordering.Direction.ASC; + Ordering.Raw.Expression expr = null; } - : c=cident (K_ANN_OF t=term)? (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })? + : c=cident + ( K_ANN_OF t=term { expr = new Ordering.Raw.Ann(c, t); } + | K_BM25_OF t=term { expr = new Ordering.Raw.Bm25(c, t); } + )? + (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })? { - Ordering.Raw.Expression expr = (t == null) - ? new Ordering.Raw.SingleColumn(c) - : new Ordering.Raw.Ann(c, t); + if (expr == null) + expr = new Ordering.Raw.SingleColumn(c); orderings.add(new Ordering.Raw(expr, direction)); } ; @@ -1967,6 +1971,7 @@ basic_unreserved_keyword returns [String str] | K_COLUMN | K_RECORD | K_ANN_OF + | K_BM25_OF | K_OFFSET ) { $str = $k.text; } ; diff --git a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java index 3d5fb2eeeef7..640e7e600686 100644 --- a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java +++ b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java @@ -141,6 +141,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator()); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator()); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java index fcd505fcf5df..f56d76e2ced9 100644 --- a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java @@ -250,6 +250,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati throw invalidRequest("%s cannot be used for multi-column relations", operator()); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used for multi-column relations", operator()); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/Operator.java b/src/java/org/apache/cassandra/cql3/Operator.java index 41b7985ffc4d..d4bf3ab93d3c 100644 --- a/src/java/org/apache/cassandra/cql3/Operator.java +++ b/src/java/org/apache/cassandra/cql3/Operator.java @@ -373,6 +373,20 @@ public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteB return !LIKE.isSatisfiedBy(type, leftOperand, rightOperand, analyzer); } }, + BM25(25) + { + @Override + public String toString() + { + return "BM25"; + } + + @Override + public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) + { + throw new UnsupportedOperationException(); + } + }, /** * An operator that only performs matching against analyzed columns. diff --git a/src/java/org/apache/cassandra/cql3/Ordering.java b/src/java/org/apache/cassandra/cql3/Ordering.java index f5760f38510e..2dd817818a76 100644 --- a/src/java/org/apache/cassandra/cql3/Ordering.java +++ b/src/java/org/apache/cassandra/cql3/Ordering.java @@ -20,6 +20,7 @@ import org.apache.cassandra.cql3.restrictions.SingleColumnRestriction; import org.apache.cassandra.cql3.restrictions.SingleRestriction; +import org.apache.cassandra.cql3.statements.SelectStatement; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; @@ -124,6 +125,48 @@ public ColumnMetadata getColumn() return column; } + @Override + public boolean isScored() + { + return SelectStatement.ANN_USE_SYNTHETIC_SCORE; + } + } + + /** + * An expression used in BM25 ordering. + * ORDER BY column BM25 OF value + */ + public static class Bm25 implements Expression + { + final ColumnMetadata column; + final Term queryValue; + final Direction direction; + + public Bm25(ColumnMetadata column, Term queryValue, Direction direction) + { + this.column = column; + this.queryValue = queryValue; + this.direction = direction; + } + + @Override + public boolean hasNonClusteredOrdering() + { + return true; + } + + @Override + public SingleRestriction toRestriction() + { + return new SingleColumnRestriction.Bm25Restriction(column, queryValue); + } + + @Override + public ColumnMetadata getColumn() + { + return column; + } + @Override public boolean isScored() { @@ -201,6 +244,27 @@ public Ordering.Expression bind(TableMetadata table, VariableSpecifications boun return new Ordering.Ann(column, value, direction); } } + + public static class Bm25 implements Expression + { + final ColumnIdentifier columnId; + final Term.Raw queryValue; + + Bm25(ColumnIdentifier column, Term.Raw queryValue) + { + this.columnId = column; + this.queryValue = queryValue; + } + + @Override + public Ordering.Expression bind(TableMetadata table, VariableSpecifications boundNames, Direction direction) + { + ColumnMetadata column = table.getExistingColumn(columnId); + Term value = queryValue.prepare(table.keyspace, column); + value.collectMarkerSpecification(boundNames); + return new Ordering.Bm25(column, value, direction); + } + } } } diff --git a/src/java/org/apache/cassandra/cql3/Relation.java b/src/java/org/apache/cassandra/cql3/Relation.java index 5cca2d257323..42cf3c9c8287 100644 --- a/src/java/org/apache/cassandra/cql3/Relation.java +++ b/src/java/org/apache/cassandra/cql3/Relation.java @@ -202,6 +202,8 @@ public final Restriction toRestriction(TableMetadata table, VariableSpecificatio return newLikeRestriction(table, boundNames, relationType); case ANN: return newAnnRestriction(table, boundNames); + case BM25: + return newBm25Restriction(table, boundNames); case ANALYZER_MATCHES: return newAnalyzerMatchesRestriction(table, boundNames); default: throw invalidRequest("Unsupported \"!=\" relation: %s", this); @@ -296,6 +298,11 @@ protected abstract Restriction newSliceRestriction(TableMetadata table, */ protected abstract Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames); + /** + * Creates a new BM25 restriction instance. + */ + protected abstract Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames); + /** * Creates a new Analyzer Matches restriction instance. */ diff --git a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java index ec66ad70b529..5a93e6c8e422 100644 --- a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java @@ -333,6 +333,14 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati return new SingleColumnRestriction.AnnRestriction(columnDef, term); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + ColumnMetadata columnDef = table.getExistingColumn(entity); + Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); + return new SingleColumnRestriction.AnnRestriction(columnDef, term); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/TokenRelation.java b/src/java/org/apache/cassandra/cql3/TokenRelation.java index a3ca586eee76..ca849dc82a30 100644 --- a/src/java/org/apache/cassandra/cql3/TokenRelation.java +++ b/src/java/org/apache/cassandra/cql3/TokenRelation.java @@ -138,7 +138,13 @@ protected Restriction newLikeRestriction(TableMetadata table, VariableSpecificat @Override protected Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames) { - throw invalidRequest("%s cannot be used for toekn relations", operator()); + throw invalidRequest("%s cannot be used for token relations", operator()); + } + + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used for token relations", operator()); } @Override diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java index 264353a34928..9d52f2bfe43c 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java @@ -1191,6 +1191,74 @@ public boolean isBoundedAnn() } } + public static final class Bm25Restriction extends SingleColumnRestriction + { + private final Term value; + + public Bm25Restriction(ColumnMetadata columnDef, Term value) + { + super(columnDef); + this.value = value; + } + + public ByteBuffer value(QueryOptions options) + { + return value.bindAndGet(options); + } + + @Override + public void addFunctionsTo(List functions) + { + value.addFunctionsTo(functions); + } + + @Override + MultiColumnRestriction toMultiColumnRestriction() + { + throw new UnsupportedOperationException(); + } + + @Override + public void addToRowFilter(RowFilter.Builder filter, + IndexRegistry indexRegistry, + QueryOptions options) + { + filter.add(columnDef, Operator.BM25, value.bindAndGet(options)); + } + + @Override + public MultiClusteringBuilder appendTo(MultiClusteringBuilder builder, QueryOptions options) + { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() + { + return String.format("BM25(%s)", value); + } + + @Override + public SingleRestriction doMergeWith(SingleRestriction otherRestriction) + { + if (otherRestriction.isIndexBasedOrdering()) + throw invalidRequest("%s cannot be restricted by multiple BM25 restrictions", columnDef.name); + throw invalidRequest("%s cannot be restricted by both BM25 and %s", columnDef.name, otherRestriction.toString()); + } + + @Override + protected boolean isSupportedBy(Index index) + { + return index.supportsExpression(columnDef, Operator.BM25); + } + + @Override + public boolean isIndexBasedOrdering() + { + return true; + } + } + /** * A Bounded ANN Restriction is one that uses a similarity score as the limiting factor for ANN instead of a number * of results. diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index eb1b5f9fa0ba..e7d5cf5874a6 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -41,6 +41,7 @@ import org.apache.cassandra.cql3.selection.SortedRowsBuilder; import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.guardrails.Guardrails; +import org.apache.cassandra.index.Index; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.TableMetadata; @@ -100,6 +101,10 @@ */ public class SelectStatement implements CQLStatement.SingleKeyspaceCqlStatement { + // TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns + // (And don't forget to remove the related hacks in Columns.Serializer.encodeBitmap and UnfilteredSerializer.serializeRowBody) + public static final boolean ANN_USE_SYNTHETIC_SCORE = Boolean.parseBoolean(System.getProperty("cassandra.sai.ann_use_synthetic_score", "false")); + private static final Logger logger = LoggerFactory.getLogger(SelectStatement.class); private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(SelectStatement.logger, 1, TimeUnit.MINUTES); public static final String TOPK_CONSISTENCY_LEVEL_ERROR = "Top-K queries can only be run with consistency level ONE/LOCAL_ONE. Consistency level %s was used."; @@ -993,7 +998,7 @@ private ResultSet process(PartitionIterator partitions, { ProtocolVersion protocolVersion = options.getProtocolVersion(); GroupMaker groupMaker = aggregationSpec == null ? null : aggregationSpec.newGroupMaker(); - SortedRowsBuilder rows = sortedRowsBuilder(userLimit, userOffset == NO_OFFSET ? 0 : userOffset); + SortedRowsBuilder rows = sortedRowsBuilder(userLimit, userOffset == NO_OFFSET ? 0 : userOffset, options); ResultSetBuilder result = new ResultSetBuilder(protocolVersion, getResultMetadata(), selectors, groupMaker, rows); while (partitions.hasNext()) @@ -1112,11 +1117,24 @@ private boolean needsPostQueryOrdering() /** * Orders results when multiple keys are selected (using IN) */ - public SortedRowsBuilder sortedRowsBuilder(int limit, int offset) + public SortedRowsBuilder sortedRowsBuilder(int limit, int offset, QueryOptions options) { if (orderingComparator == null) return SortedRowsBuilder.create(limit, offset); + if (orderingComparator instanceof IndexColumnComparator) + { + SingleRestriction restriction = ((IndexColumnComparator) orderingComparator).restriction; + int columnIndex = ((IndexColumnComparator) orderingComparator).columnIndex; + + Index index = restriction.findSupportingIndex(IndexRegistry.obtain(table)); + assert index != null; + + Index.Scorer scorer = index.postQueryScorer(restriction, columnIndex, options); + return SortedRowsBuilder.create(limit, offset, scorer); + } + + // else return SortedRowsBuilder.create(limit, offset, orderingComparator); } @@ -1474,7 +1492,11 @@ private ColumnComparator> getOrderingComparator(Selection selec assert orderingColumns.size() == 1 : orderingColumns.keySet(); var e = orderingColumns.entrySet().iterator().next(); var column = e.getKey(); - return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false); + var ordering = e.getValue(); + if (ordering.expression instanceof Ordering.Ann && !ANN_USE_SYNTHETIC_SCORE) + return new IndexColumnComparator(ordering.expression.toRestriction(), selection.getOrderingIndex(column)); + else + return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false); } if (!restrictions.keyIsInRelation()) @@ -1694,6 +1716,32 @@ public boolean isClustered() } } + // see usage in sortedRowsBuilder + private static class IndexColumnComparator extends ColumnComparator> + { + private final SingleRestriction restriction; + private final int columnIndex; + + // VSTODO maybe cache in prepared statement + public IndexColumnComparator(SingleRestriction restriction, int columnIndex) + { + this.restriction = restriction; + this.columnIndex = columnIndex; + } + + @Override + public boolean isClustered() + { + return false; + } + + @Override + public int compare(List o1, List o2) + { + throw new UnsupportedOperationException(); + } + } + /** * Used in orderResults(...) method when multiple 'ORDER BY' conditions where given */ diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java index cef03393a1db..becd23508d1d 100644 --- a/src/java/org/apache/cassandra/db/Columns.java +++ b/src/java/org/apache/cassandra/db/Columns.java @@ -716,7 +716,15 @@ private static long encodeBitmap(Collection columns, Columns sup for (ColumnMetadata column : columns) { if (iter.next(column) == null) + { + // We can't know for sure whether to add the synthetic score column because WildcardColumnFilter + // just says "yes" to everything; instead, we just skip it here. + // TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE. + if (column.isSynthetic()) + continue; + throw new IllegalStateException(columns + " is not a subset of " + superset); + } int currentIndex = iter.indexOfCurrent(); int count = currentIndex - expectIndex; diff --git a/src/java/org/apache/cassandra/db/filter/RowFilter.java b/src/java/org/apache/cassandra/db/filter/RowFilter.java index 675199e4f12d..b4f6493cbd20 100644 --- a/src/java/org/apache/cassandra/db/filter/RowFilter.java +++ b/src/java/org/apache/cassandra/db/filter/RowFilter.java @@ -1106,6 +1106,7 @@ public boolean isSatisfiedBy(TableMetadata metadata, DecoratedKey partitionKey, case LIKE_MATCHES: case ANALYZER_MATCHES: case ANN: + case BM25: { assert !column.isComplex() : "Only CONTAINS and CONTAINS_KEY are supported for 'complex' types"; ByteBuffer foundValue = getValue(metadata, partitionKey, row); diff --git a/src/java/org/apache/cassandra/db/rows/UnfilteredSerializer.java b/src/java/org/apache/cassandra/db/rows/UnfilteredSerializer.java index a38305e1dd71..c94e3470bb33 100644 --- a/src/java/org/apache/cassandra/db/rows/UnfilteredSerializer.java +++ b/src/java/org/apache/cassandra/db/rows/UnfilteredSerializer.java @@ -242,6 +242,11 @@ private void serializeRowBody(Row row, int flags, SerializationHelper helper, Da // with. So we use the ColumnMetadata from the "header" which is "current". Also see #11810 for what // happens if we don't do that. ColumnMetadata column = si.next(cd.column()); + // We can't know for sure whether to add the synthetic score column because WildcardColumnFilter + // just says "yes" to everything; instead, we just skip it here. + // TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE. + if (column == null) + return; assert column != null : cd.column.toString(); try diff --git a/src/java/org/apache/cassandra/index/sai/IndexContext.java b/src/java/org/apache/cassandra/index/sai/IndexContext.java index 94b8cd1afed1..890ecd2cc7aa 100644 --- a/src/java/org/apache/cassandra/index/sai/IndexContext.java +++ b/src/java/org/apache/cassandra/index/sai/IndexContext.java @@ -689,8 +689,8 @@ public boolean supports(Operator op) { if (op.isLike() || op == Operator.LIKE) return false; // Analyzed columns store the indexed result, so we are unable to compute raw equality. - // The only supported operator is ANALYZER_MATCHES. - if (op == Operator.ANALYZER_MATCHES) return isAnalyzed; + // The only supported operators are ANALYZER_MATCHES and BM25. + if (op == Operator.ANALYZER_MATCHES || op == Operator.BM25) return isAnalyzed; // If the column is analyzed and the operator is EQ, we need to check if the analyzer supports it. if (op == Operator.EQ && isAnalyzed && !analyzerFactory.supportsEquals()) @@ -714,7 +714,6 @@ public boolean supports(Operator op) || column.type instanceof IntegerType); // Currently truncates to 20 bytes Expression.Op operator = Expression.Op.valueOf(op); - if (isNonFrozenCollection()) { if (indexType == IndexTarget.Type.KEYS) @@ -726,17 +725,12 @@ public boolean supports(Operator op) return indexType == IndexTarget.Type.KEYS_AND_VALUES && (operator == Expression.Op.EQ || operator == Expression.Op.NOT_EQ || operator == Expression.Op.RANGE); } - if (indexType == IndexTarget.Type.FULL) return operator == Expression.Op.EQ; - AbstractType validator = getValidator(); - if (operator == Expression.Op.IN) return true; - if (operator != Expression.Op.EQ && EQ_ONLY_TYPES.contains(validator)) return false; - // RANGE only applicable to non-literal indexes return (operator != null) && !(TypeUtil.isLiteral(validator) && operator == Expression.Op.RANGE); } diff --git a/src/java/org/apache/cassandra/index/sai/QueryContext.java b/src/java/org/apache/cassandra/index/sai/QueryContext.java index 0bb2a96922ea..7d766f77dfd1 100644 --- a/src/java/org/apache/cassandra/index/sai/QueryContext.java +++ b/src/java/org/apache/cassandra/index/sai/QueryContext.java @@ -35,7 +35,7 @@ @NotThreadSafe public class QueryContext { - private static final boolean DISABLE_TIMEOUT = Boolean.getBoolean("cassandra.sai.test.disable.timeout"); + private static final boolean DISABLE_TIMEOUT = true; // Boolean.getBoolean("cassandra.sai.test.disable.timeout"); protected final long queryStartTimeNanos; diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index 6b9e2582d3ef..74482d97a7c9 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -683,7 +683,7 @@ public RowFilter getPostIndexQueryFilter(RowFilter filter) @Override public Scorer postQueryScorer(Restriction restriction, int columnIndex, QueryOptions options) { - // For now, only support ANN + // TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE. assert restriction instanceof SingleColumnRestriction.AnnRestriction; Preconditions.checkState(indexContext.isVector()); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java index 52988a685cd5..5fb246183921 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java @@ -68,9 +68,7 @@ public abstract class IndexSearcher implements Closeable, SegmentOrdering protected final SegmentMetadata metadata; protected final IndexContext indexContext; - private static final SSTableReadsListener NOOP_LISTENER = new SSTableReadsListener() {}; - - private final ColumnFilter columnFilter; + protected final ColumnFilter columnFilter; protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, PerIndexFiles perIndexFiles, @@ -90,30 +88,37 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract long indexFileCacheSize(); /** - * Search on-disk index synchronously + * Search on-disk index synchronously. Used for WHERE clause predicates, including BOUNDED_ANN. * * @param expression to filter on disk index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics * @param defer create the iterator in a deferred state - * @param limit the num of rows to returned, used by ANN index + * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected! * @return {@link KeyRangeIterator} that matches given expression */ public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException; /** - * Order the on-disk index synchronously and produce an iterator in score order + * Order the rows by the giving Orderer. Used for ORDER BY clause when + * (1) the WHERE predicate is either a partition restriction or a range restriction on the index, + * (2) there is no WHERE predicate, or + * (3) the planner determines it is better to post-filter the ordered results by the predicate. * * @param orderer the object containing the ordering logic * @param slice optional predicate to get a slice of the index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics - * @param limit the num of rows to returned, used by ANN index + * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected! * @return an iterator of {@link PrimaryKeyWithSortKey} in score order */ public abstract CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException; - + /** + * Order the rows by the giving Orderer. Used for ORDER BY clause when the WHERE predicates + * have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine + * the initial number of results returned, not a maximum. + */ @Override public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext context, List keys, Orderer orderer, int limit) throws IOException { @@ -124,7 +129,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea { var slices = Slices.with(indexContext.comparator(), Slice.make(key.clustering())); // TODO if we end up needing to read the row still, is it better to store offset and use reader.unfilteredAt? - try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, NOOP_LISTENER)) + try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) { if (iter.hasNext()) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 15aaa349e1c8..295f5fe9356d 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -19,14 +19,26 @@ package org.apache.cassandra.index.sai.disk.v1; import java.io.IOException; +import java.io.UncheckedIOException; import java.lang.invoke.MethodHandles; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import com.google.common.base.MoreObjects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.Slice; +import org.apache.cassandra.db.Slices; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; @@ -35,31 +47,41 @@ import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.disk.v1.postings.IntersectingPostingList; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.MulticastQueryEventListeners; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; +import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable; import org.apache.cassandra.index.sai.utils.SAICodecUtils; -import org.apache.cassandra.index.sai.utils.SegmentOrdering; +import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.io.sstable.format.SSTableReadsListener; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.bytecomparable.ByteComparable; +import static org.apache.cassandra.index.sai.disk.PostingList.END_OF_STREAM; + /** * Executes {@link Expression}s against the trie-based terms dictionary for an individual index segment. */ -public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrdering +public class InvertedIndexSearcher extends IndexSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private final TermsReader reader; private final QueryEventListener.TrieIndexEventListener perColumnEventListener; private final Version version; + private static final float K1 = 1.2f; // BM25 term frequency saturation parameter + private static final float B = 0.75f; // BM25 length normalization parameter + private final boolean filterRangeResults; + private final SSTableReader sstable; protected InvertedIndexSearcher(SSTableContext sstableContext, PerIndexFiles perIndexFiles, @@ -69,6 +91,7 @@ protected InvertedIndexSearcher(SSTableContext sstableContext, boolean filterRangeResults) throws IOException { super(sstableContext.primaryKeyMapFactory(), perIndexFiles, segmentMetadata, indexContext); + this.sstable = sstableContext.sstable; long root = metadata.getIndexRoot(IndexComponentType.TERMS_DATA); assert root >= 0; @@ -129,11 +152,105 @@ else if (exp.getOp() == Expression.Op.RANGE) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression: " + exp)); } + private Cell readColumn(SSTableReader sstable, PrimaryKey primaryKey) + { + var dk = primaryKey.partitionKey(); + var slices = Slices.with(indexContext.comparator(), Slice.make(primaryKey.clustering())); + try (var rowIterator = sstable.iterator(dk, slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) + { + var unfiltered = rowIterator.next(); + assert unfiltered.isRow() : unfiltered; + Row row = (Row) unfiltered; + return row.getCell(indexContext.getDefinition()); + } + } + @Override public CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException { - var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); - return toMetaSortedIterator(iter, queryContext); + if (!orderer.isBM25()) + { + var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); + return toMetaSortedIterator(iter, queryContext); + } + + // find documents that match each term + var queryTerms = orderer.extractQueryTerms(); + var postingLists = queryTerms.stream() + .collect(Collectors.toMap(Function.identity(), term -> + { + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + return reader.exactMatch(encodedTerm, listener, queryContext); + })); + // extract the match count for each + var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size())); + + try (var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap(); + var merged = IntersectingPostingList.intersect(List.copyOf(postingLists.values()))) + { + var it = new AbstractIterator() { + @Override + protected PrimaryKey computeNext() + { + try + { + int rowId = merged.nextPosting(); + if (rowId == PostingList.END_OF_STREAM) + return endOfData(); + return pkm.primaryKeyFromRowId(rowId); + } + catch (IOException e) + { + throw new UncheckedIOException(e); + } + } + }; + return bm25Internal(it, queryTerms, documentFrequencies); + } + } + + private CloseableIterator bm25Internal(Iterator keyIterator, + List queryTerms, + Map documentFrequencies) + { + var docStats = new BM25Utils.DocStats(documentFrequencies, sstable.getTotalRows()); + return BM25Utils.computeScores(keyIterator, + queryTerms, + docStats, + indexContext, + sstable.descriptor.id, + pk -> readColumn(sstable, pk)); + } + + @Override + public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext queryContext, List keys, Orderer orderer, int limit) throws IOException + { + if (!orderer.isBM25()) + return super.orderResultsBy(reader, queryContext, keys, orderer, limit); + + var queryTerms = orderer.extractQueryTerms(); + var documentFrequencies = new HashMap(); + boolean hasHistograms = metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram(); + for (ByteBuffer term : queryTerms) + { + long matches; + if (hasHistograms) + { + matches = metadata.estimateNumRowsMatching(new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term)); + } + else + { + // Without histograms, need to do an actual index scan + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + var postingList = this.reader.exactMatch(encodedTerm, listener, queryContext); + matches = postingList.size(); + FileUtils.closeQuietly(postingList); + } + documentFrequencies.put(term, matches); + } + return bm25Internal(keys.iterator(), queryTerms, documentFrequencies); } @Override @@ -172,7 +289,7 @@ protected RowIdWithByteComparable computeNext() while (true) { long nextPosting = currentPostingList.nextPosting(); - if (nextPosting != PostingList.END_OF_STREAM) + if (nextPosting != END_OF_STREAM) return new RowIdWithByteComparable(Math.toIntExact(nextPosting), currentTerm); if (!source.hasNext()) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java new file mode 100644 index 000000000000..61fa01ff848b --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.util.List; +import javax.annotation.concurrent.NotThreadSafe; + +import org.apache.cassandra.index.sai.disk.PostingList; + +import static java.lang.Math.max; + +/** + * Performs intersection operations on multiple PostingLists, returning only postings + * that appear in all inputs. + */ +@NotThreadSafe +public class IntersectingPostingList implements PostingList +{ + private final List postingLists; + private final int size; + + // currentPostings state is effectively local to findNextIntersection, but we keep it + // around as a field to avoid repeated allocations there + private final int[] currentRowIds; + + private IntersectingPostingList(List postingLists) + { + assert !postingLists.isEmpty(); + this.postingLists = postingLists; + this.size = postingLists.stream() + .mapToInt(PostingList::size) + .min() + .orElse(0); + this.currentRowIds = new int[postingLists.size()]; + } + + public static PostingList intersect(List postingLists) + { + if (postingLists.size() == 1) + return postingLists.get(0); + + if (postingLists.stream().anyMatch(PostingList::isEmpty)) + return PostingList.EMPTY; + + return new IntersectingPostingList(postingLists); + } + + @Override + public int nextPosting() throws IOException + { + return findNextIntersection(Integer.MIN_VALUE, false); + } + + @Override + public int advance(int targetRowID) throws IOException + { + assert targetRowID >= 0 : targetRowID; + return findNextIntersection(targetRowID, true); + } + + private int findNextIntersection(int targetRowID, boolean isAdvance) throws IOException + { + // Initialize currentRowIds from the underlying posting lists + for (int i = 0; i < postingLists.size(); i++) + { + currentRowIds[i] = isAdvance + ? postingLists.get(i).advance(targetRowID) + : postingLists.get(i).nextPosting(); + + if (currentRowIds[i] == END_OF_STREAM) + return END_OF_STREAM; + } + + while (true) + { + // Find the maximum row ID among all posting lists + int maxRowId = targetRowID; + for (int rowId : currentRowIds) + maxRowId = max(maxRowId, rowId); + + // Advance any posting list that's behind the maximum + boolean allMatch = true; + for (int i = 0; i < postingLists.size(); i++) + { + if (currentRowIds[i] < maxRowId) + { + currentRowIds[i] = postingLists.get(i).advance(maxRowId); + if (currentRowIds[i] == END_OF_STREAM) + return END_OF_STREAM; + allMatch = false; + } + } + + // If all posting lists have the same row ID, we've found an intersection + if (allMatch) + return maxRowId; + + // Otherwise, continue searching with the new maximum as target + targetRowID = maxRowId; + } + } + + @Override + public int size() + { + return size; + } + + @Override + public void close() throws IOException + { + for (PostingList list : postingLists) + list.close(); + } +} + + diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java similarity index 83% rename from src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java rename to src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java index 52155bf6ed59..f43c7c4e8dce 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java @@ -19,31 +19,29 @@ package org.apache.cassandra.index.sai.disk.v1.postings; import java.io.IOException; +import java.util.function.ToIntFunction; import org.apache.cassandra.index.sai.disk.PostingList; -import org.apache.cassandra.index.sai.utils.RowIdWithMeta; import org.apache.cassandra.utils.CloseableIterator; import org.apache.lucene.util.LongHeap; /** * A posting list for ANN search results. Transforms results from similarity order to rowId order. */ -public class VectorPostingList implements PostingList +public class ReorderingPostingList implements PostingList { private final LongHeap segmentRowIds; private final int size; - public VectorPostingList(CloseableIterator source) + public ReorderingPostingList(CloseableIterator source, ToIntFunction rowIdTransformer) { - // TODO find int specific data structure? segmentRowIds = new LongHeap(32); int n = 0; - // Once the source is consumed, we have to close it. try (source) { while (source.hasNext()) { - segmentRowIds.push(source.next().getSegmentRowId()); + segmentRowIds.push(rowIdTransformer.applyAsInt(source.next())); n++; } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 62748e3332c4..b6c93b368420 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -49,7 +49,7 @@ import org.apache.cassandra.index.sai.disk.v1.IndexSearcher; import org.apache.cassandra.index.sai.disk.v1.PerIndexFiles; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; -import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList; +import org.apache.cassandra.index.sai.disk.v1.postings.ReorderingPostingList; import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter; import org.apache.cassandra.index.sai.disk.vector.BruteForceRowIdIterator; import org.apache.cassandra.index.sai.disk.vector.CassandraDiskAnn; @@ -64,8 +64,8 @@ import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.index.sai.utils.RowIdWithMeta; import org.apache.cassandra.index.sai.utils.RowIdWithScore; -import org.apache.cassandra.index.sai.utils.SegmentOrdering; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.metrics.LinearFit; import org.apache.cassandra.metrics.PairedSlidingWindowReservoir; @@ -81,7 +81,7 @@ /** * Executes ann search against the graph for an individual index segment. */ -public class V2VectorIndexSearcher extends IndexSearcher implements SegmentOrdering +public class V2VectorIndexSearcher extends IndexSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); @@ -151,7 +151,7 @@ private PostingList searchPosting(QueryContext context, Expression exp, Abstract // this is a thresholded query, so pass graph.size() as top k to get all results satisfying the threshold var result = searchInternal(keyRange, context, queryVector, graph.size(), graph.size(), exp.getEuclideanSearchThreshold()); - return new VectorPostingList(result); + return new ReorderingPostingList(result, RowIdWithMeta::getSegmentRowId); } @Override @@ -160,11 +160,11 @@ public CloseableIterator orderBy(Orderer orderer, Express if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), orderer); - if (orderer.vector == null) + if (orderer.getVectorTerm() == null) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer)); int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0); return toMetaSortedIterator(result, context); @@ -485,14 +485,14 @@ public CloseableIterator orderResultsBy(SSTableReader rea if (cost.shouldUseBruteForce()) { // brute force using the in-memory compressed vectors to cut down the number of results returned - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); return toMetaSortedIterator(this.orderByBruteForce(queryVector, segmentOrdinalPairs, limit, rerankK), context); } // Create bits from the mapping var bits = bitSetForSearch(); segmentOrdinalPairs.forEachRightInt(bits::set); // else ask the index to perform a search limited to the bits we created - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); var results = graph.search(queryVector, limit, rerankK, 0, bits, context, cost::updateStatistics); return toMetaSortedIterator(results, context); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index 25231748149f..cd2f48cda384 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -222,7 +222,7 @@ public List> orderBy(QueryContext conte assert slice == null : "ANN does not support index slicing"; assert orderer.isANN() : "Only ANN is supported for vector search, received " + orderer.operator; - var qv = vts.createFloatVector(orderer.vector); + var qv = vts.createFloatVector(orderer.getVectorTerm()); return List.of(searchInternal(context, qv, keyRange, limit, 0)); } @@ -310,7 +310,7 @@ public CloseableIterator orderResultsBy(QueryContext cont relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit); // convert the expression value to query vector - var qv = vts.createFloatVector(orderer.vector); + var qv = vts.createFloatVector(orderer.getVectorTerm()); // brute force path if (keysInGraph.size() <= maxBruteForceRows) { diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java index 57517d9192d0..3c7f945a4194 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java @@ -662,6 +662,13 @@ public void close() throws IOException } } + /** + * Iterator that provides ordered access to all indexed terms and their associated primary keys + * in the TrieMemoryIndex. For each term in the index, yields PrimaryKeyWithSortKey objects that + * combine a primary key with its associated term. + *

+ * A more verbose name could be KeysMatchingTermsByTermIterator. + */ private class AllTermsIterator extends AbstractIterator { private final Iterator> iterator; diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java index d01715023dbc..63710a28c290 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Objects; @@ -32,22 +33,30 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.Runnables; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DataRange; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.RegularAndStaticColumns; +import org.apache.cassandra.db.filter.ColumnFilter; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.memtable.ShardBoundaries; import org.apache.cassandra.db.memtable.TrieMemtable; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Bounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeConcatIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithByteComparable; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; @@ -237,15 +246,45 @@ public List> orderBy(QueryContext query int startShard = boundaries.getShardForToken(keyRange.left.getToken()); int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken()); - var iterators = new ArrayList>(endShard - startShard + 1); - - for (int shard = startShard; shard <= endShard; ++shard) + if (!orderer.isBM25()) { - assert rangeIndexes[shard] != null; - iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + var iterators = new ArrayList>(endShard - startShard + 1); + for (int shard = startShard; shard <= endShard; ++shard) + { + assert rangeIndexes[shard] != null; + iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + } + return iterators; } - return iterators; + // BM25 + var queryTerms = orderer.extractQueryTerms(); + + // Intersect iterators to find documents containing all terms + var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms); + var intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build(); + + // Compute BM25 scores + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + return List.of(BM25Utils.computeScores(intersectedIterator, + queryTerms, + docStats, + indexContext, + memtable, + this::getCellForKey)); + } + + private List keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds keyRange, List queryTerms) + { + List termIterators = new ArrayList<>(queryTerms.size()); + for (ByteBuffer term : queryTerms) + { + Expression expr = new Expression(indexContext); + expr.add(Operator.ANALYZER_MATCHES, term); + KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE); + termIterators.add(iterator); + } + return termIterators; } @Override @@ -257,32 +296,98 @@ public long estimateMatchingRowsCount(Expression expression, AbstractBounds orderResultsBy(QueryContext context, List keys, Orderer orderer, int limit) + public CloseableIterator orderResultsBy(QueryContext queryContext, List keys, Orderer orderer, int limit) { if (keys.isEmpty()) return CloseableIterator.emptyIterator(); - return SortingIterator.createCloseable( - orderer.getComparator(), - keys, - key -> + + if (!orderer.isBM25()) + { + return SortingIterator.createCloseable( + orderer.getComparator(), + keys, + key -> + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + return null; + + // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what + // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. + var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); + return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); + }, + Runnables.doNothing() + ); + } + + // BM25 + var queryTerms = orderer.extractQueryTerms(); + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + return BM25Utils.computeScores(keys.iterator(), + queryTerms, + docStats, + indexContext, + memtable, + this::getCellForKey); + } + + /** + * Count document frequencies for each term using brute force + */ + private BM25Utils.DocStats computeDocumentFrequencies(QueryContext queryContext, List queryTerms) + { + var termIterators = keyIteratorsPerTerm(queryContext, Bounds.unbounded(indexContext.getPartitioner()), queryTerms); + var documentFrequencies = new HashMap(); + for (int i = 0; i < queryTerms.size(); i++) + { + // KeyRangeIterator.getMaxKeys is not accurate enough, we have to count them + long keys = 0; + for (var it = termIterators.get(i); it.hasNext(); it.next()) + keys++; + documentFrequencies.put(queryTerms.get(i), keys); + } + long docCount = 0; + + try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())), + DataRange.allData(memtable.metadata().partitioner))) + { + while (it.hasNext()) { - var partition = memtable.getPartition(key.partitionKey()); - if (partition == null) - return null; - var row = partition.getRow(key.clustering()); - if (row == null) - return null; - var cell = row.getCell(indexContext.getDefinition()); - if (cell == null) - return null; - - // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what - // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. - var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); - return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); - }, - Runnables.doNothing() - ); + var partitions = it.next(); + while (partitions.hasNext()) + { + var unfiltered = partitions.next(); + if (!unfiltered.isRow()) + continue; + var row = (Row) unfiltered; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + continue; + + docCount++; + } + } + } + return new BM25Utils.DocStats(documentFrequencies, docCount); + } + + @Nullable + private org.apache.cassandra.db.rows.Cell getCellForKey(PrimaryKey key) + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + return row.getCell(indexContext.getDefinition()); } private ByteComparable encode(ByteBuffer input) diff --git a/src/java/org/apache/cassandra/index/sai/plan/Expression.java b/src/java/org/apache/cassandra/index/sai/plan/Expression.java index a1bd9acddc4b..aac8e240829b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Expression.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Expression.java @@ -101,6 +101,7 @@ public static Op valueOf(Operator operator) return IN; case ANN: + case BM25: case ORDER_BY_ASC: case ORDER_BY_DESC: return ORDER_BY; @@ -250,6 +251,7 @@ public Expression add(Operator op, ByteBuffer value) boundedAnnEuclideanDistanceThreshold = GeoUtil.amplifiedEuclideanSimilarityThreshold(lower.value.vector, searchRadiusMeters); break; case ANN: + case BM25: case ORDER_BY_ASC: case ORDER_BY_DESC: // If we alread have an operation on the column, we don't need to set the ORDER_BY op because diff --git a/src/java/org/apache/cassandra/index/sai/plan/Operation.java b/src/java/org/apache/cassandra/index/sai/plan/Operation.java index abc9735ce510..71cd781def91 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Operation.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Operation.java @@ -93,7 +93,7 @@ protected static ListMultimap analyzeGroup(QueryCont analyzer.reset(e.getIndexValue()); // EQ/LIKE_*/NOT_EQ can have multiple expressions e.g. text = "Hello World", - // becomes text = "Hello" OR text = "World" because "space" is always interpreted as a split point (by analyzer), + // becomes text = "Hello" AND text = "World" because "space" is always interpreted as a split point (by analyzer), // CONTAINS/CONTAINS_KEY are always treated as multiple expressions since they currently only targetting // collections, NOT_EQ is made an independent expression only in case of pre-existing multiple EQ expressions, or // if there is no EQ operations and NOT_EQ is met or a single NOT_EQ expression present, @@ -102,6 +102,7 @@ protected static ListMultimap analyzeGroup(QueryCont boolean isMultiExpression = columnIsMultiExpression.getOrDefault(e.column(), Boolean.FALSE); switch (e.operator()) { + // case BM25: leave it at the default of `false` case EQ: // EQ operator will always be a multiple expression because it is being used by map entries isMultiExpression = indexContext.isNonFrozenCollection(); diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index e23c83ed91e2..6193e0a37110 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -19,6 +19,7 @@ package org.apache.cassandra.index.sai.plan; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.EnumSet; @@ -41,12 +42,14 @@ public class Orderer { // The list of operators that are valid for order by clauses. static final EnumSet ORDER_BY_OPERATORS = EnumSet.of(Operator.ANN, + Operator.BM25, Operator.ORDER_BY_ASC, Operator.ORDER_BY_DESC); public final IndexContext context; public final Operator operator; - public final float[] vector; + public final ByteBuffer term; + private float[] vector; /** * Create an orderer for the given index context, operator, and term. @@ -59,7 +62,7 @@ public Orderer(IndexContext context, Operator operator, ByteBuffer term) this.context = context; assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator; this.operator = operator; - this.vector = context.getValidator().isVector() ? TypeUtil.decomposeVector(context.getValidator(), term) : null; + this.term = term; } public String getIndexName() @@ -75,8 +78,8 @@ public boolean isAscending() public Comparator getComparator() { - // ANN's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue - return isAscending() || isANN() ? Comparator.naturalOrder() : Comparator.reverseOrder(); + // ANN/BM25's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue + return (isAscending() || isANN() || isBM25()) ? Comparator.naturalOrder() : Comparator.reverseOrder(); } public boolean isLiteral() @@ -89,6 +92,11 @@ public boolean isANN() return operator == Operator.ANN; } + public boolean isBM25() + { + return operator == Operator.BM25; + } + @Nullable public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) { @@ -110,8 +118,34 @@ public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) public String toString() { String direction = isAscending() ? "ASC" : "DESC"; - return isANN() - ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction - : context.getColumnName() + ' ' + direction; + if (isANN()) + return context.getColumnName() + " ANN OF " + Arrays.toString(getVectorTerm()) + ' ' + direction; + if (isBM25()) + return context.getColumnName() + " BM25 OF " + TypeUtil.getString(term, context.getValidator()) + ' ' + direction; + return context.getColumnName() + ' ' + direction; + } + + public float[] getVectorTerm() + { + if (vector == null) + vector = TypeUtil.decomposeVector(context.getValidator(), term); + return vector; + } + + public ArrayList extractQueryTerms() + { + var queryAnalyzer = context.getQueryAnalyzerFactory().create(); + // Split query into terms + var queryTerms = new ArrayList(); + queryAnalyzer.reset(term); + try + { + queryAnalyzer.forEachRemaining(queryTerms::add); + } + finally + { + queryAnalyzer.end(); + } + return queryTerms; } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java index 2bb379a16f14..46baa4f59a87 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java @@ -1243,10 +1243,12 @@ protected double estimateSelectivity() @Override protected KeysIterationCost estimateCost() { - return ordering.isANN() - ? estimateAnnSortCost() - : estimateGlobalSortCost(); - + if (ordering.isANN()) + return estimateAnnSortCost(); + else if (ordering.isBM25()) + return estimateBm25SortCost(); + else + return estimateGlobalSortCost(); } private KeysIterationCost estimateAnnSortCost() @@ -1263,6 +1265,21 @@ private KeysIterationCost estimateAnnSortCost() return new KeysIterationCost(expectedKeys, initCost, searchCost); } + private KeysIterationCost estimateBm25SortCost() + { + double expectedKeys = access.expectedAccessCount(source.expectedKeys()); + + int termCount = ordering.extractQueryTerms().size(); + // all of the cost for BM25 is up front since the index doesn't give us the information we need + // to return results in order, in isolation. The big cost is reading the indexed cells out of + // the sstables. + // VSTODO if we had stats on cell size _per column_ we could usefully include ROW_BYTE_COST + double initCost = source.fullCost() + + source.expectedKeys() * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + return new KeysIterationCost(expectedKeys, initCost, 0); + } + private KeysIterationCost estimateGlobalSortCost() { return new KeysIterationCost(source.expectedKeys(), @@ -1296,36 +1313,30 @@ protected KeysSort withAccess(Access access) } /** - * Returns all keys in ANN order. - * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + * Base class for index scans that return results in a computed order (ANN, BM25) + * rather than the natural index order. */ - final static class AnnIndexScan extends Leaf + abstract static class ComputedOrderIndexScan extends Leaf { final Orderer ordering; - protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + protected ComputedOrderIndexScan(Factory factory, int id, Access access, Orderer ordering) { super(factory, id, access); this.ordering = ordering; } + @Nullable @Override - protected KeysIterationCost estimateCost() + protected Orderer ordering() { - double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); - int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); - double searchCost = factory.costEstimator.estimateAnnSearchCost(ordering, - expectedKeysInt, - factory.tableMetrics.rows); - double initCost = 0; // negligible - return new KeysIterationCost(expectedKeys, initCost, searchCost); + return ordering; } - @Nullable @Override - protected Orderer ordering() + protected double estimateSelectivity() { - return ordering; + return 1.0; } @Override @@ -1334,6 +1345,30 @@ protected Iterator execute(Executor executor) int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows))); return executor.getTopKRows((Expression) null, softLimit); } + } + + /** + * Returns all keys in ANN order. + * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + */ + final static class AnnIndexScan extends ComputedOrderIndexScan + { + protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + + @Override + protected KeysIterationCost estimateCost() + { + double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); + int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); + double searchCost = factory.costEstimator.estimateAnnSearchCost(ordering, + expectedKeysInt, + factory.tableMetrics.rows); + double initCost = 0; // negligible + return new KeysIterationCost(expectedKeys, initCost, searchCost); + } @Override protected KeysIteration withAccess(Access access) @@ -1342,11 +1377,39 @@ protected KeysIteration withAccess(Access access) ? this : new AnnIndexScan(factory, id, access, ordering); } + } + /** + * Returns all keys in BM25 order. + * Like AnnIndexScan, this generates results lazily without an input node. + */ + final static class Bm25IndexScan extends ComputedOrderIndexScan + { + protected Bm25IndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + + @Nonnull @Override - protected double estimateSelectivity() + protected KeysIterationCost estimateCost() { - return 1.0; + double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); + int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); + + int termCount = ordering.extractQueryTerms().size(); + double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + + return new KeysIterationCost(expectedKeys, initCost, 0); + } + + @Override + protected KeysIteration withAccess(Access access) + { + return Objects.equals(access, this.access) + ? this + : new Bm25IndexScan(factory, id, access, ordering); } } @@ -1664,6 +1727,8 @@ private KeysIteration indexScan(Expression predicate, long matchingKeysCount, Or if (ordering != null) if (ordering.isANN()) return new AnnIndexScan(this, id, defaultAccess, ordering); + else if (ordering.isBM25()) + return new Bm25IndexScan(this, id, defaultAccess, ordering); else if (ordering.isLiteral()) return new LiteralIndexScan(this, id, predicate, matchingKeysCount, defaultAccess, ordering); else @@ -1911,6 +1976,9 @@ public static class CostCoefficients /** Additional cost added to row fetch cost per each serialized byte of the row */ public final static double ROW_BYTE_COST = 0.005; + + /** Cost to perform BM25 scoring, per query term */ + public final static double BM25_SCORE_COST = 0.5; } /** Convenience builder for building intersection and union nodes */ diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 9a8ee4ecff00..fbe4430efc6c 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -809,6 +809,7 @@ private long estimateMatchingRowCount(Expression predicate) switch (predicate.getOp()) { case EQ: + case MATCH: case CONTAINS_KEY: case CONTAINS_VALUE: case NOT_EQ: diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index 7b196212a95f..bfde594fadfc 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -110,19 +110,17 @@ public UnfilteredPartitionIterator search(ReadExecutionController executionContr // Can't check for `command.isTopK()` because the planner could optimize sorting out Orderer ordering = plan.ordering(); - if (ordering != null) - { - assert !(keysIterator instanceof KeyRangeIterator); - var scoredKeysIterator = (CloseableIterator) keysIterator; - var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, - executionController, queryContext, command.limits().count()); - return (UnfilteredPartitionIterator) new TopKProcessor(command).filter(result); - } - else + if (ordering == null) { assert keysIterator instanceof KeyRangeIterator; return new ResultRetriever((KeyRangeIterator) keysIterator, filterTree, controller, executionController, queryContext); } + + assert !(keysIterator instanceof KeyRangeIterator); + var scoredKeysIterator = (CloseableIterator) keysIterator; + var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, + executionController, queryContext, command.limits().count()); + return (UnfilteredPartitionIterator) new TopKProcessor(command).filter(result); } catch (Throwable t) { @@ -501,26 +499,27 @@ public UnfilteredRowIterator computeNext() */ private void fillPendingRows() { + // Group PKs by source sstable/memtable + var groupedKeys = new HashMap>(); // We always want to get at least 1. int rowsToRetrieve = Math.max(1, softLimit - returnedRowCount); - var keys = new HashMap>(); // We want to get the first unique `rowsToRetrieve` keys to materialize // Don't pass the priority queue here because it is more efficient to add keys in bulk - fillKeys(keys, rowsToRetrieve, null); + fillKeys(groupedKeys, rowsToRetrieve, null); // Sort the primary keys by PrK order, just in case that helps with cache and disk efficiency - var primaryKeyPriorityQueue = new PriorityQueue<>(keys.keySet()); + var primaryKeyPriorityQueue = new PriorityQueue<>(groupedKeys.keySet()); - while (!keys.isEmpty()) + while (!groupedKeys.isEmpty()) { - var primaryKey = primaryKeyPriorityQueue.poll(); - var primaryKeyWithSortKeys = keys.remove(primaryKey); - var partitionIterator = readAndValidatePartition(primaryKey, primaryKeyWithSortKeys); + var pk = primaryKeyPriorityQueue.poll(); + var sourceKeys = groupedKeys.remove(pk); + var partitionIterator = readAndValidatePartition(pk, sourceKeys); if (partitionIterator != null) pendingRows.add(partitionIterator); else // The current primaryKey did not produce a partition iterator. We know the caller will need // `rowsToRetrieve` rows, so we get the next unique key and add it to the queue. - fillKeys(keys, 1, primaryKeyPriorityQueue); + fillKeys(groupedKeys, 1, primaryKeyPriorityQueue); } } @@ -528,21 +527,21 @@ private void fillPendingRows() * Fills the keys map with the next `count` unique primary keys that are in the keys produced by calling * {@link #nextSelectedKeyInRange()}. We map PrimaryKey to List because the same * primary key can be in the result set multiple times, but with different source tables. - * @param keys the map to fill + * @param groupedKeys the map to fill * @param count the number of unique PrimaryKeys to consume from the iterator * @param primaryKeyPriorityQueue the priority queue to add new keys to. If the queue is null, we do not add * keys to the queue. */ - private void fillKeys(Map> keys, int count, PriorityQueue primaryKeyPriorityQueue) + private void fillKeys(Map> groupedKeys, int count, PriorityQueue primaryKeyPriorityQueue) { - int initialSize = keys.size(); - while (keys.size() - initialSize < count) + int initialSize = groupedKeys.size(); + while (groupedKeys.size() - initialSize < count) { var primaryKeyWithSortKey = nextSelectedKeyInRange(); if (primaryKeyWithSortKey == null) return; var nextPrimaryKey = primaryKeyWithSortKey.primaryKey(); - var accumulator = keys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); + var accumulator = groupedKeys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); if (primaryKeyPriorityQueue != null && accumulator.isEmpty()) primaryKeyPriorityQueue.add(nextPrimaryKey); accumulator.add(primaryKeyWithSortKey); @@ -582,15 +581,29 @@ private boolean isInRange(DecoratedKey key) return null; } - public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeys) + /** + * Reads and validates a partition for a given primary key against its sources. + *

+ * @param pk The primary key of the partition to read and validate + * @param sourceKeys A list of PrimaryKeyWithSortKey objects associated with the primary key. + * Multiple sort keys can exist for the same primary key when data comes from different + * sstables or memtables. + * + * @return An UnfilteredRowIterator containing the validated partition data, or null if: + * - The key has already been processed + * - The partition does not pass index filters + * - The partition contains no valid rows + * - The row data does not match the index metadata for any of the provided primary keys + */ + public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List sourceKeys) { // If we've already processed the key, we can skip it. Because the score ordered iterator does not // deduplicate rows, we could see dupes if a row is in the ordering index multiple times. This happens // in the case of dupes and of overwrites. - if (processedKeys.contains(key)) + if (processedKeys.contains(pk)) return null; - try (UnfilteredRowIterator partition = controller.getPartition(key, view, executionController)) + try (UnfilteredRowIterator partition = controller.getPartition(pk, view, executionController)) { queryContext.addPartitionsRead(1); queryContext.checkpoint(); @@ -599,7 +612,7 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List queryVector; + private final ColumnMetadata scoreColumn; private final int limit; @@ -103,16 +107,17 @@ public TopKProcessor(ReadCommand command) { this.command = command; - Pair annIndexAndExpression = findTopKIndexContext(); - Preconditions.checkNotNull(annIndexAndExpression); + Pair indexAndExpression = findTopKIndexContext(); + Preconditions.checkNotNull(indexAndExpression); - this.indexContext = annIndexAndExpression.left; - this.expression = annIndexAndExpression.right; + this.indexContext = indexAndExpression.left; + this.expression = indexAndExpression.right; if (expression.operator() == Operator.ANN) this.queryVector = vts.createFloatVector(TypeUtil.decomposeVector(indexContext, expression.getIndexValue().duplicate())); else this.queryVector = null; this.limit = command.limits().count(); + this.scoreColumn = ColumnMetadata.syntheticColumn(indexContext.getKeyspace(), indexContext.getTable(), ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); } /** @@ -158,8 +163,11 @@ private , P extends BaseParti { // priority queue ordered by score in descending order Comparator> comparator; - if (queryVector != null) + // TODO does this work for complex expressions? + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) + { comparator = Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed(); + } else { comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator()); @@ -187,7 +195,7 @@ private , P extends BaseParti executor.maybeExecuteImmediately(() -> { try (var partitionRowIterator = pIter.commandToIterator(command.left(), command.right())) { - future.complete(partitionRowIterator == null ? null : processPartition(partitionRowIterator)); + future.complete(partitionRowIterator == null ? null : processScoredPartition(partitionRowIterator)); } catch (Throwable t) { @@ -235,9 +243,9 @@ private , P extends BaseParti // have to close to move to the next partition, otherwise hasNext() fails try (var partitionRowIterator = partitions.next()) { - if (queryVector != null) + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) { - PartitionResults pr = processPartition(partitionRowIterator); + PartitionResults pr = processScoredPartition(partitionRowIterator); topK.addAll(pr.rows); for (var uf: pr.tombstones) addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf); @@ -250,7 +258,6 @@ private , P extends BaseParti topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, row.getCell(expression.column()).buffer())); } } - } } } @@ -286,7 +293,7 @@ void addRow(Triple triple) { /** * Processes a single partition, calculating scores for rows and extracting tombstones. */ - private PartitionResults processPartition(BaseRowIterator partitionRowIterator) { + private PartitionResults processScoredPartition(BaseRowIterator partitionRowIterator) { // Compute key and static row score once per partition DecoratedKey key = partitionRowIterator.partitionKey(); Row staticRow = partitionRowIterator.staticRow(); @@ -352,6 +359,14 @@ private float getScoreForRow(DecoratedKey key, Row row) if ((column.isClusteringColumn() || column.isRegular()) && row.isStatic()) return 0; + var scoreData = row.getColumnData(scoreColumn); + if (scoreData != null) + { + var cell = (Cell) scoreData; + return FloatType.instance.compose(cell.buffer()); + } + + // TODO remove this once we enable the scored path for vector queries ByteBuffer value = indexContext.getValueOf(key, row, FBUtilities.nowInSeconds()); if (value != null) { @@ -368,21 +383,24 @@ private Pair findTopKIndexContext() for (RowFilter.Expression expression : command.rowFilter().getExpressions()) { - StorageAttachedIndex sai = findVectorIndexFor(cfs.indexManager, expression); + StorageAttachedIndex sai = findOrderingIndexFor(cfs.indexManager, expression); if (sai != null) - { return Pair.create(sai.getIndexContext(), expression); - } } return null; } @Nullable - private StorageAttachedIndex findVectorIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) + private StorageAttachedIndex findOrderingIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) { - if (e.operator() != Operator.ANN && e.operator() != Operator.ORDER_BY_ASC && e.operator() != Operator.ORDER_BY_DESC) + if (e.operator() != Operator.ANN + && e.operator() != Operator.BM25 + && e.operator() != Operator.ORDER_BY_ASC + && e.operator() != Operator.ORDER_BY_DESC) + { return null; + } Optional index = sim.getBestIndexFor(e); return (StorageAttachedIndex) index.filter(i -> i instanceof StorageAttachedIndex).orElse(null); diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java new file mode 100644 index 000000000000..044997e6508b --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.utils.CloseableIterator; + +public class BM25Utils +{ + private static final float K1 = 1.2f; // BM25 term frequency saturation parameter + private static final float B = 0.75f; // BM25 length normalization parameter + + /** + * Term frequencies across all documents. Each document is only counted once. + */ + public static class DocStats + { + // Map of term -> count of docs containing that term + private final Map frequencies; + // total number of docs in the index + private final long docCount; + + public DocStats(Map frequencies, long docCount) + { + this.frequencies = frequencies; + this.docCount = docCount; + } + } + + /** + * Term frequencies within a single document. All instances of a term are counted. + */ + public static class DocTF + { + private final PrimaryKey pk; + private final Map frequencies; + private final float termCount; + + private DocTF(PrimaryKey pk, float termCount, Map frequencies) + { + this.pk = pk; + this.frequencies = frequencies; + this.termCount = termCount; + } + + public int getTermFrequency(ByteBuffer term) + { + return frequencies.getOrDefault(term, 0); + } + + public static DocTF createFromDocument(PrimaryKey pk, + Cell cell, + AbstractAnalyzer docAnalyzer, + Collection queryTerms) + { + float count = 0; + Map frequencies = new HashMap<>(); + + docAnalyzer.reset(cell.buffer()); + try + { + while (docAnalyzer.hasNext()) + { + ByteBuffer term = docAnalyzer.next(); + count++; + if (queryTerms.contains(term)) + frequencies.merge(term, 1, Integer::sum); + } + } + finally + { + docAnalyzer.end(); + } + + return new DocTF(pk, count, frequencies); + } + } + + @FunctionalInterface + public interface CellReader + { + Cell readCell(PrimaryKey pk); + } + + public static CloseableIterator computeScores(Iterator keyIterator, + List queryTerms, + DocStats docStats, + IndexContext indexContext, + Object source, + CellReader cellReader) + { + var docAnalyzer = indexContext.getAnalyzerFactory().create(); + + // data structures for document stats and frequencies + ArrayList documents = new ArrayList<>(); + double totalTermCount = 0; + + // Compute TF within each document + while (keyIterator.hasNext()) + { + var pk = keyIterator.next(); + var cell = cellReader.readCell(pk); + if (cell == null) + continue; + var tf = DocTF.createFromDocument(pk, cell, docAnalyzer, queryTerms); + + // sstable index will only send documents that contain all query terms to this method, + // but memtable is not indexed and will send all documents, so we have to skip documents + // that don't contain all query terms here to preserve consistency with sstable behavior + if (tf.frequencies.size() != queryTerms.size()) + continue; + + documents.add(tf); + + totalTermCount += tf.termCount; + } + + // Calculate average document length + double avgDocLength = documents.size() > 0 ? totalTermCount / documents.size() : 0.0; + + // Calculate BM25 scores + var scoredDocs = new ArrayList(documents.size()); + for (var doc : documents) + { + double score = 0.0; + for (var queryTerm : queryTerms) + { + int tf = doc.getTermFrequency(queryTerm); + Long df = docStats.frequencies.get(queryTerm); + double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength)); + double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); + double deltaScore = normalizedTf * idf; + assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, totalDocs=%d is %f", tf, df, docStats.docCount, deltaScore); + score += deltaScore; + } + if (source instanceof Memtable) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (Memtable) source, doc.pk, (float) score)); + else if (source instanceof SSTableId) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (SSTableId) source, doc.pk, (float) score)); + else + throw new IllegalArgumentException("Invalid source " + source.getClass()); + } + + // sort by score + scoredDocs.sort(Comparator.comparingDouble((PrimaryKeyWithScore pkws) -> pkws.indexScore).reversed()); + + return (CloseableIterator) (CloseableIterator) CloseableIterator.wrap(scoredDocs.iterator()); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java index 837c032b7952..949eb21d282c 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java @@ -21,7 +21,9 @@ import java.nio.ByteBuffer; import java.util.Arrays; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse; @@ -34,7 +36,13 @@ public class PrimaryKeyWithByteComparable extends PrimaryKeyWithSortKey { private final ByteComparable byteComparable; - public PrimaryKeyWithByteComparable(IndexContext context, Object sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + public PrimaryKeyWithByteComparable(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + { + super(context, sourceTable, primaryKey); + this.byteComparable = byteComparable; + } + + public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) { super(context, sourceTable, primaryKey); this.byteComparable = byteComparable; diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java index 61e63d43c7fa..a10d6e82549a 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java @@ -20,7 +20,9 @@ import java.nio.ByteBuffer; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; /** * A {@link PrimaryKey} that includes a score from a source index. @@ -30,7 +32,13 @@ public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey { public final float indexScore; - public PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore) + public PrimaryKeyWithScore(IndexContext context, Memtable source, PrimaryKey primaryKey, float indexScore) + { + super(context, source, primaryKey); + this.indexScore = indexScore; + } + + public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore) { super(context, source, primaryKey); this.indexScore = indexScore; diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java index 2e79b0402124..8a171fca4dbe 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java @@ -22,9 +22,11 @@ import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.Token; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; @@ -41,7 +43,14 @@ public abstract class PrimaryKeyWithSortKey implements PrimaryKey // Either a Memtable reference or an SSTableId reference private final Object sourceTable; - protected PrimaryKeyWithSortKey(IndexContext context, Object sourceTable, PrimaryKey primaryKey) + protected PrimaryKeyWithSortKey(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey) + { + this.context = context; + this.sourceTable = sourceTable; + this.primaryKey = primaryKey; + } + + protected PrimaryKeyWithSortKey(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey) { this.context = context; this.sourceTable = sourceTable; diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java index b4b1f129567f..c6ed4708e0c8 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java @@ -26,7 +26,7 @@ */ public class RowIdWithScore extends RowIdWithMeta { - private final float score; + public final float score; public RowIdWithScore(int segmentRowId, float score) { diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java new file mode 100644 index 000000000000..8feb61f36b83 --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test.sai; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.test.TestBaseImpl; + +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; +import static org.assertj.core.api.Assertions.assertThat; + +public class BM25DistributedTest extends TestBaseImpl +{ + private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}"; + private static final String CREATE_TABLE = "CREATE TABLE %s (k int PRIMARY KEY, v text)"; + private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex' WITH OPTIONS = {'index_analyzer': '{\"tokenizer\" : {\"name\" : \"standard\"}, \"filters\" : [{\"name\" : \"porterstem\"}]}'}"; + + // To get consistent results from BM25 we need to know which docs are evaluated, the easiest way + // to do that is to put all the docs on every replica + private static final int NUM_NODES = 3; + private static final int RF = 3; + + private static Cluster cluster; + private static String table; + + private static final AtomicInteger seq = new AtomicInteger(); + + @BeforeClass + public static void setupCluster() throws Exception + { + cluster = Cluster.build(NUM_NODES) + .withTokenCount(1) + .withDataDirCount(1) + .withConfig(config -> config.with(GOSSIP).with(NETWORK)) + .start(); + + cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF))); + } + + @AfterClass + public static void closeCluster() + { + if (cluster != null) + cluster.close(); + } + + @Before + public void before() + { + table = "table_" + seq.getAndIncrement(); + cluster.schemaChange(formatQuery(CREATE_TABLE)); + cluster.schemaChange(formatQuery(CREATE_INDEX)); + SAIUtil.waitForIndexQueryable(cluster, KEYSPACE); + } + + @Test + public void testTermFrequencyOrdering() + { + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + // Query memtable index + assertBM25Ordering(); + + // Flush and query on-disk index + cluster.forEach(n -> n.flush(KEYSPACE)); + assertBM25Ordering(); + } + + private void assertBM25Ordering() + { + Object[][] result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertThat(result).hasNumberOfRows(3); + + // Results should be ordered by term frequency (highest to lowest) + assertThat((Integer) result[0][0]).isEqualTo(3); // 3 occurrences + assertThat((Integer) result[1][0]).isEqualTo(2); // 2 occurrences + assertThat((Integer) result[2][0]).isEqualTo(1); // 1 occurrence + } + + private static Object[][] execute(String query) + { + return execute(query, ConsistencyLevel.QUORUM); + } + + private static Object[][] execute(String query, ConsistencyLevel consistencyLevel) + { + return cluster.coordinator(1).execute(formatQuery(query), consistencyLevel); + } + + private static String formatQuery(String query) + { + return String.format(query, KEYSPACE + '.' + table); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java b/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java index 7538d65a4d88..a32b6347f6b2 100644 --- a/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/StorageAttachedIndexTest.java @@ -135,12 +135,20 @@ private static ByteBuffer floatVectorToByteBuffer(CQLTester.Vector vector @Test public void testOrderResults() { + QueryOptions queryOptions = QueryOptions.create(ConsistencyLevel.ONE, + byteBufferList, + false, + PageSize.inRows(1), + null, + null, + ProtocolVersion.CURRENT, + KEYSPACE); List> rows = new ArrayList<>(); rows.add(byteBufferList); SelectStatement selectStatementInstance = (SelectStatement) QueryProcessor.prepareInternal("SELECT key, value FROM " + KEYSPACE + '.' + TABLE).statement; - SortedRowsBuilder builder = selectStatementInstance.sortedRowsBuilder(Integer.MAX_VALUE, 0); + SortedRowsBuilder builder = selectStatementInstance.sortedRowsBuilder(Integer.MAX_VALUE, 0, queryOptions); rows.forEach(builder::add); List> sortedRows = builder.build(); diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java new file mode 100644 index 000000000000..c8dbb2ae2528 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.cql; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.junit.Test; + +import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.disk.v1.InvertedIndexSearcher; +import org.apache.cassandra.index.sai.plan.QueryController; +import org.apache.cassandra.inject.ActionBuilder; +import org.apache.cassandra.inject.Expression; +import org.apache.cassandra.inject.Injections; +import org.apache.cassandra.inject.InvokePointBuilder; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.junit.Assert.assertEquals; + +public class BM25Test extends SAITester +{ + @Test + public void testTermFrequencyOrdering() throws Throwable + { + createSimpleTable(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testDocumentLength() throws Throwable + { + createSimpleTable(); + // Create documents with same term frequency but different lengths + execute("INSERT INTO %s (k, v) VALUES (1, 'test test')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'test test other words here to make it longer')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'test test extremely long document with many additional words to significantly increase the document length while maintaining the same term frequency for our target term')"); + + beforeAndAfterFlush(() -> + { + // Documents with same term frequency should be ordered by length (shorter first) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 3"); + assertRows(result, + row(1), + row(2), + row(3)); + }); + } + + @Test + public void testMultiTermQueryScoring() throws Throwable + { + createSimpleTable(); + // Two terms, but "apple" appears in fewer documents + execute("INSERT INTO %s (k, v) VALUES (1, 'apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (4, 'apple apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (5, 'banana banana')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple banana' LIMIT 4"); + assertRows(result, + row(2), // Highest frequency of most important term + row(4), // More mentions of both terms + row(1), // One of each term + row(3)); // Low frequency of most important term + }); + } + + @Test + public void testIrrelevantRowsScoring() throws Throwable + { + createSimpleTable(); + // Insert pizza reviews with varying relevance to "crispy crust" + execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention + execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy + execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions + execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both + execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'crispy crust' LIMIT 5"); + assertRows(result, + row(4), // Highest frequency of both terms + row(2), // High frequency of 'crispy', one 'crust' + row(1)); // One mention of each term + // Rows 4 and 5 do not contain all terms + }); + } + + private void createSimpleTable() + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + analyzeIndex(); + } + + private String analyzeIndex() + { + return createIndex("CREATE CUSTOM INDEX ON %s(v) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" + ); + } + + @Test + public void testWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, p int, v text)"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, p, v) VALUES (1, 5, 'apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (2, 5, 'apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartition() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPkPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (1, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (2, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE k1 = 0 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, p int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 1, 5, 'apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 2, 5, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWithPredicateSearchThenOrder() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 0; + testWithPredicate(); + } + + @Test + public void testWidePartitionWithPredicateOrderThenSearch() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 1; + testWidePartitionWithPredicate(); + } + + @Test + public void testQueryWithNulls() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (0, null)"); + execute("INSERT INTO %s (k, v) VALUES (1, 'test document')"); + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testQueryEmptyTable() + { + createSimpleTable(); + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertThat(result).hasSize(0); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java new file mode 100644 index 000000000000..40f6f16ef970 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; + +import org.junit.Test; + +import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.index.sai.postings.IntArrayPostingList; +import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; + +public class IntersectingPostingListTest extends SaiRandomizedTest +{ + @Test + public void shouldIntersectOverlappingPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 4, 6, 8 }), + new IntArrayPostingList(new int[]{ 2, 4, 6, 9 }), + new IntArrayPostingList(new int[]{ 4, 6, 7 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 4, 6 }), intersected); + } + + @Test + public void shouldIntersectDisjointPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5 }), + new IntArrayPostingList(new int[]{ 2, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{}), intersected); + } + + @Test + public void shouldIntersectSinglePostingList() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 4, 6 }), intersected); + } + + @Test + public void shouldIntersectIdenticalPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 2, 3 }), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 2, 3 }), intersected); + } + + @Test + public void shouldAdvanceAllIntersectedLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 2, 3, 5, 7, 8 }), + new IntArrayPostingList(new int[]{ 3, 5, 7, 10 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + final PostingList expected = new IntArrayPostingList(new int[]{ 3, 5, 7 }); + + assertEquals(expected.advance(5), + intersected.advance(5)); + + assertPostingListEquals(expected, intersected); + } + + @Test + public void shouldHandleEmptyList() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{}), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertEquals(PostingList.END_OF_STREAM, intersected.advance(1)); + } + + @Test + public void shouldInterleaveNextAndAdvance() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertEquals(1, intersected.nextPosting()); + assertEquals(5, intersected.advance(5)); + assertEquals(7, intersected.nextPosting()); + assertEquals(9, intersected.advance(9)); + } + + @Test + public void shouldInterleaveNextAndAdvanceOnRandom() throws IOException + { + for (int i = 0; i < 1000; ++i) + { + testAdvancingOnRandom(); + } + } + + private void testAdvancingOnRandom() throws IOException + { + final int postingsCount = nextInt(1, 50_000); + final int postingListCount = nextInt(2, 10); + + // Generate base postings that will be present in all lists + final AtomicInteger rowId = new AtomicInteger(); + final int[] commonPostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount / 4) // Fewer common elements + .toArray(); + + var splitPostingLists = new ArrayList(); + for (int i = 0; i < postingListCount; i++) + { + // Combine common postings with some unique ones for each list + final int[] uniquePostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount) + .toArray(); + int[] combined = IntStream.concat(IntStream.of(commonPostings), + IntStream.of(uniquePostings)) + .distinct() + .sorted() + .toArray(); + splitPostingLists.add(new IntArrayPostingList(combined)); + } + + final PostingList intersected = IntersectingPostingList.intersect(splitPostingLists); + final PostingList expected = new IntArrayPostingList(commonPostings); + + final List actions = new ArrayList<>(); + for (int idx = 0; idx < commonPostings.length; idx++) + { + if (nextInt(0, 8) == 0) + { + actions.add((postingList) -> { + try + { + return postingList.nextPosting(); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + else + { + final int skips = nextInt(0, 5); + idx = Math.min(idx + skips, commonPostings.length - 1); + final int rowID = commonPostings[idx]; + actions.add((postingList) -> { + try + { + return postingList.advance(rowID); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + } + + for (PostingListAdvance action : actions) + { + assertEquals(action.advance(expected), action.advance(intersected)); + } + } + + private ArrayList listOfLists(PostingList... postingLists) + { + var L = new ArrayList(); + Collections.addAll(L, postingLists); + return L; + } + + private interface PostingListAdvance + { + long advance(PostingList list) throws IOException; + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java similarity index 91% rename from test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java rename to test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java index 3135db3c7748..10b2b6a33f47 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java @@ -30,14 +30,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class VectorPostingListTest +public class ReorderingPostingListTest { @Test public void ensureEmptySourceBehavesCorrectly() throws Throwable { var source = new TestIterator(CloseableIterator.emptyIterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { // Even an empty source should be closed assertTrue(source.isClosed); @@ -55,7 +55,7 @@ public void ensureIteratorIsConsumedClosedAndReordered() throws Throwable new RowIdWithScore(4, 4), }).iterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { // The posting list is eagerly consumed, so it should be closed before // we close postingList @@ -80,7 +80,7 @@ public void ensureAdvanceWorksCorrectly() throws Throwable new RowIdWithScore(2, 2), }).iterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { assertEquals(3, postingList.advance(3)); assertEquals(PostingList.END_OF_STREAM, postingList.advance(4)); diff --git a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java index 75e8c83f30be..2197c00fe231 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java @@ -148,7 +148,7 @@ private void validate(List keys) { IntStream.range(0, 1_000).parallel().forEach(i -> { - var orderer = generateRandomOrderer(); + var orderer = randomVectorOrderer(); AbstractBounds keyRange = generateRandomBounds(keys); // compute keys in range of the bounds Set keysInRange = keys.stream().filter(keyRange::contains) @@ -197,7 +197,7 @@ public void indexIteratorTest() // VSTODO } - private Orderer generateRandomOrderer() + private Orderer randomVectorOrderer() { return new Orderer(indexContext, Operator.ANN, randomVectorSerialized()); } From 56a6e0f0109ee5981a59ae791506f365347e87d0 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 6 Dec 2024 16:22:32 -0600 Subject: [PATCH 06/29] re-disallow DESC with ORDER BY ANN --- .../cassandra/cql3/restrictions/StatementRestrictions.java | 3 ++- .../apache/cassandra/cql3/statements/SelectStatement.java | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index d28e16fa399f..7e92d3273077 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -682,7 +682,8 @@ else if (indexOrderings.size() == 1) if (orderings.size() > 1) throw new InvalidRequestException("Cannot combine clustering column ordering with non-clustering column ordering"); Ordering ordering = indexOrderings.get(0); - if (ordering.direction != Ordering.Direction.ASC && ordering.expression.isScored()) + // TODO remove the instanceof with SelectStatement.ANN_USE_SYNTHETIC_SCORE. + if (ordering.direction != Ordering.Direction.ASC && (ordering.expression.isScored() || ordering.expression instanceof Ordering.Ann)) throw new InvalidRequestException("Descending ANN ordering is not supported"); if (!ENABLE_SAI_GENERAL_ORDER_BY && ordering.expression instanceof Ordering.SingleColumn) throw new InvalidRequestException("SAI based ORDER BY on non-vector column is not supported"); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index e7d5cf5874a6..96e992fe876f 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -101,8 +101,11 @@ */ public class SelectStatement implements CQLStatement.SingleKeyspaceCqlStatement { - // TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns - // (And don't forget to remove the related hacks in Columns.Serializer.encodeBitmap and UnfilteredSerializer.serializeRowBody) + // TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns, + // and the related code in + // - Columns.Serializer.encodeBitmap + // - UnfilteredSerializer.serializeRowBody) + // - StatementRestrictions.addOrderingRestrictions public static final boolean ANN_USE_SYNTHETIC_SCORE = Boolean.parseBoolean(System.getProperty("cassandra.sai.ann_use_synthetic_score", "false")); private static final Logger logger = LoggerFactory.getLogger(SelectStatement.class); From 30b6545d417e84465d2475dd04499d401489c142 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Mon, 9 Dec 2024 10:14:50 -0600 Subject: [PATCH 07/29] cleanup and comments --- src/java/org/apache/cassandra/cql3/Operator.java | 2 +- .../cassandra/cql3/SingleColumnRelation.java | 2 +- .../restrictions/SingleColumnRestriction.java | 2 -- .../cql3/statements/SelectStatement.java | 15 +++++++-------- .../apache/cassandra/index/sai/QueryContext.java | 2 +- .../index/sai/disk/v1/InvertedIndexSearcher.java | 5 ++--- .../disk/v1/postings/IntersectingPostingList.java | 2 +- .../index/sai/disk/v2/V2VectorIndexSearcher.java | 2 +- .../index/sai/memory/TrieMemtableIndex.java | 1 + .../org/apache/cassandra/index/sai/plan/Plan.java | 8 ++++---- .../sai/plan/StorageAttachedIndexSearcher.java | 4 ++-- .../cassandra/index/sai/plan/TopKProcessor.java | 2 -- .../apache/cassandra/index/sai/cql/BM25Test.java | 10 ---------- 13 files changed, 21 insertions(+), 36 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/Operator.java b/src/java/org/apache/cassandra/cql3/Operator.java index d4bf3ab93d3c..8a8494661cb4 100644 --- a/src/java/org/apache/cassandra/cql3/Operator.java +++ b/src/java/org/apache/cassandra/cql3/Operator.java @@ -259,7 +259,7 @@ public String toString() @Override public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) { - return true; + throw new UnsupportedOperationException(); } }, NOT_IN(16) diff --git a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java index 5a93e6c8e422..a15251bdefa7 100644 --- a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java @@ -338,7 +338,7 @@ protected Restriction newBm25Restriction(TableMetadata table, VariableSpecificat { ColumnMetadata columnDef = table.getExistingColumn(entity); Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); - return new SingleColumnRestriction.AnnRestriction(columnDef, term); + return new SingleColumnRestriction.Bm25Restriction(columnDef, term); } @Override diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java index 9d52f2bfe43c..1cde8cfc00e4 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java @@ -1241,8 +1241,6 @@ public String toString() @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { - if (otherRestriction.isIndexBasedOrdering()) - throw invalidRequest("%s cannot be restricted by multiple BM25 restrictions", columnDef.name); throw invalidRequest("%s cannot be restricted by both BM25 and %s", columnDef.name, otherRestriction.toString()); } diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 96e992fe876f..847d058cf26b 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -1125,10 +1125,10 @@ public SortedRowsBuilder sortedRowsBuilder(int limit, int offset, QueryOptions o if (orderingComparator == null) return SortedRowsBuilder.create(limit, offset); - if (orderingComparator instanceof IndexColumnComparator) + if (orderingComparator instanceof VectorColumnComparator) { - SingleRestriction restriction = ((IndexColumnComparator) orderingComparator).restriction; - int columnIndex = ((IndexColumnComparator) orderingComparator).columnIndex; + SingleRestriction restriction = ((VectorColumnComparator) orderingComparator).restriction; + int columnIndex = ((VectorColumnComparator) orderingComparator).columnIndex; Index index = restriction.findSupportingIndex(IndexRegistry.obtain(table)); assert index != null; @@ -1497,7 +1497,7 @@ private ColumnComparator> getOrderingComparator(Selection selec var column = e.getKey(); var ordering = e.getValue(); if (ordering.expression instanceof Ordering.Ann && !ANN_USE_SYNTHETIC_SCORE) - return new IndexColumnComparator(ordering.expression.toRestriction(), selection.getOrderingIndex(column)); + return new VectorColumnComparator(ordering.expression.toRestriction(), selection.getOrderingIndex(column)); else return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false); } @@ -1719,14 +1719,13 @@ public boolean isClustered() } } - // see usage in sortedRowsBuilder - private static class IndexColumnComparator extends ColumnComparator> + // placeholder for postQueryScorer call; see usage in sortedRowsBuilder + private static class VectorColumnComparator extends ColumnComparator> { private final SingleRestriction restriction; private final int columnIndex; - // VSTODO maybe cache in prepared statement - public IndexColumnComparator(SingleRestriction restriction, int columnIndex) + public VectorColumnComparator(SingleRestriction restriction, int columnIndex) { this.restriction = restriction; this.columnIndex = columnIndex; diff --git a/src/java/org/apache/cassandra/index/sai/QueryContext.java b/src/java/org/apache/cassandra/index/sai/QueryContext.java index 7d766f77dfd1..0bb2a96922ea 100644 --- a/src/java/org/apache/cassandra/index/sai/QueryContext.java +++ b/src/java/org/apache/cassandra/index/sai/QueryContext.java @@ -35,7 +35,7 @@ @NotThreadSafe public class QueryContext { - private static final boolean DISABLE_TIMEOUT = true; // Boolean.getBoolean("cassandra.sai.test.disable.timeout"); + private static final boolean DISABLE_TIMEOUT = Boolean.getBoolean("cassandra.sai.test.disable.timeout"); protected final long queryStartTimeNanos; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 295f5fe9356d..b8353a3b5d80 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -77,9 +77,6 @@ public class InvertedIndexSearcher extends IndexSearcher private final TermsReader reader; private final QueryEventListener.TrieIndexEventListener perColumnEventListener; private final Version version; - private static final float K1 = 1.2f; // BM25 term frequency saturation parameter - private static final float B = 0.75f; // BM25 length normalization parameter - private final boolean filterRangeResults; private final SSTableReader sstable; @@ -189,6 +186,7 @@ public CloseableIterator orderBy(Orderer orderer, Express try (var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap(); var merged = IntersectingPostingList.intersect(List.copyOf(postingLists.values()))) { + // construct an Iterator() from our intersected postings var it = new AbstractIterator() { @Override protected PrimaryKey computeNext() @@ -230,6 +228,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea return super.orderResultsBy(reader, queryContext, keys, orderer, limit); var queryTerms = orderer.extractQueryTerms(); + // compute documentFrequencies from either histogram or an index search var documentFrequencies = new HashMap(); boolean hasHistograms = metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram(); for (ByteBuffer term : queryTerms) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java index 61fa01ff848b..b6240bfce7b1 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -36,7 +36,7 @@ public class IntersectingPostingList implements PostingList private final List postingLists; private final int size; - // currentPostings state is effectively local to findNextIntersection, but we keep it + // currentRowIds state is effectively local to findNextIntersection, but we keep it // around as a field to avoid repeated allocations there private final int[] currentRowIds; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index b6c93b368420..09e80338753a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -160,7 +160,7 @@ public CloseableIterator orderBy(Orderer orderer, Express if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), orderer); - if (orderer.getVectorTerm() == null) + if (!orderer.isANN()) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer)); int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java index 63710a28c290..22b392350860 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java @@ -355,6 +355,7 @@ private BM25Utils.DocStats computeDocumentFrequencies(QueryContext queryContext, } long docCount = 0; + // count all documents in the queried column try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())), DataRange.allData(memtable.metadata().partitioner))) { diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java index 46baa4f59a87..4cf730bd1d53 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java @@ -1316,11 +1316,11 @@ protected KeysSort withAccess(Access access) * Base class for index scans that return results in a computed order (ANN, BM25) * rather than the natural index order. */ - abstract static class ComputedOrderIndexScan extends Leaf + abstract static class ScoredIndexScan extends Leaf { final Orderer ordering; - protected ComputedOrderIndexScan(Factory factory, int id, Access access, Orderer ordering) + protected ScoredIndexScan(Factory factory, int id, Access access, Orderer ordering) { super(factory, id, access); this.ordering = ordering; @@ -1351,7 +1351,7 @@ protected Iterator execute(Executor executor) * Returns all keys in ANN order. * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. */ - final static class AnnIndexScan extends ComputedOrderIndexScan + final static class AnnIndexScan extends ScoredIndexScan { protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) { @@ -1383,7 +1383,7 @@ protected KeysIteration withAccess(Access access) * Returns all keys in BM25 order. * Like AnnIndexScan, this generates results lazily without an input node. */ - final static class Bm25IndexScan extends ComputedOrderIndexScan + final static class Bm25IndexScan extends ScoredIndexScan { protected Bm25IndexScan(Factory factory, int id, Access access, Orderer ordering) { diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index bfde594fadfc..d9f0dfc013b1 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -629,8 +629,8 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List, P extends BaseParti { // priority queue ordered by score in descending order Comparator> comparator; - // TODO does this work for complex expressions? if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) { comparator = Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed(); diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index c8dbb2ae2528..f38613150a64 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -18,22 +18,12 @@ package org.apache.cassandra.index.sai.cql; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; - import org.junit.Test; import org.apache.cassandra.index.sai.SAITester; -import org.apache.cassandra.index.sai.disk.v1.InvertedIndexSearcher; import org.apache.cassandra.index.sai.plan.QueryController; -import org.apache.cassandra.inject.ActionBuilder; -import org.apache.cassandra.inject.Expression; -import org.apache.cassandra.inject.Injections; -import org.apache.cassandra.inject.InvokePointBuilder; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; -import static org.junit.Assert.assertEquals; public class BM25Test extends SAITester { From 6c9a0e6cbfdae0c40d1e90d655d4d0f1a671cb9b Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Wed, 11 Dec 2024 09:00:37 -0600 Subject: [PATCH 08/29] address review notes --- .../org/apache/cassandra/cql3/Operator.java | 29 ++++++++++--------- .../index/sai/disk/v1/IndexSearcher.java | 4 +-- .../sai/disk/v1/InvertedIndexSearcher.java | 4 +-- .../v1/postings/IntersectingPostingList.java | 27 +++++++++++++++-- .../index/sai/memory/TrieMemtableIndex.java | 4 +-- .../cassandra/index/sai/plan/Orderer.java | 8 +++-- .../apache/cassandra/index/sai/plan/Plan.java | 4 +-- .../cassandra/index/sai/utils/BM25Utils.java | 5 ++-- 8 files changed, 56 insertions(+), 29 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/Operator.java b/src/java/org/apache/cassandra/cql3/Operator.java index 8a8494661cb4..cd80020c1257 100644 --- a/src/java/org/apache/cassandra/cql3/Operator.java +++ b/src/java/org/apache/cassandra/cql3/Operator.java @@ -373,20 +373,6 @@ public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteB return !LIKE.isSatisfiedBy(type, leftOperand, rightOperand, analyzer); } }, - BM25(25) - { - @Override - public String toString() - { - return "BM25"; - } - - @Override - public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) - { - throw new UnsupportedOperationException(); - } - }, /** * An operator that only performs matching against analyzed columns. @@ -429,6 +415,7 @@ private boolean hasToken(AbstractType type, List tokens, ByteBuff return false; } }, + /** * An operator that performs a distance bounded approximate nearest neighbor search against a vector column such * that all result vectors are within a given distance of the query vector. The notable difference between this @@ -473,6 +460,20 @@ public String toString() return "DESC"; } + @Override + public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) + { + throw new UnsupportedOperationException(); + } + }, + BM25(104) + { + @Override + public String toString() + { + return "BM25"; + } + @Override public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java index 5fb246183921..14b182c7ee56 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java @@ -100,7 +100,7 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException; /** - * Order the rows by the giving Orderer. Used for ORDER BY clause when + * Order the rows by the given Orderer. Used for ORDER BY clause when * (1) the WHERE predicate is either a partition restriction or a range restriction on the index, * (2) there is no WHERE predicate, or * (3) the planner determines it is better to post-filter the ordered results by the predicate. @@ -115,7 +115,7 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException; /** - * Order the rows by the giving Orderer. Used for ORDER BY clause when the WHERE predicates + * Order the rows by the given Orderer. Used for ORDER BY clause when the WHERE predicates * have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine * the initial number of results returned, not a maximum. */ diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index b8353a3b5d80..8b7d14c42755 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -172,7 +172,7 @@ public CloseableIterator orderBy(Orderer orderer, Express } // find documents that match each term - var queryTerms = orderer.extractQueryTerms(); + var queryTerms = orderer.getQueryTerms(); var postingLists = queryTerms.stream() .collect(Collectors.toMap(Function.identity(), term -> { @@ -227,7 +227,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea if (!orderer.isBM25()) return super.orderResultsBy(reader, queryContext, keys, orderer, limit); - var queryTerms = orderer.extractQueryTerms(); + var queryTerms = orderer.getQueryTerms(); // compute documentFrequencies from either histogram or an index search var documentFrequencies = new HashMap(); boolean hasHistograms = metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram(); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java index b6240bfce7b1..4cd762fd23c6 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -23,6 +23,7 @@ import javax.annotation.concurrent.NotThreadSafe; import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.io.util.FileUtils; import static java.lang.Math.max; @@ -51,13 +52,16 @@ private IntersectingPostingList(List postingLists) this.currentRowIds = new int[postingLists.size()]; } + /** + * @return the intersection of the provided posting lists + */ public static PostingList intersect(List postingLists) { if (postingLists.size() == 1) return postingLists.get(0); if (postingLists.stream().anyMatch(PostingList::isEmpty)) - return PostingList.EMPTY; + return new EmptyIntersectingList(postingLists); return new IntersectingPostingList(postingLists); } @@ -124,10 +128,27 @@ public int size() } @Override - public void close() throws IOException + public void close() { for (PostingList list : postingLists) - list.close(); + FileUtils.closeQuietly(list); + } + + private static class EmptyIntersectingList extends EmptyPostingList + { + private final List lists; + + public EmptyIntersectingList(List postingLists) + { + this.lists = postingLists; + } + + @Override + public void close() + { + for (PostingList list : lists) + FileUtils.closeQuietly(list); + } } } diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java index 22b392350860..4fa839532d28 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java @@ -258,7 +258,7 @@ public List> orderBy(QueryContext query } // BM25 - var queryTerms = orderer.extractQueryTerms(); + var queryTerms = orderer.getQueryTerms(); // Intersect iterators to find documents containing all terms var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms); @@ -328,7 +328,7 @@ public CloseableIterator orderResultsBy(QueryContext quer } // BM25 - var queryTerms = orderer.extractQueryTerms(); + var queryTerms = orderer.getQueryTerms(); var docStats = computeDocumentFrequencies(queryContext, queryTerms); return BM25Utils.computeScores(keys.iterator(), queryTerms, diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index 6193e0a37110..6cb62ca06607 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -50,6 +50,7 @@ public class Orderer public final Operator operator; public final ByteBuffer term; private float[] vector; + private ArrayList queryTerms; /** * Create an orderer for the given index context, operator, and term. @@ -132,11 +133,14 @@ public float[] getVectorTerm() return vector; } - public ArrayList extractQueryTerms() + public ArrayList getQueryTerms() { + if (queryTerms != null) + return queryTerms; + var queryAnalyzer = context.getQueryAnalyzerFactory().create(); // Split query into terms - var queryTerms = new ArrayList(); + queryTerms = new ArrayList(); queryAnalyzer.reset(term); try { diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java index 4cf730bd1d53..941608328779 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java @@ -1269,7 +1269,7 @@ private KeysIterationCost estimateBm25SortCost() { double expectedKeys = access.expectedAccessCount(source.expectedKeys()); - int termCount = ordering.extractQueryTerms().size(); + int termCount = ordering.getQueryTerms().size(); // all of the cost for BM25 is up front since the index doesn't give us the information we need // to return results in order, in isolation. The big cost is reading the indexed cells out of // the sstables. @@ -1397,7 +1397,7 @@ protected KeysIterationCost estimateCost() double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); - int termCount = ordering.extractQueryTerms().size(); + int termCount = ordering.getQueryTerms().size(); double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + termCount * BM25_SCORE_COST; diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java index 044997e6508b..f1c51e52923b 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; @@ -170,8 +171,8 @@ else if (source instanceof SSTableId) throw new IllegalArgumentException("Invalid source " + source.getClass()); } - // sort by score - scoredDocs.sort(Comparator.comparingDouble((PrimaryKeyWithScore pkws) -> pkws.indexScore).reversed()); + // sort by score (PKWS implements Comparator correctly for us) + Collections.sort(scoredDocs); return (CloseableIterator) (CloseableIterator) CloseableIterator.wrap(scoredDocs.iterator()); } From e107fcccb2a0018bf93dbbe9372f40f91be56fdb Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Wed, 11 Dec 2024 09:02:16 -0600 Subject: [PATCH 09/29] remove unused `limit` parameter from IndexSearcher::search --- .../cassandra/index/sai/disk/v1/IndexSearcher.java | 3 +-- .../index/sai/disk/v1/InvertedIndexSearcher.java | 2 +- .../index/sai/disk/v1/KDTreeIndexSearcher.java | 2 +- .../apache/cassandra/index/sai/disk/v1/Segment.java | 2 +- .../index/sai/disk/v2/V2VectorIndexSearcher.java | 6 +++--- .../index/sai/disk/v1/InvertedIndexSearcherTest.java | 10 +++++----- .../index/sai/disk/v1/KDTreeIndexSearcherTest.java | 12 ++++++------ 7 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java index 14b182c7ee56..16d14bdd8ac3 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java @@ -94,10 +94,9 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics * @param defer create the iterator in a deferred state - * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected! * @return {@link KeyRangeIterator} that matches given expression */ - public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException; + public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer) throws IOException; /** * Order the rows by the given Orderer. Used for ORDER BY clause when diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 8b7d14c42755..774ffd3f2dbf 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -120,7 +120,7 @@ public long indexFileCacheSize() } @SuppressWarnings("resource") - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { PostingList postingList = searchPosting(exp, context); return toPrimaryKeyIterator(postingList, context); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java index 8a2aa354bd47..5911f2351014 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java @@ -84,7 +84,7 @@ public long indexFileCacheSize() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { PostingList postingList = searchPosting(exp, context); return toPrimaryKeyIterator(postingList, context); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java index 87fe0a998e5f..3d52c944d726 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java @@ -144,7 +144,7 @@ public long indexFileCacheSize() */ public KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException { - return index.search(expression, keyRange, context, defer, limit); + return index.search(expression, keyRange, context, defer); } /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 09e80338753a..21771ac94e1f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -133,13 +133,13 @@ public ProductQuantization getPQ() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { - PostingList results = searchPosting(context, exp, keyRange, limit); + PostingList results = searchPosting(context, exp, keyRange); return toPrimaryKeyIterator(results, context); } - private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange, int limit) throws IOException + private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange) throws IOException { if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), exp); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java index 47565eb28739..899c03a95ccc 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java @@ -106,7 +106,7 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception for (int t = 0; t < numTerms; ++t) { try (KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT)) + .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -121,7 +121,7 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception } try (KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT)) + .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -143,12 +143,12 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception // try searching for terms that weren't indexed final String tooLongTerm = randomSimpleString(10, 12); KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false, LIMIT); + .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false); assertFalse(results.hasNext()); final String tooShortTerm = randomSimpleString(1, 2); results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false, LIMIT); + .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false); assertFalse(results.hasNext()); } } @@ -162,7 +162,7 @@ public void testUnsupportedOperator() throws Exception try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, numPostings, termsEnum)) { searcher.search(new Expression(indexContext) - .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false, LIMIT); + .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false); fail("Expect IllegalArgumentException thrown, but didn't"); } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java index 8f34b9a808aa..65242e319a37 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java @@ -151,7 +151,7 @@ public void testUnsupportedOperator() throws Exception {{ operation = Op.NOT_EQ; lower = upper = new Bound(ShortType.instance.decompose((short) 0), Int32Type.instance, true); - }}, null, new QueryContext(), false, LIMIT); + }}, null, new QueryContext(), false); fail("Expect IllegalArgumentException thrown, but didn't"); } @@ -169,7 +169,7 @@ private void testEqQueries(final IndexSearcher indexSearcher, {{ operation = Op.EQ; lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -180,7 +180,7 @@ private void testEqQueries(final IndexSearcher indexSearcher, {{ operation = Op.EQ; lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); indexSearcher.close(); @@ -206,7 +206,7 @@ private void testRangeQueries(final IndexSearcher indexSearch lower = new Bound(rawType.decompose(rawValueProducer.apply((short)2)), encodedType, false); upper = new Bound(rawType.decompose(rawValueProducer.apply((short)7)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -218,7 +218,7 @@ private void testRangeQueries(final IndexSearcher indexSearch {{ operation = Op.RANGE; lower = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); } @@ -227,7 +227,7 @@ private void testRangeQueries(final IndexSearcher indexSearch {{ operation = Op.RANGE; upper = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, false); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); indexSearcher.close(); From cfe204adb00f92a7a1da0352f9f858573a824681 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Wed, 11 Dec 2024 09:40:53 -0600 Subject: [PATCH 10/29] eliminate currentRowIds --- .../v1/postings/IntersectingPostingList.java | 60 +++++++------------ 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java index 4cd762fd23c6..7d1924c301ae 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -25,8 +25,6 @@ import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.io.util.FileUtils; -import static java.lang.Math.max; - /** * Performs intersection operations on multiple PostingLists, returning only postings * that appear in all inputs. @@ -37,10 +35,6 @@ public class IntersectingPostingList implements PostingList private final List postingLists; private final int size; - // currentRowIds state is effectively local to findNextIntersection, but we keep it - // around as a field to avoid repeated allocations there - private final int[] currentRowIds; - private IntersectingPostingList(List postingLists) { assert !postingLists.isEmpty(); @@ -49,7 +43,6 @@ private IntersectingPostingList(List postingLists) .mapToInt(PostingList::size) .min() .orElse(0); - this.currentRowIds = new int[postingLists.size()]; } /** @@ -81,44 +74,35 @@ public int advance(int targetRowID) throws IOException private int findNextIntersection(int targetRowID, boolean isAdvance) throws IOException { - // Initialize currentRowIds from the underlying posting lists + int maxRowId = targetRowID; + int maxRowIdIndex = -1; + + // Scan through all posting lists looking for a common row ID for (int i = 0; i < postingLists.size(); i++) { - currentRowIds[i] = isAdvance - ? postingLists.get(i).advance(targetRowID) - : postingLists.get(i).nextPosting(); - - if (currentRowIds[i] == END_OF_STREAM) + // don't advance the sublist in which we found our current max + if (i == maxRowIdIndex) + continue; + + // Advance this sublist to the current max, special casing the first one as needed + PostingList list = postingLists.get(i); + int rowId = (isAdvance || maxRowIdIndex >= 0) + ? list.advance(maxRowId) + : list.nextPosting(); + if (rowId == END_OF_STREAM) return END_OF_STREAM; - } - while (true) - { - // Find the maximum row ID among all posting lists - int maxRowId = targetRowID; - for (int rowId : currentRowIds) - maxRowId = max(maxRowId, rowId); - - // Advance any posting list that's behind the maximum - boolean allMatch = true; - for (int i = 0; i < postingLists.size(); i++) + // Update maxRowId + index if we find a larger value, or this was the first sublist evaluated + if (rowId > maxRowId || maxRowIdIndex < 0) { - if (currentRowIds[i] < maxRowId) - { - currentRowIds[i] = postingLists.get(i).advance(maxRowId); - if (currentRowIds[i] == END_OF_STREAM) - return END_OF_STREAM; - allMatch = false; - } + maxRowId = rowId; + maxRowIdIndex = i; + i = -1; // restart the scan with new maxRowId } - - // If all posting lists have the same row ID, we've found an intersection - if (allMatch) - return maxRowId; - - // Otherwise, continue searching with the new maximum as target - targetRowID = maxRowId; } + + // Once we complete a full scan without finding a larger rowId, we've found an intersection + return maxRowId; } @Override From cef71e30c10ce8da82a3374bf5749d3cf9f22422 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Wed, 11 Dec 2024 14:31:50 -0600 Subject: [PATCH 11/29] handle eq in an analyzed index by transforming it into a Match restriction in SingleColumnRelation.newEQRestriction. This eliminates the need for skipMerge and special cases in doMergeWith, and moves the issuing of warnings next to the place where the transformation occurs instead of doing it much later in RowFilterValidator (which is no longer needed) --- .../cassandra/cql3/SingleColumnRelation.java | 13 ++- .../ClusteringColumnRestrictions.java | 4 +- .../PartitionKeySingleRestrictionSet.java | 2 +- .../cql3/restrictions/RestrictionSet.java | 36 +++--- .../restrictions/SingleColumnRestriction.java | 24 +--- .../cql3/restrictions/SingleRestriction.java | 12 -- .../restrictions/StatementRestrictions.java | 4 +- .../org/apache/cassandra/db/ReadCommand.java | 2 - .../apache/cassandra/index/IndexRegistry.java | 25 ++--- .../cassandra/index/RowFilterValidator.java | 103 ------------------ .../index/SecondaryIndexManager.java | 6 - .../apache/cassandra/service/ClientWarn.java | 19 +++- 12 files changed, 62 insertions(+), 188 deletions(-) delete mode 100644 src/java/org/apache/cassandra/index/RowFilterValidator.java diff --git a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java index a15251bdefa7..89ae1d38b8a3 100644 --- a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java @@ -23,6 +23,8 @@ import java.util.Objects; import org.apache.cassandra.db.marshal.VectorType; +import org.apache.cassandra.index.IndexRegistry; +import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.cql3.Term.Raw; @@ -33,6 +35,7 @@ import org.apache.cassandra.db.marshal.ListType; import org.apache.cassandra.db.marshal.MapType; import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.service.ClientWarn; import static org.apache.cassandra.cql3.statements.RequestValidations.checkFalse; import static org.apache.cassandra.cql3.statements.RequestValidations.checkTrue; @@ -191,7 +194,15 @@ protected Restriction newEQRestriction(TableMetadata table, VariableSpecificatio if (mapKey == null) { Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); - return new SingleColumnRestriction.EQRestriction(columnDef, term); + var analyzedIndex = IndexRegistry.obtain(table).supportsAnalyzedEq(columnDef); + if (analyzedIndex == null) + return new SingleColumnRestriction.EQRestriction(columnDef, term); + + ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, + columnDef.toString(), + analyzedIndex.getIndexMetadata().name), + columnDef); + return new SingleColumnRestriction.AnalyzerMatchesRestriction(columnDef, term); } List receivers = toReceivers(columnDef); Term entryKey = toTerm(Collections.singletonList(receivers.get(0)), mapKey, table.keyspace, boundNames); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java index 092cc93cc7b7..0b33dfb16f69 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java @@ -189,7 +189,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti SingleRestriction lastRestriction = restrictions.lastRestriction(); ColumnMetadata lastRestrictionStart = lastRestriction.getFirstColumn(); ColumnMetadata newRestrictionStart = newRestriction.getFirstColumn(); - restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry); + restrictions.addRestriction(newRestriction, isDisjunction); checkFalse(lastRestriction.isSlice() && newRestrictionStart.position() > lastRestrictionStart.position(), "Clustering column \"%s\" cannot be restricted (preceding column \"%s\" is restricted by a non-EQ relation)", @@ -203,7 +203,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti } else { - restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry); + restrictions.addRestriction(newRestriction, isDisjunction); } return this; diff --git a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java index 50e6ea481f09..faddd9d9ff06 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java @@ -189,7 +189,7 @@ public PartitionKeyRestrictions build(IndexRegistry indexRegistry, boolean isDis if (restriction.isOnToken()) return buildWithTokens(restrictionSet, i, indexRegistry); - restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction, indexRegistry); + restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction); } return buildPartitionKeyRestrictions(restrictionSet); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java index 92736c0d298a..f9dd22756fe0 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java @@ -413,45 +413,35 @@ private Builder() { } - public void addRestriction(SingleRestriction restriction, boolean isDisjunction, IndexRegistry indexRegistry) + public void addRestriction(SingleRestriction restriction, boolean isDisjunction) { List columnDefs = restriction.getColumnDefs(); if (isDisjunction) { // If this restriction is part of a disjunction query then we don't want - // to merge the restrictions (if that is possible), we just add the - // restriction to the set of restrictions for the column. + // to merge the restrictions, we just add the new restriction addRestrictionForColumns(columnDefs, restriction, null); } else { - // In some special cases such as EQ in analyzed index we need to skip merging the restriction, - // so we can send multiple EQ restrictions to the index. - if (restriction.skipMerge(indexRegistry)) + // ANDed together restrictions against the same columns should be merged. + Set existingRestrictions = getRestrictions(newRestrictions, columnDefs); + // Trivial case of no existing restrictions + if (existingRestrictions.isEmpty()) { addRestrictionForColumns(columnDefs, restriction, null); return; } + // Since we merge new restrictions into the existing ones at each pass, there should only be + // at most one existing restriction across the same columnDefs + assert existingRestrictions.size() == 1 : existingRestrictions; - // If this restriction isn't part of a disjunction then we need to get - // the set of existing restrictions for the column and merge them with the - // new restriction - Set existingRestrictions = getRestrictions(newRestrictions, columnDefs); - - SingleRestriction merged = restriction; - Set replacedRestrictions = new HashSet<>(); - - for (SingleRestriction existing : existingRestrictions) - { - if (!existing.skipMerge(indexRegistry)) - { - merged = existing.mergeWith(merged); - replacedRestrictions.add(existing); - } - } + // Perform the merge + SingleRestriction existing = existingRestrictions.iterator().next(); + var merged = existing.mergeWith(restriction); - addRestrictionForColumns(merged.getColumnDefs(), merged, replacedRestrictions); + addRestrictionForColumns(merged.getColumnDefs(), merged, Set.of(existing)); } } diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java index 1cde8cfc00e4..ced7f9d31cdb 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java @@ -198,24 +198,6 @@ public String toString() return String.format("EQ(%s)", term); } - @Override - public boolean skipMerge(IndexRegistry indexRegistry) - { - // We should skip merging this EQ if there is an analyzed index for this column that supports EQ, - // so there can be multiple EQs for the same column. - - if (indexRegistry == null) - return false; - - for (Index index : indexRegistry.listIndexes()) - { - if (index.supportsExpression(columnDef, Operator.ANALYZER_MATCHES) && - index.supportsExpression(columnDef, Operator.EQ)) - return true; - } - return false; - } - @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { @@ -1402,10 +1384,12 @@ public String toString() @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { - if (!(otherRestriction.isAnalyzerMatches())) + if (!otherRestriction.isAnalyzerMatches()) throw invalidRequest(CANNOT_BE_MERGED_ERROR, columnDef.name); - List otherValues = ((AnalyzerMatchesRestriction) otherRestriction).getValues(); + List otherValues = otherRestriction instanceof AnalyzerMatchesRestriction + ? ((AnalyzerMatchesRestriction) otherRestriction).getValues() + : List.of(((EQRestriction) otherRestriction).term); List newValues = new ArrayList<>(values.size() + otherValues.size()); newValues.addAll(values); newValues.addAll(otherValues); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java index 595451f812de..207e1786a114 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java @@ -20,7 +20,6 @@ import org.apache.cassandra.cql3.QueryOptions; import org.apache.cassandra.cql3.statements.Bound; import org.apache.cassandra.db.MultiClusteringBuilder; -import org.apache.cassandra.index.IndexRegistry; /** * A single restriction/clause on one or multiple column. @@ -97,17 +96,6 @@ public default boolean isInclusive(Bound b) return true; } - /** - * Checks if this restriction shouldn't be merged with other restrictions. - * - * @param indexRegistry the index registry - * @return {@code true} if this shouldn't be merged with other restrictions - */ - default boolean skipMerge(IndexRegistry indexRegistry) - { - return false; - } - /** * Merges this restriction with the specified one. * diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index 7e92d3273077..4d1b4bb68a09 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -442,7 +442,7 @@ else if (def.isClusteringColumn() && nestingLevel == 0) } else { - nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction(), indexRegistry); + nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction()); } } } @@ -699,7 +699,7 @@ else if (indexOrderings.size() == 1) throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, restriction.getFirstColumn())); } - receiver.addRestriction(restriction, false, indexRegistry); + receiver.addRestriction(restriction, false); } } diff --git a/src/java/org/apache/cassandra/db/ReadCommand.java b/src/java/org/apache/cassandra/db/ReadCommand.java index a289987a1adc..c8b38695e004 100644 --- a/src/java/org/apache/cassandra/db/ReadCommand.java +++ b/src/java/org/apache/cassandra/db/ReadCommand.java @@ -383,8 +383,6 @@ static Index.QueryPlan findIndexQueryPlan(TableMetadata table, RowFilter rowFilt @Override public void maybeValidateIndexes() { - IndexRegistry.obtain(metadata()).validate(rowFilter()); - if (null != indexQueryPlan) indexQueryPlan.validate(this); } diff --git a/src/java/org/apache/cassandra/index/IndexRegistry.java b/src/java/org/apache/cassandra/index/IndexRegistry.java index cc6b0d103ea9..ec41a068d9ac 100644 --- a/src/java/org/apache/cassandra/index/IndexRegistry.java +++ b/src/java/org/apache/cassandra/index/IndexRegistry.java @@ -102,12 +102,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression) public void validate(PartitionUpdate update) { } - - @Override - public void validate(RowFilter filter) - { - // no-op since it's an empty registry - } }; /** @@ -295,12 +289,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression) public void validate(PartitionUpdate update) { } - - @Override - public void validate(RowFilter filter) - { - // no-op since it's an empty registry - } }; default void registerIndex(Index index) @@ -341,8 +329,6 @@ default Optional getAnalyzerFor(ColumnMetadata column, Operator */ void validate(PartitionUpdate update); - void validate(RowFilter filter); - /** * Returns the {@code IndexRegistry} associated to the specified table. * @@ -356,4 +342,15 @@ public static IndexRegistry obtain(TableMetadata table) return table.isVirtual() ? EMPTY : Keyspace.openAndGetStore(table).indexManager; } + + default Index supportsAnalyzedEq(ColumnMetadata cm) + { + for (Index index : listIndexes()) + { + if (index.supportsExpression(cm, Operator.ANALYZER_MATCHES) && + index.supportsExpression(cm, Operator.EQ)) + return index; + } + return null; + } } diff --git a/src/java/org/apache/cassandra/index/RowFilterValidator.java b/src/java/org/apache/cassandra/index/RowFilterValidator.java deleted file mode 100644 index fb70fbfc1452..000000000000 --- a/src/java/org/apache/cassandra/index/RowFilterValidator.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.index; - -import java.util.HashSet; -import java.util.Set; -import java.util.StringJoiner; - -import org.apache.cassandra.cql3.Operator; -import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport; -import org.apache.cassandra.schema.ColumnMetadata; -import org.apache.cassandra.service.ClientWarn; - -/** - * Class for validating the index-related aspects of a {@link RowFilter}, without considering what index is actually used. - *

- * It will emit a client warning when a query has EQ restrictions on columns having an analyzed index. - */ -class RowFilterValidator -{ - private final Iterable allIndexes; - - private Set columns; - private Set indexes; - - private RowFilterValidator(Iterable allIndexes) - { - this.allIndexes = allIndexes; - } - - private void addEqRestriction(ColumnMetadata column) - { - for (Index index : allIndexes) - { - if (index.supportsExpression(column, Operator.EQ) && - index.supportsExpression(column, Operator.ANALYZER_MATCHES)) - { - if (columns == null) - columns = new HashSet<>(); - columns.add(column); - - if (indexes == null) - indexes = new HashSet<>(); - indexes.add(index); - } - } - } - - private void validate() - { - if (columns == null || indexes == null) - return; - - StringJoiner columnNames = new StringJoiner(", "); - StringJoiner indexNames = new StringJoiner(", "); - columns.forEach(column -> columnNames.add(column.name.toString())); - indexes.forEach(index -> indexNames.add(index.getIndexMetadata().name)); - - ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, columnNames, indexNames)); - } - - /** - * Emits a client warning if the filter contains EQ restrictions on columns having an analyzed index. - * - * @param filter the filter to validate - * @param indexes the existing indexes - */ - public static void validate(RowFilter filter, Iterable indexes) - { - RowFilterValidator validator = new RowFilterValidator(indexes); - validate(filter.root(), validator); - validator.validate(); - } - - private static void validate(RowFilter.FilterElement element, RowFilterValidator validator) - { - for (RowFilter.Expression expression : element.expressions()) - { - if (expression.operator() == Operator.EQ) - validator.addEqRestriction(expression.column()); - } - - for (RowFilter.FilterElement child : element.children()) - { - validate(child, validator); - } - } -} diff --git a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java index 9954d1b3ee1c..21d633b30a39 100644 --- a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java +++ b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java @@ -1277,12 +1277,6 @@ public void validate(PartitionUpdate update) throws InvalidRequestException index.validate(update); } - @Override - public void validate(RowFilter filter) - { - RowFilterValidator.validate(filter, indexes.values()); - } - /* * IndexRegistry methods */ diff --git a/src/java/org/apache/cassandra/service/ClientWarn.java b/src/java/org/apache/cassandra/service/ClientWarn.java index 5a6a878681e1..38570a06d2b8 100644 --- a/src/java/org/apache/cassandra/service/ClientWarn.java +++ b/src/java/org/apache/cassandra/service/ClientWarn.java @@ -18,7 +18,9 @@ package org.apache.cassandra.service; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import io.netty.util.concurrent.FastThreadLocal; import org.apache.cassandra.concurrent.ExecutorLocal; @@ -45,10 +47,18 @@ public void set(State value) } public void warn(String text) + { + warn(text, null); + } + + /** + * Issue the given warning if this is the first time `key` is seen. + */ + public void warn(String text, Object key) { State state = warnLocal.get(); if (state != null) - state.add(text); + state.add(text, key); } public void captureWarnings() @@ -72,11 +82,16 @@ public void resetWarnings() public static class State { private final List warnings = new ArrayList<>(); + private final Set keysAdded = new HashSet<>(); - private void add(String warning) + private void add(String warning, Object key) { if (warnings.size() < FBUtilities.MAX_UNSIGNED_SHORT) + { + if (key != null && !keysAdded.add(key)) + return; warnings.add(maybeTruncate(warning)); + } } private static String maybeTruncate(String warning) From 3d17e2f63f64f434cb819d1db214ae595bd6c721 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Wed, 11 Dec 2024 15:25:30 -0600 Subject: [PATCH 12/29] add testMatchingAllowed and make it work via shouldMerge --- .../cql3/restrictions/RestrictionSet.java | 25 ++++++++++--------- .../restrictions/SingleColumnRestriction.java | 11 ++++++++ .../cql3/restrictions/SingleRestriction.java | 12 +++++++++ .../cassandra/index/sai/cql/BM25Test.java | 16 ++++++++++++ 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java index f9dd22756fe0..9d459e1da247 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java @@ -427,21 +427,22 @@ public void addRestriction(SingleRestriction restriction, boolean isDisjunction) { // ANDed together restrictions against the same columns should be merged. Set existingRestrictions = getRestrictions(newRestrictions, columnDefs); - // Trivial case of no existing restrictions - if (existingRestrictions.isEmpty()) + + // merge the new restriction into an existing one. note that there is only ever a single + // restriction (per column), UNLESS one is ORDER BY BM25 and the other is MATCH. + for (var existing : existingRestrictions) { - addRestrictionForColumns(columnDefs, restriction, null); - return; + // shouldMerge exists for the BM25/MATCH case + if (existing.shouldMerge(restriction)) + { + var merged = existing.mergeWith(restriction); + addRestrictionForColumns(merged.getColumnDefs(), merged, Set.of(existing)); + return; + } } - // Since we merge new restrictions into the existing ones at each pass, there should only be - // at most one existing restriction across the same columnDefs - assert existingRestrictions.size() == 1 : existingRestrictions; - // Perform the merge - SingleRestriction existing = existingRestrictions.iterator().next(); - var merged = existing.mergeWith(restriction); - - addRestrictionForColumns(merged.getColumnDefs(), merged, Set.of(existing)); + // no existing restrictions that we should merge the new one with, add a new one + addRestrictionForColumns(columnDefs, restriction, null); } } diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java index ced7f9d31cdb..5d055fd7fb89 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java @@ -1237,6 +1237,17 @@ public boolean isIndexBasedOrdering() { return true; } + + @Override + public boolean shouldMerge(SingleRestriction other) + { + // we don't want to merge MATCH restrictions with ORDER BY BM25 + // so shouldMerge = false for that scenario, and true for others + // (because even though we can't meaningfully merge with others, we want doMergeWith to be called to throw) + // + // (Note that because ORDER BY is processed before WHERE, we only need this check in the BM25 class) + return !other.isAnalyzerMatches(); + } } /** diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java index 207e1786a114..bdd80badc0ae 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java @@ -129,4 +129,16 @@ public default MultiClusteringBuilder appendBoundTo(MultiClusteringBuilder build { return appendTo(builder, options); } + + /** + * @return true if the other restriction should be merged with this one. + * This is NOT for preventing illegal combinations of restrictions, e.g. + * a=1 AND a=2; that is handled by mergeWith. Instead, this is for the case + * where we want two completely different semantics against the same column. + * Currently the only such case is BM25 with MATCH. + */ + default boolean shouldMerge(SingleRestriction other) + { + return true; + } } diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index f38613150a64..3f220527d648 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -27,6 +27,22 @@ public class BM25Test extends SAITester { + @Test + public void testMatchingAllowed() throws Throwable + { + // match operator should be allowed with BM25 on the same column + // (seems obvious but exercises a corner case in the internal RestrictionSet processing) + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); + }); + } + @Test public void testTermFrequencyOrdering() throws Throwable { From 907a2eed33c7ecb2206e5c4f2a9d7f652a1c14a4 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 12 Dec 2024 08:31:01 -0600 Subject: [PATCH 13/29] disambiguate the BM25 error message when the index isn't analyzed --- .../cql3/restrictions/StatementRestrictions.java | 9 +++++++-- .../apache/cassandra/index/sai/cql/BM25Test.java | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index 4d1b4bb68a09..b4bf34dbfa54 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -81,6 +81,7 @@ public class StatementRestrictions "Restriction on partition key column %s must not be nested under OR operator"; public static final String GEO_DISTANCE_REQUIRES_INDEX_MESSAGE = "GEO_DISTANCE requires the vector column to be indexed"; + public static final String BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE = "BM25 ordering on column %s requires an analyzed index"; public static final String NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE = "Ordering on non-clustering column %s requires the column to be indexed"; public static final String NON_CLUSTER_ORDERING_REQUIRES_ALL_RESTRICTED_NON_PARTITION_KEY_COLUMNS_INDEXED_MESSAGE = "Ordering on non-clustering column requires each restricted column to be indexed except for fully-specified partition keys"; @@ -696,8 +697,12 @@ else if (indexOrderings.size() == 1) throw new InvalidRequestException(String.format("SAI based ordering on column %s of type %s is not supported", restriction.getFirstColumn(), restriction.getFirstColumn().type.asCQL3Type())); - throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, - restriction.getFirstColumn())); + if (ordering.expression instanceof Ordering.Bm25) + throw new InvalidRequestException(String.format(BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE, + restriction.getFirstColumn())); + else + throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, + restriction.getFirstColumn())); } receiver.addRestriction(restriction, false); } diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index 3f220527d648..44f5c76982f7 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -20,9 +20,11 @@ import org.junit.Test; +import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.index.sai.SAITester; import org.apache.cassandra.index.sai.plan.QueryController; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; public class BM25Test extends SAITester @@ -43,6 +45,19 @@ public void testMatchingAllowed() throws Throwable }); } + @Test + public void testTwoIndexes() throws Throwable + { + // create un-analyzed index + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + assertThatThrownBy(() -> execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3")) + .isInstanceOf(InvalidRequestException.class) + .hasMessage("BM25 ordering on column v requires an analyzed index"); + } + @Test public void testTermFrequencyOrdering() throws Throwable { From 3315a1200d5b74b9e3cdc56364307dcfea91288c Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 12 Dec 2024 08:57:37 -0600 Subject: [PATCH 14/29] validateOptions treats analyzed and un-analyzed indexes as distinct, testTwoIndexes passes --- .../index/sai/StorageAttachedIndex.java | 30 +++++++++++-------- .../cassandra/index/sai/cql/BM25Test.java | 9 +++++- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index 74482d97a7c9..691f0fd99308 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -297,20 +297,26 @@ public static Map validateOptions(Map options, T throw new InvalidRequestException("Failed to retrieve target column for: " + targetColumn); } - // In order to support different index target on non-frozen map, ie. KEYS, VALUE, ENTRIES, we need to put index - // name as part of index file name instead of column name. We only need to check that the target is different - // between indexes. This will only allow indexes in the same column with a different IndexTarget.Type. - // - // Note that: "metadata.indexes" already includes current index - if (metadata.indexes.stream().filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName())) - .map(index -> TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME))) - .filter(Objects::nonNull).filter(t -> t.equals(target)).count() > 1) - { - throw new InvalidRequestException("Cannot create more than one storage-attached index on the same column: " + target.left); - } + // Check for duplicate indexes considering both target and analyzer configuration + boolean isAnalyzed = AbstractAnalyzer.isAnalyzed(options); + long duplicateCount = metadata.indexes.stream() + .filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName())) + .filter(index -> { + // Indexes on the same column with different target (KEYS, VALUES, ENTRIES) + // are allowed on non-frozen Maps + var existingTarget = TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME)); + if (existingTarget == null || !existingTarget.equals(target)) + return false; + // Also allow different indexes if one is analyzed and the other isn't + return isAnalyzed == AbstractAnalyzer.isAnalyzed(index.options); + }) + .count(); + // >1 because "metadata.indexes" already includes current index + if (duplicateCount > 1) + throw new InvalidRequestException(String.format("Cannot create duplicate storage-attached index on column: %s", target.left)); // Analyzer is not supported against PK columns - if (AbstractAnalyzer.isAnalyzed(options)) + if (isAnalyzed) { for (ColumnMetadata column : metadata.primaryKeyColumns()) { diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index 44f5c76982f7..cd6457f162fe 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -51,11 +51,18 @@ public void testTwoIndexes() throws Throwable // create un-analyzed index createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); - execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + // BM25 should fail with only an equality index assertThatThrownBy(() -> execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3")) .isInstanceOf(InvalidRequestException.class) .hasMessage("BM25 ordering on column v requires an analyzed index"); + + // create analyzed index + analyzeIndex(); + // BM25 query should work now + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); } @Test From c0de416a1cb4c3daaf378dbed5877aefc1e2e620 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 12 Dec 2024 09:16:49 -0600 Subject: [PATCH 15/29] detect and reject ambiguous equality predicates; testAmbiguousPredicates passes --- .../cassandra/cql3/SingleColumnRelation.java | 27 +++- .../apache/cassandra/index/IndexRegistry.java | 90 ++++++++++- .../analyzer/AnalyzerEqOperatorSupport.java | 9 +- .../org/apache/cassandra/cql3/CQLTester.java | 5 + .../cassandra/index/sai/cql/BM25Test.java | 146 +++++++++++++++--- 5 files changed, 242 insertions(+), 35 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java index 89ae1d38b8a3..4dc6aaa78a7c 100644 --- a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; import org.apache.cassandra.db.marshal.VectorType; import org.apache.cassandra.index.IndexRegistry; @@ -194,15 +195,27 @@ protected Restriction newEQRestriction(TableMetadata table, VariableSpecificatio if (mapKey == null) { Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); - var analyzedIndex = IndexRegistry.obtain(table).supportsAnalyzedEq(columnDef); - if (analyzedIndex == null) + // Leave the restriction as EQ if no analyzed index in backwards compatibility mode is present + var ebi = IndexRegistry.obtain(table).getEqBehavior(columnDef); + if (ebi.behavior == IndexRegistry.EqBehavior.EQ) return new SingleColumnRestriction.EQRestriction(columnDef, term); - ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, - columnDef.toString(), - analyzedIndex.getIndexMetadata().name), - columnDef); - return new SingleColumnRestriction.AnalyzerMatchesRestriction(columnDef, term); + // the index is configured to transform EQ into MATCH for backwards compatibility + if (ebi.behavior == IndexRegistry.EqBehavior.MATCH) + { + ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, + columnDef.toString(), + ebi.matchIndex.getIndexMetadata().name), + columnDef); + return new SingleColumnRestriction.AnalyzerMatchesRestriction(columnDef, term); + } + + // multiple indexes support EQ, this is unsupported + assert ebi.behavior == IndexRegistry.EqBehavior.AMBIGUOUS; + throw invalidRequest(AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR, + columnDef.toString(), + ebi.matchIndex.getIndexMetadata().name, + ebi.eqIndex.getIndexMetadata().name); } List receivers = toReceivers(columnDef); Term entryKey = toTerm(Collections.singletonList(receivers.get(0)), mapKey, table.keyspace, boundNames); diff --git a/src/java/org/apache/cassandra/index/IndexRegistry.java b/src/java/org/apache/cassandra/index/IndexRegistry.java index ec41a068d9ac..0ba923d9b3c3 100644 --- a/src/java/org/apache/cassandra/index/IndexRegistry.java +++ b/src/java/org/apache/cassandra/index/IndexRegistry.java @@ -343,14 +343,94 @@ public static IndexRegistry obtain(TableMetadata table) return table.isVirtual() ? EMPTY : Keyspace.openAndGetStore(table).indexManager; } - default Index supportsAnalyzedEq(ColumnMetadata cm) + enum EqBehavior { + EQ, + MATCH, + AMBIGUOUS + } + + class EqBehaviorIndexes + { + public EqBehavior behavior; + public final Index eqIndex; + public final Index matchIndex; + + private EqBehaviorIndexes(Index eqIndex, Index matchIndex, EqBehavior behavior) + { + this.eqIndex = eqIndex; + this.matchIndex = matchIndex; + this.behavior = behavior; + } + + public static EqBehaviorIndexes eq(Index eqIndex) + { + return new EqBehaviorIndexes(eqIndex, null, EqBehavior.EQ); + } + + public static EqBehaviorIndexes match(Index eqAndMatchIndex) + { + return new EqBehaviorIndexes(eqAndMatchIndex, eqAndMatchIndex, EqBehavior.MATCH); + } + + public static EqBehaviorIndexes ambiguous(Index firstEqIndex, Index secondEqIndex) + { + return new EqBehaviorIndexes(firstEqIndex, secondEqIndex, EqBehavior.AMBIGUOUS); + } + } + + /** + * @return - EQ if a single index supports EQ + * - MATCHES if a single index supports both + * - AMBIGUOUS if multiple indexes support EQ + */ + default EqBehaviorIndexes getEqBehavior(ColumnMetadata cm) + { + // scan the indexes for MATCHES and EQ support + Index matchesIndex = null; + Index eqIndex = null; for (Index index : listIndexes()) { - if (index.supportsExpression(cm, Operator.ANALYZER_MATCHES) && - index.supportsExpression(cm, Operator.EQ)) - return index; + if (index.supportsExpression(cm, Operator.EQ)) + { + if (eqIndex == null) + { + eqIndex = index; + continue; + } + + // If we find a second EQ index, return AMBIGUOUS, taking care to assign the eqIndex and matchIndex correctly + if (index.supportsExpression(cm, Operator.ANALYZER_MATCHES)) + { + matchesIndex = index; + } + else + { + assert eqIndex.supportsExpression(cm, Operator.ANALYZER_MATCHES); + matchesIndex = eqIndex; + eqIndex = index; + } + return EqBehaviorIndexes.ambiguous(eqIndex, matchesIndex); + } + + if (index.supportsExpression(cm, Operator.ANALYZER_MATCHES)) + { + // should only ever have one + assert matchesIndex == null; + matchesIndex = index; + } } - return null; + + // If we didn't find any indexes that support EQ or MATCHES, return EQ + if (eqIndex == null && matchesIndex == null) + return EqBehaviorIndexes.eq(null); + + // If the same index supports both EQ and MATCHES, promote to MATCHES + if (eqIndex == matchesIndex) + return EqBehaviorIndexes.match(matchesIndex); + + // Otherwise we either have no MATCHES index, or it's distinct from the EQ index. + // In both cases we want to return EQ + return EqBehaviorIndexes.eq(eqIndex); } } diff --git a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java index 116bc7f62832..30408c9b986f 100644 --- a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java +++ b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java @@ -49,7 +49,7 @@ public class AnalyzerEqOperatorSupport OPTION, Arrays.toString(Value.values())); public static final String EQ_RESTRICTION_ON_ANALYZED_WARNING = - String.format("Columns [%%s] are restricted by '=' and have analyzed indexes [%%s] able to process those restrictions. " + + String.format("Column [%%s] is restricted by '=' and has an analyzed index [%%s] able to process those restrictions. " + "Analyzed indexes might process '=' restrictions in a way that is inconsistent with non-indexed queries. " + "While '=' is still supported on analyzed indexes for backwards compatibility, " + "it is recommended to use the ':' operator instead to prevent the ambiguity. " + @@ -58,6 +58,13 @@ public class AnalyzerEqOperatorSupport "please use '%s':'%s' in the index options.", OPTION, Value.UNSUPPORTED.toString().toLowerCase()); + public static final String EQ_AMBIGUOUS_ERROR = + String.format("Column [%%s] equality predicate is ambiguous. It has both an analyzed index [%%s] configured with '%s':'%s', " + + "and an un-analyzed index [%%s]. " + + "To avoid ambiguity, drop the analyzed index and recreate it with option '%s':'%s'.", + OPTION, Value.MATCH.toString().toLowerCase(), OPTION, Value.UNSUPPORTED.toString().toLowerCase()); + + public static final String LWT_CONDITION_ON_ANALYZED_WARNING = "Index analyzers not applied to LWT conditions on columns [%s]."; diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java index 5a6ec34b02ef..a28a76b95524 100644 --- a/test/unit/org/apache/cassandra/cql3/CQLTester.java +++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java @@ -803,6 +803,11 @@ protected String currentIndex() return indexes.get(indexes.size() - 1); } + protected String getIndex(int i) + { + return indexes.get(i); + } + protected Collection currentTables() { if (tables == null || tables.isEmpty()) diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index cd6457f162fe..69d51a19be44 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -20,33 +20,16 @@ import org.junit.Test; -import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.index.sai.SAITester; import org.apache.cassandra.index.sai.plan.QueryController; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; public class BM25Test extends SAITester { @Test - public void testMatchingAllowed() throws Throwable - { - // match operator should be allowed with BM25 on the same column - // (seems obvious but exercises a corner case in the internal RestrictionSet processing) - createSimpleTable(); - - execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); - - beforeAndAfterFlush(() -> - { - var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); - assertRows(result, row(1)); - }); - } - - @Test - public void testTwoIndexes() throws Throwable + public void testTwoIndexes() { // create un-analyzed index createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); @@ -54,9 +37,8 @@ public void testTwoIndexes() throws Throwable execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); // BM25 should fail with only an equality index - assertThatThrownBy(() -> execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3")) - .isInstanceOf(InvalidRequestException.class) - .hasMessage("BM25 ordering on column v requires an analyzed index"); + assertInvalidMessage("BM25 ordering on column v requires an analyzed index", + "SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); // create analyzed index analyzeIndex(); @@ -65,6 +47,126 @@ public void testTwoIndexes() throws Throwable assertRows(result, row(1)); } + @Test + public void testTwoIndexesAmbiguousPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + + // Create analyzed and un-analyzed indexes + analyzeIndex(); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'orange juice')"); + + // equality predicate is ambiguous (both analyzed and un-analyzed indexes could support it) so it should + // be rejected + beforeAndAfterFlush(() -> { + // Single predicate + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple'"); + + // AND + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple' AND v : 'juice'"); + + // OR + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple' OR v : 'juice'"); + }); + } + + @Test + public void testTwoIndexesWithEqualsUnsupported() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + // analyzed index with equals_behavior:unsupported option + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = { 'equals_behaviour_when_analyzed': 'unsupported', " + + "'index_analyzer':'{\"tokenizer\":{\"name\":\"standard\"},\"filters\":[{\"name\":\"porterstem\"}]}' }"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')"); + + beforeAndAfterFlush(() -> { + // combining two EQ predicates is not allowed + assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v = 'juice'"); + + // combining EQ and MATCH predicates is also not allowed (when we're not converting EQ to MATCH) + assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v : 'apple'"); + + // combining two MATCH predicates is fine + assertRows(execute("SELECT k FROM %s WHERE v : 'apple' AND v : 'juice'"), + row(2)); + + // = operator should use un-analyzed index since equals is unsupported in analyzed index + assertRows(execute("SELECT k FROM %s WHERE v = 'apple'"), + row(1)); + + // : operator should use analyzed index + assertRows(execute("SELECT k FROM %s WHERE v : 'apple'"), + row(1), row(2)); + }); + } + + @Test + public void testComplexQueriesWithMultipleIndexes() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v1 text, v2 text, v3 int)"); + + // Create mix of analyzed, unanalyzed, and non-text indexes + createIndex("CREATE CUSTOM INDEX ON %s(v1) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v2) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'" + + "}"); + createIndex("CREATE CUSTOM INDEX ON %s(v3) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (1, 'apple', 'orange juice', 5)"); + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (2, 'apple juice', 'apple', 10)"); + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (3, 'banana', 'grape juice', 5)"); + + beforeAndAfterFlush(() -> { + // Complex query mixing different types of indexes and operators + assertRows(execute("SELECT k FROM %s WHERE v1 = 'apple' AND v2 : 'juice' AND v3 = 5"), + row(1)); + + // Mix of AND and OR conditions across different index types + assertRows(execute("SELECT k FROM %s WHERE v3 = 5 AND (v1 = 'apple' OR v2 : 'apple')"), + row(1)); + + // Multi-term analyzed query + assertRows(execute("SELECT k FROM %s WHERE v2 : 'orange juice'"), + row(1)); + + // Range query with text match + assertRows(execute("SELECT k FROM %s WHERE v3 >= 5 AND v2 : 'juice'"), + row(1), row(3)); + }); + } + + @Test + public void testMatchingAllowed() throws Throwable + { + // match operator should be allowed with BM25 on the same column + // (seems obvious but exercises a corner case in the internal RestrictionSet processing) + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); + }); + } + @Test public void testTermFrequencyOrdering() throws Throwable { From 3967e7c917457a1c683348b33eb1b9169ff74e7f Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 12 Dec 2024 17:10:40 -0600 Subject: [PATCH 16/29] don't inject +score unless coordinator requests it; this is a cleaner approach than ignoring it when serialization fails later --- .../cql3/statements/SelectStatement.java | 3 +- src/java/org/apache/cassandra/db/Columns.java | 8 ----- .../cassandra/db/filter/ColumnFilter.java | 15 ++++++++++ .../db/rows/UnfilteredSerializer.java | 5 ---- .../index/sai/plan/QueryController.java | 5 ++++ .../plan/StorageAttachedIndexSearcher.java | 29 ++++++++++++++----- .../index/sai/plan/TopKProcessor.java | 2 +- 7 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 847d058cf26b..810dd605577f 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -103,9 +103,8 @@ public class SelectStatement implements CQLStatement.SingleKeyspaceCqlStatement { // TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns, // and the related code in - // - Columns.Serializer.encodeBitmap - // - UnfilteredSerializer.serializeRowBody) // - StatementRestrictions.addOrderingRestrictions + // - StorageAttachedIndexSearcher.PrimaryKeyIterator constructor public static final boolean ANN_USE_SYNTHETIC_SCORE = Boolean.parseBoolean(System.getProperty("cassandra.sai.ann_use_synthetic_score", "false")); private static final Logger logger = LoggerFactory.getLogger(SelectStatement.class); diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java index becd23508d1d..cef03393a1db 100644 --- a/src/java/org/apache/cassandra/db/Columns.java +++ b/src/java/org/apache/cassandra/db/Columns.java @@ -716,15 +716,7 @@ private static long encodeBitmap(Collection columns, Columns sup for (ColumnMetadata column : columns) { if (iter.next(column) == null) - { - // We can't know for sure whether to add the synthetic score column because WildcardColumnFilter - // just says "yes" to everything; instead, we just skip it here. - // TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE. - if (column.isSynthetic()) - continue; - throw new IllegalStateException(columns + " is not a subset of " + superset); - } int currentIndex = iter.indexOfCurrent(); int count = currentIndex - expectIndex; diff --git a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java index 2d4b240270a8..644e6d661a61 100644 --- a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java +++ b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java @@ -75,6 +75,9 @@ public abstract class ColumnFilter public static final Serializer serializer = new Serializer(); + // TODO remove this with ANN_USE_SYNTHETIC_SCORE + public abstract boolean fetchesExplicitly(ColumnMetadata column); + /** * The fetching strategy for the different queries. */ @@ -669,6 +672,12 @@ public boolean fetches(ColumnMetadata column) return true; } + @Override + public boolean fetchesExplicitly(ColumnMetadata column) + { + return false; + } + @Override public boolean fetchedColumnIsQueried(ColumnMetadata column) { @@ -817,6 +826,12 @@ public boolean fetches(ColumnMetadata column) return fetchingStrategy.fetchesAllColumns(column.isStatic()) || fetched.contains(column); } + @Override + public boolean fetchesExplicitly(ColumnMetadata column) + { + return fetched.contains(column); + } + /** * Whether the provided complex cell (identified by its column and path), which is assumed to be _fetched_ by * this filter, is also _queried_ by the user. diff --git a/src/java/org/apache/cassandra/db/rows/UnfilteredSerializer.java b/src/java/org/apache/cassandra/db/rows/UnfilteredSerializer.java index c94e3470bb33..a38305e1dd71 100644 --- a/src/java/org/apache/cassandra/db/rows/UnfilteredSerializer.java +++ b/src/java/org/apache/cassandra/db/rows/UnfilteredSerializer.java @@ -242,11 +242,6 @@ private void serializeRowBody(Row row, int flags, SerializationHelper helper, Da // with. So we use the ColumnMetadata from the "header" which is "current". Also see #11810 for what // happens if we don't do that. ColumnMetadata column = si.next(cd.column()); - // We can't know for sure whether to add the synthetic score column because WildcardColumnFilter - // just says "yes" to everything; instead, we just skip it here. - // TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE. - if (column == null) - return; assert column != null : cd.column.toString(); try diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index fbe4430efc6c..7cd69a2e036b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -195,6 +195,11 @@ public TableMetadata metadata() return command.metadata(); } + public ReadCommand command() + { + return command; + } + RowFilter.FilterElement filterOperation() { // NOTE: we cannot remove the order by filter expression here yet because it is used in the FilterTree class diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index d9f0dfc013b1..c1d971db5d43 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -42,6 +42,7 @@ import org.apache.cassandra.db.PartitionPosition; import org.apache.cassandra.db.ReadCommand; import org.apache.cassandra.db.ReadExecutionController; +import org.apache.cassandra.db.filter.ColumnFilter; import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.db.rows.AbstractUnfilteredRowIterator; @@ -636,7 +637,7 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List primaryKeysWithScore) + public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfiltered content, List primaryKeysWithScore, ReadCommand command) { super(partition.metadata(), partition.partitionKey(), @@ -668,7 +669,24 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt partition.stats()); assert !primaryKeysWithScore.isEmpty(); - if (!content.isRow() || !(primaryKeysWithScore.get(0) instanceof PrimaryKeyWithScore)) + var isScoredRow = primaryKeysWithScore.get(0) instanceof PrimaryKeyWithScore; + if (!content.isRow() || !isScoredRow) + { + this.row = content; + return; + } + + // When +score is added on the coordinator side, it's represented as a PrecomputedColumnFilter + // even in a 'SELECT *' because WCF is not capable of representing synthetic columns. + // This can be simplified when we remove ANN_USE_SYNTHETIC_SCORE + var tm = metadata(); + var scoreColumn = ColumnMetadata.syntheticColumn(tm.keyspace, + tm.name, + ColumnMetadata.SYNTHETIC_SCORE_ID, + FloatType.instance); + var isScoreFetched = !(command.columnFilter() instanceof ColumnFilter.WildCardColumnFilter) + && command.columnFilter().fetchesExplicitly(scoreColumn); + if (!isScoreFetched) { this.row = content; return; @@ -680,11 +698,6 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt columnData.addAll(originalRow.columnData()); // inject +score as a new column - var tm = metadata(); - var scoreColumn = ColumnMetadata.syntheticColumn(tm.keyspace, - tm.name, - ColumnMetadata.SYNTHETIC_SCORE_ID, - FloatType.instance); var pkWithScore = (PrimaryKeyWithScore) primaryKeysWithScore.get(0); columnData.add(BufferCell.live(scoreColumn, FBUtilities.nowInSeconds(), diff --git a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java index c52fe19ae3d8..775a92aedf9f 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java +++ b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java @@ -364,7 +364,7 @@ private float getScoreForRow(DecoratedKey key, Row row) return FloatType.instance.compose(cell.buffer()); } - // TODO remove this once we enable the scored path for vector queries + // TODO remove this once we enable ANN_USE_SYNTHETIC_SCORE ByteBuffer value = indexContext.getValueOf(key, row, FBUtilities.nowInSeconds()); if (value != null) { From 893b87b3a22d4e2fe5d32167d8655f4116e876be Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 13 Dec 2024 14:35:52 -0600 Subject: [PATCH 17/29] fix getEqBahavior, this is most of the test failures --- .../apache/cassandra/index/IndexRegistry.java | 64 +++++++------------ 1 file changed, 22 insertions(+), 42 deletions(-) diff --git a/src/java/org/apache/cassandra/index/IndexRegistry.java b/src/java/org/apache/cassandra/index/IndexRegistry.java index 0ba923d9b3c3..2417e0c8e521 100644 --- a/src/java/org/apache/cassandra/index/IndexRegistry.java +++ b/src/java/org/apache/cassandra/index/IndexRegistry.java @@ -21,6 +21,7 @@ package org.apache.cassandra.index; import java.util.Collection; +import java.util.HashSet; import java.util.Collections; import java.util.Optional; import java.util.Set; @@ -380,57 +381,36 @@ public static EqBehaviorIndexes ambiguous(Index firstEqIndex, Index secondEqInde } /** - * @return - EQ if a single index supports EQ - * - MATCHES if a single index supports both - * - AMBIGUOUS if multiple indexes support EQ + * @return + * - AMBIGUOUS if an index supports EQ and a different one supports EQ and ANALYZER_MATCHES + * - MATCHES if an index supports both EQ and ANALYZER_MATCHES + * - otherwise EQ */ default EqBehaviorIndexes getEqBehavior(ColumnMetadata cm) { - // scan the indexes for MATCHES and EQ support - Index matchesIndex = null; - Index eqIndex = null; + Index eqOnlyIndex = null; + Index bothIndex = null; + for (Index index : listIndexes()) { - if (index.supportsExpression(cm, Operator.EQ)) - { - if (eqIndex == null) - { - eqIndex = index; - continue; - } - - // If we find a second EQ index, return AMBIGUOUS, taking care to assign the eqIndex and matchIndex correctly - if (index.supportsExpression(cm, Operator.ANALYZER_MATCHES)) - { - matchesIndex = index; - } - else - { - assert eqIndex.supportsExpression(cm, Operator.ANALYZER_MATCHES); - matchesIndex = eqIndex; - eqIndex = index; - } - return EqBehaviorIndexes.ambiguous(eqIndex, matchesIndex); - } + boolean supportsEq = index.supportsExpression(cm, Operator.EQ); + boolean supportsMatches = index.supportsExpression(cm, Operator.ANALYZER_MATCHES); - if (index.supportsExpression(cm, Operator.ANALYZER_MATCHES)) - { - // should only ever have one - assert matchesIndex == null; - matchesIndex = index; - } + if (supportsEq && supportsMatches) + bothIndex = index; + else if (supportsEq) + eqOnlyIndex = index; } - // If we didn't find any indexes that support EQ or MATCHES, return EQ - if (eqIndex == null && matchesIndex == null) - return EqBehaviorIndexes.eq(null); + // If we have one index supporting only EQ and another supporting both, return AMBIGUOUS + if (eqOnlyIndex != null && bothIndex != null) + return EqBehaviorIndexes.ambiguous(eqOnlyIndex, bothIndex); - // If the same index supports both EQ and MATCHES, promote to MATCHES - if (eqIndex == matchesIndex) - return EqBehaviorIndexes.match(matchesIndex); + // If we have an index supporting both EQ and MATCHES, return MATCHES + if (bothIndex != null) + return EqBehaviorIndexes.match(bothIndex); - // Otherwise we either have no MATCHES index, or it's distinct from the EQ index. - // In both cases we want to return EQ - return EqBehaviorIndexes.eq(eqIndex); + // Otherwise return EQ + return EqBehaviorIndexes.eq(eqOnlyIndex == null ? bothIndex : eqOnlyIndex); } } From c1eaa63abda00f0662febcb042b3ae5dfec805a4 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 13 Dec 2024 09:27:48 -0600 Subject: [PATCH 18/29] LongBM25Test --- .../cassandra/index/sai/LongBM25Test.java | 248 ++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 test/burn/org/apache/cassandra/index/sai/LongBM25Test.java diff --git a/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java new file mode 100644 index 000000000000..e1ff956ed7fd --- /dev/null +++ b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.junit.Before; +import org.junit.Test; + +import org.slf4j.Logger; + +import org.apache.cassandra.db.memtable.TrieMemtable; + +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; + +public class LongBM25Test extends SAITester +{ + private static final Logger logger = org.slf4j.LoggerFactory.getLogger(LongBM25Test.class); + + private static final List documentLines = new ArrayList<>(); + + static + { + try + { + var cl = LongBM25Test.class.getClassLoader(); + var resourceDir = cl.getResource("bm25"); + if (resourceDir == null) + throw new RuntimeException("Could not find resource directory test/resources/bm25/"); + + var dirPath = java.nio.file.Paths.get(resourceDir.toURI()); + try (var files = java.nio.file.Files.list(dirPath)) + { + files.forEach(file -> { + try (var lines = java.nio.file.Files.lines(file)) + { + lines.map(String::trim) + .filter(line -> !line.isEmpty()) + .forEach(documentLines::add); + } + catch (IOException e) + { + throw new RuntimeException("Failed to read file: " + file, e); + } + }); + } + if (documentLines.isEmpty()) + { + throw new RuntimeException("No document lines loaded from test/resources/bm25/"); + } + } + catch (IOException | URISyntaxException e) + { + throw new RuntimeException("Failed to load test documents", e); + } + } + + KeySet keysInserted = new KeySet(); + private final int threadCount = 12; + + @Before + public void setup() throws Throwable + { + // we don't get loaded until after TM, so we can't affect the very first memtable, + // but this will affect all subsequent ones + TrieMemtable.SHARD_COUNT = 4 * threadCount; + } + + @FunctionalInterface + private interface Op + { + void run(int i) throws Throwable; + } + + public void testConcurrentOps(Op op) throws ExecutionException, InterruptedException + { + createTable("CREATE TABLE %s (key int primary key, value text)"); + // Create analyzed index following BM25Test pattern + createIndex("CREATE CUSTOM INDEX ON %s(value) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" + ); + + AtomicInteger counter = new AtomicInteger(); + long start = System.currentTimeMillis(); + var fjp = new ForkJoinPool(threadCount); + var keys = IntStream.range(0, 10_000_000).boxed().collect(Collectors.toList()); + Collections.shuffle(keys); + var task = fjp.submit(() -> keys.stream().parallel().forEach(i -> + { + wrappedOp(op, i); + if (counter.incrementAndGet() % 10_000 == 0) + { + var elapsed = System.currentTimeMillis() - start; + logger.info("{} ops in {}ms = {} ops/s", counter.get(), elapsed, counter.get() * 1000.0 / elapsed); + } + if (ThreadLocalRandom.current().nextDouble() < 0.001) + flush(); + })); + fjp.shutdown(); + task.get(); // re-throw + } + + private static void wrappedOp(Op op, Integer i) + { + try + { + op.run(i); + } + catch (Throwable e) + { + throw new RuntimeException(e); + } + } + + private static String randomDocument() + { + var R = ThreadLocalRandom.current(); + int numLines = R.nextInt(5, 51); // 5 to 50 lines inclusive + var selectedLines = new ArrayList(); + + for (int i = 0; i < numLines; i++) + { + selectedLines.add(randomLine(R)); + } + + return String.join("\n", selectedLines); + } + + private static String randomLine(ThreadLocalRandom R) + { + return documentLines.get(R.nextInt(documentLines.size())); + } + + @Test + public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var R = ThreadLocalRandom.current(); + if (R.nextDouble() < 0.2 || keysInserted.isEmpty()) + { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + keysInserted.add(i); + } + else if (R.nextDouble() < 0.1) + { + var key = keysInserted.getRandom(); + execute("DELETE FROM %s WHERE key = ?", key); + } + else + { + var line = randomLine(R); + execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); + } + }); + } + + @Test + public void testConcurrentReadsWrites() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var R = ThreadLocalRandom.current(); + if (R.nextDouble() < 0.1 || keysInserted.isEmpty()) + { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + keysInserted.add(i); + } + else + { + var line = randomLine(R); + execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); + } + }); + } + + @Test + public void testConcurrentWrites() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + }); + } + + private static class KeySet + { + private final Map keys = new ConcurrentHashMap<>(); + private final AtomicInteger ordinal = new AtomicInteger(); + + public void add(int key) + { + var i = ordinal.getAndIncrement(); + keys.put(i, key); + } + + public int getRandom() + { + if (isEmpty()) + throw new IllegalStateException(); + var i = ThreadLocalRandom.current().nextInt(ordinal.get()); + // in case there is race with add(key), retry another random + return keys.containsKey(i) ? keys.get(i) : getRandom(); + } + + public boolean isEmpty() + { + return keys.isEmpty(); + } + } +} From ddbfd16fedb66aff885f19fc38905a558b340ed4 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 13 Dec 2024 10:50:15 -0600 Subject: [PATCH 19/29] misc bugfixes related to zero matches for a term --- .../index/sai/disk/v1/InvertedIndexSearcher.java | 13 +++++++------ .../v1/postings/IntersectingPostingList.java | 3 +++ .../apache/cassandra/index/sai/plan/Orderer.java | 11 +++++++---- .../cassandra/index/sai/utils/BM25Utils.java | 16 ++++++++++------ 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 774ffd3f2dbf..9a56dd3bdd63 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -174,12 +174,13 @@ public CloseableIterator orderBy(Orderer orderer, Express // find documents that match each term var queryTerms = orderer.getQueryTerms(); var postingLists = queryTerms.stream() - .collect(Collectors.toMap(Function.identity(), term -> - { - var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); - var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); - return reader.exactMatch(encodedTerm, listener, queryContext); - })); + .collect(Collectors.toMap(Function.identity(), term -> + { + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + var postings = reader.exactMatch(encodedTerm, listener, queryContext); + return postings == null ? PostingList.EMPTY : postings; + })); // extract the match count for each var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size())); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java index 7d1924c301ae..5cefcba1e59d 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -50,6 +50,9 @@ private IntersectingPostingList(List postingLists) */ public static PostingList intersect(List postingLists) { + if (postingLists.isEmpty()) + return PostingList.EMPTY; + if (postingLists.size() == 1) return postingLists.get(0); diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index 6cb62ca06607..af202bae5f5b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -23,6 +23,8 @@ import java.util.Arrays; import java.util.Comparator; import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -50,7 +52,7 @@ public class Orderer public final Operator operator; public final ByteBuffer term; private float[] vector; - private ArrayList queryTerms; + private List queryTerms; /** * Create an orderer for the given index context, operator, and term. @@ -133,23 +135,24 @@ public float[] getVectorTerm() return vector; } - public ArrayList getQueryTerms() + public List getQueryTerms() { if (queryTerms != null) return queryTerms; var queryAnalyzer = context.getQueryAnalyzerFactory().create(); // Split query into terms - queryTerms = new ArrayList(); + var uniqueTerms = new HashSet(); queryAnalyzer.reset(term); try { - queryAnalyzer.forEachRemaining(queryTerms::add); + queryAnalyzer.forEachRemaining(uniqueTerms::add); } finally { queryAnalyzer.end(); } + queryTerms = new ArrayList<>(uniqueTerms); return queryTerms; } } diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java index f1c51e52923b..b02b575a477d 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -64,9 +63,9 @@ public static class DocTF { private final PrimaryKey pk; private final Map frequencies; - private final float termCount; + private final int termCount; - private DocTF(PrimaryKey pk, float termCount, Map frequencies) + private DocTF(PrimaryKey pk, int termCount, Map frequencies) { this.pk = pk; this.frequencies = frequencies; @@ -83,7 +82,7 @@ public static DocTF createFromDocument(PrimaryKey pk, AbstractAnalyzer docAnalyzer, Collection queryTerms) { - float count = 0; + int count = 0; Map frequencies = new HashMap<>(); docAnalyzer.reset(cell.buffer()); @@ -146,7 +145,7 @@ public static CloseableIterator computeScores(Iterator 0 ? totalTermCount / documents.size() : 0.0; + double avgDocLength = !documents.isEmpty() ? totalTermCount / documents.size() : 0.0; // Calculate BM25 scores var scoredDocs = new ArrayList(documents.size()); @@ -157,10 +156,15 @@ public static CloseableIterator computeScores(Iterator docStats.docCount) + throw new AssertionError(String.format("df=%d, totalDocs=%d", df, docStats.docCount)); + double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength)); double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); double deltaScore = normalizedTf * idf; - assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, totalDocs=%d is %f", tf, df, docStats.docCount, deltaScore); + assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f", + tf, df, doc.termCount, docStats.docCount, deltaScore); score += deltaScore; } if (source instanceof Memtable) From 7ff2374f3fc7d870856c3fc2a29f8930668c2d91 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 13 Dec 2024 11:07:32 -0600 Subject: [PATCH 20/29] ramIndexer deduplicates (term, row) pairs --- .../index/sai/disk/RAMStringIndexer.java | 17 +++++++++++++---- .../index/sai/disk/v1/SegmentBuilder.java | 1 + 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java index 8b6f07d43597..f2035bd9631a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java +++ b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java @@ -140,27 +140,36 @@ private ByteComparable asByteComparable(byte[] bytes, int offset, int length) }; } + /** + * @return bytes allocated. may be zero if the (term, row) pair is a duplicate + */ public long add(BytesRef term, int segmentRowId) { long startBytes = estimatedBytesUsed(); int termID = termsHash.add(term); + boolean firstOccurrence = termID >= 0; - if (termID >= 0) + if (firstOccurrence) { - // firs time seeing this term, create the term's first slice ! + // first time seeing this term, create the term's first slice ! slices.createNewSlice(termID); } else { termID = (-termID) - 1; + // compaction should call this method only with increasing segmentRowIds + assert segmentRowId >= lastSegmentRowID[termID]; + // Skip if we've already recorded seen this segmentRowId for this term + if (segmentRowId == lastSegmentRowID[termID]) + return 0; } if (termID >= lastSegmentRowID.length - 1) - { lastSegmentRowID = ArrayUtil.grow(lastSegmentRowID, termID + 1); - } int delta = segmentRowId - lastSegmentRowID[termID]; + // sanity check that we're advancing the row id, i.e. no duplicate entries. + assert firstOccurrence || delta > 0; lastSegmentRowID[termID] = segmentRowId; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java index 504df79c924a..77c89c1d466a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java @@ -208,6 +208,7 @@ protected long addInternal(ByteBuffer term, int segmentRowId) var encodedTerm = components.onDiskFormat().encodeForTrie(term, termComparator); var bytes = ByteSourceInverse.readBytes(encodedTerm.asComparableBytes(byteComparableVersion)); var bytesRef = new BytesRef(bytes); + // ramIndexer is responsible for merging duplicate (term, row) pairs return ramIndexer.add(bytesRef, segmentRowId); } From cfa1157505716ec59c604e453259e08e94875f94 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 13 Dec 2024 11:59:11 -0600 Subject: [PATCH 21/29] need to use compareUnsigned once we have more than 4 KINDs --- src/java/org/apache/cassandra/schema/ColumnMetadata.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/java/org/apache/cassandra/schema/ColumnMetadata.java b/src/java/org/apache/cassandra/schema/ColumnMetadata.java index 2a59ee3ebf58..372deb354012 100644 --- a/src/java/org/apache/cassandra/schema/ColumnMetadata.java +++ b/src/java/org/apache/cassandra/schema/ColumnMetadata.java @@ -479,7 +479,7 @@ public int compareTo(ColumnMetadata other) return 0; if (comparisonOrder != other.comparisonOrder) - return Long.compare(comparisonOrder, other.comparisonOrder); + return Long.compareUnsigned(comparisonOrder, other.comparisonOrder); return this.name.compareTo(other.name); } From 1620d3e7acdf27564d43d940cd3e5b10aec9649c Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 13 Dec 2024 11:59:19 -0600 Subject: [PATCH 22/29] simplify --- .../cassandra/index/sai/plan/StorageAttachedIndexSearcher.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index c1d971db5d43..cb54764ec7a5 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -684,8 +684,7 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt tm.name, ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); - var isScoreFetched = !(command.columnFilter() instanceof ColumnFilter.WildCardColumnFilter) - && command.columnFilter().fetchesExplicitly(scoreColumn); + var isScoreFetched = command.columnFilter().fetchesExplicitly(scoreColumn); if (!isScoreFetched) { this.row = content; From 73f35dff0ded6a1f2dee933ca43f739cbd111df9 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 13 Dec 2024 14:15:09 -0600 Subject: [PATCH 23/29] make SYNTHETIC the first Column Kind instead of the last. This avoids breaking the assumption in BTreeRow that complex regular/static columns sort last --- src/java/org/apache/cassandra/db/Columns.java | 47 +++++++++---------- .../org/apache/cassandra/db/ReadCommand.java | 4 +- .../cassandra/schema/ColumnMetadata.java | 6 ++- 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java index cef03393a1db..f85a9b803141 100644 --- a/src/java/org/apache/cassandra/db/Columns.java +++ b/src/java/org/apache/cassandra/db/Columns.java @@ -336,11 +336,6 @@ public Iterator simpleColumns() return BTree.iterator(columns, 0, complexIdx - 1, BTree.Dir.ASC); } - public Iterator simpleColumnsDesc() - { - return BTree.iterator(columns, 0, complexIdx - 1, BTree.Dir.DESC); - } - /** * Iterator over the complex columns of this object. * @@ -486,21 +481,21 @@ public void serialize(Columns columns, DataOutputPlus out) throws IOException long packedCount = getPackedCount(syntheticCount, regularCount); out.writeUnsignedVInt(packedCount); - // First pass - write regular columns + // First pass - write synthetic columns with their full metadata for (ColumnMetadata column : columns) { - if (!column.isSynthetic()) + if (column.isSynthetic()) + { ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + typeSerializer.serialize(column.type, out); + } } - // Second pass - write synthetic columns with their full metadata + // Second pass - write regular columns for (ColumnMetadata column : columns) { - if (column.isSynthetic()) - { + if (!column.isSynthetic()) ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); - typeSerializer.serialize(column.type, out); - } } } @@ -545,7 +540,20 @@ public Columns deserialize(DataInputPlus in, TableMetadata metadata) throws IOEx int regularCount = (int) (packedCount & 0xFFFFF); int syntheticCount = (int) (packedCount >> 20); - // First pass - regular columns + // First pass - synthetic columns + for (int i = 0; i < syntheticCount; i++) + { + ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); + AbstractType type = typeSerializer.deserialize(in); + + if (!name.equals(ColumnMetadata.SYNTHETIC_SCORE_ID.bytes)) + throw new IllegalStateException("Unknown synthetic column " + UTF8Type.instance.getString(name)); + + ColumnMetadata column = ColumnMetadata.syntheticColumn(metadata.keyspace, metadata.name, ColumnMetadata.SYNTHETIC_SCORE_ID, type); + builder.add(column); + } + + // Second pass - regular columns for (int i = 0; i < regularCount; i++) { ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); @@ -559,19 +567,6 @@ public Columns deserialize(DataInputPlus in, TableMetadata metadata) throws IOEx } builder.add(column); } - - // Second pass - synthetic columns - for (int i = 0; i < syntheticCount; i++) - { - ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); - AbstractType type = typeSerializer.deserialize(in); - - if (!name.equals(ColumnMetadata.SYNTHETIC_SCORE_ID.bytes)) - throw new IllegalStateException("Unknown synthetic column " + UTF8Type.instance.getString(name)); - - ColumnMetadata column = ColumnMetadata.syntheticColumn(metadata.keyspace, metadata.name, ColumnMetadata.SYNTHETIC_SCORE_ID, type); - builder.add(column); - } return new Columns(builder.build()); } } diff --git a/src/java/org/apache/cassandra/db/ReadCommand.java b/src/java/org/apache/cassandra/db/ReadCommand.java index c8b38695e004..a8e88b14094d 100644 --- a/src/java/org/apache/cassandra/db/ReadCommand.java +++ b/src/java/org/apache/cassandra/db/ReadCommand.java @@ -1049,10 +1049,10 @@ public ReadCommand deserialize(DataInputPlus in, int version) throws IOException // add synthetic columns to the tablemetadata so we can serialize them in our response var tmb = metadata.unbuild(); - for (var it = columnFilter.fetchedColumns().regulars.simpleColumnsDesc(); it.hasNext(); ) + for (var it = columnFilter.fetchedColumns().regulars.simpleColumns(); it.hasNext(); ) { var c = it.next(); - // synthetic columns sort last, so when we hit the first non-synthetic, we're done + // synthetic columns sort first, so when we hit the first non-synthetic, we're done if (!c.isSynthetic()) break; tmb.addColumn(ColumnMetadata.syntheticColumn(c.ksName, c.cfName, c.name, c.type)); diff --git a/src/java/org/apache/cassandra/schema/ColumnMetadata.java b/src/java/org/apache/cassandra/schema/ColumnMetadata.java index 372deb354012..008813564e57 100644 --- a/src/java/org/apache/cassandra/schema/ColumnMetadata.java +++ b/src/java/org/apache/cassandra/schema/ColumnMetadata.java @@ -81,11 +81,13 @@ public enum ClusteringOrder public enum Kind { // NOTE: if adding a new type, must modify comparisonOrder + SYNTHETIC, PARTITION_KEY, CLUSTERING, REGULAR, - STATIC, - SYNTHETIC; + STATIC; + // it is not possible to add new Kinds after Synthetic without invasive changes to BTreeRow, which + // assumes that complex regulr/static columns are the last ones public boolean isPrimaryKeyKind() { From 3ad8ae2f966017c21a9afd9bd7ffef8b51ab3c68 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 13 Dec 2024 17:43:17 -0600 Subject: [PATCH 24/29] fix tests --- .../apache/cassandra/index/IndexRegistry.java | 2 +- .../sai/cql/MultipleColumnIndexTest.java | 26 ++++++++++++++++--- .../index/sai/cql/NativeIndexDDLTest.java | 2 +- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/java/org/apache/cassandra/index/IndexRegistry.java b/src/java/org/apache/cassandra/index/IndexRegistry.java index 2417e0c8e521..2d7ac81a0e8c 100644 --- a/src/java/org/apache/cassandra/index/IndexRegistry.java +++ b/src/java/org/apache/cassandra/index/IndexRegistry.java @@ -382,7 +382,7 @@ public static EqBehaviorIndexes ambiguous(Index firstEqIndex, Index secondEqInde /** * @return - * - AMBIGUOUS if an index supports EQ and a different one supports EQ and ANALYZER_MATCHES + * - AMBIGUOUS if an index supports EQ and a different one supports both EQ and ANALYZER_MATCHES * - MATCHES if an index supports both EQ and ANALYZER_MATCHES * - otherwise EQ */ diff --git a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java index 5364717a8906..8d952dc1f4f0 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java @@ -40,12 +40,30 @@ public void canCreateMultipleMapIndexesOnSameColumn() throws Throwable } @Test - public void cannotHaveMultipleLiteralIndexesWithDifferentOptions() throws Throwable + public void canHaveAnalyzedAndUnanalyzedIndexesOnSameColumn() throws Throwable { - createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))"); + createTable("CREATE TABLE %s (pk int, value text, PRIMARY KEY(pk))"); createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }"); - assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }")) - .isInstanceOf(InvalidRequestException.class); + createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false, 'equals_behaviour_when_analyzed': 'unsupported' }"); + + execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 1, "a"); + execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 2, "A"); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s WHERE value = 'a'"), + row(1)); + assertRows(execute("SELECT pk FROM %s WHERE value : 'a'"), + row(1), + row(2)); + }); + } + + @Test + public void cannotHaveMultipleAnalyzingIndexesOnSameColumn() throws Throwable + { + createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }"); + assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'normalize' : true }")) + .isInstanceOf(InvalidRequestException.class); } @Test diff --git a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java index 12b597e85ca8..8d4ac97b2a2a 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java @@ -581,7 +581,7 @@ public void shouldFailCreationMultipleIndexesOnSimpleColumn() // different name, different option, same target. assertThatThrownBy(() -> executeNet("CREATE CUSTOM INDEX ON %s(v1) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }")) .isInstanceOf(InvalidQueryException.class) - .hasMessageContaining("Cannot create more than one storage-attached index on the same column: v1" ); + .hasMessageContaining("Cannot create duplicate storage-attached index on column: v1" ); ResultSet rows = executeNet("SELECT id FROM %s WHERE v1 = '1'"); assertEquals(1, rows.all().size()); From dbbc67824733e3c36a96e0f6c3b8b34f19331c0d Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Mon, 16 Dec 2024 12:31:27 -0600 Subject: [PATCH 25/29] DRY refactor --- .../cassandra/cql3/restrictions/MultiColumnRestriction.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/restrictions/MultiColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/MultiColumnRestriction.java index b08925fafe6f..30df843651db 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/MultiColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/MultiColumnRestriction.java @@ -127,10 +127,7 @@ protected final String getColumnsInCommons(Restriction otherRestriction) @Override public final boolean hasSupportingIndex(IndexRegistry indexRegistry) { - for (Index index : indexRegistry.listIndexes()) - if (isSupportingIndex(index)) - return true; - return false; + return findSupportingIndex(indexRegistry) != null; } @Override From 1b57d55618c3e84de84f3cdf40b4f8b3716d36fe Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Mon, 16 Dec 2024 11:42:06 -0600 Subject: [PATCH 26/29] add tests for unknown query terms, duplicate query terms, no query terms add validation and reject queries with no analyzed terms --- .../restrictions/SingleColumnRestriction.java | 17 ++++++-- .../v1/postings/IntersectingPostingList.java | 6 +-- .../cassandra/index/sai/LongBM25Test.java | 23 +++++----- .../cassandra/index/sai/cql/BM25Test.java | 42 +++++++++++++++++++ 4 files changed, 70 insertions(+), 18 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java index 5d055fd7fb89..137ee3ab1a54 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java @@ -24,14 +24,18 @@ import java.util.List; import java.util.Map; -import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.schema.ColumnMetadata; -import org.apache.cassandra.cql3.*; +import org.apache.cassandra.cql3.MarkerOrTerms; +import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.cql3.QueryOptions; +import org.apache.cassandra.cql3.Term; +import org.apache.cassandra.cql3.Terms; import org.apache.cassandra.cql3.functions.Function; import org.apache.cassandra.cql3.statements.Bound; import org.apache.cassandra.db.MultiClusteringBuilder; +import org.apache.cassandra.db.filter.RowFilter; import org.apache.cassandra.index.Index; import org.apache.cassandra.index.IndexRegistry; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.serializers.ListSerializer; import org.apache.cassandra.transport.ProtocolVersion; import org.apache.cassandra.utils.ByteBufferUtil; @@ -1205,7 +1209,12 @@ public void addToRowFilter(RowFilter.Builder filter, IndexRegistry indexRegistry, QueryOptions options) { - filter.add(columnDef, Operator.BM25, value.bindAndGet(options)); + var index = findSupportingIndex(indexRegistry); + var valueBytes = value.bindAndGet(options); + var terms = index.getAnalyzer().get().analyze(valueBytes); + if (terms.isEmpty()) + throw invalidRequest("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)"); + filter.add(columnDef, Operator.BM25, valueBytes); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java index 5cefcba1e59d..8c3bfe3ee6df 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -37,7 +37,8 @@ public class IntersectingPostingList implements PostingList private IntersectingPostingList(List postingLists) { - assert !postingLists.isEmpty(); + if (postingLists.isEmpty()) + throw new AssertionError(); this.postingLists = postingLists; this.size = postingLists.stream() .mapToInt(PostingList::size) @@ -50,9 +51,6 @@ private IntersectingPostingList(List postingLists) */ public static PostingList intersect(List postingLists) { - if (postingLists.isEmpty()) - return PostingList.EMPTY; - if (postingLists.size() == 1) return postingLists.get(0); diff --git a/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java index e1ff956ed7fd..49b5b5118540 100644 --- a/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java +++ b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java @@ -20,10 +20,6 @@ import java.io.IOException; import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -38,13 +34,10 @@ import org.junit.Before; import org.junit.Test; - import org.slf4j.Logger; import org.apache.cassandra.db.memtable.TrieMemtable; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; - public class LongBM25Test extends SAITester { private static final Logger logger = org.slf4j.LoggerFactory.getLogger(LongBM25Test.class); @@ -157,7 +150,7 @@ private static String randomDocument() for (int i = 0; i < numLines; i++) { - selectedLines.add(randomLine(R)); + selectedLines.add(randomQuery(R)); } return String.join("\n", selectedLines); @@ -186,12 +179,22 @@ else if (R.nextDouble() < 0.1) } else { - var line = randomLine(R); + var line = randomQuery(R); execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); } }); } + private static String randomQuery(ThreadLocalRandom R) + { + while (true) + { + var line = randomLine(R); + if (line.chars().anyMatch(Character::isAlphabetic)) + return line; + } + } + @Test public void testConcurrentReadsWrites() throws ExecutionException, InterruptedException { @@ -205,7 +208,7 @@ public void testConcurrentReadsWrites() throws ExecutionException, InterruptedEx } else { - var line = randomLine(R); + var line = randomQuery(R); execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); } }); diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index 69d51a19be44..4723a507f4b3 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -167,6 +167,48 @@ public void testMatchingAllowed() throws Throwable }); } + @Test + public void testUnknownQueryTerm() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'orange' LIMIT 1"); + assertEmpty(result); + }); + } + + @Test + public void testDuplicateQueryTerm() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple apple' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testEmptyQuery() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + assertInvalidMessage("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)", + "SELECT k FROM %s ORDER BY v BM25 OF '+' LIMIT 1"); + }); + } + @Test public void testTermFrequencyOrdering() throws Throwable { From 0b5ce5c99d28e23a325a1a2b5eb5910ce23835df Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 17 Dec 2024 07:54:17 -0600 Subject: [PATCH 27/29] // since doc frequencies can be an estimate from the index histogram, which does not have bounded error, // cap frequencies to total rows so that the IDF term doesn't turn negative --- .../cassandra/index/sai/disk/v1/InvertedIndexSearcher.java | 7 ++++++- .../org/apache/cassandra/index/sai/utils/BM25Utils.java | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 9a56dd3bdd63..fb1bb3ac455f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -213,7 +213,12 @@ private CloseableIterator bm25Internal(Iterator queryTerms, Map documentFrequencies) { - var docStats = new BM25Utils.DocStats(documentFrequencies, sstable.getTotalRows()); + var totalRows = sstable.getTotalRows(); + // since doc frequencies can be an estimate from the index histogram, which does not have bounded error, + // cap frequencies to total rows so that the IDF term doesn't turn negative + var cappedFrequencies = documentFrequencies.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> Math.min(e.getValue(), totalRows))); + var docStats = new BM25Utils.DocStats(cappedFrequencies, totalRows); return BM25Utils.computeScores(keyIterator, queryTerms, docStats, diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java index b02b575a477d..cc4fa1f66b04 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -157,8 +157,7 @@ public static CloseableIterator computeScores(Iterator docStats.docCount) - throw new AssertionError(String.format("df=%d, totalDocs=%d", df, docStats.docCount)); + assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount); double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength)); double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); From f3f7a15b36912fd8ba5318b6ac002f40562d81cc Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 17 Dec 2024 07:54:44 -0600 Subject: [PATCH 28/29] parameterize version to test with/without histograms --- .../cassandra/index/sai/cql/BM25Test.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index 4723a507f4b3..dfd7b58d4125 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -18,16 +18,41 @@ package org.apache.cassandra.index.sai.cql; +import java.util.Collection; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.SAIUtil; +import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.plan.QueryController; import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +@RunWith(Parameterized.class) public class BM25Test extends SAITester { + @Parameterized.Parameter + public Version version; + + @Parameterized.Parameters(name = "{0}") + public static Collection data() + { + return Stream.of(Version.EB).map(v -> new Object[]{ v}).collect(Collectors.toList()); + } + + @Before + public void setup() throws Throwable + { + SAIUtil.setLatestVersion(version); + } + @Test public void testTwoIndexes() { From d83e18d9eeb8e9f544760daa357054a740d337d3 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 17 Dec 2024 13:57:47 -0600 Subject: [PATCH 29/29] actually parameterize both versions --- test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index dfd7b58d4125..ef9d09424ef5 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -44,7 +44,7 @@ public class BM25Test extends SAITester @Parameterized.Parameters(name = "{0}") public static Collection data() { - return Stream.of(Version.EB).map(v -> new Object[]{ v}).collect(Collectors.toList()); + return Stream.of(Version.DC, Version.EB).map(v -> new Object[]{ v}).collect(Collectors.toList()); } @Before