Skip to content

Commit

Permalink
Introduce Bernoulli distribution to be used as result of comparisons …
Browse files Browse the repository at this point in the history
…and inequations distribution combinations
  • Loading branch information
Fly-Style committed Jan 17, 2025
1 parent d52af46 commit 81b756d
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 61 deletions.
121 changes: 96 additions & 25 deletions datafusion/physical-expr-common/src/stats.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::stats::StatisticsV2::{Exponential, Gaussian, Uniform, Unknown};
use crate::stats::StatisticsV2::{Bernoulli, Exponential, Gaussian, Uniform, Unknown};
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_expr_common::interval_arithmetic::Interval;
Expand All @@ -18,6 +18,9 @@ pub enum StatisticsV2 {
mean: ScalarValue,
variance: ScalarValue,
},
Bernoulli {
p: ScalarValue,
},
Unknown {
mean: Option<ScalarValue>,
median: Option<ScalarValue>,
Expand All @@ -26,17 +29,6 @@ pub enum StatisticsV2 {
},
}

impl Default for StatisticsV2 {
fn default() -> Self {
Unknown {
mean: None,
median: None,
variance: None,
range: Interval::make_unbounded(&DataType::Null).unwrap(),
}
}
}

impl StatisticsV2 {
pub fn new_unknown() -> Self {
Unknown {
Expand All @@ -61,30 +53,34 @@ impl StatisticsV2 {
}
let zero = &ScalarValue::new_zero(&rate.data_type()).unwrap();
rate.gt(zero)
}
},
Gaussian { variance, .. } => {
if variance.is_null() {
return false;
}
let zero = &ScalarValue::new_zero(&variance.data_type()).unwrap();
variance.ge(zero)
}
},
Bernoulli {p} => {
p.ge(&ScalarValue::new_zero(&DataType::Float64).unwrap())
&& p.le(&ScalarValue::new_one(&DataType::Float64).unwrap())
},
Unknown {
mean,
median,
variance: std_dev,
variance,
range,
} => {
if let (Some(mn), Some(md)) = (mean, median) {
if mn.is_null() || md.is_null() {
return false;
}
range.contains_value(mn).unwrap() && range.contains_value(md).unwrap()
} else if let Some(dev) = std_dev {
if dev.is_null() {
} else if let Some(v) = variance {
if v.is_null() {
return false;
}
dev.gt(&ScalarValue::new_zero(&dev.data_type()).unwrap())
v.gt(&ScalarValue::new_zero(&v.data_type()).unwrap())
} else {
false
}
Expand All @@ -98,6 +94,7 @@ impl StatisticsV2 {
/// by addition of upper and lower bound and dividing the result by 2.
/// - [`Exponential`] distribution mean is calculable by formula: 1/λ. λ must be non-negative.
/// - [`Gaussian`] distribution has it explicitly
/// - [`Bernoulli`] mean is `p`
/// - [`Unknown`] distribution _may_ have it explicitly
pub fn mean(&self) -> datafusion_common::Result<Option<ScalarValue>> {
if !self.is_valid() {
Expand All @@ -122,6 +119,7 @@ impl StatisticsV2 {
Ok(None)
}
Gaussian { mean, .. } => Ok(Some(mean.clone())),
Bernoulli { p } => Ok(Some(p.clone())),
Unknown { mean, .. } => Ok(mean.clone()),
}
}
Expand Down Expand Up @@ -155,6 +153,13 @@ impl StatisticsV2 {
Ok(None)
}
Gaussian { mean, .. } => Ok(Some(mean.clone())),
Bernoulli { p } => {
if p.gt(&ScalarValue::Float64(Some(0.5))) {
Ok(Some(ScalarValue::new_one(&DataType::Float64)?))
} else {
Ok(Some(ScalarValue::new_zero(&DataType::Float64)?))
}
},
Unknown { median, .. } => Ok(median.clone())
}
}
Expand All @@ -177,19 +182,20 @@ impl StatisticsV2 {
} else {
Ok(None)
}
}
},
Exponential { rate, .. } => {
let one = &ScalarValue::new_one(&rate.data_type())?;
let rate_squared = rate.mul(rate);
if rate_squared.is_err() {
return Ok(None);
}
if let Ok(variance) = one.div(rate_squared.unwrap()) {
let rate_squared = rate.mul(rate)?;
if let Ok(variance) = one.div(rate_squared) {
return Ok(Some(variance));
}
Ok(None)
}
},
Gaussian { variance, .. } => Ok(Some(variance.clone())),
Bernoulli { p} => {
let one = &ScalarValue::new_one(&DataType::Float64)?;
Ok(Some(one.sub_checked(p)?.mul_checked(p)?))
},
Unknown { variance, .. } => Ok(variance.clone())
}
}
Expand All @@ -200,6 +206,7 @@ impl StatisticsV2 {
pub fn range(&self) -> Option<&Interval> {
match &self {
Uniform { interval, .. } => Some(interval),
Bernoulli { .. } => Some(&Interval::UNCERTAIN),
Unknown { range, .. } => Some(range),
_ => None,
}
Expand All @@ -210,6 +217,7 @@ impl StatisticsV2 {
pub fn data_type(&self) -> Option<DataType> {
match &self {
Uniform { interval, .. } => Some(interval.data_type()),
Bernoulli { p } => Some(p.data_type()),
Unknown { range, .. } => Some(range.data_type()),
_ => None,
}
Expand Down Expand Up @@ -328,6 +336,39 @@ mod tests {
}
}

#[test]
fn bernoulli_stats_is_valid_test() {
let gaussian_stats = vec![
(
StatisticsV2::Bernoulli { p: ScalarValue::Null },
false,
),
(
StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.25)) },
true,
),
(
StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.)) },
true,
),
(
StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(1.)) },
true,
),
(
StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(10.)) },
false,
),
(
StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(-10.)) },
false,
),
];
for case in gaussian_stats {
assert_eq!(case.0.is_valid(), case.1);
}
}

