Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Accumulator::evaluate and Accumulator::state to take &mut self #8925

Merged
merged 5 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an example of what DataFusion users have to do if this PR is merged -- they have to change the signatures to take mut

Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
Expand All @@ -134,7 +134,7 @@ impl Accumulator for GeometricMean {

// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let value = self.prod.powf(1.0 / self.n as f64);
Ok(ScalarValue::from(value))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
Expand All @@ -81,7 +81,7 @@ impl Accumulator for GeometricMean {

// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let value = self.prod.powf(1.0 / self.n as f64);
Ok(ScalarValue::from(value))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ pub(crate) fn get_col_stats(
) -> Vec<ColumnStatistics> {
(0..schema.fields().len())
.map(|i| {
let max_value = match &max_values[i] {
let max_value = match max_values.get_mut(i).unwrap() {
Some(max_value) => max_value.evaluate().ok(),
None => None,
};
let min_value = match &min_values[i] {
let min_value = match min_values.get_mut(i).unwrap() {
Some(min_value) => min_value.evaluate().ok(),
None => None,
};
Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ impl TimeSum {
}

impl Accumulator for TimeSum {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -478,7 +478,7 @@ impl Accumulator for TimeSum {
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
println!("Evaluating to {}", self.sum);
Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
}
Expand Down Expand Up @@ -603,14 +603,14 @@ impl FirstSelector {
}

impl Accumulator for FirstSelector {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let state = self.to_state().into_iter().collect::<Vec<_>>();

Ok(state)
}

/// produce the output structure
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(self.to_scalar())
}

Expand Down Expand Up @@ -689,15 +689,15 @@ impl Accumulator for TestGroupsAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::from(self.result))
}

fn size(&self) -> usize {
std::mem::size_of::<u64>()
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::from(self.result)])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ async fn udaf_as_window_func() -> Result<()> {
struct MyAccumulator;

impl Accumulator for MyAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
unimplemented!()
}

Expand All @@ -266,7 +266,7 @@ async fn udaf_as_window_func() -> Result<()> {
unimplemented!()
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
unimplemented!()
}

Expand Down
23 changes: 19 additions & 4 deletions datafusion/expr/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,18 @@ pub trait Accumulator: Send + Sync + Debug {
/// running sum.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;

/// Returns the final aggregate value.
/// Returns the final aggregate value, consuming the internal state.
///
/// For example, the `SUM` accumulator maintains a running sum,
/// and `evaluate` will produce that running sum as its output.
fn evaluate(&self) -> Result<ScalarValue>;
///
/// After this call, the accumulator's internal state should be
/// equivalent to when it was first created.
///
/// This function gets `&mut self` to allow for the accumulator to build
/// arrow compatible internal state that can be returned without copying
/// when possible (for example distinct strings)
fn evaluate(&mut self) -> Result<ScalarValue>;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core change in this PR -- make evaluate and state get a &mut reference so the Accumulator can avoid copies

This is technically possible today (with "inner mutability", for example a Mutex) but that is both not documented and quite unobvious


/// Returns the allocated size required for this accumulator, in
/// bytes, including `Self`.
Expand All @@ -72,7 +79,15 @@ pub trait Accumulator: Send + Sync + Debug {
/// the `capacity` should be used not the `len`.
fn size(&self) -> usize;

/// Returns the intermediate state of the accumulator.
/// Returns the intermediate state of the accumulator, consuming the
/// intermediate state.
///
/// After this call, the accumulator's internal state should be
/// equivalent to when it was first created.
///
/// This function gets `&mut self` to allow for the accumulator to build
/// arrow compatible internal state that can be returned without copying
/// when possible (for example distinct strings).
///
/// Intermediate state is used for "multi-phase" grouping in
/// DataFusion, where an aggregate is computed in parallel with
Expand Down Expand Up @@ -129,7 +144,7 @@ pub trait Accumulator: Send + Sync + Debug {
/// Note that [`ScalarValue::List`] can be used to pass multiple
/// values if the number of intermediate values is not known at
/// planning time (e.g. for `MEDIAN`)
fn state(&self) -> Result<Vec<ScalarValue>>;
fn state(&mut self) -> Result<Vec<ScalarValue>>;

/// Updates the accumulator's state from an `Array` containing one
/// or more intermediate values.
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ macro_rules! default_accumulator_impl {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let value = ScalarValue::from(&self.hll);
Ok(vec![value])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::UInt64(Some(self.hll.count() as u64)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl ApproxPercentileAccumulator {
}

impl Accumulator for ApproxPercentileAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(self.digest.to_scalar_state().into_iter().collect())
}

Expand All @@ -389,7 +389,7 @@ impl Accumulator for ApproxPercentileAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
if self.digest.count() == 0.0 {
return exec_err!("aggregate function needs at least one non-null element");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl ApproxPercentileWithWeightAccumulator {
}

impl Accumulator for ApproxPercentileWithWeightAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.approx_percentile_cont_accumulator.state()
}

Expand All @@ -155,7 +155,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
self.approx_percentile_cont_accumulator.evaluate()
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ impl Accumulator for ArrayAggAccumulator {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
// Transform Vec<ListArr> to ListArr

let element_arrays: Vec<&dyn Array> =
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl DistinctArrayAggAccumulator {
}

impl Accumulator for DistinctArrayAggAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand Down Expand Up @@ -163,7 +163,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
let arr = ScalarValue::new_list(&values, &self.datatype);
Ok(ScalarValue::List(arr))
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.evaluate()?];
result.push(self.evaluate_orderings()?);
Ok(result)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values = self.values.clone();
let array = if self.reverse {
ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0])
Expand Down
8 changes: 4 additions & 4 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ pub struct AvgAccumulator {
}

impl Accumulator for AvgAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::Float64(self.sum),
Expand Down Expand Up @@ -276,7 +276,7 @@ impl Accumulator for AvgAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
Expand Down Expand Up @@ -314,7 +314,7 @@ impl<T: DecimalType + ArrowNumericType> Debug for DecimalAvgAccumulator<T> {
}

impl<T: DecimalType + ArrowNumericType> Accumulator for DecimalAvgAccumulator<T> {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::new_primitive::<T>(
Expand Down Expand Up @@ -356,7 +356,7 @@ impl<T: DecimalType + ArrowNumericType> Accumulator for DecimalAvgAccumulator<T>
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let v = self
.sum
.map(|v| {
Expand Down
16 changes: 8 additions & 8 deletions datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ where
self.update_batch(states)
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -339,7 +339,7 @@ impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
where
T::Native: std::ops::BitOr<Output = T::Native>,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -355,7 +355,7 @@ where
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -500,7 +500,7 @@ impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
where
T::Native: std::ops::BitXor<Output = T::Native>,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -516,7 +516,7 @@ where
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -634,7 +634,7 @@ impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
where
T::Native: std::ops::BitXor<Output = T::Native> + std::hash::Hash + Eq,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
// 1. Stores aggregate state in `ScalarValue::List`
// 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
let state_out = {
Expand Down Expand Up @@ -679,7 +679,7 @@ where
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let mut acc = T::Native::usize_as(0);
for distinct_value in self.values.iter() {
acc = acc ^ *distinct_value;
Expand Down
Loading