Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge SortMergeJoin filtered batches into larger batches #14160

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 95 additions & 40 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,10 @@ struct SortMergeJoinStream {
/// optional join filter
pub filter: Option<JoinFilter>,
/// Staging output array builders
pub output_record_batches: JoinedRecordBatches,
pub staging_output_record_batches: JoinedRecordBatches,
/// Output buffer. Currently used by filtering as it requires double buffering
/// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches`
pub output: RecordBatch,
/// Staging output size, including output batches and staging joined results.
/// Increased when we put rows into buffer and decreased after we actually output batches.
/// Used to trigger output when sufficient rows are ready
Expand Down Expand Up @@ -1053,13 +1056,35 @@ impl Stream for SortMergeJoinStream {
{
self.freeze_all()?;

if !self.output_record_batches.batches.is_empty()
// If join is filtered and there is joined tuples waiting
// to be filtered
if !self
.staging_output_record_batches
.batches
.is_empty()
{
// Apply filter on joined tuples and get filtered batch
let out_filtered_batch =
self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(
out_filtered_batch,
)));

// Append filtered batch to the output buffer
self.output = concat_batches(
&self.schema(),
vec![&self.output, &out_filtered_batch],
)?;

// Send to output if the output buffer surpassed the `batch_size`
if self.output.num_rows() >= self.batch_size {
let record_batch = std::mem::replace(
&mut self.output,
RecordBatch::new_empty(
out_filtered_batch.schema(),
),
);
return Poll::Ready(Some(Ok(
record_batch,
)));
}
}
}

