Skip to content

Commit

Permalink
Move filtered SMJ Full filtered join out of join_partial phase (#13369
Browse files Browse the repository at this point in the history
)

* Move filtered SMJ Full filtered join out of `join_partial` phase

* Move filtered SMJ Full filtered join out of `join_partial` phase

* Move filtered SMJ Full filtered join out of `join_partial` phase
  • Loading branch information
comphead authored Nov 13, 2024
1 parent 5467a28 commit fd092e0
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 158 deletions.
43 changes: 21 additions & 22 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{
};
use datafusion::physical_plan::memory::MemoryExec;

use crate::fuzz_cases::join_fuzz::JoinTestType::{HjSmj, NljHj};
use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch_with_seed;

Expand Down Expand Up @@ -96,7 +97,7 @@ async fn test_inner_join_1k_filtered() {
JoinType::Inner,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -108,7 +109,7 @@ async fn test_inner_join_1k() {
JoinType::Inner,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -120,7 +121,7 @@ async fn test_left_join_1k() {
JoinType::Left,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -132,7 +133,7 @@ async fn test_left_join_1k_filtered() {
JoinType::Left,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -144,7 +145,7 @@ async fn test_right_join_1k() {
JoinType::Right,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -156,7 +157,7 @@ async fn test_right_join_1k_filtered() {
JoinType::Right,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -168,21 +169,19 @@ async fn test_full_join_1k() {
JoinType::Full,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

#[tokio::test]
// flaky for HjSmj case
// https://github.com/apache/datafusion/issues/12359
async fn test_full_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Full,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::NljHj], false)
.run_test(&[NljHj, HjSmj], false)
.await
}

Expand All @@ -194,7 +193,7 @@ async fn test_semi_join_1k() {
JoinType::LeftSemi,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -206,7 +205,7 @@ async fn test_semi_join_1k_filtered() {
JoinType::LeftSemi,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -218,7 +217,7 @@ async fn test_anti_join_1k() {
JoinType::LeftAnti,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -230,7 +229,7 @@ async fn test_anti_join_1k_filtered() {
JoinType::LeftAnti,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -242,7 +241,7 @@ async fn test_left_mark_join_1k() {
JoinType::LeftMark,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -254,7 +253,7 @@ async fn test_left_mark_join_1k_filtered() {
JoinType::LeftMark,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand Down Expand Up @@ -512,8 +511,8 @@ impl JoinFuzzTestCase {
nlj_formatted_sorted.sort_unstable();

if debug
&& ((join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows)
|| (join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows))
&& ((join_tests.contains(&NljHj) && nlj_rows != hj_rows)
|| (join_tests.contains(&HjSmj) && smj_rows != hj_rows))
{
let fuzz_debug = "fuzz_test_debug";
std::fs::remove_dir_all(fuzz_debug).unwrap_or(());
Expand All @@ -533,7 +532,7 @@ impl JoinFuzzTestCase {
"input2",
);

if join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows {
if join_tests.contains(&NljHj) && nlj_rows != hj_rows {
println!("=============== HashJoinExec ==================");
hj_formatted_sorted.iter().for_each(|s| println!("{}", s));
println!("=============== NestedLoopJoinExec ==================");
Expand All @@ -551,7 +550,7 @@ impl JoinFuzzTestCase {
);
}

if join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows {
if join_tests.contains(&HjSmj) && smj_rows != hj_rows {
println!("=============== HashJoinExec ==================");
hj_formatted_sorted.iter().for_each(|s| println!("{}", s));
println!("=============== SortMergeJoinExec ==================");
Expand All @@ -570,7 +569,7 @@ impl JoinFuzzTestCase {
}
}

if join_tests.contains(&JoinTestType::NljHj) {
if join_tests.contains(&NljHj) {
let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size);
assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str());

Expand All @@ -591,7 +590,7 @@ impl JoinFuzzTestCase {
}
}

if join_tests.contains(&JoinTestType::HjSmj) {
if join_tests.contains(&HjSmj) {
let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size);
assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str());

Expand Down
Loading

0 comments on commit fd092e0

Please sign in to comment.