Skip to content

Commit

Permalink
pr feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Sep 12, 2023
1 parent e4427a3 commit f4e8680
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 52 deletions.
26 changes: 12 additions & 14 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ use crate::simplify_expressions::guarantees::GuaranteeRewriter;
/// This structure handles API for expression simplification
pub struct ExprSimplifier<S> {
info: S,
///
/// Guarantees about the values of columns. This is provided by the user
/// in [ExprSimplifier::with_guarantees()].
guarantees: Vec<(Expr, NullableInterval)>,
}

Expand Down Expand Up @@ -207,7 +208,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// }
/// ),
/// // y = 3
/// (col("y"), NullableInterval::from(&ScalarValue::UInt32(Some(3)))),
/// (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
/// ];
/// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
/// let output = simplifier.simplify(expr).unwrap();
Expand Down Expand Up @@ -2753,7 +2754,7 @@ mod tests {
) -> Expr {
let schema = expr_test_schema();
let execution_props = ExecutionProps::new();
let mut simplifier = ExprSimplifier::new(
let simplifier = ExprSimplifier::new(
SimplifyContext::new(&execution_props).with_schema(schema),
)
.with_guarantees(guarantees);
Expand Down Expand Up @@ -3234,12 +3235,9 @@ mod tests {

// All guaranteed null
let guarantees = vec![
(col("c3"), NullableInterval::from(&ScalarValue::Int64(None))),
(
col("c4"),
NullableInterval::from(&ScalarValue::UInt32(None)),
),
(col("c1"), NullableInterval::from(&ScalarValue::Utf8(None))),
(col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
(col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
(col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
];

let output = simplify_with_guarantee(expr.clone(), guarantees);
Expand All @@ -3255,11 +3253,11 @@ mod tests {
),
(
col("c4"),
NullableInterval::from(&ScalarValue::UInt32(Some(9))),
NullableInterval::from(ScalarValue::UInt32(Some(9))),
),
(
col("c1"),
NullableInterval::from(&ScalarValue::Utf8(Some("a".to_string()))),
NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))),
),
];
let output = simplify_with_guarantee(expr.clone(), guarantees);
Expand Down Expand Up @@ -3293,11 +3291,11 @@ mod tests {
let guarantees = vec![
(
col("c3"),
NullableInterval::from(&ScalarValue::Int64(Some(9))),
NullableInterval::from(ScalarValue::Int64(Some(9))),
),
(
col("c4"),
NullableInterval::from(&ScalarValue::UInt32(Some(3))),
NullableInterval::from(ScalarValue::UInt32(Some(3))),
),
];
let output = simplify_with_guarantee(expr.clone(), guarantees);
Expand All @@ -3306,7 +3304,7 @@ mod tests {
// Only partially simplify
let guarantees = vec![(
col("c4"),
NullableInterval::from(&ScalarValue::UInt32(Some(3))),
NullableInterval::from(ScalarValue::UInt32(Some(3))),
)];
let output = simplify_with_guarantee(expr.clone(), guarantees);
assert_eq!(&output, &expr_x);
Expand Down
52 changes: 27 additions & 25 deletions datafusion/optimizer/src/simplify_expressions/guarantees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,22 @@

//! Simplifier implementation for [ExprSimplifier::simplify_with_guarantees()][crate::simplify_expressions::expr_simplifier::ExprSimplifier::simplify_with_guarantees].
use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result};
use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator};
use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr};
use std::collections::HashMap;

use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval};

/// Rewrite expressions to incorporate guarantees.
///
/// Guarantees are a mapping from an expression (which currently is always a
/// column reference) to a [NullableInterval]. The interval represents the known
/// possible values of the column. Using these known values, expressions are
/// rewritten so they can be simplified using [ConstEvaluator] and [Simplifier].
///
/// For example, if we know that a column is not null and has values in the
/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
///
/// See a full example in [ExprSimplifier::with_guarantees()].
pub(crate) struct GuaranteeRewriter<'a> {
guarantees: HashMap<&'a Expr, &'a NullableInterval>,
}
Expand Down Expand Up @@ -89,17 +99,9 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
}

Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
// Check if this is a comparison
match op {
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom => {}
_ => return Ok(expr),
// We only support comparisons for now
if !op.is_comparison_operator() {
return Ok(expr);
};

// Check if this is a comparison between a column and literal
Expand All @@ -117,7 +119,8 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
};