#[test]
fn unknown_stats_is_valid_test() {
let unknown_stats = vec![
Expand Down Expand Up @@ -484,6 +525,17 @@ mod tests {
));
//endregion

// region bernoulli
stats.push((
StatisticsV2::Bernoulli {p: ScalarValue::Null},
None,
));
stats.push((
StatisticsV2::Bernoulli {p: ScalarValue::Float64(Some(0.5))},
Some(ScalarValue::Float64(Some(0.5))),
));
//endregion

//region unknown
stats.push((
StatisticsV2::Unknown {
Expand Down Expand Up @@ -561,6 +613,21 @@ mod tests {
},
Some(ScalarValue::Float64(Some(2.))),
),
(
StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.25)) },
Some(ScalarValue::Float64(Some(0.))),
),
(
StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.75)) },
Some(ScalarValue::Float64(Some(1.))),
),
(
StatisticsV2::Gaussian {
mean: ScalarValue::Float64(Some(2.)),
variance: ScalarValue::Float64(Some(1.)),
},
Some(ScalarValue::Float64(Some(2.))),
),
(
StatisticsV2::Unknown {
mean: None,
Expand Down Expand Up @@ -630,6 +697,10 @@ mod tests {
},
Some(ScalarValue::Float64(Some(1.))),
),
(
StatisticsV2::Bernoulli { p: ScalarValue::Float64(Some(0.5)) },
Some(ScalarValue::Float64(Some(0.25))),
),
(
StatisticsV2::Unknown {
mean: None,
Expand Down
39 changes: 25 additions & 14 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::PhysicalExpr;

use crate::expressions::binary::kernels::concat_elements_utf8view;
use crate::utils::stats::new_unknown_from_binary_expr;
use crate::utils::stats::{new_bernoulli_from_binary_expr, new_unknown_from_binary_expr};
use arrow::array::*;
use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
use arrow::compute::kernels::cmp::*;
Expand All @@ -41,7 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::{ColumnarValue, Operator};
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
use datafusion_physical_expr_common::stats::StatisticsV2;
use datafusion_physical_expr_common::stats::StatisticsV2::{Gaussian, Uniform};
use datafusion_physical_expr_common::stats::StatisticsV2::{Bernoulli, Gaussian};
use kernels::{
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
Expand Down Expand Up @@ -503,13 +503,14 @@ impl PhysicalExpr for BinaryExpr {
let (left_stat, right_stat) = (children_stat[0], children_stat[1]);

// We can evaluate statistics only with numeric data types on stats.
// TODO: move the data type check to the the higher levels.
// TODO: move the data type check to the the higher levels, if possible.
if (left_stat.data_type().is_none() || right_stat.data_type().is_none())
|| !left_stat.data_type().unwrap().is_numeric()
|| !right_stat.data_type().unwrap().is_numeric() {
return Ok(StatisticsV2::new_unknown());
}


// TODO, to think: maybe, we can separate also Unknown + Unknown
// just for clarity and better reader understanding, how exactly
// mean, median, variance and range are computed, if it is possible.
Expand All @@ -536,22 +537,30 @@ impl PhysicalExpr for BinaryExpr {
}
},
Operator::Divide => new_unknown_from_binary_expr(&self.op, left_stat, right_stat),
Operator::Eq => {
Operator::Eq | Operator::NotEq | Operator::LtEq | Operator::GtEq
| Operator::Lt | Operator::Gt => {
new_bernoulli_from_binary_expr(&self.op, left_stat, right_stat)
},
Operator::And => {
match (left_stat, right_stat) {
(Uniform { .. }, Uniform { .. }) => {
let intersection = left_stat.range().unwrap()
.intersect(right_stat.range().unwrap())?;
if let Some(interval) = intersection {
Ok(Uniform { interval })
} else {
Ok(Uniform { interval: Interval::CERTAINLY_FALSE })
}
(Bernoulli { p: p_left }, Bernoulli { p: p_right }, ) => {
Ok(Bernoulli { p : p_left.mul_checked(p_right)? })
},
// TODO: complement with more cases
(_, _) => new_unknown_from_binary_expr(&self.op, left_stat, right_stat)
}
},
Operator::NotEq => new_unknown_from_binary_expr(&self.op, left_stat, right_stat),
// TODO: express gt/ge/lt/le operations
Operator::Or => {
match (left_stat, right_stat) {
(Bernoulli { p: p_left }, Bernoulli { p: p_right }, ) => {
Ok(Bernoulli {
p : ScalarValue::Float64(Some(1.)).sub(p_left.mul_checked(p_right)?)?
})
},
// TODO: complement with more cases
(_, _) => new_unknown_from_binary_expr(&self.op, left_stat, right_stat)
}
}
_ => new_unknown_from_binary_expr(&self.op, left_stat, right_stat)
}
}
Expand Down Expand Up @@ -812,6 +821,8 @@ mod tests {
use crate::expressions::{col, lit, try_cast, Column, Literal};
use datafusion_common::plan_datafusion_err;
use datafusion_expr::type_coercion::binary::get_input_types;
use datafusion_physical_expr_common::stats::StatisticsV2::Uniform;

// TODO: remove
//region tests
/// Performs a binary operation, applying any type coercion necessary
Expand Down
Loading

0 comments on commit 81b756d

Please sign in to comment.