Skip to content

Commit

Permalink
Added support for aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Ruaux committed Mar 24, 2022
1 parent d32f9dd commit 09cbf73
Show file tree
Hide file tree
Showing 15 changed files with 814 additions and 495 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package com.redis.trino;

import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.Type;

public class MetricAggregation {
public static final String MAX = "max";
public static final String MIN = "min";
public static final String AVG = "avg";
public static final String SUM = "sum";
public static final String COUNT = "count";
private static final List<String> SUPPORTED_AGGREGATION_FUNCTIONS = Arrays.asList(MAX, MIN, AVG, SUM, COUNT);
private static final List<Type> NUMERIC_TYPES = Arrays.asList(REAL, DOUBLE, TINYINT, SMALLINT, INTEGER, BIGINT);
private final String functionName;
private final Type outputType;
private final Optional<RediSearchColumnHandle> columnHandle;
private final String alias;

@JsonCreator
public MetricAggregation(@JsonProperty("functionName") String functionName,
@JsonProperty("outputType") Type outputType,
@JsonProperty("columnHandle") Optional<RediSearchColumnHandle> columnHandle,
@JsonProperty("alias") String alias) {
this.functionName = functionName;
this.outputType = outputType;
this.columnHandle = columnHandle;
this.alias = alias;
}

@JsonProperty
public String getFunctionName() {
return functionName;
}

@JsonProperty
public Type getOutputType() {
return outputType;
}

@JsonProperty
public Optional<RediSearchColumnHandle> getColumnHandle() {
return columnHandle;
}

@JsonProperty
public String getAlias() {
return alias;
}

public static boolean isNumericType(Type type) {
return NUMERIC_TYPES.contains(type);
}

public static Optional<MetricAggregation> handleAggregation(AggregateFunction function,
Map<String, ColumnHandle> assignments, String alias) {
if (!SUPPORTED_AGGREGATION_FUNCTIONS.contains(function.getFunctionName())) {
return Optional.empty();
}
// check
// 1. Function input can be found in assignments
// 2. Target type of column being aggregate must be numeric type
// 3. ColumnHandle support predicates(since text treats as VARCHAR, but text can
// not be treats as term in es by default
Optional<RediSearchColumnHandle> parameterColumnHandle = function.getArguments().stream()
.filter(Variable.class::isInstance).map(Variable.class::cast).map(Variable::getName)
.filter(assignments::containsKey).findFirst().map(assignments::get)
.map(RediSearchColumnHandle.class::cast)
.filter(column -> MetricAggregation.isNumericType(column.getType()));
// only count can accept empty ElasticsearchColumnHandle
if (!COUNT.equals(function.getFunctionName()) && parameterColumnHandle.isEmpty()) {
return Optional.empty();
}
return Optional.of(new MetricAggregation(function.getFunctionName(), function.getOutputType(),
parameterColumnHandle, alias));
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
MetricAggregation that = (MetricAggregation) o;
return Objects.equals(functionName, that.functionName) && Objects.equals(outputType, that.outputType)
&& Objects.equals(columnHandle, that.columnHandle) && Objects.equals(alias, that.alias);
}

@Override
public int hashCode() {
return Objects.hash(functionName, outputType, columnHandle, alias);
}

@Override
public String toString() {
return String.format("%s(%s)", functionName, columnHandle.map(RediSearchColumnHandle::getName).orElse(""));
}
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
package com.redis.trino;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.concurrent.atomic.AtomicReference;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.redis.trino.RediSearchTableHandle.Type;

import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.AggregationApplicationResult;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorInsertTableHandle;
Expand All @@ -35,11 +40,17 @@
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SchemaTablePrefix;
import io.trino.spi.connector.TableNotFoundException;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.statistics.ComputedStatistics;

public class RediSearchMetadata implements ConnectorMetadata {

private static final Logger log = Logger.get(RediSearchMetadata.class);

private static final String SYNTHETIC_COLUMN_NAME_PREFIX = "syntheticColumn";

private final RediSearchSession rediSearchSession;
private final String schemaName;
private final AtomicReference<Runnable> rollbackAction = new AtomicReference<>();
Expand Down Expand Up @@ -199,12 +210,14 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
return Optional.empty();
}

if (handle.getLimit().isPresent() && handle.getLimit().getAsInt() <= limit) {
if (handle.getLimit().isPresent() && handle.getLimit().getAsLong() <= limit) {
return Optional.empty();
}

return Optional.of(new LimitApplicationResult<>(new RediSearchTableHandle(handle.getSchemaTableName(),
handle.getConstraint(), OptionalInt.of(toIntExact(limit))), true, false));
return Optional.of(new LimitApplicationResult<>(
new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), handle.getConstraint(),
OptionalLong.of(limit), handle.getTermAggregations(), handle.getMetricAggregations()),
true, false));
}

