Skip to content

Commit

Permalink
Optimize CASE expression for "expr or expr" usage. (#13953)
Browse files Browse the repository at this point in the history
* Apply optimization for ExprOrExpr.

* Implement optimization similar to existing code.

* Add sqllogictest.
  • Loading branch information
aweltsch authored Jan 4, 2025
1 parent 39a69f5 commit 0f4b8b1
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
84 changes: 84 additions & 0 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ enum EvalMethod {
/// are literal values
/// CASE WHEN condition THEN literal ELSE literal END
ScalarOrScalar,
/// This is a specialization for a specific use case where we can take a fast path
/// if there is just one when/then pair and both the `then` and `else` are expressions
///
/// CASE WHEN condition THEN expression ELSE expression END
ExpressionOrExpression,
}

/// The CASE expression is similar to a series of nested if/else and there are two forms that
Expand Down Expand Up @@ -149,6 +154,8 @@ impl CaseExpr {
&& else_expr.as_ref().unwrap().as_any().is::<Literal>()
{
EvalMethod::ScalarOrScalar
} else if when_then_expr.len() == 1 && else_expr.is_some() {
EvalMethod::ExpressionOrExpression
} else {
EvalMethod::NoExpression
};
Expand Down Expand Up @@ -394,6 +401,43 @@ impl CaseExpr {

Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
}

fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;

// evalute when condition on batch
let when_value = self.when_then_expr[0].0.evaluate(batch)?;
let when_value = when_value.into_array(batch.num_rows())?;
let when_value = as_boolean_array(&when_value).map_err(|e| {
DataFusionError::Context(
"WHEN expression did not return a BooleanArray".to_string(),
Box::new(e),
)
})?;

// Treat 'NULL' as false value
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
_ => Cow::Owned(prep_null_mask_filter(when_value)),
};

let then_value = self.when_then_expr[0]
.1
.evaluate_selection(batch, &when_value)?
.into_array(batch.num_rows())?;

// evaluate else expression on the values not covered by when_value
let remainder = not(&when_value)?;
let e = self.else_expr.as_ref().unwrap();
// keep `else_expr`'s data type and return type consistent
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
.unwrap_or_else(|_| Arc::clone(e));
let else_ = expr
.evaluate_selection(batch, &remainder)?
.into_array(batch.num_rows())?;

Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
}
}

impl PhysicalExpr for CaseExpr {
Expand Down Expand Up @@ -457,6 +501,7 @@ impl PhysicalExpr for CaseExpr {
self.case_column_or_null(batch)
}
EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
}
}

Expand Down Expand Up @@ -1174,6 +1219,45 @@ mod tests {
Ok(())
}

#[test]
fn test_expr_or_expr_specialization() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();
let when = binary(
col("a", &schema)?,
Operator::LtEq,
lit(2i32),
&batch.schema(),
)?;
let then = binary(
col("a", &schema)?,
Operator::Plus,
lit(1i32),
&batch.schema(),
)?;
let else_expr = binary(
col("a", &schema)?,
Operator::Minus,
lit(1i32),
&batch.schema(),
)?;
let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
assert!(matches!(
expr.eval_method,
EvalMethod::ExpressionOrExpression
));
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_int32_array(&result).expect("failed to downcast to Int32Array");

let expected = &Int32Array::from(vec![Some(2), Some(1), None, Some(4)]);

assert_eq!(expected, result);
Ok(())
}

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,14 @@ query I
SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END;
----
1

# CASE WHEN with single predicate and two non-trivial branches (expr or expr usage)
query I
SELECT CASE WHEN a < 5 THEN a + b ELSE b - NVL(a, 0) END FROM foo
----
3
7
1
NULL
NULL
7

0 comments on commit 0f4b8b1

Please sign in to comment.