diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 58bc7bb90a88..e4a7eb049e9e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -19,7 +19,7 @@ use arrow::record_batch::RecordBatch; use arrow_array::types::{ - Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Date32Type, Date64Type, Decimal128Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; @@ -170,6 +170,9 @@ pub(crate) fn new_group_values( TimeUnit::Microsecond => downcast_helper!(TimestampMicrosecondType, d), TimeUnit::Nanosecond => downcast_helper!(TimestampNanosecondType, d), }, + DataType::Decimal128(_, _) => { + downcast_helper!(Decimal128Type, d); + } DataType::Utf8 => { return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 89041eb0f04e..333eb6bbcbe8 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -31,8 +31,8 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; use arrow::compute::cast; use arrow::datatypes::{ - BinaryViewType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, - Int32Type, Int64Type, Int8Type, StringViewType, Time32MillisecondType, + BinaryViewType, Date32Type, Date64Type, Decimal128Type, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, StringViewType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, @@ -1008,6 +1008,14 @@ impl GroupValues for GroupValuesColumn { ) } }, + &DataType::Decimal128(_, _) => { + instantiate_primitive! { + v, + nullable, + Decimal128Type, + data_type + } + } &DataType::Utf8 => { let b = ByteGroupValueBuilder::::new(OutputType::Utf8); v.push(Box::new(b) as _) @@ -1214,6 +1222,7 @@ fn supported_type(data_type: &DataType) -> bool { | DataType::UInt64 | DataType::Float32 | DataType::Float64 + | DataType::Decimal128(_, _) | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index 4686a78f24b0..4ceeb634bad2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -200,10 +200,10 @@ impl GroupColumn let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(first_n), - first_n_nulls, - )) + Arc::new( + PrimitiveArray::::new(ScalarBuffer::from(first_n), first_n_nulls) + .with_data_type(self.data_type.clone()), + ) } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 05214ec10d68..85cd2e79b936 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -208,6 +208,7 @@ where build_primitive(split, null_group) } }; + Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) } diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 4acf519c5de4..df7e21c2da44 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5499,3 +5499,42 @@ SELECT GROUP BY ts, text ---- foo 2024-01-01T08:00:00+08:00 + +# Test multi group by int + Decimal128 +statement ok +create table source as values +(1, '123.45'), +(1, '123.45'), +(2, '678.90'), +(2, '1011.12'), +(3, '1314.15'), +(3, '1314.15'), +(2, '1011.12'), +(null, null), +(null, '123.45'), +(null, null), +(null, '123.45'), +(2, '678.90'), +(2, '678.90'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Decimal128(10, 2)') as b from source; + +query IRI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 123.45 2 +1 NULL 1 +2 678.9 3 +2 1011.12 2 +3 1314.15 2 +NULL 123.45 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source;