diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 089a9cb9b27..40acd48329c 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -301,6 +301,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.NthValue|`nth_value`|nth window operator|true|None| spark.rapids.sql.expression.OctetLength|`octet_length`|The byte length of string data|true|None| spark.rapids.sql.expression.Or|`or`|Logical OR|true|None| +spark.rapids.sql.expression.ParseUrl|`parse_url`|Extracts a part from a URL|true|None| spark.rapids.sql.expression.PercentRank|`percent_rank`|Window function that returns the percent rank value within the aggregation window|true|None| spark.rapids.sql.expression.Pmod|`pmod`|Pmod|true|None| spark.rapids.sql.expression.PosExplode|`posexplode_outer`, `posexplode`|Given an input array produces a sequence of rows for each value in the array|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 414a53c56ac..fab825a9c0f 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10734,6 +10734,95 @@ are limited. +ParseUrl +`parse_url` +Extracts a part from a URL +None +project +url + + + + + + + + + +S + + + + + + + + + + +partToExtract + + + + + + + + + +PS
only support partToExtract=PROTOCOL;
Literal value only
+ + + + + + + + + + +key + + + + + + + + + +PS
Literal value only
+ + + + + + + + + + +result + + + + + + + + + +S + + + + + + + + + + PercentRank `percent_rank` Window function that returns the percent rank value within the aggregation window @@ -10849,6 +10938,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + PosExplode `posexplode_outer`, `posexplode` Given an input array produces a sequence of rows for each value in the array @@ -11028,32 +11143,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PreciseTimestampConversion Expression used internally to convert the TimestampType to Long and back without losing precision, i.e. in microseconds. Used in time windowing @@ -11324,6 +11413,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Quarter `quarter` Returns the quarter of the year for date, in the range 1 to 4 @@ -11439,32 +11554,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - RaiseError `raise_error` Throw an exception @@ -11695,6 +11784,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + RegExpExtractAll `regexp_extract_all` Extract all strings matching a regular expression corresponding to the regex group index @@ -11894,32 +12009,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Remainder `%`, `mod` Remainder or modulo @@ -12082,6 +12171,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Rint `rint` Rounds up a double value to the nearest double equal to an integer @@ -12266,32 +12381,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ScalaUDF User Defined Function, the UDF can choose to implement a RAPIDS accelerated interface to get better performance. @@ -12522,6 +12611,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ShiftLeft `shiftleft` Bitwise shift left (<<) @@ -12658,32 +12773,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ShiftRightUnsigned `shiftrightunsigned` Bitwise unsigned shift right (>>>) @@ -12889,6 +12978,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Sinh `sinh` Hyperbolic sine @@ -13026,32 +13141,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - SortArray `sort_array` Returns a sorted array with the input array and the ascending / descending order @@ -13261,6 +13350,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Sqrt `sqrt` Square root @@ -13419,32 +13534,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StartsWith Starts with @@ -13670,6 +13759,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringLocate `position`, `locate` Substring search operator @@ -13848,32 +13963,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringRepeat `repeat` StringRepeat operator that repeats the given strings with numbers of times given by repeatTimes @@ -14031,6 +14120,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringSplit `split` Splits `str` around occurrences that match `regex` @@ -14209,32 +14324,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringTranslate `translate` StringTranslate operator @@ -14392,6 +14481,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringTrimLeft `ltrim` StringTrimLeft operator @@ -14575,32 +14690,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Substring `substr`, `substring` Substring operator @@ -14779,6 +14868,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Subtract `-` Subtraction @@ -15001,32 +15116,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Tanh `tanh` Hyperbolic tangent @@ -15185,6 +15274,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ToDegrees `degrees` Converts radians to degrees @@ -15415,32 +15530,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - TransformValues `transform_values` Transform values in a map using a transform function @@ -15599,6 +15688,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + UnaryPositive `positive` A numeric value with a + in front of it @@ -15809,32 +15924,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - UnscaledValue Convert a Decimal to an unscaled long value for some aggregation optimizations @@ -15976,6 +16065,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + WindowExpression Calculates a return value for every input row of a table based on a group (or "window") of rows @@ -16206,32 +16321,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - AggregateExpression Aggregate expression @@ -16428,6 +16517,32 @@ are limited. S +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ApproximatePercentile `percentile_approx`, `approx_percentile` Approximate percentile @@ -16602,32 +16717,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Average `avg`, `mean` Average aggregate operator @@ -16894,6 +16983,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + CollectSet `collect_set` Collect a set of unique elements, not supported in reduction @@ -17027,32 +17142,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Count `count` Count aggregate operator @@ -17319,6 +17408,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Last `last`, `last_value` last aggregate operator @@ -17452,32 +17567,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Max `max` Max aggregate operator @@ -17744,6 +17833,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Percentile `percentile` Aggregation computing exact percentile @@ -17918,32 +18033,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PivotFirst PivotFirst operator @@ -18209,6 +18298,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StddevSamp `stddev_samp`, `std`, `stddev` Aggregation computing sample standard deviation @@ -18342,32 +18457,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Sum `sum` Sum aggregate operator @@ -18634,6 +18723,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + VarianceSamp `var_samp`, `variance` Aggregation computing sample variance @@ -18767,32 +18882,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - NormalizeNaNAndZero Normalize NaN and zero diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py new file mode 100644 index 00000000000..ba51170108d --- /dev/null +++ b/integration_tests/src/main/python/url_test.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect +from data_gen import * +from marks import * +from pyspark.sql.types import * +import pyspark.sql.functions as f +from spark_session import is_before_spark_340 + +# regex to generate limit length urls with HOST, PATH, QUERY, REF, PROTOCOL, FILE, AUTHORITY, USERINFO +url_pattern = r'((http|https|ftp)://)(([a-zA-Z][a-zA-Z0-9]{0,2}\.){0,3}([a-zA-Z][a-zA-Z0-9]{0,2})\.([a-zA-Z][a-zA-Z0-9]{0,2}))' \ + r'(:[0-9]{1,3}){0,1}(/[a-zA-Z0-9]{1,3}){0,3}(\?[a-zA-Z0-9]{1,3}=[a-zA-Z0-9]{1,3}){0,1}(#([a-zA-Z0-9]{1,3})){0,1}' + +url_pattern_with_key = r'((http|https|ftp|file)://)(([a-z]{1,3}\.){0,3}([a-z]{1,3})\.([a-z]{1,3}))' \ + r'(:[0-9]{1,3}){0,1}(/[a-z]{1,3}){0,3}(\?key=[a-z]{1,3}){0,1}(#([a-z]{1,3})){0,1}' + +edge_cases = [ + "http://foo.com/blah_blah", + "http://foo.com/blah_blah/", + "http://foo.com/blah_blah_(wikipedia)", + "http://foo.com/blah_blah_(wikipedia)_(again)", + "http://www.example.com/wpstyle/?p=364", + "https://www.example.com/foo/?bar=baz&inga=42&quux", + "http://✪df.ws/123", + "http://userid:password@example.com:8080", + "http://userid:password@example.com:8080/", + "http://userid:password@example.com", + "http://userid:password@example.com/", + "http://142.42.1.1/", + "http://142.42.1.1:8080/", + "http://➡.ws/䨹", + "http://⌘.ws", + "http://⌘.ws/", + "http://foo.com/blah_(wikipedia)#cite-1", + "http://foo.com/blah_(wikipedia)_blah#cite-1", + "http://foo.com/unicode_(✪)_in_parens", + "http://foo.com/(something)?after=parens", + "http://☺.damowmow.com/", + "http://code.google.com/events/#&product=browser", + "http://j.mp", + "ftp://foo.bar/baz", + r"http://foo.bar/?q=Test%20URL-encoded%20stuff", + "http://مثال.إختبار", + "http://例子.测试", + "http://उदाहरण.परीक्षा", + "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", + "http://1337.net", + "http://a.b-c.de", + "http://223.255.255.254", + "https://foo_bar.example.com/", + "http:# ", + "http://.", + "http://..", + "http://../", + "http://?", + "http://??", + "http://??/", + "http://#", + "http://##", + "http://##/", + "http://foo.bar?q=Spaces should be encoded", + "# ", + "//a", + "///a", + "/# ", + "http:///a", + "foo.com", + "rdar://1234", + "h://test", + "http:// shouldfail.com", + ":// should fail", + "http://foo.bar/foo(bar)baz quux", + "ftps://foo.bar/", + "http://-error-.invalid/", + "http://a.b--c.de/", + "http://-a.b.co", + "http://a.b-.co", + "http://0.0.0.0", + "http://10.1.1.0", + "http://10.1.1.255", + "http://224.1.1.1", + "http://1.1.1.1.1", + "http://123.123.123", + "http://3628126748", + "http://.www.foo.bar/", + "http://www.foo.bar./", + "http://.www.foo.bar./", + "http://10.1.1.1", + "http://10.1.1.254", + "http://userinfo@spark.apache.org/path?query=1#Ref", + r"https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two", + "http://user:pass@host", + "http://user:pass@host/", + "http://user:pass@host/?#", + "http://user:pass@host/file;param?query;p2", + "inva lid://user:pass@host/file;param?query;p2", + "http://[1:2:3:4:5:6:7:8]", + "http://[1::]", + "http://[1:2:3:4:5:6:7::]", + "http://[1::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1::7:8]", + "http://[1:2:3:4:5::7:8]", + "http://[1:2:3:4:5::8]", + "http://[1::6:7:8]", + "http://[1:2:3:4::6:7:8]", + "http://[1:2:3:4::8]", + "http://[1::5:6:7:8]", + "http://[1:2:3::5:6:7:8]", + "http://[1:2:3::8]", + "http://[1::4:5:6:7:8]", + "http://[1:2::4:5:6:7:8]", + "http://[1:2::8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::8]", + "http://[::2:3:4:5:6:7:8]", + "http://[::2:3:4:5:6:7:8]", + "http://[::8]", + "http://[::]", + "http://[fe80::7:8%eth0]", + "http://[fe80::7:8%1]", + "http://[::255.255.255.255]", + "http://[::ffff:255.255.255.255]", + "http://[::ffff:0:255.255.255.255]", + "http://[2001:db8:3:4::192.0.2.33]", + "http://[64:ff9b::192.0.2.33]" +] + +edge_cases_gen = SetValuesGen(StringType(), edge_cases) + +url_gen = StringGen(url_pattern) + +@pytest.mark.parametrize('data_gen', [url_gen, edge_cases_gen], ids=idfn) +def test_parse_url_protocol(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + "a", + "parse_url(a, 'PROTOCOL')" + )) + +unsupported_parts = ['HOST', 'PATH', 'QUERY', 'REF', 'FILE', 'AUTHORITY', 'USERINFO'] + +@allow_non_gpu('ProjectExec', 'ParseUrl') +@pytest.mark.parametrize('part', unsupported_parts, ids=idfn) +def test_parse_url_host_fallback(part): + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, '" + part + "')" + ), + 'ParseUrl') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 74430ae8e90..6cb22f59885 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3241,6 +3241,34 @@ object GpuOverrides extends Logging { ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("idx", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), (a, conf, p, r) => new GpuRegExpExtractAllMeta(a, conf, p, r)), + expr[ParseUrl]( + "Extracts a part from a URL", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("url", TypeSig.STRING, TypeSig.STRING), + ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING).withPsNote( + TypeEnum.STRING, "only support partToExtract=PROTOCOL"), TypeSig.STRING)), + // Should really be an OptionalParam + Some(RepeatingParamCheck("key", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + if (a.failOnError) { + willNotWorkOnGpu("Fail on error is not supported on GPU when parsing urls.") + } + + extractStringLit(a.children(1)).map(_.toUpperCase) match { + case Some(GpuParseUrl.PROTOCOL) => + case Some(other) => + willNotWorkOnGpu(s"Part to extract $other is not supported on GPU") + case None => + // Should never get here, but just in case + willNotWorkOnGpu("GPU only supports a literal for the part to extract") + } + } + + override def convertToGpu(): GpuExpression = { + GpuParseUrl(childExprs.map(_.convertToGpu())) + } + }), expr[Length]( "String character length or binary byte length", ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala new file mode 100644 index 00000000000..586814c38e7 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf.ColumnVector +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.Arm._ +import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.jni.ParseURI +import com.nvidia.spark.rapids.shims.ShimExpression + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.types.UTF8String + +object GpuParseUrl { + val HOST = "HOST" + val PATH = "PATH" + val QUERY = "QUERY" + val REF = "REF" + val PROTOCOL = "PROTOCOL" + val FILE = "FILE" + val AUTHORITY = "AUTHORITY" + val USERINFO = "USERINFO" +} + +case class GpuParseUrl(children: Seq[Expression]) + extends GpuExpression with ShimExpression with ExpectsInputTypes { + + override def nullable: Boolean = true + override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + override def prettyName: String = "parse_url" + + import GpuParseUrl._ + + def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar): ColumnVector = { + val part = partToExtract.getValue.asInstanceOf[UTF8String].toString + part match { + case PROTOCOL => + ParseURI.parseURIProtocol(url.getBase) + case HOST | PATH | QUERY | REF | FILE | AUTHORITY | USERINFO => + throw new UnsupportedOperationException(s"$this is not supported partToExtract=$part") + case _ => + throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") + } + } + + def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar, key: GpuScalar): ColumnVector = { + val part = partToExtract.getValue.asInstanceOf[UTF8String].toString + if (part != QUERY) { + // return a null columnvector + return ColumnVector.fromStrings(null, null) + } + throw new UnsupportedOperationException(s"$this only supports partToExtract = PROTOCOL") + } + + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + if (children.size == 2) { + val Seq(url, partToExtract) = children + withResourceIfAllowed(url.columnarEval(batch)) { urls => + withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { parts => + parts match { + case partScalar: GpuScalar => + GpuColumnVector.from(doColumnar(urls, partScalar), dataType) + case _ => + throw new + UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + } + } + } + } else { + // 3-arg, i.e. QUERY with key + assert(children.size == 3) + val Seq(url, partToExtract, key) = children + withResourceIfAllowed(url.columnarEval(batch)) { urls => + withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { parts => + withResourceIfAllowed(key.columnarEvalAny(batch)) { keys => + (urls, parts, keys) match { + case (urlCv: GpuColumnVector, partScalar: GpuScalar, keyScalar: GpuScalar) => + GpuColumnVector.from(doColumnar(urlCv, partScalar, keyScalar), dataType) + case _ => + throw new + UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + } + } + } + } + } + } +} diff --git a/tools/generated_files/operatorsScore.csv b/tools/generated_files/operatorsScore.csv index 17c80f60bfb..4c7248bf975 100644 --- a/tools/generated_files/operatorsScore.csv +++ b/tools/generated_files/operatorsScore.csv @@ -181,6 +181,7 @@ Not,4 NthValue,4 OctetLength,4 Or,4 +ParseUrl,4 PercentRank,4 Percentile,4 PivotFirst,4 diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv index bff8dc7359a..b55f893f40e 100644 --- a/tools/generated_files/supportedExprs.csv +++ b/tools/generated_files/supportedExprs.csv @@ -382,6 +382,10 @@ Or,S,`or`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA, Or,S,`or`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Or,S,`or`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Or,S,`or`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,url,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,partToExtract,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,key,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA PercentRank,S,`percent_rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS PercentRank,S,`percent_rank`,None,window,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Pmod,S,`pmod`,None,project,lhs,NA,S,S,S,S,S,S,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA