Skip to content

Commit

Permalink
refactor: use a builder-like API
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Sep 12, 2023
1 parent bffb137 commit e4427a3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 50 deletions.
56 changes: 26 additions & 30 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ use crate::simplify_expressions::guarantees::GuaranteeRewriter;
/// This structure handles API for expression simplification
pub struct ExprSimplifier<S> {
info: S,
///
guarantees: Vec<(Expr, NullableInterval)>,
}

pub const THRESHOLD_INLINE_INLIST: usize = 3;
Expand All @@ -61,7 +63,10 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
///
/// [`SimplifyContext`]: crate::simplify_expressions::context::SimplifyContext
pub fn new(info: S) -> Self {
Self { info }
Self {
info,
guarantees: vec![],
}
}

/// Simplifies this [`Expr`]`s as much as possible, evaluating
Expand Down Expand Up @@ -125,6 +130,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
let mut or_in_list_simplifier = OrInListSimplifier::new();
let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);

// TODO iterate until no changes are made during rewrite
// (evaluating constants can enable new simplifications and
Expand All @@ -133,6 +139,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
expr.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)?
.rewrite(&mut or_in_list_simplifier)?
.rewrite(&mut guarantee_rewriter)?
// run both passes twice to try an minimize simplifications that we missed
.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)
Expand All @@ -154,14 +161,14 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
expr.rewrite(&mut expr_rewrite)
}

/// Input guarantees and simplify the expression.
/// Input guarantees about the values of columns.
///
/// The guarantees can simplify expressions. For example, if a column `x` is
/// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
/// literal `true`.
///
/// The guarantees are provided as an iterator of `(Expr, NullableInterval)`
/// pairs, where the [Expr] is a column reference and the [NullableInterval]
/// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
/// where the [Expr] is a column reference and the [NullableInterval]
/// is an interval representing the known possible values of that column.
///
/// ```rust
Expand All @@ -184,7 +191,6 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// let props = ExecutionProps::new();
/// let context = SimplifyContext::new(&props)
/// .with_schema(schema);
/// let simplifier = ExprSimplifier::new(context);
///
/// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
/// let expr_x = col("x").gt_eq(lit(3_i64));
Expand All @@ -203,24 +209,15 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// // y = 3
/// (col("y"), NullableInterval::from(&ScalarValue::UInt32(Some(3)))),
/// ];
/// let output = simplifier.simplify_with_guarantees(expr, &guarantees).unwrap();
/// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
/// let output = simplifier.simplify(expr).unwrap();
/// // Expression becomes: true AND true AND (z > 5), which simplifies to
/// // z > 5.
/// assert_eq!(output, expr_z);
/// ```
pub fn simplify_with_guarantees<'a>(
&self,
expr: Expr,
guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
) -> Result<Expr> {
// Do a simplification pass in case it reveals places where a guarantee
// could be applied.
let expr = self.simplify(expr)?;
let mut rewriter = GuaranteeRewriter::new(guarantees);
let expr = expr.rewrite(&mut rewriter)?;
// Simplify after guarantees are applied, since constant folding should
// now be able to fold more expressions.
self.simplify(expr)
pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
self.guarantees = guarantees;
self
}
}

Expand Down Expand Up @@ -2752,16 +2749,15 @@ mod tests {

fn simplify_with_guarantee(
expr: Expr,
guarantees: &[(Expr, NullableInterval)],
guarantees: Vec<(Expr, NullableInterval)>,
) -> Expr {
let schema = expr_test_schema();
let execution_props = ExecutionProps::new();
let simplifier = ExprSimplifier::new(
let mut simplifier = ExprSimplifier::new(
SimplifyContext::new(&execution_props).with_schema(schema),
);
simplifier
.simplify_with_guarantees(expr, guarantees)
.unwrap()
)
.with_guarantees(guarantees);
simplifier.simplify(expr).unwrap()
}

fn expr_test_schema() -> DFSchemaRef {
Expand Down Expand Up @@ -3246,7 +3242,7 @@ mod tests {
(col("c1"), NullableInterval::from(&ScalarValue::Utf8(None))),
];

let output = simplify_with_guarantee(expr.clone(), &guarantees);
let output = simplify_with_guarantee(expr.clone(), guarantees);
assert_eq!(output, lit_bool_null());

// All guaranteed false
Expand All @@ -3266,7 +3262,7 @@ mod tests {
NullableInterval::from(&ScalarValue::Utf8(Some("a".to_string()))),
),
];
let output = simplify_with_guarantee(expr.clone(), &guarantees);
let output = simplify_with_guarantee(expr.clone(), guarantees);
assert_eq!(output, lit(false));

// Guaranteed false or null -> no change.
Expand All @@ -3290,7 +3286,7 @@ mod tests {
},
),
];
let output = simplify_with_guarantee(expr.clone(), &guarantees);
let output = simplify_with_guarantee(expr.clone(), guarantees);
assert_eq!(&output, &expr_x);

// Sufficient true guarantees
Expand All @@ -3304,15 +3300,15 @@ mod tests {
NullableInterval::from(&ScalarValue::UInt32(Some(3))),
),
];
let output = simplify_with_guarantee(expr.clone(), &guarantees);
let output = simplify_with_guarantee(expr.clone(), guarantees);
assert_eq!(output, lit(true));

