Skip to content

Commit

Permalink
Add 330 shim and fix failures in test_basic_json_read
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 12, 2023
1 parent 01dd4bf commit 70011f8
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 25 deletions.
20 changes: 3 additions & 17 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,18 +306,6 @@ def do_read(spark):

@approximate_float
@pytest.mark.parametrize('filename', [
'boolean.json',
pytest.param('boolean_invalid.json', marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/4779')),
'ints.json',
pytest.param('ints_invalid.json', marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/4793')),
'nan_and_inf.json',
pytest.param('nan_and_inf_strings.json', marks=pytest.mark.skipif(is_before_spark_330(), reason='https://issues.apache.org/jira/browse/SPARK-38060 fixed in Spark 3.3.0')),
'nan_and_inf_invalid.json',
'floats.json',
'floats_leading_zeros.json',
'floats_invalid.json',
pytest.param('floats_edge_cases.json', marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/4647')),
'decimals.json',
'dates.json',
'dates_invalid.json',
])
Expand Down Expand Up @@ -427,7 +415,7 @@ def test_json_read_valid_dates(std_input_path, filename, schema, read_func, ansi
'[1-3]{1,2}/[1-3]{1,2}/[1-9]{4}',
])
@pytest.mark.parametrize('schema', [StructType([StructField('value', DateType())])])
@pytest.mark.parametrize('date_format', ['', 'yyyy-MM-dd'] if is_before_spark_340 else json_supported_date_formats)
@pytest.mark.parametrize('date_format', ['', 'yyyy-MM-dd'] if is_before_spark_330 else json_supported_date_formats)
@pytest.mark.parametrize('ansi_enabled', [True, False])
@pytest.mark.parametrize('allow_numeric_leading_zeros', [True, False])
def test_json_read_generated_dates(spark_tmp_table_factory, spark_tmp_path, date_gen_pattern, schema, date_format, \
Expand All @@ -451,16 +439,14 @@ def test_json_read_generated_dates(spark_tmp_table_factory, spark_tmp_path, date
f = read_json_df(path, schema, spark_tmp_table_factory, options)
assert_gpu_and_cpu_are_equal_collect(f, conf = updated_conf)

## TODO fallback tests for unsupported date formats prior to spark 340

@approximate_float
@pytest.mark.parametrize('filename', [
'dates_invalid.json',
])
@pytest.mark.parametrize('schema', [_date_schema])
@pytest.mark.parametrize('read_func', [read_json_df, read_json_sql])
@pytest.mark.parametrize('ansi_enabled', ["true", "false"])
@pytest.mark.parametrize('date_format', ['', 'yyyy-MM-dd'] if is_before_spark_340 else json_supported_date_formats)
@pytest.mark.parametrize('date_format', ['', 'yyyy-MM-dd'] if is_before_spark_330 else json_supported_date_formats)
@pytest.mark.parametrize('time_parser_policy', [
pytest.param('LEGACY', marks=pytest.mark.allow_non_gpu('FileSourceScanExec')),
pytest.param('CORRECTED', marks=pytest.mark.allow_non_gpu(*not_utc_json_scan_allow)),
Expand Down Expand Up @@ -695,7 +681,7 @@ def test_from_json_struct_date_fallback_legacy(date_gen, date_format):
conf={"spark.rapids.sql.expression.JsonToStructs": True,
'spark.sql.legacy.timeParserPolicy': 'LEGACY'})

@pytest.mark.skipif(is_spark_340_or_later(), reason="We only support custom formats with Spark 340+")
@pytest.mark.skipif(is_spark_330_or_later(), reason="We only support custom formats with Spark 330+")
@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('date_gen', ["\"[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}\""])
@pytest.mark.parametrize('date_format', [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.{DateUtils, GpuCast, GpuOverrides, RapidsMeta}

import org.apache.spark.sql.catalyst.json.GpuJsonUtils
import org.apache.spark.sql.rapids.ExceptionTimeParserPolicy

object GpuJsonToStructsShim {

def tagDateFormatSupport(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = {
}

def castJsonStringToDate(input: ColumnVector, options: Map[String, String]): ColumnVector = {
GpuJsonUtils.optionalDateFormatInRead(options) match {
case None =>
// legacy behavior
withResource(Scalar.fromString(" ")) { space =>
withResource(input.strip(space)) { trimmed =>
GpuCast.castStringToDate(trimmed)
}
}
case Some(f) =>
// from_json does not respect EXCEPTION policy
jsonStringToDate(input, f, failOnInvalid = false)
}
}

def tagDateFormatSupportFromScan(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = {
}

def castJsonStringToDateFromScan(input: ColumnVector, dt: DType, dateFormat: Option[String],
failOnInvalid: Boolean): ColumnVector = {
dateFormat match {
case None =>
// legacy behavior
withResource(input.strip()) { trimmed =>
GpuCast.castStringToDateAnsi(trimmed, ansiMode = false)
}
case Some(f) =>
jsonStringToDate(input, f, failOnInvalid &&
GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy)
}
}

private def jsonStringToDate(input: ColumnVector, dateFormatPattern: String,
failOnInvalid: Boolean): ColumnVector = {
val regexRoot = dateFormatPattern
.replace("yyyy", raw"\d{4}")
.replace("MM", raw"\d{2}")
.replace("dd", raw"\d{2}")
val cudfFormat = DateUtils.toStrf(dateFormatPattern, parseString = true)
withResource(input.strip()) { input =>
GpuCast.convertDateOrNull(input, "^" + regexRoot + "$", cudfFormat, failOnInvalid)
}
}

def castJsonStringToTimestamp(input: ColumnVector,
options: Map[String, String]): ColumnVector = {
options.get("timestampFormat") match {
case None =>
// legacy behavior
withResource(Scalar.fromString(" ")) { space =>
withResource(input.strip(space)) { trimmed =>
// from_json doesn't respect ansi mode
GpuCast.castStringToTimestamp(trimmed, ansiMode = false)
}
}
case Some("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]") =>
GpuCast.convertTimestampOrNull(input,
"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(\\.[0-9]{1,6})?Z?$", "%Y-%m-%d")
case other =>
// should be unreachable due to GpuOverrides checks
throw new IllegalStateException(s"Unsupported timestampFormat $other")
}
}
}

0 comments on commit 70011f8

Please sign in to comment.