From e29b6278d34c043dfe1cdd32a59341c8ee5531af Mon Sep 17 00:00:00 2001 From: guojidan <1948535941@qq.com> Date: Wed, 27 Mar 2024 08:58:27 +0000 Subject: [PATCH] refactor: Generate GroupByHash output in multiple RecordBatches Signed-off-by: guojidan <1948535941@qq.com> --- .../physical-plan/src/aggregates/row_hash.rs | 56 ++++++++++++++----- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 45d408bb4de7..259b7163bfef 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -60,7 +60,7 @@ pub(crate) enum ExecutionState { ReadingInput, /// When producing output, the remaining rows to output are stored /// here and are sliced off as needed in batch_size chunks - ProducingOutput(RecordBatch), + ProducingOutput(Vec), Done, } @@ -454,7 +454,9 @@ impl Stream for GroupedHashAggregateStream { if let Some(to_emit) = self.group_ordering.emit_to() { let batch = extract_ok!(self.emit(to_emit, false)); - self.exec_state = ExecutionState::ProducingOutput(batch); + let batches = self.split_batch(batch)?; + self.exec_state = + ExecutionState::ProducingOutput(batches); timer.done(); // make sure the exec_state just set is not overwritten below break 'reading_input; @@ -475,25 +477,24 @@ impl Stream for GroupedHashAggregateStream { } } - ExecutionState::ProducingOutput(batch) => { - // slice off a part of the batch, if needed + ExecutionState::ProducingOutput(batches) => { + // batches is not empty + let len = batches.len(); + assert!(len > 0); let output_batch; - let size = self.batch_size; - (self.exec_state, output_batch) = if batch.num_rows() <= size { + (self.exec_state, output_batch) = if len == 1 { ( if self.input_done { ExecutionState::Done } else { ExecutionState::ReadingInput }, - batch.clone(), + batches[0].clone(), ) } else { - // output first batch_size rows - let size = self.batch_size; - let num_remaining = batch.num_rows() - size; - let remaining = batch.slice(size, num_remaining); - let output = batch.slice(0, size); + // output first batches element + let remaining = batches[1..].to_vec(); + let output = batches[0].clone(); (ExecutionState::ProducingOutput(remaining), output) }; return Poll::Ready(Some(Ok( @@ -728,7 +729,8 @@ impl GroupedHashAggregateStream { { let n = self.group_values.len() / self.batch_size * self.batch_size; let batch = self.emit(EmitTo::First(n), false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); + let batches = self.split_batch(batch)?; + self.exec_state = ExecutionState::ProducingOutput(batches); } Ok(()) } @@ -787,7 +789,8 @@ impl GroupedHashAggregateStream { let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { let batch = self.emit(EmitTo::All, false)?; - ExecutionState::ProducingOutput(batch) + let batches = self.split_batch(batch)?; + ExecutionState::ProducingOutput(batches) } else { // If spill files exist, stream-merge them. self.update_merged_stream()?; @@ -796,4 +799,29 @@ impl GroupedHashAggregateStream { timer.done(); Ok(()) } + + fn split_batch(&self, batch: RecordBatch) -> Result> { + let mut batches = vec![]; + let mut length = self.batch_size; + let len = batch.num_rows(); + if len == 0 { + return Ok(vec![batch]); + } + for offset in (0..len).step_by(length) { + if offset + length > len { + length = len - offset; + } + let slice_columns = batch + .columns() + .iter() + .map(|array| { + let sliced_array = array.slice(offset, length); + sliced_array.to_owned() + }) + .collect(); + batches.push(RecordBatch::try_new(batch.schema().clone(), slice_columns)?); + } + + Ok(batches) + } }