diff --git a/src/main/cpp/src/parse_uri.cu b/src/main/cpp/src/parse_uri.cu index 54e79ab022..9c2a3bd9d5 100644 --- a/src/main/cpp/src/parse_uri.cu +++ b/src/main/cpp/src/parse_uri.cu @@ -39,15 +39,12 @@ namespace { // utility to validate a character is valid in a URI constexpr bool is_valid_character(char ch, bool alphanum_only) { - if (alphanum_only) { - if (ch >= '-' && ch <= '9' && ch != '/') return true; // 0-9 and .- - if (ch >= 'A' && ch <= 'Z') return true; // A-Z - if (ch >= 'a' && ch <= 'z') return true; // a-z - } else { - if (ch >= '!' && ch <= ';' && ch != '"') return true; // 0-9 and !#%&'()*+,-./ - if (ch >= '=' && ch <= 'Z' && ch != '>') return true; // A-Z and =?@ - if (ch >= '_' && ch <= 'z' && ch != '`') return true; // a-z and _ - } + return alphanum_only ? (ch >= '-' && ch <= '9' && ch != '/') || // 0-9 and .- + (ch >= 'A' && ch <= 'Z') || // A-Z + (ch >= 'a' && ch <= 'z') // a-z + : (ch >= '!' && ch <= ':' && ch != '"') || // 0-9 and !#%&'()*+,-./: + (ch >= '=' && ch <= ']' && ch != '>') || // A-Z and =?@[] + (ch >= '_' && ch <= 'z' && ch != '`'); // a-z and _ return false; } diff --git a/src/main/cpp/tests/parse_uri.cpp b/src/main/cpp/tests/parse_uri.cpp index 3ff14a6075..a182e5d426 100644 --- a/src/main/cpp/tests/parse_uri.cpp +++ b/src/main/cpp/tests/parse_uri.cpp @@ -91,5 +91,52 @@ TEST_F(ParseURIProtocolTests, SparkEdges) "https"}, {1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); +} + +TEST_F(ParseURIProtocolTests, IP6) +{ + cudf::test::strings_column_wrapper col({ + "https://[fe80::]", + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", + "https://[2001:0DB8:85A3:0000:0000:8A2E:0370:7334]", + "https://[2001:db8::1:0]", + "http://[2001:db8::2:1]", + "https://[::1]", + "https://[2001:db8:85a3:8d3:1319:8a2e:370:7348]:443", + }); + auto result = spark_rapids_jni::parse_uri_to_protocol(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper expected({"https", "https", "https", "https", "http", "https", "https"}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); +} + +TEST_F(ParseURIProtocolTests, IP4) +{ + cudf::test::strings_column_wrapper col({ + "https://192.168.1.100/", + "https://192.168.1.100:8443/", + "https://192.168.1.100.5/", + "https://192.168.1/", + "https://280.100.1.1/", + }); + auto result = spark_rapids_jni::parse_uri_to_protocol(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper expected({"https", "https", "https", "https", "https"}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); +} + +TEST_F(ParseURIProtocolTests, UTF8) +{ + cudf::test::strings_column_wrapper col({ + "https://nvidia.com/%4EV%49%44%49%41", + "http://%77%77%77.%4EV%49%44%49%41.com", + }); + auto result = spark_rapids_jni::parse_uri_to_protocol(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper expected({"https", "http"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); } \ No newline at end of file diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java index 7289d110b2..3f962ffde1 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java @@ -25,8 +25,29 @@ import ai.rapids.cudf.ColumnVector; public class ParseURITest { + void buildExpectedAndRun(String[] testData) { + String[] expectedStrings = new String[testData.length]; + for (int i=0; i