Skip to content

Commit

Permalink
refactor: simplify code of eliminate_cross_join.rs (#7561)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener authored Sep 16, 2023
1 parent 61ed374 commit 93f78b2
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 79 deletions.
34 changes: 10 additions & 24 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,31 +959,17 @@ pub fn find_valid_equijoin_key_pair(
return Ok(None);
}

let l_is_left =
check_all_columns_from_schema(&left_using_columns, left_schema.clone())?;
let r_is_right =
check_all_columns_from_schema(&right_using_columns, right_schema.clone())?;

let r_is_left_and_l_is_right = || {
let result =
check_all_columns_from_schema(&right_using_columns, left_schema.clone())?
&& check_all_columns_from_schema(
&left_using_columns,
right_schema.clone(),
)?;

Result::<_>::Ok(result)
};

let join_key_pair = match (l_is_left, r_is_right) {
(true, true) => Some((left_key.clone(), right_key.clone())),
(_, _) if r_is_left_and_l_is_right()? => {
Some((right_key.clone(), left_key.clone()))
}
_ => None,
};
if check_all_columns_from_schema(&left_using_columns, left_schema.clone())?
&& check_all_columns_from_schema(&right_using_columns, right_schema.clone())?
{
return Ok(Some((left_key.clone(), right_key.clone())));
} else if check_all_columns_from_schema(&right_using_columns, left_schema)?
&& check_all_columns_from_schema(&left_using_columns, right_schema)?
{
return Ok(Some((right_key.clone(), left_key.clone())));
}

Ok(join_key_pair)
Ok(None)
}

/// Creates a detailed error message for a function with wrong signature.
Expand Down
93 changes: 38 additions & 55 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use datafusion_expr::logical_plan::{
CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
};
use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
use datafusion_expr::{and, build_join_schema, or, ExprSchemable, Operator};
use datafusion_expr::{build_join_schema, ExprSchemable, Operator};

#[derive(Default)]
pub struct EliminateCrossJoin;
Expand Down Expand Up @@ -61,14 +61,11 @@ impl OptimizerRule for EliminateCrossJoin {
let mut possible_join_keys: Vec<(Expr, Expr)> = vec![];
let mut all_inputs: Vec<LogicalPlan> = vec![];
let did_flat_successfully = match &input {
LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => {
try_flatten_join_inputs(
&input,
&mut possible_join_keys,
&mut all_inputs,
)?
}
LogicalPlan::CrossJoin(_) => try_flatten_join_inputs(
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
})
| LogicalPlan::CrossJoin(_) => try_flatten_join_inputs(
&input,
&mut possible_join_keys,
&mut all_inputs,
Expand Down Expand Up @@ -164,16 +161,11 @@ fn try_flatten_join_inputs(

for child in children.iter() {
match *child {
LogicalPlan::Join(left_join) => {
if left_join.join_type == JoinType::Inner {
if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? {
return Ok(false);
}
} else {
all_inputs.push((*child).clone());
}
}
LogicalPlan::CrossJoin(_) => {
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
})
| LogicalPlan::CrossJoin(_) => {
if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? {
return Ok(false);
}
Expand Down Expand Up @@ -202,13 +194,10 @@ fn find_inner_join(
)?;

// Save join keys
match key_pair {
Some((valid_l, valid_r)) => {
if can_hash(&valid_l.get_type(left_input.schema())?) {
join_keys.push((valid_l, valid_r));
}
if let Some((valid_l, valid_r)) = key_pair {
if can_hash(&valid_l.get_type(left_input.schema())?) {
join_keys.push((valid_l, valid_r));
}
_ => continue,
}
}

Expand Down Expand Up @@ -303,39 +292,33 @@ fn remove_join_expressions(
join_keys: &HashSet<(Expr, Expr)>,
) -> Result<Option<Expr>> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
Operator::Eq => {
if join_keys.contains(&(*left.clone(), *right.clone()))
|| join_keys.contains(&(*right.clone(), *left.clone()))
{
Ok(None)
} else {
Ok(Some(expr.clone()))
}
}
Operator::And => {
let l = remove_join_expressions(left, join_keys)?;
let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(and(ll, rr))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
match op {
Operator::Eq => {
if join_keys.contains(&(*left.clone(), *right.clone()))
|| join_keys.contains(&(*right.clone(), *left.clone()))
{
Ok(None)
} else {
Ok(Some(expr.clone()))
}
}
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::Or => {
let l = remove_join_expressions(left, join_keys)?;
let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(or(ll, rr))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::And | Operator::Or => {
let l = remove_join_expressions(left, join_keys)?;
let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(Expr::BinaryExpr(
BinaryExpr::new(Box::new(ll), *op, Box::new(rr)),
))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
}
}
_ => Ok(Some(expr.clone())),
}
_ => Ok(Some(expr.clone())),
},
}
_ => Ok(Some(expr.clone())),
}
}
Expand Down

0 comments on commit 93f78b2

Please sign in to comment.