From e1c00423840fae9c9a71313b1e63e260de356838 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sat, 13 Jul 2024 17:34:45 +0800 Subject: [PATCH] fix: make sure JOIN ON expression is boolean type (#11423) * fix: make sure JOIN ON expression is boolean type * Applied to DataFrame * Update datafusion/optimizer/src/analyzer/type_coercion.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/dataframe/mod.rs | 31 +++++++++++++++++-- .../optimizer/src/analyzer/type_coercion.rs | 17 +++++++++- datafusion/sqllogictest/test_files/join.slt | 12 ++++++- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 05a08a6378930..c55b7c752765d 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -896,9 +896,8 @@ impl DataFrame { join_type: JoinType, on_exprs: impl IntoIterator, ) -> Result { - let expr = on_exprs.into_iter().reduce(Expr::and); let plan = LogicalPlanBuilder::from(self.plan) - .join_on(right.plan, join_type, expr)? + .join_on(right.plan, join_type, on_exprs)? .build()?; Ok(DataFrame { session_state: self.session_state, @@ -1694,7 +1693,7 @@ mod tests { use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; use arrow::array::{self, Int32Array}; - use datafusion_common::{Constraint, Constraints}; + use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, @@ -2555,6 +2554,32 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_on_filter_datatype() -> Result<()> { + let left = test_table_with_name("a").await?.select_columns(&["c1"])?; + let right = test_table_with_name("b").await?.select_columns(&["c1"])?; + + // JOIN ON untyped NULL + let join = left.clone().join_on( + right.clone(), + JoinType::Inner, + Some(Expr::Literal(ScalarValue::Null)), + )?; + let expected_plan = "CrossJoin:\ + \n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\ + \n TableScan: b projection=[c1]"; + assert_eq!(expected_plan, format!("{:?}", join.into_optimized_plan()?)); + + // JOIN ON expression must be boolean type + let join = left.join_on(right, JoinType::Inner, Some(lit("TRUE")))?; + let expected = join.into_optimized_plan().unwrap_err(); + assert_eq!( + expected.strip_backtrace(), + "type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8" + ); + Ok(()) + } + #[tokio::test] async fn join_ambiguous_filter() -> Result<()> { let left = test_table_with_name("a") diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3cab474df84e0..80a8c864e4311 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -127,7 +127,7 @@ impl<'a> TypeCoercionRewriter<'a> { Self { schema } } - /// Coerce join equality expressions + /// Coerce join equality expressions and join filter /// /// Joins must be treated specially as their equality expressions are stored /// as a parallel list of left and right expressions, rather than a single @@ -151,9 +151,24 @@ impl<'a> TypeCoercionRewriter<'a> { }) .collect::>>()?; + // Join filter must be boolean + join.filter = join + .filter + .map(|expr| self.coerce_join_filter(expr)) + .transpose()?; + Ok(LogicalPlan::Join(join)) } + fn coerce_join_filter(&self, expr: Expr) -> Result { + let expr_type = expr.get_type(self.schema)?; + match expr_type { + DataType::Boolean => Ok(expr), + DataType::Null => expr.cast_to(&DataType::Boolean, self.schema), + other => plan_err!("Join condition must be boolean type, but got {other:?}"), + } + } + fn coerce_binary_op( &self, left: Expr, diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 12cb8b3985c76..efebba1779cf7 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -988,7 +988,6 @@ statement ok DROP TABLE department -# Test issue: https://github.com/apache/datafusion/issues/11269 statement ok CREATE TABLE t1 (v0 BIGINT) AS VALUES (-503661263); @@ -998,11 +997,22 @@ CREATE TABLE t2 (v0 DOUBLE) AS VALUES (-1.663563947387); statement ok CREATE TABLE t3 (v0 DOUBLE) AS VALUES (0.05112015193508901); +# Test issue: https://github.com/apache/datafusion/issues/11269 query RR SELECT t3.v0, t2.v0 FROM t1,t2,t3 WHERE t3.v0 >= t1.v0; ---- 0.051120151935 -1.663563947387 +# Test issue: https://github.com/apache/datafusion/issues/11414 +query IRR +SELECT * FROM t1 INNER JOIN t2 ON NULL RIGHT JOIN t3 ON TRUE; +---- +NULL NULL 0.051120151935 + +# ON expression must be boolean type +query error DataFusion error: type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8 +SELECT * FROM t1 INNER JOIN t2 ON 'TRUE' + statement ok DROP TABLE t1;