From ca564ceefc5c1fb468266ae4da2bc27376b882d7 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 29 Apr 2024 16:56:14 -0700 Subject: [PATCH 01/20] Fix: Sort Merge Join crashes on TPCH Q21 --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index d4cf6864d7e4..f27aab2721b5 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1161,6 +1161,12 @@ impl SMJStream { let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) + } else if matches!(self.join_type, JoinType::LeftSemi) { + let buffered_columns = self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()] + .batch + .columns(); + get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } else { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } @@ -1358,6 +1364,9 @@ fn get_filter_column( .map(|i| buffered_columns[i.index].clone()) .collect::>(); +// dbg!(&left_columns); +// dbg!(&right_columns); + filter_columns.extend(left_columns); filter_columns.extend(right_columns); } From fd21ccf8eb8afd4e499d55c73ad2d300af6a4dc0 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 3 May 2024 17:11:32 -0700 Subject: [PATCH 02/20] Fix: Sort Merge Join crashes on TPCH Q21. LeftSemi --- .../src/joins/sort_merge_join.rs | 37 ++++++++++++++++++- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index f27aab2721b5..011d5ec26ddc 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -46,6 +46,7 @@ use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +//use arrow_array::types::UInt64Type; use datafusion_common::{ internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, }; @@ -818,12 +819,15 @@ impl SMJStream { self.streamed_state = StreamedState::Exhausted; } Poll::Ready(Some(batch)) => { + //dbg!(&batch); + //println!("{:#?}", &batch); if batch.num_rows() > 0 { self.freeze_streamed()?; self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + self.streamed_state = StreamedState::Ready; } } @@ -991,6 +995,7 @@ impl SMJStream { Ordering::Equal => { if matches!(self.join_type, JoinType::LeftSemi) { join_streamed = !self.streamed_joined; + join_buffered = true; } if matches!( self.join_type, @@ -1004,6 +1009,11 @@ impl SMJStream { if matches!(self.join_type, JoinType::Full) { join_buffered = !self.buffered_joined; }; + + if matches!(self.join_type, JoinType::LeftAnti) { + join_streamed = !self.streamed_joined; + join_buffered = true; + }; } } if !join_streamed && !join_buffered { @@ -1047,6 +1057,7 @@ impl SMJStream { Some(self.buffered_data.scanning_batch_idx) }; + //dbg!(self.buffered_data.scanning_idx()); self.streamed_batch .append_output_pair(scanning_batch_idx, None); self.output_size += 1; @@ -1118,8 +1129,12 @@ impl SMJStream { // Produces and stages record batch for all output indices found // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { + //dbg!(&self.streamed_batch.batch); + for chunk in self.streamed_batch.output_indices.iter_mut() { let streamed_indices = chunk.streamed_indices.finish(); + dbg!(&streamed_indices); + //let streamed_indices = PrimitiveArray::::try_new(vec![0, 1].into(), None)?; if streamed_indices.is_empty() { continue; @@ -1134,6 +1149,8 @@ impl SMJStream { .collect::, ArrowError>>()?; let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); + dbg!(&buffered_indices); + dbg!(&self.join_type); let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { @@ -1161,11 +1178,15 @@ impl SMJStream { let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!(self.join_type, JoinType::LeftSemi) { + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { let buffered_columns = self.buffered_data.batches [chunk.buffered_batch_idx.unwrap()] .batch - .columns(); + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>()?; + //dbg!(&buffered_columns); get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } else { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) @@ -1175,6 +1196,10 @@ impl SMJStream { vec![] }; + dbg!(&streamed_columns); + dbg!(&buffered_columns); + dbg!(&filter_columns); + let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns.clone()); buffered_columns @@ -1195,17 +1220,25 @@ impl SMJStream { filter_columns, )?; + dbg!(&filter_batch); + dbg!(&f.expression()); + + let filter_result = f .expression() .evaluate(&filter_batch)? .into_array(filter_batch.num_rows())?; + dbg!(&filter_result); // The selection mask of the filter let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; // Push the filtered batch to the output let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; + + dbg!(&filtered_batch); + self.output_record_batches.push(filtered_batch); // For outer joins, we need to push the null joined rows to the output. From 64d7e5c8614c6455dbb66a99f6c51ad57297b679 Mon Sep 17 00:00:00 2001 From: comphead Date: Sat, 4 May 2024 09:21:06 -0700 Subject: [PATCH 03/20] Fix: Sort Merge Join Left Semi crashes --- .../src/joins/sort_merge_join.rs | 35 ++---- .../test_files/sort_merge_join.slt | 117 ++++++++++++++++++ 2 files changed, 124 insertions(+), 28 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 011d5ec26ddc..20fa40569fd6 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -819,8 +819,6 @@ impl SMJStream { self.streamed_state = StreamedState::Exhausted; } Poll::Ready(Some(batch)) => { - //dbg!(&batch); - //println!("{:#?}", &batch); if batch.num_rows() > 0 { self.freeze_streamed()?; self.join_metrics.input_batches.add(1); @@ -1057,7 +1055,6 @@ impl SMJStream { Some(self.buffered_data.scanning_batch_idx) }; - //dbg!(self.buffered_data.scanning_idx()); self.streamed_batch .append_output_pair(scanning_batch_idx, None); self.output_size += 1; @@ -1129,12 +1126,8 @@ impl SMJStream { // Produces and stages record batch for all output indices found // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { - //dbg!(&self.streamed_batch.batch); - for chunk in self.streamed_batch.output_indices.iter_mut() { let streamed_indices = chunk.streamed_indices.finish(); - dbg!(&streamed_indices); - //let streamed_indices = PrimitiveArray::::try_new(vec![0, 1].into(), None)?; if streamed_indices.is_empty() { continue; @@ -1149,8 +1142,6 @@ impl SMJStream { .collect::, ArrowError>>()?; let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); - dbg!(&buffered_indices); - dbg!(&self.join_type); let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { @@ -1178,15 +1169,17 @@ impl SMJStream { let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + } else if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti + ) { let buffered_columns = self.buffered_data.batches [chunk.buffered_batch_idx.unwrap()] .batch .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>()?; - //dbg!(&buffered_columns); + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>()?; get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } else { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) @@ -1196,10 +1189,6 @@ impl SMJStream { vec![] }; - dbg!(&streamed_columns); - dbg!(&buffered_columns); - dbg!(&filter_columns); - let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns.clone()); buffered_columns @@ -1220,16 +1209,11 @@ impl SMJStream { filter_columns, )?; - dbg!(&filter_batch); - dbg!(&f.expression()); - - let filter_result = f .expression() .evaluate(&filter_batch)? .into_array(filter_batch.num_rows())?; - dbg!(&filter_result); // The selection mask of the filter let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; @@ -1237,8 +1221,6 @@ impl SMJStream { let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; - dbg!(&filtered_batch); - self.output_record_batches.push(filtered_batch); // For outer joins, we need to push the null joined rows to the output. @@ -1397,9 +1379,6 @@ fn get_filter_column( .map(|i| buffered_columns[i.index].clone()) .collect::>(); -// dbg!(&left_columns); -// dbg!(&right_columns); - filter_columns.extend(left_columns); filter_columns.extend(right_columns); } diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 7b7e355fa2b5..2de029350434 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -280,5 +280,122 @@ SELECT * FROM ( ) ORDER BY 1, 2; ---- + +# LEFTSEMI join tests + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b = t1.b) +) order by 1, 2 +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b = t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- + +# This gives a wrong result for now +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b union all +# select 11 a, 13 b), +#t2 as ( +# select 11 a, 12 b union all +# select 11 a, 14 b +# ) +#select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +#) order by 1, 2; +#---- + +#LEFTANTI tests +# returns no rows instead of correct result +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c union all +# select 11 a, 14 b, 4 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +---- +11 12 1 +11 13 2 + statement ok set datafusion.optimizer.prefer_hash_join = true; From 8c6010ec37c73aabe179a96bc6bc124c44712cc3 Mon Sep 17 00:00:00 2001 From: comphead Date: Sat, 4 May 2024 09:53:59 -0700 Subject: [PATCH 04/20] Fix: Sort Merge Join Left Semi crashes --- datafusion/sqllogictest/test_files/sort_merge_join.slt | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 2de029350434..eb3ba70d29de 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -362,7 +362,6 @@ t1 as ( select 11 a, 13 b) select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) ) order by 1, 2; ----- # This gives a wrong result for now #query II @@ -378,6 +377,8 @@ t1 as ( #select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) #) order by 1, 2; #---- +#11 12 +#11 13 #LEFTANTI tests # returns no rows instead of correct result @@ -393,9 +394,9 @@ t1 as ( # ) #select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) #) order by 1, 2; ----- -11 12 1 -11 13 2 +#---- +#11 12 1 +#11 13 2 statement ok set datafusion.optimizer.prefer_hash_join = true; From f9e11332d31c1ef0817e022a92354a3a1bf5c5f7 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 6 May 2024 21:17:17 -0700 Subject: [PATCH 05/20] Fix: Sort Merge Join Left Semi crashes --- .../src/joins/sort_merge_join.rs | 37 +++++++++++++++++-- .../test_files/sort_merge_join.slt | 30 +++++++-------- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 20fa40569fd6..79bb2db75635 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -46,7 +46,6 @@ use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; -//use arrow_array::types::UInt64Type; use datafusion_common::{ internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, }; @@ -1142,7 +1141,6 @@ impl SMJStream { .collect::, ArrowError>>()?; let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); - let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { vec![] @@ -1215,7 +1213,40 @@ impl SMJStream { .into_array(filter_batch.num_rows())?; // The selection mask of the filter - let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + let mut mask = + datafusion_common::cast::as_boolean_array(&filter_result)?; + // for LeftSemi Join the filter mask should be calculated in its own way: + // if we find at least one matching row for specific streaming key/filter we dont need to check others for the same key/filter + let mut maybe_left_semi_mask: Option = None; + if matches!(self.join_type, JoinType::LeftSemi) { + // did we get a filter match for a streaming index + let mut seen_as_true: bool = false; + let streamed_indices_length = streamed_indices.len(); + let mut corrected_mask: Vec = + vec![false; streamed_indices_length]; + + #[allow(clippy::needless_range_loop)] + for i in 0..streamed_indices_length { + // if for a streaming index its a match first time, set it as true + if mask.value(i) && !seen_as_true { + seen_as_true = true; + corrected_mask[i] = true; + } + + // if switched to next streaming index(e.g from 0 to 1, or from 1 to 2), we reset seen_as_true flag + if i < streamed_indices_length - 1 + && streamed_indices.value(i) + != streamed_indices.value(i + 1) + { + seen_as_true = false; + } + } + maybe_left_semi_mask = Some(BooleanArray::from(corrected_mask)) + }; + + if let Some(ref left_semi_mask) = maybe_left_semi_mask { + mask = left_semi_mask; + } // Push the filtered batch to the output let filtered_batch = diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index eb3ba70d29de..81b7561beaa4 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -364,21 +364,21 @@ t1 as ( ) order by 1, 2; # This gives a wrong result for now -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b union all -# select 11 a, 13 b), -#t2 as ( -# select 11 a, 12 b union all -# select 11 a, 14 b -# ) -#select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) -#) order by 1, 2; -#---- -#11 12 -#11 13 +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 #LEFTANTI tests # returns no rows instead of correct result From ed0035b66cf98ce0cfd73d154fe6a417fae3004e Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 7 May 2024 07:55:01 -0700 Subject: [PATCH 06/20] Fix: Sort Merge Join Left Semi crashes --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 79bb2db75635..ea03887d11b6 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1006,11 +1006,6 @@ impl SMJStream { if matches!(self.join_type, JoinType::Full) { join_buffered = !self.buffered_joined; }; - - if matches!(self.join_type, JoinType::LeftAnti) { - join_streamed = !self.streamed_joined; - join_buffered = true; - }; } } if !join_streamed && !join_buffered { @@ -1167,10 +1162,7 @@ impl SMJStream { let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!( - self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti - ) { + } else if matches!(self.join_type, JoinType::LeftSemi) { let buffered_columns = self.buffered_data.batches [chunk.buffered_batch_idx.unwrap()] .batch From 4c2c8f3f7c7ef8700a68a5bca7cd6214ec5b6d83 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 7 May 2024 09:11:23 -0700 Subject: [PATCH 07/20] Fix: Sort Merge Join Left Semi crashes --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index ea03887d11b6..1777931721b7 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -992,7 +992,9 @@ impl SMJStream { Ordering::Equal => { if matches!(self.join_type, JoinType::LeftSemi) { join_streamed = !self.streamed_joined; - join_buffered = true; + // if the join filter specified there can be references to buffered columns + // so its needed to join them + join_buffered = self.filter.is_some(); } if matches!( self.join_type, From 4052b0d5d836d608724748d700c973bdcd6dbfcf Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 8 May 2024 16:34:49 -0700 Subject: [PATCH 08/20] Fix: Sort Merge Join Left Semi crashes. Comments --- .../src/joins/sort_merge_join.rs | 44 ++++++++++++------- .../test_files/sort_merge_join.slt | 18 +++++++- 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 1777931721b7..659de71107ba 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -824,7 +824,6 @@ impl SMJStream { self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); - self.streamed_state = StreamedState::Ready; } } @@ -993,7 +992,7 @@ impl SMJStream { if matches!(self.join_type, JoinType::LeftSemi) { join_streamed = !self.streamed_joined; // if the join filter specified there can be references to buffered columns - // so its needed to join them + // so buffered columns are needed to access them join_buffered = self.filter.is_some(); } if matches!( @@ -1142,12 +1141,11 @@ impl SMJStream { if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { - self.buffered_data.batches[buffered_idx] - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>()? + get_buffered_columns( + &self.buffered_data, + buffered_idx, + &buffered_indices, + )? } else { self.buffered_schema .fields() @@ -1165,13 +1163,13 @@ impl SMJStream { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) } else if matches!(self.join_type, JoinType::LeftSemi) { - let buffered_columns = self.buffered_data.batches - [chunk.buffered_batch_idx.unwrap()] - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>()?; + // unwrap is safe here as we check is_some on top of if statement + let buffered_columns = get_buffered_columns( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &buffered_indices, + )?; + get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } else { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) @@ -1245,7 +1243,6 @@ impl SMJStream { // Push the filtered batch to the output let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; - self.output_record_batches.push(filtered_batch); // For outer joins, we need to push the null joined rows to the output. @@ -1411,6 +1408,21 @@ fn get_filter_column( filter_columns } +// Get buffered data sliece by specific batch index and for specified column indices only +#[inline(always)] +fn get_buffered_columns( + buffered_data: &BufferedData, + buffered_batch_idx: usize, + buffered_indices: &UInt64Array, +) -> Result, ArrowError> { + buffered_data.batches[buffered_batch_idx] + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() +} + /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 81b7561beaa4..97ebb5117767 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -363,7 +363,6 @@ t1 as ( select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) ) order by 1, 2; -# This gives a wrong result for now query II select * from ( with @@ -380,6 +379,23 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != 11 12 11 13 +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + #LEFTANTI tests # returns no rows instead of correct result #query III From 9da2c457981ea8f6c8da421bd5a789c183c54e81 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 8 May 2024 18:18:07 -0700 Subject: [PATCH 09/20] Fix: Sort Merge Join Left Semi crashes. Add tests --- .../src/joins/sort_merge_join.rs | 137 ++++++++++++++---- 1 file changed, 107 insertions(+), 30 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 659de71107ba..e32bed7cef1c 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1207,37 +1207,11 @@ impl SMJStream { // The selection mask of the filter let mut mask = datafusion_common::cast::as_boolean_array(&filter_result)?; - // for LeftSemi Join the filter mask should be calculated in its own way: - // if we find at least one matching row for specific streaming key/filter we dont need to check others for the same key/filter - let mut maybe_left_semi_mask: Option = None; - if matches!(self.join_type, JoinType::LeftSemi) { - // did we get a filter match for a streaming index - let mut seen_as_true: bool = false; - let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: Vec = - vec![false; streamed_indices_length]; - - #[allow(clippy::needless_range_loop)] - for i in 0..streamed_indices_length { - // if for a streaming index its a match first time, set it as true - if mask.value(i) && !seen_as_true { - seen_as_true = true; - corrected_mask[i] = true; - } - - // if switched to next streaming index(e.g from 0 to 1, or from 1 to 2), we reset seen_as_true flag - if i < streamed_indices_length - 1 - && streamed_indices.value(i) - != streamed_indices.value(i + 1) - { - seen_as_true = false; - } - } - maybe_left_semi_mask = Some(BooleanArray::from(corrected_mask)) - }; - if let Some(ref left_semi_mask) = maybe_left_semi_mask { - mask = left_semi_mask; + let maybe_filtered_join_mask: Option = + get_filtered_join_mask(self.join_type, streamed_indices, mask); + if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { + mask = filtered_join_mask; } // Push the filtered batch to the output @@ -1423,6 +1397,42 @@ fn get_buffered_columns( .collect::, ArrowError>>() } +// Calculate join filter bit mask considering join type specifics +fn get_filtered_join_mask( + join_type: JoinType, + streamed_indices: UInt64Array, + mask: &BooleanArray, +) -> Option { + // for LeftSemi Join the filter mask should be calculated in its own way: + // if we find at least one matching row for specific streaming index + // we dont need to check any others for the same index + if matches!(join_type, JoinType::LeftSemi) { + // have we seen a filter match for a streaming index before + let mut seen_as_true: bool = false; + let streamed_indices_length = streamed_indices.len(); + let mut corrected_mask: Vec = vec![false; streamed_indices_length]; + + #[allow(clippy::needless_range_loop)] + for i in 0..streamed_indices_length { + // if for a streaming index its a match first time, set it as true + if mask.value(i) && !seen_as_true { + seen_as_true = true; + corrected_mask[i] = true; + } + + // if switched to next streaming index(e.g from 0 to 1, or from 1 to 2), we reset seen_as_true flag + if i < streamed_indices_length - 1 + && streamed_indices.value(i) != streamed_indices.value(i + 1) + { + seen_as_true = false; + } + } + Some(BooleanArray::from(corrected_mask)) + } else { + None + } +} + /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { @@ -1669,10 +1679,13 @@ mod tests { use crate::test::build_table_i32; use crate::{common, ExecutionPlan}; + use crate::joins::sort_merge_join::get_filtered_join_mask; use arrow::array::{Date32Array, Date64Array, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::{BooleanArray, UInt64Array}; + use datafusion_common::JoinType::LeftSemi; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; @@ -2699,6 +2712,70 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn left_semi_join_filtered_mask() -> Result<()> { + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 0, 1, 1]), + &BooleanArray::from(vec![true, true, false, false]) + ), + Some(BooleanArray::from(vec![true, false, false, false])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![true, true]) + ), + Some(BooleanArray::from(vec![true, true])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![false, true]) + ), + Some(BooleanArray::from(vec![false, true])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![true, false]) + ), + Some(BooleanArray::from(vec![true, false])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![false, true, true, true, true, true]) + ), + Some(BooleanArray::from(vec![ + false, true, false, true, false, false + ])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![false, false, false, false, false, true]) + ), + Some(BooleanArray::from(vec![ + false, false, false, false, false, true + ])) + ); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() From 9c71eefe8aaad21e995844bd660bcbe0018ca408 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 13 May 2024 16:57:20 -0700 Subject: [PATCH 10/20] Fix: Sort Merge Join Left Semi crashes. Add tests --- .../src/joins/sort_merge_join.rs | 64 +++++++++++-------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index e32bed7cef1c..22303fe7099b 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -30,22 +30,12 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, -}; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; - use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +use futures::{Stream, StreamExt}; + use datafusion_common::{ internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, }; @@ -54,7 +44,17 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use futures::{Stream, StreamExt}; +use crate::expressions::PhysicalSortExpr; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, +}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::{ + execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -1398,6 +1398,11 @@ fn get_buffered_columns( } // Calculate join filter bit mask considering join type specifics +// `streamed_indices` - array of streamed datasource JOINED row indices +// `mask` - array booleans representing computed join filter expression eval result: +// true = the row index matches the join filter +// false = the row index doesn't match the join filter +// `streaned_indices` have the same length as `mask` fn get_filtered_join_mask( join_type: JoinType, streamed_indices: UInt64Array, @@ -1405,29 +1410,33 @@ fn get_filtered_join_mask( ) -> Option { // for LeftSemi Join the filter mask should be calculated in its own way: // if we find at least one matching row for specific streaming index - // we dont need to check any others for the same index + // we don't need to check any others for the same index if matches!(join_type, JoinType::LeftSemi) { // have we seen a filter match for a streaming index before let mut seen_as_true: bool = false; let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: Vec = vec![false; streamed_indices_length]; + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(streamed_indices_length); #[allow(clippy::needless_range_loop)] for i in 0..streamed_indices_length { - // if for a streaming index its a match first time, set it as true + // LeftSemi respects only first true values for specific streaming index, + // others true values for the same index must be false if mask.value(i) && !seen_as_true { seen_as_true = true; - corrected_mask[i] = true; + corrected_mask.append_value(true); + } else { + corrected_mask.append_value(false); } - // if switched to next streaming index(e.g from 0 to 1, or from 1 to 2), we reset seen_as_true flag + // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag if i < streamed_indices_length - 1 && streamed_indices.value(i) != streamed_indices.value(i + 1) { seen_as_true = false; } } - Some(BooleanArray::from(corrected_mask)) + Some(corrected_mask.finish()) } else { None } @@ -1672,19 +1681,12 @@ fn is_join_arrays_equal( mod tests { use std::sync::Arc; - use crate::expressions::Column; - use crate::joins::utils::JoinOn; - use crate::joins::SortMergeJoinExec; - use crate::memory::MemoryExec; - use crate::test::build_table_i32; - use crate::{common, ExecutionPlan}; - - use crate::joins::sort_merge_join::get_filtered_join_mask; use arrow::array::{Date32Array, Date64Array, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::{BooleanArray, UInt64Array}; + use datafusion_common::JoinType::LeftSemi; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, @@ -1693,6 +1695,14 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_execution::TaskContext; + use crate::expressions::Column; + use crate::joins::sort_merge_join::get_filtered_join_mask; + use crate::joins::utils::JoinOn; + use crate::joins::SortMergeJoinExec; + use crate::memory::MemoryExec; + use crate::test::build_table_i32; + use crate::{common, ExecutionPlan}; + fn build_table( a: (&str, &Vec), b: (&str, &Vec), From fe0bb60a2141d052e30095016a0984c3282ddb24 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Mon, 13 May 2024 13:31:42 -0700 Subject: [PATCH 11/20] Update datafusion/physical-plan/src/joins/sort_merge_join.rs Co-authored-by: Andrew Lamb --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 22303fe7099b..122d9ec2f621 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1382,7 +1382,7 @@ fn get_filter_column( filter_columns } -// Get buffered data sliece by specific batch index and for specified column indices only +/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` #[inline(always)] fn get_buffered_columns( buffered_data: &BufferedData, From c0fd73eb8575e233cb98c637a37321e54a89ca74 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 16 May 2024 17:44:48 -0700 Subject: [PATCH 12/20] Fix: Sort Merge Join Left Semi crashes. Add tests --- .../src/joins/sort_merge_join.rs | 52 ++++++++++++------ .../test_files/sort_merge_join.slt | 53 ++++++++++++------- 2 files changed, 71 insertions(+), 34 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 122d9ec2f621..d4f6d2bf4900 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -35,6 +35,7 @@ use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use futures::{Stream, StreamExt}; +use hashbrown::HashSet; use datafusion_common::{ internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -491,6 +492,10 @@ struct StreamedBatch { pub output_indices: Vec, /// Index of currently scanned batch from buffered data pub buffered_batch_idx: Option, + /// Indices that found a match for the given join filter + /// Used for semi joins to keep track the streaming index which got a join filter match + /// and already emitted to the output. + pub join_filter_matched_idxs: HashSet, } impl StreamedBatch { @@ -502,6 +507,7 @@ impl StreamedBatch { join_arrays, output_indices: vec![], buffered_batch_idx: None, + join_filter_matched_idxs: HashSet::new(), } } @@ -512,6 +518,7 @@ impl StreamedBatch { join_arrays: vec![], output_indices: vec![], buffered_batch_idx: None, + join_filter_matched_idxs: HashSet::new(), } } @@ -989,11 +996,18 @@ impl SMJStream { } } Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi) { - join_streamed = !self.streamed_joined; + if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_some() { + join_streamed = !self + .streamed_batch + .join_filter_matched_idxs + .contains(&(self.streamed_batch.idx as u64)) + && !self.streamed_joined; // if the join filter specified there can be references to buffered columns // so buffered columns are needed to access them - join_buffered = self.filter.is_some(); + join_buffered = join_streamed; + } + if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_none() { + join_streamed = !self.streamed_joined; } if matches!( self.join_type, @@ -1208,10 +1222,13 @@ impl SMJStream { let mut mask = datafusion_common::cast::as_boolean_array(&filter_result)?; - let maybe_filtered_join_mask: Option = + let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = get_filtered_join_mask(self.join_type, streamed_indices, mask); if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { - mask = filtered_join_mask; + mask = &filtered_join_mask.0; + self.streamed_batch + .join_filter_matched_idxs + .extend(&filtered_join_mask.1); } // Push the filtered batch to the output @@ -1402,12 +1419,12 @@ fn get_buffered_columns( // `mask` - array booleans representing computed join filter expression eval result: // true = the row index matches the join filter // false = the row index doesn't match the join filter -// `streaned_indices` have the same length as `mask` +// `streamed_indices` have the same length as `mask` fn get_filtered_join_mask( join_type: JoinType, streamed_indices: UInt64Array, mask: &BooleanArray, -) -> Option { +) -> Option<(BooleanArray, Vec)> { // for LeftSemi Join the filter mask should be calculated in its own way: // if we find at least one matching row for specific streaming index // we don't need to check any others for the same index @@ -1418,6 +1435,8 @@ fn get_filtered_join_mask( let mut corrected_mask: BooleanBuilder = BooleanBuilder::with_capacity(streamed_indices_length); + let mut filter_matched_indices: Vec = vec![]; + #[allow(clippy::needless_range_loop)] for i in 0..streamed_indices_length { // LeftSemi respects only first true values for specific streaming index, @@ -1425,6 +1444,7 @@ fn get_filtered_join_mask( if mask.value(i) && !seen_as_true { seen_as_true = true; corrected_mask.append_value(true); + filter_matched_indices.push(streamed_indices.value(i)); } else { corrected_mask.append_value(false); } @@ -1436,7 +1456,7 @@ fn get_filtered_join_mask( seen_as_true = false; } } - Some(corrected_mask.finish()) + Some((corrected_mask.finish(), filter_matched_indices)) } else { None } @@ -2731,7 +2751,7 @@ mod tests { UInt64Array::from(vec![0, 0, 1, 1]), &BooleanArray::from(vec![true, true, false, false]) ), - Some(BooleanArray::from(vec![true, false, false, false])) + Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) ); assert_eq!( @@ -2740,7 +2760,7 @@ mod tests { UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, true]) ), - Some(BooleanArray::from(vec![true, true])) + Some((BooleanArray::from(vec![true, true]), vec![0, 1])) ); assert_eq!( @@ -2749,7 +2769,7 @@ mod tests { UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![false, true]) ), - Some(BooleanArray::from(vec![false, true])) + Some((BooleanArray::from(vec![false, true]), vec![1])) ); assert_eq!( @@ -2758,7 +2778,7 @@ mod tests { UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, false]) ), - Some(BooleanArray::from(vec![true, false])) + Some((BooleanArray::from(vec![true, false]), vec![0])) ); assert_eq!( @@ -2767,9 +2787,9 @@ mod tests { UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, true, true, true, true, true]) ), - Some(BooleanArray::from(vec![ + Some((BooleanArray::from(vec![ false, true, false, true, false, false - ])) + ]), vec![0, 1])) ); assert_eq!( @@ -2778,9 +2798,9 @@ mod tests { UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, false, false, false, false, true]) ), - Some(BooleanArray::from(vec![ + Some((BooleanArray::from(vec![ false, false, false, false, false, true - ])) + ]), vec![1])) ); Ok(()) diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 97ebb5117767..3a27d9693d00 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -263,24 +263,6 @@ DROP TABLE t1; statement ok DROP TABLE t2; -# Set batch size to 1 for sort merge join to test scenario when data spread across multiple batches -statement ok -set datafusion.execution.batch_size = 1; - -query II -SELECT * FROM ( - WITH - t1 AS ( - SELECT 12 a, 12 b - ), - t2 AS ( - SELECT 12 a, 12 b - ) - SELECT t1.* FROM t1 JOIN t2 on t1.a = t2.b WHERE t1.a > t2.b -) ORDER BY 1, 2; ----- - - # LEFTSEMI join tests query II @@ -414,5 +396,40 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != #11 12 1 #11 13 2 +# Set batch size to 1 for sort merge join to test scenario when data spread across multiple batches +statement ok +set datafusion.execution.batch_size = 1; + +query II +SELECT * FROM ( + WITH + t1 AS ( + SELECT 12 a, 12 b + ), + t2 AS ( + SELECT 12 a, 12 b + ) + SELECT t1.* FROM t1 JOIN t2 on t1.a = t2.b WHERE t1.a > t2.b +) ORDER BY 1, 2; +---- + + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + statement ok set datafusion.optimizer.prefer_hash_join = true; From f993b3c8c44d0ddd9a4456dd39a928d4ddf096d6 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 29 Apr 2024 16:56:14 -0700 Subject: [PATCH 13/20] Fix: Sort Merge Join crashes on TPCH Q21 --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index d4f6d2bf4900..563edd78ffc4 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1392,6 +1392,9 @@ fn get_filter_column( .map(|i| buffered_columns[i.index].clone()) .collect::>(); +// dbg!(&left_columns); +// dbg!(&right_columns); + filter_columns.extend(left_columns); filter_columns.extend(right_columns); } From 1354f837f2926b5620d35e796559cc807a94aaec Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 3 May 2024 17:11:32 -0700 Subject: [PATCH 14/20] Fix: Sort Merge Join crashes on TPCH Q21. LeftSemi --- .../src/joins/sort_merge_join.rs | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 563edd78ffc4..9885b290dc4c 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -825,12 +825,15 @@ impl SMJStream { self.streamed_state = StreamedState::Exhausted; } Poll::Ready(Some(batch)) => { + //dbg!(&batch); + //println!("{:#?}", &batch); if batch.num_rows() > 0 { self.freeze_streamed()?; self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + self.streamed_state = StreamedState::Ready; } } @@ -1008,6 +1011,7 @@ impl SMJStream { } if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_none() { join_streamed = !self.streamed_joined; + join_buffered = true; } if matches!( self.join_type, @@ -1021,6 +1025,11 @@ impl SMJStream { if matches!(self.join_type, JoinType::Full) { join_buffered = !self.buffered_joined; }; + + if matches!(self.join_type, JoinType::LeftAnti) { + join_streamed = !self.streamed_joined; + join_buffered = true; + }; } } if !join_streamed && !join_buffered { @@ -1064,6 +1073,7 @@ impl SMJStream { Some(self.buffered_data.scanning_batch_idx) }; + //dbg!(self.buffered_data.scanning_idx()); self.streamed_batch .append_output_pair(scanning_batch_idx, None); self.output_size += 1; @@ -1135,8 +1145,12 @@ impl SMJStream { // Produces and stages record batch for all output indices found // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { + //dbg!(&self.streamed_batch.batch); + for chunk in self.streamed_batch.output_indices.iter_mut() { let streamed_indices = chunk.streamed_indices.finish(); + dbg!(&streamed_indices); + //let streamed_indices = PrimitiveArray::::try_new(vec![0, 1].into(), None)?; if streamed_indices.is_empty() { continue; @@ -1193,6 +1207,10 @@ impl SMJStream { vec![] }; + dbg!(&streamed_columns); + dbg!(&buffered_columns); + dbg!(&filter_columns); + let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns.clone()); buffered_columns @@ -1213,11 +1231,16 @@ impl SMJStream { filter_columns, )?; + dbg!(&filter_batch); + dbg!(&f.expression()); + + let filter_result = f .expression() .evaluate(&filter_batch)? .into_array(filter_batch.num_rows())?; + dbg!(&filter_result); // The selection mask of the filter let mut mask = datafusion_common::cast::as_boolean_array(&filter_result)?; @@ -1234,6 +1257,9 @@ impl SMJStream { // Push the filtered batch to the output let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; + + dbg!(&filtered_batch); + self.output_record_batches.push(filtered_batch); // For outer joins, we need to push the null joined rows to the output. From c129846ed4533fd5f053cbac1fe99404ca2d08cc Mon Sep 17 00:00:00 2001 From: comphead Date: Sat, 4 May 2024 09:21:06 -0700 Subject: [PATCH 15/20] Fix: Sort Merge Join Left Semi crashes --- .../src/joins/sort_merge_join.rs | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 9885b290dc4c..81060e859b8b 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -825,8 +825,6 @@ impl SMJStream { self.streamed_state = StreamedState::Exhausted; } Poll::Ready(Some(batch)) => { - //dbg!(&batch); - //println!("{:#?}", &batch); if batch.num_rows() > 0 { self.freeze_streamed()?; self.join_metrics.input_batches.add(1); @@ -1073,7 +1071,6 @@ impl SMJStream { Some(self.buffered_data.scanning_batch_idx) }; - //dbg!(self.buffered_data.scanning_idx()); self.streamed_batch .append_output_pair(scanning_batch_idx, None); self.output_size += 1; @@ -1145,12 +1142,8 @@ impl SMJStream { // Produces and stages record batch for all output indices found // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { - //dbg!(&self.streamed_batch.batch); - for chunk in self.streamed_batch.output_indices.iter_mut() { let streamed_indices = chunk.streamed_indices.finish(); - dbg!(&streamed_indices); - //let streamed_indices = PrimitiveArray::::try_new(vec![0, 1].into(), None)?; if streamed_indices.is_empty() { continue; @@ -1207,10 +1200,6 @@ impl SMJStream { vec![] }; - dbg!(&streamed_columns); - dbg!(&buffered_columns); - dbg!(&filter_columns); - let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns.clone()); buffered_columns @@ -1231,16 +1220,11 @@ impl SMJStream { filter_columns, )?; - dbg!(&filter_batch); - dbg!(&f.expression()); - - let filter_result = f .expression() .evaluate(&filter_batch)? .into_array(filter_batch.num_rows())?; - dbg!(&filter_result); // The selection mask of the filter let mut mask = datafusion_common::cast::as_boolean_array(&filter_result)?; @@ -1258,8 +1242,6 @@ impl SMJStream { let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; - dbg!(&filtered_batch); - self.output_record_batches.push(filtered_batch); // For outer joins, we need to push the null joined rows to the output. @@ -1418,9 +1400,6 @@ fn get_filter_column( .map(|i| buffered_columns[i.index].clone()) .collect::>(); -// dbg!(&left_columns); -// dbg!(&right_columns); - filter_columns.extend(left_columns); filter_columns.extend(right_columns); } From 22c61fc3dd57e5f00b591c7453383726f6c95785 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 7 May 2024 07:55:01 -0700 Subject: [PATCH 16/20] Fix: Sort Merge Join Left Semi crashes --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 81060e859b8b..3b9504f9e758 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1023,11 +1023,6 @@ impl SMJStream { if matches!(self.join_type, JoinType::Full) { join_buffered = !self.buffered_joined; }; - - if matches!(self.join_type, JoinType::LeftAnti) { - join_streamed = !self.streamed_joined; - join_buffered = true; - }; } } if !join_streamed && !join_buffered { From 30f28fe8d8a41445bd36af5bc683064b0158ad52 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 7 May 2024 09:11:23 -0700 Subject: [PATCH 17/20] Fix: Sort Merge Join Left Semi crashes --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 3b9504f9e758..d798bbaa20f0 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1009,7 +1009,9 @@ impl SMJStream { } if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_none() { join_streamed = !self.streamed_joined; - join_buffered = true; + // if the join filter specified there can be references to buffered columns + // so its needed to join them + join_buffered = self.filter.is_some(); } if matches!( self.join_type, From 823f39646e51f85936ecd32240b363842e343fbe Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 8 May 2024 16:34:49 -0700 Subject: [PATCH 18/20] Fix: Sort Merge Join Left Semi crashes. Comments --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index d798bbaa20f0..8440f3184d75 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -831,7 +831,6 @@ impl SMJStream { self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); - self.streamed_state = StreamedState::Ready; } } @@ -1010,7 +1009,7 @@ impl SMJStream { if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_none() { join_streamed = !self.streamed_joined; // if the join filter specified there can be references to buffered columns - // so its needed to join them + // so buffered columns are needed to access them join_buffered = self.filter.is_some(); } if matches!( @@ -1238,7 +1237,6 @@ impl SMJStream { // Push the filtered batch to the output let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; - self.output_record_batches.push(filtered_batch); // For outer joins, we need to push the null joined rows to the output. From f0e60da4a09fc85450369a04155f5965708c218c Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 16 May 2024 17:56:35 -0700 Subject: [PATCH 19/20] fmt --- .../physical-plan/src/joins/sort_merge_join.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 8440f3184d75..66946f664a78 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -2790,9 +2790,10 @@ mod tests { UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, true, true, true, true, true]) ), - Some((BooleanArray::from(vec![ - false, true, false, true, false, false - ]), vec![0, 1])) + Some(( + BooleanArray::from(vec![false, true, false, true, false, false]), + vec![0, 1] + )) ); assert_eq!( @@ -2801,9 +2802,10 @@ mod tests { UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, false, false, false, false, true]) ), - Some((BooleanArray::from(vec![ - false, false, false, false, false, true - ]), vec![1])) + Some(( + BooleanArray::from(vec![false, false, false, false, false, true]), + vec![1] + )) ); Ok(()) From a06acaab70bfdaf27b144b146e57e3984e43c6c1 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 May 2024 14:21:57 -0700 Subject: [PATCH 20/20] fmt --- .../src/joins/sort_merge_join.rs | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 66946f664a78..1cc7bf4700d1 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -996,21 +996,23 @@ impl SMJStream { } } Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_some() { - join_streamed = !self - .streamed_batch - .join_filter_matched_idxs - .contains(&(self.streamed_batch.idx as u64)) - && !self.streamed_joined; - // if the join filter specified there can be references to buffered columns - // so buffered columns are needed to access them - join_buffered = join_streamed; - } - if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_none() { - join_streamed = !self.streamed_joined; - // if the join filter specified there can be references to buffered columns - // so buffered columns are needed to access them - join_buffered = self.filter.is_some(); + if matches!(self.join_type, JoinType::LeftSemi) { + // if the join filter is specified then its needed to output the streamed index + // only if it has not been emitted before + // the `join_filter_matched_idxs` keeps track on if streamed index has a successful + // filter match and prevents the same index to go into output more than once + if self.filter.is_some() { + join_streamed = !self + .streamed_batch + .join_filter_matched_idxs + .contains(&(self.streamed_batch.idx as u64)) + && !self.streamed_joined; + // if the join filter specified there can be references to buffered columns + // so buffered columns are needed to access them + join_buffered = join_streamed; + } else { + join_streamed = !self.streamed_joined; + } } if matches!( self.join_type,