From e041b6a37a9458dae73743a33ecb11a4f92fec5c Mon Sep 17 00:00:00 2001 From: Oleks V Date: Wed, 13 Nov 2024 12:42:11 -0800 Subject: [PATCH] Minor: SortMergeJoin small refactoring (#13398) * SortMergeJoin: small refactoring --- .../src/joins/sort_merge_join.rs | 186 +++++++++--------- 1 file changed, 96 insertions(+), 90 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 9307caf1c6ad..a01cd348f0c5 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -371,7 +371,7 @@ impl ExecutionPlan for SortMergeJoinExec { .register(context.memory_pool()); // create join stream - Ok(Box::pin(SMJStream::try_new( + Ok(Box::pin(SortMergeJoinStream::try_new( Arc::clone(&self.schema), self.sort_options.clone(), self.null_equals_null, @@ -461,7 +461,7 @@ impl SortMergeJoinMetrics { /// State of SMJ stream #[derive(Debug, PartialEq, Eq)] -enum SMJState { +enum SortMergeJoinState { /// Init joining with a new streamed row or a new buffered batches Init, /// Polling one streamed row or one buffered batch, or both @@ -597,11 +597,11 @@ struct BufferedBatch { pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, - /// The indices of buffered batch that failed the join filter. - /// This is a map between buffered row index and a boolean value indicating whether all joined row - /// of the buffered row failed the join filter. + /// The indices of buffered batch that the join filter doesn't satisfy. + /// This is a map between right row index and a boolean value indicating whether all joined row + /// of the right row does not satisfy the filter . /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. - pub join_filter_failed_map: HashMap, + pub join_filter_not_matched_map: HashMap, /// Current buffered batch number of rows. Equal to batch.num_rows() /// but if batch is spilled to disk this property is preferable /// and less expensive @@ -642,7 +642,7 @@ impl BufferedBatch { join_arrays, null_joined: vec![], size_estimation, - join_filter_failed_map: HashMap::new(), + join_filter_not_matched_map: HashMap::new(), num_rows, spill_file: None, } @@ -651,9 +651,9 @@ impl BufferedBatch { /// Sort-merge join stream that consumes streamed and buffered data stream /// and produces joined output -struct SMJStream { +struct SortMergeJoinStream { /// Current state of the stream - pub state: SMJState, + pub state: SortMergeJoinState, /// Output schema pub schema: SchemaRef, /// Sort options of join columns used to sort streamed and buffered data stream @@ -722,7 +722,15 @@ struct JoinedRecordBatches { pub batch_ids: Vec, } -impl RecordBatchStream for SMJStream { +impl JoinedRecordBatches { + fn clear(&mut self) { + self.batches.clear(); + self.batch_ids.clear(); + self.filter_mask = BooleanBuilder::new(); + self.row_indices = UInt64Builder::new(); + } +} +impl RecordBatchStream for SortMergeJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -865,6 +873,7 @@ fn get_corrected_filter_mask( let is_null = filter_mask.is_null(i); if val { + // memoize the first seen matched row if !seen_true { last_true_idx = i; } @@ -884,6 +893,8 @@ fn get_corrected_filter_mask( } if last_index { + // If the left row seen as true its needed to output it once + // To do that we mark all other matches for same row as null to avoid the output if seen_true { #[allow(clippy::needless_range_loop)] for j in first_row_idx..last_true_idx { @@ -905,7 +916,7 @@ fn get_corrected_filter_mask( } } -impl Stream for SMJStream { +impl Stream for SortMergeJoinStream { type Item = Result; fn poll_next( @@ -916,13 +927,13 @@ impl Stream for SMJStream { let _timer = join_time.timer(); loop { match &self.state { - SMJState::Init => { + SortMergeJoinState::Init => { let streamed_exhausted = self.streamed_state == StreamedState::Exhausted; let buffered_exhausted = self.buffered_state == BufferedState::Exhausted; self.state = if streamed_exhausted && buffered_exhausted { - SMJState::Exhausted + SortMergeJoinState::Exhausted } else { match self.current_ordering { Ordering::Less | Ordering::Equal => { @@ -961,10 +972,10 @@ impl Stream for SMJStream { } } } - SMJState::Polling + SortMergeJoinState::Polling }; } - SMJState::Polling => { + SortMergeJoinState::Polling => { if ![StreamedState::Exhausted, StreamedState::Ready] .contains(&self.streamed_state) { @@ -987,19 +998,19 @@ impl Stream for SMJStream { let buffered_exhausted = self.buffered_state == BufferedState::Exhausted; if streamed_exhausted && buffered_exhausted { - self.state = SMJState::Exhausted; + self.state = SortMergeJoinState::Exhausted; continue; } self.current_ordering = self.compare_streamed_buffered()?; - self.state = SMJState::JoinOutput; + self.state = SortMergeJoinState::JoinOutput; } - SMJState::JoinOutput => { + SortMergeJoinState::JoinOutput => { self.join_partial()?; if self.output_size < self.batch_size { if self.buffered_data.scanning_finished() { self.buffered_data.scanning_reset(); - self.state = SMJState::Init; + self.state = SortMergeJoinState::Init; } } else { self.freeze_all()?; @@ -1029,7 +1040,7 @@ impl Stream for SMJStream { return Poll::Pending; } } - SMJState::Exhausted => { + SortMergeJoinState::Exhausted => { self.freeze_all()?; if !self.output_record_batches.batches.is_empty() { @@ -1059,7 +1070,7 @@ impl Stream for SMJStream { } } -impl SMJStream { +impl SortMergeJoinStream { #[allow(clippy::too_many_arguments)] pub fn try_new( schema: SchemaRef, @@ -1079,7 +1090,7 @@ impl SMJStream { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); Ok(Self { - state: SMJState::Init, + state: SortMergeJoinState::Init, sort_options, null_equals_null, schema, @@ -1517,7 +1528,7 @@ impl SMJStream { // For buffered row which is joined with streamed side rows but all joined rows // don't satisfy the join filter let not_matched_buffered_indices = buffered_batch - .join_filter_failed_map + .join_filter_not_matched_map .iter() .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) .collect::>(); @@ -1531,7 +1542,6 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { - //print_batches(&[record_batch.clone()]); let num_rows = record_batch.num_rows(); self.output_record_batches @@ -1545,8 +1555,7 @@ impl SMJStream { .extend(vec![0; num_rows]); self.output_record_batches.batches.push(record_batch); } - //dbg!(&buffered_batch.join_filter_failed_map); - buffered_batch.join_filter_failed_map.clear(); + buffered_batch.join_filter_not_matched_map.clear(); Ok(()) } @@ -1556,31 +1565,31 @@ impl SMJStream { fn freeze_streamed(&mut self) -> Result<()> { for chunk in self.streamed_batch.output_indices.iter_mut() { // The row indices of joined streamed batch - let streamed_indices = chunk.streamed_indices.finish(); + let left_indices = chunk.streamed_indices.finish(); - if streamed_indices.is_empty() { + if left_indices.is_empty() { continue; } - let mut streamed_columns = self + let mut left_columns = self .streamed_batch .batch .columns() .iter() - .map(|column| take(column, &streamed_indices, None)) + .map(|column| take(column, &left_indices, None)) .collect::, ArrowError>>()?; // The row indices of joined buffered batch - let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); - let mut buffered_columns = if matches!(self.join_type, JoinType::LeftMark) { - vec![Arc::new(is_not_null(&buffered_indices)?) as ArrayRef] + let right_indices: UInt64Array = chunk.buffered_indices.finish(); + let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) { + vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { - get_buffered_columns( + fetch_right_columns_by_idxs( &self.buffered_data, buffered_idx, - &buffered_indices, + &right_indices, )? } else { // If buffered batch none, meaning it is null joined batch. @@ -1588,29 +1597,30 @@ impl SMJStream { create_unmatched_columns( self.join_type, &self.buffered_schema, - buffered_indices.len(), + right_indices.len(), ) }; // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. let filter_columns = if chunk.buffered_batch_idx.is_some() { - if matches!(self.join_type, JoinType::Right) { - get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!( - self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark - ) { - // unwrap is safe here as we check is_some on top of if statement - let buffered_columns = get_buffered_columns( - &self.buffered_data, - chunk.buffered_batch_idx.unwrap(), - &buffered_indices, - )?; + if !matches!(self.join_type, JoinType::Right) { + if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark + ) { + let right_cols = fetch_right_columns_by_idxs( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &right_indices, + )?; - get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + get_filter_column(&self.filter, &left_columns, &right_cols) + } else { + get_filter_column(&self.filter, &left_columns, &right_columns) + } } else { - get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + get_filter_column(&self.filter, &right_columns, &left_columns) } } else { // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. @@ -1618,12 +1628,12 @@ impl SMJStream { vec![] }; - let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns); - buffered_columns + let columns = if !matches!(self.join_type, JoinType::Right) { + left_columns.extend(right_columns); + left_columns } else { - streamed_columns.extend(buffered_columns); - streamed_columns + right_columns.extend(left_columns); + right_columns }; let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; @@ -1665,9 +1675,7 @@ impl SMJStream { | JoinType::LeftMark | JoinType::Full ) { - self.output_record_batches - .batches - .push(output_batch.clone()); + self.output_record_batches.batches.push(output_batch); } else { let filtered_batch = filter_record_batch(&output_batch, &mask)?; self.output_record_batches.batches.push(filtered_batch); @@ -1678,12 +1686,10 @@ impl SMJStream { } else { self.output_record_batches.filter_mask.extend(pre_mask); } - self.output_record_batches - .row_indices - .extend(&streamed_indices); + self.output_record_batches.row_indices.extend(&left_indices); self.output_record_batches.batch_ids.extend(vec![ - self.streamed_batch_counter.load(Relaxed); - streamed_indices.len() + self.streamed_batch_counter.load(Relaxed); + left_indices.len() ]); // For outer joins, we need to push the null joined rows to the output if @@ -1697,16 +1703,16 @@ impl SMJStream { for i in 0..pre_mask.len() { // If the buffered row is not joined with streamed side, // skip it. - if buffered_indices.is_null(i) { + if right_indices.is_null(i) { continue; } - let buffered_index = buffered_indices.value(i); + let buffered_index = right_indices.value(i); - buffered_batch.join_filter_failed_map.insert( + buffered_batch.join_filter_not_matched_map.insert( buffered_index, *buffered_batch - .join_filter_failed_map + .join_filter_not_matched_map .get(&buffered_index) .unwrap_or(&true) && !pre_mask.value(i), @@ -1765,6 +1771,9 @@ impl SMJStream { let mut batch_ids = &self.output_record_batches.batch_ids; let default_batch_ids = vec![0; record_batch.num_rows()]; + // If only nulls come in and indices sizes doesn't match with expected record batch count + // generate missing indices + // Happens for null joined batches for Full Join if out_indices.null_count() == out_indices.len() && out_indices.len() != record_batch.num_rows() { @@ -1804,32 +1813,32 @@ impl SMJStream { let null_mask = compute::not(corrected_mask)?; let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; - let mut buffered_columns = create_unmatched_columns( + let mut right_columns = create_unmatched_columns( self.join_type, &self.buffered_schema, null_joined_batch.num_rows(), ); - let columns = if matches!(self.join_type, JoinType::Right) { - let streamed_columns = null_joined_batch + let columns = if !matches!(self.join_type, JoinType::Right) { + let mut left_columns = null_joined_batch .columns() .iter() - .skip(left_columns_length) + .take(right_columns_length) .cloned() .collect::>(); - buffered_columns.extend(streamed_columns); - buffered_columns + left_columns.extend(right_columns); + left_columns } else { - let mut streamed_columns = null_joined_batch + let left_columns = null_joined_batch .columns() .iter() - .take(right_columns_length) + .skip(left_columns_length) .cloned() .collect::>(); - streamed_columns.extend(buffered_columns); - streamed_columns + right_columns.extend(left_columns); + right_columns }; // Push the streamed/buffered batch joined nulls to the output @@ -1905,10 +1914,7 @@ impl SMJStream { )?; } - self.output_record_batches.batches.clear(); - self.output_record_batches.batch_ids.clear(); - self.output_record_batches.filter_mask = BooleanBuilder::new(); - self.output_record_batches.row_indices = UInt64Builder::new(); + self.output_record_batches.clear(); Ok(filtered_record_batch) } @@ -1971,39 +1977,39 @@ fn produce_buffered_null_batch( } // Take buffered (right) columns - let buffered_columns = - get_buffered_columns_from_batch(buffered_batch, buffered_indices)?; + let right_columns = + fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?; // Create null streamed (left) columns - let mut streamed_columns = streamed_schema + let mut left_columns = streamed_schema .fields() .iter() .map(|f| new_null_array(f.data_type(), buffered_indices.len())) .collect::>(); - streamed_columns.extend(buffered_columns); + left_columns.extend(right_columns); Ok(Some(RecordBatch::try_new( Arc::clone(schema), - streamed_columns, + left_columns, )?)) } -/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` +/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices #[inline(always)] -fn get_buffered_columns( +fn fetch_right_columns_by_idxs( buffered_data: &BufferedData, buffered_batch_idx: usize, buffered_indices: &UInt64Array, ) -> Result> { - get_buffered_columns_from_batch( + fetch_right_columns_from_batch_by_idxs( &buffered_data.batches[buffered_batch_idx], buffered_indices, ) } #[inline(always)] -fn get_buffered_columns_from_batch( +fn fetch_right_columns_from_batch_by_idxs( buffered_batch: &BufferedBatch, buffered_indices: &UInt64Array, ) -> Result> {