From d84919900b04ca4615d1c92bf49fd4bbb73c476d Mon Sep 17 00:00:00 2001 From: Julien Ruaux Date: Mon, 2 Jan 2023 11:43:55 +0800 Subject: [PATCH] feat: Added support for UPDATE statements --- .gitignore | 1 + .../com/redis/trino/RediSearchMetadata.java | 56 +++++++++++++------ .../com/redis/trino/RediSearchPageSink.java | 21 ++++--- .../trino/RediSearchPageSinkProvider.java | 15 +++-- .../com/redis/trino/RediSearchPageSource.java | 50 +++++++++++++++-- .../redis/trino/RediSearchQueryBuilder.java | 4 +- .../com/redis/trino/RediSearchSession.java | 4 +- .../redis/trino/RediSearchTableHandle.java | 19 +++++-- .../TestRediSearchConnectorSmokeTest.java | 41 +++++++++++--- 9 files changed, 162 insertions(+), 49 deletions(-) diff --git a/.gitignore b/.gitignore index aaa3cc4..c911afe 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ build/ !**/src/test/**/build/ hs_err*.log test-output/ +*.dylib ### STS ### .checkstyle diff --git a/src/main/java/com/redis/trino/RediSearchMetadata.java b/src/main/java/com/redis/trino/RediSearchMetadata.java index 53d356e..92f6b92 100644 --- a/src/main/java/com/redis/trino/RediSearchMetadata.java +++ b/src/main/java/com/redis/trino/RediSearchMetadata.java @@ -26,11 +26,10 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.base.Verify.verifyNotNull; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.airlift.slice.SliceUtf8.getCodePointAt; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; -import static io.trino.spi.connector.RetryMode.NO_RETRIES; import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME; import static java.util.Objects.requireNonNull; @@ -205,9 +204,7 @@ public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandl @Override public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, Optional layout, RetryMode retryMode) { - if (retryMode != RetryMode.NO_RETRIES) { - throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support query retries"); - } + checkRetry(retryMode); List columns = buildColumnHandles(tableMetadata); rediSearchSession.createTable(tableMetadata.getTable(), columns); @@ -218,6 +215,12 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con columns.stream().filter(c -> !c.isHidden()).collect(Collectors.toList())); } + private void checkRetry(RetryMode retryMode) { + if (retryMode != RetryMode.NO_RETRIES) { + throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support retries"); + } + } + @Override public Optional finishCreateTable(ConnectorSession session, ConnectorOutputTableHandle tableHandle, Collection fragments, @@ -229,9 +232,7 @@ public Optional finishCreateTable(ConnectorSession sess @Override public ConnectorInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle, List insertedColumns, RetryMode retryMode) { - if (retryMode != RetryMode.NO_RETRIES) { - throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support query retries"); - } + checkRetry(retryMode); RediSearchTableHandle table = (RediSearchTableHandle) tableHandle; List columns = rediSearchSession.getTable(table.getSchemaTableName()).getColumns(); @@ -255,14 +256,34 @@ public RediSearchColumnHandle getDeleteRowIdColumnHandle(ConnectorSession sessio @Override public RediSearchTableHandle beginDelete(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) { - if (retryMode != NO_RETRIES) { - throw new TrinoException(NOT_SUPPORTED, "This connector does not support query retries"); - } + checkRetry(retryMode); return (RediSearchTableHandle) tableHandle; } @Override public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHandle, Collection fragments) { + // Do nothing + } + + @Override + public RediSearchColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle, + List updatedColumns) { + return RediSearchBuiltinField.ID.getColumnHandle(); + } + + @Override + public RediSearchTableHandle beginUpdate(ConnectorSession session, ConnectorTableHandle tableHandle, + List updatedColumns, RetryMode retryMode) { + checkRetry(retryMode); + RediSearchTableHandle table = (RediSearchTableHandle) tableHandle; + return new RediSearchTableHandle(table.getType(), table.getSchemaTableName(), table.getConstraint(), + table.getLimit(), table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(), + updatedColumns.stream().map(RediSearchColumnHandle.class::cast).collect(toImmutableList())); + } + + @Override + public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHandle, Collection fragments) { + // Do nothing } @Override @@ -285,9 +306,11 @@ public Optional> applyLimit(Connect return Optional.empty(); } - return Optional.of(new LimitApplicationResult<>(new RediSearchTableHandle(handle.getType(), - handle.getSchemaTableName(), handle.getConstraint(), OptionalLong.of(limit), - handle.getTermAggregations(), handle.getMetricAggregations(), handle.getWildcards()), true, false)); + return Optional.of(new LimitApplicationResult<>( + new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), handle.getConstraint(), + OptionalLong.of(limit), handle.getTermAggregations(), handle.getMetricAggregations(), + handle.getWildcards(), handle.getUpdatedColumns()), + true, false)); } @Override @@ -350,7 +373,7 @@ public Optional> applyFilter(C } handle = new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), newDomain, handle.getLimit(), - handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards); + handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards, handle.getUpdatedColumns()); return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported), newExpression, false)); @@ -476,7 +499,8 @@ public Optional> applyAggrega return Optional.empty(); } RediSearchTableHandle tableHandle = new RediSearchTableHandle(Type.AGGREGATE, table.getSchemaTableName(), - table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards()); + table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards(), + table.getUpdatedColumns()); return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(), resultAssignments.build(), Map.of(), false)); } diff --git a/src/main/java/com/redis/trino/RediSearchPageSink.java b/src/main/java/com/redis/trino/RediSearchPageSink.java index dd6e069..e11d963 100644 --- a/src/main/java/com/redis/trino/RediSearchPageSink.java +++ b/src/main/java/com/redis/trino/RediSearchPageSink.java @@ -52,6 +52,7 @@ import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; import com.redis.lettucemod.api.StatefulRedisModulesConnection; +import com.redis.lettucemod.api.async.RedisModulesAsyncCommands; import com.redis.lettucemod.search.CreateOptions; import com.redis.lettucemod.search.CreateOptions.DataType; import com.redis.lettucemod.search.IndexInfo; @@ -82,6 +83,7 @@ public class RediSearchPageSink implements ConnectorPageSink { + private static final String KEY_SEPARATOR = ":"; private final RediSearchSession session; private final SchemaTableName schemaTableName; private final List columns; @@ -96,22 +98,24 @@ public RediSearchPageSink(RediSearchSession rediSearchSession, SchemaTableName s @Override public CompletableFuture appendPage(Page page) { - String prefix = prefix().orElse(schemaTableName.getTableName()); + String prefix = prefix().orElse(schemaTableName.getTableName() + KEY_SEPARATOR); StatefulRedisModulesConnection connection = session.getConnection(); connection.setAutoFlushCommands(false); + RedisModulesAsyncCommands commands = connection.async(); List> futures = new ArrayList<>(); for (int position = 0; position < page.getPositionCount(); position++) { + String key = prefix + factory.create().toString(); Map map = new HashMap<>(); - String key = prefix + ":" + factory.create().toString(); for (int channel = 0; channel < page.getChannelCount(); channel++) { RediSearchColumnHandle column = columns.get(channel); Block block = page.getBlock(channel); if (block.isNull(position)) { continue; } - map.put(column.getName(), getObjectValue(columns.get(channel).getType(), block, position)); + String value = value(column.getType(), block, position); + map.put(column.getName(), value); } - RedisFuture future = connection.async().hset(key, map); + RedisFuture future = commands.hset(key, map); futures.add(future); } connection.flushCommands(); @@ -136,16 +140,16 @@ private Optional prefix() { if (prefix.equals("*")) { return Optional.empty(); } - if (prefix.endsWith(":")) { - return Optional.of(prefix.substring(0, prefix.length() - 1)); + if (prefix.endsWith(KEY_SEPARATOR)) { + return Optional.of(prefix); } - return Optional.of(prefix); + return Optional.of(prefix + KEY_SEPARATOR); } catch (Exception e) { return Optional.empty(); } } - private String getObjectValue(Type type, Block block, int position) { + public static String value(Type type, Block block, int position) { if (type.equals(BooleanType.BOOLEAN)) { return String.valueOf(type.getBoolean(block, position)); } @@ -205,5 +209,6 @@ public CompletableFuture> finish() { @Override public void abort() { + // Do nothing } } diff --git a/src/main/java/com/redis/trino/RediSearchPageSinkProvider.java b/src/main/java/com/redis/trino/RediSearchPageSinkProvider.java index 2fc6e03..73f8318 100644 --- a/src/main/java/com/redis/trino/RediSearchPageSinkProvider.java +++ b/src/main/java/com/redis/trino/RediSearchPageSinkProvider.java @@ -23,6 +23,8 @@ */ package com.redis.trino; +import java.util.List; + import javax.inject.Inject; import io.trino.spi.connector.ConnectorInsertTableHandle; @@ -32,27 +34,32 @@ import io.trino.spi.connector.ConnectorPageSinkProvider; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.connector.SchemaTableName; public class RediSearchPageSinkProvider implements ConnectorPageSinkProvider { - private final RediSearchSession rediSearchSession; + private final RediSearchSession session; @Inject public RediSearchPageSinkProvider(RediSearchSession rediSearchSession) { - this.rediSearchSession = rediSearchSession; + this.session = rediSearchSession; } @Override public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorOutputTableHandle outputTableHandle, ConnectorPageSinkId pageSinkId) { RediSearchOutputTableHandle handle = (RediSearchOutputTableHandle) outputTableHandle; - return new RediSearchPageSink(rediSearchSession, handle.getSchemaTableName(), handle.getColumns()); + return pageSink(handle.getSchemaTableName(), handle.getColumns()); } @Override public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorInsertTableHandle insertTableHandle, ConnectorPageSinkId pageSinkId) { RediSearchInsertTableHandle handle = (RediSearchInsertTableHandle) insertTableHandle; - return new RediSearchPageSink(rediSearchSession, handle.getSchemaTableName(), handle.getColumns()); + return pageSink(handle.getSchemaTableName(), handle.getColumns()); + } + + private RediSearchPageSink pageSink(SchemaTableName schemaTableName, List columns) { + return new RediSearchPageSink(session, schemaTableName, columns); } } diff --git a/src/main/java/com/redis/trino/RediSearchPageSource.java b/src/main/java/com/redis/trino/RediSearchPageSource.java index a7a8f5c..a68fcc8 100644 --- a/src/main/java/com/redis/trino/RediSearchPageSource.java +++ b/src/main/java/com/redis/trino/RediSearchPageSource.java @@ -30,23 +30,30 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonGenerator; +import com.redis.lettucemod.api.StatefulRedisModulesConnection; +import com.redis.lettucemod.api.async.RedisModulesAsyncCommands; import com.redis.lettucemod.search.Document; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; +import io.lettuce.core.LettuceFutures; +import io.lettuce.core.RedisFuture; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.UpdatablePageSource; import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; public class RediSearchPageSource implements UpdatablePageSource { @@ -54,6 +61,7 @@ public class RediSearchPageSource implements UpdatablePageSource { private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter(); private final RediSearchSession session; + private final RediSearchTableHandle table; private final Iterator> cursor; private final String[] columnNames; private final List columnTypes; @@ -66,9 +74,10 @@ public class RediSearchPageSource implements UpdatablePageSource { public RediSearchPageSource(RediSearchSession session, RediSearchTableHandle table, List columns) { this.session = session; + this.table = table; this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).toArray(String[]::new); this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType) - .collect(Collectors.toUnmodifiableList()); + .collect(Collectors.toList()); this.cursor = session.search(table, columnNames).iterator(); this.currentDoc = null; this.pageBuilder = new PageBuilder(columnTypes); @@ -127,14 +136,45 @@ public Page getNextPage() { @Override public void deleteRows(Block rowIds) { List docIds = new ArrayList<>(rowIds.getPositionCount()); - for (int i = 0; i < rowIds.getPositionCount(); i++) { - int len = rowIds.getSliceLength(i); - Slice slice = rowIds.getSlice(i, 0, len); - docIds.add(slice.toStringUtf8()); + for (int position = 0; position < rowIds.getPositionCount(); position++) { + docIds.add(VarcharType.VARCHAR.getSlice(rowIds, position).toStringUtf8()); } session.deleteDocs(docIds); } + @Override + public void updateRows(Page page, List columnValueAndRowIdChannels) { + int rowIdChannel = columnValueAndRowIdChannels.get(columnValueAndRowIdChannels.size() - 1); + List columnChannelMapping = columnValueAndRowIdChannels.subList(0, + columnValueAndRowIdChannels.size() - 1); + StatefulRedisModulesConnection connection = session.getConnection(); + connection.setAutoFlushCommands(false); + RedisModulesAsyncCommands commands = connection.async(); + List> futures = new ArrayList<>(); + for (int position = 0; position < page.getPositionCount(); position++) { + Block rowIdBlock = page.getBlock(rowIdChannel); + if (rowIdBlock.isNull(position)) { + continue; + } + String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8(); + Map map = new HashMap<>(); + for (int channel = 0; channel < columnChannelMapping.size(); channel++) { + RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel)); + Block block = page.getBlock(channel); + if (block.isNull(position)) { + continue; + } + String value = RediSearchPageSink.value(column.getType(), block, position); + map.put(column.getName(), value); + } + RedisFuture future = commands.hset(key, map); + futures.add(future); + } + connection.flushCommands(); + LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0])); + connection.setAutoFlushCommands(true); + } + private String currentValue(String columnName) { if (RediSearchBuiltinField.isBuiltinColumn(columnName)) { if (RediSearchBuiltinField.ID.getName().equals(columnName)) { diff --git a/src/main/java/com/redis/trino/RediSearchQueryBuilder.java b/src/main/java/com/redis/trino/RediSearchQueryBuilder.java index b9c6dd4..75608e7 100644 --- a/src/main/java/com/redis/trino/RediSearchQueryBuilder.java +++ b/src/main/java/com/redis/trino/RediSearchQueryBuilder.java @@ -242,9 +242,9 @@ public Optional group(RediSearchTableHandle table) { List aggregates = table.getMetricAggregations(); List groupFields = new ArrayList<>(); if (terms != null && !terms.isEmpty()) { - groupFields = terms.stream().map(RediSearchAggregationTerm::getTerm).collect(Collectors.toUnmodifiableList()); + groupFields = terms.stream().map(RediSearchAggregationTerm::getTerm).collect(Collectors.toList()); } - List reducers = aggregates.stream().map(this::reducer).collect(Collectors.toUnmodifiableList()); + List reducers = aggregates.stream().map(this::reducer).collect(Collectors.toList()); if (reducers.isEmpty()) { return Optional.empty(); } diff --git a/src/main/java/com/redis/trino/RediSearchSession.java b/src/main/java/com/redis/trino/RediSearchSession.java index 5d8f1fb..5d2a8c6 100644 --- a/src/main/java/com/redis/trino/RediSearchSession.java +++ b/src/main/java/com/redis/trino/RediSearchSession.java @@ -185,7 +185,7 @@ public void createTable(SchemaTableName schemaTableName, List> fields = columns.stream().filter(c -> !c.getName().equals("_id")) - .map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toUnmodifiableList()); + .map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toList()); CreateOptions.Builder options = CreateOptions.builder(); options.prefix(index + ":"); connection.sync().ftCreate(index, options.build(), fields.toArray(Field[]::new)); @@ -308,7 +308,7 @@ public AggregateWithCursorResults aggregate(RediSearchTableHandle table) AggregateWithCursorResults results = connection.sync().ftAggregate(aggregation.getIndex(), aggregation.getQuery(), aggregation.getCursorOptions(), aggregation.getOptions()); List> groupBys = aggregation.getOptions().getOperations().stream() - .filter(o -> o.getType() == AggregateOperation.Type.GROUP).collect(Collectors.toUnmodifiableList()); + .filter(o -> o.getType() == AggregateOperation.Type.GROUP).collect(Collectors.toList()); if (results.isEmpty() && !groupBys.isEmpty()) { Group groupBy = (Group) groupBys.get(0); Optional as = groupBy.getReducers()[0].getAs(); diff --git a/src/main/java/com/redis/trino/RediSearchTableHandle.java b/src/main/java/com/redis/trino/RediSearchTableHandle.java index 5cba87a..6479ccb 100644 --- a/src/main/java/com/redis/trino/RediSearchTableHandle.java +++ b/src/main/java/com/redis/trino/RediSearchTableHandle.java @@ -33,6 +33,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableHandle; @@ -53,10 +54,12 @@ public enum Type { private final List aggregationTerms; private final List aggregations; private final Map wildcards; + // UPDATE only + private final List updatedColumns; public RediSearchTableHandle(Type type, SchemaTableName schemaTableName) { this(type, schemaTableName, TupleDomain.all(), OptionalLong.empty(), Collections.emptyList(), - Collections.emptyList(), Map.of()); + Collections.emptyList(), Map.of(), Collections.emptyList()); } @JsonCreator @@ -65,7 +68,8 @@ public RediSearchTableHandle(@JsonProperty("type") Type type, @JsonProperty("constraint") TupleDomain constraint, @JsonProperty("limit") OptionalLong limit, @JsonProperty("aggTerms") List termAggregations, @JsonProperty("aggregates") List metricAggregations, - @JsonProperty("wildcards") Map wildcards) { + @JsonProperty("wildcards") Map wildcards, + @JsonProperty("updatedColumns") List updatedColumns) { this.type = requireNonNull(type, "type is null"); this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); this.constraint = requireNonNull(constraint, "constraint is null"); @@ -73,6 +77,7 @@ public RediSearchTableHandle(@JsonProperty("type") Type type, this.aggregationTerms = requireNonNull(termAggregations, "aggTerms is null"); this.aggregations = requireNonNull(metricAggregations, "aggregates is null"); this.wildcards = requireNonNull(wildcards, "wildcards is null"); + this.updatedColumns = ImmutableList.copyOf(requireNonNull(updatedColumns, "updatedColumns is null")); } @JsonProperty @@ -110,9 +115,14 @@ public Map getWildcards() { return wildcards; } + @JsonProperty + public List getUpdatedColumns() { + return updatedColumns; + } + @Override public int hashCode() { - return Objects.hash(schemaTableName, constraint, limit); + return Objects.hash(schemaTableName, constraint, limit, updatedColumns); } @Override @@ -125,7 +135,8 @@ public boolean equals(Object obj) { } RediSearchTableHandle other = (RediSearchTableHandle) obj; return Objects.equals(this.schemaTableName, other.schemaTableName) - && Objects.equals(this.constraint, other.constraint) && Objects.equals(this.limit, other.limit); + && Objects.equals(this.constraint, other.constraint) && Objects.equals(this.limit, other.limit) + && Objects.equals(updatedColumns, other.updatedColumns); } @Override diff --git a/src/test/java/com/redis/trino/TestRediSearchConnectorSmokeTest.java b/src/test/java/com/redis/trino/TestRediSearchConnectorSmokeTest.java index 301190a..734673d 100644 --- a/src/test/java/com/redis/trino/TestRediSearchConnectorSmokeTest.java +++ b/src/test/java/com/redis/trino/TestRediSearchConnectorSmokeTest.java @@ -1,5 +1,7 @@ package com.redis.trino; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_INSERT; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_UPDATE; import static io.trino.tpch.TpchTable.CUSTOMER; import static io.trino.tpch.TpchTable.NATION; import static io.trino.tpch.TpchTable.ORDERS; @@ -29,6 +31,7 @@ import io.trino.testing.BaseConnectorSmokeTest; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; +import io.trino.testing.sql.TestTable; public class TestRediSearchConnectorSmokeTest extends BaseConnectorSmokeTest { @@ -49,22 +52,18 @@ private void populateBeers() throws IOException { protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { case SUPPORTS_CREATE_SCHEMA: - return false; - case SUPPORTS_CREATE_VIEW: return false; case SUPPORTS_CREATE_TABLE: return true; - case SUPPORTS_RENAME_TABLE: - return false; - case SUPPORTS_ARRAY: return false; case SUPPORTS_DROP_COLUMN: case SUPPORTS_RENAME_COLUMN: + case SUPPORTS_RENAME_TABLE: return false; case SUPPORTS_COMMENT_ON_TABLE: @@ -78,14 +77,13 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { return false; case SUPPORTS_DELETE: + case SUPPORTS_INSERT: + case SUPPORTS_UPDATE: return true; case SUPPORTS_RENAME_TABLE_ACROSS_SCHEMAS: return false; - case SUPPORTS_INSERT: - return true; - case SUPPORTS_LIMIT_PUSHDOWN: return true; @@ -222,4 +220,31 @@ public void testInPredicateNumeric() { assertQuery("SELECT name, regionkey FROM nation WHERE regionKey in (1, 2, 3)"); } + @SuppressWarnings("resource") + @Test + public void testUpdate() { + if (!hasBehavior(SUPPORTS_UPDATE)) { + // Note this change is a no-op, if actually run + assertQueryFails("UPDATE nation SET nationkey = nationkey + regionkey WHERE regionkey < 1", + "This connector does not support updates"); + return; + } + + if (!hasBehavior(SUPPORTS_INSERT)) { + throw new AssertionError("Cannot test UPDATE without INSERT"); + } + + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_", + getCreateTableDefaultDefinition())) { + assertUpdate("INSERT INTO " + table.getName() + " (a, b) SELECT regionkey, regionkey * 2.5 FROM region", + "SELECT count(*) FROM region"); + assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName())) + .matches(expectedValues("(0, 0.0), (1, 2.5), (2, 5.0), (3, 7.5), (4, 10.0)")); + + assertUpdate("UPDATE " + table.getName() + " SET b = b + 1.2 WHERE a % 2 = 0", 3); + assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName())) + .matches(expectedValues("(0, 1.2), (1, 2.5), (2, 6.2), (3, 7.5), (4, 11.2)")); + } + } + }