Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make sure JOIN ON expression is boolean type #11423

Merged
merged 3 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,8 @@ impl DataFrame {
join_type: JoinType,
on_exprs: impl IntoIterator<Item = Expr>,
) -> Result<DataFrame> {
let expr = on_exprs.into_iter().reduce(Expr::and);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's redundant because the reduce operation will be performed inside the join_on function.

let plan = LogicalPlanBuilder::from(self.plan)
.join_on(right.plan, join_type, expr)?
.join_on(right.plan, join_type, on_exprs)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
Expand Down Expand Up @@ -1694,7 +1693,7 @@ mod tests {
use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name};

use arrow::array::{self, Int32Array};
use datafusion_common::{Constraint, Constraints};
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction,
Expand Down Expand Up @@ -2555,6 +2554,32 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn join_on_filter_datatype() -> Result<()> {
let left = test_table_with_name("a").await?.select_columns(&["c1"])?;
let right = test_table_with_name("b").await?.select_columns(&["c1"])?;

// JOIN ON untyped NULL
let join = left.clone().join_on(
right.clone(),
JoinType::Inner,
Some(Expr::Literal(ScalarValue::Null)),
)?;
let expected_plan = "CrossJoin:\
\n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\
\n TableScan: b projection=[c1]";
assert_eq!(expected_plan, format!("{:?}", join.into_optimized_plan()?));

// JOIN ON expression must be boolean type
let join = left.join_on(right, JoinType::Inner, Some(lit("TRUE")))?;
let expected = join.into_optimized_plan().unwrap_err();
assert_eq!(
expected.strip_backtrace(),
"type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8"
);
Ok(())
}

#[tokio::test]
async fn join_ambiguous_filter() -> Result<()> {
let left = test_table_with_name("a")
Expand Down
17 changes: 16 additions & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl<'a> TypeCoercionRewriter<'a> {
Self { schema }
}

/// Coerce join equality expressions
/// Coerce join equality expressions and join filter
///
/// Joins must be treated specially as their equality expressions are stored
/// as a parallel list of left and right expressions, rather than a single
Expand All @@ -151,9 +151,24 @@ impl<'a> TypeCoercionRewriter<'a> {
})
.collect::<Result<Vec<_>>>()?;

// Join filter must be boolean
join.filter = join
.filter
.map(|expr| self.coerce_join_filter(expr))
.transpose()?;

Ok(LogicalPlan::Join(join))
}

fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
let expr_type = expr.get_type(self.schema)?;
match expr_type {
DataType::Boolean => Ok(expr),
DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
other => plan_err!("Join condition must be boolean type, but got {other:?}"),
}
}

fn coerce_binary_op(
&self,
left: Expr,
Expand Down
12 changes: 11 additions & 1 deletion datafusion/sqllogictest/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,6 @@ statement ok
DROP TABLE department


# Test issue: https://github.com/apache/datafusion/issues/11269
statement ok
CREATE TABLE t1 (v0 BIGINT) AS VALUES (-503661263);

Expand All @@ -998,11 +997,22 @@ CREATE TABLE t2 (v0 DOUBLE) AS VALUES (-1.663563947387);
statement ok
CREATE TABLE t3 (v0 DOUBLE) AS VALUES (0.05112015193508901);

# Test issue: https://github.com/apache/datafusion/issues/11269
query RR
SELECT t3.v0, t2.v0 FROM t1,t2,t3 WHERE t3.v0 >= t1.v0;
----
0.051120151935 -1.663563947387

# Test issue: https://github.com/apache/datafusion/issues/11414
query IRR
SELECT * FROM t1 INNER JOIN t2 ON NULL RIGHT JOIN t3 ON TRUE;
----
NULL NULL 0.051120151935

# ON expression must be boolean type
query error DataFusion error: type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8
SELECT * FROM t1 INNER JOIN t2 ON 'TRUE'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This query is available in PostgreSQL, but not available in Spark.
If we plan to support it, we can directly cast the ON expression to boolean without checking its current type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error is clear and is a reasonable behavior. If it is important for someone's usecase, perhaps we can add support for it then

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the structure you have now implemented (in the analyzer) makes it pretty easy to support if desired


statement ok
DROP TABLE t1;

Expand Down