From e56824eecb1aec5c47703c1dd2da08665a137cce Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 12 Dec 2024 10:02:33 -0800 Subject: [PATCH] Rewrite all classes Signed-off-by: Nghia Truong --- .../nvidia/spark/rapids/GpuOverrides.scala | 12 +- .../sql/rapids/datetimeExpressions.scala | 127 ++++++++++++++---- 2 files changed, 103 insertions(+), 36 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index d24528731b3..0c7aa046a7c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1826,19 +1826,15 @@ object GpuOverrides extends Logging { ExprChecks.binaryProject(TypeSig.DATE, TypeSig.DATE, ("date", TypeSig.DATE, TypeSig.DATE), ("format", TypeSig.STRING, TypeSig.STRING)), - (a, conf, p, r) => new BinaryExprMeta[TruncDate](a, conf, p, r) { - override def convertToGpu(date: Expression, format: Expression): GpuExpression = - GpuTruncDate(date, format) - }), + (a, conf, p, r) => new TruncDateExprMeta(a, conf, p, r) + ), expr[TruncTimestamp]( "Truncate the timestamp to the unit specified by the given string format", ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, ("format", TypeSig.STRING, TypeSig.STRING), ("date", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP)), - (a, conf, p, r) => new BinaryExprMeta[TruncTimestamp](a, conf, p, r) { - override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = - GpuTruncTimestamp(format, timestamp, a.timeZoneId) - }), + (a, conf, p, r) => new TruncTimestampExprMeta(a, conf, p, r) + ), expr[Pmod]( "Pmod", // Decimal support disabled https://github.com/NVIDIA/spark-rapids/issues/7553 diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 16dc0069e6a..14ec6bbd090 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -30,7 +30,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.{DateTimeUtils, GpuTimeZoneDB} import com.nvidia.spark.rapids.shims.{NullIntolerantShim, ShimBinaryExpression, ShimExpression} -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, MonthsBetween, TimeZoneAwareExpression, ToUTCTimestamp} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, MonthsBetween, TimeZoneAwareExpression, ToUTCTimestamp, TruncDate, TruncTimestamp} import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1527,11 +1527,54 @@ case class GpuLastDay(startDate: Expression) input.getBase.lastDayOfMonth() } -case class GpuTruncDate(date: Expression, format: Expression) - extends GpuBinaryExpression with ImplicitCastInputTypes { +abstract class GpuTruncDateTime(fmtStr: Option[String]) extends GpuBinaryExpression + with ImplicitCastInputTypes with Serializable { + override def nullable: Boolean = true + + protected def truncate(datetimeCol: GpuColumnVector, fmtCol: GpuColumnVector): ColumnVector = { + DateTimeUtils.truncate(datetimeCol.getBase, fmtCol.getBase) + } + + protected def truncate(datetimeVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = { + withResource(ColumnVector.fromScalar(datetimeVal.getBase, 1)) { datetimeCol => + DateTimeUtils.truncate(datetimeCol, formatCol.getBase) + } + } + + protected def truncate(datetimeCol: GpuColumnVector, fmtVal: GpuScalar): ColumnVector = { + // fmtVal is unused, as it was extracted to `fmtStr` before. + fmtStr match { + case Some(fmt) => DateTimeUtils.truncate(datetimeCol.getBase, fmt) + case None => throw new IllegalArgumentException("Invalid format string.") + } + } + + protected def truncate(numRows: Int, datetimeVal: GpuScalar, fmtVal: GpuScalar): ColumnVector = { + // fmtVal is unused, as it was extracted to `fmtStr` before. + fmtStr match { + case Some(fmt) => + withResource(ColumnVector.fromScalar(datetimeVal.getBase, 1)) { datetimeCol => + val truncated = DateTimeUtils.truncate(datetimeCol, fmt) + if (numRows == 1) { + truncated + } else { + withResource(truncated) { _ => + withResource(truncated.getScalarElement(0)) { truncatedScalar => + ColumnVector.fromScalar(truncatedScalar, numRows) + } + } + } + } + case None => throw new IllegalArgumentException("Invalid format string.") + } + } +} +case class GpuTruncDate(date: Expression, fmt: Expression, fmtStr: Option[String]) + extends GpuTruncDateTime(fmtStr) { override def left: Expression = date - override def right: Expression = format + + override def right: Expression = fmt override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) @@ -1539,34 +1582,33 @@ case class GpuTruncDate(date: Expression, format: Expression) override def prettyName: String = "trunc" - override def doColumnar(dateCol: GpuColumnVector, formatCol: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(dateCol.getBase, formatCol.getBase) + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { + truncate(lhs, rhs) } - override def doColumnar(dateVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(dateVal.getBase, formatCol.getBase) + override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { + truncate(lhs, rhs) } - override def doColumnar(dateCol: GpuColumnVector, formatVal: GpuScalar): ColumnVector = { - DateTimeUtils.truncate(dateCol.getBase, formatVal.getBase) + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { + truncate(lhs, rhs) } - override def doColumnar(numRows: Int, dateVal: GpuScalar, formatVal: GpuScalar): ColumnVector = { - withResource(DateTimeUtils.truncate(dateVal.getBase, formatVal.getBase)) { truncated => - ColumnVector.fromScalar(truncated, numRows) - } + override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { + truncate(numRows, lhs, rhs) } } -case class GpuTruncTimestamp(format: Expression, timestamp: Expression, - timeZoneId: Option[String] = None) - extends GpuBinaryExpression with ImplicitCastInputTypes with TimeZoneAwareExpression { +case class GpuTruncTimestamp(fmt: Expression, timestamp: Expression, timeZoneId: Option[String], + fmtStr: Option[String]) + extends GpuTruncDateTime(fmtStr) with TimeZoneAwareExpression { override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } - override def left: Expression = format + override def left: Expression = fmt + override def right: Expression = timestamp override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) @@ -1575,21 +1617,50 @@ case class GpuTruncTimestamp(format: Expression, timestamp: Expression, override def prettyName: String = "date_trunc" - override def doColumnar(formatCol: GpuColumnVector, tsCol: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(tsCol.getBase, formatCol.getBase) + // Since the input order of this class is opposite compared to the `GpuTruncDate` class, + // we need to switch `lhs` and `rhs` in the `doColumnar` methods below. + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { + truncate(rhs, lhs) } - override def doColumnar(formatVal: GpuScalar, tsCol: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(tsCol.getBase, formatVal.getBase) + override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { + truncate(rhs, lhs) } - override def doColumnar(formatCol: GpuColumnVector, tsVal: GpuScalar): ColumnVector = { - DateTimeUtils.truncate(tsVal.getBase, formatCol.getBase) + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { + truncate(rhs, lhs) } - override def doColumnar(numRows: Int, formatVal: GpuScalar, tsVal: GpuScalar): ColumnVector = { - withResource(DateTimeUtils.truncate(tsVal.getBase, formatVal.getBase)) { truncated => - ColumnVector.fromScalar(truncated, numRows) - } + override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { + truncate(numRows, rhs, lhs) + } +} + +class TruncDateExprMeta(expr: TruncDate, + override val conf: RapidsConf, + override val parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[TruncDate](expr, conf, parent, rule) { + + // Store the format string as we need to process it on the CPU later on. + private val fmtStr = extractStringLit(expr.format) + + override def convertToGpu(date: Expression, format: Expression): GpuExpression = { + GpuTruncDate(date, format, fmtStr) + } +} + +class TruncTimestampExprMeta(expr: TruncTimestamp, + override val conf: RapidsConf, + override val parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[TruncTimestamp](expr, conf, parent, rule) { + + // Store the format string as we need to process it on the CPU later on. + private val fmtStr = extractStringLit(expr.format) + + override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = { + GpuTruncTimestamp(format, timestamp, expr.timeZoneId, fmtStr) } }