Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize basic_string::find #5101

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions stl/inc/__msvc_string_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,12 +660,29 @@ template <class _Traits>
constexpr size_t _Traits_find_ch(_In_reads_(_Hay_size) const _Traits_ptr_t<_Traits> _Haystack, const size_t _Hay_size,
const size_t _Start_at, const _Traits_ch_t<_Traits> _Ch) noexcept {
// search [_Haystack, _Haystack + _Hay_size) for _Ch, at/after _Start_at
if (_Start_at < _Hay_size) {
const auto _Found_at = _Traits::find(_Haystack + _Start_at, _Hay_size - _Start_at, _Ch);
if (_Found_at) {
return static_cast<size_t>(_Found_at - _Haystack);
if (_Start_at >= _Hay_size) {
return static_cast<size_t>(-1); // (npos) no room for match
}

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Is_implementation_handled_char_traits<_Traits>) {
if (!_STD _Is_constant_evaluated()) {
const auto _End = _Haystack + _Hay_size;
const auto _Ptr = _STD _Find_vectorized(_Haystack + _Start_at, _End, _Ch);

if (_Ptr != _End) {
return static_cast<size_t>(_Ptr - _Haystack);
} else {
return static_cast<size_t>(-1); // (npos) no match
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

const auto _Found_at = _Traits::find(_Haystack + _Start_at, _Hay_size - _Start_at, _Ch);
if (_Found_at) {
return static_cast<size_t>(_Found_at - _Haystack);
}

return static_cast<size_t>(-1); // (npos) no match
}
Expand Down
53 changes: 36 additions & 17 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,38 @@ void test_case_string_find_last_of(const basic_string<T>& input_haystack, const
assert(expected == actual);
}

template <class T>
void test_case_string_find_ch(const basic_string<T>& input_haystack, const T value) {
ptrdiff_t expected;

const auto expected_iter = last_known_good_find(input_haystack.begin(), input_haystack.end(), value);

if (expected_iter != input_haystack.end()) {
expected = expected_iter - input_haystack.begin();
} else {
expected = -1;
}

const auto actual = static_cast<ptrdiff_t>(input_haystack.find(value));
assert(expected == actual);
}

template <class T>
void test_case_string_rfind_ch(const basic_string<T>& input_haystack, const T value) {
ptrdiff_t expected;

const auto expected_iter = last_known_good_find_last(input_haystack.begin(), input_haystack.end(), value);

if (expected_iter != input_haystack.end()) {
expected = expected_iter - input_haystack.begin();
} else {
expected = -1;
}

const auto actual = static_cast<ptrdiff_t>(input_haystack.rfind(value));
assert(expected == actual);
}

template <class T>
void test_case_string_find_str(const basic_string<T>& input_haystack, const basic_string<T>& input_needle) {
ptrdiff_t expected;
Expand Down Expand Up @@ -1128,22 +1160,6 @@ void test_case_string_rfind_str(const basic_string<T>& input_haystack, const bas
assert(expected == actual);
}

template <class T>
void test_case_string_rfind_ch(const basic_string<T>& input_haystack, const T value) {
ptrdiff_t expected;

const auto expected_iter = last_known_good_find_last(input_haystack.begin(), input_haystack.end(), value);

if (expected_iter != input_haystack.end()) {
expected = expected_iter - input_haystack.begin();
} else {
expected = -1;
}

const auto actual = static_cast<ptrdiff_t>(input_haystack.rfind(value));
assert(expected == actual);
}

template <class T, class D>
void test_basic_string_dis(mt19937_64& gen, D& dis) {
basic_string<T> input_haystack;
Expand All @@ -1154,13 +1170,16 @@ void test_basic_string_dis(mt19937_64& gen, D& dis) {
temp.reserve(needleDataCount);

for (;;) {
const auto input_element = static_cast<T>(dis(gen));
test_case_string_find_ch(input_haystack, input_element);
test_case_string_rfind_ch(input_haystack, input_element);

input_needle.clear();

test_case_string_find_first_of(input_haystack, input_needle);
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);
test_case_string_rfind_str(input_haystack, input_needle);
test_case_string_rfind_ch(input_haystack, static_cast<T>(dis(gen)));

for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
Expand Down