From 1f470821b0e36ec85760a3beab5a6577baa7a2c4 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Wed, 11 Dec 2024 20:29:23 -0800 Subject: [PATCH] Adopt to JNI changes Signed-off-by: Nghia Truong --- .../sql/rapids/datetimeExpressions.scala | 59 ++++++++++--------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index d72aa5e5053..16dc0069e6a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1527,33 +1527,9 @@ case class GpuLastDay(startDate: Expression) input.getBase.lastDayOfMonth() } -trait GpuTruncDateTime extends GpuBinaryExpression with ImplicitCastInputTypes { - override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { - withResource(scalarToColumn(lhs)) { left => - doColumnar(left, rhs) - } - } - - override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { - withResource(scalarToColumn(rhs)) { right => - doColumnar(lhs, right) - } - } - - override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { - withResource(scalarToColumn(lhs, numRows)) { left => - withResource(scalarToColumn(rhs, numRows)) { right => - doColumnar(left, right) - } - } - } - - private def scalarToColumn(input: GpuScalar, numRows: Int = 1) : GpuColumnVector = { - GpuColumnVector.from(input, numRows, input.dataType) - } -} +case class GpuTruncDate(date: Expression, format: Expression) + extends GpuBinaryExpression with ImplicitCastInputTypes { -case class GpuTruncDate(date: Expression, format: Expression) extends GpuTruncDateTime { override def left: Expression = date override def right: Expression = format @@ -1566,11 +1542,26 @@ case class GpuTruncDate(date: Expression, format: Expression) extends GpuTruncDa override def doColumnar(dateCol: GpuColumnVector, formatCol: GpuColumnVector): ColumnVector = { DateTimeUtils.truncate(dateCol.getBase, formatCol.getBase) } + + override def doColumnar(dateVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = { + DateTimeUtils.truncate(dateVal.getBase, formatCol.getBase) + } + + override def doColumnar(dateCol: GpuColumnVector, formatVal: GpuScalar): ColumnVector = { + DateTimeUtils.truncate(dateCol.getBase, formatVal.getBase) + } + + override def doColumnar(numRows: Int, dateVal: GpuScalar, formatVal: GpuScalar): ColumnVector = { + withResource(DateTimeUtils.truncate(dateVal.getBase, formatVal.getBase)) { truncated => + ColumnVector.fromScalar(truncated, numRows) + } + } } case class GpuTruncTimestamp(format: Expression, timestamp: Expression, timeZoneId: Option[String] = None) - extends GpuTruncDateTime with TimeZoneAwareExpression { + extends GpuBinaryExpression with ImplicitCastInputTypes with TimeZoneAwareExpression { + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } @@ -1587,4 +1578,18 @@ case class GpuTruncTimestamp(format: Expression, timestamp: Expression, override def doColumnar(formatCol: GpuColumnVector, tsCol: GpuColumnVector): ColumnVector = { DateTimeUtils.truncate(tsCol.getBase, formatCol.getBase) } + + override def doColumnar(formatVal: GpuScalar, tsCol: GpuColumnVector): ColumnVector = { + DateTimeUtils.truncate(tsCol.getBase, formatVal.getBase) + } + + override def doColumnar(formatCol: GpuColumnVector, tsVal: GpuScalar): ColumnVector = { + DateTimeUtils.truncate(tsVal.getBase, formatCol.getBase) + } + + override def doColumnar(numRows: Int, formatVal: GpuScalar, tsVal: GpuScalar): ColumnVector = { + withResource(DateTimeUtils.truncate(tsVal.getBase, formatVal.getBase)) { truncated => + ColumnVector.fromScalar(truncated, numRows) + } + } }