From 0b90a8a5c2635e08e80995954271fd06a256ac96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Sat, 15 Oct 2022 12:16:43 +0100 Subject: [PATCH] Generate hash aggregation output in smaller record batches (#3461) * change how final aggregation row group is created ... this change would prevent of cloning of whole state, doubling memory needed for aggregation. this PR relates to apache/arrow-datafusion#1570 * Fix clippy issues * read batch size from `session_config` --- .../core/src/physical_plan/aggregates/mod.rs | 2 + .../src/physical_plan/aggregates/row_hash.rs | 121 +++++++++++------- 2 files changed, 74 insertions(+), 49 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index aaf03b5c2a80..2c4a9b26c399 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -298,6 +298,7 @@ impl ExecutionPlan for AggregateExec { partition: usize, context: Arc, ) -> Result { + let batch_size = context.session_config().batch_size(); let input = self.input.execute(partition, context)?; let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -318,6 +319,7 @@ impl ExecutionPlan for AggregateExec { self.aggr_expr.clone(), input, baseline_metrics, + batch_size, )?)) } else { Ok(Box::pin(GroupedHashAggregateStream::new( diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index d1c61cd287f9..aefc6571b068 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -85,7 +85,11 @@ pub(crate) struct GroupedHashAggregateStreamV2 { baseline_metrics: BaselineMetrics, random_state: RandomState, - finished: bool, + /// size to be used for resulting RecordBatches + batch_size: usize, + /// if the result is chunked into batches, + /// last offset is preserved for continuation. + row_group_skip_position: usize, } fn aggr_state_schema(aggr_expr: &[Arc]) -> Result { @@ -105,6 +109,7 @@ impl GroupedHashAggregateStreamV2 { aggr_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, + batch_size: usize, ) -> Result { let timer = baseline_metrics.elapsed_compute().timer(); @@ -135,7 +140,8 @@ impl GroupedHashAggregateStreamV2 { aggregate_expressions, aggr_state: Default::default(), random_state: Default::default(), - finished: false, + batch_size, + row_group_skip_position: 0, }) } } @@ -148,56 +154,62 @@ impl Stream for GroupedHashAggregateStreamV2 { cx: &mut Context<'_>, ) -> Poll> { let this = &mut *self; - if this.finished { - return Poll::Ready(None); - } let elapsed_compute = this.baseline_metrics.elapsed_compute(); loop { - let result = match ready!(this.input.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = group_aggregate_batch( - &this.mode, - &this.random_state, - &this.group_by, - &mut this.accumulators, - &this.group_schema, - this.aggr_layout.clone(), - batch, - &mut this.aggr_state, - &this.aggregate_expressions, - ); - - timer.done(); - - match result { - Ok(_) => continue, - Err(e) => Err(ArrowError::ExternalError(Box::new(e))), + let result: ArrowResult> = + match ready!(this.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + let result = group_aggregate_batch( + &this.mode, + &this.random_state, + &this.group_by, + &mut this.accumulators, + &this.group_schema, + this.aggr_layout.clone(), + batch, + &mut this.aggr_state, + &this.aggregate_expressions, + ); + + timer.done(); + + match result { + Ok(_) => continue, + Err(e) => Err(ArrowError::ExternalError(Box::new(e))), + } } + Some(Err(e)) => Err(e), + None => { + let timer = this.baseline_metrics.elapsed_compute().timer(); + let result = create_batch_from_map( + &this.mode, + &this.group_schema, + &this.aggr_schema, + this.batch_size, + this.row_group_skip_position, + &mut this.aggr_state, + &mut this.accumulators, + &this.schema, + ); + + timer.done(); + result + } + }; + + this.row_group_skip_position += this.batch_size; + match result { + Ok(Some(result)) => { + return Poll::Ready(Some(Ok( + result.record_output(&this.baseline_metrics) + ))) } - Some(Err(e)) => Err(e), - None => { - this.finished = true; - let timer = this.baseline_metrics.elapsed_compute().timer(); - let result = create_batch_from_map( - &this.mode, - &this.group_schema, - &this.aggr_schema, - &mut this.aggr_state, - &mut this.accumulators, - &this.schema, - ) - .record_output(&this.baseline_metrics); - - timer.done(); - result - } - }; - - this.finished = true; - return Poll::Ready(Some(result)); + Ok(None) => return Poll::Ready(None), + Err(error) => return Poll::Ready(Some(Err(error))), + } } } } @@ -419,16 +431,25 @@ fn create_group_rows(arrays: Vec, schema: &Schema) -> Vec> { } /// Create a RecordBatch with all group keys and accumulator' states or values. +#[allow(clippy::too_many_arguments)] fn create_batch_from_map( mode: &AggregateMode, group_schema: &Schema, aggr_schema: &Schema, + batch_size: usize, + skip_items: usize, aggr_state: &mut AggregationState, accumulators: &mut [AccumulatorItemV2], output_schema: &Schema, -) -> ArrowResult { +) -> ArrowResult> { + if skip_items > aggr_state.group_states.len() { + return Ok(None); + } + if aggr_state.group_states.is_empty() { - return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned()))); + return Ok(Some(RecordBatch::new_empty(Arc::new( + output_schema.to_owned(), + )))); } let mut state_accessor = RowAccessor::new(aggr_schema, RowType::WordAligned); @@ -436,6 +457,8 @@ fn create_batch_from_map( let (group_buffers, mut state_buffers): (Vec<_>, Vec<_>) = aggr_state .group_states .iter() + .skip(skip_items) + .take(batch_size) .map(|gs| (gs.group_by_values.clone(), gs.aggregation_buffer.clone())) .unzip(); @@ -471,7 +494,7 @@ fn create_batch_from_map( .map(|(col, desired_field)| cast(col, desired_field.data_type())) .collect::>>()?; - RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns) + RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns).map(Some) } fn read_as_batch(rows: &[Vec], schema: &Schema, row_type: RowType) -> Vec {