Skip to content

Commit

Permalink
Remove logical cross join in planning (#12985)
Browse files Browse the repository at this point in the history
* Remove logical cross join in planning

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* Implement some more substrait pieces

* Update datafusion/core/src/physical_planner.rs

Co-authored-by: Oleks V <[email protected]>

* Remove incorrect comment

---------

Co-authored-by: Oleks V <[email protected]>
  • Loading branch information
Dandandan and comphead authored Oct 18, 2024
1 parent 10af8a7 commit 34bd823
Show file tree
Hide file tree
Showing 18 changed files with 117 additions and 110 deletions.
22 changes: 13 additions & 9 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ use datafusion_expr::expr::{
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr,
DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr,
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
Expand Down Expand Up @@ -1045,14 +1045,18 @@ impl DefaultPhysicalPlanner {
session_state.config_options().optimizer.prefer_hash_join;

let join: Arc<dyn ExecutionPlan> = if join_on.is_empty() {
// there is no equal join condition, use the nested loop join
// TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins`
Arc::new(NestedLoopJoinExec::try_new(
physical_left,
physical_right,
join_filter,
join_type,
)?)
if join_filter.is_none() && matches!(join_type, JoinType::Inner) {
// cross join if there is no join conditions and no join filter set
Arc::new(CrossJoinExec::new(physical_left, physical_right))
} else {
// there is no equal join condition, use the nested loop join
Arc::new(NestedLoopJoinExec::try_new(
physical_left,
physical_right,
join_filter,
join_type,
)?)
}
} else if session_state.config().target_partitions() > 1
&& session_state.config().repartition_joins()
&& !prefer_hash_join
Expand Down
11 changes: 8 additions & 3 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::expr_rewriter::{
rewrite_sort_cols_by_aggs,
};
use crate::logical_plan::{
Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter,
Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join,
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
Expand Down Expand Up @@ -950,9 +950,14 @@ impl LogicalPlanBuilder {
pub fn cross_join(self, right: LogicalPlan) -> Result<Self> {
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?;
Ok(Self::new(LogicalPlan::CrossJoin(CrossJoin {
Ok(Self::new(LogicalPlan::Join(Join {
left: self.plan,
right: Arc::new(right),
on: vec![],
filter: None,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
null_equals_null: false,
schema: DFSchemaRef::new(join_schema),
})))
}
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ pub enum LogicalPlan {
Join(Join),
/// Apply Cross Join to two logical plans.
/// This is used to implement SQL `CROSS JOIN`
/// Deprecated: use [LogicalPlan::Join] instead with empty `on` / no filter
CrossJoin(CrossJoin),
/// Repartitions the input based on a partitioning scheme. This is
/// used to add parallelism and is sometimes referred to as an
Expand Down Expand Up @@ -1873,6 +1874,11 @@ impl LogicalPlan {
.as_ref()
.map(|expr| format!(" Filter: {expr}"))
.unwrap_or_else(|| "".to_string());
let join_type = if filter.is_none() && keys.is_empty() && matches!(join_type, JoinType::Inner) {
"Cross".to_string()
} else {
join_type.to_string()
};
match join_constraint {
JoinConstraint::On => {
write!(
Expand Down
25 changes: 15 additions & 10 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, Result};
use datafusion_expr::expr::{BinaryExpr, Expr};
use datafusion_expr::logical_plan::{
CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
};
use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
use datafusion_expr::{build_join_schema, ExprSchemable, Operator};
Expand All @@ -51,7 +51,7 @@ impl EliminateCrossJoin {
/// Looks like this:
/// ```text
/// Filter(a.x = b.y AND b.xx = 100)
/// CrossJoin
/// Cross Join
/// TableScan a
/// TableScan b
/// ```
Expand Down Expand Up @@ -351,10 +351,15 @@ fn find_inner_join(
&JoinType::Inner,
)?);

Ok(LogicalPlan::CrossJoin(CrossJoin {
Ok(LogicalPlan::Join(Join {
left: Arc::new(left_input),
right: Arc::new(right),
schema: join_schema,
on: vec![],
filter: None,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
null_equals_null: false,
}))
}

Expand Down Expand Up @@ -513,7 +518,7 @@ mod tests {

let expected = vec![
"Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand Down Expand Up @@ -601,7 +606,7 @@ mod tests {

let expected = vec![
"Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand All @@ -627,7 +632,7 @@ mod tests {

let expected = vec![
"Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand Down Expand Up @@ -843,7 +848,7 @@ mod tests {

let expected = vec![
"Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
Expand Down Expand Up @@ -924,7 +929,7 @@ mod tests {
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand Down Expand Up @@ -999,7 +1004,7 @@ mod tests {
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
Expand Down Expand Up @@ -1238,7 +1243,7 @@ mod tests {

let expected = vec![
"Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand Down
26 changes: 1 addition & 25 deletions datafusion/optimizer/src/eliminate_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion_common::{Result, ScalarValue};
use datafusion_expr::JoinType::Inner;
use datafusion_expr::{
logical_plan::{EmptyRelation, LogicalPlan},
CrossJoin, Expr,
Expr,
};

/// Eliminates joins when join condition is false.
Expand Down Expand Up @@ -54,13 +54,6 @@ impl OptimizerRule for EliminateJoin {
match plan {
LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => {
match join.filter {
Some(Expr::Literal(ScalarValue::Boolean(Some(true)))) => {
Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin {
left: join.left,
right: join.right,
schema: join.schema,
})))
}
Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok(
Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
Expand Down Expand Up @@ -105,21 +98,4 @@ mod tests {
let expected = "EmptyRelation";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn join_on_true() -> Result<()> {
let plan = LogicalPlanBuilder::empty(false)
.join_on(
LogicalPlanBuilder::empty(false).build()?,
Inner,
Some(lit(true)),
)?
.build()?;

let expected = "\
CrossJoin:\
\n EmptyRelation\
\n EmptyRelation";
assert_optimized_plan_equal(plan, expected)
}
}
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ mod tests {
.build()?;

let expected = "Projection: test.a, test1.d\
\n CrossJoin:\
\n Cross Join: \
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a = Int32(1)]\
\n Projection: test1.d, test1.e, test1.f\
Expand All @@ -1754,7 +1754,7 @@ mod tests {
.build()?;

let expected = "Projection: test.a, test1.a\
\n CrossJoin:\
\n Cross Join: \
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a = Int32(1)]\
\n Projection: test1.a, test1.b, test1.c\
Expand Down
7 changes: 3 additions & 4 deletions datafusion/optimizer/src/push_down_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,9 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {

let (left_limit, right_limit) = if is_no_join_condition(&join) {
match join.join_type {
Left | Right | Full => (Some(limit), Some(limit)),
Left | Right | Full | Inner => (Some(limit), Some(limit)),
LeftAnti | LeftSemi => (Some(limit), None),
RightAnti | RightSemi => (None, Some(limit)),
Inner => (None, None),
}
} else {
match join.join_type {
Expand Down Expand Up @@ -1116,7 +1115,7 @@ mod test {
.build()?;

let expected = "Limit: skip=0, fetch=1000\
\n CrossJoin:\
\n Cross Join: \
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000\
\n Limit: skip=0, fetch=1000\
Expand All @@ -1136,7 +1135,7 @@ mod test {
.build()?;

let expected = "Limit: skip=1000, fetch=1000\
\n CrossJoin:\
\n Cross Join: \
\n Limit: skip=0, fetch=2000\
\n TableScan: test, fetch=2000\
\n Limit: skip=0, fetch=2000\
Expand Down
4 changes: 3 additions & 1 deletion datafusion/sql/src/relation/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.build()
}
}
JoinConstraint::None => not_impl_err!("NONE constraint is not supported"),
JoinConstraint::None => LogicalPlanBuilder::from(left)
.join_on(right, join_type, [])?
.build(),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ fn roundtrip_crossjoin() -> Result<()> {
.unwrap();

let expected = "Projection: j1.j1_id, j2.j2_string\
\n Inner Join: Filter: Boolean(true)\
\n Cross Join: \
\n TableScan: j1\
\n TableScan: j2";

Expand Down
Loading

0 comments on commit 34bd823

Please sign in to comment.