From 70ffb2fd3df7fbc744587ce3299ffb52c8ca6cbd Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 20 Jan 2024 19:13:06 -0500 Subject: [PATCH 1/3] Change `Accumulator::evaluate` and `Accumulator::state` to take `&mut self` --- datafusion-examples/examples/advanced_udaf.rs | 4 ++-- datafusion-examples/examples/simple_udaf.rs | 4 ++-- datafusion/core/src/datasource/statistics.rs | 4 ++-- .../user_defined/user_defined_aggregates.rs | 8 ++++---- .../user_defined_scalar_functions.rs | 4 ++-- datafusion/expr/src/accumulator.rs | 12 +++++++++--- .../src/aggregate/approx_distinct.rs | 4 ++-- .../src/aggregate/approx_percentile_cont.rs | 4 ++-- .../approx_percentile_cont_with_weight.rs | 4 ++-- .../physical-expr/src/aggregate/array_agg.rs | 4 ++-- .../src/aggregate/array_agg_distinct.rs | 4 ++-- .../src/aggregate/array_agg_ordered.rs | 4 ++-- .../physical-expr/src/aggregate/average.rs | 8 ++++---- .../src/aggregate/bit_and_or_xor.rs | 16 ++++++++-------- .../physical-expr/src/aggregate/bool_and_or.rs | 8 ++++---- .../physical-expr/src/aggregate/correlation.rs | 6 +++--- datafusion/physical-expr/src/aggregate/count.rs | 4 ++-- .../src/aggregate/count_distinct.rs | 12 ++++++------ .../physical-expr/src/aggregate/covariance.rs | 6 +++--- .../physical-expr/src/aggregate/first_last.rs | 8 ++++---- .../src/aggregate/groups_accumulator/adapter.rs | 4 ++-- datafusion/physical-expr/src/aggregate/median.rs | 7 +++---- .../physical-expr/src/aggregate/min_max.rs | 16 ++++++++-------- .../physical-expr/src/aggregate/nth_value.rs | 4 ++-- datafusion/physical-expr/src/aggregate/regr.rs | 4 ++-- datafusion/physical-expr/src/aggregate/stddev.rs | 6 +++--- .../physical-expr/src/aggregate/string_agg.rs | 4 ++-- datafusion/physical-expr/src/aggregate/sum.rs | 8 ++++---- .../physical-expr/src/aggregate/sum_distinct.rs | 4 ++-- datafusion/physical-expr/src/aggregate/utils.rs | 2 +- .../physical-expr/src/aggregate/variance.rs | 6 +++--- datafusion/physical-plan/src/aggregates/mod.rs | 6 +++--- .../physical-plan/src/aggregates/no_grouping.rs | 13 +++++++------ .../proto/tests/cases/roundtrip_logical_plan.rs | 8 ++++---- .../proto/tests/cases/roundtrip_physical_plan.rs | 4 ++-- .../tests/cases/roundtrip_logical_plan.rs | 4 ++-- 36 files changed, 117 insertions(+), 111 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 8d5314bfbea5..928090eec82c 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -104,7 +104,7 @@ impl Accumulator for GeometricMean { // This function serializes our state to `ScalarValue`, which DataFusion uses // to pass this state between execution stages. // Note that this can be arbitrary data. - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.prod), ScalarValue::from(self.n), @@ -113,7 +113,7 @@ impl Accumulator for GeometricMean { // DataFusion expects this function to return the final value of this aggregator. // in this case, this is the formula of the geometric mean - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let value = self.prod.powf(1.0 / self.n as f64); Ok(ScalarValue::from(value)) } diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 2c797f221b2c..0996a67245a8 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -72,7 +72,7 @@ impl Accumulator for GeometricMean { // This function serializes our state to `ScalarValue`, which DataFusion uses // to pass this state between execution stages. // Note that this can be arbitrary data. - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.prod), ScalarValue::from(self.n), @@ -81,7 +81,7 @@ impl Accumulator for GeometricMean { // DataFusion expects this function to return the final value of this aggregator. // in this case, this is the formula of the geometric mean - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let value = self.prod.powf(1.0 / self.n as f64); Ok(ScalarValue::from(value)) } diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index 695e139517cf..73896f8eb7c1 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -194,11 +194,11 @@ pub(crate) fn get_col_stats( ) -> Vec { (0..schema.fields().len()) .map(|i| { - let max_value = match &max_values[i] { + let max_value = match max_values.get_mut(i).unwrap() { Some(max_value) => max_value.evaluate().ok(), None => None, }; - let min_value = match &min_values[i] { + let min_value = match min_values.get_mut(i).unwrap() { Some(min_value) => min_value.evaluate().ok(), None => None, }; diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5882718acefd..6ec0d7dbe502 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -435,7 +435,7 @@ impl TimeSum { } impl Accumulator for TimeSum { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } @@ -457,7 +457,7 @@ impl Accumulator for TimeSum { self.update_batch(states) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { println!("Evaluating to {}", self.sum); Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None)) } @@ -582,14 +582,14 @@ impl FirstSelector { } impl Accumulator for FirstSelector { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let state = self.to_state().into_iter().collect::>(); Ok(state) } /// produce the output structure - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(self.to_scalar()) } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index fe88ea6cf115..627dbbd26b41 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -248,7 +248,7 @@ async fn udaf_as_window_func() -> Result<()> { struct MyAccumulator; impl Accumulator for MyAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { unimplemented!() } @@ -260,7 +260,7 @@ async fn udaf_as_window_func() -> Result<()> { unimplemented!() } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { unimplemented!() } diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index 32de88b3d99f..6801991ec3ea 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -56,11 +56,17 @@ pub trait Accumulator: Send + Sync + Debug { /// running sum. fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; - /// Returns the final aggregate value. + /// Returns the final aggregate value and resets internal state. /// /// For example, the `SUM` accumulator maintains a running sum, /// and `evaluate` will produce that running sum as its output. - fn evaluate(&self) -> Result; + /// + /// After this call, the accumulator's internal state should be + /// equivalent to when it was first created. + /// + /// This function gets a `mut` accumulator to allow for the accumulator to + /// use an arrow compatible internal state when possible. + fn evaluate(&mut self) -> Result; /// Returns the allocated size required for this accumulator, in /// bytes, including `Self`. @@ -129,7 +135,7 @@ pub trait Accumulator: Send + Sync + Debug { /// Note that [`ScalarValue::List`] can be used to pass multiple /// values if the number of intermediate values is not known at /// planning time (e.g. for `MEDIAN`) - fn state(&self) -> Result>; + fn state(&mut self) -> Result>; /// Updates the accumulator's state from an `Array` containing one /// or more intermediate values. diff --git a/datafusion/physical-expr/src/aggregate/approx_distinct.rs b/datafusion/physical-expr/src/aggregate/approx_distinct.rs index b79a5611c334..66e1310695ad 100644 --- a/datafusion/physical-expr/src/aggregate/approx_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/approx_distinct.rs @@ -244,12 +244,12 @@ macro_rules! default_accumulator_impl { Ok(()) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let value = ScalarValue::from(&self.hll); Ok(vec![value]) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::UInt64(Some(self.hll.count() as u64))) } diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index 15c0fb3ace4d..b3de7b0b4d36 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -377,7 +377,7 @@ impl ApproxPercentileAccumulator { } impl Accumulator for ApproxPercentileAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(self.digest.to_scalar_state().into_iter().collect()) } @@ -389,7 +389,7 @@ impl Accumulator for ApproxPercentileAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { if self.digest.count() == 0.0 { return exec_err!("aggregate function needs at least one non-null element"); } diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs index ee5ef7228f4b..3fa715a59238 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs @@ -129,7 +129,7 @@ impl ApproxPercentileWithWeightAccumulator { } impl Accumulator for ApproxPercentileWithWeightAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { self.approx_percentile_cont_accumulator.state() } @@ -155,7 +155,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { self.approx_percentile_cont_accumulator.evaluate() } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 91d5c867d312..5dc29f834feb 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -153,11 +153,11 @@ impl Accumulator for ArrayAggAccumulator { Ok(()) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { // Transform Vec to ListArr let element_arrays: Vec<&dyn Array> = diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 2d263a42e0ff..a58856e398e3 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -129,7 +129,7 @@ impl DistinctArrayAggAccumulator { } impl Accumulator for DistinctArrayAggAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } @@ -163,7 +163,7 @@ impl Accumulator for DistinctArrayAggAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let values: Vec = self.values.iter().cloned().collect(); let arr = ScalarValue::new_list(&values, &self.datatype); Ok(ScalarValue::List(arr)) diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 34f8d20628dc..5263fa83a6eb 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -279,13 +279,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { Ok(()) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let mut result = vec![self.evaluate()?]; result.push(self.evaluate_orderings()?); Ok(result) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let values = self.values.clone(); let array = if self.reverse { ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0]) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 91f2fb952dce..539ad28590de 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -239,7 +239,7 @@ pub struct AvgAccumulator { } impl Accumulator for AvgAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.count), ScalarValue::Float64(self.sum), @@ -277,7 +277,7 @@ impl Accumulator for AvgAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Float64( self.sum.map(|f| f / self.count as f64), )) @@ -315,7 +315,7 @@ impl Debug for DecimalAvgAccumulator { } impl Accumulator for DecimalAvgAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.count), ScalarValue::new_primitive::( @@ -357,7 +357,7 @@ impl Accumulator for DecimalAvgAccumulator Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let v = self .sum .map(|v| { diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs index 6c97d620616a..caa3c0f537d4 100644 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs @@ -190,11 +190,11 @@ where self.update_batch(states) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } @@ -339,7 +339,7 @@ impl Accumulator for BitOrAccumulator where T::Native: std::ops::BitOr, { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } @@ -355,7 +355,7 @@ where self.update_batch(states) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } @@ -500,7 +500,7 @@ impl Accumulator for BitXorAccumulator where T::Native: std::ops::BitXor, { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } @@ -516,7 +516,7 @@ where self.update_batch(states) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } @@ -634,7 +634,7 @@ impl Accumulator for DistinctBitXorAccumulator where T::Native: std::ops::BitXor + std::hash::Hash + Eq, { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { @@ -679,7 +679,7 @@ where Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let mut acc = T::Native::usize_as(0); for distinct_value in self.values.iter() { acc = acc ^ *distinct_value; diff --git a/datafusion/physical-expr/src/aggregate/bool_and_or.rs b/datafusion/physical-expr/src/aggregate/bool_and_or.rs index 9757d314b6aa..2faba53e1935 100644 --- a/datafusion/physical-expr/src/aggregate/bool_and_or.rs +++ b/datafusion/physical-expr/src/aggregate/bool_and_or.rs @@ -191,11 +191,11 @@ impl Accumulator for BoolAndAccumulator { self.update_batch(states) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ScalarValue::Boolean(self.acc)]) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Boolean(self.acc)) } @@ -309,7 +309,7 @@ struct BoolOrAccumulator { } impl Accumulator for BoolOrAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ScalarValue::Boolean(self.acc)]) } @@ -328,7 +328,7 @@ impl Accumulator for BoolOrAccumulator { self.update_batch(states) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Boolean(self.acc)) } diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs index 61f2db5c8ef9..4dca1e4a881e 100644 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ b/datafusion/physical-expr/src/aggregate/correlation.rs @@ -149,7 +149,7 @@ impl CorrelationAccumulator { } impl Accumulator for CorrelationAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.covar.get_count()), ScalarValue::from(self.covar.get_mean1()), @@ -215,7 +215,7 @@ impl Accumulator for CorrelationAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let covar = self.covar.evaluate()?; let stddev1 = self.stddev1.evaluate()?; let stddev2 = self.stddev2.evaluate()?; @@ -519,7 +519,7 @@ mod tests { .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?; + let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; accum1.merge_batch(&state2)?; accum1.evaluate() } diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 8e9ae5cea36b..34e2ea817a39 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -295,7 +295,7 @@ impl CountAccumulator { } impl Accumulator for CountAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ScalarValue::Int64(Some(self.count))]) } @@ -320,7 +320,7 @@ impl Accumulator for CountAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Int64(Some(self.count))) } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 021c33fb94a7..ef1a248d5f82 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -212,7 +212,7 @@ impl DistinctCountAccumulator { } impl Accumulator for DistinctCountAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); Ok(vec![ScalarValue::List(arr)]) @@ -249,7 +249,7 @@ impl Accumulator for DistinctCountAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Int64(Some(self.values.len() as i64))) } @@ -288,7 +288,7 @@ where T: ArrowPrimitiveType + Send + Debug, T::Native: Eq + Hash, { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().cloned(), )) as ArrayRef; @@ -331,7 +331,7 @@ where }) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Int64(Some(self.values.len() as i64))) } @@ -374,7 +374,7 @@ impl Accumulator for FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send + Debug, { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().map(|v| v.0), )) as ArrayRef; @@ -418,7 +418,7 @@ where }) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Int64(Some(self.values.len() as i64))) } diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index 0f838eb6fa1c..45f9926975d3 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -260,7 +260,7 @@ impl CovarianceAccumulator { } impl Accumulator for CovarianceAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.count), ScalarValue::from(self.mean1), @@ -381,7 +381,7 @@ impl Accumulator for CovarianceAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let count = match self.stats_type { StatsType::Population => self.count, StatsType::Sample => { @@ -768,7 +768,7 @@ mod tests { .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?; + let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; accum1.merge_batch(&state2)?; accum1.evaluate() } diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 4afa8d0dd5ec..d2bf48551f0d 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -270,7 +270,7 @@ impl FirstValueAccumulator { } impl Accumulator for FirstValueAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let mut result = vec![self.first.clone()]; result.extend(self.orderings.iter().cloned()); result.push(ScalarValue::Boolean(Some(self.is_set))); @@ -336,7 +336,7 @@ impl Accumulator for FirstValueAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(self.first.clone()) } @@ -586,7 +586,7 @@ impl LastValueAccumulator { } impl Accumulator for LastValueAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let mut result = vec![self.last.clone()]; result.extend(self.orderings.clone()); result.push(ScalarValue::Boolean(Some(self.is_set))); @@ -655,7 +655,7 @@ impl Accumulator for LastValueAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(self.last.clone()) } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index c6fd17a69b39..9d08b5b57241 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -272,7 +272,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { let results: Vec = states .into_iter() - .map(|state| { + .map(|mut state| { self.free_allocation(state.size()); state.accumulator.evaluate() }) @@ -293,7 +293,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // which we need to form into columns let mut results: Vec> = vec![]; - for state in states { + for mut state in states { self.free_allocation(state.size()); let accumulator_state = state.accumulator.state()?; results.resize_with(accumulator_state.len(), Vec::new); diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 691b1c1752f4..94cc5c7fb76a 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -145,7 +145,7 @@ impl std::fmt::Debug for MedianAccumulator { } impl Accumulator for MedianAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let all_values = self .all_values .iter() @@ -171,9 +171,8 @@ impl Accumulator for MedianAccumulator { Ok(()) } - fn evaluate(&self) -> Result { - // TODO: evaluate could pass &mut self - let mut d = self.all_values.clone(); + fn evaluate(&mut self) -> Result { + let mut d = std::mem::take(&mut self.all_values); let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); let len = d.len(); diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 7e3ef2a2abab..14bf69064f28 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -764,11 +764,11 @@ impl Accumulator for MaxAccumulator { self.update_batch(states) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.max.clone()]) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(self.max.clone()) } @@ -820,11 +820,11 @@ impl Accumulator for SlidingMaxAccumulator { self.update_batch(states) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.max.clone()]) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(self.max.clone()) } @@ -1016,7 +1016,7 @@ impl MinAccumulator { } impl Accumulator for MinAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.min.clone()]) } @@ -1031,7 +1031,7 @@ impl Accumulator for MinAccumulator { self.update_batch(states) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(self.min.clone()) } @@ -1058,7 +1058,7 @@ impl SlidingMinAccumulator { } impl Accumulator for SlidingMinAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.min.clone()]) } @@ -1092,7 +1092,7 @@ impl Accumulator for SlidingMinAccumulator { self.update_batch(states) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(self.min.clone()) } diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/physical-expr/src/aggregate/nth_value.rs index 5a1ca90b7f5e..26a125485826 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/physical-expr/src/aggregate/nth_value.rs @@ -302,7 +302,7 @@ impl Accumulator for NthValueAccumulator { Ok(()) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { let mut result = vec![self.evaluate_values()]; if !self.ordering_req.is_empty() { result.push(self.evaluate_orderings()); @@ -310,7 +310,7 @@ impl Accumulator for NthValueAccumulator { Ok(result) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let n_required = self.n.unsigned_abs() as usize; let from_start = self.n > 0; let nth_value_idx = if from_start { diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/physical-expr/src/aggregate/regr.rs index 6922cb131cac..36e7b7c9b3e4 100644 --- a/datafusion/physical-expr/src/aggregate/regr.rs +++ b/datafusion/physical-expr/src/aggregate/regr.rs @@ -251,7 +251,7 @@ impl RegrAccumulator { } impl Accumulator for RegrAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.count), ScalarValue::from(self.mean_x), @@ -418,7 +418,7 @@ impl Accumulator for RegrAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let cov_pop_x_y = self.algo_const / self.count as f64; let var_pop_x = self.m2_x / self.count as f64; let var_pop_y = self.m2_y / self.count as f64; diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 64e19ef502c7..dcc2b0e69c02 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -200,7 +200,7 @@ impl StddevAccumulator { } impl Accumulator for StddevAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.variance.get_count()), ScalarValue::from(self.variance.get_mean()), @@ -220,7 +220,7 @@ impl Accumulator for StddevAccumulator { self.variance.merge_batch(states) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let variance = self.variance.evaluate()?; match variance { ScalarValue::Float64(e) => { @@ -459,7 +459,7 @@ mod tests { .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?; + let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; accum1.merge_batch(&state2)?; accum1.evaluate() } diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 7adc736932ad..7a1da6d62246 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -153,11 +153,11 @@ impl Accumulator for StringAggAccumulator { Ok(()) } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::LargeUtf8(self.values.clone())) } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 03f666cc4e5d..968773568bd3 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -191,7 +191,7 @@ impl SumAccumulator { } impl Accumulator for SumAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } @@ -208,7 +208,7 @@ impl Accumulator for SumAccumulator { self.update_batch(states) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { ScalarValue::new_primitive::(self.sum, &self.data_type) } @@ -243,7 +243,7 @@ impl SlidingSumAccumulator { } impl Accumulator for SlidingSumAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?, self.count.into()]) } @@ -267,7 +267,7 @@ impl Accumulator for SlidingSumAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let v = (self.count != 0).then_some(self.sum); ScalarValue::new_primitive::(v, &self.data_type) } diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 6dbb39224629..4c0f94b3a2bb 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -140,7 +140,7 @@ impl DistinctSumAccumulator { } impl Accumulator for DistinctSumAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { @@ -186,7 +186,7 @@ impl Accumulator for DistinctSumAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let mut acc = T::Native::usize_as(0); for distinct_value in self.values.iter() { acc = acc.add_wrapping(distinct_value.0) diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 6dd586bfb8ce..60d59c16be5f 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -35,7 +35,7 @@ use datafusion_expr::Accumulator; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( - accum: &dyn Accumulator, + accum: &mut dyn Accumulator, ) -> Result> { accum .state()? diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index d82c5ad5626f..94d7be4265d7 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -231,7 +231,7 @@ impl VarianceAccumulator { } impl Accumulator for VarianceAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.count), ScalarValue::from(self.mean), @@ -302,7 +302,7 @@ impl Accumulator for VarianceAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let count = match self.stats_type { StatsType::Population => self.count, StatsType::Sample => { @@ -533,7 +533,7 @@ mod tests { .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?; + let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; accum1.merge_batch(&state2)?; accum1.evaluate() } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index d3ae0d5ce01f..2d7a8cccc481 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1087,14 +1087,14 @@ fn create_accumulators( /// returns a vector of ArrayRefs, where each entry corresponds to either the /// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial) fn finalize_aggregation( - accumulators: &[AccumulatorItem], + accumulators: &mut [AccumulatorItem], mode: &AggregateMode, ) -> Result> { match mode { AggregateMode::Partial => { // Build the vector of states accumulators - .iter() + .iter_mut() .map(|accumulator| { accumulator.state().and_then(|e| { e.iter() @@ -1111,7 +1111,7 @@ fn finalize_aggregation( | AggregateMode::SinglePartitioned => { // Merge the state to the final value accumulators - .iter() + .iter_mut() .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) .collect() } diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 90eb488a2ead..5ec95bd79942 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -137,12 +137,13 @@ impl AggregateStream { None => { this.finished = true; let timer = this.baseline_metrics.elapsed_compute().timer(); - let result = finalize_aggregation(&this.accumulators, &this.mode) - .and_then(|columns| { - RecordBatch::try_new(this.schema.clone(), columns) - .map_err(Into::into) - }) - .record_output(&this.baseline_metrics); + let result = + finalize_aggregation(&mut this.accumulators, &this.mode) + .and_then(|columns| { + RecordBatch::try_new(this.schema.clone(), columns) + .map_err(Into::into) + }) + .record_output(&this.baseline_metrics); timer.done(); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2d38cfd400ad..0fbc00174045 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1578,7 +1578,7 @@ fn roundtrip_aggregate_udf() { struct Dummy {} impl Accumulator for Dummy { - fn state(&self) -> datafusion::error::Result> { + fn state(&mut self) -> datafusion::error::Result> { Ok(vec![]) } @@ -1593,7 +1593,7 @@ fn roundtrip_aggregate_udf() { Ok(()) } - fn evaluate(&self) -> datafusion::error::Result { + fn evaluate(&mut self) -> datafusion::error::Result { Ok(ScalarValue::Float64(None)) } @@ -1762,7 +1762,7 @@ fn roundtrip_window() { struct DummyAggr {} impl Accumulator for DummyAggr { - fn state(&self) -> datafusion::error::Result> { + fn state(&mut self) -> datafusion::error::Result> { Ok(vec![]) } @@ -1777,7 +1777,7 @@ fn roundtrip_window() { Ok(()) } - fn evaluate(&self) -> datafusion::error::Result { + fn evaluate(&mut self) -> datafusion::error::Result { Ok(ScalarValue::Float64(None)) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 3a13dc887f0c..72d774a76145 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -353,7 +353,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { #[derive(Debug)] struct Example; impl Accumulator for Example { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ScalarValue::Int64(Some(0))]) } @@ -365,7 +365,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Int64(Some(0))) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index d7327caee43d..79cf76de5985 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -712,7 +712,7 @@ async fn roundtrip_aggregate_udf() -> Result<()> { struct Dummy {} impl Accumulator for Dummy { - fn state(&self) -> datafusion::error::Result> { + fn state(&mut self) -> datafusion::error::Result> { Ok(vec![]) } @@ -727,7 +727,7 @@ async fn roundtrip_aggregate_udf() -> Result<()> { Ok(()) } - fn evaluate(&self) -> datafusion::error::Result { + fn evaluate(&mut self) -> datafusion::error::Result { Ok(ScalarValue::Float64(None)) } From 8f9850792afaa8c288a3ff4f429f765bcb6c7c94 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 21 Jan 2024 07:11:06 -0500 Subject: [PATCH 2/3] improve docs --- datafusion/expr/src/accumulator.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index 6801991ec3ea..523e4e21a640 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -56,7 +56,7 @@ pub trait Accumulator: Send + Sync + Debug { /// running sum. fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; - /// Returns the final aggregate value and resets internal state. + /// Returns the final aggregate value, consuming the internal state. /// /// For example, the `SUM` accumulator maintains a running sum, /// and `evaluate` will produce that running sum as its output. @@ -64,8 +64,9 @@ pub trait Accumulator: Send + Sync + Debug { /// After this call, the accumulator's internal state should be /// equivalent to when it was first created. /// - /// This function gets a `mut` accumulator to allow for the accumulator to - /// use an arrow compatible internal state when possible. + /// This function gets `&mut self` to allow for the accumulator to build + /// arrow compatible internal state that can be returned without copying + /// when possible (for example distinct strings) fn evaluate(&mut self) -> Result; /// Returns the allocated size required for this accumulator, in @@ -78,7 +79,15 @@ pub trait Accumulator: Send + Sync + Debug { /// the `capacity` should be used not the `len`. fn size(&self) -> usize; - /// Returns the intermediate state of the accumulator. + /// Returns the intermediate state of the accumulator, consuming the + /// intermediate state. + /// + /// After this call, the accumulator's internal state should be + /// equivalent to when it was first created. + /// + /// This function gets `&mut self` to allow for the accumulator to build + /// arrow compatible internal state that can be returned without copying + /// when possible (for example distinct strings). /// /// Intermediate state is used for "multi-phase" grouping in /// DataFusion, where an aggregate is computed in parallel with From aadde16a98dea02c3465c99eb7ed804825178ef0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 22 Jan 2024 15:03:01 -0500 Subject: [PATCH 3/3] fix signature --- datafusion/core/tests/user_defined/user_defined_aggregates.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index fc3070bbd066..5dbac0322fc0 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -689,7 +689,7 @@ impl Accumulator for TestGroupsAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::from(self.result)) } @@ -697,7 +697,7 @@ impl Accumulator for TestGroupsAccumulator { std::mem::size_of::() } - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ScalarValue::from(self.result)]) }