Skip to content

Commit

Permalink
Rewrite all classes
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Dec 12, 2024
1 parent 1f47082 commit e56824e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1826,19 +1826,15 @@ object GpuOverrides extends Logging {
ExprChecks.binaryProject(TypeSig.DATE, TypeSig.DATE,
("date", TypeSig.DATE, TypeSig.DATE),
("format", TypeSig.STRING, TypeSig.STRING)),
(a, conf, p, r) => new BinaryExprMeta[TruncDate](a, conf, p, r) {
override def convertToGpu(date: Expression, format: Expression): GpuExpression =
GpuTruncDate(date, format)
}),
(a, conf, p, r) => new TruncDateExprMeta(a, conf, p, r)
),
expr[TruncTimestamp](
"Truncate the timestamp to the unit specified by the given string format",
ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
("format", TypeSig.STRING, TypeSig.STRING),
("date", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP)),
(a, conf, p, r) => new BinaryExprMeta[TruncTimestamp](a, conf, p, r) {
override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression =
GpuTruncTimestamp(format, timestamp, a.timeZoneId)
}),
(a, conf, p, r) => new TruncTimestampExprMeta(a, conf, p, r)
),
expr[Pmod](
"Pmod",
// Decimal support disabled https://github.com/NVIDIA/spark-rapids/issues/7553
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.{DateTimeUtils, GpuTimeZoneDB}
import com.nvidia.spark.rapids.shims.{NullIntolerantShim, ShimBinaryExpression, ShimExpression}

import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, MonthsBetween, TimeZoneAwareExpression, ToUTCTimestamp}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, MonthsBetween, TimeZoneAwareExpression, ToUTCTimestamp, TruncDate, TruncTimestamp}
import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1527,46 +1527,88 @@ case class GpuLastDay(startDate: Expression)
input.getBase.lastDayOfMonth()
}

