Skip to content

Commit

Permalink
feat(function): Add Spark locate function (#8863)
Browse files Browse the repository at this point in the history
Summary:
A function that returns the position of the first occurrence of substring in
given string after the start position.

Doc: https://spark.apache.org/docs/latest/api/sql/index.html#locate
Spark implementation: https://github.com/apache/spark/blob/v3.5.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala#L1420

Pull Request resolved: #8863

Differential Revision: D66203871

Pulled By: kagamiori

fbshipit-source-id: cf117699a795a19786bc1df546a1578daa9757b9
  • Loading branch information
rui-mo authored and facebook-github-bot committed Nov 20, 2024
1 parent 473902a commit c286451
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 52 deletions.
29 changes: 29 additions & 0 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,35 @@ String Functions
SELECT levenshtein('kitten', 'sitting', 10); -- 3
SELECT levenshtein('kitten', 'sitting', 2); -- -1

.. spark:function:: locate(substring, string, start) -> integer
Returns the 1-based position of the first occurrence of ``substring`` in given ``string``
after position ``start``. The search is from the beginning of ``string`` to the end.
``start`` is the starting character position in ``string`` to search for the ``substring``.
``start`` is 1-based and must be at least 1 and at most the characters number of ``string``.
The following rules on special values are applied to follow Spark's implementation.
They are listed in order of priority:

Returns 0 if ``start`` is NULL. Returns NULL if ``substring`` or ``string`` is NULL.
Returns 0 if ``start`` is less than 1.
Returns 1 if ``substring`` is empty.
Returns 0 if ``start`` is greater than the characters number of ``string``.
Returns 0 if ``substring`` is not found in ``string``. ::

SELECT locate('aa', 'aaads', 1); -- 1
SELECT locate('aa', 'aaads', -1); -- 0
SELECT locate('aa', 'aaads', 2); -- 2
SELECT locate('aa', 'aaads', 6); -- 0
SELECT locate('aa', 'aaads', NULL); -- 0
SELECT locate('', 'aaads', 1); -- 1
SELECT locate('', 'aaads', 9); -- 1
SELECT locate('', 'aaads', -1); -- 0
SELECT locate('', '', 1); -- 1
SELECT locate('aa', '', 1); -- 0
SELECT locate(NULL, NULL, NULL); -- 0
SELECT locate(NULL, NULL, 1); -- NULL
SELECT locate('\u4FE1', '\u4FE1\u5FF5,\u4FE1\u7231,\u4FE1\u5E0C\u671B', 2); -- 4

.. spark:function:: lower(string) -> string
Returns string with all characters changed to lowercase. ::
Expand Down
28 changes: 13 additions & 15 deletions velox/functions/lib/string/StringImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,29 +197,27 @@ std::vector<int32_t> stringToCodePoints(const T& inputString) {
return codePoints;
}

/// Returns the starting position in characters of the Nth instance(counting
/// from the left if lpos==true and from the end otherwise) of the substring in
/// string. Positions start with 1. If not found, 0 is returned. If subString is
/// empty result is 1.
template <bool isAscii, bool lpos = true, typename T>
FOLLY_ALWAYS_INLINE int64_t
stringPosition(const T& string, const T& subString, int64_t instance = 0) {
/// Returns the starting position in characters of the Nth instance of the
/// substring in string. Positions start with 1. If not found, 0 is returned. If
/// subString is empty result is 1.
/// @tparam lpos If true, counting from the start of the string. Counting from
/// the end of the string otherwise.
/// @param instance The 1-based instance of the substring to find in string.
template <bool isAscii, bool lpos = true>
FOLLY_ALWAYS_INLINE int64_t stringPosition(
std::string_view string,
std::string_view subString,
int64_t instance) {
VELOX_USER_CHECK_GT(instance, 0, "'instance' must be a positive number");
if (subString.size() == 0) {
return 1;
}

int64_t byteIndex = -1;
if constexpr (lpos) {
byteIndex = findNthInstanceByteIndexFromStart(
std::string_view(string.data(), string.size()),
std::string_view(subString.data(), subString.size()),
instance);
byteIndex = findNthInstanceByteIndexFromStart(string, subString, instance);
} else {
byteIndex = findNthInstanceByteIndexFromEnd(
std::string_view(string.data(), string.size()),
std::string_view(subString.data(), subString.size()),
instance);
byteIndex = findNthInstanceByteIndexFromEnd(string, subString, instance);
}

if (byteIndex == -1) {
Expand Down
56 changes: 25 additions & 31 deletions velox/functions/lib/string/tests/StringImplTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,38 +396,38 @@ TEST_F(StringImplTest, stringToCodePoints) {
}

TEST_F(StringImplTest, overlappedStringPosition) {
auto testValidInputAsciiLpos = [](const std::string& string,
const std::string& substr,
auto testValidInputAsciiLpos = [](std::string_view string,
std::string_view substr,
const int64_t instance,
const int64_t expectedPosition) {
auto result = stringPosition</*isAscii*/ true, true>(
StringView(string), StringView(substr), instance);
auto result =
stringPosition</*isAscii*/ true, true>(string, substr, instance);
ASSERT_EQ(result, expectedPosition);
};
auto testValidInputAsciiRpos = [](const std::string& string,
const std::string& substr,
auto testValidInputAsciiRpos = [](std::string_view string,
std::string_view substr,
const int64_t instance,
const int64_t expectedPosition) {
auto result = stringPosition</*isAscii*/ true, false>(
StringView(string), StringView(substr), instance);
auto result =
stringPosition</*isAscii*/ true, false>(string, substr, instance);
ASSERT_EQ(result, expectedPosition);
};

auto testValidInputUnicodeLpos = [](const std::string& string,
const std::string& substr,
auto testValidInputUnicodeLpos = [](std::string_view string,
std::string_view substr,
const int64_t instance,
const int64_t expectedPosition) {
auto result = stringPosition</*isAscii*/ false, true>(
StringView(string), StringView(substr), instance);
auto result =
stringPosition</*isAscii*/ false, true>(string, substr, instance);
ASSERT_EQ(result, expectedPosition);
};

auto testValidInputUnicodeRpos = [](const std::string& string,
const std::string& substr,
auto testValidInputUnicodeRpos = [](std::string_view string,
std::string_view substr,
const int64_t instance,
const int64_t expectedPosition) {
auto result = stringPosition</*isAscii*/ false, false>(
StringView(string), StringView(substr), instance);
auto result =
stringPosition</*isAscii*/ false, false>(string, substr, instance);
ASSERT_EQ(result, expectedPosition);
};

Expand All @@ -445,31 +445,27 @@ TEST_F(StringImplTest, overlappedStringPosition) {
}

TEST_F(StringImplTest, stringPosition) {
auto testValidInputAscii = [](const std::string& string,
const std::string& substr,
auto testValidInputAscii = [](std::string_view string,
std::string_view substr,
const int64_t instance,
const int64_t expectedPosition) {
ASSERT_EQ(
stringPosition</*isAscii*/ true>(
StringView(string), StringView(substr), instance),
stringPosition</*isAscii*/ true>(string, substr, instance),
expectedPosition);
ASSERT_EQ(
stringPosition</*isAscii*/ false>(
StringView(string), StringView(substr), instance),
stringPosition</*isAscii*/ false>(string, substr, instance),
expectedPosition);
};

auto testValidInputUnicode = [](const std::string& string,
const std::string& substr,
auto testValidInputUnicode = [](std::string_view string,
std::string_view substr,
const int64_t instance,
const int64_t expectedPosition) {
ASSERT_EQ(
stringPosition</*isAscii*/ false>(
StringView(string), StringView(substr), instance),
stringPosition</*isAscii*/ false>(string, substr, instance),
expectedPosition);
ASSERT_EQ(
stringPosition</*isAscii*/ false>(
StringView(string), StringView(substr), instance),
stringPosition</*isAscii*/ false>(string, substr, instance),
expectedPosition);
};

Expand All @@ -494,9 +490,7 @@ TEST_F(StringImplTest, stringPosition) {
testValidInputUnicode("abc/xyz/foo/bar", "/", 4, 0L);

EXPECT_THROW(
stringPosition</*isAscii*/ false>(
StringView("foobar"), StringView("foobar"), 0),
VeloxUserError);
stringPosition</*isAscii*/ false>("foobar", "foobar", 0), VeloxUserError);
}

TEST_F(StringImplTest, replaceFirst) {
Expand Down
8 changes: 6 additions & 2 deletions velox/functions/prestosql/StringFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,9 @@ struct StrPosFunctionBase {
const arg_type<Varchar>& subString,
const arg_type<int64_t>& instance = 1) {
result = stringImpl::stringPosition<false /*isAscii*/, lpos>(
string, subString, instance);
std::string_view(string.data(), string.size()),
std::string_view(subString.data(), subString.size()),
instance);
}

FOLLY_ALWAYS_INLINE void callAscii(
Expand All @@ -421,7 +423,9 @@ struct StrPosFunctionBase {
const arg_type<Varchar>& subString,
const arg_type<int64_t>& instance = 1) {
result = stringImpl::stringPosition<true /*isAscii*/, lpos>(
string, subString, instance);
std::string_view(string.data(), string.size()),
std::string_view(subString.data(), subString.size()),
instance);
}
};

Expand Down
4 changes: 3 additions & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,15 @@ void registerFunctions(const std::string& prefix) {
registerCompareFunctions(prefix);
registerBitwiseFunctions(prefix);

// String sreach function
// String search function
registerFunction<StartsWithFunction, bool, Varchar, Varchar>(
{prefix + "startswith"});
registerFunction<EndsWithFunction, bool, Varchar, Varchar>(
{prefix + "endswith"});
registerFunction<ContainsFunction, bool, Varchar, Varchar>(
{prefix + "contains"});
registerFunction<LocateFunction, int32_t, Varchar, Varchar, int32_t>(
{prefix + "locate"});

registerFunction<TrimSpaceFunction, Varchar, Varchar>({prefix + "trim"});
registerFunction<TrimFunction, Varchar, Varchar, Varchar>({prefix + "trim"});
Expand Down
101 changes: 98 additions & 3 deletions velox/functions/sparksql/String.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ struct StartsWithFunction {
result = false;
} else {
result = str1.substr(0, str2.length()) == str2;
;
}
return true;
}
Expand Down Expand Up @@ -293,6 +292,96 @@ struct EndsWithFunction {
}
};

/// locate function
/// locate(substring, string, start) -> integer
///
/// Returns the 1-based position of the first occurrence of 'substring' in
/// 'string' after the give 'start' position. The search is from the beginning
/// of 'string' to the end. 'start' is the starting character position in
/// 'string' to search for the 'substring'. 'start' is 1-based and must be at
/// least 1 and at most the characters number of 'string'.
///
/// The following rules on special values are applied to follow Spark's
/// implementation. They are listed in order of priority:
/// Returns 0 if 'start' is NULL. Returns NULL if 'substring' or 'string' is
/// NULL. Returns 0 if 'start' is less than 1. Returns 1 if 'substring' is
/// empty. Returns 0 if 'start' is greater than the characters number of
/// 'string'. Returns 0 if 'substring' is not found in 'string'.
template <typename T>
struct LocateFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void callAscii(
out_type<int32_t>& result,
const arg_type<Varchar>& subString,
const arg_type<Varchar>& string,
const arg_type<int32_t>& start) {
if (start < 1) {
result = 0;
} else if (subString.empty()) {
result = 1;
} else if (start > string.size()) {
result = 0;
} else {
const auto position = stringImpl::stringPosition<true /*isAscii*/>(
std::string_view(
string.data() + start - 1, string.size() - start + 1),
std::string_view(subString.data(), subString.size()),
1 /*instance*/);
if (position) {
result = position + start - 1;
} else {
result = 0;
}
}
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<int32_t>& result,
const arg_type<Varchar>* subString,
const arg_type<Varchar>* string,
const arg_type<int32_t>* start) {
if (start == nullptr) {
result = 0;
return true;
}
if (subString == nullptr || string == nullptr) {
return false;
}
if (*start < 1) {
result = 0;
return true;
}
if (subString->empty()) {
result = 1;
return true;
}
if (*start > stringImpl::length<false /*isAscii*/>(*string)) {
result = 0;
return true;
}

// Find the start byte index of the start character. For example, in the
// Unicode string "😋😋😋", each character occupies 4 bytes. When 'start' is
// 2, the 'startByteIndex' is 4 which specifies the start of the second
// character.
const auto startByteIndex = stringCore::cappedByteLengthUnicode(
string->data(), string->size(), *start - 1);

const auto position = stringImpl::stringPosition<false /*isAscii*/>(
std::string_view(
string->data() + startByteIndex, string->size() - startByteIndex),
std::string_view(subString->data(), subString->size()),
1 /*instance*/);
if (position) {
result = position + *start - 1;
} else {
result = 0;
}
return true;
}
};

/// Returns the substring from str before count occurrences of the delimiter
/// delim. If count is positive, everything to the left of the final delimiter
/// (counting from the left) is returned. If count is negative, everything to
Expand Down Expand Up @@ -321,9 +410,15 @@ struct SubstringIndexFunction {

int64_t index;
if (count > 0) {
index = stringImpl::stringPosition<true, true>(str, delim, count);
index = stringImpl::stringPosition<true, true>(
std::string_view(str.data(), str.size()),
std::string_view(delim.data(), delim.size()),
count);
} else {
index = stringImpl::stringPosition<true, false>(str, delim, -count);
index = stringImpl::stringPosition<true, false>(
std::string_view(str.data(), str.size()),
std::string_view(delim.data(), delim.size()),
-count);
}

// If 'delim' is not found or found fewer than 'count' times,
Expand Down
39 changes: 39 additions & 0 deletions velox/functions/sparksql/tests/StringTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,45 @@ TEST_F(StringTest, substring) {
EXPECT_EQ(substringWithLength("da\u6570\u636Eta", -3, 2), "\u636Et");
}

TEST_F(StringTest, locate) {
const auto locate = [&](const std::optional<std::string>& substr,
const std::optional<std::string>& str,
const std::optional<int32_t>& start) {
return evaluateOnce<int32_t>("locate(c0, c1, c2)", substr, str, start);
};

EXPECT_EQ(locate("aa", "aaads", 1), 1);
EXPECT_EQ(locate("aa", "aaads", 0), 0);
EXPECT_EQ(locate("aa", "aaads", 2), 2);
EXPECT_EQ(locate("aa", "aaads", 3), 0);
EXPECT_EQ(locate("aa", "aaads", -3), 0);
EXPECT_EQ(locate("de", "aaads", 1), 0);
EXPECT_EQ(locate("de", "aaads", 2), 0);
EXPECT_EQ(locate("abc", "abcdddabcabc", 6), 7);
EXPECT_EQ(locate("", "", 1), 1);
EXPECT_EQ(locate("", "", 3), 1);
EXPECT_EQ(locate("", "", -1), 0);
EXPECT_EQ(locate("", "aaads", 1), 1);
EXPECT_EQ(locate("", "aaads", 9), 1);
EXPECT_EQ(locate("", "aaads", -1), 0);
EXPECT_EQ(locate("aa", "", 1), 0);
EXPECT_EQ(locate("aa", "", 2), 0);
EXPECT_EQ(locate("zz", "aaads", std::nullopt), 0);
EXPECT_EQ(locate("aa", std::nullopt, 1), std::nullopt);
EXPECT_EQ(locate(std::nullopt, "aaads", 1), std::nullopt);
EXPECT_EQ(locate(std::nullopt, std::nullopt, -1), std::nullopt);
EXPECT_EQ(locate(std::nullopt, std::nullopt, std::nullopt), 0);

EXPECT_EQ(locate("", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", 10), 1);
EXPECT_EQ(locate("", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", -1), 0);
EXPECT_EQ(locate("\u7231", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", 1), 4);
EXPECT_EQ(locate("\u7231", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", 0), 0);
EXPECT_EQ(
locate("\u4FE1", "\u4FE1\u5FF5,\u4FE1\u7231,\u4FE1\u5E0C\u671B", 2), 4);
EXPECT_EQ(
locate("\u4FE1", "\u4FE1\u5FF5,\u4FE1\u7231,\u4FE1\u5E0C\u671B", 8), 0);
}

TEST_F(StringTest, substringIndex) {
const auto substringIndex =
[&](const std::string& str, const std::string& delim, int32_t count) {
Expand Down

0 comments on commit c286451

Please sign in to comment.