Skip to content

Commit

Permalink
Add timezone checker for expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
winningsix committed Nov 17, 2023
1 parent c475d72 commit 20238cd
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ object GpuCSVScan {
}

if (types.contains(TimestampType)) {
meta.checkTimeZoneId(parsedOptions.zoneId)
GpuTextBasedDateUtils.tagCudfFormat(meta,
GpuCsvUtils.timestampFormatInRead(parsedOptions), parseString = true)
}
Expand Down
58 changes: 54 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package com.nvidia.spark.rapids

import java.time.ZoneId

import scala.collection.mutable

import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
Expand All @@ -32,7 +34,8 @@ import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableComma
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.python.AggregateInPandasExec
import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType, TimestampType}


trait DataFromReplacementRule {
val operationName: String
Expand Down Expand Up @@ -383,6 +386,19 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
}
}

def checkTimeZoneId(sessionZoneId: ZoneId): Unit = {
// Both of the Spark session time zone and JVM's default time zone should be UTC.
if (!TypeChecks.isTimestampsSupported(sessionZoneId)) {
willNotWorkOnGpu("Only UTC zone id is supported. " +
s"Actual session local zone id: $sessionZoneId")
}

val defaultZoneId = ZoneId.systemDefault()
if (!TypeChecks.isTimestampsSupported(defaultZoneId)) {
willNotWorkOnGpu(s"Only UTC zone id is supported. Actual default zone id: $defaultZoneId")
}
}

/**
* Create a string representation of this in append.
* @param strBuilder where to place the string representation.
Expand Down Expand Up @@ -1056,18 +1072,52 @@ abstract class BaseExprMeta[INPUT <: Expression](

val isFoldableNonLitAllowed: Boolean = false

// Default false as conservative approach to allow timezone related expression converted to GPU
lazy val isTimezoneSupported: Boolean = false
// Whether timezone is supported for those expressions needs to be check.
// TODO: use TimezoneDB Utils to tell whether timezone is supported
val isTimezoneSupported: Boolean = false

//+------------------------+-------------------+-----------------------------------------+
//| Value | needTimezoneCheck | isTimezoneSupported |
//+------------------------+-------------------+-----------------------------------------+
//| TimezoneAwareExpression| True | False by default, True when implemented |
//| UTCTimestamp | True | False by default, True when implemented |
//| Others | False | N/A (will not be checked) |
//+------------------------+-------------------+-----------------------------------------+
lazy val needTimezoneCheck: Boolean = {
wrapped match {
case _: TimeZoneAwareExpression | _: UTCTimestamp => true
case _ => false
}
}

final override def tagSelfForGpu(): Unit = {
if (wrapped.foldable && !GpuOverrides.isLit(wrapped) && !isFoldableNonLitAllowed) {
willNotWorkOnGpu(s"Cannot run on GPU. Is ConstantFolding excluded? Expression " +
s"$wrapped is foldable and operates on non literals")
}
rule.getChecks.foreach(_.tag(this))
if (needTimezoneCheck && !isTimezoneSupported) checkTimestampType(dataType, this)
tagExprForGpu()
}

/**
* Check whether contains timestamp type and whether timezone is supported
*/
def checkTimestampType(dataType: DataType, meta: RapidsMeta[_, _, _]): Unit = {
dataType match {
case TimestampType if !TypeChecks.isUTCTimezone() =>
meta.willNotWorkOnGpu(TypeChecks.timezoneNotSupportedString(dataType))
case ArrayType(elementType, _) =>
checkTimestampType(elementType, meta)
case MapType(keyType, valueType, _) =>
checkTimestampType(keyType, meta)
checkTimestampType(valueType, meta)
case StructType(fields) =>
fields.foreach(field => checkTimestampType(field.dataType, meta))
case _ => // do nothing
}
}

/**
* Called to verify that this expression will work on the GPU. For most expressions without
* extra checks all of the checks should have already been done.
Expand Down
12 changes: 10 additions & 2 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import java.io.{File, FileOutputStream}
import java.time.ZoneId

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.shims.{CastCheckShims, GpuTypeShims, TypeSigUtil}
Expand Down Expand Up @@ -789,13 +790,20 @@ abstract class TypeChecks[RET] {

object TypeChecks {

// TODO: move this to Timezone DB
def isTimestampsSupported(timezoneId: ZoneId): Boolean = {
timezoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID
}

def isUTCTimezone(): Boolean = {
val zoneId = DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)
zoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID
}

def isTimezoneSensitiveType(dataType: DataType): Boolean = {
dataType == TimestampType
def timezoneNotSupportedString(dataType: DataType): String = {
s"$dataType is not supported with timezone settings: (JVM:" +
s" ${ZoneId.systemDefault()}, session: ${SQLConf.get.sessionLocalTimeZone})." +
s" Set both of the timezones to UTC to enable $dataType support"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ object GpuJsonScan {
}

if (types.contains(TimestampType)) {
meta.checkTimeZoneId(parsedOptions.zoneId)
GpuTextBasedDateUtils.tagCudfFormat(meta,
GpuJsonUtils.timestampFormatInRead(parsedOptions), parseString = true)
}
Expand Down

0 comments on commit 20238cd

Please sign in to comment.