Skip to content

Commit

Permalink
apacheGH-39577: [C++] Fix tail-word access cross buffer boundary in `…
Browse files Browse the repository at this point in the history
…CompareBinaryColumnToRow` (apache#39606)

### Rationale for this change

Default buffer alignment (64b) doesn't guarantee the safety of tail-word access in  `KeyCompare::CompareBinaryColumnToRow`. Comment apache#39577 (comment) is a concrete example.

### What changes are included in this PR?

Make `KeyCompare::CompareBinaryColumnToRow` tail-word safe.

### Are these changes tested?

UT included.

### Are there any user-facing changes?

No.

* Closes: apache#39577

Authored-by: zanmato1984 <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
zanmato1984 authored Jan 16, 2024
1 parent 980e7d7 commit cd3321b
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 6 deletions.
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ add_arrow_test(internals_test
kernel_test.cc
light_array_test.cc
registry_test.cc
key_hash_test.cc)
key_hash_test.cc
row/compare_test.cc)

add_arrow_compute_test(expression_test SOURCES expression_test.cc)

Expand Down
11 changes: 6 additions & 5 deletions cpp/src/arrow/compute/row/compare_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row,
// Non-zero length guarantees no underflow
int32_t num_loops_less_one =
static_cast<int32_t>(bit_util::CeilDiv(length, 8)) - 1;

uint64_t tail_mask = ~0ULL >> (64 - 8 * (length - num_loops_less_one * 8));
int32_t num_tail_bytes = length - num_loops_less_one * 8;

const uint64_t* key_left_ptr =
reinterpret_cast<const uint64_t*>(left_base + irow_left * length);
Expand All @@ -224,9 +223,11 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row,
uint64_t key_right = key_right_ptr[i];
result_or |= key_left ^ key_right;
}
uint64_t key_left = util::SafeLoad(key_left_ptr + i);
uint64_t key_right = key_right_ptr[i];
result_or |= tail_mask & (key_left ^ key_right);
uint64_t key_left = 0;
memcpy(&key_left, key_left_ptr + i, num_tail_bytes);
uint64_t key_right = 0;
memcpy(&key_right, key_right_ptr + i, num_tail_bytes);
result_or |= key_left ^ key_right;
return result_or == 0 ? 0xff : 0;
});
}
Expand Down
110 changes: 110 additions & 0 deletions cpp/src/arrow/compute/row/compare_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <numeric>

#include "arrow/compute/row/compare_internal.h"
#include "arrow/testing/gtest_util.h"

namespace arrow {
namespace compute {

using arrow::bit_util::BytesForBits;
using arrow::internal::CpuInfo;
using arrow::util::MiniBatch;
using arrow::util::TempVectorStack;

// Specialized case for GH-39577.
TEST(KeyCompare, CompareColumnsToRowsCuriousFSB) {
int fsb_length = 9;
MemoryPool* pool = default_memory_pool();
TempVectorStack stack;
ASSERT_OK(stack.Init(pool, 8 * MiniBatch::kMiniBatchLength * sizeof(uint64_t)));

int num_rows = 7;
auto column_right = ArrayFromJSON(fixed_size_binary(fsb_length), R"([
"000000000",
"111111111",
"222222222",
"333333333",
"444444444",
"555555555",
"666666666"])");
ExecBatch batch_right({column_right}, num_rows);

std::vector<KeyColumnMetadata> column_metadatas_right;
ASSERT_OK(ColumnMetadatasFromExecBatch(batch_right, &column_metadatas_right));

RowTableMetadata table_metadata_right;
table_metadata_right.FromColumnMetadataVector(column_metadatas_right, sizeof(uint64_t),
sizeof(uint64_t));

std::vector<KeyColumnArray> column_arrays_right;
ASSERT_OK(ColumnArraysFromExecBatch(batch_right, &column_arrays_right));

RowTableImpl row_table;
ASSERT_OK(row_table.Init(pool, table_metadata_right));

RowTableEncoder row_encoder;
row_encoder.Init(column_metadatas_right, sizeof(uint64_t), sizeof(uint64_t));
row_encoder.PrepareEncodeSelected(0, num_rows, column_arrays_right);

std::vector<uint16_t> row_ids_right(num_rows);
std::iota(row_ids_right.begin(), row_ids_right.end(), 0);
ASSERT_OK(row_encoder.EncodeSelected(&row_table, num_rows, row_ids_right.data()));

auto column_left = ArrayFromJSON(fixed_size_binary(fsb_length), R"([
"000000000",
"111111111",
"222222222",
"333333333",
"444444444",
"555555555",
"777777777"])");
ExecBatch batch_left({column_left}, num_rows);
std::vector<KeyColumnArray> column_arrays_left;
ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &column_arrays_left));

std::vector<uint32_t> row_ids_left(num_rows);
std::iota(row_ids_left.begin(), row_ids_left.end(), 0);

LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack};

{
uint32_t num_rows_no_match;
std::vector<uint16_t> row_ids_out(num_rows);
KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx,
&num_rows_no_match, row_ids_out.data(),
column_arrays_left, row_table, true, NULLPTR);
ASSERT_EQ(num_rows_no_match, 1);
ASSERT_EQ(row_ids_out[0], 6);
}

{
std::vector<uint8_t> match_bitvector(BytesForBits(num_rows));
KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx,
NULLPTR, NULLPTR, column_arrays_left, row_table,
true, match_bitvector.data());
for (int i = 0; i < num_rows; ++i) {
SCOPED_TRACE(i);
ASSERT_EQ(arrow::bit_util::GetBit(match_bitvector.data(), i), i != 6);
}
}
}

} // namespace compute
} // namespace arrow

0 comments on commit cd3321b

Please sign in to comment.