Skip to content

Commit

Permalink
Add integration test, implement conversion into Bernoulli distributio…
Browse files Browse the repository at this point in the history
…n for Eq and NotEq
  • Loading branch information
Fly-Style committed Jan 26, 2025
1 parent a1bbfce commit d9c2d83
Show file tree
Hide file tree
Showing 9 changed files with 415 additions and 65 deletions.
4 changes: 2 additions & 2 deletions datafusion/expr-common/src/interval_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ fn next_value_helper<const INC: bool>(value: ScalarValue) -> ScalarValue {

/// Returns the greater of the given interval bounds. Assumes that a `NULL`
/// value represents `NEG_INF`.
fn max_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue {
pub fn max_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue {
if !first.is_null() && (second.is_null() || first >= second) {
first.clone()
} else {
Expand All @@ -1159,7 +1159,7 @@ fn max_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue {

/// Returns the lesser of the given interval bounds. Assumes that a `NULL`
/// value represents `INF`.
fn min_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue {
pub fn min_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue {
if !first.is_null() && (second.is_null() || first <= second) {
first.clone()
} else {
Expand Down
3 changes: 0 additions & 3 deletions datafusion/physical-expr-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ license = { workspace = true }
authors = { workspace = true }
rust-version = { workspace = true }

[features]
stats_v2 = []

[lints]
workspace = true

Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-expr-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ pub mod binary_view_map;
pub mod datum;
pub mod physical_expr;
pub mod sort_expr;
// #[cfg(feature = "stats_v2")]
pub mod stats;
pub mod tree_node;
pub mod utils;
1 change: 0 additions & 1 deletion datafusion/physical-expr-common/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ impl StatisticsV2 {
}

#[cfg(test)]
// #[cfg(all(test, feature = "stats_v2"))]
mod tests {
use crate::stats::StatisticsV2;
use arrow::datatypes::DataType;
Expand Down
3 changes: 0 additions & 3 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ rust-version = { workspace = true }
[lints]
workspace = true

[features]
stats_v2 = []

[lib]
name = "datafusion_physical_expr"
path = "src/lib.rs"
Expand Down
214 changes: 176 additions & 38 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,13 @@ impl PhysicalExpr for BinaryExpr {
);
}

println!(
"evaluate_statistics: {:?} {:?} {:?}",
left_stat, self.op, right_stat
);

// TODO, to think: maybe, we can separate also Unknown + Unknown
// just for clarity and better reader understanding, how exactly
// mean, median, variance and range are computed, if it is possible.
// just for clarity and better reader understanding.
match &self.op {
Operator::Plus | Operator::Minus | Operator::Multiply => {
match (left_stat, right_stat) {
Expand Down Expand Up @@ -977,6 +981,8 @@ mod tests {
Ok(())
}

//region tests

// runs an end-to-end test of physical type coercion:
// 1. construct a record batch with two columns of type A and B
// (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
Expand Down Expand Up @@ -4523,10 +4529,11 @@ mod tests {
)
.unwrap();
}
//endregion

//region evaluate_statistics and propagate_statistics test

fn binary_expr(
pub fn binary_expr(
left: Arc<dyn PhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
Expand All @@ -4541,7 +4548,6 @@ mod tests {
)
}

// TODO: generate the test case(s) with the macros
/// Test for Uniform-Uniform, Unknown-Uniform, Uniform-Unknown and Unknown-Unknown evaluation.
#[test]
fn test_evaluate_statistics_combination_of_range_holders() -> Result<()> {
Expand Down Expand Up @@ -4629,7 +4635,57 @@ mod tests {
}

#[test]
fn test_propagate_statistics_uniform_uniform_arithmetic() -> Result<()> {
fn test_evaluate_statistics_bernoulli() -> Result<()> {
let schema = &Schema::new(vec![
Field::new("a", DataType::Float64, false),
Field::new("b", DataType::Float64, false),
]);
let a: Arc<dyn PhysicalExpr> = Arc::new(Column::new("a", 0));
let b: Arc<dyn PhysicalExpr> = Arc::new(Column::new("b", 1));
let eq: Arc<dyn PhysicalExpr> = Arc::new(binary_expr(
Arc::clone(&a),
Operator::Eq,
Arc::clone(&b),
schema,
)?);
let neq: Arc<dyn PhysicalExpr> = Arc::new(binary_expr(
Arc::clone(&a),
Operator::NotEq,
Arc::clone(&b),
schema,
)?);

let left_stat = &Uniform {
interval: Interval::make(Some(0.0), Some(6.0))?,
};
let right_stat = &Uniform {
interval: Interval::make(Some(4.0), Some(10.0))?,
};

// Intervals: (0, 6], (6, 10].
// The intersection is [4,6], so the probability of value being selected from
// the intersection segment is 20%, or 0.2 as Bernoulli.
assert_eq!(
eq.evaluate_statistics(&[left_stat, right_stat])?,
Bernoulli {
p: ScalarValue::Float64(Some(0.2))
}
);

// The intersection is [4,6], so the probability of value NOT being selected from
// the intersection segment is 80%, or 0.8 as Bernoulli.
assert_eq!(
neq.evaluate_statistics(&[left_stat, right_stat])?,
Bernoulli {
p: ScalarValue::Float64(Some(0.8))
}
);

Ok(())
}

#[test]
fn test_propagate_statistics_combination_of_range_holders_arithmetic() -> Result<()> {
let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
let a: Arc<dyn PhysicalExpr> = Arc::new(Column::new("a", 0));
let b = lit(ScalarValue::Float64(Some(12.0)));
Expand All @@ -4638,17 +4694,54 @@ mod tests {
let right_interval = Interval::make(Some(12.0), Some(36.0))?;

let parent = Uniform {
interval: Interval::make(Some(-432.), Some(432.))?
interval: Interval::make(Some(-432.), Some(432.))?,
};
let children = &vec![
Uniform {
interval: left_interval.clone(),
},
Uniform {
interval: right_interval.clone(),
},
let children: Vec<Vec<StatisticsV2>> = vec![
vec![
Uniform {
interval: left_interval.clone(),
},
Uniform {
interval: right_interval.clone(),
},
],
vec![
Unknown {
mean: Some(ScalarValue::Float64(Some(6.))),
median: Some(ScalarValue::Float64(Some(6.))),
variance: None,
range: left_interval.clone(),
},
Uniform {
interval: right_interval.clone(),
},
],
vec![
Uniform {
interval: left_interval.clone(),
},
Unknown {
mean: Some(ScalarValue::Float64(Some(12.))),
median: Some(ScalarValue::Float64(Some(12.))),
variance: None,
range: right_interval.clone(),
},
],
vec![
Unknown {
mean: Some(ScalarValue::Float64(Some(6.))),
median: Some(ScalarValue::Float64(Some(6.))),
variance: None,
range: left_interval.clone(),
},
Unknown {
mean: Some(ScalarValue::Float64(Some(12.))),
median: Some(ScalarValue::Float64(Some(12.))),
variance: None,
range: right_interval.clone(),
},
],
];
let ref_view: Vec<&StatisticsV2> = children.iter().collect();

let ops = vec![
Operator::Plus,
Expand All @@ -4657,53 +4750,98 @@ mod tests {
Operator::Divide,
];

for op in ops {
let expr = binary_expr(Arc::clone(&a), op, Arc::clone(&b), schema)?;
assert_eq!(
expr.propagate_statistics(&parent, &ref_view)?,
Some(vec![
new_unknown_from_interval(&left_interval)?,
new_unknown_from_interval(&right_interval)?
])
);
for child_view in children {
let ref_view: Vec<&StatisticsV2> = child_view.iter().collect();
for op in &ops {
let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
assert_eq!(
expr.propagate_statistics(&parent, &ref_view)?,
Some(vec![
new_unknown_from_interval(&left_interval)?,
new_unknown_from_interval(&right_interval)?
])
);
}
}
Ok(())
}

#[test]
fn test_propagate_statistics_uniform_uniform_comparison() -> Result<()> {
fn test_propagate_statistics_combination_of_range_holders_comparison() -> Result<()> {
let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
let a: Arc<dyn PhysicalExpr> = Arc::new(Column::new("a", 0));
let b = lit(ScalarValue::Float64(Some(12.0)));

let left_interval = Interval::make(Some(0.0), Some(12.0))?;
let right_interval = Interval::make(Some(6.0), Some(18.0))?;

for parent_interval in [Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE] {
for parent_interval in [
Interval::CERTAINLY_TRUE, /*, Interval::CERTAINLY_FALSE*/
] {
let parent = Uniform {
interval: Interval::CERTAINLY_TRUE
interval: parent_interval,
};
let children = &vec![
Uniform {
interval: left_interval.clone(),
},
Uniform {
interval: right_interval.clone(),
},
let children: Vec<Vec<StatisticsV2>> = vec![
vec![
Uniform {
interval: left_interval.clone(),
},
Uniform {
interval: right_interval.clone(),
},
],
vec![
Unknown {
mean: Some(ScalarValue::Float64(Some(6.))),
median: Some(ScalarValue::Float64(Some(6.))),
variance: None,
range: left_interval.clone(),
},
Uniform {
interval: right_interval.clone(),
},
],
vec![
Uniform {
interval: left_interval.clone(),
},
Unknown {
mean: Some(ScalarValue::Float64(Some(12.))),
median: Some(ScalarValue::Float64(Some(12.))),
variance: None,
range: right_interval.clone(),
},
],
vec![
Unknown {
mean: Some(ScalarValue::Float64(Some(6.))),
median: Some(ScalarValue::Float64(Some(6.))),
variance: None,
range: left_interval.clone(),
},
Unknown {
mean: Some(ScalarValue::Float64(Some(12.))),
median: Some(ScalarValue::Float64(Some(12.))),
variance: None,
range: right_interval.clone(),
},
],
];
let ref_view: Vec<&StatisticsV2> = children.iter().collect();

let ops = vec![
Operator::Eq,
Operator::Gt,
Operator::GtEq,
Operator::Lt,
Operator::LtEq
Operator::LtEq,
];

for op in ops {
let expr = binary_expr(Arc::clone(&a), op, Arc::clone(&b), schema)?;
assert!(expr.propagate_statistics(&parent, &ref_view)?.is_some());
for child_view in children {
let ref_view: Vec<&StatisticsV2> = child_view.iter().collect();
for op in &ops {
let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
assert!(expr.propagate_statistics(&parent, &ref_view)?.is_some());
}
}
}
Ok(())
Expand Down
Loading

0 comments on commit d9c2d83

Please sign in to comment.