From cd3321b28b0c9703e5d7105d6146c1270bbadd7f Mon Sep 17 00:00:00 2001 From: "Rossi(Ruoxi) Sun" Date: Wed, 17 Jan 2024 01:14:03 +0800 Subject: [PATCH] GH-39577: [C++] Fix tail-word access cross buffer boundary in `CompareBinaryColumnToRow` (#39606) ### Rationale for this change Default buffer alignment (64b) doesn't guarantee the safety of tail-word access in `KeyCompare::CompareBinaryColumnToRow`. Comment https://github.com/apache/arrow/issues/39577#issuecomment-1889090279 is a concrete example. ### What changes are included in this PR? Make `KeyCompare::CompareBinaryColumnToRow` tail-word safe. ### Are these changes tested? UT included. ### Are there any user-facing changes? No. * Closes: #39577 Authored-by: zanmato1984 Signed-off-by: Antoine Pitrou --- cpp/src/arrow/compute/CMakeLists.txt | 3 +- cpp/src/arrow/compute/row/compare_internal.cc | 11 +- cpp/src/arrow/compute/row/compare_test.cc | 110 ++++++++++++++++++ 3 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 cpp/src/arrow/compute/row/compare_test.cc diff --git a/cpp/src/arrow/compute/CMakeLists.txt b/cpp/src/arrow/compute/CMakeLists.txt index 1134e0a98ae45..e14d78ff6e5ca 100644 --- a/cpp/src/arrow/compute/CMakeLists.txt +++ b/cpp/src/arrow/compute/CMakeLists.txt @@ -89,7 +89,8 @@ add_arrow_test(internals_test kernel_test.cc light_array_test.cc registry_test.cc - key_hash_test.cc) + key_hash_test.cc + row/compare_test.cc) add_arrow_compute_test(expression_test SOURCES expression_test.cc) diff --git a/cpp/src/arrow/compute/row/compare_internal.cc b/cpp/src/arrow/compute/row/compare_internal.cc index 7c402e7a2384d..078a8287c71c0 100644 --- a/cpp/src/arrow/compute/row/compare_internal.cc +++ b/cpp/src/arrow/compute/row/compare_internal.cc @@ -208,8 +208,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, // Non-zero length guarantees no underflow int32_t num_loops_less_one = static_cast(bit_util::CeilDiv(length, 8)) - 1; - - uint64_t tail_mask = ~0ULL >> (64 - 8 * (length - num_loops_less_one * 8)); + int32_t num_tail_bytes = length - num_loops_less_one * 8; const uint64_t* key_left_ptr = reinterpret_cast(left_base + irow_left * length); @@ -224,9 +223,11 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, uint64_t key_right = key_right_ptr[i]; result_or |= key_left ^ key_right; } - uint64_t key_left = util::SafeLoad(key_left_ptr + i); - uint64_t key_right = key_right_ptr[i]; - result_or |= tail_mask & (key_left ^ key_right); + uint64_t key_left = 0; + memcpy(&key_left, key_left_ptr + i, num_tail_bytes); + uint64_t key_right = 0; + memcpy(&key_right, key_right_ptr + i, num_tail_bytes); + result_or |= key_left ^ key_right; return result_or == 0 ? 0xff : 0; }); } diff --git a/cpp/src/arrow/compute/row/compare_test.cc b/cpp/src/arrow/compute/row/compare_test.cc new file mode 100644 index 0000000000000..1d8562cd56d3c --- /dev/null +++ b/cpp/src/arrow/compute/row/compare_test.cc @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/row/compare_internal.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { +namespace compute { + +using arrow::bit_util::BytesForBits; +using arrow::internal::CpuInfo; +using arrow::util::MiniBatch; +using arrow::util::TempVectorStack; + +// Specialized case for GH-39577. +TEST(KeyCompare, CompareColumnsToRowsCuriousFSB) { + int fsb_length = 9; + MemoryPool* pool = default_memory_pool(); + TempVectorStack stack; + ASSERT_OK(stack.Init(pool, 8 * MiniBatch::kMiniBatchLength * sizeof(uint64_t))); + + int num_rows = 7; + auto column_right = ArrayFromJSON(fixed_size_binary(fsb_length), R"([ + "000000000", + "111111111", + "222222222", + "333333333", + "444444444", + "555555555", + "666666666"])"); + ExecBatch batch_right({column_right}, num_rows); + + std::vector column_metadatas_right; + ASSERT_OK(ColumnMetadatasFromExecBatch(batch_right, &column_metadatas_right)); + + RowTableMetadata table_metadata_right; + table_metadata_right.FromColumnMetadataVector(column_metadatas_right, sizeof(uint64_t), + sizeof(uint64_t)); + + std::vector column_arrays_right; + ASSERT_OK(ColumnArraysFromExecBatch(batch_right, &column_arrays_right)); + + RowTableImpl row_table; + ASSERT_OK(row_table.Init(pool, table_metadata_right)); + + RowTableEncoder row_encoder; + row_encoder.Init(column_metadatas_right, sizeof(uint64_t), sizeof(uint64_t)); + row_encoder.PrepareEncodeSelected(0, num_rows, column_arrays_right); + + std::vector row_ids_right(num_rows); + std::iota(row_ids_right.begin(), row_ids_right.end(), 0); + ASSERT_OK(row_encoder.EncodeSelected(&row_table, num_rows, row_ids_right.data())); + + auto column_left = ArrayFromJSON(fixed_size_binary(fsb_length), R"([ + "000000000", + "111111111", + "222222222", + "333333333", + "444444444", + "555555555", + "777777777"])"); + ExecBatch batch_left({column_left}, num_rows); + std::vector column_arrays_left; + ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &column_arrays_left)); + + std::vector row_ids_left(num_rows); + std::iota(row_ids_left.begin(), row_ids_left.end(), 0); + + LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack}; + + { + uint32_t num_rows_no_match; + std::vector row_ids_out(num_rows); + KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx, + &num_rows_no_match, row_ids_out.data(), + column_arrays_left, row_table, true, NULLPTR); + ASSERT_EQ(num_rows_no_match, 1); + ASSERT_EQ(row_ids_out[0], 6); + } + + { + std::vector match_bitvector(BytesForBits(num_rows)); + KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx, + NULLPTR, NULLPTR, column_arrays_left, row_table, + true, match_bitvector.data()); + for (int i = 0; i < num_rows; ++i) { + SCOPED_TRACE(i); + ASSERT_EQ(arrow::bit_util::GetBit(match_bitvector.data(), i), i != 6); + } + } +} + +} // namespace compute +} // namespace arrow