From 20238cd93e32f6045e9dfb14c1e868a65f39d71f Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Fri, 17 Nov 2023 21:06:31 +0800 Subject: [PATCH] Add timezone checker for expressions --- .../com/nvidia/spark/rapids/GpuCSVScan.scala | 1 + .../com/nvidia/spark/rapids/RapidsMeta.scala | 58 +++++++++++++++++-- .../com/nvidia/spark/rapids/TypeChecks.scala | 12 +++- .../catalyst/json/rapids/GpuJsonScan.scala | 1 + 4 files changed, 66 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala index fe5150cc224f..5969415730f5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala @@ -169,6 +169,7 @@ object GpuCSVScan { } if (types.contains(TimestampType)) { + meta.checkTimeZoneId(parsedOptions.zoneId) GpuTextBasedDateUtils.tagCudfFormat(meta, GpuCsvUtils.timestampFormatInRead(parsedOptions), parseString = true) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index cfc24c22d71d..2ac0e54a3ad2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -16,11 +16,13 @@ package com.nvidia.spark.rapids +import java.time.ZoneId + import scala.collection.mutable import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -32,7 +34,8 @@ import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableComma import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.python.AggregateInPandasExec import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType, TimestampType} + trait DataFromReplacementRule { val operationName: String @@ -383,6 +386,19 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( } } + def checkTimeZoneId(sessionZoneId: ZoneId): Unit = { + // Both of the Spark session time zone and JVM's default time zone should be UTC. + if (!TypeChecks.isTimestampsSupported(sessionZoneId)) { + willNotWorkOnGpu("Only UTC zone id is supported. " + + s"Actual session local zone id: $sessionZoneId") + } + + val defaultZoneId = ZoneId.systemDefault() + if (!TypeChecks.isTimestampsSupported(defaultZoneId)) { + willNotWorkOnGpu(s"Only UTC zone id is supported. Actual default zone id: $defaultZoneId") + } + } + /** * Create a string representation of this in append. * @param strBuilder where to place the string representation. @@ -1056,8 +1072,23 @@ abstract class BaseExprMeta[INPUT <: Expression]( val isFoldableNonLitAllowed: Boolean = false - // Default false as conservative approach to allow timezone related expression converted to GPU - lazy val isTimezoneSupported: Boolean = false + // Whether timezone is supported for those expressions needs to be check. + // TODO: use TimezoneDB Utils to tell whether timezone is supported + val isTimezoneSupported: Boolean = false + + //+------------------------+-------------------+-----------------------------------------+ + //| Value | needTimezoneCheck | isTimezoneSupported | + //+------------------------+-------------------+-----------------------------------------+ + //| TimezoneAwareExpression| True | False by default, True when implemented | + //| UTCTimestamp | True | False by default, True when implemented | + //| Others | False | N/A (will not be checked) | + //+------------------------+-------------------+-----------------------------------------+ + lazy val needTimezoneCheck: Boolean = { + wrapped match { + case _: TimeZoneAwareExpression | _: UTCTimestamp => true + case _ => false + } + } final override def tagSelfForGpu(): Unit = { if (wrapped.foldable && !GpuOverrides.isLit(wrapped) && !isFoldableNonLitAllowed) { @@ -1065,9 +1096,28 @@ abstract class BaseExprMeta[INPUT <: Expression]( s"$wrapped is foldable and operates on non literals") } rule.getChecks.foreach(_.tag(this)) + if (needTimezoneCheck && !isTimezoneSupported) checkTimestampType(dataType, this) tagExprForGpu() } + /** + * Check whether contains timestamp type and whether timezone is supported + */ + def checkTimestampType(dataType: DataType, meta: RapidsMeta[_, _, _]): Unit = { + dataType match { + case TimestampType if !TypeChecks.isUTCTimezone() => + meta.willNotWorkOnGpu(TypeChecks.timezoneNotSupportedString(dataType)) + case ArrayType(elementType, _) => + checkTimestampType(elementType, meta) + case MapType(keyType, valueType, _) => + checkTimestampType(keyType, meta) + checkTimestampType(valueType, meta) + case StructType(fields) => + fields.foreach(field => checkTimestampType(field.dataType, meta)) + case _ => // do nothing + } + } + /** * Called to verify that this expression will work on the GPU. For most expressions without * extra checks all of the checks should have already been done. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 53283d4be1f0..530b5a711773 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import java.io.{File, FileOutputStream} +import java.time.ZoneId import ai.rapids.cudf.DType import com.nvidia.spark.rapids.shims.{CastCheckShims, GpuTypeShims, TypeSigUtil} @@ -789,13 +790,20 @@ abstract class TypeChecks[RET] { object TypeChecks { + // TODO: move this to Timezone DB + def isTimestampsSupported(timezoneId: ZoneId): Boolean = { + timezoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID + } + def isUTCTimezone(): Boolean = { val zoneId = DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone) zoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID } - def isTimezoneSensitiveType(dataType: DataType): Boolean = { - dataType == TimestampType + def timezoneNotSupportedString(dataType: DataType): String = { + s"$dataType is not supported with timezone settings: (JVM:" + + s" ${ZoneId.systemDefault()}, session: ${SQLConf.get.sessionLocalTimeZone})." + + s" Set both of the timezones to UTC to enable $dataType support" } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala index 03dd860bd93b..94140b00e1b5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala @@ -143,6 +143,7 @@ object GpuJsonScan { } if (types.contains(TimestampType)) { + meta.checkTimeZoneId(parsedOptions.zoneId) GpuTextBasedDateUtils.tagCudfFormat(meta, GpuJsonUtils.timestampFormatInRead(parsedOptions), parseString = true) }