diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulator.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulator.java index 8a126fb5603..97d2f933e77 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulator.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulator.java @@ -18,6 +18,7 @@ package org.apache.ignite.internal.sql.engine.exec.exp.agg; import java.util.List; +import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory; @@ -55,6 +56,8 @@ public interface Accumulator { * * @param typeFactory Type factory. * @return A result type. + * @deprecated Use {@link AggregateCall#getType()} instead. */ + @Deprecated RelDataType returnType(IgniteTypeFactory typeFactory); } diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java index ad639348891..a3ebecfeebf 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java @@ -35,6 +35,8 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.ignite.internal.catalog.commands.CatalogUtils; +import org.apache.ignite.internal.sql.engine.exec.exp.IgniteSqlFunctions; import org.apache.ignite.internal.sql.engine.type.IgniteCustomType; import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory; import org.apache.ignite.internal.sql.engine.util.IgniteMath; @@ -98,10 +100,10 @@ private Supplier avgFactory(AggregateCall call) { case SMALLINT: case INTEGER: case BIGINT: - return () -> DecimalAvg.FACTORY.apply(0); + return () -> new DecimalAvg(call.type.getPrecision(), call.type.getScale()); case DECIMAL: // TODO: https://issues.apache.org/jira/browse/IGNITE-17373 Add support for interval types. - return () -> DecimalAvg.FACTORY.apply(call.type.getScale()); + return () -> new DecimalAvg(call.type.getPrecision(), call.type.getScale()); case DOUBLE: case REAL: case FLOAT: @@ -117,8 +119,9 @@ private Supplier avgFactory(AggregateCall call) { private Supplier sumFactory(AggregateCall call) { switch (call.type.getSqlTypeName()) { case BIGINT: + return () -> new Sum(new LongSumEmptyIsZero()); case DECIMAL: - return () -> new Sum(new DecimalSumEmptyIsZero()); + return () -> new Sum(new DecimalSumEmptyIsZero(call.type.getPrecision(), call.type.getScale())); case DOUBLE: case REAL: @@ -143,7 +146,7 @@ private Supplier sumEmptyIsZeroFactory(AggregateCall call) { // Used by REDUCE phase of COUNT aggregate. return LongSumEmptyIsZero.FACTORY; case DECIMAL: - return DecimalSumEmptyIsZero.FACTORY; + return () -> new DecimalSumEmptyIsZero(call.type.getPrecision(), call.type.getScale()); case DOUBLE: case REAL: @@ -333,11 +336,17 @@ public static class DecimalAvgState { private final int scale; + @Deprecated DecimalAvg(int scale) { this.precision = RelDataType.PRECISION_NOT_SPECIFIED; this.scale = scale; } + public DecimalAvg(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + /** {@inheritDoc} */ @Override public void add(AccumulatorsState state, Object... args) { @@ -431,7 +440,6 @@ public void end(AccumulatorsState state, AccumulatorsState result) { result.set(null); } } - } /** {@inheritDoc} */ @@ -625,7 +633,22 @@ public RelDataType returnType(IgniteTypeFactory typeFactory) { /** SUM(DECIMAL) accumulator. */ public static class DecimalSumEmptyIsZero implements Accumulator { - public static final Supplier FACTORY = DecimalSumEmptyIsZero::new; + public static final IntFunction FACTORY = DecimalSumEmptyIsZero::new; + + private final int precision; + + private final int scale; + + @Deprecated + private DecimalSumEmptyIsZero(int scale) { + this.precision = CatalogUtils.MAX_DECIMAL_PRECISION; + this.scale = scale; + } + + public DecimalSumEmptyIsZero(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } /** {@inheritDoc} */ @Override @@ -650,7 +673,8 @@ public void end(AccumulatorsState state, AccumulatorsState result) { if (!state.hasValue()) { result.set(BigDecimal.ZERO); } else { - result.set(state.get()); + BigDecimal value = (BigDecimal) state.get(); + result.set(IgniteSqlFunctions.toBigDecimal(value, precision, scale)); } } @@ -663,7 +687,7 @@ public List argumentTypes(IgniteTypeFactory typeFactory) { /** {@inheritDoc} */ @Override public RelDataType returnType(IgniteTypeFactory typeFactory) { - return typeFactory.createTypeWithNullability(typeFactory.createSqlType(DECIMAL), false); + return typeFactory.createTypeWithNullability(typeFactory.createSqlType(DECIMAL, precision, scale), false); } } diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsFactory.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsFactory.java index 848c6d6a097..f8e2c2c0773 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsFactory.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsFactory.java @@ -195,7 +195,7 @@ private Accumulator accumulator() { Accumulator accumulator = accFactory.get(); inAdapter = createInAdapter(accumulator); - outAdapter = createOutAdapter(accumulator); + outAdapter = Function.identity(); return accumulator; } @@ -235,11 +235,9 @@ private Function createOutAdapter(Accumulator accumulator) { if (type == AggregateType.MAP) { return Function.identity(); } - - RelDataType inType = accumulator.returnType(ctx.getTypeFactory()); RelDataType outType = call.getType(); - return cast(inType, outType); + return cast(outType, outType); } private RelDataType nonNull(RelDataType type) { diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsState.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsState.java index d7fa4729080..ebd4f9487aa 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsState.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsState.java @@ -17,7 +17,11 @@ package org.apache.ignite.internal.sql.engine.exec.exp.agg; +import java.util.Arrays; import java.util.BitSet; +import java.util.List; +import java.util.Objects; +import org.apache.ignite.internal.tostring.S; import org.jetbrains.annotations.Nullable; /** @@ -36,6 +40,12 @@ public AccumulatorsState(int rowSize) { this.row = new Object[rowSize]; } + /** Creates a copy from the given state. */ + public AccumulatorsState(AccumulatorsState src) { + this.row = new Object[src.row.length]; + System.arraycopy(src.row, 0, this.row, 0, src.row.length); + } + /** Sets current field index. */ public void setIndex(int i) { this.index = i; @@ -61,4 +71,39 @@ public void set(@Nullable Object value) { public boolean hasValue() { return set.get(index); } + + /** The number of elements. */ + public int size() { + return row.length; + } + + /** Elements of this state as list. */ + public List toList() { + return Arrays.asList(row); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AccumulatorsState state = (AccumulatorsState) o; + return Objects.deepEquals(row, state.row); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(row); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return S.toString(this); + } } diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java index 90fe27e6146..d1b84c285d1 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java @@ -30,8 +30,6 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.mapping.Mapping; -import org.apache.ignite.internal.sql.engine.exec.exp.agg.Accumulator; -import org.apache.ignite.internal.sql.engine.exec.exp.agg.Accumulators; import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory; /** @@ -80,7 +78,7 @@ public static RelDataType createSortAggRowType(ImmutableBitSet grpKeys, builder.add(fld); } - addAccumulatorFields(typeFactory, aggregateCalls, builder); + addAccumulatorFields(aggregateCalls, builder); return builder.build(); } @@ -115,32 +113,17 @@ public static RelDataType createHashAggRowType(List groupSets, builder.add(fld); } - addAccumulatorFields(typeFactory, aggregateCalls, builder); + addAccumulatorFields(aggregateCalls, builder); builder.add("_GROUP_ID", SqlTypeName.TINYINT); return builder.build(); } - private static void addAccumulatorFields(IgniteTypeFactory typeFactory, List aggregateCalls, Builder builder) { - Accumulators accumulators = new Accumulators(typeFactory); - + private static void addAccumulatorFields(List aggregateCalls, Builder builder) { for (int i = 0; i < aggregateCalls.size(); i++) { AggregateCall call = aggregateCalls.get(i); - Accumulator acc = accumulators.accumulatorFactory(call).get(); - RelDataType fieldType; - // For a decimal type Accumulator::returnType returns a type with default precision and scale, - // that can cause precision loss when a tuple is sent over the wire by an exchanger/outbox. - // Outbox uses its input type as wire format, so if a scale is 0, then the scale is lost - // (see Outbox::sendBatch -> RowHandler::toBinaryTuple -> BinaryTupleBuilder::appendDecimalNotNull). - if (call.getType().getSqlTypeName().allowsScale()) { - fieldType = call.type; - } else { - fieldType = acc.returnType(typeFactory); - } - String fieldName = "_ACC" + i; - - builder.add(fieldName, fieldType); + builder.add("_ACC" + i, call.type); } } } diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/TypeUtils.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/TypeUtils.java index 760e7c932b6..05ecadcb8c7 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/TypeUtils.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/TypeUtils.java @@ -341,6 +341,7 @@ public static ColumnType columnType(RelDataType type) { return ColumnType.FLOAT; case BINARY: case VARBINARY: + return ColumnType.BYTE_ARRAY; case ANY: if (type instanceof IgniteCustomType) { IgniteCustomType customType = (IgniteCustomType) type; @@ -652,14 +653,15 @@ public static RowSchema rowSchemaFromRelTypes(List types) { RowSchema.Builder fieldTypes = RowSchema.builder(); for (RelDataType relType : types) { - TypeSpec typeSpec = convertToTypeSpec(relType); + TypeSpec typeSpec = relational2rowSchemaType(relType); fieldTypes.addField(typeSpec); } return fieldTypes.build(); } - private static TypeSpec convertToTypeSpec(RelDataType type) { + /** Converts the given relational data type to its {@link TypeSpec runtime type}. */ + public static TypeSpec relational2rowSchemaType(RelDataType type) { boolean simpleType = type instanceof BasicSqlType; boolean nullable = type.isNullable(); @@ -689,7 +691,7 @@ private static TypeSpec convertToTypeSpec(RelDataType type) { List fields = new ArrayList<>(); for (RelDataTypeField field : type.getFieldList()) { - TypeSpec fieldTypeSpec = convertToTypeSpec(field.getType()); + TypeSpec fieldTypeSpec = relational2rowSchemaType(field.getType()); fields.add(fieldTypeSpec); } @@ -811,6 +813,20 @@ public static boolean typesRepresentTheSameColumnTypes(RelDataType lhs, RelDataT } } + /** + * Converts relational type to a natively supported type if it is possible. + * This method returns {@code null} when the given relational type is {@link SqlTypeName#NULL}. + */ + public static @Nullable NativeType relational2nativeType(RelDataType relDataType) { + TypeSpec typeSpec = relational2rowSchemaType(relDataType); + if (typeSpec == RowSchemaTypes.NULL) { + return null; + } else { + BaseTypeSpec baseTypeSpec = (BaseTypeSpec) typeSpec; + return baseTypeSpec.nativeType(); + } + } + private static boolean isCustomType(RelDataType type) { return type instanceof IgniteCustomType; } diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/SumAccumulatorTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/SumAccumulatorTest.java index a9a87b1cafc..01a7b49930a 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/SumAccumulatorTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/SumAccumulatorTest.java @@ -54,8 +54,19 @@ private static Stream testArgs() { return Stream.of( Arguments.of(namedAccumulator(DoubleSumEmptyIsZero.FACTORY), 4.0d, new Object[]{3.0d, 1.0d}), Arguments.of(namedAccumulator(LongSumEmptyIsZero.FACTORY), 4L, new Object[]{3L, 1L}), - Arguments.of(namedAccumulator(DecimalSumEmptyIsZero.FACTORY), new BigDecimal("3.4"), - new Object[]{new BigDecimal("1.3"), new BigDecimal("2.1")}) + + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(1)), new BigDecimal("3.4"), + new Object[]{new BigDecimal("1.3"), new BigDecimal("2.1")}), + + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(1)), new BigDecimal("3.4"), + new Object[]{new BigDecimal("1.31"), new BigDecimal("2.13")}), + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(1)), new BigDecimal("3.5"), + new Object[]{new BigDecimal("1.32"), new BigDecimal("2.13")}), + + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(2)), new BigDecimal("3.44"), + new Object[]{new BigDecimal("1.31"), new BigDecimal("2.13")}), + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(2)), new BigDecimal("3.45"), + new Object[]{new BigDecimal("1.32"), new BigDecimal("2.13")}) ); } @@ -71,7 +82,7 @@ private static Stream emptyArgs() { return Stream.of( Arguments.of(namedAccumulator(DoubleSumEmptyIsZero.FACTORY)), Arguments.of(namedAccumulator(LongSumEmptyIsZero.FACTORY)), - Arguments.of(namedAccumulator(DecimalSumEmptyIsZero.FACTORY)) + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(1))) ); } @@ -81,6 +92,6 @@ private StatefulAccumulator newCall(Accumulator accumulator) { private static Named namedAccumulator(Supplier supplier) { Accumulator accumulator = supplier.get(); - return Named.of(accumulator.getClass().getName(), accumulator); + return Named.of(accumulator.getClass().getSimpleName(), accumulator); } } diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/SumIsZeroAccumulatorTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/SumIsZeroAccumulatorTest.java index f27db0c851c..b08a3a521bd 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/SumIsZeroAccumulatorTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/SumIsZeroAccumulatorTest.java @@ -54,8 +54,19 @@ private static Stream testArgs() { return Stream.of( Arguments.of(namedAccumulator(DoubleSumEmptyIsZero.FACTORY), 4.0d, new Object[]{3.0d, 1.0d}), Arguments.of(namedAccumulator(LongSumEmptyIsZero.FACTORY), 4L, new Object[]{3L, 1L}), - Arguments.of(namedAccumulator(DecimalSumEmptyIsZero.FACTORY), new BigDecimal("3.4"), - new Object[]{new BigDecimal("1.3"), new BigDecimal("2.1")}) + + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(1)), new BigDecimal("3.4"), + new Object[]{new BigDecimal("1.3"), new BigDecimal("2.1")}), + + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(1)), new BigDecimal("3.4"), + new Object[]{new BigDecimal("1.31"), new BigDecimal("2.13")}), + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(1)), new BigDecimal("3.5"), + new Object[]{new BigDecimal("1.32"), new BigDecimal("2.13")}), + + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(2)), new BigDecimal("3.44"), + new Object[]{new BigDecimal("1.31"), new BigDecimal("2.13")}), + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(2)), new BigDecimal("3.45"), + new Object[]{new BigDecimal("1.32"), new BigDecimal("2.13")}) ); } @@ -71,7 +82,7 @@ private static Stream zeroArgs() { return Stream.of( Arguments.of(namedAccumulator(DoubleSumEmptyIsZero.FACTORY), 0.0d), Arguments.of(namedAccumulator(LongSumEmptyIsZero.FACTORY), 0L), - Arguments.of(namedAccumulator(DecimalSumEmptyIsZero.FACTORY), BigDecimal.ZERO) + Arguments.of(namedAccumulator(() -> DecimalSumEmptyIsZero.FACTORY.apply(1)), BigDecimal.ZERO) ); } @@ -81,6 +92,6 @@ private static StatefulAccumulator newCall(Accumulator sum) { private static Named namedAccumulator(Supplier supplier) { Accumulator accumulator = supplier.get(); - return Named.of(accumulator.getClass().getName(), accumulator); + return Named.of(accumulator.getClass().getSimpleName(), accumulator); } } diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/BaseAggregateTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/BaseAggregateTest.java index 8455bd6ecfc..b4ba83d2070 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/BaseAggregateTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/BaseAggregateTest.java @@ -38,6 +38,7 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.ImmutableIntList; import org.apache.ignite.internal.sql.engine.exec.ExecutionContext; @@ -295,7 +296,7 @@ public void avg(TestAggregateType testAgg) { -1, null, RelCollations.EMPTY, - tf.createJavaType(int.class), + tf.createSqlType(SqlTypeName.INTEGER), null); List grpSets = List.of(ImmutableBitSet.of(0)); @@ -583,7 +584,7 @@ public void sumIntegerOverflow(TestAggregateType testAgg) { -1, null, RelCollations.EMPTY, - tf.createJavaType(Long.class), + tf.createSqlType(SqlTypeName.BIGINT), null); List grpSets = List.of(ImmutableBitSet.of(0)); @@ -632,7 +633,7 @@ public void sumLongOverflow(TestAggregateType testAgg) { -1, null, RelCollations.EMPTY, - tf.createJavaType(BigDecimal.class), + tf.createSqlType(SqlTypeName.DECIMAL), null); List grpSets = List.of(ImmutableBitSet.of(0)); diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/type/IgniteTypeSystemTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/type/IgniteTypeSystemTest.java index 47c469619ff..cbee1903363 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/type/IgniteTypeSystemTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/type/IgniteTypeSystemTest.java @@ -17,22 +17,44 @@ package org.apache.ignite.internal.sql.engine.type; +import static java.util.UUID.randomUUID; import static org.apache.ignite.internal.sql.engine.type.IgniteTypeSystem.MIN_SCALE_OF_AVG_RESULT; import static org.apache.ignite.internal.sql.engine.util.TypeUtils.native2relationalType; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import it.unimi.dsi.fastutil.longs.Long2ObjectMaps; +import java.util.List; import java.util.stream.Stream; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableIntList; +import org.apache.ignite.internal.network.ClusterNodeImpl; +import org.apache.ignite.internal.sql.engine.exec.ExecutionContext; +import org.apache.ignite.internal.sql.engine.exec.QueryTaskExecutor; +import org.apache.ignite.internal.sql.engine.exec.exp.agg.AccumulatorWrapper; +import org.apache.ignite.internal.sql.engine.exec.exp.agg.AccumulatorsFactory; +import org.apache.ignite.internal.sql.engine.exec.exp.agg.AggregateType; +import org.apache.ignite.internal.sql.engine.exec.mapping.FragmentDescription; +import org.apache.ignite.internal.sql.engine.framework.TestBuilders; import org.apache.ignite.internal.sql.engine.util.Commons; import org.apache.ignite.internal.testframework.BaseIgniteAbstractTest; import org.apache.ignite.internal.type.NativeTypes; +import org.apache.ignite.network.NetworkAddress; import org.hamcrest.Matchers; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; /** Tests for {@link IgniteTypeSystem}. */ public class IgniteTypeSystemTest extends BaseIgniteAbstractTest { @@ -106,6 +128,16 @@ void deriveAvgType(RelDataType argument, RelDataType expected) { RelDataType actual = typeSystem.deriveAvgAggType(Commons.typeFactory(), argument); assertThat(actual, Matchers.equalTo(expected)); + + checkAccReturnType(SqlStdOperatorTable.AVG, argument, expected); + } + + @ParameterizedTest + @MethodSource("deriveAvgTypeArguments") + void deriveSumType(RelDataType argument, RelDataType expected) { + RelDataType actual = typeSystem.deriveSumType(Commons.typeFactory(), argument); + + checkAccReturnType(SqlStdOperatorTable.SUM, argument, actual); } private static Stream deriveAvgTypeArguments() { @@ -270,4 +302,58 @@ private static Stream deriveDivideDecimalArgs() { ) ); } + + private void checkAccReturnType( + SqlAggFunction aggFunction, + RelDataType argument, + RelDataType expected + ) { + IgniteTypeFactory typeFactory = new IgniteTypeFactory(typeSystem); + + AggregateCall aggregateCall = createAggregateCall(aggFunction, 0, expected); + + RelDataType inputRowType = new RelDataTypeFactory.Builder(typeFactory) + .add("F1", argument) + .build(); + + ExecutionContext ctx = TestBuilders.executionContext() + .queryId(randomUUID()) + .localNode(new ClusterNodeImpl(randomUUID(), "node-1", new NetworkAddress("localhost", 1234))) + .fragment(new FragmentDescription(1, true, Long2ObjectMaps.emptyMap(), null, null, null)) + .executor(Mockito.mock(QueryTaskExecutor.class)) + .build(); + + AccumulatorsFactory accumulatorsFactory = new AccumulatorsFactory<>( + ctx, + AggregateType.SINGLE, + List.of(aggregateCall), + inputRowType + ); + AccumulatorWrapper accumulatorWrappers = accumulatorsFactory.get().get(0); + + RelDataType accRetType = accumulatorWrappers.accumulator().returnType(typeFactory); + + assertTrue(SqlTypeUtil.equalSansNullability(accRetType, expected), + "Expected: " + expected.getFullTypeString() + + "\nActual:" + accRetType.getFullTypeString()); + } + + private static AggregateCall createAggregateCall( + SqlAggFunction aggFunction, + int arg, + RelDataType outputType + ) { + return AggregateCall.create( + aggFunction, + false, + false, + false, + List.of(), + ImmutableIntList.of(arg), + -1, + ImmutableBitSet.of(), + RelCollations.EMPTY, + outputType, + null); + } } diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/PlanUtilsTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/PlanUtilsTest.java index c35dbaa40b3..ff0cab1c2a6 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/PlanUtilsTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/PlanUtilsTest.java @@ -29,8 +29,6 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.ImmutableBitSet; -import org.apache.ignite.internal.sql.engine.exec.exp.agg.Accumulator; -import org.apache.ignite.internal.sql.engine.exec.exp.agg.Accumulators; import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory; import org.junit.jupiter.api.Test; @@ -41,8 +39,6 @@ public class PlanUtilsTest { private final IgniteTypeFactory typeFactory = Commons.typeFactory(); - private final Accumulators accumulators = new Accumulators(typeFactory); - @Test public void testHashAggRowType() { RelDataType inputType = new RelDataTypeFactory.Builder(typeFactory) @@ -53,13 +49,12 @@ public void testHashAggRowType() { .build(); AggregateCall call1 = newCall(typeFactory.createSqlType(SqlTypeName.BIGINT)); - Accumulator acc1 = accumulators.accumulatorFactory(call1).get(); RelDataType expectedType = new RelDataTypeFactory.Builder(typeFactory) .add("f1", typeFactory.createSqlType(SqlTypeName.INTEGER)) .add("f2", typeFactory.createSqlType(SqlTypeName.VARCHAR)) .add("f4", typeFactory.createSqlType(SqlTypeName.VARBINARY)) - .add("_ACC0", acc1.returnType(typeFactory)) + .add("_ACC0", call1.getType()) .add("_GROUP_ID", typeFactory.createSqlType(SqlTypeName.TINYINT)) .build(); @@ -83,13 +78,12 @@ public void testSortAggRowType() { .build(); AggregateCall call1 = newCall(typeFactory.createSqlType(SqlTypeName.BIGINT)); - Accumulator acc1 = accumulators.accumulatorFactory(call1).get(); RelDataType expectedType = new RelDataTypeFactory.Builder(typeFactory) .add("f1", typeFactory.createSqlType(SqlTypeName.INTEGER)) .add("f2", typeFactory.createSqlType(SqlTypeName.VARCHAR)) .add("f4", typeFactory.createSqlType(SqlTypeName.VARBINARY)) - .add("_ACC0", acc1.returnType(typeFactory)) + .add("_ACC0", call1.getType()) .build(); ImmutableBitSet group = ImmutableBitSet.of(0, 1, 3);