From fd092e0579da5121ac75130f6b2e92da47034308 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Tue, 12 Nov 2024 22:22:32 -0800 Subject: [PATCH] Move filtered SMJ Full filtered join out of `join_partial` phase (#13369) * Move filtered SMJ Full filtered join out of `join_partial` phase * Move filtered SMJ Full filtered join out of `join_partial` phase * Move filtered SMJ Full filtered join out of `join_partial` phase --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 43 ++- .../src/joins/sort_merge_join.rs | 336 ++++++++++++------ .../test_files/sort_merge_join.slt | 33 +- 3 files changed, 254 insertions(+), 158 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index d7a3460e4987..cf1742a30e66 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::memory::MemoryExec; +use crate::fuzz_cases::join_fuzz::JoinTestType::{HjSmj, NljHj}; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; @@ -96,7 +97,7 @@ async fn test_inner_join_1k_filtered() { JoinType::Inner, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -108,7 +109,7 @@ async fn test_inner_join_1k() { JoinType::Inner, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -120,7 +121,7 @@ async fn test_left_join_1k() { JoinType::Left, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -132,7 +133,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -144,7 +145,7 @@ async fn test_right_join_1k() { JoinType::Right, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -156,7 +157,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -168,13 +169,11 @@ async fn test_full_join_1k() { JoinType::Full, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_full_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -182,7 +181,7 @@ async fn test_full_join_1k_filtered() { JoinType::Full, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[NljHj, HjSmj], false) .await } @@ -194,7 +193,7 @@ async fn test_semi_join_1k() { JoinType::LeftSemi, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -206,7 +205,7 @@ async fn test_semi_join_1k_filtered() { JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -218,7 +217,7 @@ async fn test_anti_join_1k() { JoinType::LeftAnti, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -230,7 +229,7 @@ async fn test_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -242,7 +241,7 @@ async fn test_left_mark_join_1k() { JoinType::LeftMark, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -254,7 +253,7 @@ async fn test_left_mark_join_1k_filtered() { JoinType::LeftMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -512,8 +511,8 @@ impl JoinFuzzTestCase { nlj_formatted_sorted.sort_unstable(); if debug - && ((join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows) - || (join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows)) + && ((join_tests.contains(&NljHj) && nlj_rows != hj_rows) + || (join_tests.contains(&HjSmj) && smj_rows != hj_rows)) { let fuzz_debug = "fuzz_test_debug"; std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); @@ -533,7 +532,7 @@ impl JoinFuzzTestCase { "input2", ); - if join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows { + if join_tests.contains(&NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== NestedLoopJoinExec =================="); @@ -551,7 +550,7 @@ impl JoinFuzzTestCase { ); } - if join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows { + if join_tests.contains(&HjSmj) && smj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== SortMergeJoinExec =================="); @@ -570,7 +569,7 @@ impl JoinFuzzTestCase { } } - if join_tests.contains(&JoinTestType::NljHj) { + if join_tests.contains(&NljHj) { let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); @@ -591,7 +590,7 @@ impl JoinFuzzTestCase { } } - if join_tests.contains(&JoinTestType::HjSmj) { + if join_tests.contains(&HjSmj) { let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 1eb6ea632923..9307caf1c6ad 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -852,6 +852,54 @@ fn get_corrected_filter_mask( corrected_mask.extend(vec![Some(true); null_matched]); Some(corrected_mask.finish()) } + JoinType::Full => { + let mut mask: Vec> = vec![Some(true); row_indices_length]; + let mut last_true_idx = 0; + let mut first_row_idx = 0; + let mut seen_false = false; + + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + let val = filter_mask.value(i); + let is_null = filter_mask.is_null(i); + + if val { + if !seen_true { + last_true_idx = i; + } + seen_true = true; + } + + if is_null || val { + mask[i] = Some(true); + } else if !is_null && !val && (seen_true || seen_false) { + mask[i] = None; + } else { + mask[i] = Some(false); + } + + if !is_null && !val { + seen_false = true; + } + + if last_index { + if seen_true { + #[allow(clippy::needless_range_loop)] + for j in first_row_idx..last_true_idx { + mask[j] = None; + } + } + + seen_true = false; + seen_false = false; + last_true_idx = 0; + first_row_idx = i + 1; + } + } + + Some(BooleanArray::from(mask)) + } // Only outer joins needs to keep track of processed rows and apply corrected filter mask _ => None, } @@ -887,6 +935,7 @@ impl Stream for SMJStream { | JoinType::LeftMark | JoinType::Right | JoinType::LeftAnti + | JoinType::Full ) { self.freeze_all()?; @@ -969,6 +1018,7 @@ impl Stream for SMJStream { | JoinType::Right | JoinType::LeftAnti | JoinType::LeftMark + | JoinType::Full ) { continue; @@ -990,6 +1040,7 @@ impl Stream for SMJStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::Full | JoinType::LeftMark ) { @@ -1171,9 +1222,10 @@ impl SMJStream { // If the head batch is fully processed, dequeue it and produce output of it. if head_batch.range.end == head_batch.num_rows { self.freeze_dequeuing_buffered()?; - if let Some(buffered_batch) = + if let Some(mut buffered_batch) = self.buffered_data.batches.pop_front() { + self.produce_buffered_not_matched(&mut buffered_batch)?; self.free_reservation(buffered_batch)?; } } else { @@ -1401,8 +1453,8 @@ impl SMJStream { } fn freeze_all(&mut self) -> Result<()> { + self.freeze_buffered(self.buffered_data.batches.len())?; self.freeze_streamed()?; - self.freeze_buffered(self.buffered_data.batches.len(), false)?; Ok(()) } @@ -1413,7 +1465,7 @@ impl SMJStream { fn freeze_dequeuing_buffered(&mut self) -> Result<()> { self.freeze_streamed()?; // Only freeze and produce the first batch in buffered_data as the batch is fully processed - self.freeze_buffered(1, true)?; + self.freeze_buffered(1)?; Ok(()) } @@ -1422,13 +1474,7 @@ impl SMJStream { // // Applicable only in case of Full join. // - // If `output_not_matched_filter` is true, this will also produce record batches - // for buffered rows which are joined with streamed side but don't match join filter. - fn freeze_buffered( - &mut self, - batch_count: usize, - output_not_matched_filter: bool, - ) -> Result<()> { + fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { if !matches!(self.join_type, JoinType::Full) { return Ok(()); } @@ -1442,34 +1488,66 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { + let num_rows = record_batch.num_rows(); + self.output_record_batches + .filter_mask + .extend(&BooleanArray::from(vec![None; num_rows])); + self.output_record_batches + .row_indices + .extend(&UInt64Array::from(vec![None; num_rows])); + self.output_record_batches + .batch_ids + .extend(vec![0; num_rows]); + self.output_record_batches.batches.push(record_batch); } buffered_batch.null_joined.clear(); + } + Ok(()) + } - // For buffered row which is joined with streamed side rows but all joined rows - // don't satisfy the join filter - if output_not_matched_filter { - let not_matched_buffered_indices = buffered_batch - .join_filter_failed_map - .iter() - .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) - .collect::>(); + fn produce_buffered_not_matched( + &mut self, + buffered_batch: &mut BufferedBatch, + ) -> Result<()> { + if !matches!(self.join_type, JoinType::Full) { + return Ok(()); + } - let buffered_indices = UInt64Array::from_iter_values( - not_matched_buffered_indices.iter().copied(), - ); + // For buffered row which is joined with streamed side rows but all joined rows + // don't satisfy the join filter + let not_matched_buffered_indices = buffered_batch + .join_filter_failed_map + .iter() + .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) + .collect::>(); - if let Some(record_batch) = produce_buffered_null_batch( - &self.schema, - &self.streamed_schema, - &buffered_indices, - buffered_batch, - )? { - self.output_record_batches.batches.push(record_batch); - } - buffered_batch.join_filter_failed_map.clear(); - } + let buffered_indices = + UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied()); + + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + //print_batches(&[record_batch.clone()]); + let num_rows = record_batch.num_rows(); + + self.output_record_batches + .filter_mask + .extend(&BooleanArray::from(vec![None; num_rows])); + self.output_record_batches + .row_indices + .extend(&UInt64Array::from(vec![None; num_rows])); + self.output_record_batches + .batch_ids + .extend(vec![0; num_rows]); + self.output_record_batches.batches.push(record_batch); } + //dbg!(&buffered_batch.join_filter_failed_map); + buffered_batch.join_filter_failed_map.clear(); + Ok(()) } @@ -1514,8 +1592,6 @@ impl SMJStream { ) }; - let streamed_columns_length = streamed_columns.len(); - // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. let filter_columns = if chunk.buffered_batch_idx.is_some() { @@ -1587,6 +1663,7 @@ impl SMJStream { | JoinType::Right | JoinType::LeftAnti | JoinType::LeftMark + | JoinType::Full ) { self.output_record_batches .batches @@ -1596,7 +1673,11 @@ impl SMJStream { self.output_record_batches.batches.push(filtered_batch); } - self.output_record_batches.filter_mask.extend(&mask); + if !matches!(self.join_type, JoinType::Full) { + self.output_record_batches.filter_mask.extend(&mask); + } else { + self.output_record_batches.filter_mask.extend(pre_mask); + } self.output_record_batches .row_indices .extend(&streamed_indices); @@ -1610,83 +1691,26 @@ impl SMJStream { // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. if matches!(self.join_type, JoinType::Full) { - // We need to get the mask for row indices that the joined rows are failed - // on the join filter. I.e., for a row in streamed side, if all joined rows - // between it and all buffered rows are failed on the join filter, we need to - // output it with null columns from buffered side. For the mask here, it - // behaves like LeftAnti join. - let not_mask = if mask.null_count() > 0 { - // If the mask contains nulls, we need to use `prep_null_mask_filter` to - // handle the nulls in the mask as false to produce rows where the mask - // was null itself. - compute::not(&compute::prep_null_mask_filter(&mask))? - } else { - compute::not(&mask)? - }; + let buffered_batch = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()]; - let null_joined_batch = - filter_record_batch(&output_batch, ¬_mask)?; - - let buffered_columns = self - .buffered_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - null_joined_batch.num_rows(), - ) - }) - .collect::>(); - - let columns = { - let mut streamed_columns = null_joined_batch - .columns() - .iter() - .take(streamed_columns_length) - .cloned() - .collect::>(); - - streamed_columns.extend(buffered_columns); - streamed_columns - }; - - // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns)?; - - self.output_record_batches - .batches - .push(null_joined_streamed_batch); - - // For full join, we also need to output the null joined rows from the buffered side. - // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with - // streamed side, it won't be outputted by `freeze_buffered`. - // We need to check if a buffered row is joined with streamed side and output. - // If it is joined with streamed side, but doesn't match the join filter, - // we need to output it with nulls as streamed side. - if matches!(self.join_type, JoinType::Full) { - let buffered_batch = &mut self.buffered_data.batches - [chunk.buffered_batch_idx.unwrap()]; - - for i in 0..pre_mask.len() { - // If the buffered row is not joined with streamed side, - // skip it. - if buffered_indices.is_null(i) { - continue; - } + for i in 0..pre_mask.len() { + // If the buffered row is not joined with streamed side, + // skip it. + if buffered_indices.is_null(i) { + continue; + } - let buffered_index = buffered_indices.value(i); + let buffered_index = buffered_indices.value(i); - buffered_batch.join_filter_failed_map.insert( - buffered_index, - *buffered_batch - .join_filter_failed_map - .get(&buffered_index) - .unwrap_or(&true) - && !pre_mask.value(i), - ); - } + buffered_batch.join_filter_failed_map.insert( + buffered_index, + *buffered_batch + .join_filter_failed_map + .get(&buffered_index) + .unwrap_or(&true) + && !pre_mask.value(i), + ); } } } else { @@ -1726,6 +1750,7 @@ impl SMJStream { | JoinType::Right | JoinType::LeftAnti | JoinType::LeftMark + | JoinType::Full )) { self.output_record_batches.batches.clear(); @@ -1735,12 +1760,28 @@ impl SMJStream { fn filter_joined_batch(&mut self) -> Result { let record_batch = self.output_record_batch_and_reset()?; - let out_indices = self.output_record_batches.row_indices.finish(); - let out_mask = self.output_record_batches.filter_mask.finish(); + let mut out_indices = self.output_record_batches.row_indices.finish(); + let mut out_mask = self.output_record_batches.filter_mask.finish(); + let mut batch_ids = &self.output_record_batches.batch_ids; + let default_batch_ids = vec![0; record_batch.num_rows()]; + + if out_indices.null_count() == out_indices.len() + && out_indices.len() != record_batch.num_rows() + { + out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]); + out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]); + batch_ids = &default_batch_ids; + } + + if out_mask.is_empty() { + self.output_record_batches.batches.clear(); + return Ok(record_batch); + } + let maybe_corrected_mask = get_corrected_filter_mask( self.join_type, &out_indices, - &self.output_record_batches.batch_ids, + batch_ids, &out_mask, record_batch.num_rows(), ); @@ -1753,8 +1794,8 @@ impl SMJStream { let mut filtered_record_batch = filter_record_batch(&record_batch, corrected_mask)?; - let buffered_columns_length = self.buffered_schema.fields.len(); - let streamed_columns_length = self.streamed_schema.fields.len(); + let left_columns_length = self.streamed_schema.fields.len(); + let right_columns_length = self.buffered_schema.fields.len(); if matches!( self.join_type, @@ -1773,18 +1814,17 @@ impl SMJStream { let streamed_columns = null_joined_batch .columns() .iter() - .skip(buffered_columns_length) + .skip(left_columns_length) .cloned() .collect::>(); buffered_columns.extend(streamed_columns); buffered_columns } else { - // Left join or full outer join let mut streamed_columns = null_joined_batch .columns() .iter() - .take(streamed_columns_length) + .take(right_columns_length) .cloned() .collect::>(); @@ -1801,15 +1841,75 @@ impl SMJStream { &[filtered_record_batch, null_joined_streamed_batch], )?; } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { - let output_column_indices = (0..streamed_columns_length).collect::>(); + let output_column_indices = (0..left_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; + } else if matches!(self.join_type, JoinType::Full) + && corrected_mask.false_count() > 0 + { + // Find rows which joined by key but Filter predicate evaluated as false + let joined_filter_not_matched_mask = compute::not(corrected_mask)?; + let joined_filter_not_matched_batch = + filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?; + + // Add left unmatched rows adding the right side as nulls + let right_null_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + joined_filter_not_matched_batch.num_rows(), + ) + }) + .collect::>(); + + let mut result_joined = joined_filter_not_matched_batch + .columns() + .iter() + .take(left_columns_length) + .cloned() + .collect::>(); + + result_joined.extend(right_null_columns); + + let left_null_joined_batch = + RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?; + + // Add right unmatched rows adding the left side as nulls + let mut result_joined = self + .streamed_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + joined_filter_not_matched_batch.num_rows(), + ) + }) + .collect::>(); + + let right_data = joined_filter_not_matched_batch + .columns() + .iter() + .skip(left_columns_length) + .cloned() + .collect::>(); + + result_joined.extend(right_data); + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, left_null_joined_batch], + )?; } self.output_record_batches.batches.clear(); - self.output_record_batches.batch_ids = vec![]; + self.output_record_batches.batch_ids.clear(); self.output_record_batches.filter_mask = BooleanBuilder::new(); self.output_record_batches.row_indices = UInt64Builder::new(); + Ok(filtered_record_batch) } } diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index f4cc888d6b8e..9a20e7987ff6 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -126,24 +126,21 @@ Alice 50 Alice 1 Alice 50 Alice 2 Bob 1 NULL NULL -# Uncomment when filtered FULL moved -# full join with join filter -#query TITI rowsort -#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b -#---- -#Alice 100 NULL NULL -#Alice 50 Alice 2 -#Bob 1 NULL NULL -#NULL NULL Alice 1 - -# Uncomment when filtered FULL moved -#query TITI rowsort -#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 -#---- -#Alice 100 Alice 1 -#Alice 100 Alice 2 -#Alice 50 NULL NULL -#Bob 1 NULL NULL +query TITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b +---- +Alice 100 NULL NULL +Alice 50 Alice 2 +Bob 1 NULL NULL +NULL NULL Alice 1 + +query TITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 NULL NULL +Bob 1 NULL NULL statement ok DROP TABLE t1;