Skip to content

Commit

Permalink
[MINOR]: Add new test for filter pushdown into cross join (#8648)
Browse files Browse the repository at this point in the history
* Initial commit

* Minor changes

* Simplifications

* Update UDF example

* Address review

---------

Co-authored-by: Mehmet Ozan Kabak <[email protected]>
  • Loading branch information
mustafasrepo and ozankabak authored Dec 25, 2023
1 parent 18c7566 commit ec8fd44
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 18 deletions.
1 change: 1 addition & 0 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl EliminateCrossJoin {
/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
/// or (a.x = b.y and b.xx = 200 and a.z=c.z);'
/// 'select ... from a, b where a.x > b.y'
/// For above queries, the join predicate is available in filters and they are moved to
/// join nodes appropriately
/// This fix helps to improve the performance of TPCH Q19. issue#78
Expand Down
12 changes: 9 additions & 3 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -965,11 +965,11 @@ impl PushDownFilter {
}
}

/// Convert cross join to join by pushing down filter predicate to the join condition
/// Converts the given cross join to an inner join with an empty equality
/// predicate and an empty filter condition.
fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result<Join> {
let CrossJoin { left, right, .. } = cross_join;
let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?;
// predicate is given
Ok(Join {
left,
right,
Expand All @@ -982,7 +982,8 @@ fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result<Join> {
})
}

/// Converts the inner join with empty equality predicate and empty filter condition to the cross join
/// Converts the given inner join with an empty equality predicate and an
/// empty filter condition to a cross join.
fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result<LogicalPlan> {
if let LogicalPlan::Join(join) = &plan {
// Can be converted back to cross join
Expand All @@ -991,6 +992,11 @@ fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result<LogicalPlan>
.cross_join(join.right.as_ref().clone())?
.build();
}
} else if let LogicalPlan::Filter(filter) = &plan {
let new_input =
convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?;
return Filter::try_new(filter.predicate.clone(), Arc::new(new_input))
.map(LogicalPlan::Filter);
}
Ok(plan)
}
Expand Down
61 changes: 46 additions & 15 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,33 @@
// specific language governing permissions and limitations
// under the License.

use async_trait::async_trait;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;

use arrow::array::{
ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray,
StringArray, TimestampNanosecondArray,
};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::record_batch::RecordBatch;
use datafusion::execution::context::SessionState;
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility};
use datafusion::physical_expr::functions::make_scalar_function;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion::{
arrow::{
array::{
BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray,
StringArray, TimestampNanosecondArray,
},
datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
record_batch::RecordBatch,
},
catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider},
datasource::{MemTable, TableProvider, TableType},
prelude::{CsvReadOptions, SessionContext},
};
use datafusion_common::cast::as_float64_array;
use datafusion_common::DataFusionError;

use async_trait::async_trait;
use log::info;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;
use tempfile::TempDir;

/// Context for running tests
Expand Down Expand Up @@ -102,6 +104,8 @@ impl TestContext {
}
"joins.slt" => {
info!("Registering partition table tables");
let example_udf = create_example_udf();
test_ctx.ctx.register_udf(example_udf);
register_partition_table(&mut test_ctx).await;
}
"metadata.slt" => {
Expand Down Expand Up @@ -348,3 +352,30 @@ pub async fn register_metadata_tables(ctx: &SessionContext) {

ctx.register_batch("table_with_metadata", batch).unwrap();
}

/// Create a UDF function named "example". See the `sample_udf.rs` example
/// file for an explanation of the API.
fn create_example_udf() -> ScalarUDF {
let adder = make_scalar_function(|args: &[ArrayRef]| {
let lhs = as_float64_array(&args[0]).expect("cast failed");
let rhs = as_float64_array(&args[1]).expect("cast failed");
let array = lhs
.iter()
.zip(rhs.iter())
.map(|(lhs, rhs)| match (lhs, rhs) {
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
_ => None,
})
.collect::<Float64Array>();
Ok(Arc::new(array) as ArrayRef)
});
create_udf(
"example",
// Expects two f64 values:
vec![DataType::Float64, DataType::Float64],
// Returns an f64 value:
Arc::new(DataType::Float64),
Volatility::Immutable,
adder,
)
}
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3483,6 +3483,28 @@ NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1
----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true
--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true

# Currently datafusion cannot pushdown filter conditions with scalar UDF into
# cross join.
query TT
EXPLAIN SELECT *
FROM annotated_data as t1, annotated_data as t2
WHERE EXAMPLE(t1.a, t2.a) > 3
----
logical_plan
Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3)
--CrossJoin:
----SubqueryAlias: t1
------TableScan: annotated_data projection=[a0, a, b, c, d]
----SubqueryAlias: t2
------TableScan: annotated_data projection=[a0, a, b, c, d]
physical_plan
CoalesceBatchesExec: target_batch_size=2
--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3
----CrossJoinExec
------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true
------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true

####
# Config teardown
####
Expand Down

0 comments on commit ec8fd44

Please sign in to comment.