Skip to content

Commit

Permalink
feat: Added support for UPDATE statements
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Ruaux committed Jan 2, 2023
1 parent 4c1336b commit d849199
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 49 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build/
!**/src/test/**/build/
hs_err*.log
test-output/
*.dylib

### STS ###
.checkstyle
Expand Down
56 changes: 40 additions & 16 deletions src/main/java/com/redis/trino/RediSearchMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.base.Verify.verifyNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.airlift.slice.SliceUtf8.getCodePointAt;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.connector.RetryMode.NO_RETRIES;
import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -205,9 +204,7 @@ public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandl
@Override
public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata,
Optional<ConnectorTableLayout> layout, RetryMode retryMode) {
if (retryMode != RetryMode.NO_RETRIES) {
throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support query retries");
}
checkRetry(retryMode);
List<RediSearchColumnHandle> columns = buildColumnHandles(tableMetadata);

rediSearchSession.createTable(tableMetadata.getTable(), columns);
Expand All @@ -218,6 +215,12 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con
columns.stream().filter(c -> !c.isHidden()).collect(Collectors.toList()));
}

private void checkRetry(RetryMode retryMode) {
if (retryMode != RetryMode.NO_RETRIES) {
throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support retries");
}
}

@Override
public Optional<ConnectorOutputMetadata> finishCreateTable(ConnectorSession session,
ConnectorOutputTableHandle tableHandle, Collection<Slice> fragments,
Expand All @@ -229,9 +232,7 @@ public Optional<ConnectorOutputMetadata> finishCreateTable(ConnectorSession sess
@Override
public ConnectorInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle,
List<ColumnHandle> insertedColumns, RetryMode retryMode) {
if (retryMode != RetryMode.NO_RETRIES) {
throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support query retries");
}
checkRetry(retryMode);
RediSearchTableHandle table = (RediSearchTableHandle) tableHandle;
List<RediSearchColumnHandle> columns = rediSearchSession.getTable(table.getSchemaTableName()).getColumns();

Expand All @@ -255,14 +256,34 @@ public RediSearchColumnHandle getDeleteRowIdColumnHandle(ConnectorSession sessio
@Override
public RediSearchTableHandle beginDelete(ConnectorSession session, ConnectorTableHandle tableHandle,
RetryMode retryMode) {
if (retryMode != NO_RETRIES) {
throw new TrinoException(NOT_SUPPORTED, "This connector does not support query retries");
}
checkRetry(retryMode);
return (RediSearchTableHandle) tableHandle;
}

@Override
public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHandle, Collection<Slice> fragments) {
// Do nothing
}

@Override
public RediSearchColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle,
List<ColumnHandle> updatedColumns) {
return RediSearchBuiltinField.ID.getColumnHandle();
}

@Override
public RediSearchTableHandle beginUpdate(ConnectorSession session, ConnectorTableHandle tableHandle,
List<ColumnHandle> updatedColumns, RetryMode retryMode) {
checkRetry(retryMode);
RediSearchTableHandle table = (RediSearchTableHandle) tableHandle;
return new RediSearchTableHandle(table.getType(), table.getSchemaTableName(), table.getConstraint(),
table.getLimit(), table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(),
updatedColumns.stream().map(RediSearchColumnHandle.class::cast).collect(toImmutableList()));
}

@Override
public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHandle, Collection<Slice> fragments) {
// Do nothing
}

@Override
Expand All @@ -285,9 +306,11 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
return Optional.empty();
}

return Optional.of(new LimitApplicationResult<>(new RediSearchTableHandle(handle.getType(),
handle.getSchemaTableName(), handle.getConstraint(), OptionalLong.of(limit),
handle.getTermAggregations(), handle.getMetricAggregations(), handle.getWildcards()), true, false));
return Optional.of(new LimitApplicationResult<>(
new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), handle.getConstraint(),
OptionalLong.of(limit), handle.getTermAggregations(), handle.getMetricAggregations(),
handle.getWildcards(), handle.getUpdatedColumns()),
true, false));
}

@Override
Expand Down Expand Up @@ -350,7 +373,7 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
}

handle = new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), newDomain, handle.getLimit(),
handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards);
handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards, handle.getUpdatedColumns());

return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported),
newExpression, false));
Expand Down Expand Up @@ -476,7 +499,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
return Optional.empty();
}
RediSearchTableHandle tableHandle = new RediSearchTableHandle(Type.AGGREGATE, table.getSchemaTableName(),
table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards());
table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards(),
table.getUpdatedColumns());
return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(),
resultAssignments.build(), Map.of(), false));
}
Expand Down
21 changes: 13 additions & 8 deletions src/main/java/com/redis/trino/RediSearchPageSink.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import com.google.common.primitives.Shorts;
import com.google.common.primitives.SignedBytes;
import com.redis.lettucemod.api.StatefulRedisModulesConnection;
import com.redis.lettucemod.api.async.RedisModulesAsyncCommands;
import com.redis.lettucemod.search.CreateOptions;
import com.redis.lettucemod.search.CreateOptions.DataType;
import com.redis.lettucemod.search.IndexInfo;
Expand Down Expand Up @@ -82,6 +83,7 @@

