diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index bcacc7dcae0f..d3a5d89fb079 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -792,7 +792,10 @@ struct SortMergeJoinStream { /// optional join filter pub filter: Option, /// Staging output array builders - pub output_record_batches: JoinedRecordBatches, + pub staging_output_record_batches: JoinedRecordBatches, + /// Output buffer. Currently used by filtering as it requires double buffering + /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches` + pub output: RecordBatch, /// Staging output size, including output batches and staging joined results. /// Increased when we put rows into buffer and decreased after we actually output batches. /// Used to trigger output when sufficient rows are ready @@ -1053,13 +1056,35 @@ impl Stream for SortMergeJoinStream { { self.freeze_all()?; - if !self.output_record_batches.batches.is_empty() + // If join is filtered and there is joined tuples waiting + // to be filtered + if !self + .staging_output_record_batches + .batches + .is_empty() { + // Apply filter on joined tuples and get filtered batch let out_filtered_batch = self.filter_joined_batch()?; - return Poll::Ready(Some(Ok( - out_filtered_batch, - ))); + + // Append filtered batch to the output buffer + self.output = concat_batches( + &self.schema(), + vec![&self.output, &out_filtered_batch], + )?; + + // Send to output if the output buffer surpassed the `batch_size` + if self.output.num_rows() >= self.batch_size { + let record_batch = std::mem::replace( + &mut self.output, + RecordBatch::new_empty( + out_filtered_batch.schema(), + ), + ); + return Poll::Ready(Some(Ok( + record_batch, + ))); + } } } @@ -1116,7 +1141,7 @@ impl Stream for SortMergeJoinStream { } } else { self.freeze_all()?; - if !self.output_record_batches.batches.is_empty() { + if !self.staging_output_record_batches.batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; // For non-filtered join output whenever the target output batch size // is hit. For filtered join its needed to output on later phase @@ -1146,7 +1171,8 @@ impl Stream for SortMergeJoinStream { SortMergeJoinState::Exhausted => { self.freeze_all()?; - if !self.output_record_batches.batches.is_empty() { + // if there is still something not processed + if !self.staging_output_record_batches.batches.is_empty() { if self.filter.is_some() && matches!( self.join_type, @@ -1159,12 +1185,20 @@ impl Stream for SortMergeJoinStream { | JoinType::LeftMark ) { - let out = self.filter_joined_batch()?; - return Poll::Ready(Some(Ok(out))); + let record_batch = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(record_batch))); } else { let record_batch = self.output_record_batch_and_reset()?; return Poll::Ready(Some(Ok(record_batch))); } + } else if self.output.num_rows() > 0 { + // if processed but still not outputted because it didn't hit batch size before + let schema = self.output.schema(); + let record_batch = std::mem::replace( + &mut self.output, + RecordBatch::new_empty(schema), + ); + return Poll::Ready(Some(Ok(record_batch))); } else { return Poll::Ready(None); } @@ -1197,7 +1231,7 @@ impl SortMergeJoinStream { state: SortMergeJoinState::Init, sort_options, null_equals_null, - schema, + schema: Arc::clone(&schema), streamed_schema: Arc::clone(&streamed_schema), buffered_schema, streamed, @@ -1212,12 +1246,13 @@ impl SortMergeJoinStream { on_streamed, on_buffered, filter, - output_record_batches: JoinedRecordBatches { + staging_output_record_batches: JoinedRecordBatches { batches: vec![], filter_mask: BooleanBuilder::new(), row_indices: UInt64Builder::new(), batch_ids: vec![], }, + output: RecordBatch::new_empty(schema), output_size: 0, batch_size, join_type, @@ -1607,17 +1642,20 @@ impl SortMergeJoinStream { buffered_batch, )? { let num_rows = record_batch.num_rows(); - self.output_record_batches + self.staging_output_record_batches .filter_mask .append_nulls(num_rows); - self.output_record_batches + self.staging_output_record_batches .row_indices .append_nulls(num_rows); - self.output_record_batches - .batch_ids - .resize(self.output_record_batches.batch_ids.len() + num_rows, 0); + self.staging_output_record_batches.batch_ids.resize( + self.staging_output_record_batches.batch_ids.len() + num_rows, + 0, + ); - self.output_record_batches.batches.push(record_batch); + self.staging_output_record_batches + .batches + .push(record_batch); } buffered_batch.null_joined.clear(); } @@ -1651,16 +1689,19 @@ impl SortMergeJoinStream { )? { let num_rows = record_batch.num_rows(); - self.output_record_batches + self.staging_output_record_batches .filter_mask .append_nulls(num_rows); - self.output_record_batches + self.staging_output_record_batches .row_indices .append_nulls(num_rows); - self.output_record_batches - .batch_ids - .resize(self.output_record_batches.batch_ids.len() + num_rows, 0); - self.output_record_batches.batches.push(record_batch); + self.staging_output_record_batches.batch_ids.resize( + self.staging_output_record_batches.batch_ids.len() + num_rows, + 0, + ); + self.staging_output_record_batches + .batches + .push(record_batch); } buffered_batch.join_filter_not_matched_map.clear(); @@ -1794,20 +1835,29 @@ impl SortMergeJoinStream { | JoinType::LeftMark | JoinType::Full ) { - self.output_record_batches.batches.push(output_batch); + self.staging_output_record_batches + .batches + .push(output_batch); } else { let filtered_batch = filter_record_batch(&output_batch, &mask)?; - self.output_record_batches.batches.push(filtered_batch); + self.staging_output_record_batches + .batches + .push(filtered_batch); } if !matches!(self.join_type, JoinType::Full) { - self.output_record_batches.filter_mask.extend(&mask); + self.staging_output_record_batches.filter_mask.extend(&mask); } else { - self.output_record_batches.filter_mask.extend(pre_mask); + self.staging_output_record_batches + .filter_mask + .extend(pre_mask); } - self.output_record_batches.row_indices.extend(&left_indices); - self.output_record_batches.batch_ids.resize( - self.output_record_batches.batch_ids.len() + left_indices.len(), + self.staging_output_record_batches + .row_indices + .extend(&left_indices); + self.staging_output_record_batches.batch_ids.resize( + self.staging_output_record_batches.batch_ids.len() + + left_indices.len(), self.streamed_batch_counter.load(Relaxed), ); @@ -1839,10 +1889,14 @@ impl SortMergeJoinStream { } } } else { - self.output_record_batches.batches.push(output_batch); + self.staging_output_record_batches + .batches + .push(output_batch); } } else { - self.output_record_batches.batches.push(output_batch); + self.staging_output_record_batches + .batches + .push(output_batch); } } @@ -1853,7 +1907,7 @@ impl SortMergeJoinStream { fn output_record_batch_and_reset(&mut self) -> Result { let record_batch = - concat_batches(&self.schema, &self.output_record_batches.batches)?; + concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); // If join filter exists, `self.output_size` is not accurate as we don't know the exact @@ -1879,16 +1933,17 @@ impl SortMergeJoinStream { | JoinType::Full )) { - self.output_record_batches.batches.clear(); + self.staging_output_record_batches.batches.clear(); } Ok(record_batch) } fn filter_joined_batch(&mut self) -> Result { - let record_batch = self.output_record_batch_and_reset()?; - let mut out_indices = self.output_record_batches.row_indices.finish(); - let mut out_mask = self.output_record_batches.filter_mask.finish(); - let mut batch_ids = &self.output_record_batches.batch_ids; + let record_batch = + concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; + let mut out_indices = self.staging_output_record_batches.row_indices.finish(); + let mut out_mask = self.staging_output_record_batches.filter_mask.finish(); + let mut batch_ids = &self.staging_output_record_batches.batch_ids; let default_batch_ids = vec![0; record_batch.num_rows()]; // If only nulls come in and indices sizes doesn't match with expected record batch count @@ -1903,7 +1958,7 @@ impl SortMergeJoinStream { } if out_mask.is_empty() { - self.output_record_batches.batches.clear(); + self.staging_output_record_batches.batches.clear(); return Ok(record_batch); } @@ -2046,7 +2101,7 @@ impl SortMergeJoinStream { )?; } - self.output_record_batches.clear(); + self.staging_output_record_batches.clear(); Ok(filtered_record_batch) }