// Only partially simplify
let guarantees = vec![(
col("c4"),
NullableInterval::from(&ScalarValue::UInt32(Some(3))),
)];
let output = simplify_with_guarantee(expr.clone(), &guarantees);
let output = simplify_with_guarantee(expr.clone(), guarantees);
assert_eq!(&output, &expr_x);
}
}
20 changes: 12 additions & 8 deletions datafusion/optimizer/src/simplify_expressions/guarantees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInter

/// Rewrite expressions to incorporate guarantees.
pub(crate) struct GuaranteeRewriter<'a> {
intervals: HashMap<&'a Expr, &'a NullableInterval>,
guarantees: HashMap<&'a Expr, &'a NullableInterval>,
}

impl<'a> GuaranteeRewriter<'a> {
pub fn new(
guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
) -> Self {
Self {
intervals: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
}
}
}
Expand All @@ -41,13 +41,17 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
type N = Expr;

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if self.guarantees.is_empty() {
return Ok(expr);
}

match &expr {
Expr::IsNull(inner) => match self.intervals.get(inner.as_ref()) {
Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) {
Some(NullableInterval::Null { .. }) => Ok(lit(true)),
Some(NullableInterval::NotNull { .. }) => Ok(lit(false)),
_ => Ok(expr),
},
Expr::IsNotNull(inner) => match self.intervals.get(inner.as_ref()) {
Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) {
Some(NullableInterval::Null { .. }) => Ok(lit(false)),
Some(NullableInterval::NotNull { .. }) => Ok(lit(true)),
_ => Ok(expr),
Expand All @@ -59,7 +63,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
high,
}) => {
if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = (
self.intervals.get(inner.as_ref()),
self.guarantees.get(inner.as_ref()),
low.as_ref(),
high.as_ref(),
) {
Expand Down Expand Up @@ -112,7 +116,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
_ => return Ok(expr),
};

if let Some(col_interval) = self.intervals.get(col.as_ref()) {
if let Some(col_interval) = self.guarantees.get(col.as_ref()) {
let result = col_interval.apply_operator(&op, &value.into())?;
if result.is_certainly_true() {
Ok(lit(true))
Expand All @@ -128,7 +132,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {

// Columns (if interval is collapsed to a single value)
Expr::Column(_) => {
if let Some(col_interval) = self.intervals.get(&expr) {
if let Some(col_interval) = self.guarantees.get(&expr) {
if let Some(value) = col_interval.single_value() {
Ok(lit(value))
} else {
Expand All @@ -144,7 +148,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
list,
negated,
}) => {
if let Some(interval) = self.intervals.get(inner.as_ref()) {
if let Some(interval) = self.guarantees.get(inner.as_ref()) {
// Can remove items from the list that don't match the guarantee
let new_list: Vec<Expr> = list
.iter()
Expand Down
34 changes: 22 additions & 12 deletions datafusion/physical-expr/src/intervals/interval_aritmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ fn calculate_cardinality_based_on_bounds(
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NullableInterval {
/// The value is always null in this interval
///
///
/// This is typed so it can be used in physical expressions, which don't do
/// type coercion.
Null { datatype: DataType },
Expand Down Expand Up @@ -784,7 +784,9 @@ impl From<&ScalarValue> for NullableInterval {
/// Create an interval that represents a single value.
fn from(value: &ScalarValue) -> Self {
if value.is_null() {
Self::Null { datatype: value.get_datatype() }
Self::Null {
datatype: value.get_datatype(),
}
} else {
Self::NotNull {
values: Interval::new(
Expand Down Expand Up @@ -835,7 +837,9 @@ impl NullableInterval {
/// Perform logical negation on a boolean nullable interval.
fn not(&self) -> Result<Self> {
match self {
Self::Null { datatype } => Ok(Self::Null { datatype: datatype.clone() }),
Self::Null { datatype } => Ok(Self::Null {
datatype: datatype.clone(),
}),
Self::MaybeNull { values } => Ok(Self::MaybeNull {
values: values.not()?,
}),
Expand Down Expand Up @@ -921,12 +925,14 @@ impl NullableInterval {
}
_ => Ok(Self::MaybeNull { values }),
}
} else if op.is_comparison_operator() {
Ok(Self::Null {
datatype: DataType::Boolean,
})
} else {
if op.is_comparison_operator() {
Ok(Self::Null { datatype: DataType::Boolean})
} else {
Ok(Self::Null { datatype: self.get_datatype()? })
}
Ok(Self::Null {
datatype: self.get_datatype()?,
})
}
}
}
Expand All @@ -947,7 +953,9 @@ impl NullableInterval {
_ => Ok(Self::MaybeNull { values }),
}
} else {
Ok(Self::Null { datatype: DataType::Boolean })
Ok(Self::Null {
datatype: DataType::Boolean,
})
}
}

Expand All @@ -974,10 +982,12 @@ impl NullableInterval {
/// ```
pub fn single_value(&self) -> Option<ScalarValue> {
match self {
Self::Null { datatype } => Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)),
Self::Null { datatype } => {
Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null))
}
Self::MaybeNull { values } | Self::NotNull { values }
if values.lower.value == values.upper.value &&
!values.lower.is_unbounded() =>
if values.lower.value == values.upper.value
&& !values.lower.is_unbounded() =>
{
Some(values.lower.value.clone())
}
Expand Down

0 comments on commit e4427a3

Please sign in to comment.