Skip to content

Commit

Permalink
Add value_from_statisics to AggregateUDFImpl, remove special case f…
Browse files Browse the repository at this point in the history
…or min/max/count aggregate statistics (#12296)

* Removes min/max/count comparison based on name in aggregate statistics

* Abstracting away value from statistics

* Removing imports

* Introduced StatisticsArgs

* Fixed docs
  • Loading branch information
edmondop authored Sep 30, 2024
1 parent ddb4fac commit 0242767
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 170 deletions.
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub use logical_plan::*;
pub use partition_evaluator::PartitionEvaluator;
pub use sqlparser;
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
27 changes: 26 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use std::vec;

use arrow::datatypes::{DataType, Field};

use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

use crate::expr::AggregateFunction;
use crate::function::{
Expand Down Expand Up @@ -94,6 +95,19 @@ impl fmt::Display for AggregateUDF {
}
}

pub struct StatisticsArgs<'a> {
pub statistics: &'a Statistics,
pub return_type: &'a DataType,
/// Whether the aggregate function is distinct.
///
/// ```sql
/// SELECT COUNT(DISTINCT column1) FROM t;
/// ```
pub is_distinct: bool,
/// The physical expression of arguments the aggregate function takes.
pub exprs: &'a [Arc<dyn PhysicalExpr>],
}

impl AggregateUDF {
/// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
///
Expand Down Expand Up @@ -244,6 +258,13 @@ impl AggregateUDF {
self.inner.is_descending()
}

pub fn value_from_stats(
&self,
statistics_args: &StatisticsArgs,
) -> Option<ScalarValue> {
self.inner.value_from_stats(statistics_args)
}

/// See [`AggregateUDFImpl::default_value`] for more details.
pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
self.inner.default_value(data_type)
Expand Down Expand Up @@ -556,6 +577,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn is_descending(&self) -> Option<bool> {
None
}
// Return the value of the current UDF from the statistics
fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
None
}

/// Returns default value of the function given the input is all `null`.
///
Expand Down
35 changes: 34 additions & 1 deletion datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use ahash::RandomState;
use datafusion_common::stats::Precision;
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use datafusion_physical_expr::expressions;
use std::collections::HashSet;
use std::ops::BitAnd;
use std::{fmt::Debug, sync::Arc};
Expand Down Expand Up @@ -46,14 +48,15 @@ use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
EmitTo, GroupsAccumulator, Signature, Volatility,
};
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
PrimitiveDistinctCountAccumulator,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::binary_map::OutputType;

use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
make_udaf_expr_and_func!(
Count,
count,
Expand Down Expand Up @@ -291,6 +294,36 @@ impl AggregateUDFImpl for Count {
fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(0)))
}

fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
if statistics_args.is_distinct {
return None;
}
if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
if statistics_args.exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
let current_val = &statistics_args.statistics.column_statistics
[col_expr.index()]
.null_count;
if let &Precision::Exact(val) = current_val {
return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
}
} else if let Some(lit_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some(ScalarValue::Int64(Some(num_rows as i64)));
}
}
}
}
None
}
}

#[derive(Debug)]
Expand Down
77 changes: 74 additions & 3 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// under the License.

//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function
//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
Expand Down Expand Up @@ -49,10 +49,12 @@ use arrow::datatypes::{
UInt8Type,
};
use arrow_schema::IntervalUnit;
use datafusion_common::stats::Precision;
use datafusion_common::{
downcast_value, exec_err, internal_err, DataFusionError, Result,
downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
use datafusion_physical_expr::expressions;
use std::fmt::Debug;

use arrow::datatypes::i256;
Expand All @@ -63,10 +65,10 @@ use arrow::datatypes::{
};

use datafusion_common::ScalarValue;
use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
use half::f16;
use std::ops::Deref;

Expand Down Expand Up @@ -147,6 +149,54 @@ macro_rules! instantiate_min_accumulator {
}};
}

trait FromColumnStatistics {
fn value_from_column_statistics(
&self,
stats: &ColumnStatistics,
) -> Option<ScalarValue>;

fn value_from_statistics(
&self,
statistics_args: &StatisticsArgs,
) -> Option<ScalarValue> {
if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
match *num_rows {
0 => return ScalarValue::try_from(statistics_args.return_type).ok(),
value if value > 0 => {
let col_stats = &statistics_args.statistics.column_statistics;
if statistics_args.exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
return self.value_from_column_statistics(
&col_stats[col_expr.index()],
);
}
}
}
_ => {}
}
}
None
}
}

impl FromColumnStatistics for Max {
fn value_from_column_statistics(
&self,
col_stats: &ColumnStatistics,
) -> Option<ScalarValue> {
if let Precision::Exact(ref val) = col_stats.max_value {
if !val.is_null() {
return Some(val.clone());
}
}
None
}
}

impl AggregateUDFImpl for Max {
fn as_any(&self) -> &dyn std::any::Any {
self
Expand Down Expand Up @@ -272,6 +322,7 @@ impl AggregateUDFImpl for Max {
fn is_descending(&self) -> Option<bool> {
Some(true)
}

fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
}
Expand All @@ -282,6 +333,9 @@ impl AggregateUDFImpl for Max {
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
datafusion_expr::ReversedUDAF::Identical
}
fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
self.value_from_statistics(statistics_args)
}
}

// Statically-typed version of min/max(array) -> ScalarValue for string types
Expand Down Expand Up @@ -926,6 +980,20 @@ impl Default for Min {
}
}

impl FromColumnStatistics for Min {
fn value_from_column_statistics(
&self,
col_stats: &ColumnStatistics,
) -> Option<ScalarValue> {
if let Precision::Exact(ref val) = col_stats.min_value {
if !val.is_null() {
return Some(val.clone());
}
}
None
}
}

impl AggregateUDFImpl for Min {
fn as_any(&self) -> &dyn std::any::Any {
self
Expand Down Expand Up @@ -1052,6 +1120,9 @@ impl AggregateUDFImpl for Min {
Some(false)
}

fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
self.value_from_statistics(statistics_args)
}
fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
}
Expand Down
Loading

0 comments on commit 0242767

Please sign in to comment.