diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 7fcd719539ec7..a9f9b22fafda4 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -68,6 +68,7 @@ rand = { workspace = true } tokio = { workspace = true } [dev-dependencies] +criterion = { version = "0.5", features = ["async_futures"] } datafusion-functions-aggregate = { workspace = true } rstest = { workspace = true } rstest_reuse = "0.7.0" @@ -76,3 +77,7 @@ tokio = { workspace = true, features = [ "fs", "parking_lot", ] } + +[[bench]] +harness = false +name = "spm" diff --git a/datafusion/physical-plan/benches/spm.rs b/datafusion/physical-plan/benches/spm.rs new file mode 100644 index 0000000000000..9cc703f5f7265 --- /dev/null +++ b/datafusion/physical-plan/benches/spm.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use arrow_array::{ArrayRef, Int32Array, Int64Array, StringArray}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_plan::memory::MemoryExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::{collect, ExecutionPlan}; + +use criterion::async_executor::FuturesExecutor; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +fn generate_spm_for_round_robin_tie_breaker( + has_same_value: bool, + enable_round_robin_repartition: bool, + batch_count: usize, + partition_count: usize, +) -> SortPreservingMergeExec { + let row_size = 256; + let rb = if has_same_value { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); + let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + } else { + let v = (0i32..row_size as i32).collect::>(); + let a: ArrayRef = Arc::new(Int32Array::from(v)); + + // Use alphanumeric characters + let charset: Vec = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + .chars() + .collect(); + + let mut strings = Vec::new(); + for i in 0..256 { + let mut s = String::new(); + s.push(charset[i % charset.len()]); + s.push(charset[(i / charset.len()) % charset.len()]); + strings.push(Some(s)); + } + + let b: ArrayRef = Arc::new(StringArray::from_iter(strings)); + + let v = (0i64..row_size as i64).collect::>(); + let c: ArrayRef = Arc::new(Int64Array::from_iter(v)); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + }; + + let rbs = (0..batch_count).map(|_| rb.clone()).collect::>(); + let partitiones = vec![rbs.clone(); partition_count]; + + let schema = rb.schema(); + let sort = vec![ + PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: Default::default(), + }, + ]; + + let exec = MemoryExec::try_new(&partitiones, schema, None).unwrap(); + SortPreservingMergeExec::new(sort, Arc::new(exec)) + .with_round_robin_repartition(enable_round_robin_repartition) +} + +fn run_bench( + c: &mut Criterion, + has_same_value: bool, + enable_round_robin_repartition: bool, + batch_count: usize, + partition_count: usize, + description: &str, +) { + let task_ctx = TaskContext::default(); + let task_ctx = Arc::new(task_ctx); + + let spm = Arc::new(generate_spm_for_round_robin_tie_breaker( + has_same_value, + enable_round_robin_repartition, + batch_count, + partition_count, + )) as Arc; + + c.bench_function(description, |b| { + b.to_async(FuturesExecutor) + .iter(|| black_box(collect(Arc::clone(&spm), Arc::clone(&task_ctx)))) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let params = [ + (true, false, "low_card_without_tiebreaker"), // low cardinality, no tie breaker + (true, true, "low_card_with_tiebreaker"), // low cardinality, with tie breaker + (false, false, "high_card_without_tiebreaker"), // high cardinality, no tie breaker + (false, true, "high_card_with_tiebreaker"), // high cardinality, with tie breaker + ]; + + let batch_counts = [1, 25, 625]; + let partition_counts = [2, 8, 32]; + + for &(has_same_value, enable_round_robin_repartition, cardinality_label) in ¶ms { + for &batch_count in &batch_counts { + for &partition_count in &partition_counts { + let description = format!( + "{}_batch_count_{}_partition_count_{}", + cardinality_label, batch_count, partition_count + ); + run_bench( + c, + has_same_value, + enable_round_robin_repartition, + batch_count, + partition_count, + &description, + ); + } + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index df90c97faf68e..133d736c1467b 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -38,6 +38,10 @@ pub trait CursorValues { /// Returns true if `l[l_idx] == r[r_idx]` fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool; + /// Returns true if `row[idx] == row[idx - 1]` + /// Given `idx` should be greater than 0 + fn eq_to_previous(cursor: &Self, idx: usize) -> bool; + /// Returns comparison of `l[l_idx]` and `r[r_idx]` fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering; } @@ -95,6 +99,16 @@ impl Cursor { self.offset += 1; t } + + pub fn is_eq_to_prev_one(&self, prev_cursor: Option<&Cursor>) -> bool { + if self.offset > 0 { + self.is_eq_to_prev_row() + } else if let Some(prev_cursor) = prev_cursor { + self.is_eq_to_prev_row_in_prev_batch(prev_cursor) + } else { + false + } + } } impl PartialEq for Cursor { @@ -103,6 +117,22 @@ impl PartialEq for Cursor { } } +impl Cursor { + fn is_eq_to_prev_row(&self) -> bool { + T::eq_to_previous(&self.values, self.offset) + } + + fn is_eq_to_prev_row_in_prev_batch(&self, other: &Self) -> bool { + assert_eq!(self.offset, 0); + T::eq( + &self.values, + self.offset, + &other.values, + other.values.len() - 1, + ) + } +} + impl Eq for Cursor {} impl PartialOrd for Cursor { @@ -156,6 +186,11 @@ impl CursorValues for RowValues { l.rows.row(l_idx) == r.rows.row(r_idx) } + fn eq_to_previous(cursor: &Self, idx: usize) -> bool { + assert!(idx > 0); + cursor.rows.row(idx) == cursor.rows.row(idx - 1) + } + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { l.rows.row(l_idx).cmp(&r.rows.row(r_idx)) } @@ -188,6 +223,11 @@ impl CursorValues for PrimitiveValues { l.0[l_idx].is_eq(r.0[r_idx]) } + fn eq_to_previous(cursor: &Self, idx: usize) -> bool { + assert!(idx > 0); + cursor.0[idx].is_eq(cursor.0[idx - 1]) + } + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { l.0[l_idx].compare(r.0[r_idx]) } @@ -219,6 +259,11 @@ impl CursorValues for ByteArrayValues { l.value(l_idx) == r.value(r_idx) } + fn eq_to_previous(cursor: &Self, idx: usize) -> bool { + assert!(idx > 0); + cursor.value(idx) == cursor.value(idx - 1) + } + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { l.value(l_idx).cmp(r.value(r_idx)) } @@ -284,6 +329,15 @@ impl CursorValues for ArrayValues { } } + fn eq_to_previous(cursor: &Self, idx: usize) -> bool { + assert!(idx > 0); + match (cursor.is_null(idx), cursor.is_null(idx - 1)) { + (true, true) => true, + (false, false) => T::eq(&cursor.values, idx, &cursor.values, idx - 1), + _ => false, + } + } + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { match (l.is_null(l_idx), r.is_null(r_idx)) { (true, true) => Ordering::Equal, diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index e0644e3d99e55..458c1c29c0cf5 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -97,6 +97,40 @@ pub(crate) struct SortPreservingMergeStream { /// Cursors for each input partition. `None` means the input is exhausted cursors: Vec>>, + /// Configuration parameter to enable round-robin selection of tied winners of loser tree. + /// + /// To address the issue of unbalanced polling between partitions due to tie-breakers being based + /// on partition index, especially in cases of low cardinality, we are making changes to the winner + /// selection mechanism. Previously, partitions with smaller indices were consistently chosen as the winners, + /// leading to an uneven distribution of polling. This caused upstream operator buffers for the other partitions + /// to grow excessively, as they continued receiving data without consuming it. + /// + /// For example, an upstream operator like a repartition execution would keep sending data to certain partitions, + /// but those partitions wouldn't consume the data if they weren't selected as winners. This resulted in inefficient buffer usage. + /// + /// To resolve this, we are modifying the tie-breaking logic. Instead of always choosing the partition with the smallest index, + /// we now select the partition that has the fewest poll counts for the same value. + /// This ensures that multiple partitions with the same value are chosen equally, distributing the polling load in a round-robin fashion. + /// This approach balances the workload more effectively across partitions and avoids excessive buffer growth. + enable_round_robin_tie_breaker: bool, + + /// Flag indicating whether we are in the mode of round-robin + /// tie breaker for the loser tree winners. + round_robin_tie_breaker_mode: bool, + + /// Total number of polls returning the same value, as per partition. + /// We select the one that has less poll counts for tie-breaker in loser tree. + num_of_polled_with_same_value: Vec, + + /// To keep track of reset counts + poll_reset_epochs: Vec, + + /// Current reset count + current_reset_epoch: usize, + + /// Stores the previous value of each partitions for tracking the poll counts on the same value. + prev_cursors: Vec>>, + /// Optional number of rows to fetch fetch: Option, @@ -118,6 +152,7 @@ impl SortPreservingMergeStream { batch_size: usize, fetch: Option, reservation: MemoryReservation, + enable_round_robin_tie_breaker: bool, ) -> Self { let stream_count = streams.partitions(); @@ -127,12 +162,18 @@ impl SortPreservingMergeStream { metrics, aborted: false, cursors: (0..stream_count).map(|_| None).collect(), + prev_cursors: (0..stream_count).map(|_| None).collect(), + round_robin_tie_breaker_mode: false, + num_of_polled_with_same_value: vec![0; stream_count], + current_reset_epoch: 0, + poll_reset_epochs: vec![0; stream_count], loser_tree: vec![], loser_tree_adjusted: false, batch_size, fetch, produced: 0, uninitiated_partitions: (0..stream_count).collect(), + enable_round_robin_tie_breaker, } } @@ -218,7 +259,7 @@ impl SortPreservingMergeStream { } let stream_idx = self.loser_tree[0]; - if self.advance(stream_idx) { + if self.advance_cursors(stream_idx) { self.loser_tree_adjusted = false; self.in_progress.push_row(stream_idx); @@ -236,27 +277,53 @@ impl SortPreservingMergeStream { } } + /// For the given partition, updates the poll count. If the current value is the same + /// of the previous value, it increases the count by 1; otherwise, it is reset as 0. + fn update_poll_count_on_the_same_value(&mut self, partition_idx: usize) { + let cursor = &mut self.cursors[partition_idx]; + + // Check if the current partition's poll count is logically "reset" + if self.poll_reset_epochs[partition_idx] != self.current_reset_epoch { + self.poll_reset_epochs[partition_idx] = self.current_reset_epoch; + self.num_of_polled_with_same_value[partition_idx] = 0; + } + + if let Some(c) = cursor.as_mut() { + // Compare with the last row in the previous batch + let prev_cursor = &self.prev_cursors[partition_idx]; + if c.is_eq_to_prev_one(prev_cursor.as_ref()) { + self.num_of_polled_with_same_value[partition_idx] += 1; + } else { + self.num_of_polled_with_same_value[partition_idx] = 0; + } + } + } + fn fetch_reached(&mut self) -> bool { self.fetch .map(|fetch| self.produced + self.in_progress.len() >= fetch) .unwrap_or(false) } - fn advance(&mut self, stream_idx: usize) -> bool { - let slot = &mut self.cursors[stream_idx]; - match slot.as_mut() { - Some(c) => { - c.advance(); - if c.is_finished() { - *slot = None; - } - true + /// Advances the actual cursor. If it reaches its end, update the + /// previous cursor with it. + /// + /// If the given partition is not exhausted, the function returns `true`. + fn advance_cursors(&mut self, stream_idx: usize) -> bool { + if let Some(cursor) = &mut self.cursors[stream_idx] { + let _ = cursor.advance(); + if cursor.is_finished() { + // Take the current cursor, leaving `None` in its place + self.prev_cursors[stream_idx] = self.cursors[stream_idx].take(); } - None => false, + true + } else { + false } } - /// Returns `true` if the cursor at index `a` is greater than at index `b` + /// Returns `true` if the cursor at index `a` is greater than at index `b`. + /// In an equality case, it compares the partition indices given. #[inline] fn is_gt(&self, a: usize, b: usize) -> bool { match (&self.cursors[a], &self.cursors[b]) { @@ -266,6 +333,19 @@ impl SortPreservingMergeStream { } } + #[inline] + fn is_poll_count_gt(&self, a: usize, b: usize) -> bool { + let poll_a = self.num_of_polled_with_same_value[a]; + let poll_b = self.num_of_polled_with_same_value[b]; + poll_a.cmp(&poll_b).then_with(|| a.cmp(&b)).is_gt() + } + + #[inline] + fn update_winner(&mut self, cmp_node: usize, winner: &mut usize, challenger: usize) { + self.loser_tree[cmp_node] = *winner; + *winner = challenger; + } + /// Find the leaf node index in the loser tree for the given cursor index /// /// Note that this is not necessarily a leaf node in the tree, but it can @@ -327,16 +407,101 @@ impl SortPreservingMergeStream { self.loser_tree_adjusted = true; } - /// Attempts to update the loser tree, following winner replacement, if possible + /// Resets the poll count by incrementing the reset epoch. + fn reset_poll_counts(&mut self) { + self.current_reset_epoch += 1; + } + + /// Handles tie-breaking logic during the adjustment of the loser tree. + /// + /// When comparing elements from multiple partitions in the `update_loser_tree` process, a tie can occur + /// between the current winner and a challenger. This function is invoked when such a tie needs to be + /// resolved according to the round-robin tie-breaker mode. + /// + /// If round-robin tie-breaking is not active, it is enabled, and the poll counts for all elements are reset. + /// The function then compares the poll counts of the current winner and the challenger: + /// - If the winner remains at the top after the final comparison, it increments the winner's poll count. + /// - If the challenger has a lower poll count than the current winner, the challenger becomes the new winner. + /// - If the poll counts are equal but the challenger's index is smaller, the challenger is preferred. + /// + /// # Parameters + /// - `cmp_node`: The index of the comparison node in the loser tree where the tie-breaking is happening. + /// - `winner`: A mutable reference to the current winner, which may be updated based on the tie-breaking result. + /// - `challenger`: The index of the challenger being compared against the winner. + /// + /// This function ensures fair selection among elements with equal values when tie-breaking mode is enabled, + /// aiming to balance the polling across different partitions. + #[inline] + fn handle_tie(&mut self, cmp_node: usize, winner: &mut usize, challenger: usize) { + if !self.round_robin_tie_breaker_mode { + self.round_robin_tie_breaker_mode = true; + // Reset poll count for tie-breaker + self.reset_poll_counts(); + } + // Update poll count if the winner survives in the final match + if *winner == self.loser_tree[0] { + self.update_poll_count_on_the_same_value(*winner); + if self.is_poll_count_gt(*winner, challenger) { + self.update_winner(cmp_node, winner, challenger); + } + } else if challenger < *winner { + // If the winner doesn’t survive in the final match, it indicates that the original winner + // has moved up in value, so the challenger now becomes the new winner. + // This also means that we’re in a new round of the tie breaker, + // and the polls count is outdated (though not yet cleaned up). + // + // By the time we reach this code, both the new winner and the current challenger + // have the same value, and neither has an updated polls count. + // Therefore, we simply select the one with the smaller index. + self.update_winner(cmp_node, winner, challenger); + } + } + + /// Updates the loser tree to reflect the new winner after the previous winner is consumed. + /// This function adjusts the tree by comparing the current winner with challengers from + /// other partitions. + /// + /// If `enable_round_robin_tie_breaker` is true and a tie occurs at the final level, the + /// tie-breaker logic will be applied to ensure fair selection among equal elements. fn update_loser_tree(&mut self) { + // Start with the current winner let mut winner = self.loser_tree[0]; - // Replace overall winner by walking tree of losers + + // Find the leaf node index of the winner in the loser tree. let mut cmp_node = self.lt_leaf_node_index(winner); + + // Traverse up the tree to adjust comparisons until reaching the root. while cmp_node != 0 { let challenger = self.loser_tree[cmp_node]; - if self.is_gt(winner, challenger) { - self.loser_tree[cmp_node] = winner; - winner = challenger; + // If round-robin tie-breaker is enabled and we're at the final comparison (cmp_node == 1) + if self.enable_round_robin_tie_breaker && cmp_node == 1 { + match (&self.cursors[winner], &self.cursors[challenger]) { + (Some(ac), Some(bc)) => { + let ord = ac.cmp(bc); + if ord.is_eq() { + self.handle_tie(cmp_node, &mut winner, challenger); + } else { + // Ends of tie breaker + self.round_robin_tie_breaker_mode = false; + if ord.is_gt() { + self.update_winner(cmp_node, &mut winner, challenger); + } + } + } + (None, _) => { + // Challenger wins, update winner + // Ends of tie breaker + self.round_robin_tie_breaker_mode = false; + self.update_winner(cmp_node, &mut winner, challenger); + } + (_, None) => { + // Winner wins again + // Ends of tie breaker + self.round_robin_tie_breaker_mode = false; + } + } + } else if self.is_gt(winner, challenger) { + self.update_winner(cmp_node, &mut winner, challenger); } cmp_node = self.lt_parent_node_index(cmp_node); } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 31a4ed61cf9e8..f17161306c7ae 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -82,6 +82,8 @@ pub struct SortPreservingMergeExec { fetch: Option, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, + /// Configuration parameter to enable round-robin selection of tied winners of loser tree. + enable_round_robin_repartition: bool, } impl SortPreservingMergeExec { @@ -94,14 +96,25 @@ impl SortPreservingMergeExec { metrics: ExecutionPlanMetricsSet::new(), fetch: None, cache, + enable_round_robin_repartition: true, } } + /// Sets the number of rows to fetch pub fn with_fetch(mut self, fetch: Option) -> Self { self.fetch = fetch; self } + /// Sets the selection strategy of tied winners of the loser tree algorithm + pub fn with_round_robin_repartition( + mut self, + enable_round_robin_repartition: bool, + ) -> Self { + self.enable_round_robin_repartition = enable_round_robin_repartition; + self + } + /// Input schema pub fn input(&self) -> &Arc { &self.input @@ -182,6 +195,7 @@ impl ExecutionPlan for SortPreservingMergeExec { metrics: self.metrics.clone(), fetch: limit, cache: self.cache.clone(), + enable_round_robin_repartition: true, })) } @@ -281,6 +295,7 @@ impl ExecutionPlan for SortPreservingMergeExec { .with_batch_size(context.session_config().batch_size()) .with_fetch(self.fetch) .with_reservation(reservation) + .with_round_robin_tie_breaker(self.enable_round_robin_repartition) .build()?; debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); @@ -312,10 +327,12 @@ mod tests { use std::time::Duration; use super::*; + use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::expressions::col; use crate::memory::MemoryExec; use crate::metrics::{MetricValue, Timestamp}; + use crate::repartition::RepartitionExec; use crate::sorts::sort::SortExec; use crate::stream::RecordBatchReceiverStream; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; @@ -326,10 +343,12 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::Int64Array; use arrow_schema::SchemaRef; use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::RecordBatchStream; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::EquivalenceProperties; @@ -338,6 +357,80 @@ mod tests { use futures::{FutureExt, Stream, StreamExt}; use tokio::time::timeout; + // The number in the function is highly related to the memory limit we are testing + // any change of the constant should be aware of + fn generate_task_ctx_for_round_robin_tie_breaker() -> Result> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(20_000_000, 1.0) + .build_arc()?; + let config = SessionConfig::new(); + let task_ctx = TaskContext::default() + .with_runtime(runtime) + .with_session_config(config); + Ok(Arc::new(task_ctx)) + } + // The number in the function is highly related to the memory limit we are testing, + // any change of the constant should be aware of + fn generate_spm_for_round_robin_tie_breaker( + enable_round_robin_repartition: bool, + ) -> Result> { + let target_batch_size = 12500; + let row_size = 12500; + let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); + let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); + let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let rbs = (0..1024).map(|_| rb.clone()).collect::>(); + + let schema = rb.schema(); + let sort = vec![ + PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: Default::default(), + }, + ]; + + let exec = MemoryExec::try_new(&[rbs], schema, None).unwrap(); + let repartition_exec = + RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(2))?; + let coalesce_batches_exec = + CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size); + let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec)) + .with_round_robin_repartition(enable_round_robin_repartition); + Ok(Arc::new(spm)) + } + + /// This test verifies that memory usage stays within limits when the tie breaker is enabled. + /// Any errors here could indicate unintended changes in tie breaker logic. + /// + /// Note: If you adjust constants in this test, ensure that memory usage differs + /// based on whether the tie breaker is enabled or disabled. + #[tokio::test(flavor = "multi_thread")] + async fn test_round_robin_tie_breaker_success() -> Result<()> { + let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; + let spm = generate_spm_for_round_robin_tie_breaker(true)?; + let _collected = collect(spm, task_ctx).await.unwrap(); + Ok(()) + } + + /// This test verifies that memory usage stays within limits when the tie breaker is enabled. + /// Any errors here could indicate unintended changes in tie breaker logic. + /// + /// Note: If you adjust constants in this test, ensure that memory usage differs + /// based on whether the tie breaker is enabled or disabled. + #[tokio::test(flavor = "multi_thread")] + async fn test_round_robin_tie_breaker_fail() -> Result<()> { + let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; + let spm = generate_spm_for_round_robin_tie_breaker(false)?; + let _err = collect(spm, task_ctx).await.unwrap_err(); + Ok(()) + } + #[tokio::test] async fn test_merge_interleave() { let task_ctx = Arc::new(TaskContext::default()); diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index ad640d8e8470d..e8330a7cabc03 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -36,7 +36,7 @@ macro_rules! primitive_merge_helper { } macro_rules! merge_helper { - ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ + ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident, $enable_round_robin_tie_breaker:ident) => {{ let streams = FieldCursorStream::<$t>::new($sort, $streams); return Ok(Box::pin(SortPreservingMergeStream::new( Box::new(streams), @@ -45,6 +45,7 @@ macro_rules! merge_helper { $batch_size, $fetch, $reservation, + $enable_round_robin_tie_breaker, ))); }}; } @@ -58,11 +59,15 @@ pub struct StreamingMergeBuilder<'a> { batch_size: Option, fetch: Option, reservation: Option, + enable_round_robin_tie_breaker: bool, } impl<'a> StreamingMergeBuilder<'a> { pub fn new() -> Self { - Self::default() + Self { + enable_round_robin_tie_breaker: true, + ..Default::default() + } } pub fn with_streams(mut self, streams: Vec) -> Self { @@ -100,6 +105,14 @@ impl<'a> StreamingMergeBuilder<'a> { self } + pub fn with_round_robin_tie_breaker( + mut self, + enable_round_robin_tie_breaker: bool, + ) -> Self { + self.enable_round_robin_tie_breaker = enable_round_robin_tie_breaker; + self + } + pub fn build(self) -> Result { let Self { streams, @@ -109,6 +122,7 @@ impl<'a> StreamingMergeBuilder<'a> { reservation, fetch, expressions, + enable_round_robin_tie_breaker, } = self; // Early return if streams or expressions are empty @@ -141,11 +155,11 @@ impl<'a> StreamingMergeBuilder<'a> { let sort = expressions[0].clone(); let data_type = sort.expr.data_type(schema.as_ref())?; downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) _ => {} } } @@ -163,6 +177,7 @@ impl<'a> StreamingMergeBuilder<'a> { batch_size, fetch, reservation, + enable_round_robin_tie_breaker, ))) } }