diff --git a/datafusion/physical-expr-common/src/stats.rs b/datafusion/physical-expr-common/src/stats.rs index e575ed91a36b..254e1699e110 100644 --- a/datafusion/physical-expr-common/src/stats.rs +++ b/datafusion/physical-expr-common/src/stats.rs @@ -1,4 +1,4 @@ -use crate::stats::StatisticsV2::{Exponential, Gaussian, Uniform, Unknown}; +use crate::stats::StatisticsV2::{Bernoulli, Exponential, Gaussian, Uniform, Unknown}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -18,6 +18,9 @@ pub enum StatisticsV2 { mean: ScalarValue, variance: ScalarValue, }, + Bernoulli { + p: ScalarValue, + }, Unknown { mean: Option, median: Option, @@ -26,17 +29,6 @@ pub enum StatisticsV2 { }, } -impl Default for StatisticsV2 { - fn default() -> Self { - Unknown { - mean: None, - median: None, - variance: None, - range: Interval::make_unbounded(&DataType::Null).unwrap(), - } - } -} - impl StatisticsV2 { pub fn new_unknown() -> Self { Unknown { @@ -61,18 +53,22 @@ impl StatisticsV2 { } let zero = &ScalarValue::new_zero(&rate.data_type()).unwrap(); rate.gt(zero) - } + }, Gaussian { variance, .. } => { if variance.is_null() { return false; } let zero = &ScalarValue::new_zero(&variance.data_type()).unwrap(); variance.ge(zero) - } + }, + Bernoulli {p} => { + p.ge(&ScalarValue::new_zero(&DataType::Float64).unwrap()) + && p.le(&ScalarValue::new_one(&DataType::Float64).unwrap()) + }, Unknown { mean, median, - variance: std_dev, + variance, range, } => { if let (Some(mn), Some(md)) = (mean, median) { @@ -80,11 +76,11 @@ impl StatisticsV2 { return false; } range.contains_value(mn).unwrap() && range.contains_value(md).unwrap() - } else if let Some(dev) = std_dev { - if dev.is_null() { + } else if let Some(v) = variance { + if v.is_null() { return false; } - dev.gt(&ScalarValue::new_zero(&dev.data_type()).unwrap()) + v.gt(&ScalarValue::new_zero(&v.data_type()).unwrap()) } else { false } @@ -98,6 +94,7 @@ impl StatisticsV2 { /// by addition of upper and lower bound and dividing the result by 2. /// - [`Exponential`] distribution mean is calculable by formula: 1/λ. λ must be non-negative. /// - [`Gaussian`] distribution has it explicitly + /// - [`Bernoulli`] mean is `p` /// - [`Unknown`] distribution _may_ have it explicitly pub fn mean(&self) -> datafusion_common::Result> { if !self.is_valid() { @@ -122,6 +119,7 @@ impl StatisticsV2 { Ok(None) } Gaussian { mean, .. } => Ok(Some(mean.clone())), + Bernoulli { p } => Ok(Some(p.clone())), Unknown { mean, .. } => Ok(mean.clone()), } } @@ -155,6 +153,13 @@ impl StatisticsV2 { Ok(None) } Gaussian { mean, .. } => Ok(Some(mean.clone())), + Bernoulli { p } => { + if p.gt(&ScalarValue::Float64(Some(0.5))) { + Ok(Some(ScalarValue::new_one(&DataType::Float64)?)) + } else { + Ok(Some(ScalarValue::new_zero(&DataType::Float64)?)) + } + }, Unknown { median, .. } => Ok(median.clone()) } } @@ -177,19 +182,20 @@ impl StatisticsV2 { } else { Ok(None) } - } + }, Exponential { rate, .. } => { let one = &ScalarValue::new_one(&rate.data_type())?; - let rate_squared = rate.mul(rate); - if rate_squared.is_err() { - return Ok(None); - } - if let Ok(variance) = one.div(rate_squared.unwrap()) { + let rate_squared = rate.mul(rate)?; + if let Ok(variance) = one.div(rate_squared) { return Ok(Some(variance)); } Ok(None) - } + }, Gaussian { variance, .. } => Ok(Some(variance.clone())), + Bernoulli { p} => { + let one = &ScalarValue::new_one(&DataType::Float64)?; + Ok(Some(one.sub_checked(p)?.mul_checked(p)?)) + }, Unknown { variance, .. } => Ok(variance.clone()) } } @@ -200,6 +206,7 @@ impl StatisticsV2 { pub fn range(&self) -> Option<&Interval> { match &self { Uniform { interval, .. } => Some(interval), + Bernoulli { .. } => Some(&Interval::UNCERTAIN), Unknown { range, .. } => Some(range), _ => None, } @@ -210,6 +217,7 @@ impl StatisticsV2 { pub fn data_type(&self) -> Option { match &self { Uniform { interval, .. } => Some(interval.data_type()), + Bernoulli { p } => Some(p.data_type()), Unknown { range, .. } => Some(range.data_type()), _ => None, } @@ -328,6 +336,39 @@ mod tests { } } + #[test] + fn bernoulli_stats_is_valid_test() { + let gaussian_stats = vec![ + ( + StatisticsV2::Bernoulli { p: ScalarValue::Null }, + false, + ), + ( + StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.25)) }, + true, + ), + ( + StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.)) }, + true, + ), + ( + StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(1.)) }, + true, + ), + ( + StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(10.)) }, + false, + ), + ( + StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(-10.)) }, + false, + ), + ]; + for case in gaussian_stats { + assert_eq!(case.0.is_valid(), case.1); + } + } + #[test] fn unknown_stats_is_valid_test() { let unknown_stats = vec![ @@ -484,6 +525,17 @@ mod tests { )); //endregion + // region bernoulli + stats.push(( + StatisticsV2::Bernoulli {p: ScalarValue::Null}, + None, + )); + stats.push(( + StatisticsV2::Bernoulli {p: ScalarValue::Float64(Some(0.5))}, + Some(ScalarValue::Float64(Some(0.5))), + )); + //endregion + //region unknown stats.push(( StatisticsV2::Unknown { @@ -561,6 +613,21 @@ mod tests { }, Some(ScalarValue::Float64(Some(2.))), ), + ( + StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.25)) }, + Some(ScalarValue::Float64(Some(0.))), + ), + ( + StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.75)) }, + Some(ScalarValue::Float64(Some(1.))), + ), + ( + StatisticsV2::Gaussian { + mean: ScalarValue::Float64(Some(2.)), + variance: ScalarValue::Float64(Some(1.)), + }, + Some(ScalarValue::Float64(Some(2.))), + ), ( StatisticsV2::Unknown { mean: None, @@ -630,6 +697,10 @@ mod tests { }, Some(ScalarValue::Float64(Some(1.))), ), + ( + StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.5)) }, + Some(ScalarValue::Float64(Some(0.25))), + ), ( StatisticsV2::Unknown { mean: None, diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index f9896b94264d..77d0e81f44d2 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -24,7 +24,7 @@ use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; use crate::expressions::binary::kernels::concat_elements_utf8view; -use crate::utils::stats::new_unknown_from_binary_expr; +use crate::utils::stats::{new_bernoulli_from_binary_expr, new_unknown_from_binary_expr}; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::cmp::*; @@ -41,7 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; use datafusion_physical_expr_common::stats::StatisticsV2; -use datafusion_physical_expr_common::stats::StatisticsV2::{Gaussian, Uniform}; +use datafusion_physical_expr_common::stats::StatisticsV2::{Bernoulli, Gaussian}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, @@ -503,13 +503,14 @@ impl PhysicalExpr for BinaryExpr { let (left_stat, right_stat) = (children_stat[0], children_stat[1]); // We can evaluate statistics only with numeric data types on stats. - // TODO: move the data type check to the the higher levels. + // TODO: move the data type check to the the higher levels, if possible. if (left_stat.data_type().is_none() || right_stat.data_type().is_none()) || !left_stat.data_type().unwrap().is_numeric() || !right_stat.data_type().unwrap().is_numeric() { return Ok(StatisticsV2::new_unknown()); } + // TODO, to think: maybe, we can separate also Unknown + Unknown // just for clarity and better reader understanding, how exactly // mean, median, variance and range are computed, if it is possible. @@ -536,22 +537,30 @@ impl PhysicalExpr for BinaryExpr { } }, Operator::Divide => new_unknown_from_binary_expr(&self.op, left_stat, right_stat), - Operator::Eq => { + Operator::Eq | Operator::NotEq | Operator::LtEq | Operator::GtEq + | Operator::Lt | Operator::Gt => { + new_bernoulli_from_binary_expr(&self.op, left_stat, right_stat) + }, + Operator::And => { match (left_stat, right_stat) { - (Uniform { .. }, Uniform { .. }) => { - let intersection = left_stat.range().unwrap() - .intersect(right_stat.range().unwrap())?; - if let Some(interval) = intersection { - Ok(Uniform { interval }) - } else { - Ok(Uniform { interval: Interval::CERTAINLY_FALSE }) - } + (Bernoulli { p: p_left }, Bernoulli { p: p_right }, ) => { + Ok(Bernoulli { p : p_left.mul_checked(p_right)? }) }, + // TODO: complement with more cases (_, _) => new_unknown_from_binary_expr(&self.op, left_stat, right_stat) } }, - Operator::NotEq => new_unknown_from_binary_expr(&self.op, left_stat, right_stat), - // TODO: express gt/ge/lt/le operations + Operator::Or => { + match (left_stat, right_stat) { + (Bernoulli { p: p_left }, Bernoulli { p: p_right }, ) => { + Ok(Bernoulli { + p : ScalarValue::Float64(Some(1.)).sub(p_left.mul_checked(p_right)?)? + }) + }, + // TODO: complement with more cases + (_, _) => new_unknown_from_binary_expr(&self.op, left_stat, right_stat) + } + } _ => new_unknown_from_binary_expr(&self.op, left_stat, right_stat) } } @@ -812,6 +821,8 @@ mod tests { use crate::expressions::{col, lit, try_cast, Column, Literal}; use datafusion_common::plan_datafusion_err; use datafusion_expr::type_coercion::binary::get_input_types; + use datafusion_physical_expr_common::stats::StatisticsV2::Uniform; + // TODO: remove //region tests /// Performs a binary operation, applying any type coercion necessary diff --git a/datafusion/physical-expr/src/utils/stats.rs b/datafusion/physical-expr/src/utils/stats.rs index adf7b05352ac..4bd90ec46e82 100644 --- a/datafusion/physical-expr/src/utils/stats.rs +++ b/datafusion/physical-expr/src/utils/stats.rs @@ -7,7 +7,7 @@ use datafusion_common::ScalarValue::Float64; use datafusion_expr_common::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::stats::StatisticsV2; -use datafusion_physical_expr_common::stats::StatisticsV2::{Exponential, Uniform, Unknown}; +use datafusion_physical_expr_common::stats::StatisticsV2::{Bernoulli, Exponential, Uniform, Unknown}; use log::debug; use petgraph::adj::DefaultIx; use petgraph::prelude::Bfs; @@ -202,6 +202,18 @@ pub fn new_unknown_from_binary_expr( }) } +/// Creates a new [`Bernoulli`] distribution, and tries to compute the result probability. +/// TODO: implement properly, temporarily always returns 1. +pub fn new_bernoulli_from_binary_expr( + _op: &Operator, + _left: &StatisticsV2, + _right: &StatisticsV2 +) -> datafusion_common::Result { + Ok(Bernoulli { + p: Float64(Some(1.)) + }) +} + /// Computes a mean value for a given binary operator and two statistics. /// The result is calculated based on the operator type for any statistics kind. pub fn compute_mean( @@ -334,6 +346,9 @@ pub fn compute_variance( pub fn compute_range(op: &Operator, left_stat: &StatisticsV2, right_stat: &StatisticsV2) -> datafusion_common::Result { + if !left_stat.is_valid() || !right_stat.is_valid() { + return Interval::make_unbounded(&DataType::Float64); + } match (left_stat, right_stat) { (Uniform { interval: l }, Uniform { interval: r }) | (Uniform { interval: l }, Unknown { range: r, .. }) @@ -344,18 +359,6 @@ pub fn compute_range(op: &Operator, left_stat: &StatisticsV2, right_stat: &Stati | Operator::Gt | Operator::GtEq | Operator::Lt | Operator::LtEq => { apply_operator(op, l, r) }, - Operator::Eq => { - // Note: unwrap is legit, because Uniform & Unknown always have ranges - if let Some(intersection) = left_stat.range().unwrap() - .intersect(right_stat.range().unwrap())? { - Ok(intersection) - } else if let Some(data_type) = left_stat.data_type() { - Interval::make_unbounded(&data_type) - } else { - Interval::make_unbounded(&DataType::Float64) - } - }, - Operator::NotEq => Ok(Interval::CERTAINLY_FALSE), _ => Interval::make_unbounded(&DataType::Float64) } } @@ -370,7 +373,7 @@ mod tests { use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::Float64; use datafusion_expr_common::interval_arithmetic::{apply_operator, Interval}; - use datafusion_expr_common::operator::Operator::{Divide, Minus, Multiply, Eq, NotEq, Plus, Gt, GtEq, Lt, LtEq}; + use datafusion_expr_common::operator::Operator::{Gt, GtEq, Lt, LtEq, Minus, Multiply, Plus}; use datafusion_physical_expr_common::stats::StatisticsV2::{Uniform, Unknown}; type Actual = Option; @@ -417,19 +420,24 @@ mod tests { fn test_compute_range_where_present() -> datafusion_common::Result<()> { let a = &Interval::make(Some(0.), Some(12.0))?; let b = &Interval::make(Some(0.), Some(12.0))?; + let _mean = Some(Float64(Some(6.0))); for (stat_a, stat_b) in [ (Uniform { interval: a.clone() }, Uniform { interval: b.clone() }), - (Unknown { mean: None, median: None, variance: None, range: a.clone() }, Uniform { interval: b.clone() }), - (Uniform { interval: a.clone() }, Unknown { mean: None, median: None, variance: None, range: b.clone() }), - (Unknown { mean: None, median: None, variance: None, range: a.clone() }, - Unknown { mean: None, median: None, variance: None, range: b.clone() })] { + (Unknown { mean: _mean.clone(), median: _mean.clone(), variance: None, range: a.clone() }, + Uniform { interval: b.clone() }), + (Uniform { interval: a.clone() }, + Unknown { mean: _mean.clone(), median: _mean.clone(), variance: None, range: b.clone() }), + (Unknown {mean: _mean.clone(), median: _mean.clone(), variance: None, range: a.clone() }, + Unknown { mean: _mean.clone(), median: _mean.clone(), variance: None, range: b.clone() })] { // range - for op in [Plus, Minus, Multiply, Divide, Gt, GtEq, Lt, LtEq] { - assert_eq!(compute_range(&op, &stat_a, &stat_b)?, apply_operator(&op, a, b)?); + for op in [Plus, Minus, Multiply, Gt, GtEq, Lt, LtEq] { + assert_eq!( + compute_range(&op, &stat_a, &stat_b)?, + apply_operator(&op, a, b)?, + "{}", format!("Failed for {:?} {op} {:?}", stat_a, stat_b), + ); } - assert_eq!(compute_range(&Eq, &stat_a, &stat_b)?, a.intersect(b)?.unwrap()); - assert_eq!(compute_range(&NotEq, &stat_a, &stat_b)?, Interval::CERTAINLY_FALSE); } Ok(())