Skip to content

Commit

Permalink
Change Accumulator::evaluate and Accumulator::state to take `&mut…
Browse files Browse the repository at this point in the history
… self` (#8925)

* Change `Accumulator::evaluate` and `Accumulator::state` to take `&mut self`

* improve docs

* fix signature
  • Loading branch information
alamb authored Jan 24, 2024
1 parent 2b84877 commit 5d70c32
Show file tree
Hide file tree
Showing 36 changed files with 129 additions and 114 deletions.
4 changes: 2 additions & 2 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
Expand All @@ -134,7 +134,7 @@ impl Accumulator for GeometricMean {

// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let value = self.prod.powf(1.0 / self.n as f64);
Ok(ScalarValue::from(value))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
Expand All @@ -81,7 +81,7 @@ impl Accumulator for GeometricMean {

// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let value = self.prod.powf(1.0 / self.n as f64);
Ok(ScalarValue::from(value))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ pub(crate) fn get_col_stats(
) -> Vec<ColumnStatistics> {
(0..schema.fields().len())
.map(|i| {
let max_value = match &max_values[i] {
let max_value = match max_values.get_mut(i).unwrap() {
Some(max_value) => max_value.evaluate().ok(),
None => None,
};
let min_value = match &min_values[i] {
let min_value = match min_values.get_mut(i).unwrap() {
Some(min_value) => min_value.evaluate().ok(),
None => None,
};
Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ impl TimeSum {
}

impl Accumulator for TimeSum {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -478,7 +478,7 @@ impl Accumulator for TimeSum {
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
println!("Evaluating to {}", self.sum);
Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
}
Expand Down Expand Up @@ -603,14 +603,14 @@ impl FirstSelector {
}

impl Accumulator for FirstSelector {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let state = self.to_state().into_iter().collect::<Vec<_>>();

Ok(state)
}

/// produce the output structure
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(self.to_scalar())
}

Expand Down Expand Up @@ -689,15 +689,15 @@ impl Accumulator for TestGroupsAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::from(self.result))
}

fn size(&self) -> usize {
std::mem::size_of::<u64>()
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::from(self.result)])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ async fn udaf_as_window_func() -> Result<()> {
struct MyAccumulator;

impl Accumulator for MyAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
unimplemented!()
}

Expand All @@ -266,7 +266,7 @@ async fn udaf_as_window_func() -> Result<()> {
unimplemented!()
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
unimplemented!()
}

Expand Down
23 changes: 19 additions & 4 deletions datafusion/expr/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,18 @@ pub trait Accumulator: Send + Sync + Debug {
/// running sum.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;

/// Returns the final aggregate value.
/// Returns the final aggregate value, consuming the internal state.
///
/// For example, the `SUM` accumulator maintains a running sum,
/// and `evaluate` will produce that running sum as its output.
fn evaluate(&self) -> Result<ScalarValue>;
///
/// After this call, the accumulator's internal state should be
/// equivalent to when it was first created.
///
/// This function gets `&mut self` to allow for the accumulator to build
/// arrow compatible internal state that can be returned without copying
/// when possible (for example distinct strings)
fn evaluate(&mut self) -> Result<ScalarValue>;

/// Returns the allocated size required for this accumulator, in
/// bytes, including `Self`.
Expand All @@ -72,7 +79,15 @@ pub trait Accumulator: Send + Sync + Debug {
/// the `capacity` should be used not the `len`.
fn size(&self) -> usize;

/// Returns the intermediate state of the accumulator.
/// Returns the intermediate state of the accumulator, consuming the
/// intermediate state.
///
/// After this call, the accumulator's internal state should be
/// equivalent to when it was first created.
///
/// This function gets `&mut self` to allow for the accumulator to build
/// arrow compatible internal state that can be returned without copying
/// when possible (for example distinct strings).
///
/// Intermediate state is used for "multi-phase" grouping in
/// DataFusion, where an aggregate is computed in parallel with
Expand Down Expand Up @@ -129,7 +144,7 @@ pub trait Accumulator: Send + Sync + Debug {
/// Note that [`ScalarValue::List`] can be used to pass multiple
/// values if the number of intermediate values is not known at
/// planning time (e.g. for `MEDIAN`)
fn state(&self) -> Result<Vec<ScalarValue>>;
fn state(&mut self) -> Result<Vec<ScalarValue>>;

/// Updates the accumulator's state from an `Array` containing one
/// or more intermediate values.
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ macro_rules! default_accumulator_impl {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let value = ScalarValue::from(&self.hll);
Ok(vec![value])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::UInt64(Some(self.hll.count() as u64)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl ApproxPercentileAccumulator {
}

impl Accumulator for ApproxPercentileAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(self.digest.to_scalar_state().into_iter().collect())
}

Expand All @@ -389,7 +389,7 @@ impl Accumulator for ApproxPercentileAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
if self.digest.count() == 0.0 {
return exec_err!("aggregate function needs at least one non-null element");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl ApproxPercentileWithWeightAccumulator {
}

impl Accumulator for ApproxPercentileWithWeightAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.approx_percentile_cont_accumulator.state()
}

Expand All @@ -155,7 +155,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
self.approx_percentile_cont_accumulator.evaluate()
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ impl Accumulator for ArrayAggAccumulator {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
// Transform Vec<ListArr> to ListArr

let element_arrays: Vec<&dyn Array> =
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl DistinctArrayAggAccumulator {
}

impl Accumulator for DistinctArrayAggAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand Down Expand Up @@ -163,7 +163,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
let arr = ScalarValue::new_list(&values, &self.datatype);
Ok(ScalarValue::List(arr))
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.evaluate()?];
result.push(self.evaluate_orderings()?);
Ok(result)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values = self.values.clone();
let array = if self.reverse {
ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0])
Expand Down
8 changes: 4 additions & 4 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ pub struct AvgAccumulator {
}

impl Accumulator for AvgAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::Float64(self.sum),
Expand Down Expand Up @@ -276,7 +276,7 @@ impl Accumulator for AvgAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
Expand Down Expand Up @@ -314,7 +314,7 @@ impl<T: DecimalType + ArrowNumericType> Debug for DecimalAvgAccumulator<T> {
}

impl<T: DecimalType + ArrowNumericType> Accumulator for DecimalAvgAccumulator<T> {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::new_primitive::<T>(
Expand Down Expand Up @@ -356,7 +356,7 @@ impl<T: DecimalType + ArrowNumericType> Accumulator for DecimalAvgAccumulator<T>
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let v = self
.sum
.map(|v| {
Expand Down
16 changes: 8 additions & 8 deletions datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ where
self.update_batch(states)
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -339,7 +339,7 @@ impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
where
T::Native: std::ops::BitOr<Output = T::Native>,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -355,7 +355,7 @@ where
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -500,7 +500,7 @@ impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
where
T::Native: std::ops::BitXor<Output = T::Native>,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -516,7 +516,7 @@ where
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -634,7 +634,7 @@ impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
where
T::Native: std::ops::BitXor<Output = T::Native> + std::hash::Hash + Eq,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
// 1. Stores aggregate state in `ScalarValue::List`
// 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
let state_out = {
Expand Down Expand Up @@ -679,7 +679,7 @@ where
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let mut acc = T::Native::usize_as(0);
for distinct_value in self.values.iter() {
acc = acc ^ *distinct_value;
Expand Down
Loading

0 comments on commit 5d70c32

Please sign in to comment.