From ce705d2a795512e51e952721d2d981d729a9461d Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 11 Dec 2024 17:22:00 +0000 Subject: [PATCH 1/7] Sum statistics --- datafusion/common/src/stats.rs | 75 +++++++++++++--- datafusion/core/src/datasource/statistics.rs | 9 +- .../tests/custom_sources_cases/statistics.rs | 2 + datafusion/expr/src/udaf.rs | 1 + datafusion/functions-aggregate/src/sum.rs | 31 ++++++- .../physical-expr-common/src/physical_expr.rs | 7 +- .../physical-expr/src/expressions/binary.rs | 19 +++- .../physical-expr/src/expressions/cast.rs | 14 ++- .../physical-expr/src/expressions/column.rs | 6 +- .../physical-expr/src/expressions/literal.rs | 19 +++- datafusion/physical-plan/src/common.rs | 3 + datafusion/physical-plan/src/filter.rs | 2 + .../physical-plan/src/joins/cross_join.rs | 14 +++ datafusion/physical-plan/src/joins/utils.rs | 4 + datafusion/physical-plan/src/projection.rs | 31 +++---- datafusion/physical-plan/src/union.rs | 24 +++-- datafusion/physical-plan/src/values.rs | 1 + .../proto/datafusion_common.proto | 1 + datafusion/proto-common/src/from_proto/mod.rs | 5 ++ .../proto-common/src/generated/pbjson.rs | 18 ++++ .../proto-common/src/generated/prost.rs | 89 +++++++++---------- datafusion/proto-common/src/to_proto/mod.rs | 1 + 22 files changed, 284 insertions(+), 92 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index d2ce965c5c49..3bb7070a6738 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Debug, Display}; use crate::{Result, ScalarValue}; -use arrow_schema::{Schema, SchemaRef}; +use arrow_schema::{DataType, Schema, SchemaRef}; /// Represents a value with a degree of certainty. `Precision` is used to /// propagate information the precision of statistical values. @@ -170,24 +170,57 @@ impl Precision { pub fn add(&self, other: &Precision) -> Precision { match (self, other) { (Precision::Exact(a), Precision::Exact(b)) => { - if let Ok(result) = a.add(b) { - Precision::Exact(result) - } else { - Precision::Absent - } + a.add(b).map(Precision::Exact).unwrap_or(Precision::Absent) } (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) | (Precision::Inexact(a), Precision::Inexact(b)) => { - if let Ok(result) = a.add(b) { - Precision::Inexact(result) - } else { - Precision::Absent - } + a.add(b).map(Precision::Inexact).unwrap_or(Precision::Absent) } (_, _) => Precision::Absent, } } + + /// Calculates the difference of two (possibly inexact) [`ScalarValue`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn sub(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => { + a.add(b).map(Precision::Exact).unwrap_or(Precision::Absent) + } + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => { + a.add(b).map(Precision::Inexact).unwrap_or(Precision::Absent) + } + (_, _) => Precision::Absent, + } + } + + /// Calculates the multiplication of two (possibly inexact) [`ScalarValue`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn multiply(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => a.mul_checked(b) + .map(Precision::Exact) + .unwrap_or(Precision::Absent), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => a.mul_checked(b).map(Precision::Inexact).unwrap_or(Precision::Absent), + (_, _) => Precision::Absent, + } + } + + /// Casts the value to the given data type, propagating exactness information. + pub fn cast_to(&self, data_type: &DataType) -> Result> { + match self { + Precision::Exact(value) => value.cast_to(data_type).map(Precision::Exact), + Precision::Inexact(value) => value.cast_to(data_type).map(Precision::Inexact), + Precision::Absent => Ok(Precision::Absent), + } + } } impl Debug for Precision { @@ -210,6 +243,16 @@ impl Display for Precision { } } +impl From> for Precision { + fn from(value: Precision) -> Self { + match value { + Precision::Exact(v) => Precision::Exact(ScalarValue::UInt64(Some(v as u64))), + Precision::Inexact(v) => Precision::Inexact(ScalarValue::UInt64(Some(v as u64))), + Precision::Absent => Precision::Absent, + } + } +} + /// Statistics for a relation /// Fields are optional and can be inexact because the sources /// sometimes provide approximate estimates for performance reasons @@ -401,6 +444,11 @@ impl Display for Statistics { } else { s }; + let s = if cs.sum_value != Precision::Absent { + format!("{} Sum={}", s, cs.sum_value) + } else { + s + }; let s = if cs.null_count != Precision::Absent { format!("{} Null={}", s, cs.null_count) } else { @@ -436,6 +484,8 @@ pub struct ColumnStatistics { pub max_value: Precision, /// Minimum value of column pub min_value: Precision, + /// Sum value of a column + pub sum_value: Precision, /// Number of distinct values pub distinct_count: Precision, } @@ -458,6 +508,7 @@ impl ColumnStatistics { null_count: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, + sum_value: Precision::Absent, distinct_count: Precision::Absent, } } @@ -469,6 +520,7 @@ impl ColumnStatistics { self.null_count = self.null_count.to_inexact(); self.max_value = self.max_value.to_inexact(); self.min_value = self.min_value.to_inexact(); + self.sum_value = self.sum_value.to_inexact(); self.distinct_count = self.distinct_count.to_inexact(); self } @@ -646,6 +698,7 @@ mod tests { null_count: Precision::Exact(null_count), max_value: Precision::Exact(ScalarValue::Int64(Some(42))), min_value: Precision::Exact(ScalarValue::Int64(Some(64))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(4600))), distinct_count: Precision::Exact(100), } } diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index 201bbfd5c007..f81e7bb916de 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -73,9 +73,7 @@ pub async fn get_statistics_with_limit( for (index, file_column) in file_stats.column_statistics.clone().into_iter().enumerate() { - col_stats_set[index].null_count = file_column.null_count; - col_stats_set[index].max_value = file_column.max_value; - col_stats_set[index].min_value = file_column.min_value; + col_stats_set[index] = file_column; } // If the number of rows exceeds the limit, we can stop processing @@ -113,12 +111,14 @@ pub async fn get_statistics_with_limit( null_count: file_nc, max_value: file_max, min_value: file_min, + sum_value: file_sum, distinct_count: _, } = file_col_stats; col_stats.null_count = add_row_stats(*file_nc, col_stats.null_count); set_max_if_greater(file_max, &mut col_stats.max_value); - set_min_if_lesser(file_min, &mut col_stats.min_value) + set_min_if_lesser(file_min, &mut col_stats.min_value); + col_stats.sum_value = file_sum.add(&col_stats.sum_value); } // If the number of rows exceeds the limit, we can stop processing @@ -204,6 +204,7 @@ pub(crate) fn get_col_stats( null_count: null_counts[i], max_value: max_value.map(Precision::Exact).unwrap_or(Precision::Absent), min_value: min_value.map(Precision::Exact).unwrap_or(Precision::Absent), + sum_value: Precision::Absent, distinct_count: Precision::Absent, } }) diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 41d182a3767b..b937b505bbda 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -200,12 +200,14 @@ fn fully_defined() -> (Statistics, Schema) { distinct_count: Precision::Exact(2), max_value: Precision::Exact(ScalarValue::Int32(Some(1023))), min_value: Precision::Exact(ScalarValue::Int32(Some(-24))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(10))), null_count: Precision::Exact(0), }, ColumnStatistics { distinct_count: Precision::Exact(13), max_value: Precision::Exact(ScalarValue::Int64(Some(5486))), min_value: Precision::Exact(ScalarValue::Int64(Some(-6783))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(10))), null_count: Precision::Exact(5), }, ], diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 56c9822495f8..4d8b19f09b02 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -96,6 +96,7 @@ impl fmt::Display for AggregateUDF { } /// Arguments passed to [`AggregateUDFImpl::value_from_stats`] +#[derive(Debug)] pub struct StatisticsArgs<'a> { /// The statistics of the aggregate input pub statistics: &'a Statistics, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 6c2854f6bc24..2f5ff35ad4ad 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -34,13 +34,11 @@ use arrow::datatypes::{ }; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::stats::Precision; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, - Signature, Volatility, -}; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, Signature, StatisticsArgs, Volatility}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_functions_aggregate_common::utils::Hashable; use datafusion_macros::user_doc; @@ -254,6 +252,31 @@ impl AggregateUDFImpl for Sum { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + log::warn!("SUM STATS: {:#?}", statistics_args); + if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { + match *num_rows { + 0 => return ScalarValue::new_zero(statistics_args.return_type).ok(), + _ => { + if statistics_args.exprs.len() == 1 { + let sum_value = if let Precision::Exact(sum) = statistics_args.exprs[0] + .column_statistics(&statistics_args.statistics) + .ok() + ?.sum_value { + + sum.cast_to(statistics_args.return_type).ok() + } else { + None + }; + log::warn!("SUM STATS VALUE: {:?}", sum_value); + return sum_value; + } + } + } + } + None + } } /// This accumulator computes SUM incrementally diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index c2e892d63da0..4b025c5b79e0 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -26,7 +26,7 @@ use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_common::{internal_err, not_impl_err, ColumnStatistics, Result, Statistics}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; @@ -149,6 +149,11 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { fn get_properties(&self, _children: &[ExprProperties]) -> Result { Ok(ExprProperties::new_unknown()) } + + /// Return the column statistics of this expression given the statistics of the input + fn column_statistics(&self, _statistics: &Statistics) -> Result { + Ok(ColumnStatistics::new_unknown()) + } } /// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index ae2bfe5b0bd4..7027e6d435bc 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -32,7 +32,7 @@ use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; use arrow_schema::ArrowError; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{internal_err, ColumnStatistics, Result, ScalarValue, Statistics}; use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::binary::get_result_type; @@ -535,6 +535,23 @@ impl PhysicalExpr for BinaryExpr { _ => Ok(ExprProperties::new_unknown()), } } + + fn column_statistics(&self, statistics: &Statistics) -> Result { + let mut col_stats = ColumnStatistics::new_unknown(); + + // Propagate the sum statistic for numeric operators + if matches!(self.op, Operator::Plus | Operator::Minus) { + let left = self.left.column_statistics(statistics)?; + let right = self.right.column_statistics(statistics)?; + match self.op { + Operator::Plus => col_stats.sum_value = left.sum_value.add(&right.sum_value), + Operator::Minus => col_stats.sum_value = left.sum_value.sub(&right.sum_value), + _ => {} + } + } + + Ok(col_stats) + } } /// Casts dictionary array to result type for binary numerical operators. Such operators diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 7eda5fb4beaa..aa3973327da5 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -26,7 +26,7 @@ use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{not_impl_err, ColumnStatistics, Result, Statistics}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; @@ -194,6 +194,18 @@ impl PhysicalExpr for CastExpr { Ok(ExprProperties::new_unknown().with_range(unbounded)) } } + + fn column_statistics(&self, statistics: &Statistics) -> Result { + let child_stats = self.expr.column_statistics(statistics)?; + Ok(ColumnStatistics { + null_count: child_stats.null_count, + max_value: child_stats.max_value.cast_to(&self.cast_type)?, + min_value: child_stats.min_value.cast_to(&self.cast_type)?, + // The sum value may be a wider numeric type than the data value, it's not safe to cast. + sum_value: child_stats.sum_value, + distinct_count: Default::default(), + }) + } } /// Return a PhysicalExpression representing `expr` casted to diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 5f6932f6d725..1930785d4a3e 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -28,7 +28,7 @@ use arrow::{ }; use arrow_schema::SchemaRef; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{internal_err, plan_err, ColumnStatistics, Result, Statistics}; use datafusion_expr::ColumnarValue; /// Represents the column at a given index in a RecordBatch @@ -138,6 +138,10 @@ impl PhysicalExpr for Column { ) -> Result> { Ok(self) } + + fn column_statistics(&self, statistics: &Statistics) -> Result { + Ok(statistics.column_statistics[self.index()].clone()) + } } impl Column { diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index f0d02eb605b2..a7cad9a3e352 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -27,7 +27,8 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{ColumnStatistics, Result, ScalarValue, Statistics}; +use datafusion_common::stats::Precision; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -92,6 +93,22 @@ impl PhysicalExpr for Literal { range: Interval::try_new(self.value().clone(), self.value().clone())?, }) } + + fn column_statistics(&self, statistics: &Statistics) -> Result { + Ok(ColumnStatistics { + null_count: Precision::Exact(if self.value.is_null() { 1 } else { 0 }), + max_value: Precision::Exact(self.value.clone()), + min_value: Precision::Exact(self.value.clone()), + sum_value: if self.value.data_type().is_numeric() { + Precision::::from(statistics.num_rows).cast_to(&self.value.data_type()) + .map(|num_rows| Precision::Exact(self.value.clone()).multiply(&num_rows)) + .unwrap_or(Precision::Absent) + } else { + Precision::Absent + }, + distinct_count: Precision::Exact(1), + }) + } } /// Create a literal expression diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index aefb90d1d1b7..20a4e89dba94 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -333,12 +333,14 @@ mod tests { distinct_count: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, + sum_value: Precision::Absent, null_count: Precision::Exact(0), }, ColumnStatistics { distinct_count: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, + sum_value: Precision::Absent, null_count: Precision::Exact(0), }, ], @@ -371,6 +373,7 @@ mod tests { distinct_count: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, + sum_value: Precision::Absent, null_count: Precision::Exact(3), }], }; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 07898e8d22d8..f15336f0edb1 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -414,6 +414,7 @@ fn collect_new_statistics( null_count: input_column_stats[idx].null_count.to_inexact(), max_value, min_value, + sum_value: Precision::Absent, distinct_count: distinct_count.to_inexact(), } }, @@ -1132,6 +1133,7 @@ mod tests { null_count: Precision::Absent, min_value: Precision::Inexact(ScalarValue::Int32(Some(5))), max_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + sum_value: Precision::Absent, distinct_count: Precision::Absent, }], }; diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 8bf675e87362..f10c705f49b8 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -354,12 +354,14 @@ fn stats_cartesian_product( distinct_count: s.distinct_count, min_value: s.min_value, max_value: s.max_value, + sum_value: s.sum_value.multiply(&right_row_count.into()), }) .chain(right_col_stats.into_iter().map(|s| ColumnStatistics { null_count: s.null_count.multiply(&left_row_count), distinct_count: s.distinct_count, min_value: s.min_value, max_value: s.max_value, + sum_value: s.sum_value.multiply(&left_row_count.into()), })) .collect(); @@ -593,12 +595,14 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), }, ColumnStatistics { distinct_count: Precision::Exact(1), max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(3), }, ], @@ -611,6 +615,7 @@ mod tests { distinct_count: Precision::Exact(3), max_value: Precision::Exact(ScalarValue::Int64(Some(12))), min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(20))), null_count: Precision::Exact(2), }], }; @@ -625,18 +630,21 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42 * 20))), null_count: Precision::Exact(0), }, ColumnStatistics { distinct_count: Precision::Exact(1), max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, null_count: Precision::Exact(3 * right_row_count), }, ColumnStatistics { distinct_count: Precision::Exact(3), max_value: Precision::Exact(ScalarValue::Int64(Some(12))), min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42 * 20))), null_count: Precision::Exact(2 * left_row_count), }, ], @@ -657,12 +665,14 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42 * 20))), null_count: Precision::Exact(0), }, ColumnStatistics { distinct_count: Precision::Exact(1), max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, null_count: Precision::Exact(3), }, ], @@ -675,6 +685,7 @@ mod tests { distinct_count: Precision::Exact(3), max_value: Precision::Exact(ScalarValue::Int64(Some(12))), min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(2), }], }; @@ -689,18 +700,21 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { distinct_count: Precision::Exact(1), max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { distinct_count: Precision::Exact(3), max_value: Precision::Exact(ScalarValue::Int64(Some(12))), min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(2 * left_row_count), }, ], diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 0366c9fa5e46..f56ff96702c6 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1989,6 +1989,8 @@ mod tests { fn create_column_stats( min: Precision, max: Precision, + // FIXME(ngates): add these to test cases + // sum: Precision, distinct_count: Precision, null_count: Precision, ) -> ColumnStatistics { @@ -1996,6 +1998,8 @@ mod tests { distinct_count, min_value: min.map(ScalarValue::from), max_value: max.map(ScalarValue::from), + // sum_value: sum.map(ScalarValue::from), + sum_value: Absent, null_count, } } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index c1d3f368366f..b46d5cfaa40e 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -32,7 +32,7 @@ use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::{ColumnStatistics, DisplayFormatType, ExecutionPlan, PhysicalExpr}; +use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr}; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; @@ -224,11 +224,11 @@ impl ExecutionPlan for ProjectionExec { } fn statistics(&self) -> Result { - Ok(stats_projection( + stats_projection( self.input.statistics()?, self.expr.iter().map(|(e, _)| Arc::clone(e)), Arc::clone(&self.schema), - )) + ) } fn supports_limit_pushdown(&self) -> bool { @@ -262,18 +262,12 @@ fn stats_projection( mut stats: Statistics, exprs: impl Iterator>, schema: SchemaRef, -) -> Statistics { +) -> Result { let mut primitive_row_size = 0; let mut primitive_row_size_possible = true; let mut column_statistics = vec![]; for expr in exprs { - let col_stats = if let Some(col) = expr.as_any().downcast_ref::() { - stats.column_statistics[col.index()].clone() - } else { - // TODO stats: estimate more statistics from expressions - // (expressions should compute their statistics themselves) - ColumnStatistics::new_unknown() - }; + let col_stats = expr.column_statistics(&stats)?; column_statistics.push(col_stats); if let Ok(data_type) = expr.data_type(&schema) { if let Some(value) = data_type.primitive_width() { @@ -289,7 +283,7 @@ fn stats_projection( Precision::Exact(primitive_row_size).multiply(&stats.num_rows); } stats.column_statistics = column_statistics; - stats + Ok(stats) } impl ProjectionStream { @@ -359,7 +353,7 @@ mod tests { use crate::test; use arrow_schema::DataType; - use datafusion_common::ScalarValue; + use datafusion_common::{ScalarValue, ColumnStatistics}; #[tokio::test] async fn project_no_column() -> Result<()> { @@ -387,18 +381,21 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), }, ColumnStatistics { distinct_count: Precision::Exact(1), max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, null_count: Precision::Exact(3), }, ColumnStatistics { distinct_count: Precision::Absent, max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))), null_count: Precision::Absent, }, ], @@ -421,7 +418,7 @@ mod tests { Arc::new(Column::new("col0", 0)), ]; - let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); + let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap(); let expected = Statistics { num_rows: Precision::Exact(5), @@ -431,12 +428,14 @@ mod tests { distinct_count: Precision::Exact(1), max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, null_count: Precision::Exact(3), }, ColumnStatistics { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), }, ], @@ -455,7 +454,7 @@ mod tests { Arc::new(Column::new("col0", 0)), ]; - let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); + let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap(); let expected = Statistics { num_rows: Precision::Exact(5), @@ -465,12 +464,14 @@ mod tests { distinct_count: Precision::Absent, max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))), null_count: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), }, ], diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index bd36753880eb..24c6f062c687 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -573,15 +573,16 @@ impl Stream for CombinedRecordBatchStream { } fn col_stats_union( - mut left: ColumnStatistics, + left: ColumnStatistics, right: ColumnStatistics, ) -> ColumnStatistics { - left.distinct_count = Precision::Absent; - left.min_value = left.min_value.min(&right.min_value); - left.max_value = left.max_value.max(&right.max_value); - left.null_count = left.null_count.add(&right.null_count); - - left + ColumnStatistics { + null_count: left.null_count.add(&right.null_count), + max_value: left.max_value.max(&right.max_value), + min_value: left.min_value.min(&right.min_value), + sum_value: left.sum_value.add(&right.sum_value), + distinct_count: Precision::Absent, + } } fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { @@ -671,18 +672,21 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), }, ColumnStatistics { distinct_count: Precision::Exact(1), max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, null_count: Precision::Exact(3), }, ColumnStatistics { distinct_count: Precision::Absent, max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))), null_count: Precision::Absent, }, ], @@ -696,18 +700,21 @@ mod tests { distinct_count: Precision::Exact(3), max_value: Precision::Exact(ScalarValue::Int64(Some(34))), min_value: Precision::Exact(ScalarValue::Int64(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(1), }, ColumnStatistics { distinct_count: Precision::Absent, max_value: Precision::Exact(ScalarValue::from("c")), min_value: Precision::Exact(ScalarValue::from("b")), + sum_value: Precision::Absent, null_count: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, + sum_value: Precision::Absent, null_count: Precision::Absent, }, ], @@ -722,18 +729,21 @@ mod tests { distinct_count: Precision::Absent, max_value: Precision::Exact(ScalarValue::Int64(Some(34))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(1), }, ColumnStatistics { distinct_count: Precision::Absent, max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, null_count: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, + sum_value: Precision::Absent, null_count: Precision::Absent, }, ], diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index edadf98cb10c..152e48ed07bf 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -294,6 +294,7 @@ mod tests { distinct_count: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, + sum_value: Precision::Absent, },], } ); diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 2da8b6066742..6a68cd22d9fb 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -568,6 +568,7 @@ message Statistics { message ColumnStats { Precision min_value = 1; Precision max_value = 2; + Precision sum_value = 5; Precision null_count = 3; Precision distinct_count = 4; } diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index 14375c0590a4..4acec1ef54b3 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -694,6 +694,11 @@ impl From<&protobuf::ColumnStats> for ColumnStatistics { } else { Precision::Absent }, + sum_value: if let Some(sum) = &cs.sum_value { + sum.clone().into() + } else { + Precision::Absent + }, distinct_count: if let Some(dc) = &cs.distinct_count { dc.clone().into() } else { diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 6a75b14d35a8..4ec3703e1da5 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -985,6 +985,9 @@ impl serde::Serialize for ColumnStats { if self.max_value.is_some() { len += 1; } + if self.sum_value.is_some() { + len += 1; + } if self.null_count.is_some() { len += 1; } @@ -998,6 +1001,9 @@ impl serde::Serialize for ColumnStats { if let Some(v) = self.max_value.as_ref() { struct_ser.serialize_field("maxValue", v)?; } + if let Some(v) = self.sum_value.as_ref() { + struct_ser.serialize_field("sumValue", v)?; + } if let Some(v) = self.null_count.as_ref() { struct_ser.serialize_field("nullCount", v)?; } @@ -1018,6 +1024,8 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { "minValue", "max_value", "maxValue", + "sum_value", + "sumValue", "null_count", "nullCount", "distinct_count", @@ -1028,6 +1036,7 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { enum GeneratedField { MinValue, MaxValue, + SumValue, NullCount, DistinctCount, } @@ -1053,6 +1062,7 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { match value { "minValue" | "min_value" => Ok(GeneratedField::MinValue), "maxValue" | "max_value" => Ok(GeneratedField::MaxValue), + "sumValue" | "sum_value" => Ok(GeneratedField::SumValue), "nullCount" | "null_count" => Ok(GeneratedField::NullCount), "distinctCount" | "distinct_count" => Ok(GeneratedField::DistinctCount), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -1076,6 +1086,7 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { { let mut min_value__ = None; let mut max_value__ = None; + let mut sum_value__ = None; let mut null_count__ = None; let mut distinct_count__ = None; while let Some(k) = map_.next_key()? { @@ -1092,6 +1103,12 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { } max_value__ = map_.next_value()?; } + GeneratedField::SumValue => { + if sum_value__.is_some() { + return Err(serde::de::Error::duplicate_field("sumValue")); + } + sum_value__ = map_.next_value()?; + } GeneratedField::NullCount => { if null_count__.is_some() { return Err(serde::de::Error::duplicate_field("nullCount")); @@ -1109,6 +1126,7 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { Ok(ColumnStats { min_value: min_value__, max_value: max_value__, + sum_value: sum_value__, null_count: null_count__, distinct_count: distinct_count__, }) diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 50a3cff5f568..6cc7652265c0 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -638,27 +638,33 @@ pub struct ParquetColumnSpecificOptions { #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetColumnOptions { #[prost(oneof = "parquet_column_options::BloomFilterEnabledOpt", tags = "1")] - pub bloom_filter_enabled_opt: - ::core::option::Option, + pub bloom_filter_enabled_opt: ::core::option::Option< + parquet_column_options::BloomFilterEnabledOpt, + >, #[prost(oneof = "parquet_column_options::EncodingOpt", tags = "2")] pub encoding_opt: ::core::option::Option, #[prost(oneof = "parquet_column_options::DictionaryEnabledOpt", tags = "3")] - pub dictionary_enabled_opt: - ::core::option::Option, + pub dictionary_enabled_opt: ::core::option::Option< + parquet_column_options::DictionaryEnabledOpt, + >, #[prost(oneof = "parquet_column_options::CompressionOpt", tags = "4")] pub compression_opt: ::core::option::Option, #[prost(oneof = "parquet_column_options::StatisticsEnabledOpt", tags = "5")] - pub statistics_enabled_opt: - ::core::option::Option, + pub statistics_enabled_opt: ::core::option::Option< + parquet_column_options::StatisticsEnabledOpt, + >, #[prost(oneof = "parquet_column_options::BloomFilterFppOpt", tags = "6")] - pub bloom_filter_fpp_opt: - ::core::option::Option, + pub bloom_filter_fpp_opt: ::core::option::Option< + parquet_column_options::BloomFilterFppOpt, + >, #[prost(oneof = "parquet_column_options::BloomFilterNdvOpt", tags = "7")] - pub bloom_filter_ndv_opt: - ::core::option::Option, + pub bloom_filter_ndv_opt: ::core::option::Option< + parquet_column_options::BloomFilterNdvOpt, + >, #[prost(oneof = "parquet_column_options::MaxStatisticsSizeOpt", tags = "8")] - pub max_statistics_size_opt: - ::core::option::Option, + pub max_statistics_size_opt: ::core::option::Option< + parquet_column_options::MaxStatisticsSizeOpt, + >, } /// Nested message and enum types in `ParquetColumnOptions`. pub mod parquet_column_options { @@ -763,22 +769,27 @@ pub struct ParquetOptions { #[prost(string, tag = "16")] pub created_by: ::prost::alloc::string::String, #[prost(oneof = "parquet_options::MetadataSizeHintOpt", tags = "4")] - pub metadata_size_hint_opt: - ::core::option::Option, + pub metadata_size_hint_opt: ::core::option::Option< + parquet_options::MetadataSizeHintOpt, + >, #[prost(oneof = "parquet_options::CompressionOpt", tags = "10")] pub compression_opt: ::core::option::Option, #[prost(oneof = "parquet_options::DictionaryEnabledOpt", tags = "11")] - pub dictionary_enabled_opt: - ::core::option::Option, + pub dictionary_enabled_opt: ::core::option::Option< + parquet_options::DictionaryEnabledOpt, + >, #[prost(oneof = "parquet_options::StatisticsEnabledOpt", tags = "13")] - pub statistics_enabled_opt: - ::core::option::Option, + pub statistics_enabled_opt: ::core::option::Option< + parquet_options::StatisticsEnabledOpt, + >, #[prost(oneof = "parquet_options::MaxStatisticsSizeOpt", tags = "14")] - pub max_statistics_size_opt: - ::core::option::Option, + pub max_statistics_size_opt: ::core::option::Option< + parquet_options::MaxStatisticsSizeOpt, + >, #[prost(oneof = "parquet_options::ColumnIndexTruncateLengthOpt", tags = "17")] - pub column_index_truncate_length_opt: - ::core::option::Option, + pub column_index_truncate_length_opt: ::core::option::Option< + parquet_options::ColumnIndexTruncateLengthOpt, + >, #[prost(oneof = "parquet_options::EncodingOpt", tags = "19")] pub encoding_opt: ::core::option::Option, #[prost(oneof = "parquet_options::BloomFilterFppOpt", tags = "21")] @@ -856,14 +867,14 @@ pub struct ColumnStats { pub min_value: ::core::option::Option, #[prost(message, optional, tag = "2")] pub max_value: ::core::option::Option, + #[prost(message, optional, tag = "5")] + pub sum_value: ::core::option::Option, #[prost(message, optional, tag = "3")] pub null_count: ::core::option::Option, #[prost(message, optional, tag = "4")] pub distinct_count: ::core::option::Option, } -#[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, -)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum JoinType { Inner = 0, @@ -910,9 +921,7 @@ impl JoinType { } } } -#[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, -)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum JoinConstraint { On = 0, @@ -938,9 +947,7 @@ impl JoinConstraint { } } } -#[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, -)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum TimeUnit { Second = 0, @@ -972,9 +979,7 @@ impl TimeUnit { } } } -#[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, -)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum IntervalUnit { YearMonth = 0, @@ -1003,9 +1008,7 @@ impl IntervalUnit { } } } -#[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, -)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum UnionMode { Sparse = 0, @@ -1031,9 +1034,7 @@ impl UnionMode { } } } -#[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, -)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum CompressionTypeVariant { Gzip = 0, @@ -1068,9 +1069,7 @@ impl CompressionTypeVariant { } } } -#[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, -)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum JoinSide { LeftSide = 0, @@ -1099,9 +1098,7 @@ impl JoinSide { } } } -#[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, -)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum PrecisionInfo { Exact = 0, diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 1b9583516ced..8fa1c0dd2ed1 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -748,6 +748,7 @@ impl From<&ColumnStatistics> for protobuf::ColumnStats { protobuf::ColumnStats { min_value: Some(protobuf::Precision::from(&s.min_value)), max_value: Some(protobuf::Precision::from(&s.max_value)), + sum_value: Some(protobuf::Precision::from(&s.sum_value)), null_count: Some(protobuf::Precision::from(&s.null_count)), distinct_count: Some(protobuf::Precision::from(&s.distinct_count)), } From 692187ad1ca58630bd192c4ab64a9f860a905fff Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 11 Dec 2024 17:53:08 +0000 Subject: [PATCH 2/7] Fix tests --- datafusion/common/src/stats.rs | 26 ++++++--- datafusion/functions-aggregate/src/sum.rs | 29 +++++----- .../physical-expr-common/src/physical_expr.rs | 4 +- .../physical-expr/src/expressions/binary.rs | 12 +++- .../physical-expr/src/expressions/literal.rs | 9 ++- .../physical-plan/src/joins/cross_join.rs | 56 ++++++++++++++----- datafusion/physical-plan/src/joins/utils.rs | 3 - datafusion/physical-plan/src/projection.rs | 8 ++- datafusion/physical-plan/src/union.rs | 7 +-- 9 files changed, 100 insertions(+), 54 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 3bb7070a6738..3c7ed0af1a1b 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -174,9 +174,10 @@ impl Precision { } (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => { - a.add(b).map(Precision::Inexact).unwrap_or(Precision::Absent) - } + | (Precision::Inexact(a), Precision::Inexact(b)) => a + .add(b) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), (_, _) => Precision::Absent, } } @@ -191,9 +192,10 @@ impl Precision { } (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => { - a.add(b).map(Precision::Inexact).unwrap_or(Precision::Absent) - } + | (Precision::Inexact(a), Precision::Inexact(b)) => a + .add(b) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), (_, _) => Precision::Absent, } } @@ -203,12 +205,16 @@ impl Precision { /// values is [`Precision::Absent`], the result is `Absent` too. pub fn multiply(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => a.mul_checked(b) + (Precision::Exact(a), Precision::Exact(b)) => a + .mul_checked(b) .map(Precision::Exact) .unwrap_or(Precision::Absent), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => a.mul_checked(b).map(Precision::Inexact).unwrap_or(Precision::Absent), + | (Precision::Inexact(a), Precision::Inexact(b)) => a + .mul_checked(b) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), (_, _) => Precision::Absent, } } @@ -247,7 +253,9 @@ impl From> for Precision { fn from(value: Precision) -> Self { match value { Precision::Exact(v) => Precision::Exact(ScalarValue::UInt64(Some(v as u64))), - Precision::Inexact(v) => Precision::Inexact(ScalarValue::UInt64(Some(v as u64))), + Precision::Inexact(v) => { + Precision::Inexact(ScalarValue::UInt64(Some(v as u64))) + } Precision::Absent => Precision::Absent, } } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 2f5ff35ad4ad..75423d302733 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -33,12 +33,15 @@ use arrow::datatypes::{ DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; use datafusion_common::stats::Precision; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, Signature, StatisticsArgs, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, StatisticsArgs, Volatility, +}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_functions_aggregate_common::utils::Hashable; use datafusion_macros::user_doc; @@ -254,23 +257,23 @@ impl AggregateUDFImpl for Sum { } fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { - log::warn!("SUM STATS: {:#?}", statistics_args); + if statistics_args.is_distinct { + // Distinct sum cannot be inferred from statistics + return None; + } + if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { match *num_rows { 0 => return ScalarValue::new_zero(statistics_args.return_type).ok(), _ => { if statistics_args.exprs.len() == 1 { - let sum_value = if let Precision::Exact(sum) = statistics_args.exprs[0] + if let Precision::Exact(sum) = statistics_args.exprs[0] .column_statistics(&statistics_args.statistics) - .ok() - ?.sum_value { - - sum.cast_to(statistics_args.return_type).ok() - } else { - None - }; - log::warn!("SUM STATS VALUE: {:?}", sum_value); - return sum_value; + .ok()? + .sum_value + { + return sum.cast_to(statistics_args.return_type).ok(); + } } } } diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 4b025c5b79e0..3435a2345e65 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -26,7 +26,9 @@ use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, not_impl_err, ColumnStatistics, Result, Statistics}; +use datafusion_common::{ + internal_err, not_impl_err, ColumnStatistics, Result, Statistics, +}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 7027e6d435bc..35109c3803b2 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -32,7 +32,9 @@ use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; use arrow_schema::ArrowError; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{internal_err, ColumnStatistics, Result, ScalarValue, Statistics}; +use datafusion_common::{ + internal_err, ColumnStatistics, Result, ScalarValue, Statistics, +}; use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::binary::get_result_type; @@ -544,8 +546,12 @@ impl PhysicalExpr for BinaryExpr { let left = self.left.column_statistics(statistics)?; let right = self.right.column_statistics(statistics)?; match self.op { - Operator::Plus => col_stats.sum_value = left.sum_value.add(&right.sum_value), - Operator::Minus => col_stats.sum_value = left.sum_value.sub(&right.sum_value), + Operator::Plus => { + col_stats.sum_value = left.sum_value.add(&right.sum_value) + } + Operator::Minus => { + col_stats.sum_value = left.sum_value.sub(&right.sum_value) + } _ => {} } } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index a7cad9a3e352..16b3bb154403 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -27,8 +27,8 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{ColumnStatistics, Result, ScalarValue, Statistics}; use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, Result, ScalarValue, Statistics}; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -100,8 +100,11 @@ impl PhysicalExpr for Literal { max_value: Precision::Exact(self.value.clone()), min_value: Precision::Exact(self.value.clone()), sum_value: if self.value.data_type().is_numeric() { - Precision::::from(statistics.num_rows).cast_to(&self.value.data_type()) - .map(|num_rows| Precision::Exact(self.value.clone()).multiply(&num_rows)) + Precision::::from(statistics.num_rows) + .cast_to(&self.value.data_type()) + .map(|num_rows| { + Precision::Exact(self.value.clone()).multiply(&num_rows) + }) .unwrap_or(Precision::Absent) } else { Precision::Absent diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index f10c705f49b8..9ecd1a6e7f07 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -354,14 +354,36 @@ fn stats_cartesian_product( distinct_count: s.distinct_count, min_value: s.min_value, max_value: s.max_value, - sum_value: s.sum_value.multiply(&right_row_count.into()), + sum_value: s + .sum_value + .get_value() + // Cast the row count into the same type as any existing sum value + .and_then(|v| { + Precision::::from(right_row_count) + .cast_to(&v.data_type()) + .ok() + }) + .map(|row_count| s.sum_value.multiply(&row_count)) + .unwrap_or(Precision::Absent), }) - .chain(right_col_stats.into_iter().map(|s| ColumnStatistics { - null_count: s.null_count.multiply(&left_row_count), - distinct_count: s.distinct_count, - min_value: s.min_value, - max_value: s.max_value, - sum_value: s.sum_value.multiply(&left_row_count.into()), + .chain(right_col_stats.into_iter().map(|s| { + ColumnStatistics { + null_count: s.null_count.multiply(&left_row_count), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + sum_value: s + .sum_value + .get_value() + // Cast the row count into the same type as any existing sum value + .and_then(|v| { + Precision::::from(left_row_count) + .cast_to(&v.data_type()) + .ok() + }) + .map(|row_count| s.sum_value.multiply(&row_count)) + .unwrap_or(Precision::Absent), + } })) .collect(); @@ -602,7 +624,7 @@ mod tests { distinct_count: Precision::Exact(1), max_value: Precision::Exact(ScalarValue::from("x")), min_value: Precision::Exact(ScalarValue::from("a")), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), + sum_value: Precision::Absent, null_count: Precision::Exact(3), }, ], @@ -630,7 +652,9 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42 * 20))), + sum_value: Precision::Exact(ScalarValue::Int64(Some( + 42 * right_row_count as i64, + ))), null_count: Precision::Exact(0), }, ColumnStatistics { @@ -644,7 +668,9 @@ mod tests { distinct_count: Precision::Exact(3), max_value: Precision::Exact(ScalarValue::Int64(Some(12))), min_value: Precision::Exact(ScalarValue::Int64(Some(0))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42 * 20))), + sum_value: Precision::Exact(ScalarValue::Int64(Some( + 20 * left_row_count as i64, + ))), null_count: Precision::Exact(2 * left_row_count), }, ], @@ -665,7 +691,7 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42 * 20))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), }, ColumnStatistics { @@ -685,7 +711,7 @@ mod tests { distinct_count: Precision::Exact(3), max_value: Precision::Exact(ScalarValue::Int64(Some(12))), min_value: Precision::Exact(ScalarValue::Int64(Some(0))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(20))), null_count: Precision::Exact(2), }], }; @@ -700,7 +726,7 @@ mod tests { distinct_count: Precision::Exact(5), max_value: Precision::Exact(ScalarValue::Int64(Some(21))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), + sum_value: Precision::Absent, // we don't know the row count on the right null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { @@ -714,7 +740,9 @@ mod tests { distinct_count: Precision::Exact(3), max_value: Precision::Exact(ScalarValue::Int64(Some(12))), min_value: Precision::Exact(ScalarValue::Int64(Some(0))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), + sum_value: Precision::Exact(ScalarValue::Int64(Some( + 20 * left_row_count as i64, + ))), null_count: Precision::Exact(2 * left_row_count), }, ], diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index f56ff96702c6..fca806c95b71 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1989,8 +1989,6 @@ mod tests { fn create_column_stats( min: Precision, max: Precision, - // FIXME(ngates): add these to test cases - // sum: Precision, distinct_count: Precision, null_count: Precision, ) -> ColumnStatistics { @@ -1998,7 +1996,6 @@ mod tests { distinct_count, min_value: min.map(ScalarValue::from), max_value: max.map(ScalarValue::from), - // sum_value: sum.map(ScalarValue::from), sum_value: Absent, null_count, } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index b46d5cfaa40e..d1e40f010bd3 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -353,7 +353,7 @@ mod tests { use crate::test; use arrow_schema::DataType; - use datafusion_common::{ScalarValue, ColumnStatistics}; + use datafusion_common::{ColumnStatistics, ScalarValue}; #[tokio::test] async fn project_no_column() -> Result<()> { @@ -418,7 +418,8 @@ mod tests { Arc::new(Column::new("col0", 0)), ]; - let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap(); + let result = + stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap(); let expected = Statistics { num_rows: Precision::Exact(5), @@ -454,7 +455,8 @@ mod tests { Arc::new(Column::new("col0", 0)), ]; - let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap(); + let result = + stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap(); let expected = Statistics { num_rows: Precision::Exact(5), diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 24c6f062c687..22e5baf95938 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -572,10 +572,7 @@ impl Stream for CombinedRecordBatchStream { } } -fn col_stats_union( - left: ColumnStatistics, - right: ColumnStatistics, -) -> ColumnStatistics { +fn col_stats_union(left: ColumnStatistics, right: ColumnStatistics) -> ColumnStatistics { ColumnStatistics { null_count: left.null_count.add(&right.null_count), max_value: left.max_value.max(&right.max_value), @@ -729,7 +726,7 @@ mod tests { distinct_count: Precision::Absent, max_value: Precision::Exact(ScalarValue::Int64(Some(34))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(84))), null_count: Precision::Exact(1), }, ColumnStatistics { From 3879ab154fd6363abfae68dcb9e9e6818b930667 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 11 Dec 2024 18:28:53 +0000 Subject: [PATCH 3/7] Fix tests --- datafusion/functions-aggregate/src/sum.rs | 40 +++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 75423d302733..e4739e6022ca 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -483,3 +483,43 @@ impl Accumulator for DistinctSumAccumulator { size_of_val(self) + self.values.capacity() * size_of::() } } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::{ColumnStatistics, Statistics}; + use datafusion_physical_expr::expressions::Column; + use std::sync::Arc; + + #[test] + fn sum() { + let agg = Box::new(Sum::new()); + let statistics = Statistics { + num_rows: Precision::Exact(5), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Exact(ScalarValue::Int8(Some(10))), + distinct_count: Default::default(), + }], + }; + let mut statistics_args = StatisticsArgs { + statistics: &statistics, + return_type: &DataType::Int64, + is_distinct: false, + exprs: &[Arc::new(Column::new("a", 0))], + }; + + // Ensure that the sum statistic is used and cast to the return type. + assert_eq!( + agg.value_from_stats(&statistics_args), + Some(ScalarValue::Int64(Some(10))) + ); + + // With a distinct aggregate, the sum statistic isn't helpful + statistics_args.is_distinct = true; + assert_eq!(agg.value_from_stats(&statistics_args), None); + } +} From caf6a487a4c478ca4ed44439617ae9c633a8330a Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 11 Dec 2024 18:58:27 +0000 Subject: [PATCH 4/7] Fix tests --- datafusion/functions-aggregate/src/average.rs | 73 ++++++++++++++++++- datafusion/functions-aggregate/src/count.rs | 21 +++--- 2 files changed, 81 insertions(+), 13 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 18874f831e9d..c5da6a136146 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -34,7 +34,7 @@ use datafusion_expr::utils::format_state_name; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, - ReversedUDAF, Signature, + ReversedUDAF, Signature, StatisticsArgs, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; @@ -42,6 +42,8 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls: filtered_null_mask, set_nulls, }; +use crate::sum::Sum; +use datafusion_common::stats::Precision; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; @@ -253,6 +255,35 @@ impl AggregateUDFImpl for Avg { coerce_avg_type(self.name(), arg_types) } + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + if statistics_args.is_distinct { + return None; + } + + if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { + match *num_rows { + 0 => return ScalarValue::new_zero(statistics_args.return_type).ok(), + _ => { + if statistics_args.exprs.len() == 1 { + if let Precision::Exact(sum) = statistics_args.exprs[0] + .column_statistics(&statistics_args.statistics) + .ok()? + .sum_value + { + let sum = sum.cast_to(statistics_args.return_type).ok()?; + let num_rows = + ScalarValue::from(statistics_args.statistics.num_rows) + .cast_to(statistics_args.return_type) + .ok()?; + return sum.div(&num_rows).ok(); + } + } + } + } + } + None + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -606,3 +637,43 @@ where self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::{ColumnStatistics, Statistics}; + use datafusion_physical_expr::expressions::Column; + use std::sync::Arc; + + #[test] + fn sum() { + let agg = Box::new(Avg::new()); + let statistics = Statistics { + num_rows: Precision::Exact(5), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Exact(ScalarValue::Int8(Some(10))), + distinct_count: Default::default(), + }], + }; + let mut statistics_args = StatisticsArgs { + statistics: &statistics, + return_type: &DataType::Float64, + is_distinct: false, + exprs: &[Arc::new(Column::new("a", 0))], + }; + + // Ensure that the sum statistic is used and cast to the return type. + assert_eq!( + agg.value_from_stats(&statistics_args), + Some(ScalarValue::Float64(Some(2.0))) + ); + + // With a distinct aggregate, the sum statistic isn't helpful + statistics_args.is_distinct = true; + assert_eq!(agg.value_from_stats(&statistics_args), None); + } +} diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index b4164c211c35..73136dd7a4f6 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -324,18 +324,7 @@ impl AggregateUDFImpl for Count { } if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows { if statistics_args.exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = statistics_args.exprs[0] - .as_any() - .downcast_ref::() - { - let current_val = &statistics_args.statistics.column_statistics - [col_expr.index()] - .null_count; - if let &Precision::Exact(val) = current_val { - return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); - } - } else if let Some(lit_expr) = statistics_args.exprs[0] + if let Some(lit_expr) = statistics_args.exprs[0] .as_any() .downcast_ref::() { @@ -343,6 +332,14 @@ impl AggregateUDFImpl for Count { return Some(ScalarValue::Int64(Some(num_rows as i64))); } } + + let col_stats = statistics_args.exprs[0] + .column_statistics(&statistics_args.statistics) + .ok()?; + let current_val = col_stats.null_count; + if let &Precision::Exact(val) = current_val { + return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); + } } } None From 05f480fbce0307fed0a22c489d76c5ccce4aab91 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 12 Dec 2024 11:02:16 +0000 Subject: [PATCH 5/7] Fix tests --- datafusion/functions-aggregate/src/average.rs | 5 ++--- datafusion/functions-aggregate/src/count.rs | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index c5da6a136146..ea22fdb02a49 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -42,7 +42,6 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls: filtered_null_mask, set_nulls, }; -use crate::sum::Sum; use datafusion_common::stats::Precision; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; @@ -263,7 +262,7 @@ impl AggregateUDFImpl for Avg { if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { match *num_rows { 0 => return ScalarValue::new_zero(statistics_args.return_type).ok(), - _ => { + num_rows => { if statistics_args.exprs.len() == 1 { if let Precision::Exact(sum) = statistics_args.exprs[0] .column_statistics(&statistics_args.statistics) @@ -272,7 +271,7 @@ impl AggregateUDFImpl for Avg { { let sum = sum.cast_to(statistics_args.return_type).ok()?; let num_rows = - ScalarValue::from(statistics_args.statistics.num_rows) + ScalarValue::from(num_rows as u64) .cast_to(statistics_args.return_type) .ok()?; return sum.div(&num_rows).ok(); diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 73136dd7a4f6..2865cc3f33c7 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -336,8 +336,7 @@ impl AggregateUDFImpl for Count { let col_stats = statistics_args.exprs[0] .column_statistics(&statistics_args.statistics) .ok()?; - let current_val = col_stats.null_count; - if let &Precision::Exact(val) = current_val { + if let Precision::Exact(val) = col_stats.null_count { return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); } } From 531dde494b11303cce896f05243a302694b39160 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 12 Dec 2024 15:18:03 +0000 Subject: [PATCH 6/7] Fix cast column_statistics to not raise on bad cast --- datafusion/functions-aggregate/src/average.rs | 7 +++---- datafusion/physical-expr/src/expressions/cast.rs | 11 +++++++++-- .../proto/src/generated/datafusion_proto_common.rs | 2 ++ 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index ea22fdb02a49..08793ff6a5b8 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -270,10 +270,9 @@ impl AggregateUDFImpl for Avg { .sum_value { let sum = sum.cast_to(statistics_args.return_type).ok()?; - let num_rows = - ScalarValue::from(num_rows as u64) - .cast_to(statistics_args.return_type) - .ok()?; + let num_rows = ScalarValue::from(num_rows as u64) + .cast_to(statistics_args.return_type) + .ok()?; return sum.div(&num_rows).ok(); } } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index aa3973327da5..0f0c7c37c16c 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -26,6 +26,7 @@ use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, ColumnStatistics, Result, Statistics}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -199,8 +200,14 @@ impl PhysicalExpr for CastExpr { let child_stats = self.expr.column_statistics(statistics)?; Ok(ColumnStatistics { null_count: child_stats.null_count, - max_value: child_stats.max_value.cast_to(&self.cast_type)?, - min_value: child_stats.min_value.cast_to(&self.cast_type)?, + max_value: child_stats + .max_value + .cast_to(&self.cast_type) + .unwrap_or(Precision::Absent), + min_value: child_stats + .min_value + .cast_to(&self.cast_type) + .unwrap_or(Precision::Absent), // The sum value may be a wider numeric type than the data value, it's not safe to cast. sum_value: child_stats.sum_value, distinct_count: Default::default(), diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index fa77d23a6ae6..6cc7652265c0 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -867,6 +867,8 @@ pub struct ColumnStats { pub min_value: ::core::option::Option, #[prost(message, optional, tag = "2")] pub max_value: ::core::option::Option, + #[prost(message, optional, tag = "5")] + pub sum_value: ::core::option::Option, #[prost(message, optional, tag = "3")] pub null_count: ::core::option::Option, #[prost(message, optional, tag = "4")] From 93beb1226a9c230e2d2828d58a06e5294223ad47 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 12 Dec 2024 15:27:25 +0000 Subject: [PATCH 7/7] Make clippy happy --- datafusion/functions-aggregate/src/average.rs | 2 +- datafusion/functions-aggregate/src/count.rs | 2 +- datafusion/functions-aggregate/src/sum.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 08793ff6a5b8..653ed0291093 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -265,7 +265,7 @@ impl AggregateUDFImpl for Avg { num_rows => { if statistics_args.exprs.len() == 1 { if let Precision::Exact(sum) = statistics_args.exprs[0] - .column_statistics(&statistics_args.statistics) + .column_statistics(statistics_args.statistics) .ok()? .sum_value { diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2865cc3f33c7..b384f60a3a38 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -334,7 +334,7 @@ impl AggregateUDFImpl for Count { } let col_stats = statistics_args.exprs[0] - .column_statistics(&statistics_args.statistics) + .column_statistics(statistics_args.statistics) .ok()?; if let Precision::Exact(val) = col_stats.null_count { return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index e4739e6022ca..354a6b3318e9 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -268,7 +268,7 @@ impl AggregateUDFImpl for Sum { _ => { if statistics_args.exprs.len() == 1 { if let Precision::Exact(sum) = statistics_args.exprs[0] - .column_statistics(&statistics_args.statistics) + .column_statistics(statistics_args.statistics) .ok()? .sum_value {