From 8235c95f0733ca179053cb5c775a86cc47b753f4 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 20 Jul 2023 13:11:22 +0800 Subject: [PATCH 01/27] WIP: Support parse_url Signed-off-by: Haoyang Li --- .../advanced_configs.md | 1 + docs/supported_ops.md | 765 ++++++++++-------- integration_tests/src/main/python/url_test.py | 121 +++ .../nvidia/spark/rapids/GpuOverrides.scala | 13 + .../spark/sql/rapids/urlFunctions.scala | 211 +++++ .../sql/rapids/shims/RapidsErrorUtils.scala | 10 + .../sql/rapids/shims/RapidsErrorUtils.scala | 12 + .../sql/rapids/shims/RapidsErrorUtils.scala | 12 + .../sql/rapids/shims/RapidsErrorUtils.scala | 12 + .../sql/rapids/shims/RapidsErrorUtils.scala | 12 + .../sql/rapids/shims/RapidsErrorUtils.scala | 14 + .../spark/rapids/UrlFunctionsSuite.scala | 246 ++++++ tools/generated_files/operatorsScore.csv | 1 + tools/generated_files/supportedExprs.csv | 4 + 14 files changed, 1096 insertions(+), 338 deletions(-) create mode 100644 integration_tests/src/main/python/url_test.py create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 39e20f1d1d2..8fb3e91f65f 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -299,6 +299,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.NthValue|`nth_value`|nth window operator|true|None| spark.rapids.sql.expression.OctetLength|`octet_length`|The byte length of string data|true|None| spark.rapids.sql.expression.Or|`or`|Logical OR|true|None| +spark.rapids.sql.expression.ParseUrl|`parse_url`|Extracts a part from a URL|true|None| spark.rapids.sql.expression.PercentRank|`percent_rank`|Window function that returns the percent rank value within the aggregation window|true|None| spark.rapids.sql.expression.Pmod|`pmod`|Pmod|true|None| spark.rapids.sql.expression.PosExplode|`posexplode_outer`, `posexplode`|Given an input array produces a sequence of rows for each value in the array|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 1784c11892a..953597c8016 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10530,6 +10530,95 @@ are limited. +ParseUrl +`parse_url` +Extracts a part from a URL +None +project +url + + + + + + + + + +S + + + + + + + + + + +partToExtract + + + + + + + + + +PS
Literal value only
+ + + + + + + + + + +key + + + + + + + + + +PS
Literal value only
+ + + + + + + + + + +result + + + + + + + + + +S + + + + + + + + + + PercentRank `percent_rank` Window function that returns the percent rank value within the aggregation window @@ -10645,6 +10734,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + PosExplode `posexplode_outer`, `posexplode` Given an input array produces a sequence of rows for each value in the array @@ -10824,32 +10939,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PreciseTimestampConversion Expression used internally to convert the TimestampType to Long and back without losing precision, i.e. in microseconds. Used in time windowing @@ -11120,6 +11209,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Quarter `quarter` Returns the quarter of the year for date, in the range 1 to 4 @@ -11235,32 +11350,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - RaiseError `raise_error` Throw an exception @@ -11491,6 +11580,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + RegExpExtractAll `regexp_extract_all` Extract all strings matching a regular expression corresponding to the regex group index @@ -11690,32 +11805,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Remainder `%`, `mod` Remainder or modulo @@ -11878,6 +11967,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Rint `rint` Rounds up a double value to the nearest double equal to an integer @@ -12062,32 +12177,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ScalaUDF User Defined Function, the UDF can choose to implement a RAPIDS accelerated interface to get better performance. @@ -12318,6 +12407,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ShiftLeft `shiftleft` Bitwise shift left (<<) @@ -12454,32 +12569,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ShiftRightUnsigned `shiftrightunsigned` Bitwise unsigned shift right (>>>) @@ -12685,6 +12774,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Sinh `sinh` Hyperbolic sine @@ -12822,32 +12937,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - SortArray `sort_array` Returns a sorted array with the input array and the ascending / descending order @@ -13057,6 +13146,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Sqrt `sqrt` Square root @@ -13215,32 +13330,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringInstr `instr` Instr string operator @@ -13487,6 +13576,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringRPad `rpad` Pad a string on the right @@ -13576,32 +13691,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringRepeat `repeat` StringRepeat operator that repeats the given strings with numbers of times given by repeatTimes @@ -13848,6 +13937,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringToMap `str_to_map` Creates a map after splitting the input string into pairs of key-value strings @@ -13937,32 +14052,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringTranslate `translate` StringTranslate operator @@ -14256,6 +14345,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Substring `substr`, `substring` Substring operator @@ -14345,32 +14460,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - SubstringIndex `substring_index` substring_index operator @@ -14682,6 +14771,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Tanh `tanh` Hyperbolic tangent @@ -14772,32 +14887,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - TimeAdd Adds interval to timestamp @@ -15096,6 +15185,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + TransformValues `transform_values` Transform values in a map using a transform function @@ -15164,32 +15279,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - UnaryMinus `negative` Negate a numeric value @@ -15490,6 +15579,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + UnscaledValue Convert a Decimal to an unscaled long value for some aggregation optimizations @@ -15537,32 +15652,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Upper `upper`, `ucase` String uppercase operator diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py new file mode 100644 index 00000000000..ac7b56f35a4 --- /dev/null +++ b/integration_tests/src/main/python/url_test.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error +from data_gen import * +from marks import * +from pyspark.sql.types import * +import pyspark.sql.functions as f +from spark_session import is_before_spark_320 + +# regex to generate limit length urls with HOST, PATH, QUERY, REF, PROTOCOL, FILE, AUTHORITY, USERINFO +url_pattern = r'((http|https|ftp)://)(([a-zA-Z][a-zA-Z0-9]{0,2}\.){0,3}([a-zA-Z][a-zA-Z0-9]{0,2})\.([a-zA-Z][a-zA-Z0-9]{0,2}))' \ + r'(:[0-9]{1,3}){0,1}(/[a-zA-Z0-9]{1,3}){0,3}(\?[a-zA-Z0-9]{1,3}=[a-zA-Z0-9]{1,3}){0,1}(#([a-zA-Z0-9]{1,3})){0,1}' + +url_pattern_with_key = r'((http|https|ftp|file)://)(([a-z]{1,3}\.){0,3}([a-z]{1,3})\.([a-z]{1,3}))' \ + r'(:[0-9]{1,3}){0,1}(/[a-z]{1,3}){0,3}(\?key=[a-z]{1,3}){0,1}(#([a-z]{1,3})){0,1}' + +url_gen = StringGen(url_pattern) + +def test_parse_url_host(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'HOST')" + )) + +def test_parse_url_path(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'PATH')" + )) + +def test_parse_url_query(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'QUERY')" + )) + +def test_parse_url_ref(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'REF')" + )) + +def test_parse_url_protocol(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'PROTOCOL')" + )) + +def test_parse_url_file(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'FILE')" + )) + +def test_parse_url_authority(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'AUTHORITY')" + )) + +def test_parse_url_userinfo(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'USERINFO')" + )) + +def test_parse_url_with_no_query_key(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen, length=100).selectExpr( + "a", + "parse_url(a, 'HOST', '')", + "parse_url(a, 'PATH', '')", + "parse_url(a, 'REF', '')", + "parse_url(a, 'PROTOCOL', '')", + "parse_url(a, 'FILE', '')", + "parse_url(a, 'AUTHORITY', '')", + "parse_url(a, 'USERINFO', '')" + )) + +def test_parse_url_with_query_key(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, StringGen(url_pattern_with_key)).selectExpr( + "a", + "parse_url(a, 'QUERY', 'key')" + )) + +def test_parse_url_invalid_failonerror(): + assert_gpu_and_cpu_error( + lambda spark : unary_op_df(spark, StringGen()).selectExpr( + "a","parse_url(a, 'USERINFO')").collect(), + conf={'spark.sql.ansi.enabled': 'true'}, + error_message='IllegalArgumentException' if is_before_spark_320() else 'URISyntaxException') + +def test_parse_url_too_many_args(): + assert_gpu_and_cpu_error( + lambda spark : unary_op_df(spark, StringGen()).selectExpr( + "a","parse_url(a, 'USERINFO', 'key', 'value')").collect(), + conf={}, + error_message='parse_url function requires two or three arguments') \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 387f5f21645..a724afb45b8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3125,6 +3125,19 @@ object GpuOverrides extends Logging { ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("idx", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), (a, conf, p, r) => new GpuRegExpExtractAllMeta(a, conf, p, r)), + expr[ParseUrl]( + "Extracts a part from a URL", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("url", TypeSig.STRING, TypeSig.STRING), + ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + // Should really be an OptionalParam + Some(RepeatingParamCheck("key", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { + val failOnError = SQLConf.get.ansiEnabled + override def convertToGpu(): GpuExpression = { + GpuParseUrl(childExprs.map(_.convertToGpu()), failOnError) + } + }), expr[Length]( "String character length or binary byte length", ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala new file mode 100644 index 00000000000..d1dbcd3a086 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.net.URISyntaxException + +import ai.rapids.cudf.{ColumnVector, DType, RegexProgram, Scalar, Table} +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.Arm._ +import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.shims.ShimExpression + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.shims.RapidsErrorUtils +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.types.UTF8String + +object GpuParseUrl { + private val HOST = "HOST" + private val PATH = "PATH" + private val QUERY = "QUERY" + private val REF = "REF" + private val PROTOCOL = "PROTOCOL" + private val FILE = "FILE" + private val AUTHORITY = "AUTHORITY" + private val USERINFO = "USERINFO" + private val REGEXPREFIX = """(&|^|\?)""" + private val REGEXSUBFIX = "=([^&]*)" +} + +case class GpuParseUrl(children: Seq[Expression], + failOnErrorOverride: Boolean = SQLConf.get.ansiEnabled) + extends GpuExpression with ShimExpression with ExpectsInputTypes { + + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) + + override def nullable: Boolean = true + override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + override def prettyName: String = "parse_url" + + import GpuParseUrl._ + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size > 3 || children.size < 2) { + RapidsErrorUtils.parseUrlWrongNumArgs(children.size) + } else { + super[ExpectsInputTypes].checkInputDataTypes() + } + } + + private def getPattern(key: UTF8String): RegexProgram = { + val regex = REGEXPREFIX + key.toString + REGEXSUBFIX + new RegexProgram(regex) + } + + private def reValid(url: ColumnVector): ColumnVector = { + // TODO: Validite the url + val regex = """^[^ ]*$""" + val prog = new RegexProgram(regex) + withResource(url.matchesRe(prog)) { isMatch => + if (failOnErrorOverride) { + withResource(isMatch.all()) { allMatch => + if (!allMatch.getBoolean) { + val invalidUrl = UTF8String.fromString(url.toString()) + val exception = new URISyntaxException("", "") + throw RapidsErrorUtils.invalidUrlException(invalidUrl, exception) + } + } + } + withResource(Scalar.fromNull(DType.STRING)) { nullScalar => + isMatch.ifElse(url, nullScalar) + } + } + } + + private def reMatch(url: ColumnVector, partToExtract: String): ColumnVector = { + val regex = """^(([^:/?#]+):)(//((([^:]*:?[^\@]*)\@)?(\[[0-9A-Za-z%.:]*\]|[^/?#:]*)""" + + """(:[0-9]+)?))?(([^?#]*)(\?([^#]*))?)(#(.*))?""" + val prog = new RegexProgram(regex) + withResource(url.extractRe(prog)) { table: Table => + partToExtract match { + case HOST => table.getColumn(6).incRefCount() + case PATH => table.getColumn(9).incRefCount() + case QUERY => table.getColumn(10).incRefCount() + case REF => table.getColumn(12).incRefCount() + case PROTOCOL => table.getColumn(1).incRefCount() + case FILE => table.getColumn(8).incRefCount() + case AUTHORITY => table.getColumn(3).incRefCount() + case USERINFO => table.getColumn(5).incRefCount() + case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") + } + } + } + + private def emptyToNulls(cv: ColumnVector): ColumnVector = { + withResource(ColumnVector.fromStrings("")) { empty => + withResource(ColumnVector.fromStrings(null)) { nulls => + cv.findAndReplaceAll(empty, nulls) + } + } + } + + // private def isHost(cv: ColumnVector): ColumnVector = { + // // TODO: Valid if it is a valid host name, including ipv4, ipv6 and hostname + // cv + // } + + def doColumnar(numRows: Int, url: GpuScalar, partToExtract: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(url, numRows, StringType)) { urlCol => + doColumnar(urlCol, partToExtract) + } + } + + def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar): ColumnVector = { + val valid = reValid(url.getBase) + val part = partToExtract.getValue.asInstanceOf[UTF8String].toString + val matched = withResource(valid) { _ => + reMatch(valid, part) + } + if (part == HOST) { + // withResource(matched) { _ => + // isHost(matched) + // } + withResource(matched) { _ => + emptyToNulls(matched) + } + } else if (part == QUERY || part == REF) { + val resWithNulls = withResource(matched) { _ => + emptyToNulls(matched) + } + withResource(resWithNulls) { _ => + resWithNulls.substring(1) + } + } else if (part == PATH || part == FILE) { + matched + } else { + withResource(matched) { _ => + emptyToNulls(matched) + } + } + } + + def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar, key: GpuScalar): ColumnVector = { + val query = partToExtract.getValue.asInstanceOf[UTF8String].toString + if (query != QUERY) { + // return a null columnvector + return ColumnVector.fromStrings(null, null) + } + val matched = reMatch(url.getBase, query) + val keyStr = key.getValue.asInstanceOf[UTF8String] + val queryValue = withResource(matched) { _ => + withResource(matched.extractRe(getPattern(keyStr))) { table: Table => + table.getColumn(1).incRefCount() + } + } + withResource(queryValue) { _ => + emptyToNulls(queryValue) + } + } + + override def columnarEval(batch: ColumnarBatch): Any = { + if (children.size == 2) { + val Seq(url, partToExtract) = children + withResourceIfAllowed(url.columnarEval(batch)) { val0 => + withResourceIfAllowed(partToExtract.columnarEval(batch)) { val1 => + (val0, val1) match { + case (v0: GpuColumnVector, v1: GpuScalar) => + GpuColumnVector.from(doColumnar(v0, v1), dataType) + case _ => + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + } + } + } + } else { + // 3-arg, i.e. QUERY with key + assert(children.size == 3) + val Seq(url, partToExtract, key) = children + withResourceIfAllowed(url.columnarEval(batch)) { val0 => + withResourceIfAllowed(partToExtract.columnarEval(batch)) { val1 => + withResourceIfAllowed(key.columnarEval(batch)) { val2 => + (val0, val1, val2) match { + case (v0: GpuColumnVector, v1: GpuScalar, v2: GpuScalar) => + GpuColumnVector.from(doColumnar(v0, v1, v2), dataType) + case _ => + throw new + UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + } + } + } + } + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index f23229e0956..9259f0d3e90 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -23,8 +23,10 @@ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} +import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -81,4 +83,12 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { throw new AnalysisException(s"$tableIdentifier already exists.") } + + def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { + TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + } + + def invalidUrlException(url: UFT8String, e: Throwable): Throwable = { + new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e) + } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b301397255a..6a569d98248 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -24,10 +24,14 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims +import java.net.URISyntaxException + import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} +import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -85,4 +89,12 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } + + def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { + TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + } + + def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { + QueryExecutionErrors.invalidUrlError(url, e) + } } diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 6fa5b8350a5..b5997e4e6db 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -19,10 +19,14 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims +import java.net.URISyntaxException + import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} +import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -83,4 +87,12 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } + + def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { + TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + } + + def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { + QueryExecutionErrors.invalidUrlError(url, e) + } } diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 4b81f540e40..a0e827150d5 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -23,11 +23,15 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims +import java.net.URISyntaxException + import org.apache.spark.SparkDateTimeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} +import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { @@ -79,4 +83,12 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { def sqlArrayIndexNotStartAtOneError(): RuntimeException = { new ArrayIndexOutOfBoundsException("SQL array indices start at 1") } + + def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { + TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + } + + def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { + QueryExecutionErrors.invalidUrlError(url, e) + } } diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 3585910993d..43be246548a 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -20,11 +20,15 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims +import java.net.URISyntaxException + import org.apache.spark.SparkDateTimeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} +import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { @@ -87,4 +91,12 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { override def intervalDivByZeroError(origin: Origin): ArithmeticException = { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } + + def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { + TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + } + + def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { + QueryExecutionErrors.invalidUrlError(url, e) + } } diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index f0b74c1c276..6dde4e0a8f9 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -20,11 +20,15 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims +import java.net.URISyntaxException + import org.apache.spark.SparkDateTimeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} +import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { @@ -87,4 +91,14 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { override def intervalDivByZeroError(origin: Origin): ArithmeticException = { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } + + def parseUrlWrongNumArgs(actual: Int): Throwable = { + throw QueryCompilationErrors.wrongNumArgsError( + "parse_url", Seq("[2, 3]"), actualNumber + ) + } + + def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { + QueryExecutionErrors.invalidUrlError(url, e) + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala new file mode 100644 index 00000000000..7e97d266fb5 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -0,0 +1,246 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.unsafe.types.UTF8String + +class UrlFunctionsSuite extends SparkQueryCompareTestSuite { + def validUrlEdgeCasesDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + // [In search of the perfect URL validation regex](https://mathiasbynens.be/demo/url-regex) + Seq[String]( + "http://foo.com/blah_blah", + "http://foo.com/blah_blah/", + "http://foo.com/blah_blah_(wikipedia)", + "http://foo.com/blah_blah_(wikipedia)_(again)", + "http://www.example.com/wpstyle/?p=364", + "https://www.example.com/foo/?bar=baz&inga=42&quux", + // "http://✪df.ws/123", + "http://userid:password@example.com:8080", + "http://userid:password@example.com:8080/", + "http://userid:password@example.com", + "http://userid:password@example.com/", + "http://142.42.1.1/", + "http://142.42.1.1:8080/", + // "http://➡.ws/䨹", + // "http://⌘.ws", + // "http://⌘.ws/", + "http://foo.com/blah_(wikipedia)#cite-1", + "http://foo.com/blah_(wikipedia)_blah#cite-1", + // "http://foo.com/unicode_(✪)_in_parens", + "http://foo.com/(something)?after=parens", + // "http://☺.damowmow.com/", + "http://code.google.com/events/#&product=browser", + "http://j.mp", + "ftp://foo.bar/baz", + "http://foo.bar/?q=Test%20URL-encoded%20stuff", + // "http://مثال.إختبار", + // "http://例子.测试", + // "http://उदाहरण.परीक्षा", + "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", + "http://1337.net", + "http://a.b-c.de", + "http://223.255.255.254", + // "https://foo_bar.example.com/", + // "http://", + // "http://.", + // "http://..", + // "http://../", + "http://?", + "http://??", + "http://??/", + "http://#", + // "http://##", + // "http://##/", + "http://foo.bar?q=Spaces should be encoded", + "//", + // "//a", + // "///a", + // "///", + "http:///a", + // "foo.com", + "rdar://1234", + "h://test", + "http:// shouldfail.com", + ":// should fail", + "http://foo.bar/foo(bar)baz quux", + "ftps://foo.bar/", + // "http://-error-.invalid/", + "http://a.b--c.de/", + // "http://-a.b.co", + // "http://a.b-.co", + "http://0.0.0.0", + "http://10.1.1.0", + "http://10.1.1.255", + "http://224.1.1.1", + // "http://1.1.1.1.1", + // "http://123.123.123", + "http://3628126748", + // "http://.www.foo.bar/", + "http://www.foo.bar./", + // "http://.www.foo.bar./", + "http://10.1.1.1", + "http://10.1.1.254" + ).toDF("urls") + } + + def urlCasesFromSpark(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://userinfo@spark.apache.org/path?query=1#Ref", + "https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two", + "http://user:pass@host", + "http://user:pass@host/", + "http://user:pass@host/?#", + "http://user:pass@host/file;param?query;p2" + ).toDF("urls") + } + + def urlCasesFromSparkInvalid(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "inva lid://user:pass@host/file;param?query;p2" + ).toDF("urls") + } + + def urlCasesFromJavaUriLib(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "ftp://ftp.is.co.za/rfc/rfc1808.txt", + "http://www.math.uio.no/faq/compression-faq/part1.html", + "telnet://melvyl.ucop.edu/", + "http://www.w3.org/Addressing/", + "ftp://ds.internic.net/rfc/", + "http://www.ics.uci.edu/pub/ietf/uri/historical.html#WARNING", + "http://www.ics.uci.edu/pub/ietf/uri/#Related", + "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:80/index.html", + "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:10%12]:80/index.html", + "http://[1080:0:0:0:8:800:200C:417A]/index.html", + "http://[1080:0:0:0:8:800:200C:417A%1]/index.html", + "http://[3ffe:2a00:100:7031::1]", + "http://[1080::8:800:200C:417A]/foo", + "http://[::192.9.5.5]/ipng", + "http://[::192.9.5.5%interface]/ipng", + "http://[::FFFF:129.144.52.38]:80/index.html", + "http://[2010:836B:4179::836B:4179]", + "http://[FF01::101]", + "http://[::1]", + "http://[::]", + "http://[::%hme0]", + "http://[0:0:0:0:0:0:13.1.68.3]", + "http://[0:0:0:0:0:FFFF:129.144.52.38]", + "http://[0:0:0:0:0:FFFF:129.144.52.38%33]", + "http://[0:0:0:0:0:ffff:1.2.3.4]", + "http://[::13.1.68.3]" + ).toDF("urls") + } + + def urlWithQueryKey(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://foo.com/blah_blah?foo=bar&baz=blah#vertical-bar" + ).toDF("urls") + } + + def utf8UrlCases(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://✪df.ws/123", + "http://➡.ws/䨹", + "http://⌘.ws", + "http://⌘.ws/", + "http://foo.com/unicode_(✪)_in_parens", + "http://☺.damowmow.com/", + "http://مثال.إختبار", + "http://例子.测试", + "http://उदाहरण.परीक्षा" + ).map(UTF8String.fromString(_).toString()).toDF("urls") + } + + def unsupportedUrlCases(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "https://foo_bar.example.com/", + "http://", + "http://.", + "http://..", + "http://../", + "http://##", + "http://##/", + "//a", + "///a", + "///", + "foo.com", + "http://-error-.invalid/", + "http://-a.b.co", + "http://a.b-.co", + "http://1.1.1.1.1", + "http://123.123.123", + "http://.www.foo.bar/", + "http://.www.foo.bar./" + ).toDF("urls") + } + + def parseUrls(frame: DataFrame): DataFrame = { + frame.selectExpr( + "urls", + "parse_url(urls, 'HOST') as HOST", + "parse_url(urls, 'PATH') as PATH", + "parse_url(urls, 'QUERY') as QUERY", + "parse_url(urls, 'REF') as REF", + "parse_url(urls, 'PROTOCOL') as PROTOCOL", + "parse_url(urls, 'FILE') as FILE", + "parse_url(urls, 'AUTHORITY') as AUTHORITY", + "parse_url(urls, 'USERINFO') as USERINFO") + } + + // def disableGpuRegex(): SparkConf = { + // new SparkConf() + // .set("spark.rapids.sql.regexp.enabled", "false") + // } + + testSparkResultsAreEqual("Test parse_url edge cases from internet", validUrlEdgeCasesDf) { + parseUrls + } + + testSparkResultsAreEqual("Test parse_url cases from Spark", urlCasesFromSpark) { + parseUrls + } + + testSparkResultsAreEqual("Test parse_url invalid cases from Spark", urlCasesFromSparkInvalid) { + parseUrls + } + + testSparkResultsAreEqual("Test parse_url cases from java URI library", urlCasesFromJavaUriLib) { + parseUrls + } + + // testSparkResultsAreEqual("Test parse_url utf-8 cases", utf8UrlCases) { + // parseUrls + // } + + // testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { + // parseUrls + // } + + testSparkResultsAreEqual("Test parse_url with query and key", urlWithQueryKey) { + frame => frame.selectExpr( + "urls", + "parse_url(urls, 'QUERY', 'foo') as QUERY") + } +} \ No newline at end of file diff --git a/tools/generated_files/operatorsScore.csv b/tools/generated_files/operatorsScore.csv index 532ec2d9b02..6fe3cb47a17 100644 --- a/tools/generated_files/operatorsScore.csv +++ b/tools/generated_files/operatorsScore.csv @@ -178,6 +178,7 @@ Not,4 NthValue,4 OctetLength,4 Or,4 +ParseUrl,4 PercentRank,4 PivotFirst,4 Pmod,4 diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv index 391e4c199bd..4f01e3b7379 100644 --- a/tools/generated_files/supportedExprs.csv +++ b/tools/generated_files/supportedExprs.csv @@ -373,6 +373,10 @@ Or,S,`or`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA, Or,S,`or`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Or,S,`or`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Or,S,`or`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,url,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,partToExtract,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,key,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA PercentRank,S,`percent_rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS PercentRank,S,`percent_rank`,None,window,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Pmod,S,`pmod`,None,project,lhs,NA,S,S,S,S,S,S,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA From 729fe35a5067fad89d5b872d4aebd598e9e829dc Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 20 Jul 2023 20:56:50 +0800 Subject: [PATCH 02/27] fix build failures --- .../scala/org/apache/spark/sql/rapids/urlFunctions.scala | 8 +++++--- .../apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 6 +++--- .../apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++-- .../apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++-- .../apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++-- .../apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++-- .../apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 7 ++++--- 7 files changed, 20 insertions(+), 17 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index d1dbcd3a086..73b280aadd5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -60,10 +60,12 @@ case class GpuParseUrl(children: Seq[Expression], override def checkInputDataTypes(): TypeCheckResult = { if (children.size > 3 || children.size < 2) { - RapidsErrorUtils.parseUrlWrongNumArgs(children.size) - } else { - super[ExpectsInputTypes].checkInputDataTypes() + RapidsErrorUtils.parseUrlWrongNumArgs(children.size) match { + case res: Some[TypeCheckResult] => return res.get + case _ => // error message has been thrown + } } + super[ExpectsInputTypes].checkInputDataTypes() } private def getPattern(key: UTF8String): RegexProgram = { diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 9259f0d3e90..5d74bbac86a 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -84,11 +84,11 @@ object RapidsErrorUtils { throw new AnalysisException(s"$tableIdentifier already exists.") } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } - def invalidUrlException(url: UFT8String, e: Throwable): Throwable = { + def invalidUrlException(url: UTF8String, e: Throwable): Throwable = { new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e) } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 6a569d98248..115d4e93ba8 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -90,8 +90,8 @@ object RapidsErrorUtils { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b5997e4e6db..b7b95016047 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -88,8 +88,8 @@ object RapidsErrorUtils { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index a0e827150d5..841b3a96078 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -84,8 +84,8 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { new ArrayIndexOutOfBoundsException("SQL array indices start at 1") } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 43be246548a..aaaecc86256 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -92,8 +92,8 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 6dde4e0a8f9..032fed254ef 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -25,7 +25,7 @@ import java.net.URISyntaxException import org.apache.spark.SparkDateTimeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} import org.apache.spark.unsafe.types.UTF8String @@ -92,10 +92,11 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } - def parseUrlWrongNumArgs(actual: Int): Throwable = { + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { throw QueryCompilationErrors.wrongNumArgsError( - "parse_url", Seq("[2, 3]"), actualNumber + "parse_url", Seq("[2, 3]"), actual ) + None } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { From 85c3284a7b295b8f0557e33d97ac948e101d6ba6 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 3 Aug 2023 09:27:24 +0800 Subject: [PATCH 03/27] regex refactor --- .../src/main/python/regexp_test.py | 28 +++++++- .../nvidia/spark/rapids/GpuOverrides.scala | 2 +- .../spark/sql/rapids/urlFunctions.scala | 68 ++++++++++++------- .../spark/rapids/UrlFunctionsSuite.scala | 42 ++++++++---- 4 files changed, 98 insertions(+), 42 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index 1bc204699ea..2c6747af593 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -436,7 +436,7 @@ def test_regexp_replace_character_set_negated(): 'regexp_replace(a, "[^\n]", "1")'), conf=_regexp_conf) -def test_regexp_extract(): +def test_regexp_extract_good(): gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}/?[abcd]{1,3}') assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).selectExpr( @@ -451,6 +451,32 @@ def test_regexp_extract(): 'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)", 3)', 'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)$", 3)'), conf=_regexp_conf) + +def test_regexp_extract_utf8(): + gen = mk_str_gen('[你我他✪]{1,3}[0-9]{1,3}/?[一二三四]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "([0-9]+)", 1)', + 'regexp_extract(a, "([0-9])([abcd]+)", 1)', + 'regexp_extract(a, "([0-9])([abcd]+)", 2)', + 'regexp_extract(a, "^([你我他✪]*)([0-9]*)([一二三四]*)$", 1)', + 'regexp_extract(a, "^([你我他✪]*)([0-9]*)([一二三四]*)$", 2)', + 'regexp_extract(a, "^([你我他✪]*)([0-9]*)([一二三四]*)$", 3)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)\\\\/([a-d]*)", 3)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)\\\\/([a-d]*)$", 3)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)", 3)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)$", 3)'), + conf=_regexp_conf) + +# ^([^:/?#]+):(//)?(([^:]*:?[^\x40]*)\x40)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*)(:[0-9]+)?([^?#]*)(\?[^#]*)?(#.*)?$ + +def test_regexp_extract_url(): + gen = mk_str_gen('[你我他✪]{1,3}[0-9]{1,3}/?[一二三四]{1,3}') + query = r"""regexp_extract(a, r'^([^:/?#]+):(//)?(([^:]*:?[^\x40]*)\x40)?([0-9A-Za-z%.:\[\]]*|[^/#:?]*)(:[0-9]+)?$', 1)""" + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + query), + conf=_regexp_conf) def test_regexp_extract_no_match(): gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 5705e38f8c9..2a25298fa84 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3129,7 +3129,7 @@ object GpuOverrides extends Logging { // Should really be an OptionalParam Some(RepeatingParamCheck("key", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { - val failOnError = SQLConf.get.ansiEnabled + val failOnError = a.failOnError override def convertToGpu(): GpuExpression = { GpuParseUrl(childExprs.map(_.convertToGpu()), failOnError) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 73b280aadd5..8fd85f2382f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -41,7 +41,7 @@ object GpuParseUrl { private val FILE = "FILE" private val AUTHORITY = "AUTHORITY" private val USERINFO = "USERINFO" - private val REGEXPREFIX = """(&|^|\?)""" + private val REGEXPREFIX = """(&|^)""" private val REGEXSUBFIX = "=([^&]*)" } @@ -69,6 +69,8 @@ case class GpuParseUrl(children: Seq[Expression], } private def getPattern(key: UTF8String): RegexProgram = { + // SPARK-44500: in spark, the key is treated as a regex. + // In plugin we quote the key to be sure that we treat it as a literal value. val regex = REGEXPREFIX + key.toString + REGEXSUBFIX new RegexProgram(regex) } @@ -94,19 +96,26 @@ case class GpuParseUrl(children: Seq[Expression], } private def reMatch(url: ColumnVector, partToExtract: String): ColumnVector = { - val regex = """^(([^:/?#]+):)(//((([^:]*:?[^\@]*)\@)?(\[[0-9A-Za-z%.:]*\]|[^/?#:]*)""" + - """(:[0-9]+)?))?(([^?#]*)(\?([^#]*))?)(#(.*))?""" + // scalastyle:off line.size.limit + // val regex = """(([^:/?#]+):)(//((([^:]*:?[^\@]*)\@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*)(:[0-9]+)?))?(([^?#]*)(\?([^#]*))?)(#(.*))?""" + // 0 0 1 2 2 3 3 1 4 45 5 6 6 + val regex = """^(?:([^:/?#]+):(?://((?:([^@]+)@)(\[[0-9A-Za-z%.:]*\]|[^/#:?]*)(?::[0-9]+)?))?([^?#]*)(\?[^#]*)?(#.*)?)$""" + // scalastyle:on val prog = new RegexProgram(regex) withResource(url.extractRe(prog)) { table: Table => partToExtract match { - case HOST => table.getColumn(6).incRefCount() - case PATH => table.getColumn(9).incRefCount() - case QUERY => table.getColumn(10).incRefCount() - case REF => table.getColumn(12).incRefCount() - case PROTOCOL => table.getColumn(1).incRefCount() - case FILE => table.getColumn(8).incRefCount() - case AUTHORITY => table.getColumn(3).incRefCount() - case USERINFO => table.getColumn(5).incRefCount() + case HOST => table.getColumn(3).incRefCount() + case PATH => table.getColumn(4).incRefCount() + case QUERY => table.getColumn(5).incRefCount() + case REF => table.getColumn(6).incRefCount() + case PROTOCOL => table.getColumn(0).incRefCount() + case FILE => { + val path = table.getColumn(4) + val query = table.getColumn(5) + ColumnVector.stringConcatenate(Array(path, query)) + } + case AUTHORITY => table.getColumn(1).incRefCount() + case USERINFO => table.getColumn(2).incRefCount() case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") } } @@ -120,10 +129,17 @@ case class GpuParseUrl(children: Seq[Expression], } } - // private def isHost(cv: ColumnVector): ColumnVector = { - // // TODO: Valid if it is a valid host name, including ipv4, ipv6 and hostname - // cv - // } + private def unsetInvalidHost(cv: ColumnVector): ColumnVector = { + // TODO: Valid if it is a valid host name, including ipv4, ipv6 and hostname + // current it only exclude utf8 string + val regex = """^([?[a-zA-Z0-9\-.]+|\[[0-9A-Za-z%.:]*\])$""" + val prog = new RegexProgram(regex) + withResource(cv.matchesRe(prog)) { isMatch => + withResource(Scalar.fromNull(DType.STRING)) { nullScalar => + isMatch.ifElse(cv, nullScalar) + } + } + } def doColumnar(numRows: Int, url: GpuScalar, partToExtract: GpuScalar): ColumnVector = { withResource(GpuColumnVector.from(url, numRows, StringType)) { urlCol => @@ -138,11 +154,11 @@ case class GpuParseUrl(children: Seq[Expression], reMatch(valid, part) } if (part == HOST) { - // withResource(matched) { _ => - // isHost(matched) - // } - withResource(matched) { _ => - emptyToNulls(matched) + val valided = withResource(matched) { _ => + unsetInvalidHost(matched) + } + withResource(valided) { _ => + emptyToNulls(valided) } } else if (part == QUERY || part == REF) { val resWithNulls = withResource(matched) { _ => @@ -161,15 +177,17 @@ case class GpuParseUrl(children: Seq[Expression], } def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar, key: GpuScalar): ColumnVector = { - val query = partToExtract.getValue.asInstanceOf[UTF8String].toString - if (query != QUERY) { + val part = partToExtract.getValue.asInstanceOf[UTF8String].toString + if (part != QUERY) { // return a null columnvector return ColumnVector.fromStrings(null, null) } - val matched = reMatch(url.getBase, query) + val querys = withResource(reMatch(url.getBase, QUERY)) { matched => + matched.substring(1) + } val keyStr = key.getValue.asInstanceOf[UTF8String] - val queryValue = withResource(matched) { _ => - withResource(matched.extractRe(getPattern(keyStr))) { table: Table => + val queryValue = withResource(querys) { _ => + withResource(querys.extractRe(getPattern(keyStr))) { table: Table => table.getColumn(1).incRefCount() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index 7e97d266fb5..351e4d2fdc2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -49,9 +49,9 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://j.mp", "ftp://foo.bar/baz", "http://foo.bar/?q=Test%20URL-encoded%20stuff", - // "http://مثال.إختبار", - // "http://例子.测试", - // "http://उदाहरण.परीक्षा", + "http://مثال.إختبار", + "http://例子.测试", + "http://उदाहरण.परीक्षा", "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", "http://1337.net", "http://a.b-c.de", @@ -160,15 +160,27 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { def utf8UrlCases(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ Seq[String]( - "http://✪df.ws/123", - "http://➡.ws/䨹", - "http://⌘.ws", - "http://⌘.ws/", - "http://foo.com/unicode_(✪)_in_parens", - "http://☺.damowmow.com/", - "http://مثال.إختبار", - "http://例子.测试", - "http://उदाहरण.परीक्षा" + "http://user✪info@sp✪ark.apa✪che.org/pa✪th?que✪ry=1#R✪ef", + "http://@✪df.ws/123", + "http://@➡.ws/䨹", + "http://@⌘.ws", + "http://@⌘.ws/", + "http://@foo.com/unicode_(✪)_in_parens", + "http://@☺.damowmow.com/", + "http://@xxx☺.damowmow.com/", + "http://@مثال.إختبار/index.html?query=1#Ref", + "http://@例子.测试/index.html?query=1#Ref", + "http://@उदाहरण.परीक्षा/index.html?query=1#Ref" + // "http://user✪info@✪df.ws/123", + // "http://user✪info@➡.ws/䨹", + // "http://user✪info@⌘.ws", + // "http://user✪info@⌘.ws/", + // "http://user✪info@foo.com/unicode_(✪)_in_parens", + // "http://user✪info@☺.damowmow.com/", + // "http://user✪info@xxx☺.damowmow.com/", + // "http://user✪info@مثال.إختبار/index.html?query=1#Ref", + // "http://user✪info@例子.测试/index.html?query=1#Ref", + // "http://user✪info@उदाहरण.परीक्षा/index.html?query=1#Ref" ).map(UTF8String.fromString(_).toString()).toDF("urls") } @@ -230,9 +242,9 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { parseUrls } - // testSparkResultsAreEqual("Test parse_url utf-8 cases", utf8UrlCases) { - // parseUrls - // } + testSparkResultsAreEqual("Test parse_url utf-8 cases", utf8UrlCases) { + parseUrls + } // testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { // parseUrls From 4166362ed0c906f760369e771bafe2dad3b2d137 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 3 Aug 2023 16:27:56 +0800 Subject: [PATCH 04/27] Separate regexes and UTF-8 special characters support --- .../spark/sql/rapids/urlFunctions.scala | 85 +++++++++++++------ .../spark/rapids/UrlFunctionsSuite.scala | 51 ++--------- 2 files changed, 67 insertions(+), 69 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 8fd85f2382f..615609e5bb1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -43,6 +43,18 @@ object GpuParseUrl { private val USERINFO = "USERINFO" private val REGEXPREFIX = """(&|^)""" private val REGEXSUBFIX = "=([^&]*)" + // scalastyle:off line.size.limit + // a 0 0 b 1 c d 2 2 d 3 3ce e 1b 4 45 5 6 6 a + // val regex = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + private val HOST_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + private val PATH_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + private val QUERY_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(\?[^#]*)?(?:#.*)?)$""" + private val REF_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(#.*)?)$""" + private val PROTOCOL_REGEX = """^(?:([^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + private val FILE_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?((?:[^?#]*)(?:\?[^#]*)?)(?:#.*)?)$""" + private val AUTHORITY_REGEX = """^(?:(?:[^:/?#]+):(?://((?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + private val USERINFO_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:([^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + // scalastyle:on } case class GpuParseUrl(children: Seq[Expression], @@ -96,31 +108,50 @@ case class GpuParseUrl(children: Seq[Expression], } private def reMatch(url: ColumnVector, partToExtract: String): ColumnVector = { - // scalastyle:off line.size.limit - // val regex = """(([^:/?#]+):)(//((([^:]*:?[^\@]*)\@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*)(:[0-9]+)?))?(([^?#]*)(\?([^#]*))?)(#(.*))?""" - // 0 0 1 2 2 3 3 1 4 45 5 6 6 - val regex = """^(?:([^:/?#]+):(?://((?:([^@]+)@)(\[[0-9A-Za-z%.:]*\]|[^/#:?]*)(?::[0-9]+)?))?([^?#]*)(\?[^#]*)?(#.*)?)$""" - // scalastyle:on + val regex = partToExtract match { + case HOST => HOST_REGEX + case PATH => PATH_REGEX + case QUERY => QUERY_REGEX + case REF => REF_REGEX + case PROTOCOL => PROTOCOL_REGEX + case FILE => FILE_REGEX + case AUTHORITY => AUTHORITY_REGEX + case USERINFO => USERINFO_REGEX + case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") + } val prog = new RegexProgram(regex) withResource(url.extractRe(prog)) { table: Table => - partToExtract match { - case HOST => table.getColumn(3).incRefCount() - case PATH => table.getColumn(4).incRefCount() - case QUERY => table.getColumn(5).incRefCount() - case REF => table.getColumn(6).incRefCount() - case PROTOCOL => table.getColumn(0).incRefCount() - case FILE => { - val path = table.getColumn(4) - val query = table.getColumn(5) - ColumnVector.stringConcatenate(Array(path, query)) - } - case AUTHORITY => table.getColumn(1).incRefCount() - case USERINFO => table.getColumn(2).incRefCount() - case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") - } - } + table.getColumn(0).incRefCount() + } } + // private def reMatch_old(url: ColumnVector, partToExtract: String): ColumnVector = { + // scalastyle:off line.size.limit + // // val regex = """(([^:/?#]+):)(//((([^:]*:?[^\@]*)\@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*)(:[0-9]+)?))?(([^?#]*)(\?([^#]*))?)(#(.*))?""" + // // a 0 0 b 1c d 2 2 d 3 3ce e 1b 4 45 5 6 6 a + // // val regex = """^(?:([^:/?#]+):(?://((?:(?:([^@]*)@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(\?[^#]*)?(#.*)?)$""" + // val regex = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + // scalastyle:on + // val prog = new RegexProgram(regex) + // withResource(url.extractRe(prog)) { table: Table => + // partToExtract match { + // case HOST => table.getColumn(3).incRefCount() + // case PATH => table.getColumn(4).incRefCount() + // case QUERY => table.getColumn(5).incRefCount() + // case REF => table.getColumn(6).incRefCount() + // case PROTOCOL => table.getColumn(0).incRefCount() + // case FILE => { + // val path = table.getColumn(4) + // val query = table.getColumn(5) + // ColumnVector.stringConcatenate(Array(path, query)) + // } + // case AUTHORITY => table.getColumn(1).incRefCount() + // case USERINFO => table.getColumn(2).incRefCount() + // case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") + // } + // } + // } + private def emptyToNulls(cv: ColumnVector): ColumnVector = { withResource(ColumnVector.fromStrings("")) { empty => withResource(ColumnVector.fromStrings(null)) { nulls => @@ -196,11 +227,11 @@ case class GpuParseUrl(children: Seq[Expression], } } - override def columnarEval(batch: ColumnarBatch): Any = { + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { if (children.size == 2) { val Seq(url, partToExtract) = children - withResourceIfAllowed(url.columnarEval(batch)) { val0 => - withResourceIfAllowed(partToExtract.columnarEval(batch)) { val1 => + withResourceIfAllowed(url.columnarEvalAny(batch)) { val0 => + withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { val1 => (val0, val1) match { case (v0: GpuColumnVector, v1: GpuScalar) => GpuColumnVector.from(doColumnar(v0, v1), dataType) @@ -213,9 +244,9 @@ case class GpuParseUrl(children: Seq[Expression], // 3-arg, i.e. QUERY with key assert(children.size == 3) val Seq(url, partToExtract, key) = children - withResourceIfAllowed(url.columnarEval(batch)) { val0 => - withResourceIfAllowed(partToExtract.columnarEval(batch)) { val1 => - withResourceIfAllowed(key.columnarEval(batch)) { val2 => + withResourceIfAllowed(url.columnarEvalAny(batch)) { val0 => + withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { val1 => + withResourceIfAllowed(key.columnarEvalAny(batch)) { val2 => (val0, val1, val2) match { case (v0: GpuColumnVector, v1: GpuScalar, v2: GpuScalar) => GpuColumnVector.from(doColumnar(v0, v1, v2), dataType) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index 351e4d2fdc2..e9f53a8fa79 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -17,7 +17,6 @@ package com.nvidia.spark.rapids import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.unsafe.types.UTF8String class UrlFunctionsSuite extends SparkQueryCompareTestSuite { def validUrlEdgeCasesDf(session: SparkSession): DataFrame = { @@ -30,21 +29,21 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://foo.com/blah_blah_(wikipedia)_(again)", "http://www.example.com/wpstyle/?p=364", "https://www.example.com/foo/?bar=baz&inga=42&quux", - // "http://✪df.ws/123", + "http://✪df.ws/123", "http://userid:password@example.com:8080", "http://userid:password@example.com:8080/", "http://userid:password@example.com", "http://userid:password@example.com/", "http://142.42.1.1/", "http://142.42.1.1:8080/", - // "http://➡.ws/䨹", - // "http://⌘.ws", - // "http://⌘.ws/", + "http://➡.ws/䨹", + "http://⌘.ws", + "http://⌘.ws/", "http://foo.com/blah_(wikipedia)#cite-1", "http://foo.com/blah_(wikipedia)_blah#cite-1", - // "http://foo.com/unicode_(✪)_in_parens", + "http://foo.com/unicode_(✪)_in_parens", "http://foo.com/(something)?after=parens", - // "http://☺.damowmow.com/", + "http://☺.damowmow.com/", "http://code.google.com/events/#&product=browser", "http://j.mp", "ftp://foo.bar/baz", @@ -56,7 +55,7 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://1337.net", "http://a.b-c.de", "http://223.255.255.254", - // "https://foo_bar.example.com/", + "https://foo_bar.example.com/", // "http://", // "http://.", // "http://..", @@ -157,37 +156,9 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ).toDF("urls") } - def utf8UrlCases(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "http://user✪info@sp✪ark.apa✪che.org/pa✪th?que✪ry=1#R✪ef", - "http://@✪df.ws/123", - "http://@➡.ws/䨹", - "http://@⌘.ws", - "http://@⌘.ws/", - "http://@foo.com/unicode_(✪)_in_parens", - "http://@☺.damowmow.com/", - "http://@xxx☺.damowmow.com/", - "http://@مثال.إختبار/index.html?query=1#Ref", - "http://@例子.测试/index.html?query=1#Ref", - "http://@उदाहरण.परीक्षा/index.html?query=1#Ref" - // "http://user✪info@✪df.ws/123", - // "http://user✪info@➡.ws/䨹", - // "http://user✪info@⌘.ws", - // "http://user✪info@⌘.ws/", - // "http://user✪info@foo.com/unicode_(✪)_in_parens", - // "http://user✪info@☺.damowmow.com/", - // "http://user✪info@xxx☺.damowmow.com/", - // "http://user✪info@مثال.إختبار/index.html?query=1#Ref", - // "http://user✪info@例子.测试/index.html?query=1#Ref", - // "http://user✪info@उदाहरण.परीक्षा/index.html?query=1#Ref" - ).map(UTF8String.fromString(_).toString()).toDF("urls") - } - def unsupportedUrlCases(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ Seq[String]( - "https://foo_bar.example.com/", "http://", "http://.", "http://..", @@ -242,14 +213,10 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { parseUrls } - testSparkResultsAreEqual("Test parse_url utf-8 cases", utf8UrlCases) { - parseUrls + testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { + parseUrls } - // testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { - // parseUrls - // } - testSparkResultsAreEqual("Test parse_url with query and key", urlWithQueryKey) { frame => frame.selectExpr( "urls", From 43acceb6ad6138b8cb230ac9d8f951e4af498cce Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 3 Aug 2023 18:17:28 +0800 Subject: [PATCH 05/27] hostname validation --- .../nvidia/spark/rapids/UrlFunctionsSuite.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index e9f53a8fa79..446ba906aff 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -79,10 +79,10 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ":// should fail", "http://foo.bar/foo(bar)baz quux", "ftps://foo.bar/", - // "http://-error-.invalid/", + "http://-error-.invalid/", "http://a.b--c.de/", - // "http://-a.b.co", - // "http://a.b-.co", + "http://-a.b.co", + "http://a.b-.co", "http://0.0.0.0", "http://10.1.1.0", "http://10.1.1.255", @@ -90,9 +90,9 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { // "http://1.1.1.1.1", // "http://123.123.123", "http://3628126748", - // "http://.www.foo.bar/", + "http://.www.foo.bar/", "http://www.foo.bar./", - // "http://.www.foo.bar./", + "http://.www.foo.bar./", "http://10.1.1.1", "http://10.1.1.254" ).toDF("urls") @@ -169,13 +169,8 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "///a", "///", "foo.com", - "http://-error-.invalid/", - "http://-a.b.co", - "http://a.b-.co", "http://1.1.1.1.1", - "http://123.123.123", - "http://.www.foo.bar/", - "http://.www.foo.bar./" + "http://123.123.123" ).toDF("urls") } From 64d8373bcfed0d8ef1c6c5bf504f36600e6ba1ed Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 3 Aug 2023 18:18:00 +0800 Subject: [PATCH 06/27] hostname validation --- .../src/main/python/regexp_test.py | 28 +------------ .../spark/sql/rapids/urlFunctions.scala | 42 ++++++------------- 2 files changed, 13 insertions(+), 57 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index 6962046d2b8..ecc3634abcb 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -436,7 +436,7 @@ def test_regexp_replace_character_set_negated(): 'regexp_replace(a, "[^\n]", "1")'), conf=_regexp_conf) -def test_regexp_extract_good(): +def test_regexp_extract(): gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}/?[abcd]{1,3}') assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).selectExpr( @@ -451,32 +451,6 @@ def test_regexp_extract_good(): 'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)", 3)', 'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)$", 3)'), conf=_regexp_conf) - -def test_regexp_extract_utf8(): - gen = mk_str_gen('[你我他✪]{1,3}[0-9]{1,3}/?[一二三四]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract(a, "([0-9]+)", 1)', - 'regexp_extract(a, "([0-9])([abcd]+)", 1)', - 'regexp_extract(a, "([0-9])([abcd]+)", 2)', - 'regexp_extract(a, "^([你我他✪]*)([0-9]*)([一二三四]*)$", 1)', - 'regexp_extract(a, "^([你我他✪]*)([0-9]*)([一二三四]*)$", 2)', - 'regexp_extract(a, "^([你我他✪]*)([0-9]*)([一二三四]*)$", 3)', - 'regexp_extract(a, "^([a-d]*)([0-9]*)\\\\/([a-d]*)", 3)', - 'regexp_extract(a, "^([a-d]*)([0-9]*)\\\\/([a-d]*)$", 3)', - 'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)", 3)', - 'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)$", 3)'), - conf=_regexp_conf) - -# ^([^:/?#]+):(//)?(([^:]*:?[^\x40]*)\x40)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*)(:[0-9]+)?([^?#]*)(\?[^#]*)?(#.*)?$ - -def test_regexp_extract_url(): - gen = mk_str_gen('[你我他✪]{1,3}[0-9]{1,3}/?[一二三四]{1,3}') - query = r"""regexp_extract(a, r'^([^:/?#]+):(//)?(([^:]*:?[^\x40]*)\x40)?([0-9A-Za-z%.:\[\]]*|[^/#:?]*)(:[0-9]+)?$', 1)""" - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - query), - conf=_regexp_conf) def test_regexp_extract_no_match(): gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 615609e5bb1..d26d88dd72d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -125,33 +125,6 @@ case class GpuParseUrl(children: Seq[Expression], } } - // private def reMatch_old(url: ColumnVector, partToExtract: String): ColumnVector = { - // scalastyle:off line.size.limit - // // val regex = """(([^:/?#]+):)(//((([^:]*:?[^\@]*)\@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*)(:[0-9]+)?))?(([^?#]*)(\?([^#]*))?)(#(.*))?""" - // // a 0 0 b 1c d 2 2 d 3 3ce e 1b 4 45 5 6 6 a - // // val regex = """^(?:([^:/?#]+):(?://((?:(?:([^@]*)@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(\?[^#]*)?(#.*)?)$""" - // val regex = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" - // scalastyle:on - // val prog = new RegexProgram(regex) - // withResource(url.extractRe(prog)) { table: Table => - // partToExtract match { - // case HOST => table.getColumn(3).incRefCount() - // case PATH => table.getColumn(4).incRefCount() - // case QUERY => table.getColumn(5).incRefCount() - // case REF => table.getColumn(6).incRefCount() - // case PROTOCOL => table.getColumn(0).incRefCount() - // case FILE => { - // val path = table.getColumn(4) - // val query = table.getColumn(5) - // ColumnVector.stringConcatenate(Array(path, query)) - // } - // case AUTHORITY => table.getColumn(1).incRefCount() - // case USERINFO => table.getColumn(2).incRefCount() - // case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") - // } - // } - // } - private def emptyToNulls(cv: ColumnVector): ColumnVector = { withResource(ColumnVector.fromStrings("")) { empty => withResource(ColumnVector.fromStrings(null)) { nulls => @@ -161,9 +134,18 @@ case class GpuParseUrl(children: Seq[Expression], } private def unsetInvalidHost(cv: ColumnVector): ColumnVector = { - // TODO: Valid if it is a valid host name, including ipv4, ipv6 and hostname - // current it only exclude utf8 string - val regex = """^([?[a-zA-Z0-9\-.]+|\[[0-9A-Za-z%.:]*\])$""" + // scalastyle:off line.size.limit + // HostName parsing: + // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] + // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum + // toplabel = alpha | alpha *( alphanum | "-" ) alphanum + val hostname_regex = """^(((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)|""" + // TODO: ipv6_regex + val ipv6_regex = """\[[0-9A-Za-z%.:]*\]|""" + // TODO: ipv4_regex + val ipv4_regex = """([0-9.]*))$""" + // scalastyle:on + val regex = hostname_regex + ipv6_regex + ipv4_regex val prog = new RegexProgram(regex) withResource(cv.matchesRe(prog)) { isMatch => withResource(Scalar.fromNull(DType.STRING)) { nullScalar => From e6a45d312a5f83abd8f4a69de8c02eb339ab7080 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 4 Aug 2023 11:30:37 +0800 Subject: [PATCH 07/27] ipv4 validation --- .../spark/sql/rapids/urlFunctions.scala | 12 +++---- .../spark/rapids/UrlFunctionsSuite.scala | 34 ++++++++++++------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index d26d88dd72d..84aa341385e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -135,17 +135,17 @@ case class GpuParseUrl(children: Seq[Expression], private def unsetInvalidHost(cv: ColumnVector): ColumnVector = { // scalastyle:off line.size.limit - // HostName parsing: + // HostName parsing followed rules in java URI lib: // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum // toplabel = alpha | alpha *( alphanum | "-" ) alphanum - val hostname_regex = """^(((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)|""" - // TODO: ipv6_regex - val ipv6_regex = """\[[0-9A-Za-z%.:]*\]|""" + val hostname_regex = """((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)""" // TODO: ipv4_regex - val ipv4_regex = """([0-9.]*))$""" + val ipv4_regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" + // TODO: ipv6_regex + val ipv6_regex = """(\[[0-9A-Za-z%\.:]*\])""" // scalastyle:on - val regex = hostname_regex + ipv6_regex + ipv4_regex + val regex = "^(" + hostname_regex + "|" + ipv4_regex + "|" + ipv6_regex + ")$" val prog = new RegexProgram(regex) withResource(cv.matchesRe(prog)) { isMatch => withResource(Scalar.fromNull(DType.STRING)) { nullScalar => diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index 446ba906aff..0fbf637927a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -57,9 +57,9 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://223.255.255.254", "https://foo_bar.example.com/", // "http://", - // "http://.", - // "http://..", - // "http://../", + "http://.", + "http://..", + "http://../", "http://?", "http://??", "http://??/", @@ -87,11 +87,11 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://10.1.1.0", "http://10.1.1.255", "http://224.1.1.1", - // "http://1.1.1.1.1", - // "http://123.123.123", + "http://1.1.1.1.1", + "http://123.123.123", "http://3628126748", "http://.www.foo.bar/", - "http://www.foo.bar./", + "http://www.foo.bar./", "http://.www.foo.bar./", "http://10.1.1.1", "http://10.1.1.254" @@ -160,17 +160,27 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { import session.sqlContext.implicits._ Seq[String]( "http://", - "http://.", - "http://..", - "http://../", + // "http://.", + // "http://..", + // "http://../", "http://##", "http://##/", "//a", "///a", "///", - "foo.com", - "http://1.1.1.1.1", - "http://123.123.123" + "foo.com" + // "http://www.foo.bar.", + // "http://1.1.1.1.1", + // "http://123.123.123", + // "http://223.255.255.254", + // "http://142.42.1.1/", + // "http://142.42.1.1:8080/", + // "http://0.0.0.0", + // "http://10.1.1.0", + // "http://10.1.1.255", + // "http://224.1.1.1", + // "http://10.1.1.1", + // "http://10.1.1.254" ).toDF("urls") } From 8c4dc7a206ee2d2c709c8700ace2cf88c8d804d5 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 4 Aug 2023 13:46:38 +0800 Subject: [PATCH 08/27] verify --- docs/supported_ops.md | 338 +++++++++++++++++++++++------------------- 1 file changed, 182 insertions(+), 156 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index dd4a7d755bc..60d9aa1c5d9 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -15976,6 +15976,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + AggregateExpression Aggregate expression @@ -16172,32 +16198,6 @@ are limited. S -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ApproximatePercentile `percentile_approx`, `approx_percentile` Approximate percentile @@ -16372,6 +16372,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Average `avg`, `mean` Average aggregate operator @@ -16638,32 +16664,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - CollectSet `collect_set` Collect a set of unique elements, not supported in reduction @@ -16797,6 +16797,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Count `count` Count aggregate operator @@ -17063,32 +17089,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Last `last`, `last_value` last aggregate operator @@ -17222,6 +17222,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Max `max` Max aggregate operator @@ -17488,32 +17514,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PivotFirst PivotFirst operator @@ -17646,6 +17646,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StddevPop `stddev_pop` Aggregation computing population standard deviation @@ -17912,32 +17938,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Sum `sum` Sum aggregate operator @@ -18071,6 +18071,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + VariancePop `var_pop` Aggregation computing population variance @@ -18337,32 +18363,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - NormalizeNaNAndZero Normalize NaN and zero @@ -18436,6 +18436,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + HiveGenericUDF Hive Generic UDF, the UDF can choose to implement a RAPIDS accelerated interface to get better performance From fee5a3d0c17d70a0fa88a94a7f6862de32cf897c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 4 Aug 2023 18:57:53 +0800 Subject: [PATCH 09/27] wip ipv6 and SPARK-44500 --- .../spark/sql/rapids/urlFunctions.scala | 20 +++++++++++--- .../spark/rapids/UrlFunctionsSuite.scala | 26 ++++++++++++++++++- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 84aa341385e..be38e3aa2f5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -41,8 +41,8 @@ object GpuParseUrl { private val FILE = "FILE" private val AUTHORITY = "AUTHORITY" private val USERINFO = "USERINFO" - private val REGEXPREFIX = """(&|^)""" - private val REGEXSUBFIX = "=([^&]*)" + private val REGEXPREFIX = """(&|^)(""" + private val REGEXSUBFIX = "=)([^&]*)" // scalastyle:off line.size.limit // a 0 0 b 1 c d 2 2 d 3 3ce e 1b 4 45 5 6 6 a // val regex = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" @@ -80,10 +80,16 @@ case class GpuParseUrl(children: Seq[Expression], super[ExpectsInputTypes].checkInputDataTypes() } + private def escapeRegex(str: String): String = { + // Escape all regex special characters. It is a workaround for /Q and /E not working + // in cudf regex, can use Pattern.quote(str) instead after they are supported. + str.replaceAll("""[\^$.⎮?*+(){}-]""", "\\\\$0") + } + private def getPattern(key: UTF8String): RegexProgram = { // SPARK-44500: in spark, the key is treated as a regex. // In plugin we quote the key to be sure that we treat it as a literal value. - val regex = REGEXPREFIX + key.toString + REGEXSUBFIX + val regex = REGEXPREFIX + escapeRegex(key.toString) + REGEXSUBFIX new RegexProgram(regex) } @@ -143,6 +149,12 @@ case class GpuParseUrl(children: Seq[Expression], // TODO: ipv4_regex val ipv4_regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" // TODO: ipv6_regex + // IPv6address = hexseq [ ":" IPv4address ] + // | hexseq [ "::" [ hexpost ] ] + // | "::" [ hexpost ] + // hexpost = hexseq | hexseq ":" IPv4address | IPv4address + // hexseq = hex4 *( ":" hex4) + // hex4 = 1*4HEXDIG val ipv6_regex = """(\[[0-9A-Za-z%\.:]*\])""" // scalastyle:on val regex = "^(" + hostname_regex + "|" + ipv4_regex + "|" + ipv6_regex + ")$" @@ -201,7 +213,7 @@ case class GpuParseUrl(children: Seq[Expression], val keyStr = key.getValue.asInstanceOf[UTF8String] val queryValue = withResource(querys) { _ => withResource(querys.extractRe(getPattern(keyStr))) { table: Table => - table.getColumn(1).incRefCount() + table.getColumn(2).incRefCount() } } withResource(queryValue) { _ => diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index 0fbf637927a..2642d35773d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -156,6 +156,14 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ).toDF("urls") } + def urlWithRegexLikeQuery(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://foo/bar?abc=BAD&a.c=GOOD", + "http://foo/bar?a.c=GOOD&abc=BAD" + ).toDF("urls") + } + def unsupportedUrlCases(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ Seq[String]( @@ -225,6 +233,22 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { testSparkResultsAreEqual("Test parse_url with query and key", urlWithQueryKey) { frame => frame.selectExpr( "urls", - "parse_url(urls, 'QUERY', 'foo') as QUERY") + "parse_url(urls, 'QUERY', 'foo') as QUERY", + "parse_url(urls, 'QUERY', 'baz') as QUERY") + } + + test("Test parse_url with regex like query") { + withGpuSparkSession(spark => { + val frame = urlWithRegexLikeQuery(spark) + val result = frame.selectExpr( + "urls", + "parse_url(urls, 'QUERY', 'a.c') as QUERY") + import spark.implicits._ + val expected = Seq( + ("http://foo/bar?abc=BAD&a.c=GOOD", "GOOD"), + ("http://foo/bar?a.c=GOOD&abc=BAD", "GOOD") + ).toDF("urls", "QUERY") + assert(result.collect().deep == expected.collect().deep) + }) } } \ No newline at end of file From e81d8a3bade57b7c087ba67a3bc0296c571510ec Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 7 Aug 2023 10:30:43 +0800 Subject: [PATCH 10/27] optional protocol and ref validation --- .../spark/sql/rapids/urlFunctions.scala | 21 +++++----- .../spark/rapids/UrlFunctionsSuite.scala | 41 ++++--------------- 2 files changed, 19 insertions(+), 43 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index be38e3aa2f5..8a493225170 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -44,16 +44,17 @@ object GpuParseUrl { private val REGEXPREFIX = """(&|^)(""" private val REGEXSUBFIX = "=)([^&]*)" // scalastyle:off line.size.limit - // a 0 0 b 1 c d 2 2 d 3 3ce e 1b 4 45 5 6 6 a - // val regex = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" - private val HOST_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" - private val PATH_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" - private val QUERY_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(\?[^#]*)?(?:#.*)?)$""" - private val REF_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(#.*)?)$""" - private val PROTOCOL_REGEX = """^(?:([^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" - private val FILE_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?((?:[^?#]*)(?:\?[^#]*)?)(?:#.*)?)$""" - private val AUTHORITY_REGEX = """^(?:(?:[^:/?#]+):(?://((?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" - private val USERINFO_REGEX = """^(?:(?:[^:/?#]+):(?://(?:(?:(?:([^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + // val regex = """^(?:(?:([^:/?#]+):)?(?://((?:(?:([^:]*:?[^\@]*)@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(\?[^#]*)?(#.*)?)$""" + // a 0 0 b 1 c d 2 2 d 3 3ce e 1b 4 45 5 6 6 a + // val regex = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" + private val HOST_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val PATH_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val QUERY_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val REF_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val PROTOCOL_REGEX = """^(?:(?:([^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val FILE_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?((?:[^?#]*)(?:\?[^#]*)?)(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val AUTHORITY_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://((?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val USERINFO_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:([^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" // scalastyle:on } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index 2642d35773d..85e4cb0efc7 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -64,15 +64,15 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://??", "http://??/", "http://#", - // "http://##", - // "http://##/", + "http://##", + "http://##/", "http://foo.bar?q=Spaces should be encoded", - "//", - // "//a", - // "///a", - // "///", + // "//", + "//a", + "///a", + "///", "http:///a", - // "foo.com", + "foo.com", "rdar://1234", "h://test", "http:// shouldfail.com", @@ -168,27 +168,7 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { import session.sqlContext.implicits._ Seq[String]( "http://", - // "http://.", - // "http://..", - // "http://../", - "http://##", - "http://##/", - "//a", - "///a", - "///", - "foo.com" - // "http://www.foo.bar.", - // "http://1.1.1.1.1", - // "http://123.123.123", - // "http://223.255.255.254", - // "http://142.42.1.1/", - // "http://142.42.1.1:8080/", - // "http://0.0.0.0", - // "http://10.1.1.0", - // "http://10.1.1.255", - // "http://224.1.1.1", - // "http://10.1.1.1", - // "http://10.1.1.254" + "//" ).toDF("urls") } @@ -205,11 +185,6 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "parse_url(urls, 'USERINFO') as USERINFO") } - // def disableGpuRegex(): SparkConf = { - // new SparkConf() - // .set("spark.rapids.sql.regexp.enabled", "false") - // } - testSparkResultsAreEqual("Test parse_url edge cases from internet", validUrlEdgeCasesDf) { parseUrls } From 93a9342a0296f9372a0cc3f16accf444e8cd4cdf Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 8 Aug 2023 11:45:27 +0800 Subject: [PATCH 11/27] IPV6 VALIDATION --- .../spark/sql/rapids/urlFunctions.scala | 73 +++++++++++++++---- .../spark/rapids/UrlFunctionsSuite.scala | 42 +++++++++++ 2 files changed, 100 insertions(+), 15 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 8a493225170..70161c5d4d5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -82,9 +82,10 @@ case class GpuParseUrl(children: Seq[Expression], } private def escapeRegex(str: String): String = { - // Escape all regex special characters. It is a workaround for /Q and /E not working + // Escape all regex special characters in \^$.⎮?*+(){}-[] + // It is a workaround for /Q and /E not working // in cudf regex, can use Pattern.quote(str) instead after they are supported. - str.replaceAll("""[\^$.⎮?*+(){}-]""", "\\\\$0") + str.replaceAll("""[\^$.|?*+()\[\]-]""", "\\$0") } private def getPattern(key: UTF8String): RegexProgram = { @@ -146,25 +147,67 @@ case class GpuParseUrl(children: Seq[Expression], // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum // toplabel = alpha | alpha *( alphanum | "-" ) alphanum - val hostname_regex = """((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)""" - // TODO: ipv4_regex - val ipv4_regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" - // TODO: ipv6_regex - // IPv6address = hexseq [ ":" IPv4address ] - // | hexseq [ "::" [ hexpost ] ] - // | "::" [ hexpost ] - // hexpost = hexseq | hexseq ":" IPv4address | IPv4address - // hexseq = hex4 *( ":" hex4) - // hex4 = 1*4HEXDIG - val ipv6_regex = """(\[[0-9A-Za-z%\.:]*\])""" + val hostnameRegex = """((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)""" + // ipv4_regex + val ipv4Regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" + val simpleIpv6Regex = """(\[[0-9A-Za-z%.:]*\])""" // scalastyle:on - val regex = "^(" + hostname_regex + "|" + ipv4_regex + "|" + ipv6_regex + ")$" + val regex = "^(" + hostnameRegex + "|" + ipv4Regex + "|" + simpleIpv6Regex + ")$" val prog = new RegexProgram(regex) - withResource(cv.matchesRe(prog)) { isMatch => + val HostnameIpv4Res = withResource(cv.matchesRe(prog)) { isMatch => withResource(Scalar.fromNull(DType.STRING)) { nullScalar => isMatch.ifElse(cv, nullScalar) } } + // match the simple ipv6 address, valid ipv6 only when necessary + val simpleIpv6Prog = new RegexProgram(simpleIpv6Regex) + withResource(cv.matchesRe(simpleIpv6Prog)) { isMatch => + val anyIpv6 = withResource(isMatch.any()) { a => + a.isValid && a.getBoolean + } + if (anyIpv6) { + withResource(HostnameIpv4Res) { _ => + unsetInvalidIpv6Host(HostnameIpv4Res, isMatch) + } + } else { + HostnameIpv4Res + } + } + } + + private def unsetInvalidIpv6Host(cv: ColumnVector, simpleMatched: ColumnVector): ColumnVector = { + // scalastyle:off line.size.limit + // regex basically copied from https://stackoverflow.com/questions/53497/regular-expression-that-matches-valid-ipv6-addresses + // spilt the ipv6 regex into 8 parts to avoid the regex size limit + val ipv6Regex1 = """(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4})""" // 1:2:3:4:5:6:7:8 + val ipv6Regex2 = """(([0-9a-fA-F]{1,4}:){1,7}:)""" // 1:: 1:2:3:4:5:6:7:: + val ipv6Regex3 = """(([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4})""" // 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8 + val ipv6Regex4 = """(([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2})""" // 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8 + val ipv6Regex5 = """(([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3})""" // 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8 + val ipv6Regex6 = """(([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4})""" // 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8 + val ipv6Regex7 = """(([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5})""" // 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8 + val ipv6Regex8 = """([0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6}))""" // 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8 + val ipv6Regex9 = """(:((:[0-9a-fA-F]{1,4}){1,7}|:))""" // ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 :: + val ipv6Regex10 = """fe80:((:([0-9a-fA-F]{1,4})?){1,4})?%[0-9a-zA-Z]+|""" // fe80::7:8%eth0 fe80::7:8%1 (link-local IPv6 addresses with zone index) + val ipv6Regex11 = """::(ffff(:0{1,4})?:)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.){3}(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" + // ::255.255.255.255 ::ffff:255.255.255.255 ::ffff:0:255.255.255.255 (IPv4-mapped IPv6 addresses and IPv4-translated addresses) + val ipv6Regex12 = """([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.){3}(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" + // 2001:db8:3:4::192.0.2.33 64:ff9b::192.0.2.33 (IPv4-Embedded IPv6 Address) + // scalastyle:on + val regex = "^" + ipv6Regex1 + "|" + ipv6Regex2 + "|" + ipv6Regex3 + "|" + ipv6Regex4 + "|" + + ipv6Regex5 + ipv6Regex6 + "|" + ipv6Regex7 + "|" + ipv6Regex8 + "|" + ipv6Regex9 + + ipv6Regex10 + ipv6Regex11 + ipv6Regex12 + "$" + + val prog = new RegexProgram(regex) + + val invalidIpv6 = withResource(cv.matchesRe(prog)) { matched => + matched.not() + } + withResource(invalidIpv6) { _ => + withResource(Scalar.fromNull(DType.STRING)) { nullScalar => + invalidIpv6.ifElse(cv, nullScalar) + } + } } def doColumnar(numRows: Int, url: GpuScalar, partToExtract: GpuScalar): ColumnVector = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index 85e4cb0efc7..1fb80a44395 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -164,6 +164,44 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ).toDF("urls") } + def urlIpv6Host(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://[1:2:3:4:5:6:7:8]", + "http://[1::]", + "http://[1:2:3:4:5:6:7::]", + "http://[1::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1::7:8]", + "http://[1:2:3:4:5::7:8]", + "http://[1:2:3:4:5::8]", + "http://[1::6:7:8]", + "http://[1:2:3:4::6:7:8]", + "http://[1:2:3:4::8]", + "http://[1::5:6:7:8]", + "http://[1:2:3::5:6:7:8]", + "http://[1:2:3::8]", + "http://[1::4:5:6:7:8]", + "http://[1:2::4:5:6:7:8]", + "http://[1:2::8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::8]", + "http://[::2:3:4:5:6:7:8]", + "http://[::2:3:4:5:6:7:8]", + "http://[::8]", + "http://[::]", + "http://[fe80::7:8%eth0]", + "http://[fe80::7:8%1]", + "http://[::255.255.255.255]", + "http://[::ffff:255.255.255.255]", + "http://[::ffff:0:255.255.255.255]", + "http://[2001:db8:3:4::192.0.2.33]", + "http://[64:ff9b::192.0.2.33]" + ).toDF("urls") + } + def unsupportedUrlCases(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ Seq[String]( @@ -201,6 +239,10 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { parseUrls } + testSparkResultsAreEqual("Test parse_url ipv6 host", urlIpv6Host) { + parseUrls + } + testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { parseUrls } From 1ad665f6dbc2e594ef24cbc759889c74d7223fcf Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 8 Aug 2023 18:02:23 +0800 Subject: [PATCH 12/27] clean up --- integration_tests/src/main/python/url_test.py | 9 +------ .../nvidia/spark/rapids/GpuOverrides.scala | 7 +++++ .../spark/sql/rapids/urlFunctions.scala | 27 +++++++------------ .../sql/rapids/shims/RapidsErrorUtils.scala | 5 ---- .../sql/rapids/shims/RapidsErrorUtils.scala | 7 ----- .../sql/rapids/shims/RapidsErrorUtils.scala | 7 ----- .../sql/rapids/shims/RapidsErrorUtils.scala | 7 ----- .../sql/rapids/shims/RapidsErrorUtils.scala | 2 -- .../sql/rapids/shims/RapidsErrorUtils.scala | 7 ----- .../spark/rapids/UrlFunctionsSuite.scala | 23 +++++++++------- 10 files changed, 31 insertions(+), 70 deletions(-) diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index ac7b56f35a4..85277922486 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -32,7 +32,7 @@ def test_parse_url_host(): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( + lambda spark : unary_op_df(spark, url_gen, length=10).selectExpr( "a", "parse_url(a, 'HOST')" )) @@ -105,13 +105,6 @@ def test_parse_url_with_query_key(): "a", "parse_url(a, 'QUERY', 'key')" )) - -def test_parse_url_invalid_failonerror(): - assert_gpu_and_cpu_error( - lambda spark : unary_op_df(spark, StringGen()).selectExpr( - "a","parse_url(a, 'USERINFO')").collect(), - conf={'spark.sql.ansi.enabled': 'true'}, - error_message='IllegalArgumentException' if is_before_spark_320() else 'URISyntaxException') def test_parse_url_too_many_args(): assert_gpu_and_cpu_error( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 27b5cdb5b30..4b5cd160892 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3146,6 +3146,13 @@ object GpuOverrides extends Logging { Some(RepeatingParamCheck("key", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { val failOnError = a.failOnError + + override def tagExprForGpu(): Unit = { + if (failOnError) { + willNotWorkOnGpu("Fail on error is not supported on GPU when parsing urls.") + } + } + override def convertToGpu(): GpuExpression = { GpuParseUrl(childExprs.map(_.convertToGpu()), failOnError) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 70161c5d4d5..81311a93b1c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -16,8 +16,6 @@ package org.apache.spark.sql.rapids -import java.net.URISyntaxException - import ai.rapids.cudf.{ColumnVector, DType, RegexProgram, Scalar, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ @@ -96,22 +94,13 @@ case class GpuParseUrl(children: Seq[Expression], } private def reValid(url: ColumnVector): ColumnVector = { - // TODO: Validite the url - val regex = """^[^ ]*$""" - val prog = new RegexProgram(regex) - withResource(url.matchesRe(prog)) { isMatch => - if (failOnErrorOverride) { - withResource(isMatch.all()) { allMatch => - if (!allMatch.getBoolean) { - val invalidUrl = UTF8String.fromString(url.toString()) - val exception = new URISyntaxException("", "") - throw RapidsErrorUtils.invalidUrlException(invalidUrl, exception) - } + // Simply check if urls contain spaces for now, most validations will be done when extracting. + withResource(Scalar.fromString(" ")) { blank => + withResource(url.stringContains(blank)) { isMatch => + withResource(Scalar.fromNull(DType.STRING)) { nullScalar => + isMatch.ifElse(nullScalar, url) } } - withResource(Scalar.fromNull(DType.STRING)) { nullScalar => - isMatch.ifElse(url, nullScalar) - } } } @@ -147,6 +136,10 @@ case class GpuParseUrl(children: Seq[Expression], // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum // toplabel = alpha | alpha *( alphanum | "-" ) alphanum + + // Note: Spark allow an empty authority component only when it's followed by a non-empty path, + // query component, or fragment component. But in plugin, parse_url just simply allow empty + // authority component without checking if it is followed something or not. val hostnameRegex = """((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)""" // ipv4_regex val ipv4Regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" @@ -159,7 +152,7 @@ case class GpuParseUrl(children: Seq[Expression], isMatch.ifElse(cv, nullScalar) } } - // match the simple ipv6 address, valid ipv6 only when necessary + // match the simple ipv6 address, valid ipv6 only when necessary cause the regex is very long val simpleIpv6Prog = new RegexProgram(simpleIpv6Regex) withResource(cv.matchesRe(simpleIpv6Prog)) { isMatch => val anyIpv6 = withResource(isMatch.any()) { a => diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 5d74bbac86a..2084336cced 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} -import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -87,8 +86,4 @@ object RapidsErrorUtils { def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } - - def invalidUrlException(url: UTF8String, e: Throwable): Throwable = { - new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e) - } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 115d4e93ba8..c6122375cf2 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -24,14 +24,11 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import java.net.URISyntaxException - import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} -import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -93,8 +90,4 @@ object RapidsErrorUtils { def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } - - def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { - QueryExecutionErrors.invalidUrlError(url, e) - } } diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b7b95016047..fed5d0c4f6c 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -19,14 +19,11 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import java.net.URISyntaxException - import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} -import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -91,8 +88,4 @@ object RapidsErrorUtils { def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } - - def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { - QueryExecutionErrors.invalidUrlError(url, e) - } } diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 841b3a96078..e428a3377d5 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -23,15 +23,12 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import java.net.URISyntaxException - import org.apache.spark.SparkDateTimeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} -import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { @@ -87,8 +84,4 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } - - def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { - QueryExecutionErrors.invalidUrlError(url, e) - } } diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index aaaecc86256..9f0a365f6dc 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -20,8 +20,6 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import java.net.URISyntaxException - import org.apache.spark.SparkDateTimeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 032fed254ef..cca1493bcc1 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -20,15 +20,12 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import java.net.URISyntaxException - import org.apache.spark.SparkDateTimeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} -import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { @@ -98,8 +95,4 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { ) None } - - def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { - QueryExecutionErrors.invalidUrlError(url, e) - } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index 1fb80a44395..f8c37f3bf97 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -202,13 +202,16 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ).toDF("urls") } - def unsupportedUrlCases(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "http://", - "//" - ).toDF("urls") - } + // def unsupportedUrlCases(session: SparkSession): DataFrame = { + // // Spark allow an empty authority component only when it's followed by a non-empty path, + // // query component, or fragment component. But in plugin, parse_url just simply allow + // // empty authority component without checking if it is followed something or not. + // import session.sqlContext.implicits._ + // Seq[String]( + // "http://", + // "//" + // ).toDF("urls") + // } def parseUrls(frame: DataFrame): DataFrame = { frame.selectExpr( @@ -243,9 +246,9 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { parseUrls } - testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { - parseUrls - } + // testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { + // parseUrls + // } testSparkResultsAreEqual("Test parse_url with query and key", urlWithQueryKey) { frame => frame.selectExpr( From 3edb9298fd392f0e12104c3422101a7bbed9667c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 9 Aug 2023 17:42:37 +0800 Subject: [PATCH 13/27] Fix ipv6 validation, it is still wip --- docs/compatibility.md | 10 +++ .../spark/sql/rapids/urlFunctions.scala | 77 ++++++++++--------- .../spark/rapids/UrlFunctionsSuite.scala | 50 ++++-------- 3 files changed, 63 insertions(+), 74 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 0d2e11633de..37353350f8a 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -456,6 +456,16 @@ Spark stores timestamps internally relative to the JVM time zone. Converting an between time zones is not currently supported on the GPU. Therefore operations involving timestamps will only be GPU-accelerated if the time zone used by the JVM is UTC. +## URL parsing + +In Spark, parse_url is based on java's URI library, while the implementation in the RAPIDS Accelerator is based on regex extraction. Therefore, the results may be different in some edge cases. + +These are the known cases where running on the GPU will produce different results to the CPU: + +- Spark allow an empty authority component only when it's followed by a non-empty path, + query component, or fragment component. But in plugin, parse_url just simply allow empty + authority component without checking if it is followed something or not. So `parse_url('http://', 'HOST')` will return `null` in Spark, but return `""` in plugin. + ## Windowing ### Window Functions diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 81311a93b1c..2061c2e1149 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -42,9 +42,6 @@ object GpuParseUrl { private val REGEXPREFIX = """(&|^)(""" private val REGEXSUBFIX = "=)([^&]*)" // scalastyle:off line.size.limit - // val regex = """^(?:(?:([^:/?#]+):)?(?://((?:(?:([^:]*:?[^\@]*)@)?(\[[0-9A-Za-z%.:]*\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(\?[^#]*)?(#.*)?)$""" - // a 0 0 b 1 c d 2 2 d 3 3ce e 1b 4 45 5 6 6 a - // val regex = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#.*)?)$""" private val HOST_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" private val PATH_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" private val QUERY_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" @@ -79,17 +76,8 @@ case class GpuParseUrl(children: Seq[Expression], super[ExpectsInputTypes].checkInputDataTypes() } - private def escapeRegex(str: String): String = { - // Escape all regex special characters in \^$.⎮?*+(){}-[] - // It is a workaround for /Q and /E not working - // in cudf regex, can use Pattern.quote(str) instead after they are supported. - str.replaceAll("""[\^$.|?*+()\[\]-]""", "\\$0") - } - private def getPattern(key: UTF8String): RegexProgram = { - // SPARK-44500: in spark, the key is treated as a regex. - // In plugin we quote the key to be sure that we treat it as a literal value. - val regex = REGEXPREFIX + escapeRegex(key.toString) + REGEXSUBFIX + val regex = REGEXPREFIX + key.toString + REGEXSUBFIX new RegexProgram(regex) } @@ -136,14 +124,9 @@ case class GpuParseUrl(children: Seq[Expression], // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum // toplabel = alpha | alpha *( alphanum | "-" ) alphanum - - // Note: Spark allow an empty authority component only when it's followed by a non-empty path, - // query component, or fragment component. But in plugin, parse_url just simply allow empty - // authority component without checking if it is followed something or not. val hostnameRegex = """((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)""" - // ipv4_regex val ipv4Regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" - val simpleIpv6Regex = """(\[[0-9A-Za-z%.:]*\])""" + val simpleIpv6Regex = """(\[[0-9A-Za-z%.:]+])""" // scalastyle:on val regex = "^(" + hostnameRegex + "|" + ipv4Regex + "|" + simpleIpv6Regex + ")$" val prog = new RegexProgram(regex) @@ -172,33 +155,51 @@ case class GpuParseUrl(children: Seq[Expression], // scalastyle:off line.size.limit // regex basically copied from https://stackoverflow.com/questions/53497/regular-expression-that-matches-valid-ipv6-addresses // spilt the ipv6 regex into 8 parts to avoid the regex size limit - val ipv6Regex1 = """(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4})""" // 1:2:3:4:5:6:7:8 - val ipv6Regex2 = """(([0-9a-fA-F]{1,4}:){1,7}:)""" // 1:: 1:2:3:4:5:6:7:: - val ipv6Regex3 = """(([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4})""" // 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8 - val ipv6Regex4 = """(([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2})""" // 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8 - val ipv6Regex5 = """(([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3})""" // 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8 - val ipv6Regex6 = """(([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4})""" // 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8 - val ipv6Regex7 = """(([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5})""" // 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8 - val ipv6Regex8 = """([0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6}))""" // 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8 - val ipv6Regex9 = """(:((:[0-9a-fA-F]{1,4}){1,7}|:))""" // ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 :: - val ipv6Regex10 = """fe80:((:([0-9a-fA-F]{1,4})?){1,4})?%[0-9a-zA-Z]+|""" // fe80::7:8%eth0 fe80::7:8%1 (link-local IPv6 addresses with zone index) - val ipv6Regex11 = """::(ffff(:0{1,4})?:)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.){3}(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" + val ipv6Regex1 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?""" + // 1:2:3:4:5:6:7:8 + val ipv6Regex2 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:""" + // 1:: 1:2:3:4:5:6:7:: + val ipv6Regex3 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)""" + // 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8 + val ipv6Regex4 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8 + val ipv6Regex5 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8 + val ipv6Regex6 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8 + val ipv6Regex7 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8 + val ipv6Regex8 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?))""" + // 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8 + val ipv6Regex9 = """(:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?|:))""" + // ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 :: + val ipv6Regex10 = """(fe80:((:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?)?%[0-9a-zA-Z]+)""" + // fe80::7:8%eth0 fe80::7:8%1 (link-local IPv6 addresses with zone index) + val ipv6Regex11 = """(::(ffff(:00?0?0?)?:)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?(25[0-5]|(2[0-4]|1?[0-9])?[0-9]))""" // ::255.255.255.255 ::ffff:255.255.255.255 ::ffff:0:255.255.255.255 (IPv4-mapped IPv6 addresses and IPv4-translated addresses) - val ipv6Regex12 = """([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.){3}(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" + val ipv6Regex12 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" // 2001:db8:3:4::192.0.2.33 64:ff9b::192.0.2.33 (IPv4-Embedded IPv6 Address) // scalastyle:on - val regex = "^" + ipv6Regex1 + "|" + ipv6Regex2 + "|" + ipv6Regex3 + "|" + ipv6Regex4 + "|" + - ipv6Regex5 + ipv6Regex6 + "|" + ipv6Regex7 + "|" + ipv6Regex8 + "|" + ipv6Regex9 + - ipv6Regex10 + ipv6Regex11 + ipv6Regex12 + "$" - + val regex = """^\[(""" + ipv6Regex1 + "|" + ipv6Regex2 + "|" + ipv6Regex3 + "|" + ipv6Regex4 + "|" + + ipv6Regex5 + "|" + ipv6Regex6 + "|" + ipv6Regex7 + "|" + ipv6Regex8 + "|" + ipv6Regex9 + "|" + + ipv6Regex10 + "|" + ipv6Regex11 + "|" + ipv6Regex12 + """)]$""" + // ^\[((([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:){7,7}[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)|(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:){1,7}:))]$ val prog = new RegexProgram(regex) + GpuColumnVector.debug("cv", cv) + GpuColumnVector.debug("simpleMatched", simpleMatched) + val invalidIpv6 = withResource(cv.matchesRe(prog)) { matched => - matched.not() + withResource(matched.not()) { invalid => + GpuColumnVector.debug("invalid", invalid) + simpleMatched.and(invalid) + } } withResource(invalidIpv6) { _ => withResource(Scalar.fromNull(DType.STRING)) { nullScalar => - invalidIpv6.ifElse(cv, nullScalar) + val x = invalidIpv6.ifElse(nullScalar, cv) + GpuColumnVector.debug("unsetinvalidIpv6", x) + x } } } @@ -210,8 +211,8 @@ case class GpuParseUrl(children: Seq[Expression], } def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar): ColumnVector = { - val valid = reValid(url.getBase) val part = partToExtract.getValue.asInstanceOf[UTF8String].toString + val valid = reValid(url.getBase) val matched = withResource(valid) { _ => reMatch(valid, part) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index f8c37f3bf97..d4cf3f69935 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -156,17 +156,10 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ).toDF("urls") } - def urlWithRegexLikeQuery(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "http://foo/bar?abc=BAD&a.c=GOOD", - "http://foo/bar?a.c=GOOD&abc=BAD" - ).toDF("urls") - } - def urlIpv6Host(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ Seq[String]( + "http://[1:2:3:4:5:6:7:8:9:10]", "http://[1:2:3:4:5:6:7:8]", "http://[1::]", "http://[1:2:3:4:5:6:7::]", @@ -202,16 +195,16 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ).toDF("urls") } - // def unsupportedUrlCases(session: SparkSession): DataFrame = { - // // Spark allow an empty authority component only when it's followed by a non-empty path, - // // query component, or fragment component. But in plugin, parse_url just simply allow - // // empty authority component without checking if it is followed something or not. - // import session.sqlContext.implicits._ - // Seq[String]( - // "http://", - // "//" - // ).toDF("urls") - // } + def unsupportedUrlCases(session: SparkSession): DataFrame = { + // Spark allow an empty authority component only when it's followed by a non-empty path, + // query component, or fragment component. But in plugin, parse_url just simply allow + // empty authority component without checking if it is followed something or not. + import session.sqlContext.implicits._ + Seq[String]( + "http://", + "//" + ).toDF("urls") + } def parseUrls(frame: DataFrame): DataFrame = { frame.selectExpr( @@ -246,9 +239,9 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { parseUrls } - // testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { - // parseUrls - // } + testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { + parseUrls + } testSparkResultsAreEqual("Test parse_url with query and key", urlWithQueryKey) { frame => frame.selectExpr( @@ -256,19 +249,4 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "parse_url(urls, 'QUERY', 'foo') as QUERY", "parse_url(urls, 'QUERY', 'baz') as QUERY") } - - test("Test parse_url with regex like query") { - withGpuSparkSession(spark => { - val frame = urlWithRegexLikeQuery(spark) - val result = frame.selectExpr( - "urls", - "parse_url(urls, 'QUERY', 'a.c') as QUERY") - import spark.implicits._ - val expected = Seq( - ("http://foo/bar?abc=BAD&a.c=GOOD", "GOOD"), - ("http://foo/bar?a.c=GOOD&abc=BAD", "GOOD") - ).toDF("urls", "QUERY") - assert(result.collect().deep == expected.collect().deep) - }) - } } \ No newline at end of file From daa61ea8ab89c480b77185c5ddfb4c063d8f2c51 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 9 Aug 2023 19:28:41 +0800 Subject: [PATCH 14/27] Fix ipv6 validation and some clean up --- docs/compatibility.md | 8 +- integration_tests/src/main/python/url_test.py | 2 +- .../spark/sql/rapids/urlFunctions.scala | 83 +++++++++---------- .../spark/rapids/UrlFunctionsSuite.scala | 16 ---- 4 files changed, 44 insertions(+), 65 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 37353350f8a..5c2d581cfd1 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -463,8 +463,12 @@ In Spark, parse_url is based on java's URI library, while the implementation in These are the known cases where running on the GPU will produce different results to the CPU: - Spark allow an empty authority component only when it's followed by a non-empty path, - query component, or fragment component. But in plugin, parse_url just simply allow empty - authority component without checking if it is followed something or not. So `parse_url('http://', 'HOST')` will return `null` in Spark, but return `""` in plugin. + query component, or fragment component. But in plugin, parse_url just simply allow empty + authority component without checking if it is followed something or not. So `parse_url('http://', 'HOST')` will + return `null` in Spark, but return `""` in plugin. +- If input url has a invalid Ipv6 address, Spark will return `null` for all components, but plugin will parse other + components except `HOST` as normal. So `http://userinfo@[1:2:3:4:5:6:7:8:9:10]/path?query=1#Ref`'s result will be + `[null,/path,query=1,Ref,http,/path?query=1,userinfo@[1:2:3:4:5:6:7:8:9:10],userinfo]` ## Windowing diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index 85277922486..289b5f8371b 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -111,4 +111,4 @@ def test_parse_url_too_many_args(): lambda spark : unary_op_df(spark, StringGen()).selectExpr( "a","parse_url(a, 'USERINFO', 'key', 'value')").collect(), conf={}, - error_message='parse_url function requires two or three arguments') \ No newline at end of file + error_message='parse_url function requires two or three arguments') diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 2061c2e1149..2ebb62c83af 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -50,6 +50,40 @@ object GpuParseUrl { private val FILE_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?((?:[^?#]*)(?:\?[^#]*)?)(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" private val AUTHORITY_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://((?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" private val USERINFO_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:([^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + // HostName parsing followed rules in java URI lib: + // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] + // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum + // toplabel = alpha | alpha *( alphanum | "-" ) alphanum + val hostnameRegex = """((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)""" + val ipv4Regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" + val simpleIpv6Regex = """(\[[0-9A-Za-z%.:]+])""" + // regex basically copied from https://stackoverflow.com/questions/53497/regular-expression-that-matches-valid-ipv6-addresses + val ipv6Regex1 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?""" + // 1:2:3:4:5:6:7:8 + val ipv6Regex2 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:""" + // 1:: 1:2:3:4:5:6:7:: + val ipv6Regex3 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)""" + // 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8 + val ipv6Regex4 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8 + val ipv6Regex5 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8 + val ipv6Regex6 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8 + val ipv6Regex7 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8 + val ipv6Regex8 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?))""" + // 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8 + val ipv6Regex9 = """(:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?|:))""" + // ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 :: + val ipv6Regex10 = """(fe80:((:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?)?%[0-9a-zA-Z]+)""" + // fe80::7:8%eth0 fe80::7:8%1 (link-local IPv6 addresses with zone index) + val ipv6Regex11 = """(::((ffff|FFFF)(:00?0?0?)?:)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?(25[0-5]|(2[0-4]|1?[0-9])?[0-9]))""" + // ::255.255.255.255 ::ffff:255.255.255.255 ::ffff:0:255.255.255.255 (IPv4-mapped IPv6 addresses and IPv4-translated addresses) + val ipv6Regex12 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" + // 2001:db8:3:4::192.0.2.33 64:ff9b::192.0.2.33 (IPv4-Embedded IPv6 Address) + val ipv6Regex13 = """(0:0:0:0:0:(0|FFFF|ffff):((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" + // 0:0:0:0:0:0:13.1.68.3 // scalastyle:on } @@ -119,15 +153,6 @@ case class GpuParseUrl(children: Seq[Expression], } private def unsetInvalidHost(cv: ColumnVector): ColumnVector = { - // scalastyle:off line.size.limit - // HostName parsing followed rules in java URI lib: - // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] - // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum - // toplabel = alpha | alpha *( alphanum | "-" ) alphanum - val hostnameRegex = """((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)""" - val ipv4Regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" - val simpleIpv6Regex = """(\[[0-9A-Za-z%.:]+])""" - // scalastyle:on val regex = "^(" + hostnameRegex + "|" + ipv4Regex + "|" + simpleIpv6Regex + ")$" val prog = new RegexProgram(regex) val HostnameIpv4Res = withResource(cv.matchesRe(prog)) { isMatch => @@ -152,54 +177,20 @@ case class GpuParseUrl(children: Seq[Expression], } private def unsetInvalidIpv6Host(cv: ColumnVector, simpleMatched: ColumnVector): ColumnVector = { - // scalastyle:off line.size.limit - // regex basically copied from https://stackoverflow.com/questions/53497/regular-expression-that-matches-valid-ipv6-addresses - // spilt the ipv6 regex into 8 parts to avoid the regex size limit - val ipv6Regex1 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?""" - // 1:2:3:4:5:6:7:8 - val ipv6Regex2 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:""" - // 1:: 1:2:3:4:5:6:7:: - val ipv6Regex3 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)""" - // 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8 - val ipv6Regex4 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" - // 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8 - val ipv6Regex5 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" - // 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8 - val ipv6Regex6 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" - // 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8 - val ipv6Regex7 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" - // 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8 - val ipv6Regex8 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?))""" - // 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8 - val ipv6Regex9 = """(:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?|:))""" - // ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 :: - val ipv6Regex10 = """(fe80:((:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?)?%[0-9a-zA-Z]+)""" - // fe80::7:8%eth0 fe80::7:8%1 (link-local IPv6 addresses with zone index) - val ipv6Regex11 = """(::(ffff(:00?0?0?)?:)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?(25[0-5]|(2[0-4]|1?[0-9])?[0-9]))""" - // ::255.255.255.255 ::ffff:255.255.255.255 ::ffff:0:255.255.255.255 (IPv4-mapped IPv6 addresses and IPv4-translated addresses) - val ipv6Regex12 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" - // 2001:db8:3:4::192.0.2.33 64:ff9b::192.0.2.33 (IPv4-Embedded IPv6 Address) - // scalastyle:on val regex = """^\[(""" + ipv6Regex1 + "|" + ipv6Regex2 + "|" + ipv6Regex3 + "|" + ipv6Regex4 + "|" + ipv6Regex5 + "|" + ipv6Regex6 + "|" + ipv6Regex7 + "|" + ipv6Regex8 + "|" + ipv6Regex9 + "|" + - ipv6Regex10 + "|" + ipv6Regex11 + "|" + ipv6Regex12 + """)]$""" - // ^\[((([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:){7,7}[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)|(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:){1,7}:))]$ + ipv6Regex10 + "|" + ipv6Regex11 + "|" + ipv6Regex12 + "|" + ipv6Regex13 + """)(%[a-zA-Z0-9]*)?]$""" + val prog = new RegexProgram(regex) - GpuColumnVector.debug("cv", cv) - GpuColumnVector.debug("simpleMatched", simpleMatched) - val invalidIpv6 = withResource(cv.matchesRe(prog)) { matched => withResource(matched.not()) { invalid => - GpuColumnVector.debug("invalid", invalid) simpleMatched.and(invalid) } } withResource(invalidIpv6) { _ => withResource(Scalar.fromNull(DType.STRING)) { nullScalar => - val x = invalidIpv6.ifElse(nullScalar, cv) - GpuColumnVector.debug("unsetinvalidIpv6", x) - x + invalidIpv6.ifElse(nullScalar, cv) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index d4cf3f69935..4a7ee5ea838 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -159,7 +159,6 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { def urlIpv6Host(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ Seq[String]( - "http://[1:2:3:4:5:6:7:8:9:10]", "http://[1:2:3:4:5:6:7:8]", "http://[1::]", "http://[1:2:3:4:5:6:7::]", @@ -195,17 +194,6 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ).toDF("urls") } - def unsupportedUrlCases(session: SparkSession): DataFrame = { - // Spark allow an empty authority component only when it's followed by a non-empty path, - // query component, or fragment component. But in plugin, parse_url just simply allow - // empty authority component without checking if it is followed something or not. - import session.sqlContext.implicits._ - Seq[String]( - "http://", - "//" - ).toDF("urls") - } - def parseUrls(frame: DataFrame): DataFrame = { frame.selectExpr( "urls", @@ -239,10 +227,6 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { parseUrls } - testSparkResultsAreEqual("Test parse_url unsupport cases", unsupportedUrlCases) { - parseUrls - } - testSparkResultsAreEqual("Test parse_url with query and key", urlWithQueryKey) { frame => frame.selectExpr( "urls", From b3abaf66342de4262ec38e374310e53854028884 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 19 Oct 2023 14:13:58 +0800 Subject: [PATCH 15/27] Use parse_url kernel for PROTOCOL parsing Signed-off-by: Haoyang Li --- integration_tests/src/main/python/url_test.py | 56 ------ .../nvidia/spark/rapids/GpuOverrides.scala | 3 +- .../spark/sql/rapids/urlFunctions.scala | 178 +----------------- .../spark/rapids/UrlFunctionsSuite.scala | 175 ++++++----------- 4 files changed, 70 insertions(+), 342 deletions(-) diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index 289b5f8371b..32a838d7121 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -29,34 +29,6 @@ r'(:[0-9]{1,3}){0,1}(/[a-z]{1,3}){0,3}(\?key=[a-z]{1,3}){0,1}(#([a-z]{1,3})){0,1}' url_gen = StringGen(url_pattern) - -def test_parse_url_host(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen, length=10).selectExpr( - "a", - "parse_url(a, 'HOST')" - )) - -def test_parse_url_path(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'PATH')" - )) - -def test_parse_url_query(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'QUERY')" - )) - -def test_parse_url_ref(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'REF')" - )) def test_parse_url_protocol(): assert_gpu_and_cpu_are_equal_collect( @@ -65,27 +37,6 @@ def test_parse_url_protocol(): "parse_url(a, 'PROTOCOL')" )) -def test_parse_url_file(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'FILE')" - )) - -def test_parse_url_authority(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'AUTHORITY')" - )) - -def test_parse_url_userinfo(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'USERINFO')" - )) - def test_parse_url_with_no_query_key(): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, url_gen, length=100).selectExpr( @@ -99,13 +50,6 @@ def test_parse_url_with_no_query_key(): "parse_url(a, 'USERINFO', '')" )) -def test_parse_url_with_query_key(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, StringGen(url_pattern_with_key)).selectExpr( - "a", - "parse_url(a, 'QUERY', 'key')" - )) - def test_parse_url_too_many_args(): assert_gpu_and_cpu_error( lambda spark : unary_op_df(spark, StringGen()).selectExpr( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 8bbec860a61..2bc555f42d5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3231,7 +3231,8 @@ object GpuOverrides extends Logging { "Extracts a part from a URL", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("url", TypeSig.STRING, TypeSig.STRING), - ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING).withPsNote( + TypeEnum.STRING, "only support partToExtract=PROTOCOL"), TypeSig.STRING)), // Should really be an OptionalParam Some(RepeatingParamCheck("key", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 2ebb62c83af..a35a294dab3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -16,10 +16,11 @@ package org.apache.spark.sql.rapids -import ai.rapids.cudf.{ColumnVector, DType, RegexProgram, Scalar, Table} +import ai.rapids.cudf.ColumnVector import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.jni.ParseURI import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -39,52 +40,6 @@ object GpuParseUrl { private val FILE = "FILE" private val AUTHORITY = "AUTHORITY" private val USERINFO = "USERINFO" - private val REGEXPREFIX = """(&|^)(""" - private val REGEXSUBFIX = "=)([^&]*)" - // scalastyle:off line.size.limit - private val HOST_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" - private val PATH_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?([^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" - private val QUERY_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" - private val REF_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" - private val PROTOCOL_REGEX = """^(?:(?:([^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" - private val FILE_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?((?:[^?#]*)(?:\?[^#]*)?)(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" - private val AUTHORITY_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://((?:(?:(?:[^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" - private val USERINFO_REGEX = """^(?:(?:(?:[^:/?#]+):)?(?://(?:(?:(?:([^:]*:?[^\@]*)@)?(?:\[[0-9A-Za-z%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(?:[^?#]*)(?:\?[^#]*)?(?:#[a-zA-Z0-9\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" - // HostName parsing followed rules in java URI lib: - // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] - // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum - // toplabel = alpha | alpha *( alphanum | "-" ) alphanum - val hostnameRegex = """((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])|(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\-]*[a-zA-Z]))\.?)""" - val ipv4Regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" - val simpleIpv6Regex = """(\[[0-9A-Za-z%.:]+])""" - // regex basically copied from https://stackoverflow.com/questions/53497/regular-expression-that-matches-valid-ipv6-addresses - val ipv6Regex1 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?""" - // 1:2:3:4:5:6:7:8 - val ipv6Regex2 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:""" - // 1:: 1:2:3:4:5:6:7:: - val ipv6Regex3 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)""" - // 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8 - val ipv6Regex4 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" - // 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8 - val ipv6Regex5 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" - // 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8 - val ipv6Regex6 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" - // 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8 - val ipv6Regex7 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" - // 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8 - val ipv6Regex8 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?))""" - // 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8 - val ipv6Regex9 = """(:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?|:))""" - // ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 :: - val ipv6Regex10 = """(fe80:((:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?)?%[0-9a-zA-Z]+)""" - // fe80::7:8%eth0 fe80::7:8%1 (link-local IPv6 addresses with zone index) - val ipv6Regex11 = """(::((ffff|FFFF)(:00?0?0?)?:)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?(25[0-5]|(2[0-4]|1?[0-9])?[0-9]))""" - // ::255.255.255.255 ::ffff:255.255.255.255 ::ffff:0:255.255.255.255 (IPv4-mapped IPv6 addresses and IPv4-translated addresses) - val ipv6Regex12 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" - // 2001:db8:3:4::192.0.2.33 64:ff9b::192.0.2.33 (IPv4-Embedded IPv6 Address) - val ipv6Regex13 = """(0:0:0:0:0:(0|FFFF|ffff):((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" - // 0:0:0:0:0:0:13.1.68.3 - // scalastyle:on } case class GpuParseUrl(children: Seq[Expression], @@ -110,91 +65,6 @@ case class GpuParseUrl(children: Seq[Expression], super[ExpectsInputTypes].checkInputDataTypes() } - private def getPattern(key: UTF8String): RegexProgram = { - val regex = REGEXPREFIX + key.toString + REGEXSUBFIX - new RegexProgram(regex) - } - - private def reValid(url: ColumnVector): ColumnVector = { - // Simply check if urls contain spaces for now, most validations will be done when extracting. - withResource(Scalar.fromString(" ")) { blank => - withResource(url.stringContains(blank)) { isMatch => - withResource(Scalar.fromNull(DType.STRING)) { nullScalar => - isMatch.ifElse(nullScalar, url) - } - } - } - } - - private def reMatch(url: ColumnVector, partToExtract: String): ColumnVector = { - val regex = partToExtract match { - case HOST => HOST_REGEX - case PATH => PATH_REGEX - case QUERY => QUERY_REGEX - case REF => REF_REGEX - case PROTOCOL => PROTOCOL_REGEX - case FILE => FILE_REGEX - case AUTHORITY => AUTHORITY_REGEX - case USERINFO => USERINFO_REGEX - case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") - } - val prog = new RegexProgram(regex) - withResource(url.extractRe(prog)) { table: Table => - table.getColumn(0).incRefCount() - } - } - - private def emptyToNulls(cv: ColumnVector): ColumnVector = { - withResource(ColumnVector.fromStrings("")) { empty => - withResource(ColumnVector.fromStrings(null)) { nulls => - cv.findAndReplaceAll(empty, nulls) - } - } - } - - private def unsetInvalidHost(cv: ColumnVector): ColumnVector = { - val regex = "^(" + hostnameRegex + "|" + ipv4Regex + "|" + simpleIpv6Regex + ")$" - val prog = new RegexProgram(regex) - val HostnameIpv4Res = withResource(cv.matchesRe(prog)) { isMatch => - withResource(Scalar.fromNull(DType.STRING)) { nullScalar => - isMatch.ifElse(cv, nullScalar) - } - } - // match the simple ipv6 address, valid ipv6 only when necessary cause the regex is very long - val simpleIpv6Prog = new RegexProgram(simpleIpv6Regex) - withResource(cv.matchesRe(simpleIpv6Prog)) { isMatch => - val anyIpv6 = withResource(isMatch.any()) { a => - a.isValid && a.getBoolean - } - if (anyIpv6) { - withResource(HostnameIpv4Res) { _ => - unsetInvalidIpv6Host(HostnameIpv4Res, isMatch) - } - } else { - HostnameIpv4Res - } - } - } - - private def unsetInvalidIpv6Host(cv: ColumnVector, simpleMatched: ColumnVector): ColumnVector = { - val regex = """^\[(""" + ipv6Regex1 + "|" + ipv6Regex2 + "|" + ipv6Regex3 + "|" + ipv6Regex4 + "|" + - ipv6Regex5 + "|" + ipv6Regex6 + "|" + ipv6Regex7 + "|" + ipv6Regex8 + "|" + ipv6Regex9 + "|" + - ipv6Regex10 + "|" + ipv6Regex11 + "|" + ipv6Regex12 + "|" + ipv6Regex13 + """)(%[a-zA-Z0-9]*)?]$""" - - val prog = new RegexProgram(regex) - - val invalidIpv6 = withResource(cv.matchesRe(prog)) { matched => - withResource(matched.not()) { invalid => - simpleMatched.and(invalid) - } - } - withResource(invalidIpv6) { _ => - withResource(Scalar.fromNull(DType.STRING)) { nullScalar => - invalidIpv6.ifElse(nullScalar, cv) - } - } - } - def doColumnar(numRows: Int, url: GpuScalar, partToExtract: GpuScalar): ColumnVector = { withResource(GpuColumnVector.from(url, numRows, StringType)) { urlCol => doColumnar(urlCol, partToExtract) @@ -203,30 +73,13 @@ case class GpuParseUrl(children: Seq[Expression], def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar): ColumnVector = { val part = partToExtract.getValue.asInstanceOf[UTF8String].toString - val valid = reValid(url.getBase) - val matched = withResource(valid) { _ => - reMatch(valid, part) - } - if (part == HOST) { - val valided = withResource(matched) { _ => - unsetInvalidHost(matched) - } - withResource(valided) { _ => - emptyToNulls(valided) - } - } else if (part == QUERY || part == REF) { - val resWithNulls = withResource(matched) { _ => - emptyToNulls(matched) - } - withResource(resWithNulls) { _ => - resWithNulls.substring(1) - } - } else if (part == PATH || part == FILE) { - matched - } else { - withResource(matched) { _ => - emptyToNulls(matched) - } + part match { + case PROTOCOL => + ParseURI.parseURIProtocol(url.getBase) + case HOST | PATH | QUERY | REF | FILE | AUTHORITY | USERINFO => + throw new UnsupportedOperationException(s"$this is not supported partToExtract=$part") + case _ => + throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") } } @@ -236,18 +89,7 @@ case class GpuParseUrl(children: Seq[Expression], // return a null columnvector return ColumnVector.fromStrings(null, null) } - val querys = withResource(reMatch(url.getBase, QUERY)) { matched => - matched.substring(1) - } - val keyStr = key.getValue.asInstanceOf[UTF8String] - val queryValue = withResource(querys) { _ => - withResource(querys.extractRe(getPattern(keyStr))) { table: Table => - table.getColumn(2).incRefCount() - } - } - withResource(queryValue) { _ => - emptyToNulls(queryValue) - } + throw new UnsupportedOperationException(s"$this only supports partToExtract = PROTOCOL") } override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index 4a7ee5ea838..e2de2fdc482 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -29,29 +29,29 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://foo.com/blah_blah_(wikipedia)_(again)", "http://www.example.com/wpstyle/?p=364", "https://www.example.com/foo/?bar=baz&inga=42&quux", - "http://✪df.ws/123", + // "http://✪df.ws/123", "http://userid:password@example.com:8080", "http://userid:password@example.com:8080/", "http://userid:password@example.com", "http://userid:password@example.com/", "http://142.42.1.1/", "http://142.42.1.1:8080/", - "http://➡.ws/䨹", - "http://⌘.ws", - "http://⌘.ws/", + // "http://➡.ws/䨹", + // "http://⌘.ws", + // "http://⌘.ws/", "http://foo.com/blah_(wikipedia)#cite-1", "http://foo.com/blah_(wikipedia)_blah#cite-1", - "http://foo.com/unicode_(✪)_in_parens", + // "http://foo.com/unicode_(✪)_in_parens", "http://foo.com/(something)?after=parens", - "http://☺.damowmow.com/", + // "http://☺.damowmow.com/", "http://code.google.com/events/#&product=browser", "http://j.mp", "ftp://foo.bar/baz", "http://foo.bar/?q=Test%20URL-encoded%20stuff", - "http://مثال.إختبار", - "http://例子.测试", - "http://उदाहरण.परीक्षा", - "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", + // "http://مثال.إختبار", + // "http://例子.测试", + // "http://उदाहरण.परीक्षा", + // "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", "http://1337.net", "http://a.b-c.de", "http://223.255.255.254", @@ -64,10 +64,10 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://??", "http://??/", "http://#", - "http://##", - "http://##/", + // "http://##", + // "http://##/", "http://foo.bar?q=Spaces should be encoded", - // "//", + "//", "//a", "///a", "///", @@ -117,120 +117,61 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { ).toDF("urls") } - def urlCasesFromJavaUriLib(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "ftp://ftp.is.co.za/rfc/rfc1808.txt", - "http://www.math.uio.no/faq/compression-faq/part1.html", - "telnet://melvyl.ucop.edu/", - "http://www.w3.org/Addressing/", - "ftp://ds.internic.net/rfc/", - "http://www.ics.uci.edu/pub/ietf/uri/historical.html#WARNING", - "http://www.ics.uci.edu/pub/ietf/uri/#Related", - "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:80/index.html", - "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:10%12]:80/index.html", - "http://[1080:0:0:0:8:800:200C:417A]/index.html", - "http://[1080:0:0:0:8:800:200C:417A%1]/index.html", - "http://[3ffe:2a00:100:7031::1]", - "http://[1080::8:800:200C:417A]/foo", - "http://[::192.9.5.5]/ipng", - "http://[::192.9.5.5%interface]/ipng", - "http://[::FFFF:129.144.52.38]:80/index.html", - "http://[2010:836B:4179::836B:4179]", - "http://[FF01::101]", - "http://[::1]", - "http://[::]", - "http://[::%hme0]", - "http://[0:0:0:0:0:0:13.1.68.3]", - "http://[0:0:0:0:0:FFFF:129.144.52.38]", - "http://[0:0:0:0:0:FFFF:129.144.52.38%33]", - "http://[0:0:0:0:0:ffff:1.2.3.4]", - "http://[::13.1.68.3]" - ).toDF("urls") - } + // def urlIpv6Host(session: SparkSession): DataFrame = { + // import session.sqlContext.implicits._ + // Seq[String]( + // "http://[1:2:3:4:5:6:7:8]", + // "http://[1::]", + // "http://[1:2:3:4:5:6:7::]", + // "http://[1::8]", + // "http://[1:2:3:4:5:6::8]", + // "http://[1:2:3:4:5:6::8]", + // "http://[1::7:8]", + // "http://[1:2:3:4:5::7:8]", + // "http://[1:2:3:4:5::8]", + // "http://[1::6:7:8]", + // "http://[1:2:3:4::6:7:8]", + // "http://[1:2:3:4::8]", + // "http://[1::5:6:7:8]", + // "http://[1:2:3::5:6:7:8]", + // "http://[1:2:3::8]", + // "http://[1::4:5:6:7:8]", + // "http://[1:2::4:5:6:7:8]", + // "http://[1:2::8]", + // "http://[1::3:4:5:6:7:8]", + // "http://[1::3:4:5:6:7:8]", + // "http://[1::8]", + // "http://[::2:3:4:5:6:7:8]", + // "http://[::2:3:4:5:6:7:8]", + // "http://[::8]", + // "http://[::]", + // "http://[fe80::7:8%eth0]", + // "http://[fe80::7:8%1]", + // "http://[::255.255.255.255]", + // "http://[::ffff:255.255.255.255]", + // "http://[::ffff:0:255.255.255.255]", + // "http://[2001:db8:3:4::192.0.2.33]", + // "http://[64:ff9b::192.0.2.33]" + // ).toDF("urls") + // } - def urlWithQueryKey(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "http://foo.com/blah_blah?foo=bar&baz=blah#vertical-bar" - ).toDF("urls") - } - - def urlIpv6Host(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "http://[1:2:3:4:5:6:7:8]", - "http://[1::]", - "http://[1:2:3:4:5:6:7::]", - "http://[1::8]", - "http://[1:2:3:4:5:6::8]", - "http://[1:2:3:4:5:6::8]", - "http://[1::7:8]", - "http://[1:2:3:4:5::7:8]", - "http://[1:2:3:4:5::8]", - "http://[1::6:7:8]", - "http://[1:2:3:4::6:7:8]", - "http://[1:2:3:4::8]", - "http://[1::5:6:7:8]", - "http://[1:2:3::5:6:7:8]", - "http://[1:2:3::8]", - "http://[1::4:5:6:7:8]", - "http://[1:2::4:5:6:7:8]", - "http://[1:2::8]", - "http://[1::3:4:5:6:7:8]", - "http://[1::3:4:5:6:7:8]", - "http://[1::8]", - "http://[::2:3:4:5:6:7:8]", - "http://[::2:3:4:5:6:7:8]", - "http://[::8]", - "http://[::]", - "http://[fe80::7:8%eth0]", - "http://[fe80::7:8%1]", - "http://[::255.255.255.255]", - "http://[::ffff:255.255.255.255]", - "http://[::ffff:0:255.255.255.255]", - "http://[2001:db8:3:4::192.0.2.33]", - "http://[64:ff9b::192.0.2.33]" - ).toDF("urls") - } - - def parseUrls(frame: DataFrame): DataFrame = { - frame.selectExpr( - "urls", - "parse_url(urls, 'HOST') as HOST", - "parse_url(urls, 'PATH') as PATH", - "parse_url(urls, 'QUERY') as QUERY", - "parse_url(urls, 'REF') as REF", - "parse_url(urls, 'PROTOCOL') as PROTOCOL", - "parse_url(urls, 'FILE') as FILE", - "parse_url(urls, 'AUTHORITY') as AUTHORITY", - "parse_url(urls, 'USERINFO') as USERINFO") + def parseUrlProtocol(frame: DataFrame): DataFrame = { + frame.selectExpr("urls", "parse_url(urls, 'PROTOCOL')") } testSparkResultsAreEqual("Test parse_url edge cases from internet", validUrlEdgeCasesDf) { - parseUrls + parseUrlProtocol } testSparkResultsAreEqual("Test parse_url cases from Spark", urlCasesFromSpark) { - parseUrls + parseUrlProtocol } testSparkResultsAreEqual("Test parse_url invalid cases from Spark", urlCasesFromSparkInvalid) { - parseUrls + parseUrlProtocol } - testSparkResultsAreEqual("Test parse_url cases from java URI library", urlCasesFromJavaUriLib) { - parseUrls - } - - testSparkResultsAreEqual("Test parse_url ipv6 host", urlIpv6Host) { - parseUrls - } - - testSparkResultsAreEqual("Test parse_url with query and key", urlWithQueryKey) { - frame => frame.selectExpr( - "urls", - "parse_url(urls, 'QUERY', 'foo') as QUERY", - "parse_url(urls, 'QUERY', 'baz') as QUERY") - } + // testSparkResultsAreEqual("Test parse_url ipv6 host", urlIpv6Host) { + // parseUrlProtocol + // } } \ No newline at end of file From 592c642de47f14b41224dba0d2d4f02b67bb70eb Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 19 Oct 2023 14:21:23 +0800 Subject: [PATCH 16/27] verify Signed-off-by: Haoyang Li --- docs/supported_ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 254fe8ef876..4446a0c7869 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10770,7 +10770,7 @@ are limited. -PS
Literal value only
+PS
only support partToExtract=PROTOCOL;
Literal value only
From 9db1b2a80d1cf6b4ea24b7d4e3df2eeb02f28269 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 19 Oct 2023 15:03:36 +0800 Subject: [PATCH 17/27] edit compatibility and update IT Signed-off-by: Haoyang Li --- docs/compatibility.md | 19 ++++++++----------- integration_tests/src/main/python/url_test.py | 7 +++++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index fd7825be5f2..343a0281577 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -453,17 +453,14 @@ will only be GPU-accelerated if the time zone used by the JVM is UTC. ## URL parsing -In Spark, parse_url is based on java's URI library, while the implementation in the RAPIDS Accelerator is based on regex extraction. Therefore, the results may be different in some edge cases. - -These are the known cases where running on the GPU will produce different results to the CPU: - -- Spark allow an empty authority component only when it's followed by a non-empty path, - query component, or fragment component. But in plugin, parse_url just simply allow empty - authority component without checking if it is followed something or not. So `parse_url('http://', 'HOST')` will - return `null` in Spark, but return `""` in plugin. -- If input url has a invalid Ipv6 address, Spark will return `null` for all components, but plugin will parse other - components except `HOST` as normal. So `http://userinfo@[1:2:3:4:5:6:7:8:9:10]/path?query=1#Ref`'s result will be - `[null,/path,query=1,Ref,http,/path?query=1,userinfo@[1:2:3:4:5:6:7:8:9:10],userinfo]` +`parse_url` can produce different results on the GPU compared to the CPU. + +Known issues for PROTOCOL parsing: +- If urls containing utf-8 special characters, PROTOCOL results on GPU will be null. +- If urls containing ipv6 host, GPU will return null for PROTOCOL. +- GPU will still try to parse the PROTOCOL instead of returning null for some edge invalid cases, + such as urls containing multiple '#' in REF (http://##) or empty authority component followed by + a empty path (http://). ## Windowing diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index 32a838d7121..34ce24bf2cc 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -19,7 +19,7 @@ from marks import * from pyspark.sql.types import * import pyspark.sql.functions as f -from spark_session import is_before_spark_320 +from spark_session import is_before_spark_340 # regex to generate limit length urls with HOST, PATH, QUERY, REF, PROTOCOL, FILE, AUTHORITY, USERINFO url_pattern = r'((http|https|ftp)://)(([a-zA-Z][a-zA-Z0-9]{0,2}\.){0,3}([a-zA-Z][a-zA-Z0-9]{0,2})\.([a-zA-Z][a-zA-Z0-9]{0,2}))' \ @@ -51,8 +51,11 @@ def test_parse_url_with_no_query_key(): )) def test_parse_url_too_many_args(): + error_message = 'parse_url function requires two or three arguments' \ + if is_before_spark_340() else \ + '[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `parse_url` requires [2, 3] parameters' assert_gpu_and_cpu_error( lambda spark : unary_op_df(spark, StringGen()).selectExpr( "a","parse_url(a, 'USERINFO', 'key', 'value')").collect(), conf={}, - error_message='parse_url function requires two or three arguments') + error_message=error_message) From d09f06df4b9029f4983b2442a1f2c4c82d9e18c9 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 20 Oct 2023 14:17:06 +0800 Subject: [PATCH 18/27] update integration tests Signed-off-by: Haoyang Li --- docs/supported_ops.md | 2 +- integration_tests/src/main/python/url_test.py | 197 ++++++++++++++++-- .../nvidia/spark/rapids/GpuOverrides.scala | 10 +- .../spark/rapids/UrlFunctionsSuite.scala | 177 ---------------- 4 files changed, 194 insertions(+), 192 deletions(-) delete mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 4446a0c7869..254fe8ef876 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10770,7 +10770,7 @@ are limited. -PS
only support partToExtract=PROTOCOL;
Literal value only
+PS
Literal value only
diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index 34ce24bf2cc..6bb427a4c30 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect from data_gen import * from marks import * from pyspark.sql.types import * @@ -28,6 +28,122 @@ url_pattern_with_key = r'((http|https|ftp|file)://)(([a-z]{1,3}\.){0,3}([a-z]{1,3})\.([a-z]{1,3}))' \ r'(:[0-9]{1,3}){0,1}(/[a-z]{1,3}){0,3}(\?key=[a-z]{1,3}){0,1}(#([a-z]{1,3})){0,1}' +edge_cases = [ + "http://foo.com/blah_blah", + "http://foo.com/blah_blah/", + "http://foo.com/blah_blah_(wikipedia)", + "http://foo.com/blah_blah_(wikipedia)_(again)", + "http://www.example.com/wpstyle/?p=364", + "https://www.example.com/foo/?bar=baz&inga=42&quux", + # "http://✪df.ws/123", + "http://userid:password@example.com:8080", + "http://userid:password@example.com:8080/", + "http://userid:password@example.com", + "http://userid:password@example.com/", + "http://142.42.1.1/", + "http://142.42.1.1:8080/", + # "http://➡.ws/䨹", + # "http://⌘.ws", + # "http://⌘.ws/", + "http://foo.com/blah_(wikipedia)#cite-1", + "http://foo.com/blah_(wikipedia)_blah#cite-1", + # "http://foo.com/unicode_(✪)_in_parens", + "http://foo.com/(something)?after=parens", + # "http://☺.damowmow.com/", + "http://code.google.com/events/#&product=browser", + "http://j.mp", + "ftp://foo.bar/baz", + r"http://foo.bar/?q=Test%20URL-encoded%20stuff", + # "http://مثال.إختبار", + # "http://例子.测试", + # "http://उदाहरण.परीक्षा", + # "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", + "http://1337.net", + "http://a.b-c.de", + "http://223.255.255.254", + "https://foo_bar.example.com/", + # "http:# ", + "http://.", + "http://..", + "http://../", + "http://?", + "http://??", + "http://??/", + "http://#", + # "http://##", + # "http://##/", + "http://foo.bar?q=Spaces should be encoded", + "# ", + "//a", + "///a", + "/# ", + "http:///a", + "foo.com", + "rdar://1234", + "h://test", + "http:// shouldfail.com", + ":// should fail", + "http://foo.bar/foo(bar)baz quux", + "ftps://foo.bar/", + "http://-error-.invalid/", + "http://a.b--c.de/", + "http://-a.b.co", + "http://a.b-.co", + "http://0.0.0.0", + "http://10.1.1.0", + "http://10.1.1.255", + "http://224.1.1.1", + "http://1.1.1.1.1", + "http://123.123.123", + "http://3628126748", + "http://.www.foo.bar/", + "http://www.foo.bar./", + "http://.www.foo.bar./", + "http://10.1.1.1", + "http://10.1.1.254", + "http://userinfo@spark.apache.org/path?query=1#Ref", + r"https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two", + "http://user:pass@host", + "http://user:pass@host/", + "http://user:pass@host/?#", + "http://user:pass@host/file;param?query;p2", + "inva lid://user:pass@host/file;param?query;p2", + # "http://[1:2:3:4:5:6:7:8]", + # "http://[1::]", + # "http://[1:2:3:4:5:6:7::]", + # "http://[1::8]", + # "http://[1:2:3:4:5:6::8]", + # "http://[1:2:3:4:5:6::8]", + # "http://[1::7:8]", + # "http://[1:2:3:4:5::7:8]", + # "http://[1:2:3:4:5::8]", + # "http://[1::6:7:8]", + # "http://[1:2:3:4::6:7:8]", + # "http://[1:2:3:4::8]", + # "http://[1::5:6:7:8]", + # "http://[1:2:3::5:6:7:8]", + # "http://[1:2:3::8]", + # "http://[1::4:5:6:7:8]", + # "http://[1:2::4:5:6:7:8]", + # "http://[1:2::8]", + # "http://[1::3:4:5:6:7:8]", + # "http://[1::3:4:5:6:7:8]", + # "http://[1::8]", + # "http://[::2:3:4:5:6:7:8]", + # "http://[::2:3:4:5:6:7:8]", + # "http://[::8]", + # "http://[::]", + # "http://[fe80::7:8%eth0]", + # "http://[fe80::7:8%1]", + # "http://[::255.255.255.255]", + # "http://[::ffff:255.255.255.255]", + # "http://[::ffff:0:255.255.255.255]", + # "http://[2001:db8:3:4::192.0.2.33]", + # "http://[64:ff9b::192.0.2.33]" +] + +edge_cases_gen = SetValuesGen(StringType(), edge_cases) + url_gen = StringGen(url_pattern) def test_parse_url_protocol(): @@ -37,18 +153,75 @@ def test_parse_url_protocol(): "parse_url(a, 'PROTOCOL')" )) -def test_parse_url_with_no_query_key(): +def test_parse_url_protocol_edge_cases(): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen, length=100).selectExpr( - "a", - "parse_url(a, 'HOST', '')", - "parse_url(a, 'PATH', '')", - "parse_url(a, 'REF', '')", - "parse_url(a, 'PROTOCOL', '')", - "parse_url(a, 'FILE', '')", - "parse_url(a, 'AUTHORITY', '')", - "parse_url(a, 'USERINFO', '')" + lambda spark : unary_op_df(spark, edge_cases_gen).selectExpr( + "a", + "parse_url(a, 'PROTOCOL')" )) + +@allow_non_gpu('ProjectExec', 'ParseUrl') +def test_parse_url_host_fallback(): + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'HOST')" + ), + 'ParseUrl') + +@allow_non_gpu('ProjectExec', 'ParseUrl') +def test_parse_url_path_fallback(): + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'PATH')" + ), + 'ParseUrl') + +@allow_non_gpu('ProjectExec', 'ParseUrl') +def test_parse_url_query_fallback(): + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'QUERY')" + ), + 'ParseUrl') + +@allow_non_gpu('ProjectExec', 'ParseUrl') +def test_parse_url_ref_fallback(): + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'REF')" + ), + 'ParseUrl') + +@allow_non_gpu('ProjectExec', 'ParseUrl') +def test_parse_url_file_fallback(): + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'FILE')" + ), + 'ParseUrl') + +@allow_non_gpu('ProjectExec', 'ParseUrl') +def test_parse_url_authority_fallback(): + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'AUTHORITY')" + ), + 'ParseUrl') + +@allow_non_gpu('ProjectExec', 'ParseUrl') +def test_parse_url_userinfo_fallback(): + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'USERINFO')" + ), + 'ParseUrl') def test_parse_url_too_many_args(): error_message = 'parse_url function requires two or three arguments' \ @@ -56,6 +229,6 @@ def test_parse_url_too_many_args(): '[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `parse_url` requires [2, 3] parameters' assert_gpu_and_cpu_error( lambda spark : unary_op_df(spark, StringGen()).selectExpr( - "a","parse_url(a, 'USERINFO', 'key', 'value')").collect(), + "a","parse_url(a, 'PROTOCOL', 'key', 'value')").collect(), conf={}, error_message=error_message) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 2bc555f42d5..6c6080920fb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3231,8 +3231,7 @@ object GpuOverrides extends Logging { "Extracts a part from a URL", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("url", TypeSig.STRING, TypeSig.STRING), - ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING).withPsNote( - TypeEnum.STRING, "only support partToExtract=PROTOCOL"), TypeSig.STRING)), + ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), // Should really be an OptionalParam Some(RepeatingParamCheck("key", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { @@ -3242,6 +3241,13 @@ object GpuOverrides extends Logging { if (failOnError) { willNotWorkOnGpu("Fail on error is not supported on GPU when parsing urls.") } + val partToExtract = childExprs(1).convertToGpu() + .asInstanceOf[GpuLiteral].value.asInstanceOf[UTF8String].toString + partToExtract.toUpperCase match { + case "PROTOCOL" => + case _ => + willNotWorkOnGpu(s"Part to extract $partToExtract is not supported on GPU") + } } override def convertToGpu(): GpuExpression = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala deleted file mode 100644 index e2de2fdc482..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import org.apache.spark.sql.{DataFrame, SparkSession} - -class UrlFunctionsSuite extends SparkQueryCompareTestSuite { - def validUrlEdgeCasesDf(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - // [In search of the perfect URL validation regex](https://mathiasbynens.be/demo/url-regex) - Seq[String]( - "http://foo.com/blah_blah", - "http://foo.com/blah_blah/", - "http://foo.com/blah_blah_(wikipedia)", - "http://foo.com/blah_blah_(wikipedia)_(again)", - "http://www.example.com/wpstyle/?p=364", - "https://www.example.com/foo/?bar=baz&inga=42&quux", - // "http://✪df.ws/123", - "http://userid:password@example.com:8080", - "http://userid:password@example.com:8080/", - "http://userid:password@example.com", - "http://userid:password@example.com/", - "http://142.42.1.1/", - "http://142.42.1.1:8080/", - // "http://➡.ws/䨹", - // "http://⌘.ws", - // "http://⌘.ws/", - "http://foo.com/blah_(wikipedia)#cite-1", - "http://foo.com/blah_(wikipedia)_blah#cite-1", - // "http://foo.com/unicode_(✪)_in_parens", - "http://foo.com/(something)?after=parens", - // "http://☺.damowmow.com/", - "http://code.google.com/events/#&product=browser", - "http://j.mp", - "ftp://foo.bar/baz", - "http://foo.bar/?q=Test%20URL-encoded%20stuff", - // "http://مثال.إختبار", - // "http://例子.测试", - // "http://उदाहरण.परीक्षा", - // "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", - "http://1337.net", - "http://a.b-c.de", - "http://223.255.255.254", - "https://foo_bar.example.com/", - // "http://", - "http://.", - "http://..", - "http://../", - "http://?", - "http://??", - "http://??/", - "http://#", - // "http://##", - // "http://##/", - "http://foo.bar?q=Spaces should be encoded", - "//", - "//a", - "///a", - "///", - "http:///a", - "foo.com", - "rdar://1234", - "h://test", - "http:// shouldfail.com", - ":// should fail", - "http://foo.bar/foo(bar)baz quux", - "ftps://foo.bar/", - "http://-error-.invalid/", - "http://a.b--c.de/", - "http://-a.b.co", - "http://a.b-.co", - "http://0.0.0.0", - "http://10.1.1.0", - "http://10.1.1.255", - "http://224.1.1.1", - "http://1.1.1.1.1", - "http://123.123.123", - "http://3628126748", - "http://.www.foo.bar/", - "http://www.foo.bar./", - "http://.www.foo.bar./", - "http://10.1.1.1", - "http://10.1.1.254" - ).toDF("urls") - } - - def urlCasesFromSpark(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "http://userinfo@spark.apache.org/path?query=1#Ref", - "https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two", - "http://user:pass@host", - "http://user:pass@host/", - "http://user:pass@host/?#", - "http://user:pass@host/file;param?query;p2" - ).toDF("urls") - } - - def urlCasesFromSparkInvalid(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "inva lid://user:pass@host/file;param?query;p2" - ).toDF("urls") - } - - // def urlIpv6Host(session: SparkSession): DataFrame = { - // import session.sqlContext.implicits._ - // Seq[String]( - // "http://[1:2:3:4:5:6:7:8]", - // "http://[1::]", - // "http://[1:2:3:4:5:6:7::]", - // "http://[1::8]", - // "http://[1:2:3:4:5:6::8]", - // "http://[1:2:3:4:5:6::8]", - // "http://[1::7:8]", - // "http://[1:2:3:4:5::7:8]", - // "http://[1:2:3:4:5::8]", - // "http://[1::6:7:8]", - // "http://[1:2:3:4::6:7:8]", - // "http://[1:2:3:4::8]", - // "http://[1::5:6:7:8]", - // "http://[1:2:3::5:6:7:8]", - // "http://[1:2:3::8]", - // "http://[1::4:5:6:7:8]", - // "http://[1:2::4:5:6:7:8]", - // "http://[1:2::8]", - // "http://[1::3:4:5:6:7:8]", - // "http://[1::3:4:5:6:7:8]", - // "http://[1::8]", - // "http://[::2:3:4:5:6:7:8]", - // "http://[::2:3:4:5:6:7:8]", - // "http://[::8]", - // "http://[::]", - // "http://[fe80::7:8%eth0]", - // "http://[fe80::7:8%1]", - // "http://[::255.255.255.255]", - // "http://[::ffff:255.255.255.255]", - // "http://[::ffff:0:255.255.255.255]", - // "http://[2001:db8:3:4::192.0.2.33]", - // "http://[64:ff9b::192.0.2.33]" - // ).toDF("urls") - // } - - def parseUrlProtocol(frame: DataFrame): DataFrame = { - frame.selectExpr("urls", "parse_url(urls, 'PROTOCOL')") - } - - testSparkResultsAreEqual("Test parse_url edge cases from internet", validUrlEdgeCasesDf) { - parseUrlProtocol - } - - testSparkResultsAreEqual("Test parse_url cases from Spark", urlCasesFromSpark) { - parseUrlProtocol - } - - testSparkResultsAreEqual("Test parse_url invalid cases from Spark", urlCasesFromSparkInvalid) { - parseUrlProtocol - } - - // testSparkResultsAreEqual("Test parse_url ipv6 host", urlIpv6Host) { - // parseUrlProtocol - // } -} \ No newline at end of file From 3b71c4d9d852a6ef6d6259dd950fcfa4c5f6e27d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 24 Oct 2023 13:55:30 +0800 Subject: [PATCH 19/27] address comments Signed-off-by: Haoyang Li --- integration_tests/src/main/python/url_test.py | 71 ++----------------- .../nvidia/spark/rapids/GpuOverrides.scala | 23 +++--- .../spark/sql/rapids/urlFunctions.scala | 57 +++++++-------- 3 files changed, 43 insertions(+), 108 deletions(-) diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index 6bb427a4c30..68e3ecd376a 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -146,80 +146,23 @@ url_gen = StringGen(url_pattern) -def test_parse_url_protocol(): +@pytest.mark.parametrize('data_gen', [url_gen, edge_cases_gen], ids=idfn) +def test_parse_url_protocol(data_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'PROTOCOL')" - )) - -def test_parse_url_protocol_edge_cases(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, edge_cases_gen).selectExpr( + lambda spark : unary_op_df(spark, data_gen).selectExpr( "a", "parse_url(a, 'PROTOCOL')" )) - -@allow_non_gpu('ProjectExec', 'ParseUrl') -def test_parse_url_host_fallback(): - assert_gpu_fallback_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'HOST')" - ), - 'ParseUrl') -@allow_non_gpu('ProjectExec', 'ParseUrl') -def test_parse_url_path_fallback(): - assert_gpu_fallback_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'PATH')" - ), - 'ParseUrl') - -@allow_non_gpu('ProjectExec', 'ParseUrl') -def test_parse_url_query_fallback(): - assert_gpu_fallback_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'QUERY')" - ), - 'ParseUrl') - -@allow_non_gpu('ProjectExec', 'ParseUrl') -def test_parse_url_ref_fallback(): - assert_gpu_fallback_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'REF')" - ), - 'ParseUrl') - -@allow_non_gpu('ProjectExec', 'ParseUrl') -def test_parse_url_file_fallback(): - assert_gpu_fallback_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'FILE')" - ), - 'ParseUrl') +unsupported_part = ['HOST', 'PATH', 'QUERY', 'REF', 'FILE', 'AUTHORITY', 'USERINFO'] @allow_non_gpu('ProjectExec', 'ParseUrl') -def test_parse_url_authority_fallback(): - assert_gpu_fallback_collect( - lambda spark : unary_op_df(spark, url_gen).selectExpr( - "a", - "parse_url(a, 'AUTHORITY')" - ), - 'ParseUrl') - -@allow_non_gpu('ProjectExec', 'ParseUrl') -def test_parse_url_userinfo_fallback(): +@pytest.mark.parametrize('part', unsupported_part, ids=idfn) +def test_parse_url_host_fallback(part): assert_gpu_fallback_collect( lambda spark : unary_op_df(spark, url_gen).selectExpr( "a", - "parse_url(a, 'USERINFO')" + "parse_url(a, '" + part + "')" ), 'ParseUrl') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 6c6080920fb..86981477e15 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3231,27 +3231,28 @@ object GpuOverrides extends Logging { "Extracts a part from a URL", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("url", TypeSig.STRING, TypeSig.STRING), - ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING).withPsNote( + TypeEnum.STRING, "only support partToExtract=PROTOCOL"), TypeSig.STRING)), // Should really be an OptionalParam Some(RepeatingParamCheck("key", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { - val failOnError = a.failOnError - override def tagExprForGpu(): Unit = { - if (failOnError) { + if (a.failOnError) { willNotWorkOnGpu("Fail on error is not supported on GPU when parsing urls.") } - val partToExtract = childExprs(1).convertToGpu() - .asInstanceOf[GpuLiteral].value.asInstanceOf[UTF8String].toString - partToExtract.toUpperCase match { - case "PROTOCOL" => - case _ => - willNotWorkOnGpu(s"Part to extract $partToExtract is not supported on GPU") + + extractStringLit(childExprs(1).convertToCpu()).map(_.toUpperCase) match { + case Some(GpuParseUrl.PROTOCOL) => + case Some(other) => + willNotWorkOnGpu(s"Part to extract $other is not supported on GPU") + case None => + // Should never get here, but just in case + willNotWorkOnGpu("GPU only supports a literal for the part to extract") } } override def convertToGpu(): GpuExpression = { - GpuParseUrl(childExprs.map(_.convertToGpu()), failOnError) + GpuParseUrl(childExprs.map(_.convertToGpu())) } }), expr[Length]( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index a35a294dab3..89f2990a106 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -25,29 +25,25 @@ import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.shims.RapidsErrorUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String object GpuParseUrl { - private val HOST = "HOST" - private val PATH = "PATH" - private val QUERY = "QUERY" - private val REF = "REF" - private val PROTOCOL = "PROTOCOL" - private val FILE = "FILE" - private val AUTHORITY = "AUTHORITY" - private val USERINFO = "USERINFO" + val HOST = "HOST" + val PATH = "PATH" + val QUERY = "QUERY" + val REF = "REF" + val PROTOCOL = "PROTOCOL" + val FILE = "FILE" + val AUTHORITY = "AUTHORITY" + val USERINFO = "USERINFO" } -case class GpuParseUrl(children: Seq[Expression], - failOnErrorOverride: Boolean = SQLConf.get.ansiEnabled) +case class GpuParseUrl(children: Seq[Expression]) extends GpuExpression with ShimExpression with ExpectsInputTypes { - def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) - override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) override def dataType: DataType = StringType @@ -55,7 +51,7 @@ case class GpuParseUrl(children: Seq[Expression], import GpuParseUrl._ - override def checkInputDataTypes(): TypeCheckResult = { + def checkInputDataTypesUseless(): TypeCheckResult = { if (children.size > 3 || children.size < 2) { RapidsErrorUtils.parseUrlWrongNumArgs(children.size) match { case res: Some[TypeCheckResult] => return res.get @@ -65,12 +61,6 @@ case class GpuParseUrl(children: Seq[Expression], super[ExpectsInputTypes].checkInputDataTypes() } - def doColumnar(numRows: Int, url: GpuScalar, partToExtract: GpuScalar): ColumnVector = { - withResource(GpuColumnVector.from(url, numRows, StringType)) { urlCol => - doColumnar(urlCol, partToExtract) - } - } - def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar): ColumnVector = { val part = partToExtract.getValue.asInstanceOf[UTF8String].toString part match { @@ -95,13 +85,14 @@ case class GpuParseUrl(children: Seq[Expression], override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { if (children.size == 2) { val Seq(url, partToExtract) = children - withResourceIfAllowed(url.columnarEvalAny(batch)) { val0 => - withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { val1 => - (val0, val1) match { - case (v0: GpuColumnVector, v1: GpuScalar) => - GpuColumnVector.from(doColumnar(v0, v1), dataType) + withResourceIfAllowed(url.columnarEval(batch)) { urls => + withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { parts => + (urls, parts) match { + case (urlCv: GpuColumnVector, partScalar: GpuScalar) => + GpuColumnVector.from(doColumnar(urlCv, partScalar), dataType) case _ => - throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + throw new + UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") } } } @@ -109,12 +100,12 @@ case class GpuParseUrl(children: Seq[Expression], // 3-arg, i.e. QUERY with key assert(children.size == 3) val Seq(url, partToExtract, key) = children - withResourceIfAllowed(url.columnarEvalAny(batch)) { val0 => - withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { val1 => - withResourceIfAllowed(key.columnarEvalAny(batch)) { val2 => - (val0, val1, val2) match { - case (v0: GpuColumnVector, v1: GpuScalar, v2: GpuScalar) => - GpuColumnVector.from(doColumnar(v0, v1, v2), dataType) + withResourceIfAllowed(url.columnarEval(batch)) { urls => + withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { parts => + withResourceIfAllowed(key.columnarEvalAny(batch)) { keys => + (urls, parts, keys) match { + case (urlCv: GpuColumnVector, partScalar: GpuScalar, keyScalar: GpuScalar) => + GpuColumnVector.from(doColumnar(urlCv, partScalar, keyScalar), dataType) case _ => throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") @@ -124,4 +115,4 @@ case class GpuParseUrl(children: Seq[Expression], } } } -} \ No newline at end of file +} From 46527f37f218196ae22370d7bbe57cb8e813a784 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 24 Oct 2023 14:13:32 +0800 Subject: [PATCH 20/27] remove unnecessary error handling Signed-off-by: Haoyang Li --- docs/supported_ops.md | 2 +- integration_tests/src/main/python/url_test.py | 10 ---------- .../org/apache/spark/sql/rapids/urlFunctions.scala | 12 ------------ .../spark/sql/rapids/shims/RapidsErrorUtils.scala | 5 ----- .../spark/sql/rapids/shims/RapidsErrorUtils.scala | 5 ----- .../spark/sql/rapids/shims/RapidsErrorUtils.scala | 5 ----- .../spark/sql/rapids/shims/RapidsErrorUtils.scala | 5 ----- .../spark/sql/rapids/shims/RapidsErrorUtils.scala | 9 --------- .../spark/sql/rapids/shims/RapidsErrorUtils.scala | 8 -------- 9 files changed, 1 insertion(+), 60 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 254fe8ef876..4446a0c7869 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10770,7 +10770,7 @@ are limited. -PS
Literal value only
+PS
only support partToExtract=PROTOCOL;
Literal value only
diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index 68e3ecd376a..5cc289f7771 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -165,13 +165,3 @@ def test_parse_url_host_fallback(part): "parse_url(a, '" + part + "')" ), 'ParseUrl') - -def test_parse_url_too_many_args(): - error_message = 'parse_url function requires two or three arguments' \ - if is_before_spark_340() else \ - '[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `parse_url` requires [2, 3] parameters' - assert_gpu_and_cpu_error( - lambda spark : unary_op_df(spark, StringGen()).selectExpr( - "a","parse_url(a, 'PROTOCOL', 'key', 'value')").collect(), - conf={}, - error_message=error_message) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 89f2990a106..5512e0b9ce6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -23,9 +23,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.ParseURI import com.nvidia.spark.rapids.shims.ShimExpression -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.rapids.shims.RapidsErrorUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -50,16 +48,6 @@ case class GpuParseUrl(children: Seq[Expression]) override def prettyName: String = "parse_url" import GpuParseUrl._ - - def checkInputDataTypesUseless(): TypeCheckResult = { - if (children.size > 3 || children.size < 2) { - RapidsErrorUtils.parseUrlWrongNumArgs(children.size) match { - case res: Some[TypeCheckResult] => return res.get - case _ => // error message has been thrown - } - } - super[ExpectsInputTypes].checkInputDataTypes() - } def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar): ColumnVector = { val part = partToExtract.getValue.asInstanceOf[UTF8String].toString diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 2084336cced..f23229e0956 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -23,7 +23,6 @@ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} @@ -82,8 +81,4 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { throw new AnalysisException(s"$tableIdentifier already exists.") } - - def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { - Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) - } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index c6122375cf2..b301397255a 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -25,7 +25,6 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} @@ -86,8 +85,4 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } - - def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { - Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) - } } diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index fed5d0c4f6c..6fa5b8350a5 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -20,7 +20,6 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} @@ -84,8 +83,4 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } - - def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { - Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) - } } diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index c48ff240640..581e5dc9c2e 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -25,7 +25,6 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.SparkDateTimeException -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -81,8 +80,4 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { def sqlArrayIndexNotStartAtOneError(): RuntimeException = { new ArrayIndexOutOfBoundsException("SQL array indices start at 1") } - - def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { - Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) - } } diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 9f0a365f6dc..2bd162cd0f9 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -21,7 +21,6 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.SparkDateTimeException -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -89,12 +88,4 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { override def intervalDivByZeroError(origin: Origin): ArithmeticException = { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } - - def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { - Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) - } - - def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { - QueryExecutionErrors.invalidUrlError(url, e) - } } diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index a056488e403..e9116801699 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -22,7 +22,6 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.SparkDateTimeException -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -89,11 +88,4 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { override def intervalDivByZeroError(origin: Origin): ArithmeticException = { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } - - def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { - throw QueryCompilationErrors.wrongNumArgsError( - "parse_url", Seq("[2, 3]"), actual - ) - None - } } From 6161fa4871c3f9b8fc7709caab2de77079aaed4c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 24 Oct 2023 14:27:05 +0800 Subject: [PATCH 21/27] clean up Signed-off-by: Haoyang Li --- .../org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 1 - .../org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 2bd162cd0f9..3585910993d 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} -import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index e9116801699..45b64307254 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -23,7 +23,7 @@ package org.apache.spark.sql.rapids.shims import org.apache.spark.SparkDateTimeException import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} From 8f4990c2ff53f953996e796b06d63893f0b75dd2 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 16 Nov 2023 11:00:44 +0800 Subject: [PATCH 22/27] Revert scala tests temporarily for easier testing Signed-off-by: Haoyang Li --- .../spark/rapids/UrlFunctionsSuite.scala | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala new file mode 100644 index 00000000000..f4a5740fa8f --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.{DataFrame, SparkSession} + +class UrlFunctionsSuite extends SparkQueryCompareTestSuite { + def validUrlEdgeCasesDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + // [In search of the perfect URL validation regex](https://mathiasbynens.be/demo/url-regex) + Seq[String]( + "http://foo.com/blah_blah", + "http://foo.com/blah_blah/", + "http://foo.com/blah_blah_(wikipedia)", + "http://foo.com/blah_blah_(wikipedia)_(again)", + "http://www.example.com/wpstyle/?p=364", + "https://www.example.com/foo/?bar=baz&inga=42&quux", + "http://✪df.ws/123", + "http://userid:password@example.com:8080", + "http://userid:password@example.com:8080/", + "http://userid:password@example.com", + "http://userid:password@example.com/", + "http://142.42.1.1/", + "http://142.42.1.1:8080/", + "http://➡.ws/䨹", + "http://⌘.ws", + "http://⌘.ws/", + "http://foo.com/blah_(wikipedia)#cite-1", + "http://foo.com/blah_(wikipedia)_blah#cite-1", + "http://foo.com/unicode_(✪)_in_parens", + "http://foo.com/(something)?after=parens", + "http://☺.damowmow.com/", + "http://code.google.com/events/#&product=browser", + "http://j.mp", + "ftp://foo.bar/baz", + "http://foo.bar/?q=Test%20URL-encoded%20stuff", + "http://مثال.إختبار", + "http://例子.测试", + "http://उदाहरण.परीक्षा", + "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", + "http://1337.net", + "http://a.b-c.de", + "http://223.255.255.254", + "https://foo_bar.example.com/", + "http://", + "http://.", + "http://..", + "http://../", + // "http://?", + // "http://??", + // "http://??/", + // "http://#", + "http://##", + "http://##/", + "http://foo.bar?q=Spaces should be encoded", + "//", + "//a", + "///a", + "///", + "http:///a", + "foo.com", + "rdar://1234", + "h://test", + "http:// shouldfail.com", + ":// should fail", + "http://foo.bar/foo(bar)baz quux", + "ftps://foo.bar/", + "http://-error-.invalid/", + "http://a.b--c.de/", + "http://-a.b.co", + "http://a.b-.co", + "http://0.0.0.0", + "http://10.1.1.0", + "http://10.1.1.255", + "http://224.1.1.1", + "http://1.1.1.1.1", + "http://123.123.123", + "http://3628126748", + "http://.www.foo.bar/", + "http://www.foo.bar./", + "http://.www.foo.bar./", + "http://10.1.1.1", + "http://10.1.1.254" + ).toDF("urls") + } + + def urlCasesFromSpark(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://userinfo@spark.apache.org/path?query=1#Ref", + "https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two", + "http://user:pass@host", + "http://user:pass@host/", + "http://user:pass@host/?#", + // "http://user:pass@host/file;param?query;p2" + ).toDF("urls") + } + + def urlCasesFromSparkInvalid(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "inva lid://user:pass@host/file;param?query;p2" + ).toDF("urls") + } + + def urlIpv6Host(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://[1:2:3:4:5:6:7:8]", + "http://[1::]", + // "http://[1:2:3:4:5:6:7::]", + "http://[1::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1::7:8]", + "http://[1:2:3:4:5::7:8]", + "http://[1:2:3:4:5::8]", + "http://[1::6:7:8]", + "http://[1:2:3:4::6:7:8]", + "http://[1:2:3:4::8]", + "http://[1::5:6:7:8]", + "http://[1:2:3::5:6:7:8]", + "http://[1:2:3::8]", + "http://[1::4:5:6:7:8]", + "http://[1:2::4:5:6:7:8]", + "http://[1:2::8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::8]", + // "http://[::2:3:4:5:6:7:8]", + "http://[::8]", + "http://[::]", + // "http://[fe80::7:8%eth0]", + // "http://[fe80::7:8%1]", + "http://[::255.255.255.255]", + "http://[::ffff:255.255.255.255]", + "http://[::ffff:0:255.255.255.255]", + "http://[2001:db8:3:4::192.0.2.33]", + "http://[64:ff9b::192.0.2.33]" + ).toDF("urls") + } + + def parseUrlProtocol(frame: DataFrame): DataFrame = { + frame.selectExpr("urls", "parse_url(urls, 'PROTOCOL')") + } + + testSparkResultsAreEqual("Test parse_url edge cases from internet", validUrlEdgeCasesDf) { + parseUrlProtocol + } + + testSparkResultsAreEqual("Test parse_url cases from Spark", urlCasesFromSpark) { + parseUrlProtocol + } + + testSparkResultsAreEqual("Test parse_url invalid cases from Spark", urlCasesFromSparkInvalid) { + parseUrlProtocol + } + + testSparkResultsAreEqual("Test parse_url ipv6 host", urlIpv6Host) { + parseUrlProtocol + } +} \ No newline at end of file From 337637603807bb065f79e081ec1d170cde8738a5 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 16 Nov 2023 18:16:46 +0800 Subject: [PATCH 23/27] Fix two nits Signed-off-by: Haoyang Li --- .../main/scala/com/nvidia/spark/rapids/GpuOverrides.scala | 2 +- .../scala/org/apache/spark/sql/rapids/urlFunctions.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index b5be6cc391a..5dfc951c0bd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3241,7 +3241,7 @@ object GpuOverrides extends Logging { willNotWorkOnGpu("Fail on error is not supported on GPU when parsing urls.") } - extractStringLit(childExprs(1).convertToCpu()).map(_.toUpperCase) match { + extractStringLit(a.children(1)).map(_.toUpperCase) match { case Some(GpuParseUrl.PROTOCOL) => case Some(other) => willNotWorkOnGpu(s"Part to extract $other is not supported on GPU") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index 5512e0b9ce6..586814c38e7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -75,9 +75,9 @@ case class GpuParseUrl(children: Seq[Expression]) val Seq(url, partToExtract) = children withResourceIfAllowed(url.columnarEval(batch)) { urls => withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { parts => - (urls, parts) match { - case (urlCv: GpuColumnVector, partScalar: GpuScalar) => - GpuColumnVector.from(doColumnar(urlCv, partScalar), dataType) + parts match { + case partScalar: GpuScalar => + GpuColumnVector.from(doColumnar(urls, partScalar), dataType) case _ => throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") From 4e98888c89b93977b2238b17e112cc31b6358ac7 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 22 Nov 2023 09:49:48 +0800 Subject: [PATCH 24/27] Updated results Signed-off-by: Haoyang Li --- integration_tests/src/main/python/url_test.py | 90 +++++++++---------- .../spark/rapids/UrlFunctionsSuite.scala | 18 ++-- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index 5cc289f7771..b8859257722 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -35,34 +35,34 @@ "http://foo.com/blah_blah_(wikipedia)_(again)", "http://www.example.com/wpstyle/?p=364", "https://www.example.com/foo/?bar=baz&inga=42&quux", - # "http://✪df.ws/123", + "http://✪df.ws/123", "http://userid:password@example.com:8080", "http://userid:password@example.com:8080/", "http://userid:password@example.com", "http://userid:password@example.com/", "http://142.42.1.1/", "http://142.42.1.1:8080/", - # "http://➡.ws/䨹", - # "http://⌘.ws", - # "http://⌘.ws/", + "http://➡.ws/䨹", + "http://⌘.ws", + "http://⌘.ws/", "http://foo.com/blah_(wikipedia)#cite-1", "http://foo.com/blah_(wikipedia)_blah#cite-1", - # "http://foo.com/unicode_(✪)_in_parens", + "http://foo.com/unicode_(✪)_in_parens", "http://foo.com/(something)?after=parens", - # "http://☺.damowmow.com/", + "http://☺.damowmow.com/", "http://code.google.com/events/#&product=browser", "http://j.mp", "ftp://foo.bar/baz", r"http://foo.bar/?q=Test%20URL-encoded%20stuff", - # "http://مثال.إختبار", - # "http://例子.测试", - # "http://उदाहरण.परीक्षा", - # "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", + "http://مثال.إختبار", + "http://例子.测试", + "http://उदाहरण.परीक्षा", + "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", "http://1337.net", "http://a.b-c.de", "http://223.255.255.254", "https://foo_bar.example.com/", - # "http:# ", + "http:# ", "http://.", "http://..", "http://../", @@ -70,8 +70,8 @@ "http://??", "http://??/", "http://#", - # "http://##", - # "http://##/", + "http://##", + "http://##/", "http://foo.bar?q=Spaces should be encoded", "# ", "//a", @@ -108,38 +108,38 @@ "http://user:pass@host/?#", "http://user:pass@host/file;param?query;p2", "inva lid://user:pass@host/file;param?query;p2", - # "http://[1:2:3:4:5:6:7:8]", - # "http://[1::]", - # "http://[1:2:3:4:5:6:7::]", - # "http://[1::8]", - # "http://[1:2:3:4:5:6::8]", - # "http://[1:2:3:4:5:6::8]", - # "http://[1::7:8]", - # "http://[1:2:3:4:5::7:8]", - # "http://[1:2:3:4:5::8]", - # "http://[1::6:7:8]", - # "http://[1:2:3:4::6:7:8]", - # "http://[1:2:3:4::8]", - # "http://[1::5:6:7:8]", - # "http://[1:2:3::5:6:7:8]", - # "http://[1:2:3::8]", - # "http://[1::4:5:6:7:8]", - # "http://[1:2::4:5:6:7:8]", - # "http://[1:2::8]", - # "http://[1::3:4:5:6:7:8]", - # "http://[1::3:4:5:6:7:8]", - # "http://[1::8]", - # "http://[::2:3:4:5:6:7:8]", - # "http://[::2:3:4:5:6:7:8]", - # "http://[::8]", - # "http://[::]", - # "http://[fe80::7:8%eth0]", - # "http://[fe80::7:8%1]", - # "http://[::255.255.255.255]", - # "http://[::ffff:255.255.255.255]", - # "http://[::ffff:0:255.255.255.255]", - # "http://[2001:db8:3:4::192.0.2.33]", - # "http://[64:ff9b::192.0.2.33]" + "http://[1:2:3:4:5:6:7:8]", + "http://[1::]", + "http://[1:2:3:4:5:6:7::]", + "http://[1::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1::7:8]", + "http://[1:2:3:4:5::7:8]", + "http://[1:2:3:4:5::8]", + "http://[1::6:7:8]", + "http://[1:2:3:4::6:7:8]", + "http://[1:2:3:4::8]", + "http://[1::5:6:7:8]", + "http://[1:2:3::5:6:7:8]", + "http://[1:2:3::8]", + "http://[1::4:5:6:7:8]", + "http://[1:2::4:5:6:7:8]", + "http://[1:2::8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::8]", + "http://[::2:3:4:5:6:7:8]", + "http://[::2:3:4:5:6:7:8]", + "http://[::8]", + "http://[::]", + "http://[fe80::7:8%eth0]", + "http://[fe80::7:8%1]", + "http://[::255.255.255.255]", + "http://[::ffff:255.255.255.255]", + "http://[::ffff:0:255.255.255.255]", + "http://[2001:db8:3:4::192.0.2.33]", + "http://[64:ff9b::192.0.2.33]" ] edge_cases_gen = SetValuesGen(StringType(), edge_cases) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala index f4a5740fa8f..020eed919a0 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -60,10 +60,10 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://.", "http://..", "http://../", - // "http://?", - // "http://??", - // "http://??/", - // "http://#", + // "http://?", + "http://??", + "http://??/", + // "http://#", "http://##", "http://##/", "http://foo.bar?q=Spaces should be encoded", @@ -106,7 +106,7 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://user:pass@host", "http://user:pass@host/", "http://user:pass@host/?#", - // "http://user:pass@host/file;param?query;p2" + "http://user:pass@host/file;param?query;p2" ).toDF("urls") } @@ -122,7 +122,7 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { Seq[String]( "http://[1:2:3:4:5:6:7:8]", "http://[1::]", - // "http://[1:2:3:4:5:6:7::]", + "http://[1:2:3:4:5:6:7::]", "http://[1::8]", "http://[1:2:3:4:5:6::8]", "http://[1:2:3:4:5:6::8]", @@ -141,11 +141,11 @@ class UrlFunctionsSuite extends SparkQueryCompareTestSuite { "http://[1::3:4:5:6:7:8]", "http://[1::3:4:5:6:7:8]", "http://[1::8]", - // "http://[::2:3:4:5:6:7:8]", + "http://[::2:3:4:5:6:7:8]", "http://[::8]", "http://[::]", - // "http://[fe80::7:8%eth0]", - // "http://[fe80::7:8%1]", + "http://[fe80::7:8%eth0]", + "http://[fe80::7:8%1]", "http://[::255.255.255.255]", "http://[::ffff:255.255.255.255]", "http://[::ffff:0:255.255.255.255]", From 6d916c4da77d1695d2e9fb916461cac688861466 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 22 Nov 2023 15:05:42 +0800 Subject: [PATCH 25/27] clean up Signed-off-by: Haoyang Li --- docs/compatibility.md | 11 -- .../spark/rapids/UrlFunctionsSuite.scala | 176 ------------------ 2 files changed, 187 deletions(-) delete mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala diff --git a/docs/compatibility.md b/docs/compatibility.md index 091cafe228f..ac90d309fe1 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -467,17 +467,6 @@ Spark stores timestamps internally relative to the JVM time zone. Converting an between time zones is not currently supported on the GPU. Therefore operations involving timestamps will only be GPU-accelerated if the time zone used by the JVM is UTC. -## URL parsing - -`parse_url` can produce different results on the GPU compared to the CPU. - -Known issues for PROTOCOL parsing: -- If urls containing utf-8 special characters, PROTOCOL results on GPU will be null. -- If urls containing ipv6 host, GPU will return null for PROTOCOL. -- GPU will still try to parse the PROTOCOL instead of returning null for some edge invalid cases, - such as urls containing multiple '#' in REF (http://##) or empty authority component followed by - a empty path (http://). - ## Windowing ### Window Functions diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala deleted file mode 100644 index 020eed919a0..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import org.apache.spark.sql.{DataFrame, SparkSession} - -class UrlFunctionsSuite extends SparkQueryCompareTestSuite { - def validUrlEdgeCasesDf(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - // [In search of the perfect URL validation regex](https://mathiasbynens.be/demo/url-regex) - Seq[String]( - "http://foo.com/blah_blah", - "http://foo.com/blah_blah/", - "http://foo.com/blah_blah_(wikipedia)", - "http://foo.com/blah_blah_(wikipedia)_(again)", - "http://www.example.com/wpstyle/?p=364", - "https://www.example.com/foo/?bar=baz&inga=42&quux", - "http://✪df.ws/123", - "http://userid:password@example.com:8080", - "http://userid:password@example.com:8080/", - "http://userid:password@example.com", - "http://userid:password@example.com/", - "http://142.42.1.1/", - "http://142.42.1.1:8080/", - "http://➡.ws/䨹", - "http://⌘.ws", - "http://⌘.ws/", - "http://foo.com/blah_(wikipedia)#cite-1", - "http://foo.com/blah_(wikipedia)_blah#cite-1", - "http://foo.com/unicode_(✪)_in_parens", - "http://foo.com/(something)?after=parens", - "http://☺.damowmow.com/", - "http://code.google.com/events/#&product=browser", - "http://j.mp", - "ftp://foo.bar/baz", - "http://foo.bar/?q=Test%20URL-encoded%20stuff", - "http://مثال.إختبار", - "http://例子.测试", - "http://उदाहरण.परीक्षा", - "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", - "http://1337.net", - "http://a.b-c.de", - "http://223.255.255.254", - "https://foo_bar.example.com/", - "http://", - "http://.", - "http://..", - "http://../", - // "http://?", - "http://??", - "http://??/", - // "http://#", - "http://##", - "http://##/", - "http://foo.bar?q=Spaces should be encoded", - "//", - "//a", - "///a", - "///", - "http:///a", - "foo.com", - "rdar://1234", - "h://test", - "http:// shouldfail.com", - ":// should fail", - "http://foo.bar/foo(bar)baz quux", - "ftps://foo.bar/", - "http://-error-.invalid/", - "http://a.b--c.de/", - "http://-a.b.co", - "http://a.b-.co", - "http://0.0.0.0", - "http://10.1.1.0", - "http://10.1.1.255", - "http://224.1.1.1", - "http://1.1.1.1.1", - "http://123.123.123", - "http://3628126748", - "http://.www.foo.bar/", - "http://www.foo.bar./", - "http://.www.foo.bar./", - "http://10.1.1.1", - "http://10.1.1.254" - ).toDF("urls") - } - - def urlCasesFromSpark(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "http://userinfo@spark.apache.org/path?query=1#Ref", - "https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two", - "http://user:pass@host", - "http://user:pass@host/", - "http://user:pass@host/?#", - "http://user:pass@host/file;param?query;p2" - ).toDF("urls") - } - - def urlCasesFromSparkInvalid(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "inva lid://user:pass@host/file;param?query;p2" - ).toDF("urls") - } - - def urlIpv6Host(session: SparkSession): DataFrame = { - import session.sqlContext.implicits._ - Seq[String]( - "http://[1:2:3:4:5:6:7:8]", - "http://[1::]", - "http://[1:2:3:4:5:6:7::]", - "http://[1::8]", - "http://[1:2:3:4:5:6::8]", - "http://[1:2:3:4:5:6::8]", - "http://[1::7:8]", - "http://[1:2:3:4:5::7:8]", - "http://[1:2:3:4:5::8]", - "http://[1::6:7:8]", - "http://[1:2:3:4::6:7:8]", - "http://[1:2:3:4::8]", - "http://[1::5:6:7:8]", - "http://[1:2:3::5:6:7:8]", - "http://[1:2:3::8]", - "http://[1::4:5:6:7:8]", - "http://[1:2::4:5:6:7:8]", - "http://[1:2::8]", - "http://[1::3:4:5:6:7:8]", - "http://[1::3:4:5:6:7:8]", - "http://[1::8]", - "http://[::2:3:4:5:6:7:8]", - "http://[::8]", - "http://[::]", - "http://[fe80::7:8%eth0]", - "http://[fe80::7:8%1]", - "http://[::255.255.255.255]", - "http://[::ffff:255.255.255.255]", - "http://[::ffff:0:255.255.255.255]", - "http://[2001:db8:3:4::192.0.2.33]", - "http://[64:ff9b::192.0.2.33]" - ).toDF("urls") - } - - def parseUrlProtocol(frame: DataFrame): DataFrame = { - frame.selectExpr("urls", "parse_url(urls, 'PROTOCOL')") - } - - testSparkResultsAreEqual("Test parse_url edge cases from internet", validUrlEdgeCasesDf) { - parseUrlProtocol - } - - testSparkResultsAreEqual("Test parse_url cases from Spark", urlCasesFromSpark) { - parseUrlProtocol - } - - testSparkResultsAreEqual("Test parse_url invalid cases from Spark", urlCasesFromSparkInvalid) { - parseUrlProtocol - } - - testSparkResultsAreEqual("Test parse_url ipv6 host", urlIpv6Host) { - parseUrlProtocol - } -} \ No newline at end of file From 1b3609090c0dfc5efca40b9ef526f9e259dd612d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 28 Nov 2023 13:31:45 +0800 Subject: [PATCH 26/27] rename urlFunctions to GpuParseUrl Signed-off-by: Haoyang Li --- integration_tests/src/main/python/url_test.py | 4 ++-- .../sql/rapids/{urlFunctions.scala => GpuParseUrl.scala} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename sql-plugin/src/main/scala/org/apache/spark/sql/rapids/{urlFunctions.scala => GpuParseUrl.scala} (100%) diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index b8859257722..ba51170108d 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -154,10 +154,10 @@ def test_parse_url_protocol(data_gen): "parse_url(a, 'PROTOCOL')" )) -unsupported_part = ['HOST', 'PATH', 'QUERY', 'REF', 'FILE', 'AUTHORITY', 'USERINFO'] +unsupported_parts = ['HOST', 'PATH', 'QUERY', 'REF', 'FILE', 'AUTHORITY', 'USERINFO'] @allow_non_gpu('ProjectExec', 'ParseUrl') -@pytest.mark.parametrize('part', unsupported_part, ids=idfn) +@pytest.mark.parametrize('part', unsupported_parts, ids=idfn) def test_parse_url_host_fallback(part): assert_gpu_fallback_collect( lambda spark : unary_op_df(spark, url_gen).selectExpr( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala similarity index 100% rename from sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala rename to sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala From 3ace1245d324d3c677bc4e9c148f61652ac44394 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 6 Dec 2023 16:17:43 +0800 Subject: [PATCH 27/27] verify Signed-off-by: Haoyang Li --- docs/supported_ops.md | 702 ++++++++++++++++++++---------------------- 1 file changed, 338 insertions(+), 364 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 5d7bc3365bc..fab825a9c0f 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -13759,6 +13759,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringLocate `position`, `locate` Substring search operator @@ -13848,32 +13874,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringRPad `rpad` Pad a string on the right @@ -14120,6 +14120,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringSplit `split` Splits `str` around occurrences that match `regex` @@ -14209,32 +14235,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringToMap `str_to_map` Creates a map after splitting the input string into pairs of key-value strings @@ -14481,6 +14481,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringTrimLeft `ltrim` StringTrimLeft operator @@ -14664,58 +14690,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Substring `substr`, `substring` Substring operator @@ -14894,6 +14868,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Subtract `-` Subtraction @@ -15116,32 +15116,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Tanh `tanh` Hyperbolic tangent @@ -15300,6 +15274,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ToDegrees `degrees` Converts radians to degrees @@ -15530,32 +15530,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - TransformValues `transform_values` Transform values in a map using a transform function @@ -15714,6 +15688,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + UnaryPositive `positive` A numeric value with a + in front of it @@ -15924,32 +15924,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - UnscaledValue Convert a Decimal to an unscaled long value for some aggregation optimizations @@ -16091,6 +16065,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + WindowExpression Calculates a return value for every input row of a table based on a group (or "window") of rows @@ -16321,32 +16321,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - AggregateExpression Aggregate expression @@ -16543,6 +16517,32 @@ are limited. S +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ApproximatePercentile `percentile_approx`, `approx_percentile` Approximate percentile @@ -16717,32 +16717,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Average `avg`, `mean` Average aggregate operator @@ -17009,6 +16983,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + CollectSet `collect_set` Collect a set of unique elements, not supported in reduction @@ -17142,32 +17142,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Count `count` Count aggregate operator @@ -17434,6 +17408,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Last `last`, `last_value` last aggregate operator @@ -17567,32 +17567,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Max `max` Max aggregate operator @@ -17859,6 +17833,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Percentile `percentile` Aggregation computing exact percentile @@ -18165,32 +18165,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StddevPop `stddev_pop` Aggregation computing population standard deviation @@ -18324,6 +18298,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StddevSamp `stddev_samp`, `std`, `stddev` Aggregation computing sample standard deviation @@ -18590,32 +18590,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - VariancePop `var_pop` Aggregation computing population variance @@ -18749,6 +18723,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + VarianceSamp `var_samp`, `variance` Aggregation computing sample variance @@ -18955,32 +18955,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - HiveGenericUDF Hive Generic UDF, the UDF can choose to implement a RAPIDS accelerated interface to get better performance