diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index e66b27d073d1c..1c9b45c6ed9fc 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -31,6 +31,8 @@ pub enum EmitTo { /// For example, if `n=10`, group_index `0, 1, ... 9` are emitted /// and group indexes '`10, 11, 12, ...` become `0, 1, 2, ...`. First(usize), + /// Emit all groups managed by blocks + CurrentBlock(bool), } impl EmitTo { @@ -52,6 +54,7 @@ impl EmitTo { std::mem::swap(v, &mut t); t } + EmitTo::CurrentBlock(_) => unimplemented!(), } } } @@ -143,6 +146,12 @@ pub trait GroupsAccumulator: Send { /// [`Accumulator::state`]: crate::accumulator::Accumulator::state fn state(&mut self, emit_to: EmitTo) -> Result>; + /// Returns `true` if blocked emission is supported + /// The blocked emission is possible to avoid result splitting in aggregation. + fn supports_blocked_emission(&self) -> bool { + false + } + /// Merges intermediate state (the output from [`Self::state`]) /// into this accumulator's current state. /// diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index be7ac934d7bcc..165275f1cabd6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -22,7 +22,7 @@ use bytes_view::GroupValuesBytesView; use datafusion_common::Result; pub(crate) mod primitive; -use datafusion_expr::EmitTo; +use datafusion_expr::{groups_accumulator::GroupIndicesType, EmitTo}; use primitive::GroupValuesPrimitive; mod row; @@ -36,7 +36,12 @@ use datafusion_physical_expr::binary_map::OutputType; /// An interning store for group keys pub trait GroupValues: Send { /// Calculates the `groups` for each input row of `cols` - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + group_type: GroupIndicesType, + ) -> Result<()>; /// Returns the number of bytes used by this [`GroupValues`] fn size(&self) -> usize; @@ -52,6 +57,12 @@ pub trait GroupValues: Send { /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, batch: &RecordBatch); + + /// Returns `true` if blocked emission is supported + /// The blocked emission is possible to avoid result splitting in aggregation. + fn supports_blocked_emission(&self) -> bool { + false + } } pub fn new_group_values(schema: SchemaRef) -> Result> { diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index ed3d6d49f9f3d..c4c9121d05128 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -62,6 +62,7 @@ pub(crate) enum ExecutionState { /// When producing output, the remaining rows to output are stored /// here and are sliced off as needed in batch_size chunks ProducingOutput(RecordBatch), + ProducingBlocks(Option), /// Produce intermediate aggregate state for each input row without /// aggregation. /// @@ -387,6 +388,10 @@ pub(crate) struct GroupedHashAggregateStream { /// Optional probe for skipping data aggregation, if supported by /// current stream. skip_aggregation_probe: Option, + + enable_blocked_group_states: bool, + + block_size: usize, } impl GroupedHashAggregateStream { @@ -676,6 +681,43 @@ impl Stream for GroupedHashAggregateStream { ))); } + ExecutionState::ProducingBlocks(blocks) => { + if let Some(blk) = blocks { + if blk > 0 { + self.exec_state = ExecutionState::ProducingBlocks(Some(*blk - 1)); + } else { + self.exec_state = if self.input_done { + ExecutionState::Done + } else if self.should_skip_aggregation() { + ExecutionState::SkippingAggregation + } else { + ExecutionState::ReadingInput + }; + continue; + } + } + + let emit_result = self.emit(EmitTo::CurrentBlock(true), false); + if emit_result.is_err() { + return Poll::Ready(Some(emit_result)); + } + + let emit_batch = emit_result.unwrap(); + if emit_batch.num_rows() == 0 { + self.exec_state = if self.input_done { + ExecutionState::Done + } else if self.should_skip_aggregation() { + ExecutionState::SkippingAggregation + } else { + ExecutionState::ReadingInput + }; + } + + return Poll::Ready(Some(Ok( + emit_batch.record_output(&self.baseline_metrics) + ))); + } + ExecutionState::Done => { // release the memory reservation since sending back output batch itself needs // some memory reservation, so make some room for it. @@ -900,10 +942,15 @@ impl GroupedHashAggregateStream { && matches!(self.group_ordering, GroupOrdering::None) && matches!(self.mode, AggregateMode::Partial) && self.update_memory_reservation().is_err() - { - let n = self.group_values.len() / self.batch_size * self.batch_size; - let batch = self.emit(EmitTo::First(n), false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); + { + if self.enable_blocked_group_states { + let n = self.group_values.len() / self.batch_size * self.batch_size; + let batch = self.emit(EmitTo::First(n), false)?; + self.exec_state = ExecutionState::ProducingOutput(batch); + } else { + let blocks = self.group_values.len() / self.block_size; + self.exec_state = ExecutionState::ProducingBlocks(Some(blocks)); + } } Ok(()) } @@ -961,8 +1008,12 @@ impl GroupedHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { - let batch = self.emit(EmitTo::All, false)?; - ExecutionState::ProducingOutput(batch) + if !self.enable_blocked_group_states { + let batch = self.emit(EmitTo::All, false)?; + ExecutionState::ProducingOutput(batch) + } else { + ExecutionState::ProducingBlocks(None) + } } else { // If spill files exist, stream-merge them. self.update_merged_stream()?; @@ -994,8 +1045,12 @@ impl GroupedHashAggregateStream { fn switch_to_skip_aggregation(&mut self) -> Result<()> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { - let batch = self.emit(EmitTo::All, false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); + if !self.enable_blocked_group_states { + let batch = self.emit(EmitTo::All, false)?; + self.exec_state = ExecutionState::ProducingOutput(batch); + } else { + self.exec_state = ExecutionState::ProducingBlocks(None); + } } }