Skip to content

Commit

Permalink
refactor: Generate GroupByHash output in multiple RecordBatches
Browse files Browse the repository at this point in the history
Signed-off-by: guojidan <[email protected]>
  • Loading branch information
guojidan committed Mar 27, 2024
1 parent 56c735c commit e29b627
Showing 1 changed file with 42 additions and 14 deletions.
56 changes: 42 additions & 14 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub(crate) enum ExecutionState {
ReadingInput,
/// When producing output, the remaining rows to output are stored
/// here and are sliced off as needed in batch_size chunks
ProducingOutput(RecordBatch),
ProducingOutput(Vec<RecordBatch>),
Done,
}

Expand Down Expand Up @@ -454,7 +454,9 @@ impl Stream for GroupedHashAggregateStream {

if let Some(to_emit) = self.group_ordering.emit_to() {
let batch = extract_ok!(self.emit(to_emit, false));
self.exec_state = ExecutionState::ProducingOutput(batch);
let batches = self.split_batch(batch)?;
self.exec_state =
ExecutionState::ProducingOutput(batches);
timer.done();
// make sure the exec_state just set is not overwritten below
break 'reading_input;
Expand All @@ -475,25 +477,24 @@ impl Stream for GroupedHashAggregateStream {
}
}

ExecutionState::ProducingOutput(batch) => {
// slice off a part of the batch, if needed
ExecutionState::ProducingOutput(batches) => {
// batches is not empty
let len = batches.len();
assert!(len > 0);
let output_batch;
let size = self.batch_size;
(self.exec_state, output_batch) = if batch.num_rows() <= size {
(self.exec_state, output_batch) = if len == 1 {
(
if self.input_done {
ExecutionState::Done
} else {
ExecutionState::ReadingInput
},
batch.clone(),
batches[0].clone(),
)
} else {
// output first batch_size rows
let size = self.batch_size;
let num_remaining = batch.num_rows() - size;
let remaining = batch.slice(size, num_remaining);
let output = batch.slice(0, size);
// output first batches element
let remaining = batches[1..].to_vec();
let output = batches[0].clone();
(ExecutionState::ProducingOutput(remaining), output)
};
return Poll::Ready(Some(Ok(
Expand Down Expand Up @@ -728,7 +729,8 @@ impl GroupedHashAggregateStream {
{
let n = self.group_values.len() / self.batch_size * self.batch_size;
let batch = self.emit(EmitTo::First(n), false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
let batches = self.split_batch(batch)?;
self.exec_state = ExecutionState::ProducingOutput(batches);
}
Ok(())
}
Expand Down Expand Up @@ -787,7 +789,8 @@ impl GroupedHashAggregateStream {
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
let batch = self.emit(EmitTo::All, false)?;
ExecutionState::ProducingOutput(batch)
let batches = self.split_batch(batch)?;
ExecutionState::ProducingOutput(batches)
} else {
// If spill files exist, stream-merge them.
self.update_merged_stream()?;
Expand All @@ -796,4 +799,29 @@ impl GroupedHashAggregateStream {
timer.done();
Ok(())
}

fn split_batch(&self, batch: RecordBatch) -> Result<Vec<RecordBatch>> {
let mut batches = vec![];
let mut length = self.batch_size;
let len = batch.num_rows();
if len == 0 {
return Ok(vec![batch]);
}
for offset in (0..len).step_by(length) {
if offset + length > len {
length = len - offset;
}
let slice_columns = batch
.columns()
.iter()
.map(|array| {
let sliced_array = array.slice(offset, length);
sliced_array.to_owned()
})
.collect();
batches.push(RecordBatch::try_new(batch.schema().clone(), slice_columns)?);
}

Ok(batches)
}
}

0 comments on commit e29b627

Please sign in to comment.