Skip to content

Commit

Permalink
Support decimal32/64 in reader & vector kernels & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
curioustien committed Jan 25, 2025
1 parent b0e0375 commit 83c8a02
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 115 deletions.
4 changes: 3 additions & 1 deletion cpp/src/arrow/compute/kernels/vector_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ KernelInit GetHashInit(Type::type type_id) {
case Type::DATE32:
case Type::TIME32:
case Type::INTERVAL_MONTHS:
case Type::DECIMAL32:
return HashInit<RegularHashKernel<UInt32Type, Action>>;
case Type::INT64:
case Type::UINT64:
Expand All @@ -564,6 +565,7 @@ KernelInit GetHashInit(Type::type type_id) {
case Type::TIMESTAMP:
case Type::DURATION:
case Type::INTERVAL_DAY_TIME:
case Type::DECIMAL64:
return HashInit<RegularHashKernel<UInt64Type, Action>>;
case Type::BINARY:
case Type::STRING:
Expand Down Expand Up @@ -707,7 +709,7 @@ void AddHashKernels(VectorFunction* func, VectorKernel base, OutputType out_ty)
DCHECK_OK(func->AddKernel(base));
}

for (auto t : {Type::DECIMAL128, Type::DECIMAL256}) {
for (auto t : {Type::DECIMAL32, Type::DECIMAL64, Type::DECIMAL128, Type::DECIMAL256}) {
base.init = GetHashInit<Action>(t);
base.signature = KernelSignature::Make({t}, out_ty);
DCHECK_OK(func->AddKernel(base));
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/kernels/vector_selection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ std::shared_ptr<VectorFunction> MakeIndicesNonZeroFunction(std::string name,
AddKernels(NumericTypes());
AddKernels({boolean()});

for (const auto& ty : {Type::DECIMAL128, Type::DECIMAL256}) {
for (const auto& ty :
{Type::DECIMAL32, Type::DECIMAL64, Type::DECIMAL128, Type::DECIMAL256}) {
kernel.signature = KernelSignature::Make({ty}, uint64());
DCHECK_OK(func->AddKernel(kernel));
}
Expand Down
170 changes: 110 additions & 60 deletions cpp/src/parquet/arrow/arrow_reader_writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,14 @@ std::shared_ptr<const LogicalType> get_logical_type(const DataType& type) {
static_cast<const ::arrow::DictionaryType&>(type);
return get_logical_type(*dict_type.value_type());
}
case ArrowId::DECIMAL32: {
const auto& dec_type = static_cast<const ::arrow::Decimal32Type&>(type);
return LogicalType::Decimal(dec_type.precision(), dec_type.scale());
}
case ArrowId::DECIMAL64: {
const auto& dec_type = static_cast<const ::arrow::Decimal64Type&>(type);
return LogicalType::Decimal(dec_type.precision(), dec_type.scale());
}
case ArrowId::DECIMAL128: {
const auto& dec_type = static_cast<const ::arrow::Decimal128Type&>(type);
return LogicalType::Decimal(dec_type.precision(), dec_type.scale());
Expand All @@ -206,9 +214,11 @@ ParquetType::type get_physical_type(const DataType& type) {
case ArrowId::INT16:
case ArrowId::UINT32:
case ArrowId::INT32:
case ArrowId::DECIMAL32:
return ParquetType::INT32;
case ArrowId::UINT64:
case ArrowId::INT64:
case ArrowId::DECIMAL64:
return ParquetType::INT64;
case ArrowId::FLOAT:
return ParquetType::FLOAT;
Expand Down Expand Up @@ -533,6 +543,8 @@ static std::shared_ptr<GroupNode> MakeSimpleSchema(const DataType& type,
case ::arrow::Type::HALF_FLOAT:
byte_width = sizeof(::arrow::HalfFloatType::c_type);
break;
case ::arrow::Type::DECIMAL32:
case ::arrow::Type::DECIMAL64:
case ::arrow::Type::DECIMAL128:
case ::arrow::Type::DECIMAL256: {
const auto& decimal_type = static_cast<const DecimalType&>(values_type);
Expand All @@ -548,6 +560,8 @@ static std::shared_ptr<GroupNode> MakeSimpleSchema(const DataType& type,
case ::arrow::Type::HALF_FLOAT:
byte_width = sizeof(::arrow::HalfFloatType::c_type);
break;
case ::arrow::Type::DECIMAL32:
case ::arrow::Type::DECIMAL64:
case ::arrow::Type::DECIMAL128:
case ::arrow::Type::DECIMAL256: {
const auto& decimal_type = static_cast<const DecimalType&>(type);
Expand Down Expand Up @@ -783,34 +797,70 @@ class TestReadDecimals : public ParquetIOTestBase {
// The Decimal roundtrip tests always go through the FixedLenByteArray path,
// check the ByteArray case manually.

TEST_F(TestReadDecimals, Decimal128ByteArray) {
TEST_F(TestReadDecimals, Decimal32ByteArray) {
const std::vector<std::vector<uint8_t>> big_endian_decimals = {
// 123456
{1, 226, 64},
// 987654
{15, 18, 6},
// -123456
{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 29, 192},
{255, 254, 29, 192},
};

auto expected =
ArrayFromJSON(::arrow::decimal128(6, 3), R"(["123.456", "987.654", "-123.456"])");
ArrayFromJSON(::arrow::decimal32(6, 3), R"(["123.456", "987.654", "-123.456"])");
CheckReadFromByteArrays(LogicalType::Decimal(6, 3), big_endian_decimals, *expected);
}

TEST_F(TestReadDecimals, Decimal64ByteArray) {
const std::vector<std::vector<uint8_t>> big_endian_decimals = {
// 123456
{1, 226, 64},
// 987654
{15, 18, 6},
// -123456
{255, 254, 29, 192},
// -123456
{255, 255, 255, 255, 255, 254, 29, 192},
};

auto expected = ArrayFromJSON(::arrow::decimal64(16, 3),
R"(["123.456", "987.654", "-123.456", "-123.456"])");
CheckReadFromByteArrays(LogicalType::Decimal(16, 3), big_endian_decimals, *expected);
}

TEST_F(TestReadDecimals, Decimal128ByteArray) {
const std::vector<std::vector<uint8_t>> big_endian_decimals = {
// 123456
{1, 226, 64},
// 987654
{15, 18, 6},
// -123456
{255, 254, 29, 192},
// -123456
{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 29, 192},
};

auto expected = ArrayFromJSON(::arrow::decimal128(20, 3),
R"(["123.456", "987.654", "-123.456", "-123.456"])");
CheckReadFromByteArrays(LogicalType::Decimal(20, 3), big_endian_decimals, *expected);
}

TEST_F(TestReadDecimals, Decimal256ByteArray) {
const std::vector<std::vector<uint8_t>> big_endian_decimals = {
// 123456
{1, 226, 64},
// 987654
{15, 18, 6},
// -123456
{255, 254, 29, 192},
// -123456
{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 29, 192},
};

auto expected =
ArrayFromJSON(::arrow::decimal256(40, 3), R"(["123.456", "987.654", "-123.456"])");
auto expected = ArrayFromJSON(::arrow::decimal256(40, 3),
R"(["123.456", "987.654", "-123.456", "-123.456"])");
CheckReadFromByteArrays(LogicalType::Decimal(40, 3), big_endian_decimals, *expected);
}

Expand Down Expand Up @@ -858,9 +908,9 @@ typedef ::testing::Types<
::arrow::Int16Type, ::arrow::Int32Type, ::arrow::UInt64Type, ::arrow::Int64Type,
::arrow::Date32Type, ::arrow::FloatType, ::arrow::DoubleType, ::arrow::StringType,
::arrow::BinaryType, ::arrow::FixedSizeBinaryType, ::arrow::HalfFloatType,
Decimal128WithPrecisionAndScale<1>, Decimal128WithPrecisionAndScale<5>,
Decimal128WithPrecisionAndScale<10>, Decimal128WithPrecisionAndScale<19>,
Decimal128WithPrecisionAndScale<23>, Decimal128WithPrecisionAndScale<27>,
Decimal32WithPrecisionAndScale<1>, Decimal32WithPrecisionAndScale<5>,
Decimal64WithPrecisionAndScale<10>, Decimal64WithPrecisionAndScale<18>,
Decimal128WithPrecisionAndScale<19>, Decimal128WithPrecisionAndScale<27>,
Decimal128WithPrecisionAndScale<38>, Decimal256WithPrecisionAndScale<39>,
Decimal256WithPrecisionAndScale<56>, Decimal256WithPrecisionAndScale<76>>
TestTypes;
Expand Down Expand Up @@ -903,8 +953,9 @@ TYPED_TEST(TestParquetIO, SingleColumnTableRequiredWrite) {
std::shared_ptr<Table> table = MakeSimpleTable(values, false);

this->ResetSink();
ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
values->length(), default_writer_properties()));
ASSERT_OK_NO_THROW(WriteTable(
*table, ::arrow::default_memory_pool(), this->sink_, values->length(),
::parquet::WriterProperties::Builder().enable_store_decimal_as_integer()->build()));

std::shared_ptr<Table> out;
std::unique_ptr<FileReader> reader;
Expand Down Expand Up @@ -2944,7 +2995,7 @@ TEST(ArrowReadWrite, Decimal256) {
using ::arrow::Decimal256;
using ::arrow::field;

auto type = ::arrow::decimal256(8, 4);
auto type = ::arrow::decimal256(48, 4);

const char* json = R"(["1.0000", null, "-1.2345", "-1000.5678",
"-9999.9999", "9999.9999"])";
Expand All @@ -2958,7 +3009,7 @@ TEST(ArrowReadWrite, DecimalStats) {
using ::arrow::Decimal128;
using ::arrow::field;

auto type = ::arrow::decimal128(/*precision=*/8, /*scale=*/0);
auto type = ::arrow::decimal128(/*precision=*/28, /*scale=*/0);

const char* json = R"(["255", "128", null, "0", "1", "-127", "-128", "-129", "-255"])";
auto array = ::arrow::ArrayFromJSON(type, json);
Expand Down Expand Up @@ -3447,8 +3498,8 @@ TEST(ArrowReadWrite, NestedRequiredOuterOptional) {
types.push_back(::arrow::duration(::arrow::TimeUnit::MILLI));
types.push_back(::arrow::duration(::arrow::TimeUnit::MICRO));
types.push_back(::arrow::duration(::arrow::TimeUnit::NANO));
types.push_back(::arrow::decimal128(3, 2));
types.push_back(::arrow::decimal256(3, 2));
types.push_back(::arrow::decimal32(3, 2));
types.push_back(::arrow::decimal128(23, 2));
types.push_back(::arrow::fixed_size_binary(4));
// Note large variants of types appear to get converted back to regular on read
types.push_back(::arrow::dictionary(::arrow::int32(), ::arrow::binary()));
Expand Down Expand Up @@ -3500,9 +3551,8 @@ TEST(ArrowReadWrite, NestedRequiredOuterOptionalDecimal) {
ByteArray("\x0f\x12\x06"), // 987654
};
const std::vector<int32_t> int32_values = {123456, 987654};
const std::vector<int64_t> int64_values = {123456, 987654};

const auto inner_type = ::arrow::decimal128(6, 3);
const auto inner_type = ::arrow::decimal32(6, 3);
auto inner_field = ::arrow::field("inner", inner_type, /*nullable=*/false);
auto type = ::arrow::struct_({inner_field});
auto field = ::arrow::field("outer", type, /*nullable=*/true);
Expand All @@ -3512,7 +3562,7 @@ TEST(ArrowReadWrite, NestedRequiredOuterOptionalDecimal) {
::arrow::StructArray::Make({inner}, {inner_field}, null_bitmap));
auto table = ::arrow::Table::Make(::arrow::schema({field}), {array});

for (const auto& encoding : {Type::BYTE_ARRAY, Type::INT32, Type::INT64}) {
for (const auto& encoding : {Type::BYTE_ARRAY, Type::INT32}) {
// Manually write out file based on encoding type
ARROW_SCOPED_TRACE("Encoding decimals as ", encoding);
auto parquet_schema = GroupNode::Make(
Expand Down Expand Up @@ -3543,12 +3593,6 @@ TEST(ArrowReadWrite, NestedRequiredOuterOptionalDecimal) {
int32_values.data());
break;
}
case Type::INT64: {
auto typed_writer = checked_cast<Int64Writer*>(column_writer);
typed_writer->WriteBatch(4, def_levels.data(), /*rep_levels=*/nullptr,
int64_values.data());
break;
}
default:
FAIL() << "Invalid encoding";
return;
Expand All @@ -3562,11 +3606,11 @@ TEST(ArrowReadWrite, NestedRequiredOuterOptionalDecimal) {
}
}

TEST(ArrowReadWrite, Decimal256AsInt) {
TEST(ArrowReadWrite, Decimal32AsInt) {
using ::arrow::Decimal256;
using ::arrow::field;

auto type = ::arrow::decimal256(8, 4);
auto type = ::arrow::decimal32(8, 4);

const char* json = R"(["1.0000", null, "-1.2345", "-1000.5678",
"-9999.9999", "9999.9999"])";
Expand Down Expand Up @@ -4059,7 +4103,7 @@ TEST(TestArrowReaderAdHoc, WriteBatchedNestedNullableStringColumn) {
// ARROW-10493
std::vector<std::shared_ptr<::arrow::Field>> fields{
::arrow::field("s", ::arrow::utf8(), /*nullable=*/true),
::arrow::field("d", ::arrow::decimal128(4, 2), /*nullable=*/true),
::arrow::field("d", ::arrow::decimal32(4, 2), /*nullable=*/true),
::arrow::field("b", ::arrow::boolean(), /*nullable=*/true),
::arrow::field("i8", ::arrow::int8(), /*nullable=*/true),
::arrow::field("i64", ::arrow::int64(), /*nullable=*/true)};
Expand Down Expand Up @@ -4222,25 +4266,47 @@ TEST_P(TestArrowReaderAdHocSparkAndHvr, ReadDecimals) {

std::shared_ptr<Array> expected_array;

::arrow::Decimal128Builder builder(decimal_type, pool);

for (int32_t i = 0; i < expected_length; ++i) {
::arrow::Decimal128 value((i + 1) * 100);
ASSERT_OK(builder.Append(value));
if (decimal_type->id() == ::arrow::Decimal32Type::type_id) {
::arrow::Decimal32Builder builder(decimal_type, pool);
for (int32_t i = 0; i < expected_length; ++i) {
::arrow::Decimal32 value((i + 1) * 100);
ASSERT_OK(builder.Append(value));
}
ASSERT_OK(builder.Finish(&expected_array));
} else if (decimal_type->id() == ::arrow::Decimal64Type::type_id) {
::arrow::Decimal64Builder builder(decimal_type, pool);
for (int32_t i = 0; i < expected_length; ++i) {
::arrow::Decimal64 value((i + 1) * 100);
ASSERT_OK(builder.Append(value));
}
ASSERT_OK(builder.Finish(&expected_array));
} else if (decimal_type->id() == ::arrow::Decimal128Type::type_id) {
::arrow::Decimal128Builder builder(decimal_type, pool);
for (int32_t i = 0; i < expected_length; ++i) {
::arrow::Decimal128 value((i + 1) * 100);
ASSERT_OK(builder.Append(value));
}
ASSERT_OK(builder.Finish(&expected_array));
} else {
::arrow::Decimal256Builder builder(decimal_type, pool);
for (int32_t i = 0; i < expected_length; ++i) {
::arrow::Decimal256 value((i + 1) * 100);
ASSERT_OK(builder.Append(value));
}
ASSERT_OK(builder.Finish(&expected_array));
}
ASSERT_OK(builder.Finish(&expected_array));

AssertArraysEqual(*expected_array, *chunk);
}

INSTANTIATE_TEST_SUITE_P(
ReadDecimals, TestArrowReaderAdHocSparkAndHvr,
::testing::Values(
std::make_tuple("int32_decimal.parquet", ::arrow::decimal128(4, 2)),
std::make_tuple("int64_decimal.parquet", ::arrow::decimal128(10, 2)),
std::make_tuple("int32_decimal.parquet", ::arrow::decimal32(4, 2)),
std::make_tuple("int64_decimal.parquet", ::arrow::decimal64(10, 2)),
std::make_tuple("fixed_length_decimal.parquet", ::arrow::decimal128(25, 2)),
std::make_tuple("fixed_length_decimal_legacy.parquet",
::arrow::decimal128(13, 2)),
std::make_tuple("byte_array_decimal.parquet", ::arrow::decimal128(4, 2))));
std::make_tuple("fixed_length_decimal_legacy.parquet", ::arrow::decimal64(13, 2)),
std::make_tuple("byte_array_decimal.parquet", ::arrow::decimal32(4, 2))));

TEST(TestArrowReaderAdHoc, ReadFloat16Files) {
using ::arrow::util::Float16;
Expand Down Expand Up @@ -5162,33 +5228,17 @@ class TestIntegerAnnotateDecimalTypeParquetIO : public TestParquetIO<TestType> {
this->ReaderFromSink(&reader);
this->ReadSingleColumnFile(std::move(reader), &out);

// Reader always read values as DECIMAL128 type
ASSERT_EQ(out->type()->id(), ::arrow::Type::DECIMAL128);

if (values.type()->id() == ::arrow::Type::DECIMAL128) {
AssertArraysEqual(values, *out);
} else {
auto& expected_values = dynamic_cast<const ::arrow::Decimal256Array&>(values);
auto& read_values = dynamic_cast<const ::arrow::Decimal128Array&>(*out);
ASSERT_EQ(expected_values.length(), read_values.length());
ASSERT_EQ(expected_values.null_count(), read_values.null_count());
ASSERT_EQ(expected_values.length(), read_values.length());
for (int64_t i = 0; i < expected_values.length(); ++i) {
ASSERT_EQ(expected_values.IsNull(i), read_values.IsNull(i));
if (!expected_values.IsNull(i)) {
ASSERT_EQ(::arrow::Decimal256(expected_values.Value(i)).ToString(0),
::arrow::Decimal128(read_values.Value(i)).ToString(0));
}
}
}
ASSERT_EQ(out->type()->id(), TestType::type_id);
AssertArraysEqual(values, *out);
}
};

typedef ::testing::Types<
Decimal128WithPrecisionAndScale<1>, Decimal128WithPrecisionAndScale<5>,
Decimal128WithPrecisionAndScale<10>, Decimal128WithPrecisionAndScale<18>,
Decimal256WithPrecisionAndScale<1>, Decimal256WithPrecisionAndScale<5>,
Decimal256WithPrecisionAndScale<10>, Decimal256WithPrecisionAndScale<18>>
Decimal32WithPrecisionAndScale<1>, Decimal32WithPrecisionAndScale<5>,
Decimal64WithPrecisionAndScale<10>, Decimal64WithPrecisionAndScale<18>,
Decimal128WithPrecisionAndScale<19>, Decimal128WithPrecisionAndScale<27>,
Decimal128WithPrecisionAndScale<38>, Decimal256WithPrecisionAndScale<39>,
Decimal256WithPrecisionAndScale<56>, Decimal256WithPrecisionAndScale<76>>
DecimalTestTypes;

TYPED_TEST_SUITE(TestIntegerAnnotateDecimalTypeParquetIO, DecimalTestTypes);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/parquet/arrow/arrow_schema_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ const auto TIMESTAMP_MS = ::arrow::timestamp(TimeUnit::MILLI);
const auto TIMESTAMP_US = ::arrow::timestamp(TimeUnit::MICRO);
const auto TIMESTAMP_NS = ::arrow::timestamp(TimeUnit::NANO);
const auto BINARY = ::arrow::binary();
const auto DECIMAL_8_4 = std::make_shared<::arrow::Decimal128Type>(8, 4);
const auto DECIMAL_8_4 = std::make_shared<::arrow::Decimal32Type>(8, 4);

class TestConvertParquetSchema : public ::testing::Test {
public:
Expand Down
Loading

0 comments on commit 83c8a02

Please sign in to comment.