From 8e730e40934082f6cd274afa60a77f3df457e044 Mon Sep 17 00:00:00 2001 From: comphead Date: Sat, 4 May 2024 09:21:06 -0700 Subject: [PATCH] Fix: Sort Merge Join Left Semi crashes --- .../src/joins/sort_merge_join.rs | 35 ++---- .../test_files/sort_merge_join.slt | 117 ++++++++++++++++++ 2 files changed, 124 insertions(+), 28 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index d98ee83d5d700..da047b317aa96 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -819,8 +819,6 @@ impl SMJStream { self.streamed_state = StreamedState::Exhausted; } Poll::Ready(Some(batch)) => { - //dbg!(&batch); - //println!("{:#?}", &batch); if batch.num_rows() > 0 { self.freeze_streamed()?; self.join_metrics.input_batches.add(1); @@ -1057,7 +1055,6 @@ impl SMJStream { Some(self.buffered_data.scanning_batch_idx) }; - //dbg!(self.buffered_data.scanning_idx()); self.streamed_batch .append_output_pair(scanning_batch_idx, None); self.output_size += 1; @@ -1129,12 +1126,8 @@ impl SMJStream { // Produces and stages record batch for all output indices found // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { - //dbg!(&self.streamed_batch.batch); - for chunk in self.streamed_batch.output_indices.iter_mut() { let streamed_indices = chunk.streamed_indices.finish(); - dbg!(&streamed_indices); - //let streamed_indices = PrimitiveArray::::try_new(vec![0, 1].into(), None)?; if streamed_indices.is_empty() { continue; @@ -1149,8 +1142,6 @@ impl SMJStream { .collect::, ArrowError>>()?; let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); - dbg!(&buffered_indices); - dbg!(&self.join_type); let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { @@ -1178,15 +1169,17 @@ impl SMJStream { let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + } else if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti + ) { let buffered_columns = self.buffered_data.batches [chunk.buffered_batch_idx.unwrap()] .batch .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>()?; - //dbg!(&buffered_columns); + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>()?; get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } else { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) @@ -1196,10 +1189,6 @@ impl SMJStream { vec![] }; - dbg!(&streamed_columns); - dbg!(&buffered_columns); - dbg!(&filter_columns); - let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns.clone()); buffered_columns @@ -1220,16 +1209,11 @@ impl SMJStream { filter_columns, )?; - dbg!(&filter_batch); - dbg!(&f.expression()); - - let filter_result = f .expression() .evaluate(&filter_batch)? .into_array(filter_batch.num_rows())?; - dbg!(&filter_result); // The selection mask of the filter let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; @@ -1237,8 +1221,6 @@ impl SMJStream { let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; - dbg!(&filtered_batch); - self.output_record_batches.push(filtered_batch); // For outer joins, we need to push the null joined rows to the output. @@ -1395,9 +1377,6 @@ fn get_filter_column( .map(|i| buffered_columns[i.index].clone()) .collect::>(); -// dbg!(&left_columns); -// dbg!(&right_columns); - filter_columns.extend(left_columns); filter_columns.extend(right_columns); } diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 0361d4e5af7fb..6b423a933338a 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -263,5 +263,122 @@ DROP TABLE t1; statement ok DROP TABLE t2; + +# LEFTSEMI join tests + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b = t1.b) +) order by 1, 2 +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b = t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- + +# This gives a wrong result for now +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b union all +# select 11 a, 13 b), +#t2 as ( +# select 11 a, 12 b union all +# select 11 a, 14 b +# ) +#select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +#) order by 1, 2; +#---- + +#LEFTANTI tests +# returns no rows instead of correct result +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c union all +# select 11 a, 14 b, 4 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +---- +11 12 1 +11 13 2 + statement ok set datafusion.optimizer.prefer_hash_join = true;