Skip to content

Commit

Permalink
Add overflow detection to IntervalYearMonthOperators
Browse files Browse the repository at this point in the history
  • Loading branch information
tdcmeehan authored and Pramod Satya committed Nov 24, 2024
1 parent 499fd5a commit df7cdaa
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.AbstractIntType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.BlockIndex;
import com.facebook.presto.spi.function.BlockPosition;
import com.facebook.presto.spi.function.IsNull;
Expand All @@ -42,6 +43,7 @@
import static com.facebook.presto.common.function.OperatorType.NEGATION;
import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.common.function.OperatorType.SUBTRACT;
import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.Math.toIntExact;
Expand All @@ -56,49 +58,81 @@ private IntervalYearMonthOperators()
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long add(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
{
return left + right;
try {
return Math.addExact((int) left, (int) right);
}
catch (ArithmeticException e) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow adding interval year-month values: " + left + " + " + right);
}
}

@ScalarOperator(SUBTRACT)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long subtract(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
{
return left - right;
try {
return Math.subtractExact((int) left, (int) right);
}
catch (ArithmeticException e) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow subtracting interval year-month values: " + left + " - " + right);
}
}

@ScalarOperator(MULTIPLY)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long multiplyByBigint(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.BIGINT) long right)
public static long multiplyByInteger(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTEGER) long right)
{
return left * right;
try {
return Math.multiplyExact((int) left, (int) right);
}
catch (ArithmeticException e) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying interval year-month value by bigint: " + left + " * " + right);
}
}

@ScalarOperator(MULTIPLY)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long multiplyByDouble(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.DOUBLE) double right)
{
return (long) (left * right);
long result = (long) (left * right);
if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying interval year-month value by double: " + left + " * " + right);
}
return result;
}

@ScalarOperator(MULTIPLY)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long bigintMultiply(@SqlType(StandardTypes.BIGINT) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
public static long integerMultiply(@SqlType(StandardTypes.INTEGER) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
{
return left * right;
try {
return Math.multiplyExact((int) left, (int) right);
}
catch (ArithmeticException e) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying bigint by interval year-month value: " + left + " * " + right);
}
}

@ScalarOperator(MULTIPLY)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long doubleMultiply(@SqlType(StandardTypes.DOUBLE) double left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
{
return (long) (left * right);
long result = (long) (left * right);
if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying double by interval year-month value: " + left + " * " + right);
}
return result;
}

@ScalarOperator(DIVIDE)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long divideByDouble(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.DOUBLE) double right)
{
return (long) (left / right);
long result = (long) (left / right);
if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow dividing interval year-month value by double: " + left + " / " + right);
}
return result;
}

@ScalarOperator(NEGATION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class TestIntervalYearMonth
extends AbstractTestFunctions
{
private static final int MAX_SHORT = Short.MAX_VALUE;
private static final long MAX_INT_PLUS_1 = Integer.MAX_VALUE + 1L;

@Test
public void testObject()
Expand Down Expand Up @@ -74,6 +75,7 @@ public void testInvalidLiteral()
assertInvalidFunction("INTERVAL '124-X' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: 124-X");
assertInvalidFunction("INTERVAL '124--30' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: 124--30");
assertInvalidFunction("INTERVAL '--124--30' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: --124--30");
assertInvalidFunction(format("INTERVAL '%s' MONTH", MAX_INT_PLUS_1), "Invalid INTERVAL MONTH value: " + MAX_INT_PLUS_1);
}

@Test
Expand All @@ -82,6 +84,7 @@ public void testAdd()
assertFunction("INTERVAL '3' MONTH + INTERVAL '3' MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(6));
assertFunction("INTERVAL '6' YEAR + INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(12 * 12));
assertFunction("INTERVAL '3' MONTH + INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((6 * 12) + (3)));
assertNumericOverflow(format("INTERVAL '%s' MONTH + INTERVAL '1' MONTH", Integer.MAX_VALUE), format("Overflow adding interval year-month values: %s + 1", Integer.MAX_VALUE));
}

@Test
Expand All @@ -90,6 +93,7 @@ public void testSubtract()
assertFunction("INTERVAL '6' MONTH - INTERVAL '3' MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(3));
assertFunction("INTERVAL '9' YEAR - INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(3 * 12));
assertFunction("INTERVAL '3' MONTH - INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((3) - (6 * 12)));
assertNumericOverflow(format("-INTERVAL '%s' MONTH - INTERVAL '2' MONTH", Integer.MAX_VALUE), format("Overflow subtracting interval year-month values: -%s - 2", Integer.MAX_VALUE));
}

@Test
Expand All @@ -104,6 +108,11 @@ public void testMultiply()
assertFunction("2 * INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(12 * 12));
assertFunction("INTERVAL '1' YEAR * 2.5", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((int) (2.5 * 12)));
assertFunction("2.5 * INTERVAL '1' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((int) (2.5 * 12)));

assertNumericOverflow(format("INTERVAL '%s' MONTH * 2", Integer.MAX_VALUE), format("Overflow multiplying interval year-month value by bigint: %s * 2", Integer.MAX_VALUE));
assertNumericOverflow(format("2 * INTERVAL '%s' MONTH", Integer.MAX_VALUE), format("Overflow multiplying bigint by interval year-month value: 2 * %s", Integer.MAX_VALUE));
assertNumericOverflow(format("INTERVAL '%s' MONTH * 2.0", Integer.MAX_VALUE), format("Overflow multiplying interval year-month value by double: %s * 2.0", Integer.MAX_VALUE));
assertNumericOverflow(format("DOUBLE '2' * INTERVAL '%s' MONTH", Integer.MAX_VALUE), format("Overflow multiplying double by interval year-month value: 2.0 * %s", Integer.MAX_VALUE));
}

@Test
Expand All @@ -114,6 +123,8 @@ public void testDivide()

assertFunction("INTERVAL '3' YEAR / 2", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(18));
assertFunction("INTERVAL '4' YEAR / 4.8", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(10));

assertNumericOverflow(format("INTERVAL '%s' MONTH / 0.5", Integer.MAX_VALUE), format("Overflow dividing interval year-month value by double: %s / 0.5", Integer.MAX_VALUE));
}

@Test
Expand Down

0 comments on commit df7cdaa

Please sign in to comment.