@Override
Expand All @@ -218,11 +231,57 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
return Optional.empty();
}

handle = new RediSearchTableHandle(handle.getSchemaTableName(), newDomain, handle.getLimit());
handle = new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), newDomain, handle.getLimit(),
handle.getTermAggregations(), handle.getMetricAggregations());

return Optional.of(new ConstraintApplicationResult<>(handle, constraint.getSummary(), false));
}

@Override
public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggregation(ConnectorSession session,
ConnectorTableHandle handle, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets) {
log.info("applyAggregation aggregates=%s groupingSets=%s", aggregates, groupingSets);
RediSearchTableHandle table = (RediSearchTableHandle) handle;
// Global aggregation is represented by [[]]
verify(!groupingSets.isEmpty(), "No grouping sets provided");
if (!table.getTermAggregations().isEmpty()) {
return Optional.empty();
}
ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
ImmutableList.Builder<MetricAggregation> metricAggregations = ImmutableList.builder();
ImmutableList.Builder<TermAggregation> termAggregations = ImmutableList.builder();
for (int i = 0; i < aggregates.size(); i++) {
AggregateFunction function = aggregates.get(i);
String colName = SYNTHETIC_COLUMN_NAME_PREFIX + i;
Optional<MetricAggregation> metricAggregation = MetricAggregation.handleAggregation(function, assignments,
colName);
if (metricAggregation.isEmpty()) {
return Optional.empty();
}
RediSearchColumnHandle newColumn = new RediSearchColumnHandle(colName, function.getOutputType(), false);
projections.add(new Variable(colName, function.getOutputType()));
resultAssignments.add(new Assignment(colName, newColumn, function.getOutputType()));
metricAggregations.add(metricAggregation.get());
}
for (ColumnHandle columnHandle : groupingSets.get(0)) {
Optional<TermAggregation> termAggregation = TermAggregation.fromColumnHandle(columnHandle);
if (termAggregation.isEmpty()) {
return Optional.empty();
}
termAggregations.add(termAggregation.get());
}
ImmutableList<MetricAggregation> metrics = metricAggregations.build();
if (metrics.isEmpty()) {
return Optional.empty();
}
RediSearchTableHandle tableHandle = new RediSearchTableHandle(Type.AGGREGATE, table.getSchemaTableName(),
table.getConstraint(), table.getLimit(), termAggregations.build(), metrics);
return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(),
resultAssignments.build(), ImmutableMap.of(), false));
}

private void setRollback(Runnable action) {
checkState(rollbackAction.compareAndSet(null, action), "rollback action is already set");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,55 +1,29 @@
package com.redis.trino;

import static com.google.common.base.Verify.verify;
import static com.redis.trino.TypeUtils.isJsonType;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.plugin.base.util.JsonTypeUtil.jsonParse;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.Chars.truncateToLengthAndTrimSpaces;
import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.Decimals.encodeScaledValue;
import static io.trino.spi.type.Decimals.encodeShortScaledValue;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TimeZoneKey.UTC_KEY;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND;
import static io.trino.spi.type.TinyintType.TINYINT;
import static java.lang.Float.floatToIntBits;
import static java.util.stream.Collectors.toList;

import java.io.IOException;
import java.io.OutputStream;
import java.math.BigDecimal;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.Iterator;
import java.util.List;

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.google.common.primitives.SignedBytes;
import com.redis.lettucemod.search.Document;

import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.connector.ConnectorPageSource;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;

public class RediSearchPageSource implements ConnectorPageSource {

private static final int ROWS_PER_REQUEST = 1024;

private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter();
private final Iterator<Document<String, String>> cursor;
private final List<String> columnNames;
private final List<Type> columnTypes;
Expand All @@ -63,7 +37,7 @@ public RediSearchPageSource(RediSearchSession rediSearchSession, RediSearchTable
List<RediSearchColumnHandle> columns) {
this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).collect(toList());
this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType).collect(toList());
this.cursor = rediSearchSession.execute(tableHandle).iterator();
this.cursor = rediSearchSession.search(tableHandle).iterator();
this.currentDoc = null;
this.pageBuilder = new PageBuilder(columnTypes);
}
Expand Down Expand Up @@ -103,7 +77,12 @@ public Page getNextPage() {
pageBuilder.declarePosition();
for (int column = 0; column < columnTypes.size(); column++) {
BlockBuilder output = pageBuilder.getBlockBuilder(column);
appendTo(columnTypes.get(column), currentDoc.get(columnNames.get(column)), output);
String value = currentDoc.get(columnNames.get(column));
if (value == null) {
output.appendNull();
} else {
writer.appendTo(columnTypes.get(column), value, output);
}
}
}

