Skip to content

Commit

Permalink
feat: Using FT.AGGREGATE for all queries. Resolves #17
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Ruaux committed Mar 28, 2023
1 parent ae873bf commit 498184e
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 290 deletions.
7 changes: 5 additions & 2 deletions src/main/java/com/redis/trino/RediSearchBuiltinField.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.redis.trino;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.util.Arrays.stream;
import static java.util.function.Function.identity;
Expand All @@ -16,7 +15,7 @@

enum RediSearchBuiltinField {

ID("_id", VARCHAR, Field.Type.TAG), SCORE("_score", REAL, Field.Type.NUMERIC);
KEY("__key", VARCHAR, Field.Type.TAG);

private static final Map<String, RediSearchBuiltinField> COLUMNS_BY_NAME = stream(values())
.collect(toImmutableMap(RediSearchBuiltinField::getName, identity()));
Expand Down Expand Up @@ -58,4 +57,8 @@ public ColumnMetadata getMetadata() {
public RediSearchColumnHandle getColumnHandle() {
return new RediSearchColumnHandle(name, type, fieldType, true, false);
}

public static boolean isKeyColumn(String columnName) {
return KEY.name.equals(columnName);
}
}
24 changes: 10 additions & 14 deletions src/main/java/com/redis/trino/RediSearchMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import com.google.common.collect.ImmutableMap;
import com.redis.lettucemod.search.Field;
import com.redis.lettucemod.search.querybuilder.Values;
import com.redis.trino.RediSearchTableHandle.Type;

import io.airlift.log.Logger;
import io.airlift.slice.Slice;
Expand Down Expand Up @@ -250,7 +249,7 @@ public Optional<ConnectorOutputMetadata> finishInsert(ConnectorSession session,
@Override
public RediSearchColumnHandle getDeleteRowIdColumnHandle(ConnectorSession session,
ConnectorTableHandle tableHandle) {
return RediSearchBuiltinField.ID.getColumnHandle();
return RediSearchBuiltinField.KEY.getColumnHandle();
}

@Override
Expand All @@ -268,16 +267,16 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan
@Override
public RediSearchColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle,
List<ColumnHandle> updatedColumns) {
return RediSearchBuiltinField.ID.getColumnHandle();
return RediSearchBuiltinField.KEY.getColumnHandle();
}

@Override
public RediSearchTableHandle beginUpdate(ConnectorSession session, ConnectorTableHandle tableHandle,
List<ColumnHandle> updatedColumns, RetryMode retryMode) {
checkRetry(retryMode);
RediSearchTableHandle table = (RediSearchTableHandle) tableHandle;
return new RediSearchTableHandle(table.getType(), table.getSchemaTableName(), table.getConstraint(),
table.getLimit(), table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(),
return new RediSearchTableHandle(table.getSchemaTableName(), table.getConstraint(), table.getLimit(),
table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(),
updatedColumns.stream().map(RediSearchColumnHandle.class::cast).collect(toImmutableList()));
}

Expand Down Expand Up @@ -306,11 +305,9 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
return Optional.empty();
}

return Optional.of(new LimitApplicationResult<>(
new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), handle.getConstraint(),
OptionalLong.of(limit), handle.getTermAggregations(), handle.getMetricAggregations(),
handle.getWildcards(), handle.getUpdatedColumns()),
true, false));
return Optional.of(new LimitApplicationResult<>(new RediSearchTableHandle(handle.getSchemaTableName(),
handle.getConstraint(), OptionalLong.of(limit), handle.getTermAggregations(),
handle.getMetricAggregations(), handle.getWildcards(), handle.getUpdatedColumns()), true, false));
}

@Override
Expand Down Expand Up @@ -372,7 +369,7 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
return Optional.empty();
}

handle = new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), newDomain, handle.getLimit(),
handle = new RediSearchTableHandle(handle.getSchemaTableName(), newDomain, handle.getLimit(),
handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards, handle.getUpdatedColumns());

