diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index 2b2106aa3fc..ba31126736a 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -29,6 +29,7 @@ r'(:[0-9]{1,3}){0,1}(/[a-z]{1,3}){0,3}(\?key=[a-z]{1,3}){0,1}(#([a-z]{1,3})){0,1}' edge_cases = [ + "userinfo@spark.apache.org/path?query=1#Ref", "http://foo.com/blah_blah", "http://foo.com/blah_blah/", "http://foo.com/blah_blah_(wikipedia)", @@ -103,6 +104,7 @@ "http://10.1.1.254", "http://userinfo@spark.apache.org/path?query=1#Ref", r"https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two", + r"https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%9Fy&q2=2#Ref%20two", "http://user:pass@host", "http://user:pass@host/", "http://user:pass@host/?#", @@ -146,8 +148,10 @@ url_gen = StringGen(url_pattern) -supported_parts = ['PROTOCOL', 'HOST'] -unsupported_parts = ['PATH', 'QUERY', 'REF', 'FILE', 'AUTHORITY', 'USERINFO'] +supported_parts = ['PROTOCOL', 'HOST', 'QUERY'] +unsupported_parts = ['PATH', 'REF', 'FILE', 'AUTHORITY', 'USERINFO'] +supported_with_key_parts = ['PROTOCOL', 'HOST'] +unsupported_with_key_parts = ['QUERY', 'PATH', 'REF', 'FILE', 'AUTHORITY', 'USERINFO'] @pytest.mark.parametrize('data_gen', [url_gen, edge_cases_gen], ids=idfn) @pytest.mark.parametrize('part', supported_parts, ids=idfn) @@ -161,3 +165,17 @@ def test_parse_url_unsupported_fallback(part): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, url_gen).selectExpr("a", "parse_url(a, '" + part + "')"), 'ParseUrl') + +@pytest.mark.parametrize('part', supported_with_key_parts, ids=idfn) +def test_parse_url_with_key(part): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, url_gen).selectExpr("parse_url(a, '" + part + "', 'key')")) + + + +@allow_non_gpu('ProjectExec', 'ParseUrl') +@pytest.mark.parametrize('part', unsupported_with_key_parts, ids=idfn) +def test_parse_url_query_with_key_fallback(part): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, url_gen).selectExpr("parse_url(a, '" + part + "', 'key')"), + 'ParseUrl') \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 5a52acbcad6..149fd1226b4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3263,6 +3263,8 @@ object GpuOverrides extends Logging { } extractStringLit(a.children(1)).map(_.toUpperCase) match { + case Some("QUERY") if (a.children.size == 3) => + willNotWorkOnGpu("Part to extract QUERY with key is not supported on GPU") case Some(part) if GpuParseUrl.isSupportedPart(part) => case Some(other) => willNotWorkOnGpu(s"Part to extract $other is not supported on GPU") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala index 528f0ca3bdc..8b6769bf810 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuParseUrl.scala @@ -40,7 +40,7 @@ object GpuParseUrl { def isSupportedPart(part: String): Boolean = { part match { - case PROTOCOL | HOST => + case PROTOCOL | HOST | QUERY => true case _ => false @@ -65,22 +65,24 @@ case class GpuParseUrl(children: Seq[Expression]) ParseURI.parseURIProtocol(url.getBase) case HOST => ParseURI.parseURIHost(url.getBase) - case PATH | QUERY | REF | FILE | AUTHORITY | USERINFO => + case QUERY => + ParseURI.parseURIQuery(url.getBase) + case PATH | REF | FILE | AUTHORITY | USERINFO => throw new UnsupportedOperationException(s"$this is not supported partToExtract=$part. " + - s"Only PROTOCOL and HOST are supported") + s"Only PROTOCOL, HOST and QUERY without a key are supported") case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") } } - def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar, key: GpuScalar): ColumnVector = { + def doColumnar(col: GpuColumnVector, partToExtract: GpuScalar, key: GpuScalar): ColumnVector = { val part = partToExtract.getValue.asInstanceOf[UTF8String].toString if (part != QUERY) { // return a null columnvector - return ColumnVector.fromStrings(null, null) + return GpuColumnVector.columnVectorFromNull(col.getRowCount.toInt, StringType) } throw new UnsupportedOperationException(s"$this is not supported partToExtract=$part. " + - s"Only PROTOCOL and HOST are supported") + s"Only PROTOCOL, HOST and QUERY without a key are supported") } override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {