Skip to content

Commit

Permalink
Refactor result distribution computation during the statistics evalua…
Browse files Browse the repository at this point in the history
…tion phase; add compute_range function
  • Loading branch information
Fly-Style committed Jan 16, 2025
1 parent bf6d01c commit add838c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 61 deletions.
66 changes: 25 additions & 41 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::{ColumnarValue, Operator};
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
use datafusion_physical_expr_common::stats::StatisticsV2;
use datafusion_physical_expr_common::stats::StatisticsV2::{
Gaussian, Uniform,
};
use datafusion_physical_expr_common::stats::StatisticsV2::{Gaussian, Uniform};
use kernels::{
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
Expand Down Expand Up @@ -513,22 +511,8 @@ impl PhysicalExpr for BinaryExpr {
}

match &self.op {
Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide => {
Operator::Plus | Operator::Minus | Operator::Multiply => {
match (left_stat, right_stat) {
(Uniform { interval: left_interval}, Uniform { interval: right_interval, }) => {
new_unknown_from_binary_expr(
&self.op,
left_stat,
right_stat,
apply_operator(&self.op, left_interval, right_interval)?,
)
},
(Uniform {..}, _) | (_, Uniform {..}) => new_unknown_from_binary_expr(
&self.op,
left_stat,
right_stat,
Interval::make_unbounded(&left_stat.data_type().unwrap())?,
),
(Gaussian { mean: left_mean, variance: left_v, ..},
Gaussian { mean: right_mean, variance: right_v}, ) => {
if self.op.eq(&Operator::Plus) {
Expand All @@ -542,22 +526,27 @@ impl PhysicalExpr for BinaryExpr {
variance: left_v.add(right_v)?,
})
} else {
new_unknown_from_binary_expr(
&self.op,
left_stat,
right_stat,
Interval::make_unbounded(&left_stat.data_type().unwrap())?,
)
new_unknown_from_binary_expr(&self.op, left_stat, right_stat)
}
}
(_, _) => new_unknown_from_binary_expr(
&self.op,
left_stat,
right_stat,
Interval::make_unbounded(&left_stat.data_type().unwrap())?,
)
},
(_, _) => new_unknown_from_binary_expr(&self.op, left_stat, right_stat)
}
}
},
Operator::Divide => new_unknown_from_binary_expr(&self.op, left_stat, right_stat),
Operator::Eq => {
match (left_stat, right_stat) {
(Uniform { .. }, Uniform { .. }) => {
let intersection = left_stat.range().unwrap()
.intersect(right_stat.range().unwrap())?;
if let Some(interval) = intersection {
Ok(Uniform { interval })
} else {
Ok(Uniform { interval: Interval::CERTAINLY_FALSE })
}
},
(_, _) => new_unknown_from_binary_expr(&self.op, left_stat, right_stat)
}
},
_ => internal_err!("BinaryExpr requires exactly 2 children")
}
}
Expand Down Expand Up @@ -4502,17 +4491,12 @@ mod tests {
for op in ops {
let expr = binary_expr(Arc::clone(&a), op, Arc::clone(&b), schema)?;
// TODO: to think, if maybe we want to handcraft the expected value...
let expected = new_unknown_from_binary_expr(
&op,
ref_view[0],
ref_view[1],
apply_operator(&op, &left_interval, &right_interval)?
)?;
assert_eq!(expr.evaluate_statistics(&ref_view)?, expected);
assert_eq!(
expr.evaluate_statistics(&ref_view)?,
new_unknown_from_binary_expr(&op, ref_view[0], ref_view[1])?
);
}

Ok(())
}

//endregion evaluate_statistics test
}
97 changes: 77 additions & 20 deletions datafusion/physical-expr/src/utils/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::utils::{build_dag, ExprTreeNode};
use arrow::datatypes::{DataType, Schema};
use datafusion_common::ScalarValue;
use datafusion_common::ScalarValue::Float64;
use datafusion_expr_common::interval_arithmetic::Interval;
use datafusion_expr_common::interval_arithmetic::{apply_operator, Interval};
use datafusion_expr_common::operator::Operator;
use datafusion_physical_expr_common::stats::StatisticsV2;
use datafusion_physical_expr_common::stats::StatisticsV2::{Exponential, Uniform, Unknown};
Expand Down Expand Up @@ -192,14 +192,13 @@ pub fn new_unknown_with_range(range: Interval) -> StatisticsV2 {
pub fn new_unknown_from_binary_expr(
op: &Operator,
left: &StatisticsV2,
right: &StatisticsV2,
range: Interval
right: &StatisticsV2
) -> datafusion_common::Result<StatisticsV2> {
Ok(Unknown {
mean: compute_mean(op, left, right)?,
median: compute_median(op, left, right)?,
variance: compute_variance(op, left, right)?,
range
range: compute_range(op, left, right)?,
})
}

Expand Down Expand Up @@ -333,48 +332,106 @@ pub fn compute_variance(
}
}

pub fn compute_range(op: &Operator, left_stat: &StatisticsV2, right_stat: &StatisticsV2)
-> datafusion_common::Result<Interval> {
match (left_stat, right_stat) {
(Uniform { interval: l }, Uniform { interval: r })
| (Uniform { interval: l }, Unknown { range: r, .. })
| (Unknown { range: l, .. }, Uniform { interval: r })
| (Unknown { range: l, .. }, Unknown { range: r, .. }) => {
match op {
Operator::Plus | Operator::Minus | Operator::Multiply
| Operator::Gt | Operator::GtEq | Operator::Lt | Operator::LtEq => {
apply_operator(op, l, r)
},
Operator::Eq => {
// Note: unwrap is legit, because Uniform & Unknown always have ranges
if let Some(intersection) = left_stat.range().unwrap()
.intersect(right_stat.range().unwrap())? {
Ok(intersection)
} else if let Some(data_type) = left_stat.data_type() {
Interval::make_unbounded(&data_type)
} else {
Interval::make_unbounded(&DataType::Float64)
}
},
Operator::NotEq => Ok(Interval::CERTAINLY_FALSE),
_ => Interval::make_unbounded(&DataType::Float64)
}
}
(_, _) => Interval::make_unbounded(&DataType::Float64)
}
}