case class GpuTruncDate(date: Expression, format: Expression)
extends GpuBinaryExpression with ImplicitCastInputTypes {
abstract class GpuTruncDateTime(fmtStr: Option[String]) extends GpuBinaryExpression
with ImplicitCastInputTypes with Serializable {
override def nullable: Boolean = true

protected def truncate(datetimeCol: GpuColumnVector, fmtCol: GpuColumnVector): ColumnVector = {
DateTimeUtils.truncate(datetimeCol.getBase, fmtCol.getBase)
}

protected def truncate(datetimeVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = {
withResource(ColumnVector.fromScalar(datetimeVal.getBase, 1)) { datetimeCol =>
DateTimeUtils.truncate(datetimeCol, formatCol.getBase)
}
}

protected def truncate(datetimeCol: GpuColumnVector, fmtVal: GpuScalar): ColumnVector = {
// fmtVal is unused, as it was extracted to `fmtStr` before.
fmtStr match {
case Some(fmt) => DateTimeUtils.truncate(datetimeCol.getBase, fmt)
case None => throw new IllegalArgumentException("Invalid format string.")
}
}

protected def truncate(numRows: Int, datetimeVal: GpuScalar, fmtVal: GpuScalar): ColumnVector = {
// fmtVal is unused, as it was extracted to `fmtStr` before.
fmtStr match {
case Some(fmt) =>
withResource(ColumnVector.fromScalar(datetimeVal.getBase, 1)) { datetimeCol =>
val truncated = DateTimeUtils.truncate(datetimeCol, fmt)
if (numRows == 1) {
truncated
} else {
withResource(truncated) { _ =>
withResource(truncated.getScalarElement(0)) { truncatedScalar =>
ColumnVector.fromScalar(truncatedScalar, numRows)
}
}
}
}
case None => throw new IllegalArgumentException("Invalid format string.")
}
}
}

case class GpuTruncDate(date: Expression, fmt: Expression, fmtStr: Option[String])
extends GpuTruncDateTime(fmtStr) {
override def left: Expression = date
override def right: Expression = format

override def right: Expression = fmt

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)

override def dataType: DataType = DateType

override def prettyName: String = "trunc"

override def doColumnar(dateCol: GpuColumnVector, formatCol: GpuColumnVector): ColumnVector = {
DateTimeUtils.truncate(dateCol.getBase, formatCol.getBase)
override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = {
truncate(lhs, rhs)
}

override def doColumnar(dateVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = {
DateTimeUtils.truncate(dateVal.getBase, formatCol.getBase)
override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = {
truncate(lhs, rhs)
}

override def doColumnar(dateCol: GpuColumnVector, formatVal: GpuScalar): ColumnVector = {
DateTimeUtils.truncate(dateCol.getBase, formatVal.getBase)
override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
truncate(lhs, rhs)
}

override def doColumnar(numRows: Int, dateVal: GpuScalar, formatVal: GpuScalar): ColumnVector = {
withResource(DateTimeUtils.truncate(dateVal.getBase, formatVal.getBase)) { truncated =>
ColumnVector.fromScalar(truncated, numRows)
}
override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
truncate(numRows, lhs, rhs)
}
}

case class GpuTruncTimestamp(format: Expression, timestamp: Expression,
timeZoneId: Option[String] = None)
extends GpuBinaryExpression with ImplicitCastInputTypes with TimeZoneAwareExpression {
case class GpuTruncTimestamp(fmt: Expression, timestamp: Expression, timeZoneId: Option[String],
fmtStr: Option[String])
extends GpuTruncDateTime(fmtStr) with TimeZoneAwareExpression {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

override def left: Expression = format
override def left: Expression = fmt

override def right: Expression = timestamp

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType)
Expand All @@ -1575,21 +1617,50 @@ case class GpuTruncTimestamp(format: Expression, timestamp: Expression,

override def prettyName: String = "date_trunc"

override def doColumnar(formatCol: GpuColumnVector, tsCol: GpuColumnVector): ColumnVector = {
DateTimeUtils.truncate(tsCol.getBase, formatCol.getBase)
// Since the input order of this class is opposite compared to the `GpuTruncDate` class,
// we need to switch `lhs` and `rhs` in the `doColumnar` methods below.

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = {
truncate(rhs, lhs)
}

override def doColumnar(formatVal: GpuScalar, tsCol: GpuColumnVector): ColumnVector = {
DateTimeUtils.truncate(tsCol.getBase, formatVal.getBase)
override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = {
truncate(rhs, lhs)
}

override def doColumnar(formatCol: GpuColumnVector, tsVal: GpuScalar): ColumnVector = {
DateTimeUtils.truncate(tsVal.getBase, formatCol.getBase)
override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
truncate(rhs, lhs)
}

override def doColumnar(numRows: Int, formatVal: GpuScalar, tsVal: GpuScalar): ColumnVector = {
withResource(DateTimeUtils.truncate(tsVal.getBase, formatVal.getBase)) { truncated =>
ColumnVector.fromScalar(truncated, numRows)
}
override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
truncate(numRows, rhs, lhs)
}
}

class TruncDateExprMeta(expr: TruncDate,
override val conf: RapidsConf,
override val parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends BinaryExprMeta[TruncDate](expr, conf, parent, rule) {

// Store the format string as we need to process it on the CPU later on.
private val fmtStr = extractStringLit(expr.format)

override def convertToGpu(date: Expression, format: Expression): GpuExpression = {
GpuTruncDate(date, format, fmtStr)
}
}

class TruncTimestampExprMeta(expr: TruncTimestamp,
override val conf: RapidsConf,
override val parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends BinaryExprMeta[TruncTimestamp](expr, conf, parent, rule) {

// Store the format string as we need to process it on the CPU later on.
private val fmtStr = extractStringLit(expr.format)

override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = {
GpuTruncTimestamp(format, timestamp, expr.timeZoneId, fmtStr)
}
}

0 comments on commit e56824e

Please sign in to comment.