diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1e1c5d5424b0..47ffe0b1c66b 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -338,6 +338,12 @@ config_namespace! { /// if the source of statistics is accurate. /// We plan to make this the default in the future. pub use_row_number_estimates_to_optimize_partitioning: bool, default = false + + /// Should DataFusion enforce batch size in joins or not. By default, + /// DataFusion will not enforce batch size in joins. Enforcing batch size + /// in joins can reduce memory usage when joining large + /// tables with a highly-selective join filter, but is also slightly slower. + pub enforce_batch_size_in_joins: bool, default = false } } @@ -1222,16 +1228,18 @@ impl ConfigField for TableOptions { fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); - let Some(format) = &self.current_format else { - return _config_err!("Specify a format for TableOptions"); - }; match key { - "format" => match format { - #[cfg(feature = "parquet")] - ConfigFileType::PARQUET => self.parquet.set(rem, value), - ConfigFileType::CSV => self.csv.set(rem, value), - ConfigFileType::JSON => self.json.set(rem, value), - }, + "format" => { + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; + match format { + #[cfg(feature = "parquet")] + ConfigFileType::PARQUET => self.parquet.set(rem, value), + ConfigFileType::CSV => self.csv.set(rem, value), + ConfigFileType::JSON => self.json.set(rem, value), + } + } _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cede75d21ca4..53646dc5b468 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -432,6 +432,20 @@ impl SessionConfig { self } + /// Enables or disables the enforcement of batch size in joins + pub fn with_enforce_batch_size_in_joins( + mut self, + enforce_batch_size_in_joins: bool, + ) -> Self { + self.options.execution.enforce_batch_size_in_joins = enforce_batch_size_in_joins; + self + } + + /// Returns true if the joins will be enforced to output batches of the configured size + pub fn enforce_batch_size_in_joins(&self) -> bool { + self.options.execution.enforce_batch_size_in_joins + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index a70645f3d6c0..8f2bef56da76 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -19,7 +19,8 @@ //! and producing batches in parallel for the right partitions use super::utils::{ - adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, + adjust_right_output_partitioning, BatchSplitter, BatchTransformer, + BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; @@ -86,6 +87,7 @@ impl CrossJoinExec { let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + CrossJoinExec { left, right, @@ -246,6 +248,10 @@ impl ExecutionPlan for CrossJoinExec { let reservation = MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let left_fut = self.left_fut.once(|| { load_left_input( Arc::clone(&self.left), @@ -255,15 +261,29 @@ impl ExecutionPlan for CrossJoinExec { ) }); - Ok(Box::pin(CrossJoinStream { - schema: Arc::clone(&self.schema), - left_fut, - right: stream, - left_index: 0, - join_metrics, - state: CrossJoinStreamState::WaitBuildSide, - left_data: RecordBatch::new_empty(self.left().schema()), - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: NoopBatchTransformer::new(), + })) + } } fn statistics(&self) -> Result { @@ -319,7 +339,7 @@ fn stats_cartesian_product( } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct CrossJoinStream { +struct CrossJoinStream { /// Input schema schema: Arc, /// Future for data from left side @@ -334,9 +354,11 @@ struct CrossJoinStream { state: CrossJoinStreamState, /// Left data left_data: RecordBatch, + /// Batch transformer + batch_transformer: T, } -impl RecordBatchStream for CrossJoinStream { +impl RecordBatchStream for CrossJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -390,7 +412,7 @@ fn build_batch( } #[async_trait] -impl Stream for CrossJoinStream { +impl Stream for CrossJoinStream { type Item = Result; fn poll_next( @@ -401,7 +423,7 @@ impl Stream for CrossJoinStream { } } -impl CrossJoinStream { +impl CrossJoinStream { /// Separate implementation function that unpins the [`CrossJoinStream`] so /// that partial borrows work correctly fn poll_next_impl( @@ -470,21 +492,33 @@ impl CrossJoinStream { fn build_batches(&mut self) -> Result>> { let right_batch = self.state.try_as_record_batch()?; if self.left_index < self.left_data.num_rows() { - let join_timer = self.join_metrics.join_time.timer(); - let result = - build_batch(self.left_index, right_batch, &self.left_data, &self.schema); - join_timer.done(); - - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + match self.batch_transformer.next() { + None => { + let join_timer = self.join_metrics.join_time.timer(); + let result = build_batch( + self.left_index, + right_batch, + &self.left_data, + &self.schema, + ); + join_timer.done(); + + self.batch_transformer.set_batch(result?); + } + Some((batch, last)) => { + if last { + self.left_index += 1; + } + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Ok(StatefulStreamResult::Ready(Some(batch))); + } } - self.left_index += 1; - result.map(|r| StatefulStreamResult::Ready(Some(r))) } else { self.state = CrossJoinStreamState::FetchProbeBatch; - Ok(StatefulStreamResult::Continue) } + Ok(StatefulStreamResult::Continue) } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 74a45a7e4761..3b730c01291c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1438,7 +1438,7 @@ impl HashJoinStream { index_alignment_range_start..index_alignment_range_end, self.join_type, self.right_side_ordered, - ); + )?; let result = build_batch_from_indices( &self.schema, diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 6068e7526316..358ff02473a6 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -25,7 +25,10 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; +use super::utils::{ + asymmetric_join_output_partitioning, need_produce_result_in_final, BatchSplitter, + BatchTransformer, NoopBatchTransformer, StatefulStreamResult, +}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, @@ -35,8 +38,8 @@ use crate::joins::utils::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + execution_mode_from_children, handle_state, DisplayAs, DisplayFormatType, + Distribution, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; @@ -45,7 +48,9 @@ use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics}; +use datafusion_common::{ + exec_datafusion_err, internal_err, JoinSide, Result, Statistics, +}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::JoinType; @@ -230,10 +235,11 @@ impl NestedLoopJoinExec { asymmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: - let mut mode = execution_mode_from_children([left, right]); - if mode.is_unbounded() { - mode = ExecutionMode::PipelineBreaking; - } + let mode = if left.execution_mode().is_unbounded() { + ExecutionMode::PipelineBreaking + } else { + execution_mode_from_children([left, right]) + }; PlanProperties::new(eq_properties, output_partitioning, mode) } @@ -345,6 +351,10 @@ impl ExecutionPlan for NestedLoopJoinExec { ) }); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let outer_table = self.right.execute(partition, context)?; let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); @@ -352,18 +362,38 @@ impl ExecutionPlan for NestedLoopJoinExec { // Right side has an order and it is maintained during operation. let right_side_ordered = self.maintains_input_order()[1] && self.right.output_ordering().is_some(); - Ok(Box::pin(NestedLoopJoinStream { - schema: Arc::clone(&self.schema), - filter: self.filter.clone(), - join_type: self.join_type, - outer_table, - inner_table, - is_exhausted: false, - column_indices: self.column_indices.clone(), - join_metrics, - indices_cache, - right_side_ordered, - })) + + if enforce_batch_size_in_joins { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: BatchSplitter::new(batch_size), + left_data: None, + })) + } else { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: NoopBatchTransformer::new(), + left_data: None, + })) + } } fn metrics(&self) -> Option { @@ -442,8 +472,37 @@ async fn collect_left_input( )) } +/// This enumeration represents various states of the nested loop join algorithm. +#[derive(Debug, Clone)] +enum NestedLoopJoinStreamState { + /// The initial state, indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for + /// fetching probe-side + FetchProbeBatch, + /// Indicates that a non-empty batch has been fetched from probe-side, and + /// is ready to be processed + ProcessProbeBatch(RecordBatch), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that NestedLoopJoinStream execution is completed + Completed, +} + +impl NestedLoopJoinStreamState { + /// Tries to extract a `ProcessProbeBatchState` from the + /// `NestedLoopJoinStreamState` enum. Returns an error if state is not + /// `ProcessProbeBatchState`. + fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> { + match self { + NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected join stream in ProcessProbeBatch state"), + } + } +} + /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct NestedLoopJoinStream { +struct NestedLoopJoinStream { /// Input schema schema: Arc, /// join filter @@ -454,8 +513,6 @@ struct NestedLoopJoinStream { outer_table: SendableRecordBatchStream, /// the inner table data of the nested loop join inner_table: OnceFut, - /// There is nothing to process anymore and left side is processed in case of full join - is_exhausted: bool, /// Information of index and left / right placement of columns column_indices: Vec, // TODO: support null aware equal @@ -466,6 +523,12 @@ struct NestedLoopJoinStream { indices_cache: (UInt64Array, UInt32Array), /// Whether the right side is ordered right_side_ordered: bool, + /// Current state of the stream + state: NestedLoopJoinStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, + /// Result of the left data future + left_data: Option>, } /// Creates a Cartesian product of two input batches, preserving the order of the right batch, @@ -544,107 +607,164 @@ fn build_join_indices( } } -impl NestedLoopJoinStream { +impl NestedLoopJoinStream { fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { - // all left row + loop { + return match self.state { + NestedLoopJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + NestedLoopJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + NestedLoopJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + NestedLoopJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + NestedLoopJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let left_data = match ready!(self.inner_table.get_shared(cx)) { - Ok(data) => data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + // build hash table from left (build) side, if not yet done + self.left_data = Some(ready!(self.inner_table.get_shared(cx))?); build_timer.done(); - // Get or initialize visited_left_side bitmap if required by join type + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If a non-empty batch has been fetched, updates state to + /// `ProcessProbeBatchState`, otherwise updates state to `ExhaustedProbeSide`. + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.outer_table.poll_next_unpin(cx)) { + None => { + self.state = NestedLoopJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(right_batch)) => { + self.state = NestedLoopJoinStreamState::ProcessProbeBatch(right_batch); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with + /// matched output, updates state to `FetchProbeBatch`. + fn process_probe_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ProcessProbeBatch state" + ); + }; let visited_left_side = left_data.bitmap(); + let batch = self.state.try_as_process_probe_batch()?; + + match self.batch_transformer.next() { + None => { + // Setting up timer & updating input metrics + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); + timer.done(); + + self.batch_transformer.set_batch(result?); + Ok(StatefulStreamResult::Continue) + } + Some((batch, last)) => { + if last { + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + } - // Check is_exhausted before polling the outer_table, such that when the outer table - // does not support `FusedStream`, Self will not poll it again - if self.is_exhausted { - return Poll::Ready(None); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Ok(StatefulStreamResult::Ready(Some(batch))) + } } + } - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } - Some(err) => Some(err), - None => { - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { - self.is_exhausted = true; - return None; - }; - - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } else { - // end of the join loop - None - } - } - }) + /// Processes unmatched build-side rows for certain join types and produces + /// output batch, updates state to `Completed`. + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ExhaustedProbeSide state" + ); + }; + let visited_left_side = left_data.bitmap(); + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` / returning None will prevent from + // multiple calls of `report_probe_completed()` + if !left_data.report_probe_completed() { + self.state = NestedLoopJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.state = NestedLoopJoinStreamState::Completed; + + // Recording time + if result.is_ok() { + timer.done(); + } + + Ok(StatefulStreamResult::Ready(Some(result?))) + } else { + // end of the join loop + self.state = NestedLoopJoinStreamState::Completed; + Ok(StatefulStreamResult::Ready(None)) + } } } @@ -684,7 +804,7 @@ fn join_left_and_right_batch( 0..right_batch.num_rows(), join_type, right_side_ordered, - ); + )?; build_batch_from_indices( schema, @@ -705,7 +825,7 @@ fn get_final_indices_from_shared_bitmap( get_final_indices_from_bit_map(&bitmap, join_type) } -impl Stream for NestedLoopJoinStream { +impl Stream for NestedLoopJoinStream { type Item = Result; fn poll_next( @@ -716,14 +836,14 @@ impl Stream for NestedLoopJoinStream { } } -impl RecordBatchStream for NestedLoopJoinStream { +impl RecordBatchStream for NestedLoopJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, @@ -850,7 +970,7 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } - async fn multi_partitioned_join_collect( + pub(crate) async fn multi_partitioned_join_collect( left: Arc, right: Arc, join_type: &JoinType, diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index ba9384aef1a6..bddd152341da 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -31,8 +31,7 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, - ScalarValue, + arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; @@ -369,34 +368,40 @@ impl SortedFilterExpr { filter_expr: Arc, filter_schema: &Schema, ) -> Result { - let dt = &filter_expr.data_type(filter_schema)?; + let dt = filter_expr.data_type(filter_schema)?; Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::make_unbounded(dt)?, + interval: Interval::make_unbounded(&dt)?, node_index: 0, }) } + /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { &self.origin_sorted_expr } + /// Get filter expr information pub fn filter_expr(&self) -> &Arc { &self.filter_expr } + /// Get interval information pub fn interval(&self) -> &Interval { &self.interval } + /// Sets interval pub fn set_interval(&mut self, interval: Interval) { self.interval = interval; } + /// Node index in ExprIntervalGraph pub fn node_index(&self) -> usize { self.node_index } + /// Node index setter in ExprIntervalGraph pub fn set_node_index(&mut self, node_index: usize) { self.node_index = node_index; @@ -409,41 +414,45 @@ impl SortedFilterExpr { /// on the first or the last value of the expression in `build_input_buffer` /// and `probe_batch`. /// -/// # Arguments +/// # Parameters /// /// * `build_input_buffer` - The [RecordBatch] on the build side of the join. /// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. /// * `probe_batch` - The `RecordBatch` on the probe side of the join. /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. /// -/// ### Note -/// ```text +/// ## Note /// -/// Interval arithmetic is used to calculate viable join ranges for build-side -/// pruning. This is done by first creating an interval for join filter values in -/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the -/// ordering (descending/ascending) of the filter expression. Here, FV denotes the -/// first value on the build side. This range is then compared with the probe side -/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering -/// (ascending/descending) of the probe side. Here, LV denotes the last value on -/// the probe side. +/// Utilizing interval arithmetic, this function computes feasible join intervals +/// on the pruning side by evaluating the prospective value ranges that might +/// emerge in subsequent data batches from the enforcer side. This is done by +/// first creating an interval for join filter values in the pruning side of the +/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering (descending/ +/// ascending) of the filter expression. Here, `FV` denotes the first value on the +/// pruning side. This range is then compared with the enforcer side interval, +/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering (ascending/ +/// descending) of the probe side. Here, `LV` denotes the last value on the enforcer +/// side. /// /// As a concrete example, consider the following query: /// +/// ```text /// SELECT * FROM left_table, right_table /// WHERE /// left_key = right_key AND /// a > b - 3 AND /// a < b + 10 +/// ``` /// -/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// where columns `a` and `b` come from tables `left_table` and `right_table`, /// respectively. When a new `RecordBatch` arrives at the right side, the -/// condition a > b - 3 will possibly indicate a prunable range for the left +/// condition `a > b - 3` will possibly indicate a prunable range for the left /// side. Conversely, when a new `RecordBatch` arrives at the left side, the -/// condition a < b + 10 will possibly indicate prunability for the right side. -/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// condition `a < b + 10` will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new `RecordBatch` arrives at the right /// side (i.e. when the left side is the build side): /// +/// ```text /// Build Probe /// +-------+ +-------+ /// | a | z | | b | y | @@ -456,13 +465,13 @@ impl SortedFilterExpr { /// |+--|--+| |+--|--+| /// | 7 | 1 | | 6 | 3 | /// +-------+ +-------+ +/// ``` /// /// In this case, the interval representing viable (i.e. joinable) values for -/// column "a" is [1, ∞], and the interval representing possible future values -/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// column `a` is `[1, ∞]`, and the interval representing possible future values +/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate /// intervals for the whole filter expression and propagate join constraint by /// traversing the expression graph. -/// ``` pub fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, build_sorted_filter_expr: &mut SortedFilterExpr, @@ -710,13 +719,21 @@ fn update_sorted_exprs_with_node_indices( } } -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// Prepares and sorts expressions based on a given filter, left and right schemas, +/// and sort expressions. /// -/// # Arguments +/// This function prepares sorted filter expressions for both the left and right +/// sides of a join operation. It first builds the filter order for each side +/// based on the provided `ExecutionPlan`. If both sides have valid sorted filter +/// expressions, the function then constructs an expression interval graph and +/// updates the sorted expressions with node indices. The final sorted filter +/// expressions for both sides are then returned. +/// +/// # Parameters /// /// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. +/// * `left` - The `ExecutionPlan` for the left side of the join. +/// * `right` - The `ExecutionPlan` for the right side of the join. /// * `left_sort_exprs` - The expressions to sort on the left side. /// * `right_sort_exprs` - The expressions to sort on the right side. /// @@ -730,9 +747,11 @@ pub fn prepare_sorted_exprs( left_sort_exprs: &[PhysicalSortExpr], right_sort_exprs: &[PhysicalSortExpr], ) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); + let err = || { + datafusion_common::plan_datafusion_err!("Filter does not include the child order") + }; + // Build the filter order for the left side: let left_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Left, filter, @@ -741,7 +760,7 @@ pub fn prepare_sorted_exprs( )? .ok_or_else(err)?; - // Build the filter order for the right side + // Build the filter order for the right side: let right_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Right, filter, @@ -952,15 +971,15 @@ pub mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index ac718a95e9f4..70ada3892aea 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -32,7 +32,6 @@ use std::task::{Context, Poll}; use std::vec; use crate::common::SharedMemoryReservation; -use crate::handle_state; use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, @@ -42,8 +41,9 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, JoinFilter, - JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, + check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter, + BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, + NoopBatchTransformer, StatefulStreamResult, }; use crate::{ execution_mode_from_children, @@ -465,23 +465,27 @@ impl ExecutionPlan for SymmetricHashJoinExec { consider using RepartitionExec" ); } - // If `filter_state` and `filter` are both present, then calculate sorted filter expressions - // for both sides, and build an expression graph. - let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = - match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) { - (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { - let (left, right, graph) = prepare_sorted_exprs( - filter, - &self.left, - &self.right, - left_sort_exprs, - right_sort_exprs, - )?; - (Some(left), Some(right), Some(graph)) - } - // If `filter_state` or `filter` is not present, then return None for all three values: - _ => (None, None, None), - }; + // If `filter_state` and `filter` are both present, then calculate sorted + // filter expressions for both sides, and build an expression graph. + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match ( + self.left_sort_exprs(), + self.right_sort_exprs(), + &self.filter, + ) { + (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { + let (left, right, graph) = prepare_sorted_exprs( + filter, + &self.left, + &self.right, + left_sort_exprs, + right_sort_exprs, + )?; + (Some(left), Some(right), Some(graph)) + } + // If `filter_state` or `filter` is not present, then return None + // for all three values: + _ => (None, None, None), + }; let (on_left, on_right) = self.on.iter().cloned().unzip(); @@ -494,6 +498,10 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_stream = self.right.execute(partition, Arc::clone(&context))?; + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) .register(context.memory_pool()), @@ -502,29 +510,52 @@ impl ExecutionPlan for SymmetricHashJoinExec { reservation.lock().try_grow(g.size())?; } - Ok(Box::pin(SymmetricHashJoinStream { - left_stream, - right_stream, - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - random_state: self.random_state.clone(), - left: left_side_joiner, - right: right_side_joiner, - column_indices: self.column_indices.clone(), - metrics: StreamJoinMetrics::new(partition, &self.metrics), - graph, - left_sorted_filter_expr, - right_sorted_filter_expr, - null_equals_null: self.null_equals_null, - state: SHJStreamState::PullRight, - reservation, - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: NoopBatchTransformer::new(), + })) + } } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct SymmetricHashJoinStream { +struct SymmetricHashJoinStream { /// Input streams left_stream: SendableRecordBatchStream, right_stream: SendableRecordBatchStream, @@ -556,15 +587,19 @@ struct SymmetricHashJoinStream { reservation: SharedMemoryReservation, /// State machine for input execution state: SHJStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, } -impl RecordBatchStream for SymmetricHashJoinStream { +impl RecordBatchStream + for SymmetricHashJoinStream +{ fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } -impl Stream for SymmetricHashJoinStream { +impl Stream for SymmetricHashJoinStream { type Item = Result; fn poll_next( @@ -1140,7 +1175,7 @@ impl OneSideHashJoiner { /// - Transition to `BothExhausted { final_result: true }`: /// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are /// exhausted, indicating completion of processing and availability of final results. -impl SymmetricHashJoinStream { +impl SymmetricHashJoinStream { /// Implements the main polling logic for the join stream. /// /// This method continuously checks the state of the join stream and @@ -1159,26 +1194,45 @@ impl SymmetricHashJoinStream { cx: &mut Context<'_>, ) -> Poll>> { loop { - return match self.state() { - SHJStreamState::PullRight => { - handle_state!(ready!(self.fetch_next_from_right_stream(cx))) - } - SHJStreamState::PullLeft => { - handle_state!(ready!(self.fetch_next_from_left_stream(cx))) + match self.batch_transformer.next() { + None => { + let result = match self.state() { + SHJStreamState::PullRight => { + ready!(self.fetch_next_from_right_stream(cx)) + } + SHJStreamState::PullLeft => { + ready!(self.fetch_next_from_left_stream(cx)) + } + SHJStreamState::RightExhausted => { + ready!(self.handle_right_stream_end(cx)) + } + SHJStreamState::LeftExhausted => { + ready!(self.handle_left_stream_end(cx)) + } + SHJStreamState::BothExhausted { + final_result: false, + } => self.prepare_for_final_results_after_exhaustion(), + SHJStreamState::BothExhausted { final_result: true } => { + return Poll::Ready(None); + } + }; + + match result? { + StatefulStreamResult::Ready(None) => { + return Poll::Ready(None); + } + StatefulStreamResult::Ready(Some(batch)) => { + self.batch_transformer.set_batch(batch); + } + _ => {} + } } - SHJStreamState::RightExhausted => { - handle_state!(ready!(self.handle_right_stream_end(cx))) - } - SHJStreamState::LeftExhausted => { - handle_state!(ready!(self.handle_left_stream_end(cx))) - } - SHJStreamState::BothExhausted { - final_result: false, - } => { - handle_state!(self.prepare_for_final_results_after_exhaustion()) + Some((batch, _)) => { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); } - SHJStreamState::BothExhausted { final_result: true } => Poll::Ready(None), - }; + } } } /// Asynchronously pulls the next batch from the right stream. @@ -1384,11 +1438,8 @@ impl SymmetricHashJoinStream { // Combine the left and right results: let result = combine_two_batches(&self.schema, left_result, right_result)?; - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); + // Return the result: + if result.is_some() { return Ok(StatefulStreamResult::Ready(result)); } Ok(StatefulStreamResult::Continue) @@ -1523,11 +1574,6 @@ impl SymmetricHashJoinStream { let capacity = self.size(); self.metrics.stream_memory_usage.set(capacity); self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - } Ok(result) } } @@ -1716,15 +1762,15 @@ mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; @@ -1771,10 +1817,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1825,10 +1868,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1877,10 +1917,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; experiment(left, right, None, join_type, on, task_ctx).await?; Ok(()) } @@ -1926,10 +1963,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1987,10 +2021,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2048,10 +2079,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2111,10 +2139,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2170,10 +2195,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2237,10 +2259,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2296,10 +2315,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let left_sorted = vec![PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { @@ -2380,10 +2396,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let left_sorted = vec![PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { @@ -2473,10 +2486,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Float64, true), diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 89f3feaf07be..c520e4271416 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -546,15 +546,16 @@ pub struct ColumnIndex { pub side: JoinSide, } -/// Filter applied before join output +/// Filter applied before join output. Fields are crate-public to allow +/// downstream implementations to experiment with custom joins. #[derive(Debug, Clone)] pub struct JoinFilter { /// Filter expression - expression: Arc, + pub(crate) expression: Arc, /// Column indices required to construct intermediate batch for filtering - column_indices: Vec, + pub(crate) column_indices: Vec, /// Physical schema of intermediate batch - schema: Schema, + pub(crate) schema: Schema, } impl JoinFilter { @@ -1280,15 +1281,15 @@ pub(crate) fn adjust_indices_by_join_type( adjust_range: Range, join_type: JoinType, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { match join_type { JoinType::Inner => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::Left => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap } JoinType::Right => { @@ -1307,22 +1308,22 @@ pub(crate) fn adjust_indices_by_join_type( // need to remove the duplicated record in the right side let right_indices = get_semi_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right semi` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side let right_indices = get_anti_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right anti` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::LeftSemi | JoinType::LeftAnti => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop - ( + Ok(( UInt64Array::from_iter_values(vec![]), UInt32Array::from_iter_values(vec![]), - ) + )) } } } @@ -1347,27 +1348,64 @@ pub(crate) fn append_right_indices( right_indices: UInt32Array, adjust_range: Range, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { if preserve_order_for_right { - append_probe_indices_in_order(left_indices, right_indices, adjust_range) + Ok(append_probe_indices_in_order( + left_indices, + right_indices, + adjust_range, + )) } else { let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); if right_unmatched_indices.is_empty() { - (left_indices, right_indices) + Ok((left_indices, right_indices)) } else { - let unmatched_size = right_unmatched_indices.len(); + // `into_builder()` can fail here when there is nothing to be filtered and + // left_indices or right_indices has the same reference to the cached indices. + // In that case, we use a slower alternative. + // the new left indices: left_indices + null array + let mut new_left_indices_builder = + left_indices.into_builder().unwrap_or_else(|left_indices| { + let mut builder = UInt64Builder::with_capacity( + left_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + left_indices.null_count(), + 0, + "expected left indices to have no nulls" + ); + builder.append_slice(left_indices.values()); + builder + }); + new_left_indices_builder.append_nulls(right_unmatched_indices.len()); + let new_left_indices = UInt64Array::from(new_left_indices_builder.finish()); + // the new right indices: right_indices + right_unmatched_indices - let new_left_indices = left_indices - .iter() - .chain(std::iter::repeat(None).take(unmatched_size)) - .collect(); - let new_right_indices = right_indices - .iter() - .chain(right_unmatched_indices.iter()) - .collect(); - (new_left_indices, new_right_indices) + let mut new_right_indices_builder = right_indices + .into_builder() + .unwrap_or_else(|right_indices| { + let mut builder = UInt32Builder::with_capacity( + right_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + right_indices.null_count(), + 0, + "expected right indices to have no nulls" + ); + builder.append_slice(right_indices.values()); + builder + }); + debug_assert_eq!( + right_unmatched_indices.null_count(), + 0, + "expected right unmatched indices to have no nulls" + ); + new_right_indices_builder.append_slice(right_unmatched_indices.values()); + let new_right_indices = UInt32Array::from(new_right_indices_builder.finish()); + + Ok((new_left_indices, new_right_indices)) } } } @@ -1635,6 +1673,91 @@ pub(crate) fn asymmetric_join_output_partitioning( } } +/// Trait for incrementally generating Join output. +/// +/// This trait is used to limit some join outputs +/// so it does not produce single large batches +pub(crate) trait BatchTransformer: Debug + Clone { + /// Sets the next `RecordBatch` to be processed. + fn set_batch(&mut self, batch: RecordBatch); + + /// Retrieves the next `RecordBatch` from the transformer. + /// Returns `None` if all batches have been produced. + /// The boolean flag indicates whether the batch is the last one. + fn next(&mut self) -> Option<(RecordBatch, bool)>; +} + +#[derive(Debug, Clone)] +/// A batch transformer that does nothing. +pub(crate) struct NoopBatchTransformer { + /// RecordBatch to be processed + batch: Option, +} + +impl NoopBatchTransformer { + pub fn new() -> Self { + Self { batch: None } + } +} + +impl BatchTransformer for NoopBatchTransformer { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + self.batch.take().map(|batch| (batch, true)) + } +} + +#[derive(Debug, Clone)] +/// Splits large batches into smaller batches with a maximum number of rows. +pub(crate) struct BatchSplitter { + /// RecordBatch to be split + batch: Option, + /// Maximum number of rows in a split batch + batch_size: usize, + /// Current row index + row_index: usize, +} + +impl BatchSplitter { + /// Creates a new `BatchSplitter` with the specified batch size. + pub(crate) fn new(batch_size: usize) -> Self { + Self { + batch: None, + batch_size, + row_index: 0, + } + } +} + +impl BatchTransformer for BatchSplitter { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + self.row_index = 0; + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + let Some(batch) = &self.batch else { + return None; + }; + + let remaining_rows = batch.num_rows() - self.row_index; + let rows_to_slice = remaining_rows.min(self.batch_size); + let sliced_batch = batch.slice(self.row_index, rows_to_slice); + self.row_index += rows_to_slice; + + let mut last = false; + if self.row_index >= batch.num_rows() { + self.batch = None; + last = true; + } + + Some((sliced_batch, last)) + } +} + #[cfg(test)] mod tests { use std::pin::Pin; @@ -1643,11 +1766,13 @@ mod tests { use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; + use arrow_array::Int32Array; use arrow_schema::SortOptions; - use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use rstest::rstest; + fn check( left: &[Column], right: &[Column], @@ -2554,4 +2679,49 @@ mod tests { Ok(()) } + + fn create_test_batch(num_rows: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32)); + RecordBatch::try_new(schema, vec![data]).unwrap() + } + + fn assert_split_batches( + batches: Vec<(RecordBatch, bool)>, + batch_size: usize, + num_rows: usize, + ) { + let mut row_count = 0; + for (batch, last) in batches.into_iter() { + assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size)); + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + assert_eq!(column.value(i), i as i32 + row_count as i32); + } + row_count += batch.num_rows(); + assert_eq!(last, row_count == num_rows); + } + } + + #[rstest] + #[test] + fn test_batch_splitter( + #[values(1, 3, 11)] batch_size: usize, + #[values(1, 6, 50)] num_rows: usize, + ) { + let mut splitter = BatchSplitter::new(batch_size); + splitter.set_batch(create_test_batch(num_rows)); + + let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size)); + while let Some(batch) = splitter.next() { + batches.push(batch); + } + + assert!(splitter.next().is_none()); + assert_split_batches(batches, batch_size, num_rows); + } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 7acdf25b6596..57bf029a63c1 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -173,6 +173,7 @@ datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false datafusion.execution.enable_recursive_ctes true +datafusion.execution.enforce_batch_size_in_joins false datafusion.execution.keep_partition_by_columns false datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 @@ -263,6 +264,7 @@ datafusion.execution.batch_size 8192 Default batch size while creating new batch datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs +datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index f34d148f092f..c61a7b673334 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -91,6 +91,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | | datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | | datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | +| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible |