diff --git a/cpp/benchmarks/string/find.cpp b/cpp/benchmarks/string/find.cpp index a9c620e4bf0..f07e82286f2 100644 --- a/cpp/benchmarks/string/find.cpp +++ b/cpp/benchmarks/string/find.cpp @@ -73,6 +73,27 @@ static void bench_find_string(nvbench::state& state) } else if (api == "contains") { state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { cudf::strings::contains(input, target); }); + } else if (api == "multi-contains") { + constexpr int iters = 10; + std::vector match_targets({" abc", + "W43", + "0987 5W43", + "123 abc", + "23 abc", + "3 abc", + "é", + "7 5W43", + "87 5W43", + "987 5W43"}); + auto multi_targets = std::vector{}; + for (int i = 0; i < iters; i++) { + multi_targets.emplace_back(match_targets[i % match_targets.size()]); + } + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { + cudf::test::strings_column_wrapper multi_targets_column(multi_targets.begin(), + multi_targets.end()); + cudf::strings::multi_contains(input, cudf::strings_column_view(multi_targets_column)); + }); } else if (api == "starts_with") { state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { cudf::strings::starts_with(input, target); }); @@ -84,7 +105,8 @@ static void bench_find_string(nvbench::state& state) NVBENCH_BENCH(bench_find_string) .set_name("find_string") - .add_string_axis("api", {"find", "find_multi", "contains", "starts_with", "ends_with"}) + .add_string_axis("api", + {"find", "find_multi", "contains", "starts_with", "ends_with", "multi-contains"}) .add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024}) .add_int64_axis("num_rows", {260'000, 1'953'000, 16'777'216}) .add_int64_axis("hit_rate", {20, 80}); // percentage diff --git a/cpp/include/cudf/strings/find.hpp b/cpp/include/cudf/strings/find.hpp index e024b116a71..9b57ed7a2f3 100644 --- a/cpp/include/cudf/strings/find.hpp +++ b/cpp/include/cudf/strings/find.hpp @@ -138,6 +138,39 @@ std::unique_ptr contains( rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); +/** + * @brief Returns a table of columns of boolean values for each string where true indicates + * the target string was found within that string in the provided column. + * + * Each column in the result table corresponds to the result for the target string at the same + * ordinal. i.e. 0th column is the boolean-column result for the 0th target string, 1th for 1th, + * etc. + * + * If the target is not found for a string, false is returned for that entry in the output column. + * If the target is an empty string, true is returned for all non-null entries in the output column. + * + * Any null string entries return corresponding null entries in the output columns. + * e.g.: + * @code + * input: "a", "b", "c" + * targets: "a", "c" + * output is a table with two boolean columns: + * column_0: true, false, false + * column_1: false, false, true + * @endcode + * + * @param input Strings instance for this operation + * @param targets UTF-8 encoded strings to search for in each string in `input` + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return New BOOL8 column + */ +std::unique_ptr multi_contains( + strings_column_view const& input, + strings_column_view const& targets, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + /** * @brief Returns a column of boolean values for each string where true indicates * the corresponding target string was found within that string in the provided column. diff --git a/cpp/src/strings/search/find.cu b/cpp/src/strings/search/find.cu index 9bd1abb5542..6ca45dc2d53 100644 --- a/cpp/src/strings/search/find.cu +++ b/cpp/src/strings/search/find.cu @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include #include +#include +#include #include #include #include #include +#include +#include #include #include #include @@ -35,11 +38,16 @@ #include #include #include +#include #include +#include #include #include +#include #include +#include // For std::min + namespace cudf { namespace strings { namespace detail { @@ -414,6 +422,297 @@ std::unique_ptr contains_warp_parallel(strings_column_view const& input, return results; } +/** + * Each string uses a warp(32 threads) to handle all the targets. + * Each thread uses num_targets bools shared memory to store temp result for each lane. + */ +CUDF_KERNEL void multi_contains_warp_parallel_multi_scalars_fn( + column_device_view const d_strings, + column_device_view const d_targets, + cudf::device_span const d_target_first_bytes, + column_device_view const d_target_indexes_for_first_bytes, + cudf::device_span d_results) +{ + auto const num_targets = d_targets.size(); + auto const num_rows = d_strings.size(); + + auto const idx = cudf::detail::grid_1d::global_thread_id(); + auto const str_idx = idx / cudf::detail::warp_size; + if (str_idx >= num_rows) { return; } + + auto const lane_idx = idx % cudf::detail::warp_size; + if (d_strings.is_null(str_idx)) { return; } // bitmask will set result to null. + // get the string for this warp + auto const d_str = d_strings.element(str_idx); + + /** + * size of shared_bools = targets_size * block_size + * each thread uses targets_size bools + */ + extern __shared__ bool shared_bools[]; + + // initialize temp result: + // set true if target is empty, set false otherwise + for (int target_idx = 0; target_idx < num_targets; target_idx++) { + auto const d_target = d_targets.element(target_idx); + shared_bools[threadIdx.x * num_targets + target_idx] = d_target.size_bytes() == 0; + } + + for (size_type str_byte_idx = lane_idx; str_byte_idx < d_str.size_bytes(); + str_byte_idx += cudf::detail::warp_size) { + // 1. check the first chars using binary search on first char set + char c = *(d_str.data() + str_byte_idx); + auto first_byte_ptr = + thrust::lower_bound(thrust::seq, d_target_first_bytes.begin(), d_target_first_bytes.end(), c); + if (not(first_byte_ptr != d_target_first_bytes.end() && *first_byte_ptr == c)) { + // first char is not matched for all targets + // Note: first bytes does not work for empty target. + // For empty target, already set result as found + continue; + } + + // 2. check the 2nd chars + int first_char_index_in_list = first_byte_ptr - d_target_first_bytes.begin(); + // get possible targets + auto const possible_targets_list = + cudf::list_device_view{d_target_indexes_for_first_bytes, first_char_index_in_list}; + for (auto list_idx = 0; list_idx < possible_targets_list.size(); + ++list_idx) { // iterate possible targets + auto target_idx = possible_targets_list.element(list_idx); + int temp_result_idx = threadIdx.x * num_targets + target_idx; + if (!shared_bools[temp_result_idx]) { // not found before + auto const d_target = d_targets.element(target_idx); + if (d_str.size_bytes() - str_byte_idx >= d_target.size_bytes()) { + // first char already checked, only need to check the [2nd, end) chars if has. + bool found = true; + for (auto i = 1; i < d_target.size_bytes(); i++) { + if (*(d_str.data() + str_byte_idx + i) != *(d_target.data() + i)) { + found = false; + break; + } + } + if (found) { shared_bools[temp_result_idx] = true; } + } + } + } + } + + // wait all lanes are done in a warp + __syncwarp(); + + if (lane_idx == 0) { + for (int target_idx = 0; target_idx < num_targets; target_idx++) { + bool found = false; + for (int lane_idx = 0; lane_idx < cudf::detail::warp_size; lane_idx++) { + int temp_idx = (threadIdx.x + lane_idx) * num_targets + target_idx; + if (shared_bools[temp_idx]) { + found = true; + break; + } + } + d_results[target_idx][str_idx] = found; + } + } +} + +CUDF_KERNEL void multi_contains_using_indexes_fn( + column_device_view const d_strings, + column_device_view const d_targets, + cudf::device_span const d_target_first_bytes, + column_device_view const d_target_indexes_for_first_bytes, + cudf::device_span d_results) +{ + auto const str_idx = static_cast(cudf::detail::grid_1d::global_thread_id()); + auto const num_targets = d_targets.size(); + auto const num_rows = d_strings.size(); + if (str_idx >= num_rows) { return; } + if (d_strings.is_null(str_idx)) { return; } // bitmask will set result to null. + auto const d_str = d_strings.element(str_idx); + + // initialize temp result: + // set true if target is empty, set false otherwise + for (auto target_idx = 0; target_idx < num_targets; ++target_idx) { + auto const d_target = d_targets.element(target_idx); + d_results[target_idx][str_idx] = d_target.size_bytes() == 0; + } + + for (auto str_byte_idx = 0; str_byte_idx < d_str.size_bytes(); + ++str_byte_idx) { // iterate the start index in the string + + // 1. check the first chars using binary search on first char set + char c = *(d_str.data() + str_byte_idx); + auto first_byte_ptr = + thrust::lower_bound(thrust::seq, d_target_first_bytes.begin(), d_target_first_bytes.end(), c); + + if (not(first_byte_ptr != d_target_first_bytes.end() && *first_byte_ptr == c)) { + // first char is not matched for all targets + // Note: first bytes does not work for empty target. + // For empty target, already set result as found + continue; + } + + int first_char_index_in_list = first_byte_ptr - d_target_first_bytes.begin(); + // get possible targets + auto const possible_targets_list = + cudf::list_device_view{d_target_indexes_for_first_bytes, first_char_index_in_list}; + + for (auto list_idx = 0; list_idx < possible_targets_list.size(); + ++list_idx) { // iterate possible targets + auto target_idx = possible_targets_list.element(list_idx); + if (!d_results[target_idx][str_idx]) { // not found before + auto const d_target = d_targets.element(target_idx); + if (d_str.size_bytes() - str_byte_idx >= d_target.size_bytes()) { + // first char already checked, only need to check the [2nd, end) chars if has. + bool found = true; + for (auto i = 1; i < d_target.size_bytes(); i++) { + if (*(d_str.data() + str_byte_idx + i) != *(d_target.data() + i)) { + found = false; + break; + } + } + if (found) { d_results[target_idx][str_idx] = true; } + } + } + } + } +} + +/** + * Execute multi contains. + * First index the first char for all targets. + * Index the first char: + * collect first char for all targets and do uniq and sort, + * then index the targets for the first char. + * e.g.: + * targets: xa xb ac ad af + * first char set is: (a, x) + * index result is: + * { + * a: [2, 3, 4], // indexes for: ac ad af + * x: [0, 1] // indexes for: xa xb + * } + * when do searching: + * find (binary search) from `first char set` for a char in string: + * if char in string is not in ['a', 'x'], fast skip + * if char in string is 'x', then only need to try ["xa", "xb"] targets. + * if char in string is 'a', then only need to try ["ac", "ad", "af"] targets. + * + */ +std::vector> multi_contains(bool warp_parallel, + strings_column_view const& input, + strings_column_view const& targets, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto const num_targets = static_cast(targets.size()); + CUDF_EXPECTS(not targets.is_empty(), "Must specify at least one target string."); + + // 1. copy targets from device to host + auto const h_targets_child = cudf::detail::make_std_vector_sync( + cudf::device_span(targets.chars_begin(stream), targets.chars_size(stream)), stream); + + // Note: targets may be sliced, so should find the correct first offset + auto first_offset = targets.offset(); + auto const targets_offsets = targets.offsets(); + auto const h_targets_offsets = cudf::detail::make_std_vector_sync( + cudf::device_span{targets_offsets.data() + first_offset, + static_cast(targets.size() + 1)}, + stream); + + // 2. index the first characters for all targets + std::map> indexes; + for (auto i = 0; i < targets.size(); i++) { + auto target_begin_offset = h_targets_offsets[i]; + auto target_end_offset = h_targets_offsets[i + 1]; + if (target_end_offset - target_begin_offset > 0) { + char first_char = h_targets_child[target_begin_offset]; + auto not_exist = indexes.find(first_char) == indexes.end(); + if (not_exist) { indexes[first_char] = std::vector(); } + indexes[first_char].push_back(i); + } + } + thrust::host_vector h_first_bytes = {}; + thrust::host_vector h_offsets = {0}; + thrust::host_vector h_elements = {}; + for (const auto& pair : indexes) { + h_first_bytes.push_back(pair.first); + h_elements.insert(h_elements.end(), pair.second.begin(), pair.second.end()); + h_offsets.push_back(h_elements.size()); + } + + // 3. copy first char set and first char indexes to device + auto d_first_bytes = cudf::detail::make_device_uvector_async(h_first_bytes, stream, mr); + auto d_offsets = cudf::detail::make_device_uvector_async(h_offsets, stream, mr); + auto d_elements = cudf::detail::make_device_uvector_async(h_elements, stream, mr); + auto offsets_column = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + h_offsets.size(), + d_offsets.release(), + rmm::device_buffer{}, // null mask + 0 // null size + ); + auto element_column = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + h_elements.size(), + d_elements.release(), + rmm::device_buffer{}, // null mask + 0 // null size + ); + auto list_column = cudf::make_lists_column(h_first_bytes.size(), + std::move(offsets_column), + std::move(element_column), + 0, // null count + rmm::device_buffer{}, // null mask + stream, + mr); + auto d_list_column = column_device_view::create(list_column->view(), stream); + + // 4. Create output columns. + auto const results_iter = + thrust::make_transform_iterator(thrust::counting_iterator(0), [&](int i) { + return make_numeric_column(data_type{type_id::BOOL8}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + }); + auto results_list = + std::vector>(results_iter, results_iter + targets.size()); + auto device_results_list = [&] { + auto host_results_pointer_iter = + thrust::make_transform_iterator(results_list.begin(), [](auto const& results_column) { + return results_column->mutable_view().template data(); + }); + auto host_results_pointers = std::vector( + host_results_pointer_iter, host_results_pointer_iter + results_list.size()); + return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr); + }(); + + auto const d_strings = column_device_view::create(input.parent(), stream); + auto const d_targets = column_device_view::create(targets.parent(), stream); + + // 5. execute the kernel + constexpr int block_size = 256; + + if (warp_parallel) { + cudf::detail::grid_1d grid{input.size() * cudf::detail::warp_size, block_size}; + int shared_mem_size = block_size * targets.size(); + multi_contains_warp_parallel_multi_scalars_fn<<>>( + *d_strings, *d_targets, d_first_bytes, *d_list_column, device_results_list); + } else { + cudf::detail::grid_1d grid{input.size(), block_size}; + multi_contains_using_indexes_fn<<>>( + *d_strings, *d_targets, d_first_bytes, *d_list_column, device_results_list); + } + + return results_list; +} + /** * @brief Utility to return a bool column indicating the presence of * a given target string in a strings column. @@ -534,6 +833,16 @@ std::unique_ptr contains_fn(strings_column_view const& strings, return results; } +std::unique_ptr contains_small_strings_impl(strings_column_view const& input, + string_scalar const& target, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto pfn = [] __device__(string_view d_string, string_view d_target) { + return d_string.find(d_target) != string_view::npos; + }; + return contains_fn(input, target, pfn, stream, mr); +} } // namespace std::unique_ptr contains(strings_column_view const& input, @@ -548,10 +857,47 @@ std::unique_ptr contains(strings_column_view const& input, } // benchmark measurements showed this to be faster for smaller strings - auto pfn = [] __device__(string_view d_string, string_view d_target) { - return d_string.find(d_target) != string_view::npos; - }; - return contains_fn(input, target, pfn, stream, mr); + return contains_small_strings_impl(input, target, stream, mr); +} + +std::unique_ptr
multi_contains(strings_column_view const& input, + strings_column_view const& targets, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_EXPECTS(not targets.has_nulls(), "Target strings cannot be null"); + auto result_columns = [&] { + if ((input.null_count() < input.size()) && + ((input.chars_size(stream) / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) { + // Large strings. + // use warp parallel when the average string width is greater than the threshold + + static constexpr int target_group_size = 16; + if (targets.size() <= target_group_size) { + return multi_contains(/**warp parallel**/ true, input, targets, stream, mr); + } else { + // Too many targets will consume more shared memory, so split targets + std::vector> ret_columns; + size_type num_groups = (targets.size() + target_group_size - 1) / target_group_size; + for (size_type group_idx = 0; group_idx < num_groups; group_idx++) { + size_type start_target = group_idx * target_group_size; + size_type end_target = std::min(start_target + target_group_size, targets.size()); + auto target_goup = + cudf::detail::slice(targets.parent(), start_target, end_target, stream); + auto bool_columns = multi_contains( + /**warp parallel**/ true, input, strings_column_view(target_goup), stream, mr); + for (auto& c : bool_columns) { + ret_columns.push_back(std::move(c)); // take the ownership + } + } + return ret_columns; + } + } else { + // Small strings. Searching for multiple targets in one thread seems to work fastest. + return multi_contains(/**warp parallel**/ false, input, targets, stream, mr); + } + }(); + return std::make_unique
(std::move(result_columns)); } std::unique_ptr contains(strings_column_view const& strings, @@ -632,6 +978,15 @@ std::unique_ptr contains(strings_column_view const& strings, return detail::contains(strings, target, stream, mr); } +std::unique_ptr
multi_contains(strings_column_view const& strings, + strings_column_view const& targets, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::multi_contains(strings, targets, stream, mr); +} + std::unique_ptr contains(strings_column_view const& strings, strings_column_view const& targets, rmm::cuda_stream_view stream, diff --git a/cpp/tests/strings/find_tests.cpp b/cpp/tests/strings/find_tests.cpp index 2da95ba5c27..52369d0755c 100644 --- a/cpp/tests/strings/find_tests.cpp +++ b/cpp/tests/strings/find_tests.cpp @@ -17,16 +17,14 @@ #include #include #include +#include -#include #include #include #include #include #include -#include - #include struct StringsFindTest : public cudf::test::BaseFixture {}; @@ -198,6 +196,159 @@ TEST_F(StringsFindTest, ContainsLongStrings) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected); } +TEST_F(StringsFindTest, MultiContains) +{ + constexpr int num_rows = 1024 + 1; + // replicate the following 9 rows: + std::vector s = { + "Héllo, there world and goodbye", + "quick brown fox jumped over the lazy brown dog; the fat cats jump in place without moving", + "the following code snippet demonstrates how to use search for values in an ordered range", + "it returns the last position where value could be inserted without violating the ordering", + "algorithms execution is parallelized as determined by an execution policy. t", + "he this is a continuation of previous row to make sure string boundaries are honored", + "abcdefghijklmnopqrstuvwxyz 0123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ !@#$%^&*()~", + "", + ""}; + + // replicate strings + auto string_itr = + cudf::detail::make_counting_transform_iterator(0, [&](auto i) { return s[i % s.size()]; }); + + // nulls: 8, 8 + 1 * 9, 8 + 2 * 9 ...... + auto string_v = cudf::detail::make_counting_transform_iterator( + 0, [&](auto i) { return (i + 1) % s.size() != 0; }); + + auto const strings = + cudf::test::strings_column_wrapper(string_itr, string_itr + num_rows, string_v); + auto strings_view = cudf::strings_column_view(strings); + std::vector match_targets({" the ", "a", "", "é"}); + cudf::test::strings_column_wrapper multi_targets_column(match_targets.begin(), + match_targets.end()); + auto results = + cudf::strings::multi_contains(strings_view, cudf::strings_column_view(multi_targets_column)); + + std::vector ret_0 = {0, 1, 0, 1, 0, 0, 0, 0, 0}; + std::vector ret_1 = {1, 1, 1, 1, 1, 1, 1, 0, 0}; + std::vector ret_2 = {1, 1, 1, 1, 1, 1, 1, 1, 0}; + std::vector ret_3 = {1, 0, 0, 0, 0, 0, 0, 0, 0}; + + auto make_bool_col_fn = [&string_v, &num_rows](std::vector bools) { + auto iter = cudf::detail::make_counting_transform_iterator( + 0, [&](auto i) { return bools[i % bools.size()]; }); + return cudf::test::fixed_width_column_wrapper(iter, iter + num_rows, string_v); + }; + + auto expected_0 = make_bool_col_fn(ret_0); + auto expected_1 = make_bool_col_fn(ret_1); + auto expected_2 = make_bool_col_fn(ret_2); + auto expected_3 = make_bool_col_fn(ret_3); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(0), expected_0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(1), expected_1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(2), expected_2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(3), expected_3); +} + +TEST_F(StringsFindTest, MultiContainsMoreTargets) +{ + auto const strings = cudf::test::strings_column_wrapper{ + "quick brown fox jumped over the lazy brown dog; the fat cats jump in place without moving " + "quick brown fox jumped", + "the following code snippet demonstrates how to use search for values in an ordered rangethe " + "following code snippet", + "thé it returns the last position where value could be inserted without violating ordering thé " + "it returns the last position"}; + auto strings_view = cudf::strings_column_view(strings); + std::vector targets({"lazy brown", "non-exist", ""}); + + std::vector> expects; + expects.push_back(cudf::test::fixed_width_column_wrapper({1, 0, 0})); + expects.push_back(cudf::test::fixed_width_column_wrapper({0, 0, 0})); + expects.push_back(cudf::test::fixed_width_column_wrapper({1, 1, 1})); + + std::vector match_targets; + int max_num_targets = 50; + + for (int num_targets = 1; num_targets < max_num_targets; num_targets++) { + match_targets.clear(); + for (int i = 0; i < num_targets; i++) { + match_targets.push_back(targets[i % targets.size()]); + } + + cudf::test::strings_column_wrapper multi_targets_column(match_targets.begin(), + match_targets.end()); + auto results = + cudf::strings::multi_contains(strings_view, cudf::strings_column_view(multi_targets_column)); + EXPECT_EQ(results->num_columns(), num_targets); + for (int i = 0; i < num_targets; i++) { + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(i), expects[i % expects.size()]); + } + } +} + +TEST_F(StringsFindTest, MultiContainsLongStrings) +{ + constexpr int num_rows = 1024 + 1; + // replicate the following 7 rows: + std::vector s = { + "quick brown fox jumped over the lazy brown dog; the fat cats jump in place without moving " + "quick brown fox jumped", + "the following code snippet demonstrates how to use search for values in an ordered rangethe " + "following code snippet", + "thé it returns the last position where value could be inserted without violating ordering thé " + "it returns the last position", + "algorithms execution is parallelized as determined by an execution policy. t algorithms " + "execution is parallelized as ", + "he this is a continuation of previous row to make sure string boundaries are honored he this " + "is a continuation of previous row", + "abcdefghijklmnopqrstuvwxyz 0123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ " + "!@#$%^&*()~abcdefghijklmnopqrstuvwxyz 0123456789 ABCDEFGHIJKL", + ""}; + + // replicate strings + auto string_itr = + cudf::detail::make_counting_transform_iterator(0, [&](auto i) { return s[i % s.size()]; }); + + // nulls: 6, 6 + 1 * 7, 6 + 2 * 7 ...... + auto string_v = cudf::detail::make_counting_transform_iterator( + 0, [&](auto i) { return (i + 1) % s.size() != 0; }); + + auto const strings = + cudf::test::strings_column_wrapper(string_itr, string_itr + num_rows, string_v); + + auto sv = cudf::strings_column_view(strings); + auto targets = cudf::test::strings_column_wrapper({" the ", "search", "", "string", "ox", "é "}); + auto results = cudf::strings::multi_contains(sv, cudf::strings_column_view(targets)); + + std::vector ret_0 = {1, 0, 1, 0, 0, 0, 0}; + std::vector ret_1 = {0, 1, 0, 0, 0, 0, 0}; + std::vector ret_2 = {1, 1, 1, 1, 1, 1, 0}; + std::vector ret_3 = {0, 0, 0, 0, 1, 0, 0}; + std::vector ret_4 = {1, 0, 0, 0, 0, 0, 0}; + std::vector ret_5 = {0, 0, 1, 0, 0, 0, 0}; + + auto make_bool_col_fn = [&string_v, &num_rows](std::vector bools) { + auto iter = cudf::detail::make_counting_transform_iterator( + 0, [&](auto i) { return bools[i % bools.size()]; }); + return cudf::test::fixed_width_column_wrapper(iter, iter + num_rows, string_v); + }; + + auto expected_0 = make_bool_col_fn(ret_0); + auto expected_1 = make_bool_col_fn(ret_1); + auto expected_2 = make_bool_col_fn(ret_2); + auto expected_3 = make_bool_col_fn(ret_3); + auto expected_4 = make_bool_col_fn(ret_4); + auto expected_5 = make_bool_col_fn(ret_5); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(0), expected_0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(1), expected_1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(2), expected_2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(3), expected_3); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(4), expected_4); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(5), expected_5); +} + TEST_F(StringsFindTest, StartsWith) { cudf::test::strings_column_wrapper strings({"Héllo", "thesé", "", "lease", "tést strings", ""}, diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 6bd4e06c47e..e113518229b 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -20,6 +20,7 @@ import java.util.*; import java.util.stream.IntStream; +import java.util.stream.Stream; import static ai.rapids.cudf.HostColumnVector.OFFSET_SIZE; @@ -3332,6 +3333,22 @@ public final ColumnVector stringContains(Scalar compString) { return new ColumnVector(stringContains(getNativeView(), compString.getScalarHandle())); } + private static long[] toPrimitive(Long[] longs) { + long[] ret = new long[longs.length]; + for (int i = 0; i < longs.length; ++i) { + ret[i] = longs[i]; + } + return ret; + } + + public final ColumnVector[] stringContains(ColumnView targets) { + assert type.equals(DType.STRING) : "column type must be a String"; + assert targets.getType().equals(DType.STRING) : "targets type must be a string"; + assert targets.getNullCount() == 0 : "targets must not be null"; + long[] resultPointers = stringContainsMulti(getNativeView(), targets.getNativeView()); + return Arrays.stream(resultPointers).mapToObj(ColumnVector::new).toArray(ColumnVector[]::new); + } + /** * Replaces values less than `lo` in `input` with `lo`, * and values greater than `hi` with `hi`. @@ -4437,6 +4454,11 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat */ private static native long stringContains(long cudfViewHandle, long compString) throws CudfException; + /** + * Check multiple target strings against the same input column. + */ + private static native long[] stringContainsMulti(long cudfViewHandle, long targets) throws CudfException; + /** * Native method for extracting results from a regex program pattern. Returns a table handle. * diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 72f0ad19912..4d06be394d4 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1522,6 +1522,26 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringContains(JNIEnv* en CATCH_STD(env, 0); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringContainsMulti(JNIEnv* env, + jobject j_object, + jlong j_view_handle, + jlong comp_strings) +{ + JNI_NULL_CHECK(env, j_view_handle, "column is null", 0); + JNI_NULL_CHECK(env, comp_strings, "targets is null", 0); + + try { + cudf::jni::auto_set_device(env); + auto* column_view = reinterpret_cast(j_view_handle); + auto* targets_view = reinterpret_cast(comp_strings); + auto const strings_column = cudf::strings_column_view(*column_view); + auto const targets_column = cudf::strings_column_view(*targets_view); + auto contains_results = cudf::strings::multi_contains(strings_column, targets_column); + return cudf::jni::convert_table_for_return(env, std::move(contains_results)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_matchesRe(JNIEnv* env, jobject j_object, jlong j_view_handle, diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 708744569df..c3097ce99c4 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3827,6 +3827,30 @@ void testStringOpsEmpty() { } } + @Test + void testStringContainsMulti() { + ColumnVector[] results = null; + try (ColumnVector haystack = ColumnVector.fromStrings("All the leaves are brown", + "And the sky is grey", + "I've been for a walk", + "On a winter's day", + null, + ""); + ColumnVector targets = ColumnVector.fromStrings("the", "a"); + ColumnVector expected0 = ColumnVector.fromBoxedBooleans(true, true, false, false, null, false); + ColumnVector expected1 = ColumnVector.fromBoxedBooleans(true, false, true, true, null, false)) { + results = haystack.stringContains(targets); + assertColumnsAreEqual(results[0], expected0); + assertColumnsAreEqual(results[1], expected1); + } finally { + if (results != null) { + for (ColumnVector c : results) { + c.close(); + } + } + } + } + @Test void testStringFindOperations() { try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "1a\"\u0100B1", "a\"\u0100B1", "1a\"\u0100B",