Skip to content

Commit

Permalink
save progress on 330 shim
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 12, 2023
1 parent 70011f8 commit c65b29f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ object GpuJsonScan {

val hasTimestamps = TrampolineUtil.dataTypeExistsRecursively(dt, _.isInstanceOf[TimestampType])
if (hasTimestamps) {
GpuJsonToStructsShim.tagTimestampFormatSupport(meta,
GpuJsonUtils.optionalTimestampFormatInRead(parsedOptions))


GpuJsonUtils.optionalTimestampFormatInRead(parsedOptions) match {
case None | Some("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]") =>
// this is fine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ object GpuJsonToStructsShim {
def tagDateFormatSupport(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = {
dateFormat match {
case None | Some("yyyy-MM-dd") =>
// this is fine
case dateFormat =>
meta.willNotWorkOnGpu(s"GpuJsonToStructs unsupported dateFormat $dateFormat")
}
Expand Down Expand Up @@ -74,6 +73,9 @@ object GpuJsonToStructsShim {
}
}

def tagTimestampFormatSupport(meta: RapidsMeta[_, _, _],
timestampFormat: Option[String]): Unit = {}

def castJsonStringToTimestamp(input: ColumnVector,
options: Map[String, String]): ColumnVector = {
withResource(Scalar.fromString(" ")) { space =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.{DateUtils, GpuCast, GpuOverrides, RapidsMeta}

import org.apache.spark.sql.catalyst.json.GpuJsonUtils
//import org.apache.spark.sql.catalyst.json.GpuJsonUtils
import org.apache.spark.sql.rapids.ExceptionTimeParserPolicy

object GpuJsonToStructsShim {
Expand All @@ -38,66 +38,54 @@ object GpuJsonToStructsShim {
}

def castJsonStringToDate(input: ColumnVector, options: Map[String, String]): ColumnVector = {
GpuJsonUtils.optionalDateFormatInRead(options) match {
case None =>
// legacy behavior
withResource(Scalar.fromString(" ")) { space =>
withResource(input.strip(space)) { trimmed =>
GpuCast.castStringToDate(trimmed)
}
}
case Some(f) =>
// from_json does not respect EXCEPTION policy
jsonStringToDate(input, f, failOnInvalid = false)
// dateFormat is ignored in Spark 3.3
withResource(Scalar.fromString(" ")) { space =>
withResource(input.strip(space)) { trimmed =>
GpuCast.castStringToDate(trimmed)
}
}
}

def tagDateFormatSupportFromScan(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = {
}

def castJsonStringToDateFromScan(input: ColumnVector, dt: DType, dateFormat: Option[String],
failOnInvalid: Boolean): ColumnVector = {
failOnInvalid: Boolean): ColumnVector = {
dateFormat match {
case None =>
// legacy behavior
withResource(input.strip()) { trimmed =>
GpuCast.castStringToDateAnsi(trimmed, ansiMode = false)
}
case Some(f) =>
jsonStringToDate(input, f, failOnInvalid &&
GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy)
withResource(input.strip()) { trimmed =>
jsonStringToDate(trimmed, f, failOnInvalid &&
GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy)
}
}
}

private def jsonStringToDate(input: ColumnVector, dateFormatPattern: String,
failOnInvalid: Boolean): ColumnVector = {
val regexRoot = dateFormatPattern
.replace("yyyy", raw"\d{4}")
.replace("MM", raw"\d{2}")
.replace("dd", raw"\d{2}")
.replace("MM", raw"\d{1,2}")
.replace("dd", raw"\d{1,2}")
val cudfFormat = DateUtils.toStrf(dateFormatPattern, parseString = true)
withResource(input.strip()) { input =>
GpuCast.convertDateOrNull(input, "^" + regexRoot + "$", cudfFormat, failOnInvalid)
}
GpuCast.convertDateOrNull(input, "^" + regexRoot + "$", cudfFormat, failOnInvalid)
}

def tagTimestampFormatSupport(meta: RapidsMeta[_, _, _],
timestampFormat: Option[String]): Unit = {}

def castJsonStringToTimestamp(input: ColumnVector,
options: Map[String, String]): ColumnVector = {
options.get("timestampFormat") match {
case None =>
// legacy behavior
withResource(Scalar.fromString(" ")) { space =>
withResource(input.strip(space)) { trimmed =>
// from_json doesn't respect ansi mode
GpuCast.castStringToTimestamp(trimmed, ansiMode = false)
}
}
case Some("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]") =>
GpuCast.convertTimestampOrNull(input,
"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(\\.[0-9]{1,6})?Z?$", "%Y-%m-%d")
case other =>
// should be unreachable due to GpuOverrides checks
throw new IllegalStateException(s"Unsupported timestampFormat $other")
// legacy behavior
withResource(Scalar.fromString(" ")) { space =>
withResource(input.strip(space)) { trimmed =>
// from_json doesn't respect ansi mode
GpuCast.castStringToTimestamp(trimmed, ansiMode = false)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ object GpuJsonToStructsShim {
}
}


private def jsonStringToDate(input: ColumnVector, dateFormatPattern: String,
failOnInvalid: Boolean): ColumnVector = {
val regexRoot = dateFormatPattern
Expand All @@ -76,6 +77,16 @@ object GpuJsonToStructsShim {
GpuCast.convertDateOrNull(input, "^" + regexRoot + "$", cudfFormat, failOnInvalid)
}

def tagTimestampFormatSupport(meta: RapidsMeta[_, _, _],
timestampFormat: Option[String]): Unit = {
timestampFormat match {
case None | Some("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]") =>
// fine
case other =>
meta.willNotWorkOnGpu(s"Unsupported timestampFormat ${other}")
}
}

def castJsonStringToTimestamp(input: ColumnVector,
options: Map[String, String]): ColumnVector = {
options.get("timestampFormat") match {
Expand Down

0 comments on commit c65b29f

Please sign in to comment.