Expand All @@ -112,67 +91,12 @@ public Page getNextPage() {
return page;
}

private void appendTo(Type type, String value, BlockBuilder output) {
if (value == null) {
output.appendNull();
return;
}
Class<?> javaType = type.getJavaType();
if (javaType == boolean.class) {
type.writeBoolean(output, Boolean.parseBoolean(value));
} else if (javaType == long.class) {
if (type.equals(BIGINT)) {
type.writeLong(output, Long.parseLong(value));
} else if (type.equals(INTEGER)) {
type.writeLong(output, Integer.parseInt(value));
} else if (type.equals(SMALLINT)) {
type.writeLong(output, Short.parseShort(value));
} else if (type.equals(TINYINT)) {
type.writeLong(output, SignedBytes.checkedCast(Long.parseLong(value)));
} else if (type.equals(REAL)) {
type.writeLong(output, floatToIntBits((Float.parseFloat(value))));
} else if (type instanceof DecimalType) {
type.writeLong(output, encodeShortScaledValue(new BigDecimal(value), ((DecimalType) type).getScale()));
} else if (type.equals(DATE)) {
type.writeLong(output, LocalDate.from(DateTimeFormatter.ISO_DATE.parse(value)).toEpochDay());
} else if (type.equals(TIMESTAMP_MILLIS)) {
type.writeLong(output, Long.parseLong(value) * MICROSECONDS_PER_MILLISECOND);
} else if (type.equals(TIMESTAMP_TZ_MILLIS)) {
type.writeLong(output, packDateTimeWithZone(Long.parseLong(value), UTC_KEY));
} else {
throw new TrinoException(GENERIC_INTERNAL_ERROR,
"Unhandled type for " + javaType.getSimpleName() + ":" + type.getTypeSignature());
}
} else if (javaType == double.class) {
type.writeDouble(output, Double.parseDouble(value));
} else if (javaType == Slice.class) {
writeSlice(output, type, value);
} else {
throw new TrinoException(GENERIC_INTERNAL_ERROR,
"Unhandled type for " + javaType.getSimpleName() + ":" + type.getTypeSignature());
}
}

private void writeSlice(BlockBuilder output, Type type, String value) {
if (type instanceof VarcharType) {
type.writeSlice(output, utf8Slice(value));
} else if (type instanceof CharType) {
type.writeSlice(output, truncateToLengthAndTrimSpaces(utf8Slice(value), ((CharType) type)));
} else if (type instanceof DecimalType) {
type.writeObject(output, encodeScaledValue(new BigDecimal(value), ((DecimalType) type).getScale()));
} else if (isJsonType(type)) {
type.writeSlice(output, jsonParse(utf8Slice(value)));
} else {
throw new TrinoException(GENERIC_INTERNAL_ERROR, "Unhandled type for Slice: " + type.getTypeSignature());
}
}

public static JsonGenerator createJsonGenerator(JsonFactory factory, SliceOutput output) throws IOException {
return factory.createGenerator((OutputStream) output);
}

@Override
public void close() {

// nothing to do
}
}
Loading

0 comments on commit 09cbf73

Please sign in to comment.