diff --git a/src/main/java/com/redis/trino/RediSearchBuiltinField.java b/src/main/java/com/redis/trino/RediSearchBuiltinField.java index 87a468c..aceb5f7 100644 --- a/src/main/java/com/redis/trino/RediSearchBuiltinField.java +++ b/src/main/java/com/redis/trino/RediSearchBuiltinField.java @@ -1,7 +1,6 @@ package com.redis.trino; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.VarcharType.VARCHAR; import static java.util.Arrays.stream; import static java.util.function.Function.identity; @@ -16,7 +15,7 @@ enum RediSearchBuiltinField { - ID("_id", VARCHAR, Field.Type.TAG), SCORE("_score", REAL, Field.Type.NUMERIC); + KEY("__key", VARCHAR, Field.Type.TAG); private static final Map COLUMNS_BY_NAME = stream(values()) .collect(toImmutableMap(RediSearchBuiltinField::getName, identity())); @@ -58,4 +57,8 @@ public ColumnMetadata getMetadata() { public RediSearchColumnHandle getColumnHandle() { return new RediSearchColumnHandle(name, type, fieldType, true, false); } + + public static boolean isKeyColumn(String columnName) { + return KEY.name.equals(columnName); + } } diff --git a/src/main/java/com/redis/trino/RediSearchMetadata.java b/src/main/java/com/redis/trino/RediSearchMetadata.java index 92f6b92..64afd11 100644 --- a/src/main/java/com/redis/trino/RediSearchMetadata.java +++ b/src/main/java/com/redis/trino/RediSearchMetadata.java @@ -49,7 +49,6 @@ import com.google.common.collect.ImmutableMap; import com.redis.lettucemod.search.Field; import com.redis.lettucemod.search.querybuilder.Values; -import com.redis.trino.RediSearchTableHandle.Type; import io.airlift.log.Logger; import io.airlift.slice.Slice; @@ -250,7 +249,7 @@ public Optional finishInsert(ConnectorSession session, @Override public RediSearchColumnHandle getDeleteRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) { - return RediSearchBuiltinField.ID.getColumnHandle(); + return RediSearchBuiltinField.KEY.getColumnHandle(); } @Override @@ -268,7 +267,7 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan @Override public RediSearchColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle, List updatedColumns) { - return RediSearchBuiltinField.ID.getColumnHandle(); + return RediSearchBuiltinField.KEY.getColumnHandle(); } @Override @@ -276,8 +275,8 @@ public RediSearchTableHandle beginUpdate(ConnectorSession session, ConnectorTabl List updatedColumns, RetryMode retryMode) { checkRetry(retryMode); RediSearchTableHandle table = (RediSearchTableHandle) tableHandle; - return new RediSearchTableHandle(table.getType(), table.getSchemaTableName(), table.getConstraint(), - table.getLimit(), table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(), + return new RediSearchTableHandle(table.getSchemaTableName(), table.getConstraint(), table.getLimit(), + table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(), updatedColumns.stream().map(RediSearchColumnHandle.class::cast).collect(toImmutableList())); } @@ -306,11 +305,9 @@ public Optional> applyLimit(Connect return Optional.empty(); } - return Optional.of(new LimitApplicationResult<>( - new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), handle.getConstraint(), - OptionalLong.of(limit), handle.getTermAggregations(), handle.getMetricAggregations(), - handle.getWildcards(), handle.getUpdatedColumns()), - true, false)); + return Optional.of(new LimitApplicationResult<>(new RediSearchTableHandle(handle.getSchemaTableName(), + handle.getConstraint(), OptionalLong.of(limit), handle.getTermAggregations(), + handle.getMetricAggregations(), handle.getWildcards(), handle.getUpdatedColumns()), true, false)); } @Override @@ -372,7 +369,7 @@ public Optional> applyFilter(C return Optional.empty(); } - handle = new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), newDomain, handle.getLimit(), + handle = new RediSearchTableHandle(handle.getSchemaTableName(), newDomain, handle.getLimit(), handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards, handle.getUpdatedColumns()); return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported), @@ -498,9 +495,8 @@ public Optional> applyAggrega if (aggregationList.isEmpty()) { return Optional.empty(); } - RediSearchTableHandle tableHandle = new RediSearchTableHandle(Type.AGGREGATE, table.getSchemaTableName(), - table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards(), - table.getUpdatedColumns()); + RediSearchTableHandle tableHandle = new RediSearchTableHandle(table.getSchemaTableName(), table.getConstraint(), + table.getLimit(), terms.build(), aggregationList, table.getWildcards(), table.getUpdatedColumns()); return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(), resultAssignments.build(), Map.of(), false)); } diff --git a/src/main/java/com/redis/trino/RediSearchPageSink.java b/src/main/java/com/redis/trino/RediSearchPageSink.java index a32b1ae..8723ecc 100644 --- a/src/main/java/com/redis/trino/RediSearchPageSink.java +++ b/src/main/java/com/redis/trino/RediSearchPageSink.java @@ -100,26 +100,29 @@ public CompletableFuture appendPage(Page page) { String prefix = prefix().orElse(schemaTableName.getTableName() + KEY_SEPARATOR); StatefulRedisModulesConnection connection = session.getConnection(); connection.setAutoFlushCommands(false); - RedisModulesAsyncCommands commands = connection.async(); - List> futures = new ArrayList<>(); - for (int position = 0; position < page.getPositionCount(); position++) { - String key = prefix + factory.create().toString(); - Map map = new HashMap<>(); - for (int channel = 0; channel < page.getChannelCount(); channel++) { - RediSearchColumnHandle column = columns.get(channel); - Block block = page.getBlock(channel); - if (block.isNull(position)) { - continue; + try { + RedisModulesAsyncCommands commands = connection.async(); + List> futures = new ArrayList<>(); + for (int position = 0; position < page.getPositionCount(); position++) { + String key = prefix + factory.create().toString(); + Map map = new HashMap<>(); + for (int channel = 0; channel < page.getChannelCount(); channel++) { + RediSearchColumnHandle column = columns.get(channel); + Block block = page.getBlock(channel); + if (block.isNull(position)) { + continue; + } + String value = value(column.getType(), block, position); + map.put(column.getName(), value); } - String value = value(column.getType(), block, position); - map.put(column.getName(), value); + RedisFuture future = commands.hset(key, map); + futures.add(future); } - RedisFuture future = commands.hset(key, map); - futures.add(future); + connection.flushCommands(); + LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0])); + } finally { + connection.setAutoFlushCommands(true); } - connection.flushCommands(); - LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0])); - connection.setAutoFlushCommands(true); return NOT_BLOCKED; } diff --git a/src/main/java/com/redis/trino/RediSearchPageSource.java b/src/main/java/com/redis/trino/RediSearchPageSource.java index a68fcc8..4dc5ee7 100644 --- a/src/main/java/com/redis/trino/RediSearchPageSource.java +++ b/src/main/java/com/redis/trino/RediSearchPageSource.java @@ -41,8 +41,9 @@ import com.fasterxml.jackson.core.JsonGenerator; import com.redis.lettucemod.api.StatefulRedisModulesConnection; import com.redis.lettucemod.api.async.RedisModulesAsyncCommands; -import com.redis.lettucemod.search.Document; +import com.redis.lettucemod.search.AggregateWithCursorResults; +import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.lettuce.core.LettuceFutures; @@ -57,28 +58,29 @@ public class RediSearchPageSource implements UpdatablePageSource { + private static final Logger log = Logger.get(RediSearchPageSource.class); + private static final int ROWS_PER_REQUEST = 1024; private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter(); private final RediSearchSession session; private final RediSearchTableHandle table; - private final Iterator> cursor; private final String[] columnNames; private final List columnTypes; - private final PageBuilder pageBuilder; - - private Document currentDoc; + private final CursorIterator iterator; + private Map currentDoc; private long count; private boolean finished; + private final PageBuilder pageBuilder; + public RediSearchPageSource(RediSearchSession session, RediSearchTableHandle table, List columns) { this.session = session; this.table = table; this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).toArray(String[]::new); - this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType) - .collect(Collectors.toList()); - this.cursor = session.search(table, columnNames).iterator(); + this.iterator = new CursorIterator(session, table, columnNames); + this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType).collect(Collectors.toList()); this.currentDoc = null; this.pageBuilder = new PageBuilder(columnTypes); } @@ -108,26 +110,24 @@ public Page getNextPage() { verify(pageBuilder.isEmpty()); count = 0; for (int i = 0; i < ROWS_PER_REQUEST; i++) { - if (!cursor.hasNext()) { + if (!iterator.hasNext()) { finished = true; break; } - currentDoc = cursor.next(); + currentDoc = iterator.next(); count++; pageBuilder.declarePosition(); for (int column = 0; column < columnTypes.size(); column++) { BlockBuilder output = pageBuilder.getBlockBuilder(column); - String columnName = columnNames[column]; - String value = currentValue(columnName); + Object value = currentValue(columnNames[column]); if (value == null) { output.appendNull(); } else { - writer.appendTo(columnTypes.get(column), value, output); + writer.appendTo(columnTypes.get(column), value.toString(), output); } } } - Page page = pageBuilder.build(); pageBuilder.reset(); return page; @@ -149,42 +149,33 @@ public void updateRows(Page page, List columnValueAndRowIdChannels) { columnValueAndRowIdChannels.size() - 1); StatefulRedisModulesConnection connection = session.getConnection(); connection.setAutoFlushCommands(false); - RedisModulesAsyncCommands commands = connection.async(); - List> futures = new ArrayList<>(); - for (int position = 0; position < page.getPositionCount(); position++) { - Block rowIdBlock = page.getBlock(rowIdChannel); - if (rowIdBlock.isNull(position)) { - continue; - } - String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8(); - Map map = new HashMap<>(); - for (int channel = 0; channel < columnChannelMapping.size(); channel++) { - RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel)); - Block block = page.getBlock(channel); - if (block.isNull(position)) { + try { + RedisModulesAsyncCommands commands = connection.async(); + List> futures = new ArrayList<>(); + for (int position = 0; position < page.getPositionCount(); position++) { + Block rowIdBlock = page.getBlock(rowIdChannel); + if (rowIdBlock.isNull(position)) { continue; } - String value = RediSearchPageSink.value(column.getType(), block, position); - map.put(column.getName(), value); - } - RedisFuture future = commands.hset(key, map); - futures.add(future); - } - connection.flushCommands(); - LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0])); - connection.setAutoFlushCommands(true); - } - - private String currentValue(String columnName) { - if (RediSearchBuiltinField.isBuiltinColumn(columnName)) { - if (RediSearchBuiltinField.ID.getName().equals(columnName)) { - return currentDoc.getId(); - } - if (RediSearchBuiltinField.SCORE.getName().equals(columnName)) { - return String.valueOf(currentDoc.getScore()); + String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8(); + Map map = new HashMap<>(); + for (int channel = 0; channel < columnChannelMapping.size(); channel++) { + RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel)); + Block block = page.getBlock(channel); + if (block.isNull(position)) { + continue; + } + String value = RediSearchPageSink.value(column.getType(), block, position); + map.put(column.getName(), value); + } + RedisFuture future = commands.hset(key, map); + futures.add(future); } + connection.flushCommands(); + LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0])); + } finally { + connection.setAutoFlushCommands(true); } - return currentDoc.get(columnName); } @Override @@ -194,12 +185,67 @@ public CompletableFuture> finish() { return future; } + private Object currentValue(String columnName) { + if (RediSearchBuiltinField.isKeyColumn(columnName)) { + return currentDoc.get(RediSearchBuiltinField.KEY.getName()); + } + return currentDoc.get(columnName); + } + public static JsonGenerator createJsonGenerator(JsonFactory factory, SliceOutput output) throws IOException { return factory.createGenerator((OutputStream) output); } @Override public void close() { - // nothing to do + try { + iterator.close(); + } catch (Exception e) { + log.error(e, "Could not close cursor iterator"); + } + } + + private static class CursorIterator implements Iterator>, AutoCloseable { + + private final RediSearchSession session; + private final RediSearchTableHandle table; + private Iterator> iterator; + private long cursor; + + public CursorIterator(RediSearchSession session, RediSearchTableHandle table, String[] columnNames) { + this.session = session; + this.table = table; + read(session.aggregate(table, columnNames)); + } + + private void read(AggregateWithCursorResults results) { + this.iterator = results.iterator(); + this.cursor = results.getCursor(); + } + + @Override + public boolean hasNext() { + while (!iterator.hasNext()) { + if (cursor == 0) { + return false; + } + read(session.cursorRead(table, cursor)); + } + return true; + } + + @Override + public Map next() { + return iterator.next(); + } + + @Override + public void close() throws Exception { + if (cursor == 0) { + return; + } + session.cursorDelete(table, cursor); + } + } } diff --git a/src/main/java/com/redis/trino/RediSearchPageSourceAggregate.java b/src/main/java/com/redis/trino/RediSearchPageSourceAggregate.java deleted file mode 100644 index d0d0114..0000000 --- a/src/main/java/com/redis/trino/RediSearchPageSourceAggregate.java +++ /dev/null @@ -1,176 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2022, Redis Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -package com.redis.trino; - -import static com.google.common.base.Verify.verify; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonGenerator; -import com.redis.lettucemod.search.AggregateWithCursorResults; - -import io.airlift.log.Logger; -import io.airlift.slice.SliceOutput; -import io.trino.spi.Page; -import io.trino.spi.PageBuilder; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.connector.ConnectorPageSource; -import io.trino.spi.type.Type; - -public class RediSearchPageSourceAggregate implements ConnectorPageSource { - - private static final Logger log = Logger.get(RediSearchPageSourceAggregate.class); - - private static final int ROWS_PER_REQUEST = 1024; - - private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter(); - private final List columnNames; - private final List columnTypes; - private final CursorIterator iterator; - private Map currentDoc; - private long count; - private boolean finished; - - private final PageBuilder pageBuilder; - - public RediSearchPageSourceAggregate(RediSearchSession rediSearchSession, RediSearchTableHandle tableHandle, - List columns) { - this.iterator = new CursorIterator(rediSearchSession, tableHandle); - this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).collect(Collectors.toList()); - this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType).collect(Collectors.toList()); - this.currentDoc = null; - this.pageBuilder = new PageBuilder(columnTypes); - } - - @Override - public long getCompletedBytes() { - return count; - } - - @Override - public long getReadTimeNanos() { - return 0; - } - - @Override - public boolean isFinished() { - return finished; - } - - @Override - public long getMemoryUsage() { - return 0L; - } - - @Override - public Page getNextPage() { - verify(pageBuilder.isEmpty()); - count = 0; - for (int i = 0; i < ROWS_PER_REQUEST; i++) { - if (!iterator.hasNext()) { - finished = true; - break; - } - currentDoc = iterator.next(); - count++; - - pageBuilder.declarePosition(); - for (int column = 0; column < columnTypes.size(); column++) { - BlockBuilder output = pageBuilder.getBlockBuilder(column); - Object value = currentDoc.get(columnNames.get(column)); - if (value == null) { - output.appendNull(); - } else { - writer.appendTo(columnTypes.get(column), value.toString(), output); - } - } - } - Page page = pageBuilder.build(); - pageBuilder.reset(); - return page; - } - - public static JsonGenerator createJsonGenerator(JsonFactory factory, SliceOutput output) throws IOException { - return factory.createGenerator((OutputStream) output); - } - - @Override - public void close() { - try { - iterator.close(); - } catch (Exception e) { - log.error(e, "Could not close cursor iterator"); - } - } - - private static class CursorIterator implements Iterator>, AutoCloseable { - - private final RediSearchSession session; - private final RediSearchTableHandle tableHandle; - private Iterator> iterator; - private long cursor; - - public CursorIterator(RediSearchSession session, RediSearchTableHandle tableHandle) { - this.session = session; - this.tableHandle = tableHandle; - read(session.aggregate(tableHandle)); - } - - private void read(AggregateWithCursorResults results) { - this.iterator = results.iterator(); - this.cursor = results.getCursor(); - } - - @Override - public boolean hasNext() { - while (!iterator.hasNext()) { - if (cursor == 0) { - return false; - } - read(session.cursorRead(tableHandle, cursor)); - } - return true; - } - - @Override - public Map next() { - return iterator.next(); - } - - @Override - public void close() throws Exception { - if (cursor == 0) { - return; - } - session.cursorDelete(tableHandle, cursor); - } - - } -} diff --git a/src/main/java/com/redis/trino/RediSearchPageSourceProvider.java b/src/main/java/com/redis/trino/RediSearchPageSourceProvider.java index 0c1e056..40d4ae9 100644 --- a/src/main/java/com/redis/trino/RediSearchPageSourceProvider.java +++ b/src/main/java/com/redis/trino/RediSearchPageSourceProvider.java @@ -23,8 +23,13 @@ */ package com.redis.trino; +import static java.util.Objects.requireNonNull; + +import java.util.List; + +import javax.inject.Inject; + import com.google.common.collect.ImmutableList; -import com.redis.trino.RediSearchTableHandle.Type; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; @@ -35,12 +40,6 @@ import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.DynamicFilter; -import javax.inject.Inject; - -import java.util.List; - -import static java.util.Objects.requireNonNull; - public class RediSearchPageSourceProvider implements ConnectorPageSourceProvider { private final RediSearchSession rediSearchSession; @@ -53,15 +52,11 @@ public RediSearchPageSourceProvider(RediSearchSession rediSearchSession) { public ConnectorPageSource createPageSource(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorSplit split, ConnectorTableHandle table, List columns, DynamicFilter dynamicFilter) { RediSearchTableHandle tableHandle = (RediSearchTableHandle) table; - ImmutableList.Builder handles = ImmutableList.builder(); for (ColumnHandle handle : requireNonNull(columns, "columns is null")) { handles.add((RediSearchColumnHandle) handle); } ImmutableList columnHandles = handles.build(); - if (tableHandle.getType() == Type.AGGREGATE) { - return new RediSearchPageSourceAggregate(rediSearchSession, tableHandle, columnHandles); - } return new RediSearchPageSource(rediSearchSession, tableHandle, columnHandles); } } diff --git a/src/main/java/com/redis/trino/RediSearchSession.java b/src/main/java/com/redis/trino/RediSearchSession.java index df65b05..e824efe 100644 --- a/src/main/java/com/redis/trino/RediSearchSession.java +++ b/src/main/java/com/redis/trino/RediSearchSession.java @@ -190,7 +190,7 @@ public RediSearchTable getTable(SchemaTableName tableName) throws TableNotFoundE public void createTable(SchemaTableName schemaTableName, List columns) { String index = index(schemaTableName); if (!connection.sync().ftList().contains(index)) { - List> fields = columns.stream().filter(c -> !c.getName().equals("_id")) + List> fields = columns.stream().filter(c -> !RediSearchBuiltinField.isKeyColumn(c.getName())) .map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toList()); CreateOptions.Builder options = CreateOptions.builder(); options.prefix(index + ":"); @@ -260,9 +260,7 @@ private RediSearchTable loadTableSchema(SchemaTableName schemaTableName) throws fields.add(docField); } } - RediSearchTableHandle tableHandle = new RediSearchTableHandle(RediSearchTableHandle.Type.SEARCH, - schemaTableName); - return new RediSearchTable(tableHandle, columns.build(), indexInfo); + return new RediSearchTable(new RediSearchTableHandle(schemaTableName), columns.build(), indexInfo); } private Optional indexInfo(String index) { @@ -308,8 +306,8 @@ public SearchResults search(RediSearchTableHandle tableHandle, S return connection.sync().ftSearch(search.getIndex(), search.getQuery(), search.getOptions()); } - public AggregateWithCursorResults aggregate(RediSearchTableHandle table) { - Aggregation aggregation = translator.aggregate(table); + public AggregateWithCursorResults aggregate(RediSearchTableHandle table, String[] columnNames) { + Aggregation aggregation = translator.aggregate(table, columnNames); log.info("Running %s", aggregation); String index = aggregation.getIndex(); String query = aggregation.getQuery(); diff --git a/src/main/java/com/redis/trino/RediSearchTableHandle.java b/src/main/java/com/redis/trino/RediSearchTableHandle.java index 6479ccb..01a3080 100644 --- a/src/main/java/com/redis/trino/RediSearchTableHandle.java +++ b/src/main/java/com/redis/trino/RediSearchTableHandle.java @@ -46,7 +46,6 @@ public enum Type { SEARCH, AGGREGATE } - private final Type type; private final SchemaTableName schemaTableName; private final TupleDomain constraint; private final OptionalLong limit; @@ -57,20 +56,18 @@ public enum Type { // UPDATE only private final List updatedColumns; - public RediSearchTableHandle(Type type, SchemaTableName schemaTableName) { - this(type, schemaTableName, TupleDomain.all(), OptionalLong.empty(), Collections.emptyList(), - Collections.emptyList(), Map.of(), Collections.emptyList()); + public RediSearchTableHandle(SchemaTableName schemaTableName) { + this(schemaTableName, TupleDomain.all(), OptionalLong.empty(), Collections.emptyList(), Collections.emptyList(), + Map.of(), Collections.emptyList()); } @JsonCreator - public RediSearchTableHandle(@JsonProperty("type") Type type, - @JsonProperty("schemaTableName") SchemaTableName schemaTableName, + public RediSearchTableHandle(@JsonProperty("schemaTableName") SchemaTableName schemaTableName, @JsonProperty("constraint") TupleDomain constraint, @JsonProperty("limit") OptionalLong limit, @JsonProperty("aggTerms") List termAggregations, @JsonProperty("aggregates") List metricAggregations, @JsonProperty("wildcards") Map wildcards, @JsonProperty("updatedColumns") List updatedColumns) { - this.type = requireNonNull(type, "type is null"); this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); this.constraint = requireNonNull(constraint, "constraint is null"); this.limit = requireNonNull(limit, "limit is null"); @@ -80,11 +77,6 @@ public RediSearchTableHandle(@JsonProperty("type") Type type, this.updatedColumns = ImmutableList.copyOf(requireNonNull(updatedColumns, "updatedColumns is null")); } - @JsonProperty - public Type getType() { - return type; - } - @JsonProperty public SchemaTableName getSchemaTableName() { return schemaTableName; diff --git a/src/main/java/com/redis/trino/RediSearchTranslator.java b/src/main/java/com/redis/trino/RediSearchTranslator.java index b024041..2a87439 100644 --- a/src/main/java/com/redis/trino/RediSearchTranslator.java +++ b/src/main/java/com/redis/trino/RediSearchTranslator.java @@ -216,10 +216,12 @@ public Search search(RediSearchTableHandle table, String[] columnNames) { return Search.builder().index(index).query(query).options(options.build()).build(); } - public Aggregation aggregate(RediSearchTableHandle table) { + public Aggregation aggregate(RediSearchTableHandle table, String[] columnNames) { String index = index(table); String query = queryBuilder.buildQuery(table.getConstraint(), table.getWildcards()); AggregateOptions.Builder builder = AggregateOptions.builder(); + builder.load(RediSearchBuiltinField.KEY.getName()); + builder.loads(columnNames); queryBuilder.group(table).ifPresent(builder::operation); builder.operation(Limit.offset(0).num(limit(table))); AggregateOptions options = builder.build(); diff --git a/src/test/java/com/redis/trino/TestRediSearchConnectorSmokeTest.java b/src/test/java/com/redis/trino/TestRediSearchConnectorSmokeTest.java index e35411d..ce0787c 100644 --- a/src/test/java/com/redis/trino/TestRediSearchConnectorSmokeTest.java +++ b/src/test/java/com/redis/trino/TestRediSearchConnectorSmokeTest.java @@ -113,7 +113,7 @@ protected void assertQuery(String sql) { public void testRediSearchFields() throws IOException, InterruptedException { populateBeers(); getQueryRunner().execute("select id, last_mod from beers"); - getQueryRunner().execute("select _id, _score from beers"); + getQueryRunner().execute("select __key from beers"); } @SuppressWarnings("unchecked") diff --git a/src/test/java/com/redis/trino/TestRediSearchTableHandle.java b/src/test/java/com/redis/trino/TestRediSearchTableHandle.java index bd15282..d5adc21 100644 --- a/src/test/java/com/redis/trino/TestRediSearchTableHandle.java +++ b/src/test/java/com/redis/trino/TestRediSearchTableHandle.java @@ -4,8 +4,6 @@ import org.testng.annotations.Test; -import com.redis.trino.RediSearchTableHandle.Type; - import io.airlift.json.JsonCodec; import io.trino.spi.connector.SchemaTableName; @@ -14,7 +12,7 @@ public class TestRediSearchTableHandle { @Test public void testRoundTrip() { - RediSearchTableHandle expected = new RediSearchTableHandle(Type.SEARCH, new SchemaTableName("schema", "table")); + RediSearchTableHandle expected = new RediSearchTableHandle(new SchemaTableName("schema", "table")); String json = codec.toJson(expected); RediSearchTableHandle actual = codec.fromJson(json);