Skip to content

Commit

Permalink
Fix DistinctCount for timestamps with time zone (#10043) (#10105)
Browse files Browse the repository at this point in the history
* Fix DistinctCount for timestamps with time zone

Preserve the original data type in the aggregation state

* Add tests for decimal count distinct

Co-authored-by: Georgi Krastev <[email protected]>
  • Loading branch information
alamb and joroKr21 authored Apr 17, 2024
1 parent ec4a6da commit 9b0f4e2
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 26 deletions.
42 changes: 24 additions & 18 deletions datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ impl AggregateExpr for DistinctCount {
UInt16 => Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new()),
UInt32 => Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new()),
UInt64 => Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new()),
Decimal128(_, _) => {
Box::new(PrimitiveDistinctCountAccumulator::<Decimal128Type>::new())
}
Decimal256(_, _) => {
Box::new(PrimitiveDistinctCountAccumulator::<Decimal256Type>::new())
}
dt @ Decimal128(_, _) => Box::new(
PrimitiveDistinctCountAccumulator::<Decimal128Type>::new()
.with_data_type(dt.clone()),
),
dt @ Decimal256(_, _) => Box::new(
PrimitiveDistinctCountAccumulator::<Decimal256Type>::new()
.with_data_type(dt.clone()),
),

Date32 => Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new()),
Date64 => Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new()),
Expand All @@ -130,18 +132,22 @@ impl AggregateExpr for DistinctCount {
Time64(Nanosecond) => {
Box::new(PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new())
}
Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampMicrosecondType,
>::new()),
Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampMillisecondType,
>::new()),
Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampNanosecondType,
>::new()),
Timestamp(Second, _) => {
Box::new(PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new())
}
dt @ Timestamp(Microsecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new()
.with_data_type(dt.clone()),
),
dt @ Timestamp(Millisecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new()
.with_data_type(dt.clone()),
),
dt @ Timestamp(Nanosecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new()
.with_data_type(dt.clone()),
),
dt @ Timestamp(Second, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new()
.with_data_type(dt.clone()),
),

Float16 => Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
Float32 => Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()),
Expand Down
15 changes: 12 additions & 3 deletions datafusion/physical-expr/src/aggregate/count_distinct/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use ahash::RandomState;
use arrow::array::ArrayRef;
use arrow_array::types::ArrowPrimitiveType;
use arrow_array::PrimitiveArray;
use arrow_schema::DataType;

use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
Expand All @@ -45,6 +46,7 @@ where
T::Native: Eq + Hash,
{
values: HashSet<T::Native, RandomState>,
data_type: DataType,
}

impl<T> PrimitiveDistinctCountAccumulator<T>
Expand All @@ -55,8 +57,14 @@ where
pub(super) fn new() -> Self {
Self {
values: HashSet::default(),
data_type: T::DATA_TYPE,
}
}

pub(super) fn with_data_type(mut self, data_type: DataType) -> Self {
self.data_type = data_type;
self
}
}

impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
Expand All @@ -65,9 +73,10 @@ where
T::Native: Eq + Hash,
{
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().cloned(),
)) as ArrayRef;
let arr = Arc::new(
PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
.with_data_type(self.data_type.clone()),
);
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}
Expand Down
37 changes: 32 additions & 5 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1876,18 +1876,22 @@ select
arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros,
arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis,
arrow_cast(column1, 'Timestamp(Second, None)') as secs,
arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as nanos_utc,
arrow_cast(column1, 'Timestamp(Microsecond, Some("UTC"))') as micros_utc,
arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as millis_utc,
arrow_cast(column1, 'Timestamp(Second, Some("UTC"))') as secs_utc,
column2 as names,
column3 as tag
from t_source;

# Demonstate the contents
query PPPPTT
query PPPPPPPPTT
select * from t;
----
2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 X
2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 X
NULL NULL NULL NULL Row 2 Y
2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 Y
2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 X
2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 X
NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y
2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 Y


# aggregate_timestamps_sum
Expand Down Expand Up @@ -1933,6 +1937,17 @@ SELECT tag, max(nanos), max(micros), max(millis), max(secs) FROM t GROUP BY tag
X 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10
Y 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10

# aggregate_timestamps_count_distinct_with_tz
query IIII
SELECT count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t;
----
3 3 3 3

query TIIII
SELECT tag, count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t GROUP BY tag ORDER BY tag;
----
X 2 2 2 2
Y 1 1 1 1

# aggregate_timestamps_avg
statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\.
Expand Down Expand Up @@ -2285,6 +2300,18 @@ select c2, avg(c1), arrow_typeof(avg(c1)) from d_table GROUP BY c2 ORDER BY c2
A 110.0045 Decimal128(14, 7)
B -100.0045 Decimal128(14, 7)

# aggregate_decimal_count_distinct
query I
select count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table
----
4

query TI
select c2, count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table GROUP BY c2 ORDER BY c2
----
A 2
B 2

# Use PostgresSQL dialect
statement ok
set datafusion.sql_parser.dialect = 'Postgres';
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -720,5 +720,16 @@ select count(*),c1 from decimal256_simple group by c1 order by c1;
4 0.00004
5 0.00005

query I
select count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple;
----
2

query BI
select c4, count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple GROUP BY c4 ORDER BY c4;
----
false 2
true 2

statement ok
drop table decimal256_simple;

0 comments on commit 9b0f4e2

Please sign in to comment.