Expand Down Expand Up @@ -1116,7 +1141,7 @@ impl Stream for SortMergeJoinStream {
}
} else {
self.freeze_all()?;
if !self.output_record_batches.batches.is_empty() {
if !self.staging_output_record_batches.batches.is_empty() {
let record_batch = self.output_record_batch_and_reset()?;
// For non-filtered join output whenever the target output batch size
// is hit. For filtered join its needed to output on later phase
Expand Down Expand Up @@ -1146,7 +1171,8 @@ impl Stream for SortMergeJoinStream {
SortMergeJoinState::Exhausted => {
self.freeze_all()?;

if !self.output_record_batches.batches.is_empty() {
// if there is still something not processed
if !self.staging_output_record_batches.batches.is_empty() {
if self.filter.is_some()
&& matches!(
self.join_type,
Expand All @@ -1159,12 +1185,20 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftMark
)
{
let out = self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(out)));
let record_batch = self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(record_batch)));
} else {
let record_batch = self.output_record_batch_and_reset()?;
return Poll::Ready(Some(Ok(record_batch)));
}
} else if self.output.num_rows() > 0 {
// if processed but still not outputted because it didn't hit batch size before
let schema = self.output.schema();
let record_batch = std::mem::replace(
&mut self.output,
RecordBatch::new_empty(schema),
);
return Poll::Ready(Some(Ok(record_batch)));
} else {
return Poll::Ready(None);
}
Expand Down Expand Up @@ -1197,7 +1231,7 @@ impl SortMergeJoinStream {
state: SortMergeJoinState::Init,
sort_options,
null_equals_null,
schema,
schema: Arc::clone(&schema),
streamed_schema: Arc::clone(&streamed_schema),
buffered_schema,
streamed,
Expand All @@ -1212,12 +1246,13 @@ impl SortMergeJoinStream {
on_streamed,
on_buffered,
filter,
output_record_batches: JoinedRecordBatches {
staging_output_record_batches: JoinedRecordBatches {
batches: vec![],
filter_mask: BooleanBuilder::new(),
row_indices: UInt64Builder::new(),
batch_ids: vec![],
},
output: RecordBatch::new_empty(schema),
output_size: 0,
batch_size,
join_type,
Expand Down Expand Up @@ -1607,17 +1642,20 @@ impl SortMergeJoinStream {
buffered_batch,
)? {
let num_rows = record_batch.num_rows();
self.output_record_batches
self.staging_output_record_batches
.filter_mask
.append_nulls(num_rows);
self.output_record_batches
self.staging_output_record_batches
.row_indices
.append_nulls(num_rows);
self.output_record_batches
.batch_ids
.resize(self.output_record_batches.batch_ids.len() + num_rows, 0);
self.staging_output_record_batches.batch_ids.resize(
self.staging_output_record_batches.batch_ids.len() + num_rows,
0,
);

self.output_record_batches.batches.push(record_batch);
self.staging_output_record_batches
.batches
.push(record_batch);
}
buffered_batch.null_joined.clear();
}
Expand Down Expand Up @@ -1651,16 +1689,19 @@ impl SortMergeJoinStream {
)? {
let num_rows = record_batch.num_rows();

self.output_record_batches
self.staging_output_record_batches
.filter_mask
.append_nulls(num_rows);
self.output_record_batches
self.staging_output_record_batches
.row_indices
.append_nulls(num_rows);
self.output_record_batches
.batch_ids
.resize(self.output_record_batches.batch_ids.len() + num_rows, 0);
self.output_record_batches.batches.push(record_batch);
self.staging_output_record_batches.batch_ids.resize(
self.staging_output_record_batches.batch_ids.len() + num_rows,
0,
);
self.staging_output_record_batches
.batches
.push(record_batch);
}
buffered_batch.join_filter_not_matched_map.clear();

Expand Down Expand Up @@ -1794,20 +1835,29 @@ impl SortMergeJoinStream {
| JoinType::LeftMark
| JoinType::Full
) {
self.output_record_batches.batches.push(output_batch);
self.staging_output_record_batches
.batches
.push(output_batch);
} else {
let filtered_batch = filter_record_batch(&output_batch, &mask)?;
self.output_record_batches.batches.push(filtered_batch);
self.staging_output_record_batches
.batches
.push(filtered_batch);
}

if !matches!(self.join_type, JoinType::Full) {
self.output_record_batches.filter_mask.extend(&mask);
self.staging_output_record_batches.filter_mask.extend(&mask);
} else {
self.output_record_batches.filter_mask.extend(pre_mask);
self.staging_output_record_batches
.filter_mask
.extend(pre_mask);
}
self.output_record_batches.row_indices.extend(&left_indices);
self.output_record_batches.batch_ids.resize(
self.output_record_batches.batch_ids.len() + left_indices.len(),
self.staging_output_record_batches
.row_indices
.extend(&left_indices);
self.staging_output_record_batches.batch_ids.resize(
self.staging_output_record_batches.batch_ids.len()
+ left_indices.len(),
self.streamed_batch_counter.load(Relaxed),
);

Expand Down Expand Up @@ -1839,10 +1889,14 @@ impl SortMergeJoinStream {
}
}
} else {
self.output_record_batches.batches.push(output_batch);
self.staging_output_record_batches
.batches
.push(output_batch);
}
} else {
self.output_record_batches.batches.push(output_batch);
self.staging_output_record_batches
.batches
.push(output_batch);
}
}

Expand All @@ -1853,7 +1907,7 @@ impl SortMergeJoinStream {

fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> {
let record_batch =
concat_batches(&self.schema, &self.output_record_batches.batches)?;
concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(record_batch.num_rows());
// If join filter exists, `self.output_size` is not accurate as we don't know the exact
Expand All @@ -1879,16 +1933,17 @@ impl SortMergeJoinStream {
| JoinType::Full
))
{
self.output_record_batches.batches.clear();
self.staging_output_record_batches.batches.clear();
}
Ok(record_batch)
}

fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
let record_batch = self.output_record_batch_and_reset()?;
let mut out_indices = self.output_record_batches.row_indices.finish();
let mut out_mask = self.output_record_batches.filter_mask.finish();
let mut batch_ids = &self.output_record_batches.batch_ids;
let record_batch =
concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
let mut out_indices = self.staging_output_record_batches.row_indices.finish();
let mut out_mask = self.staging_output_record_batches.filter_mask.finish();
let mut batch_ids = &self.staging_output_record_batches.batch_ids;
let default_batch_ids = vec![0; record_batch.num_rows()];

// If only nulls come in and indices sizes doesn't match with expected record batch count
Expand All @@ -1903,7 +1958,7 @@ impl SortMergeJoinStream {
}

if out_mask.is_empty() {
self.output_record_batches.batches.clear();
self.staging_output_record_batches.batches.clear();
return Ok(record_batch);
}

Expand Down Expand Up @@ -2046,7 +2101,7 @@ impl SortMergeJoinStream {
)?;
}

self.output_record_batches.clear();
self.staging_output_record_batches.clear();

Ok(filtered_record_batch)
}
Expand Down
Loading