Skip to content

Commit

Permalink
fix: Produce buffered null join row only if all joined rows are faile…
Browse files Browse the repository at this point in the history
…d on join filter in SMJ full join (#12090)

* Add test case

* fix

* Update test

* fix

* Remove duplicate

* fix clippy

* fix clippy again

* For review
  • Loading branch information
viirya authored Aug 23, 2024
1 parent b8b76bc commit 7806393
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 22 deletions.
48 changes: 27 additions & 21 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
use std::any::Any;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::collections::{HashMap, VecDeque};
use std::fmt::Formatter;
use std::fs::File;
use std::io::BufReader;
Expand Down Expand Up @@ -595,8 +595,10 @@ struct BufferedBatch {
/// Size estimation used for reserving / releasing memory
pub size_estimation: usize,
/// The indices of buffered batch that failed the join filter.
/// This is a map between buffered row index and a boolean value indicating whether all joined row
/// of the buffered row failed the join filter.
/// When dequeuing the buffered batch, we need to produce null joined rows for these indices.
pub join_filter_failed_idxs: HashSet<u64>,
pub join_filter_failed_map: HashMap<u64, bool>,
/// Current buffered batch number of rows. Equal to batch.num_rows()
/// but if batch is spilled to disk this property is preferable
/// and less expensive
Expand Down Expand Up @@ -637,7 +639,7 @@ impl BufferedBatch {
join_arrays,
null_joined: vec![],
size_estimation,
join_filter_failed_idxs: HashSet::new(),
join_filter_failed_map: HashMap::new(),
num_rows,
spill_file: None,
}
Expand Down Expand Up @@ -1229,11 +1231,19 @@ impl SMJStream {
}
buffered_batch.null_joined.clear();

// For buffered rows which are joined with streamed side but doesn't satisfy the join filter
// For buffered row which is joined with streamed side rows but all joined rows
// don't satisfy the join filter
if output_not_matched_filter {
let not_matched_buffered_indices = buffered_batch
.join_filter_failed_map
.iter()
.filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None })
.collect::<Vec<_>>();

let buffered_indices = UInt64Array::from_iter_values(
buffered_batch.join_filter_failed_idxs.iter().copied(),
not_matched_buffered_indices.iter().copied(),
);

if let Some(record_batch) = produce_buffered_null_batch(
&self.schema,
&self.streamed_schema,
Expand All @@ -1242,7 +1252,7 @@ impl SMJStream {
)? {
self.output_record_batches.push(record_batch);
}
buffered_batch.join_filter_failed_idxs.clear();
buffered_batch.join_filter_failed_map.clear();
}
}
Ok(())
Expand Down Expand Up @@ -1459,24 +1469,20 @@ impl SMJStream {
// If it is joined with streamed side, but doesn't match the join filter,
// we need to output it with nulls as streamed side.
if matches!(self.join_type, JoinType::Full) {
let buffered_batch = &mut self.buffered_data.batches
[chunk.buffered_batch_idx.unwrap()];

for i in 0..pre_mask.len() {
let buffered_batch = &mut self.buffered_data.batches
[chunk.buffered_batch_idx.unwrap()];
let buffered_index = buffered_indices.value(i);

if !pre_mask.value(i) {
// For a buffered row that is joined with streamed side but doesn't satisfy the join filter,
buffered_batch
.join_filter_failed_idxs
.insert(buffered_index);
} else if buffered_batch
.join_filter_failed_idxs
.contains(&buffered_index)
{
buffered_batch
.join_filter_failed_idxs
.remove(&buffered_index);
}
buffered_batch.join_filter_failed_map.insert(
buffered_index,
*buffered_batch
.join_filter_failed_map
.get(&buffered_index)
.unwrap_or(&true)
&& !pre_mask.value(i),
);
}
}
}
Expand Down
27 changes: 26 additions & 1 deletion datafusion/sqllogictest/test_files/sort_merge_join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ Alice 100 NULL NULL
Alice 50 Alice 2
Bob 1 NULL NULL
NULL NULL Alice 1
NULL NULL Alice 2

query TITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50
Expand Down Expand Up @@ -624,6 +623,32 @@ NULL NULL 7 9
NULL NULL 8 10
NULL NULL 9 11

query IIII
select * from (
with t as (
select id_a id_a_1, id_a % 5 id_a_2 from (select unnest(make_array(5, 6, 7, 8, 9, 0, 1, 2, 3, 4)) id_a)
), t1 as (
select id_b % 10 id_b_1, id_b + 2 id_b_2 from (select unnest(make_array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) id_b)
)
select * from t full join t1 on t.id_a_2 = t1.id_b_1 and t.id_a_1 > t1.id_b_2
) order by 1, 2, 3, 4
----
0 0 NULL NULL
1 1 NULL NULL
2 2 NULL NULL
3 3 NULL NULL
4 4 NULL NULL
5 0 0 2
6 1 1 3
7 2 2 4
8 3 3 5
9 4 4 6
NULL NULL 5 7
NULL NULL 6 8
NULL NULL 7 9
NULL NULL 8 10
NULL NULL 9 11

# return sql params back to default values
statement ok
set datafusion.optimizer.prefer_hash_join = true;
Expand Down

0 comments on commit 7806393

Please sign in to comment.