return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported),
Expand Down Expand Up @@ -498,9 +495,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
if (aggregationList.isEmpty()) {
return Optional.empty();
}
RediSearchTableHandle tableHandle = new RediSearchTableHandle(Type.AGGREGATE, table.getSchemaTableName(),
table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards(),
table.getUpdatedColumns());
RediSearchTableHandle tableHandle = new RediSearchTableHandle(table.getSchemaTableName(), table.getConstraint(),
table.getLimit(), terms.build(), aggregationList, table.getWildcards(), table.getUpdatedColumns());
return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(),
resultAssignments.build(), Map.of(), false));
}
Expand Down
37 changes: 20 additions & 17 deletions src/main/java/com/redis/trino/RediSearchPageSink.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,29 @@ public CompletableFuture<?> appendPage(Page page) {
String prefix = prefix().orElse(schemaTableName.getTableName() + KEY_SEPARATOR);
StatefulRedisModulesConnection<String, String> connection = session.getConnection();
connection.setAutoFlushCommands(false);
RedisModulesAsyncCommands<String, String> commands = connection.async();
List<RedisFuture<?>> futures = new ArrayList<>();
for (int position = 0; position < page.getPositionCount(); position++) {
String key = prefix + factory.create().toString();
Map<String, String> map = new HashMap<>();
for (int channel = 0; channel < page.getChannelCount(); channel++) {
RediSearchColumnHandle column = columns.get(channel);
Block block = page.getBlock(channel);
if (block.isNull(position)) {
continue;
try {
RedisModulesAsyncCommands<String, String> commands = connection.async();
List<RedisFuture<?>> futures = new ArrayList<>();
for (int position = 0; position < page.getPositionCount(); position++) {
String key = prefix + factory.create().toString();
Map<String, String> map = new HashMap<>();
for (int channel = 0; channel < page.getChannelCount(); channel++) {
RediSearchColumnHandle column = columns.get(channel);
Block block = page.getBlock(channel);
if (block.isNull(position)) {
continue;
}
String value = value(column.getType(), block, position);
map.put(column.getName(), value);
}
String value = value(column.getType(), block, position);
map.put(column.getName(), value);
RedisFuture<Long> future = commands.hset(key, map);
futures.add(future);
}
RedisFuture<Long> future = commands.hset(key, map);
futures.add(future);
connection.flushCommands();
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
} finally {
connection.setAutoFlushCommands(true);
}
connection.flushCommands();
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
connection.setAutoFlushCommands(true);
return NOT_BLOCKED;
}

Expand Down
140 changes: 93 additions & 47 deletions src/main/java/com/redis/trino/RediSearchPageSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
import com.fasterxml.jackson.core.JsonGenerator;
import com.redis.lettucemod.api.StatefulRedisModulesConnection;
import com.redis.lettucemod.api.async.RedisModulesAsyncCommands;
import com.redis.lettucemod.search.Document;
import com.redis.lettucemod.search.AggregateWithCursorResults;

import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.lettuce.core.LettuceFutures;
Expand All @@ -57,28 +58,29 @@

public class RediSearchPageSource implements UpdatablePageSource {

private static final Logger log = Logger.get(RediSearchPageSource.class);

private static final int ROWS_PER_REQUEST = 1024;

private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter();
private final RediSearchSession session;
private final RediSearchTableHandle table;
private final Iterator<Document<String, String>> cursor;
private final String[] columnNames;
private final List<Type> columnTypes;
private final PageBuilder pageBuilder;

private Document<String, String> currentDoc;
private final CursorIterator iterator;
private Map<String, Object> currentDoc;
private long count;
private boolean finished;

private final PageBuilder pageBuilder;

public RediSearchPageSource(RediSearchSession session, RediSearchTableHandle table,
List<RediSearchColumnHandle> columns) {
this.session = session;
this.table = table;
this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).toArray(String[]::new);
this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType)
.collect(Collectors.toList());
this.cursor = session.search(table, columnNames).iterator();
this.iterator = new CursorIterator(session, table, columnNames);
this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType).collect(Collectors.toList());
this.currentDoc = null;
this.pageBuilder = new PageBuilder(columnTypes);
}
Expand Down Expand Up @@ -108,26 +110,24 @@ public Page getNextPage() {
verify(pageBuilder.isEmpty());
count = 0;
for (int i = 0; i < ROWS_PER_REQUEST; i++) {
if (!cursor.hasNext()) {
if (!iterator.hasNext()) {
finished = true;
break;
}
currentDoc = cursor.next();
currentDoc = iterator.next();
count++;

pageBuilder.declarePosition();
for (int column = 0; column < columnTypes.size(); column++) {
BlockBuilder output = pageBuilder.getBlockBuilder(column);
String columnName = columnNames[column];
String value = currentValue(columnName);
Object value = currentValue(columnNames[column]);
if (value == null) {
output.appendNull();
} else {
writer.appendTo(columnTypes.get(column), value, output);
writer.appendTo(columnTypes.get(column), value.toString(), output);
}
}
}

Page page = pageBuilder.build();
pageBuilder.reset();
return page;
Expand All @@ -149,42 +149,33 @@ public void updateRows(Page page, List<Integer> columnValueAndRowIdChannels) {
columnValueAndRowIdChannels.size() - 1);
StatefulRedisModulesConnection<String, String> connection = session.getConnection();
connection.setAutoFlushCommands(false);
RedisModulesAsyncCommands<String, String> commands = connection.async();
List<RedisFuture<?>> futures = new ArrayList<>();
for (int position = 0; position < page.getPositionCount(); position++) {
Block rowIdBlock = page.getBlock(rowIdChannel);
if (rowIdBlock.isNull(position)) {
continue;
}
String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8();
Map<String, String> map = new HashMap<>();
for (int channel = 0; channel < columnChannelMapping.size(); channel++) {
RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel));
Block block = page.getBlock(channel);
if (block.isNull(position)) {
try {
RedisModulesAsyncCommands<String, String> commands = connection.async();
List<RedisFuture<?>> futures = new ArrayList<>();
for (int position = 0; position < page.getPositionCount(); position++) {
Block rowIdBlock = page.getBlock(rowIdChannel);
if (rowIdBlock.isNull(position)) {
continue;
}
String value = RediSearchPageSink.value(column.getType(), block, position);
map.put(column.getName(), value);
}
RedisFuture<Long> future = commands.hset(key, map);
futures.add(future);
}
connection.flushCommands();
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
connection.setAutoFlushCommands(true);
}

private String currentValue(String columnName) {
if (RediSearchBuiltinField.isBuiltinColumn(columnName)) {
if (RediSearchBuiltinField.ID.getName().equals(columnName)) {
return currentDoc.getId();
}
if (RediSearchBuiltinField.SCORE.getName().equals(columnName)) {
return String.valueOf(currentDoc.getScore());
String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8();
Map<String, String> map = new HashMap<>();
for (int channel = 0; channel < columnChannelMapping.size(); channel++) {
RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel));
Block block = page.getBlock(channel);
if (block.isNull(position)) {
continue;
}
String value = RediSearchPageSink.value(column.getType(), block, position);
map.put(column.getName(), value);
}
RedisFuture<Long> future = commands.hset(key, map);
futures.add(future);
}
connection.flushCommands();
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
} finally {
connection.setAutoFlushCommands(true);
}
return currentDoc.get(columnName);
}

@Override
Expand All @@ -194,12 +185,67 @@ public CompletableFuture<Collection<Slice>> finish() {
return future;
}

private Object currentValue(String columnName) {
if (RediSearchBuiltinField.isKeyColumn(columnName)) {
return currentDoc.get(RediSearchBuiltinField.KEY.getName());
}
return currentDoc.get(columnName);
}

public static JsonGenerator createJsonGenerator(JsonFactory factory, SliceOutput output) throws IOException {
return factory.createGenerator((OutputStream) output);
}

@Override
public void close() {
// nothing to do
try {
iterator.close();
} catch (Exception e) {
log.error(e, "Could not close cursor iterator");
}
}

private static class CursorIterator implements Iterator<Map<String, Object>>, AutoCloseable {

private final RediSearchSession session;
private final RediSearchTableHandle table;
private Iterator<Map<String, Object>> iterator;
private long cursor;

public CursorIterator(RediSearchSession session, RediSearchTableHandle table, String[] columnNames) {
this.session = session;
this.table = table;
read(session.aggregate(table, columnNames));
}

private void read(AggregateWithCursorResults<String> results) {
this.iterator = results.iterator();
this.cursor = results.getCursor();
}

@Override
public boolean hasNext() {
while (!iterator.hasNext()) {
if (cursor == 0) {
return false;
}
read(session.cursorRead(table, cursor));
}
return true;
}

@Override
public Map<String, Object> next() {
return iterator.next();
}

@Override
public void close() throws Exception {
if (cursor == 0) {
return;
}
session.cursorDelete(table, cursor);
}

}
}
Loading

0 comments on commit 498184e

Please sign in to comment.