public class RediSearchPageSink implements ConnectorPageSink {

private static final String KEY_SEPARATOR = ":";
private final RediSearchSession session;
private final SchemaTableName schemaTableName;
private final List<RediSearchColumnHandle> columns;
Expand All @@ -96,22 +98,24 @@ public RediSearchPageSink(RediSearchSession rediSearchSession, SchemaTableName s

@Override
public CompletableFuture<?> appendPage(Page page) {
String prefix = prefix().orElse(schemaTableName.getTableName());
String prefix = prefix().orElse(schemaTableName.getTableName() + KEY_SEPARATOR);
StatefulRedisModulesConnection<String, String> connection = session.getConnection();
connection.setAutoFlushCommands(false);
RedisModulesAsyncCommands<String, String> commands = connection.async();
List<RedisFuture<?>> futures = new ArrayList<>();
for (int position = 0; position < page.getPositionCount(); position++) {
String key = prefix + factory.create().toString();
Map<String, String> map = new HashMap<>();
String key = prefix + ":" + factory.create().toString();
for (int channel = 0; channel < page.getChannelCount(); channel++) {
RediSearchColumnHandle column = columns.get(channel);
Block block = page.getBlock(channel);
if (block.isNull(position)) {
continue;
}
map.put(column.getName(), getObjectValue(columns.get(channel).getType(), block, position));
String value = value(column.getType(), block, position);
map.put(column.getName(), value);
}
RedisFuture<Long> future = connection.async().hset(key, map);
RedisFuture<Long> future = commands.hset(key, map);
futures.add(future);
}
connection.flushCommands();
Expand All @@ -136,16 +140,16 @@ private Optional<String> prefix() {
if (prefix.equals("*")) {
return Optional.empty();
}
if (prefix.endsWith(":")) {
return Optional.of(prefix.substring(0, prefix.length() - 1));
if (prefix.endsWith(KEY_SEPARATOR)) {
return Optional.of(prefix);
}
return Optional.of(prefix);
return Optional.of(prefix + KEY_SEPARATOR);
} catch (Exception e) {
return Optional.empty();
}
}

private String getObjectValue(Type type, Block block, int position) {
public static String value(Type type, Block block, int position) {
if (type.equals(BooleanType.BOOLEAN)) {
return String.valueOf(type.getBoolean(block, position));
}
Expand Down Expand Up @@ -205,5 +209,6 @@ public CompletableFuture<Collection<Slice>> finish() {

@Override
public void abort() {
// Do nothing
}
}
15 changes: 11 additions & 4 deletions src/main/java/com/redis/trino/RediSearchPageSinkProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
*/
package com.redis.trino;

import java.util.List;

import javax.inject.Inject;

import io.trino.spi.connector.ConnectorInsertTableHandle;
Expand All @@ -32,27 +34,32 @@
import io.trino.spi.connector.ConnectorPageSinkProvider;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTransactionHandle;
import io.trino.spi.connector.SchemaTableName;

public class RediSearchPageSinkProvider implements ConnectorPageSinkProvider {

private final RediSearchSession rediSearchSession;
private final RediSearchSession session;

@Inject
public RediSearchPageSinkProvider(RediSearchSession rediSearchSession) {
this.rediSearchSession = rediSearchSession;
this.session = rediSearchSession;
}

@Override
public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session,
ConnectorOutputTableHandle outputTableHandle, ConnectorPageSinkId pageSinkId) {
RediSearchOutputTableHandle handle = (RediSearchOutputTableHandle) outputTableHandle;
return new RediSearchPageSink(rediSearchSession, handle.getSchemaTableName(), handle.getColumns());
return pageSink(handle.getSchemaTableName(), handle.getColumns());
}

@Override
public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session,
ConnectorInsertTableHandle insertTableHandle, ConnectorPageSinkId pageSinkId) {
RediSearchInsertTableHandle handle = (RediSearchInsertTableHandle) insertTableHandle;
return new RediSearchPageSink(rediSearchSession, handle.getSchemaTableName(), handle.getColumns());
return pageSink(handle.getSchemaTableName(), handle.getColumns());
}

private RediSearchPageSink pageSink(SchemaTableName schemaTableName, List<RediSearchColumnHandle> columns) {
return new RediSearchPageSink(session, schemaTableName, columns);
}
}
50 changes: 45 additions & 5 deletions src/main/java/com/redis/trino/RediSearchPageSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,38 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.redis.lettucemod.api.StatefulRedisModulesConnection;
import com.redis.lettucemod.api.async.RedisModulesAsyncCommands;
import com.redis.lettucemod.search.Document;

import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.lettuce.core.LettuceFutures;
import io.lettuce.core.RedisFuture;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.connector.UpdatablePageSource;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;

public class RediSearchPageSource implements UpdatablePageSource {

private static final int ROWS_PER_REQUEST = 1024;

private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter();
private final RediSearchSession session;
private final RediSearchTableHandle table;
private final Iterator<Document<String, String>> cursor;
private final String[] columnNames;
private final List<Type> columnTypes;
Expand All @@ -66,9 +74,10 @@ public class RediSearchPageSource implements UpdatablePageSource {
public RediSearchPageSource(RediSearchSession session, RediSearchTableHandle table,
List<RediSearchColumnHandle> columns) {
this.session = session;
this.table = table;
this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).toArray(String[]::new);
this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType)
.collect(Collectors.toUnmodifiableList());
.collect(Collectors.toList());
this.cursor = session.search(table, columnNames).iterator();
this.currentDoc = null;
this.pageBuilder = new PageBuilder(columnTypes);
Expand Down Expand Up @@ -127,14 +136,45 @@ public Page getNextPage() {
@Override
public void deleteRows(Block rowIds) {
List<String> docIds = new ArrayList<>(rowIds.getPositionCount());
for (int i = 0; i < rowIds.getPositionCount(); i++) {
int len = rowIds.getSliceLength(i);
Slice slice = rowIds.getSlice(i, 0, len);
docIds.add(slice.toStringUtf8());
for (int position = 0; position < rowIds.getPositionCount(); position++) {
docIds.add(VarcharType.VARCHAR.getSlice(rowIds, position).toStringUtf8());
}
session.deleteDocs(docIds);
}

@Override
public void updateRows(Page page, List<Integer> columnValueAndRowIdChannels) {
int rowIdChannel = columnValueAndRowIdChannels.get(columnValueAndRowIdChannels.size() - 1);
List<Integer> columnChannelMapping = columnValueAndRowIdChannels.subList(0,
columnValueAndRowIdChannels.size() - 1);
StatefulRedisModulesConnection<String, String> connection = session.getConnection();
connection.setAutoFlushCommands(false);
RedisModulesAsyncCommands<String, String> commands = connection.async();
List<RedisFuture<?>> futures = new ArrayList<>();
for (int position = 0; position < page.getPositionCount(); position++) {
Block rowIdBlock = page.getBlock(rowIdChannel);
if (rowIdBlock.isNull(position)) {
continue;
}
String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8();
Map<String, String> map = new HashMap<>();
for (int channel = 0; channel < columnChannelMapping.size(); channel++) {
RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel));
Block block = page.getBlock(channel);
if (block.isNull(position)) {
continue;
}
String value = RediSearchPageSink.value(column.getType(), block, position);
map.put(column.getName(), value);
}
RedisFuture<Long> future = commands.hset(key, map);
futures.add(future);
}
connection.flushCommands();
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
connection.setAutoFlushCommands(true);
}

private String currentValue(String columnName) {
if (RediSearchBuiltinField.isBuiltinColumn(columnName)) {
if (RediSearchBuiltinField.ID.getName().equals(columnName)) {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/redis/trino/RediSearchQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ public Optional<Group> group(RediSearchTableHandle table) {
List<RediSearchAggregation> aggregates = table.getMetricAggregations();
List<String> groupFields = new ArrayList<>();
if (terms != null && !terms.isEmpty()) {
groupFields = terms.stream().map(RediSearchAggregationTerm::getTerm).collect(Collectors.toUnmodifiableList());
groupFields = terms.stream().map(RediSearchAggregationTerm::getTerm).collect(Collectors.toList());
}
List<Reducer> reducers = aggregates.stream().map(this::reducer).collect(Collectors.toUnmodifiableList());
List<Reducer> reducers = aggregates.stream().map(this::reducer).collect(Collectors.toList());
if (reducers.isEmpty()) {
return Optional.empty();
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/redis/trino/RediSearchSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public void createTable(SchemaTableName schemaTableName, List<RediSearchColumnHa
String index = index(schemaTableName);
if (!connection.sync().ftList().contains(index)) {
List<Field<String>> fields = columns.stream().filter(c -> !c.getName().equals("_id"))
.map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toUnmodifiableList());
.map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toList());
CreateOptions.Builder<String, String> options = CreateOptions.<String, String>builder();
options.prefix(index + ":");
connection.sync().ftCreate(index, options.build(), fields.toArray(Field[]::new));
Expand Down Expand Up @@ -308,7 +308,7 @@ public AggregateWithCursorResults<String> aggregate(RediSearchTableHandle table)
AggregateWithCursorResults<String> results = connection.sync().ftAggregate(aggregation.getIndex(),
aggregation.getQuery(), aggregation.getCursorOptions(), aggregation.getOptions());
List<AggregateOperation<?, ?>> groupBys = aggregation.getOptions().getOperations().stream()
.filter(o -> o.getType() == AggregateOperation.Type.GROUP).collect(Collectors.toUnmodifiableList());
.filter(o -> o.getType() == AggregateOperation.Type.GROUP).collect(Collectors.toList());
if (results.isEmpty() && !groupBys.isEmpty()) {
Group groupBy = (Group) groupBys.get(0);
Optional<String> as = groupBy.getReducers()[0].getAs();
Expand Down
Loading

0 comments on commit d849199

Please sign in to comment.