Skip to content

Commit

Permalink
Remove some Expr clones in EliminateCrossJoin(3%-5% faster planning) (
Browse files Browse the repository at this point in the history
#10430)

* Remove some Expr clones in `EliminateCrossJoin`

* Apply suggestions from code review

Co-authored-by: comphead <[email protected]>

* fix

---------

Co-authored-by: comphead <[email protected]>
  • Loading branch information
alamb and comphead authored May 11, 2024
1 parent 6d413a4 commit 1eff714
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 73 deletions.
123 changes: 50 additions & 73 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
// under the License.

//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
use std::collections::HashSet;
use std::sync::Arc;

use crate::{utils, OptimizerConfig, OptimizerRule};

use crate::join_key_set::JoinKeySet;
use datafusion_common::{plan_err, Result};
use datafusion_expr::expr::{BinaryExpr, Expr};
use datafusion_expr::logical_plan::{
Expand Down Expand Up @@ -55,7 +55,7 @@ impl OptimizerRule for EliminateCrossJoin {
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let mut possible_join_keys: Vec<(Expr, Expr)> = vec![];
let mut possible_join_keys = JoinKeySet::new();
let mut all_inputs: Vec<LogicalPlan> = vec![];
let parent_predicate = match plan {
LogicalPlan::Filter(filter) => {
Expand All @@ -76,7 +76,7 @@ impl OptimizerRule for EliminateCrossJoin {
extract_possible_join_keys(
&filter.predicate,
&mut possible_join_keys,
)?;
);
Some(&filter.predicate)
}
_ => {
Expand All @@ -101,7 +101,7 @@ impl OptimizerRule for EliminateCrossJoin {
};

// Join keys are handled locally:
let mut all_join_keys = HashSet::<(Expr, Expr)>::new();
let mut all_join_keys = JoinKeySet::new();
let mut left = all_inputs.remove(0);
while !all_inputs.is_empty() {
left = find_inner_join(
Expand Down Expand Up @@ -131,7 +131,7 @@ impl OptimizerRule for EliminateCrossJoin {
.map(|f| Some(LogicalPlan::Filter(f)))
} else {
// Remove join expressions from filter:
match remove_join_expressions(predicate, &all_join_keys)? {
match remove_join_expressions(predicate.clone(), &all_join_keys) {
Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
.map(|f| Some(LogicalPlan::Filter(f))),
_ => Ok(Some(left)),
Expand All @@ -150,7 +150,7 @@ impl OptimizerRule for EliminateCrossJoin {
/// Returns a boolean indicating whether the flattening was successful.
fn try_flatten_join_inputs(
plan: &LogicalPlan,
possible_join_keys: &mut Vec<(Expr, Expr)>,
possible_join_keys: &mut JoinKeySet,
all_inputs: &mut Vec<LogicalPlan>,
) -> Result<bool> {
let children = match plan {
Expand All @@ -160,7 +160,7 @@ fn try_flatten_join_inputs(
// issue: https://github.com/apache/datafusion/issues/4844
return Ok(false);
}
possible_join_keys.extend(join.on.clone());
possible_join_keys.insert_all(join.on.iter());
vec![&join.left, &join.right]
}
LogicalPlan::CrossJoin(join) => {
Expand Down Expand Up @@ -204,8 +204,8 @@ fn try_flatten_join_inputs(
fn find_inner_join(
left_input: &LogicalPlan,
rights: &mut Vec<LogicalPlan>,
possible_join_keys: &[(Expr, Expr)],
all_join_keys: &mut HashSet<(Expr, Expr)>,
possible_join_keys: &JoinKeySet,
all_join_keys: &mut JoinKeySet,
) -> Result<LogicalPlan> {
for (i, right_input) in rights.iter().enumerate() {
let mut join_keys = vec![];
Expand All @@ -228,7 +228,7 @@ fn find_inner_join(

// Found one or more matching join keys
if !join_keys.is_empty() {
all_join_keys.extend(join_keys.clone());
all_join_keys.insert_all(join_keys.iter());
let right_input = rights.remove(i);
let join_schema = Arc::new(build_join_schema(
left_input.schema(),
Expand Down Expand Up @@ -265,90 +265,67 @@ fn find_inner_join(
}))
}

fn intersect(
accum: &mut Vec<(Expr, Expr)>,
vec1: &[(Expr, Expr)],
vec2: &[(Expr, Expr)],
) {
if !(vec1.is_empty() || vec2.is_empty()) {
for x1 in vec1.iter() {
for x2 in vec2.iter() {
if x1.0 == x2.0 && x1.1 == x2.1 || x1.1 == x2.0 && x1.0 == x2.1 {
accum.push((x1.0.clone(), x1.1.clone()));
}
}
}
}
}

/// Extract join keys from a WHERE clause
fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> {
fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
match op {
Operator::Eq => {
// Ensure that we don't add the same Join keys multiple times
if !(accum.contains(&(*left.clone(), *right.clone()))
|| accum.contains(&(*right.clone(), *left.clone())))
{
accum.push((*left.clone(), *right.clone()));
}
// insert handles ensuring we don't add the same Join keys multiple times
join_keys.insert(left, right);
}
Operator::And => {
extract_possible_join_keys(left, accum)?;
extract_possible_join_keys(right, accum)?
extract_possible_join_keys(left, join_keys);
extract_possible_join_keys(right, join_keys)
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::Or => {
let mut left_join_keys = vec![];
let mut right_join_keys = vec![];
let mut left_join_keys = JoinKeySet::new();
let mut right_join_keys = JoinKeySet::new();

extract_possible_join_keys(left, &mut left_join_keys)?;
extract_possible_join_keys(right, &mut right_join_keys)?;
extract_possible_join_keys(left, &mut left_join_keys);
extract_possible_join_keys(right, &mut right_join_keys);

intersect(accum, &left_join_keys, &right_join_keys)
join_keys.insert_intersection(left_join_keys, right_join_keys)
}
_ => (),
};
}
Ok(())
}

/// Remove join expressions from a filter expression
/// Returns Some() when there are few remaining predicates in filter_expr
/// Returns None otherwise
fn remove_join_expressions(
expr: &Expr,
join_keys: &HashSet<(Expr, Expr)>,
) -> Result<Option<Expr>> {
///
/// # Returns
/// * `Some()` when there are few remaining predicates in filter_expr
/// * `None` otherwise
fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
match op {
Operator::Eq => {
if join_keys.contains(&(*left.clone(), *right.clone()))
|| join_keys.contains(&(*right.clone(), *left.clone()))
{
Ok(None)
} else {
Ok(Some(expr.clone()))
}
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::And | Operator::Or => {
let l = remove_join_expressions(left, join_keys)?;
let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(Expr::BinaryExpr(
BinaryExpr::new(Box::new(ll), *op, Box::new(rr)),
))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
}
}
_ => Ok(Some(expr.clone())),
Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) if join_keys.contains(&left, &right) => {
// was a join key, so remove it
None
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Expr::BinaryExpr(BinaryExpr { left, op, right })
if matches!(op, Operator::And | Operator::Or) =>
{
let l = remove_join_expressions(*left, join_keys);
let r = remove_join_expressions(*right, join_keys);
match (l, r) {
(Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
Box::new(ll),
op,
Box::new(rr),
))),
(Some(ll), _) => Some(ll),
(_, Some(rr)) => Some(rr),
_ => None,
}
}
_ => Ok(Some(expr.clone())),

_ => Some(expr),
}
}

Expand Down
Loading

0 comments on commit 1eff714

Please sign in to comment.