diff --git a/src/query/src/optimizer/windowed_sort.rs b/src/query/src/optimizer/windowed_sort.rs index 6d944a44c00a..638c3d9d3905 100644 --- a/src/query/src/optimizer/windowed_sort.rs +++ b/src/query/src/optimizer/windowed_sort.rs @@ -19,6 +19,7 @@ use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::ExecutionPlan; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::Result as DataFusionResult; @@ -67,10 +68,12 @@ impl WindowedSortPhysicalRule { .transform_down(|plan| { if let Some(sort_exec) = plan.as_any().downcast_ref::() { // TODO: support multiple expr in windowed sort - if !sort_exec.preserve_partitioning() || sort_exec.expr().len() != 1 { + if sort_exec.expr().len() != 1 { return Ok(Transformed::no(plan)); } + let preserve_partitioning = sort_exec.preserve_partitioning(); + let Some(scanner_info) = fetch_partition_range(sort_exec.input().clone())? else { return Ok(Transformed::no(plan)); @@ -111,11 +114,23 @@ impl WindowedSortPhysicalRule { new_input, )?; - return Ok(Transformed { - data: Arc::new(windowed_sort_exec), - transformed: true, - tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop, - }); + if !preserve_partitioning { + let order_preserving_merge = SortPreservingMergeExec::new( + sort_exec.expr().to_vec(), + Arc::new(windowed_sort_exec), + ); + return Ok(Transformed { + data: Arc::new(order_preserving_merge), + transformed: true, + tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop, + }); + } else { + return Ok(Transformed { + data: Arc::new(windowed_sort_exec), + transformed: true, + tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop, + }); + } } Ok(Transformed::no(plan)) @@ -126,6 +141,7 @@ impl WindowedSortPhysicalRule { } } +#[derive(Debug)] struct ScannerInfo { partition_ranges: Vec>, time_index: String, @@ -136,11 +152,11 @@ fn fetch_partition_range(input: Arc) -> DataFusionResult() - || plan.as_any().is::() || plan.as_any().is::() || plan.as_any().is::() || plan.as_any().is::() @@ -148,13 +164,19 @@ fn fetch_partition_range(input: Arc) -> DataFusionResult() { + is_batch_coalesced = true; + } + if let Some(region_scan_exec) = plan.as_any().downcast_ref::() { partition_ranges = Some(region_scan_exec.get_uncollapsed_partition_ranges()); time_index = Some(region_scan_exec.time_index()); tag_columns = Some(region_scan_exec.tag_columns()); // set distinguish_partition_ranges to true, this is an incorrect workaround - region_scan_exec.with_distinguish_partition_range(true); + if !is_batch_coalesced { + region_scan_exec.with_distinguish_partition_range(true); + } } Ok(Transformed::no(plan)) diff --git a/src/query/src/part_sort.rs b/src/query/src/part_sort.rs index 2828db202d9d..bcf82502057d 100644 --- a/src/query/src/part_sort.rs +++ b/src/query/src/part_sort.rs @@ -12,6 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Module for sorting input data within each [`PartitionRange`]. +//! +//! This module defines the [`PartSortExec`] execution plan, which sorts each +//! partition ([`PartitionRange`]) independently based on the provided physical +//! sort expressions. + use std::any::Any; use std::pin::Pin; use std::sync::Arc; @@ -36,7 +42,7 @@ use itertools::Itertools; use snafu::location; use store_api::region_engine::PartitionRange; -use crate::downcast_ts_array; +use crate::{array_iter_helper, downcast_ts_array}; /// Sort input within given PartitionRange /// @@ -193,6 +199,7 @@ struct PartSortStream { #[allow(dead_code)] // this is used under #[debug_assertions] partition: usize, cur_part_idx: usize, + evaluating_batch: Option, metrics: BaselineMetrics, } @@ -218,6 +225,7 @@ impl PartSortStream { partition_ranges, partition, cur_part_idx: 0, + evaluating_batch: None, metrics: BaselineMetrics::new(&sort.metrics, partition), } } @@ -288,9 +296,51 @@ impl PartSortStream { Ok(()) } + /// Try find data whose value exceeds the current partition range. + /// + /// Returns `None` if no such data is found, and `Some(idx)` where idx points to + /// the first data that exceeds the current partition range. + fn try_find_next_range( + &self, + sort_column: &ArrayRef, + ) -> datafusion_common::Result> { + if sort_column.len() == 0 { + return Ok(Some(0)); + } + + // check if the current partition index is out of range + if self.cur_part_idx >= self.partition_ranges.len() { + internal_err!( + "Partition index out of range: {} >= {}", + self.cur_part_idx, + self.partition_ranges.len() + )?; + } + let cur_range = self.partition_ranges[self.cur_part_idx]; + + let sort_column_iter = downcast_ts_array!( + sort_column.data_type() => (array_iter_helper, sort_column), + _ => internal_err!( + "Unsupported data type for sort column: {:?}", + sort_column.data_type() + )?, + ); + + for (idx, val) in sort_column_iter { + // ignore vacant time index data + if let Some(val) = val { + if val >= cur_range.end.value() || val < cur_range.start.value() { + return Ok(Some(idx)); + } + } + } + + Ok(None) + } + /// Sort and clear the buffer and return the sorted record batch /// - /// this function should return a empty record batch if the buffer is empty + /// this function will return a empty record batch if the buffer is empty fn sort_buffer(&mut self) -> datafusion_common::Result { if self.buffer.is_empty() { return Ok(DfRecordBatch::new_empty(self.schema.clone())); @@ -317,6 +367,9 @@ impl PartSortStream { Some(format!("Fail to sort to indices at {}", location!())), ) })?; + if indices.is_empty() { + return Ok(DfRecordBatch::new_empty(self.schema.clone())); + } self.check_in_range( &sort_column, @@ -374,11 +427,58 @@ impl PartSortStream { Ok(sorted) } + fn split_batch( + &mut self, + batch: DfRecordBatch, + ) -> datafusion_common::Result> { + if batch.num_rows() == 0 { + return Ok(None); + } + + let sort_column = self + .expression + .expr + .evaluate(&batch)? + .into_array(batch.num_rows())?; + + let next_range_idx = self.try_find_next_range(&sort_column)?; + let Some(idx) = next_range_idx else { + self.buffer.push(batch); + // keep polling input for next batch + return Ok(None); + }; + + let this_range = batch.slice(0, idx); + let remaining_range = batch.slice(idx, batch.num_rows() - idx); + if this_range.num_rows() != 0 { + self.buffer.push(this_range); + } + // mark end of current PartitionRange + let sorted_batch = self.sort_buffer(); + // step to next proper PartitionRange + self.cur_part_idx += 1; + let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx); + if self.try_find_next_range(&next_sort_column)?.is_some() { + // remaining batch still contains data that exceeds the current partition range + // register the remaining batch for next polling + self.evaluating_batch = Some(remaining_range); + } else { + // remaining batch is within the current partition range + // push to the buffer and continue polling + if remaining_range.num_rows() != 0 { + self.buffer.push(remaining_range); + } + } + + sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) }) + } + pub fn poll_next_inner( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { loop { + // no more input, sort the buffer and return if self.input_complete { if self.buffer.is_empty() { return Poll::Ready(None); @@ -386,24 +486,30 @@ impl PartSortStream { return Poll::Ready(Some(self.sort_buffer())); } } + + // if there is a remaining batch being evaluated from last run, + // split on it instead of fetching new batch + if let Some(evaluating_batch) = self.evaluating_batch.take() + && evaluating_batch.num_rows() != 0 + { + if let Some(sorted_batch) = self.split_batch(evaluating_batch)? { + return Poll::Ready(Some(Ok(sorted_batch))); + } else { + continue; + } + } + + // fetch next batch from input let res = self.input.as_mut().poll_next(cx); match res { Poll::Ready(Some(Ok(batch))) => { - if batch.num_rows() == 0 { - // mark end of current PartitionRange - let sorted_batch = self.sort_buffer()?; - self.cur_part_idx += 1; - if sorted_batch.num_rows() == 0 { - // Current part is empty, continue polling next part. - continue; - } + if let Some(sorted_batch) = self.split_batch(batch)? { return Poll::Ready(Some(Ok(sorted_batch))); + } else { + continue; } - self.buffer.push(batch); - // keep polling until boundary(a empty RecordBatch) is reached - continue; } - // input stream end, sort the buffer and return + // input stream end, mark and continue Poll::Ready(None) => { self.input_complete = true; continue; @@ -484,14 +590,19 @@ mod test { let schema = Arc::new(schema); let mut input_ranged_data = vec![]; + let mut output_ranges = vec![]; let mut output_data = vec![]; // generate each input `PartitionRange` for part_id in 0..rng.usize(0..part_cnt_bound) { // generate each `PartitionRange`'s timestamp range let (start, end) = if descending { let end = bound_val - .map(|i| i.checked_sub(rng.i64(0..range_offset_bound)).expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")) - .unwrap_or_else(|| rng.i64(..)); + .map( + |i| i + .checked_sub(rng.i64(0..range_offset_bound)) + .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again") + ) + .unwrap_or_else(|| rng.i64(-100000000..100000000)); bound_val = Some(end); let start = end - rng.i64(1..range_size_bound); let start = Timestamp::new(start, unit.clone().into()); @@ -514,13 +625,15 @@ mod test { for _batch_idx in 0..rng.usize(1..batch_cnt_bound) { let cnt = rng.usize(0..batch_size_bound) + 1; let iter = 0..rng.usize(0..cnt); - let data_gen = iter + let mut data_gen = iter .map(|_| rng.i64(start.value()..end.value())) .collect_vec(); if data_gen.is_empty() { // current batch is empty, skip continue; } + // mito always sort on ASC order + data_gen.sort(); per_part_sort_data.extend(data_gen.clone()); let arr = new_ts_array(unit.clone(), data_gen.clone()); let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap(); @@ -535,15 +648,35 @@ mod test { }; input_ranged_data.push((range, batches)); - if descending { - per_part_sort_data.sort_by(|a, b| b.cmp(a)); - } else { - per_part_sort_data.sort(); - } + output_ranges.push(range); if per_part_sort_data.is_empty() { continue; } - output_data.push(per_part_sort_data); + output_data.extend_from_slice(&per_part_sort_data); + } + + // adjust output data with adjacent PartitionRanges + let mut output_data_iter = output_data.iter().peekable(); + let mut output_data = vec![]; + for range in output_ranges.clone() { + let mut cur_data = vec![]; + while let Some(val) = output_data_iter.peek() { + if **val < range.start.value() || **val >= range.end.value() { + break; + } + cur_data.push(*output_data_iter.next().unwrap()); + } + + if cur_data.is_empty() { + continue; + } + + if descending { + cur_data.sort_by(|a, b| b.cmp(a)); + } else { + cur_data.sort(); + } + output_data.push(cur_data); } let expected_output = output_data @@ -578,7 +711,7 @@ mod test { ((5, 10), vec![vec![5, 6], vec![7, 8]]), ], false, - vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![5, 6, 7, 8]], + vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]], ), ( TimeUnit::Millisecond, @@ -620,6 +753,25 @@ mod test { true, vec![], ), + ( + TimeUnit::Millisecond, + vec![ + ( + (15, 20), + vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]], + ), + ((10, 15), vec![]), + ((5, 10), vec![]), + ((0, 10), vec![]), + ], + true, + vec![ + vec![19, 17, 15], + vec![12, 11, 10], + vec![9, 8, 7, 6, 5], + vec![4, 3, 2, 1], + ], + ), ]; for (identifier, (unit, input_ranged_data, descending, expected_output)) in @@ -664,10 +816,18 @@ mod test { }) .collect_vec(); - run_test(0, input_ranged_data, schema.clone(), opt, expected_output).await; + run_test( + identifier, + input_ranged_data, + schema.clone(), + opt, + expected_output, + ) + .await; } } + #[allow(clippy::print_stdout)] async fn run_test( case_id: usize, input_ranged_data: Vec<(PartitionRange, Vec)>, @@ -692,20 +852,36 @@ mod test { options: opt, }, None, - vec![ranges], + vec![ranges.clone()], Arc::new(mock_input), ); let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap(); let real_output = exec_stream.map(|r| r.unwrap()).collect::>().await; - // a makeshift solution for compare large data if real_output != expected_output { + let mut first_diff = 0; + for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() { + if lhs != rhs { + first_diff = idx; + break; + } + } + println!("first diff batch at {}", first_diff); + println!( + "ranges: {:?}", + ranges + .into_iter() + .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime())) + .enumerate() + .collect::>() + ); + let mut full_msg = String::new(); { let mut buf = Vec::with_capacity(10 * real_output.len()); - for batch in &real_output { + for batch in real_output.iter().skip(first_diff) { let mut rb_json: Vec = Vec::new(); let mut writer = ArrayWriter::new(&mut rb_json); writer.write(batch).unwrap(); @@ -714,12 +890,12 @@ mod test { buf.push(b','); } // TODO(discord9): better ways to print buf - let _buf = String::from_utf8_lossy(&buf); - full_msg += &format!("case_id:{case_id}, real_output"); + let buf = String::from_utf8_lossy(&buf); + full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n"); } { let mut buf = Vec::with_capacity(10 * real_output.len()); - for batch in &expected_output { + for batch in expected_output.iter().skip(first_diff) { let mut rb_json: Vec = Vec::new(); let mut writer = ArrayWriter::new(&mut rb_json); writer.write(batch).unwrap(); @@ -727,12 +903,16 @@ mod test { buf.append(&mut rb_json); buf.push(b','); } - let _buf = String::from_utf8_lossy(&buf); - full_msg += &format!("case_id:{case_id}, expected_output"); + let buf = String::from_utf8_lossy(&buf); + full_msg += &format!("case_id:{case_id}, expected_output \n{buf}"); } panic!( - "case_{} failed, opt: {:?}, full msg: {}", - case_id, opt, full_msg + "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}", + case_id, opt, + real_output.len(), + real_output.iter().map(|x|x.num_rows()).sum::(), + expected_output.len(), + expected_output.iter().map(|x|x.num_rows()).sum::(), full_msg ); } } diff --git a/src/query/src/window_sort.rs b/src/query/src/window_sort.rs index 435a255beb95..1d03eedc13b7 100644 --- a/src/query/src/window_sort.rs +++ b/src/query/src/window_sort.rs @@ -21,7 +21,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, ArrayRef}; use arrow::compute::SortColumn; use arrow_schema::{DataType, SchemaRef, SortOptions}; use common_error::ext::{BoxedError, PlainError}; @@ -812,9 +812,16 @@ fn find_slice_from_range( Ok((start, end - start)) } +/// Get an iterator from a primitive array. +/// +/// Used with `downcast_ts_array`. The returned iter is wrapped with `.enumerate()`. +#[macro_export] macro_rules! array_iter_helper { ($t:ty, $unit:expr, $arr:expr) => {{ - let typed = $arr.as_any().downcast_ref::>().unwrap(); + let typed = $arr + .as_any() + .downcast_ref::>() + .unwrap(); let iter = typed.iter().enumerate(); Box::new(iter) as Box)>> }};