Skip to content

Commit

Permalink
Generate hash aggregation output in smaller record batches (#3461)
Browse files Browse the repository at this point in the history
* change how final aggregation row group is created ...

this change would prevent of cloning of whole state, doubling memory
needed for aggregation.

this PR relates to #1570

* Fix clippy issues

* read batch size from `session_config`
  • Loading branch information
milenkovicm authored Oct 15, 2022
1 parent 011bcf4 commit 0b90a8a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 49 deletions.
2 changes: 2 additions & 0 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl ExecutionPlan for AggregateExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let batch_size = context.session_config().batch_size();
let input = self.input.execute(partition, context)?;

let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
Expand All @@ -318,6 +319,7 @@ impl ExecutionPlan for AggregateExec {
self.aggr_expr.clone(),
input,
baseline_metrics,
batch_size,
)?))
} else {
Ok(Box::pin(GroupedHashAggregateStream::new(
Expand Down
121 changes: 72 additions & 49 deletions datafusion/core/src/physical_plan/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ pub(crate) struct GroupedHashAggregateStreamV2 {

baseline_metrics: BaselineMetrics,
random_state: RandomState,
finished: bool,
/// size to be used for resulting RecordBatches
batch_size: usize,
/// if the result is chunked into batches,
/// last offset is preserved for continuation.
row_group_skip_position: usize,
}

fn aggr_state_schema(aggr_expr: &[Arc<dyn AggregateExpr>]) -> Result<SchemaRef> {
Expand All @@ -105,6 +109,7 @@ impl GroupedHashAggregateStreamV2 {
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
batch_size: usize,
) -> Result<Self> {
let timer = baseline_metrics.elapsed_compute().timer();

Expand Down Expand Up @@ -135,7 +140,8 @@ impl GroupedHashAggregateStreamV2 {
aggregate_expressions,
aggr_state: Default::default(),
random_state: Default::default(),
finished: false,
batch_size,
row_group_skip_position: 0,
})
}
}
Expand All @@ -148,56 +154,62 @@ impl Stream for GroupedHashAggregateStreamV2 {
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if this.finished {
return Poll::Ready(None);
}

let elapsed_compute = this.baseline_metrics.elapsed_compute();

loop {
let result = match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
&this.group_by,
&mut this.accumulators,
&this.group_schema,
this.aggr_layout.clone(),
batch,
&mut this.aggr_state,
&this.aggregate_expressions,
);

timer.done();

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
let result: ArrowResult<Option<RecordBatch>> =
match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
&this.group_by,
&mut this.accumulators,
&this.group_schema,
this.aggr_layout.clone(),
batch,
&mut this.aggr_state,
&this.aggregate_expressions,
);

timer.done();

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
}
}
Some(Err(e)) => Err(e),
None => {
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = create_batch_from_map(
&this.mode,
&this.group_schema,
&this.aggr_schema,
this.batch_size,
this.row_group_skip_position,
&mut this.aggr_state,
&mut this.accumulators,
&this.schema,
);

timer.done();
result
}
};

this.row_group_skip_position += this.batch_size;
match result {
Ok(Some(result)) => {
return Poll::Ready(Some(Ok(
result.record_output(&this.baseline_metrics)
)))
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = create_batch_from_map(
&this.mode,
&this.group_schema,
&this.aggr_schema,
&mut this.aggr_state,
&mut this.accumulators,
&this.schema,
)
.record_output(&this.baseline_metrics);

timer.done();
result
}
};

this.finished = true;
return Poll::Ready(Some(result));
Ok(None) => return Poll::Ready(None),
Err(error) => return Poll::Ready(Some(Err(error))),
}
}
}
}
Expand Down Expand Up @@ -419,23 +431,34 @@ fn create_group_rows(arrays: Vec<ArrayRef>, schema: &Schema) -> Vec<Vec<u8>> {
}

/// Create a RecordBatch with all group keys and accumulator' states or values.
#[allow(clippy::too_many_arguments)]
fn create_batch_from_map(
mode: &AggregateMode,
group_schema: &Schema,
aggr_schema: &Schema,
batch_size: usize,
skip_items: usize,
aggr_state: &mut AggregationState,
accumulators: &mut [AccumulatorItemV2],
output_schema: &Schema,
) -> ArrowResult<RecordBatch> {
) -> ArrowResult<Option<RecordBatch>> {
if skip_items > aggr_state.group_states.len() {
return Ok(None);
}

if aggr_state.group_states.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned())));
return Ok(Some(RecordBatch::new_empty(Arc::new(
output_schema.to_owned(),
))));
}

let mut state_accessor = RowAccessor::new(aggr_schema, RowType::WordAligned);

let (group_buffers, mut state_buffers): (Vec<_>, Vec<_>) = aggr_state
.group_states
.iter()
.skip(skip_items)
.take(batch_size)
.map(|gs| (gs.group_by_values.clone(), gs.aggregation_buffer.clone()))
.unzip();

Expand Down Expand Up @@ -471,7 +494,7 @@ fn create_batch_from_map(
.map(|(col, desired_field)| cast(col, desired_field.data_type()))
.collect::<ArrowResult<Vec<_>>>()?;

RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)
RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns).map(Some)
}

fn read_as_batch(rows: &[Vec<u8>], schema: &Schema, row_type: RowType) -> Vec<ArrayRef> {
Expand Down

0 comments on commit 0b90a8a

Please sign in to comment.