#[cfg(test)]
// #[cfg(all(test, feature = "stats_v2"))]
mod tests {
use crate::utils::stats::{compute_mean, compute_median, compute_variance};
use crate::utils::stats::{compute_mean, compute_median, compute_range, compute_variance};
use datafusion_common::ScalarValue;
use datafusion_common::ScalarValue::Float64;
use datafusion_expr_common::interval_arithmetic::Interval;
use datafusion_expr_common::operator::Operator::{Minus, Multiply, Plus};
use datafusion_physical_expr_common::stats::StatisticsV2::Uniform;
use datafusion_expr_common::interval_arithmetic::{apply_operator, Interval};
use datafusion_expr_common::operator::Operator::{Divide, Minus, Multiply, Eq, NotEq, Plus, Gt, GtEq, Lt, LtEq};
use datafusion_physical_expr_common::stats::StatisticsV2::{Uniform, Unknown};

type Actual = Option<ScalarValue>;
type Expected = Option<ScalarValue>;

// Expected test results were calculated in Wolfram Mathematica, by using
// *METHOD_NAME*[TransformedDistribution[x op y, {x ~ *DISTRIBUTION_X*[..], y ~ *DISTRIBUTION_Y*[..]}]]
#[test]
fn test_unknown_properties_uniform_uniform() {
fn test_unknown_properties_uniform_uniform() -> datafusion_common::Result<()> {
let stat_a = Uniform {
interval: Interval::make(Some(0.), Some(12.0)).unwrap()
interval: Interval::make(Some(0.), Some(12.0))?
};

let stat_b = Uniform {
interval: Interval::make(Some(12.0), Some(36.0)).unwrap()
interval: Interval::make(Some(12.0), Some(36.0))?
};

let test_data: Vec<(Actual, Expected)> = vec![
// mean
(compute_mean(&Plus, &stat_a, &stat_b).unwrap(), Some(Float64(Some(30.)))),
(compute_mean(&Minus, &stat_a, &stat_b).unwrap(), Some(Float64(Some(-18.)))),
(compute_mean(&Multiply, &stat_a, &stat_b).unwrap(), Some(Float64(Some(144.)))),
(compute_mean(&Plus, &stat_a, &stat_b)?, Some(Float64(Some(30.)))),
(compute_mean(&Minus, &stat_a, &stat_b)?, Some(Float64(Some(-18.)))),
(compute_mean(&Multiply, &stat_a, &stat_b)?, Some(Float64(Some(144.)))),

// median
(compute_median(&Plus, &stat_a, &stat_b).unwrap(), Some(Float64(Some(30.)))),
(compute_median(&Minus, &stat_a, &stat_b).unwrap(), Some(Float64(Some(-18.)))),
(compute_median(&Plus, &stat_a, &stat_b)?, Some(Float64(Some(30.)))),
(compute_median(&Minus, &stat_a, &stat_b)?, Some(Float64(Some(-18.)))),
// FYI: median of combined distributions for mul, div and mod ops doesn't exist.

// variance
(compute_variance(&Plus, &stat_a, &stat_b).unwrap(), Some(Float64(Some(60.)))),
(compute_variance(&Minus, &stat_a, &stat_b).unwrap(), Some(Float64(Some(60.)))),
// (compute_variance(&Operator::Multiply, &stat_a, &stat_b).unwrap(), Some(Float64(Some(9216.)))),
(compute_variance(&Plus, &stat_a, &stat_b)?, Some(Float64(Some(60.)))),
(compute_variance(&Minus, &stat_a, &stat_b)?, Some(Float64(Some(60.)))),
// (compute_variance(&Operator::Multiply, &stat_a, &stat_b), Some(Float64(Some(9216.)))),
];

for (actual, expected) in test_data {
assert_eq!(actual, expected);
}

Ok(())
}

/// Test for Uniform-Uniform, Uniform-Unknown, Unknown-Uniform, Unknown-Unknown pairs,
/// where range is always present.
#[test]
fn test_compute_range_where_present() -> datafusion_common::Result<()> {
let a = &Interval::make(Some(0.), Some(12.0))?;
let b = &Interval::make(Some(0.), Some(12.0))?;
for (stat_a, stat_b) in [
(Uniform { interval: a.clone() }, Uniform { interval: b.clone() }),
(Unknown { mean: None, median: None, variance: None, range: a.clone() }, Uniform { interval: b.clone() }),
(Uniform { interval: a.clone() }, Unknown { mean: None, median: None, variance: None, range: b.clone() }),
(Unknown { mean: None, median: None, variance: None, range: a.clone() },
Unknown { mean: None, median: None, variance: None, range: b.clone() })] {

// range
for op in [Plus, Minus, Multiply, Divide, Gt, GtEq, Lt, LtEq] {
assert_eq!(compute_range(&op, &stat_a, &stat_b)?, apply_operator(&op, a, b)?);
}
assert_eq!(compute_range(&Eq, &stat_a, &stat_b)?, a.intersect(b)?.unwrap());
assert_eq!(compute_range(&NotEq, &stat_a, &stat_b)?, Interval::CERTAINLY_FALSE);
}

Ok(())
}
}

0 comments on commit add838c

Please sign in to comment.