if let Some(col_interval) = self.guarantees.get(col.as_ref()) {
let result = col_interval.apply_operator(&op, &value.into())?;
let result =
col_interval.apply_operator(&op, &value.clone().into())?;
if result.is_certainly_true() {
Ok(lit(true))
} else if result.is_certainly_false() {
Expand Down Expand Up @@ -154,16 +157,14 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
.iter()
.filter_map(|expr| {
if let Expr::Literal(item) = expr {
match interval.contains(&NullableInterval::from(item)) {
match interval
.contains(&NullableInterval::from(item.clone()))
{
// If we know for certain the value isn't in the column's interval,
// we can skip checking it.
Ok(NullableInterval::NotNull { values })
if values == Interval::CERTAINLY_FALSE =>
{
None
}
Err(err) => Some(Err(err)),
_ => Some(Ok(expr.clone())),
Ok(interval) if interval.is_certainly_false() => None,
Ok(_) => Some(Ok(expr.clone())),
Err(e) => Some(Err(e)),
}
} else {
Some(Ok(expr.clone()))
Expand Down Expand Up @@ -192,7 +193,7 @@ mod tests {

use arrow::datatypes::DataType;
use datafusion_common::{tree_node::TreeNode, ScalarValue};
use datafusion_expr::{col, lit};
use datafusion_expr::{col, lit, Operator};

#[test]
fn test_null_handling() {
Expand Down Expand Up @@ -270,8 +271,10 @@ mod tests {
(col("x").eq(lit(0)), false),
(col("x").not_eq(lit(0)), true),
(col("x").between(lit(2), lit(5)), true),
(col("x").between(lit(2), lit(3)), true),
(col("x").between(lit(5), lit(10)), false),
(col("x").not_between(lit(2), lit(5)), false),
(col("x").not_between(lit(2), lit(3)), false),
(col("x").not_between(lit(5), lit(10)), true),
(
Expr::BinaryExpr(BinaryExpr {
Expand Down Expand Up @@ -451,9 +454,8 @@ mod tests {
ScalarValue::Decimal128(Some(1000), 19, 2),
];

for scalar in &scalars {
let guarantees = vec![(col("x"), NullableInterval::from(scalar))];
dbg!(&guarantees);
for scalar in scalars {
let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))];
let mut rewriter = GuaranteeRewriter::new(guarantees.iter());

let output = col("x").rewrite(&mut rewriter).unwrap();
Expand Down
27 changes: 14 additions & 13 deletions datafusion/physical-expr/src/intervals/interval_aritmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ fn calculate_cardinality_based_on_bounds(
/// # Examples
///
/// ```
/// use arrow::datatypes::DataType;
/// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
/// use datafusion_common::ScalarValue;
///
Expand All @@ -741,10 +742,10 @@ fn calculate_cardinality_based_on_bounds(
/// };
///
/// // {NULL}
/// NullableInterval::Null;
/// NullableInterval::Null { datatype: DataType::Int32 };
///
/// // {4}
/// NullableInterval::from(&ScalarValue::Int32(Some(4)));
/// NullableInterval::from(ScalarValue::Int32(Some(4)));
/// ```
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NullableInterval {
Expand Down Expand Up @@ -780,9 +781,9 @@ impl Display for NullableInterval {
}
}

impl From<&ScalarValue> for NullableInterval {
impl From<ScalarValue> for NullableInterval {
/// Create an interval that represents a single value.
fn from(value: &ScalarValue) -> Self {
fn from(value: ScalarValue) -> Self {
if value.is_null() {
Self::Null {
datatype: value.get_datatype(),
Expand All @@ -791,7 +792,7 @@ impl From<&ScalarValue> for NullableInterval {
Self::NotNull {
values: Interval::new(
IntervalBound::new(value.clone(), false),
IntervalBound::new(value.clone(), false),
IntervalBound::new(value, false),
),
}
}
Expand Down Expand Up @@ -859,18 +860,18 @@ impl NullableInterval {
/// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
///
/// // 4 > 3 -> true
/// let lhs = NullableInterval::from(&ScalarValue::Int32(Some(4)));
/// let rhs = NullableInterval::from(&ScalarValue::Int32(Some(3)));
/// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4)));
/// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3)));
/// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap();
/// assert_eq!(result, NullableInterval::from(&ScalarValue::Boolean(Some(true))));
/// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true))));
///
/// // [1, 3) > NULL -> NULL
/// let lhs = NullableInterval::NotNull {
/// values: Interval::make(Some(1), Some(3), (false, true)),
/// };
/// let rhs = NullableInterval::from(&ScalarValue::Int32(None));
/// let rhs = NullableInterval::from(ScalarValue::Int32(None));
/// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap();
/// assert_eq!(result.single_value(), Some(ScalarValue::Null));
/// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None)));
///
/// // [1, 3] > [2, 4] -> [false, true]
/// let lhs = NullableInterval::NotNull {
Expand Down Expand Up @@ -969,11 +970,11 @@ impl NullableInterval {
/// use datafusion_common::ScalarValue;
/// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
///
/// let interval = NullableInterval::from(&ScalarValue::Int32(Some(4)));
/// let interval = NullableInterval::from(ScalarValue::Int32(Some(4)));
/// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4))));
///
/// let interval = NullableInterval::from(&ScalarValue::Int32(None));
/// assert_eq!(interval.single_value(), Some(ScalarValue::Null));
/// let interval = NullableInterval::from(ScalarValue::Int32(None));
/// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None)));
///
/// let interval = NullableInterval::MaybeNull {
/// values: Interval::make(Some(1), Some(4), (false, true)),
Expand Down

0 comments on commit f4e8680

Please sign in to comment.