Skip to content

Commit

Permalink
Convert Correlation to UDAF (#11064)
Browse files Browse the repository at this point in the history
* init

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

* remove files

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
pingsutw and alamb authored Jun 23, 2024
1 parent 8aad208 commit d32747d
Show file tree
Hide file tree
Showing 18 changed files with 240 additions and 1,073 deletions.
11 changes: 0 additions & 11 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ pub enum AggregateFunction {
ArrayAgg,
/// N'th value in a group according to some ordering
NthValue,
/// Correlation
Correlation,
/// Grouping
Grouping,
}
Expand All @@ -55,7 +53,6 @@ impl AggregateFunction {
Max => "MAX",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Correlation => "CORR",
Grouping => "GROUPING",
}
}
Expand All @@ -76,8 +73,6 @@ impl FromStr for AggregateFunction {
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
// statistical
"corr" => AggregateFunction::Correlation,
// other
"grouping" => AggregateFunction::Grouping,
_ => {
Expand Down Expand Up @@ -115,9 +110,6 @@ impl AggregateFunction {
// The coerced_data_types is same with input_types.
Ok(coerced_data_types[0].clone())
}
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
Expand Down Expand Up @@ -150,9 +142,6 @@ impl AggregateFunction {
Signature::uniform(1, valid, Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
}
}
}
Expand Down
11 changes: 0 additions & 11 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ pub fn coerce_types(
input_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
use DataType::*;
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

Expand All @@ -102,16 +101,6 @@ pub fn coerce_types(
// unpack the dictionary to get the value
get_min_max_result_type(input_types)
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64, Float64])
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
}
Expand Down
225 changes: 225 additions & 0 deletions datafusion/functions-aggregate/src/correlation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`Correlation`]: correlation sample aggregations.
use std::any::Any;
use std::fmt::Debug;

use arrow::compute::{and, filter, is_not_null};
use arrow::{
array::ArrayRef,
datatypes::{DataType, Field},
};

use crate::covariance::CovarianceAccumulator;
use crate::stddev::StddevAccumulator;
use datafusion_common::{plan_err, Result, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
type_coercion::aggregates::NUMERICS,
utils::format_state_name,
Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr_common::aggregate::stats::StatsType;

make_udaf_expr_and_func!(
Correlation,
corr,
y x,
"Correlation between two numeric values.",
corr_udaf
);

#[derive(Debug)]
pub struct Correlation {
signature: Signature,
}

impl Default for Correlation {
fn default() -> Self {
Self::new()
}
}

impl Correlation {
/// Create a new COVAR_POP aggregate function
pub fn new() -> Self {
Self {
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for Correlation {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"corr"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Correlation requires numeric input types");
}

Ok(DataType::Float64)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CorrelationAccumulator::try_new()?))
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
Field::new(
format_state_name(name, "algo_const"),
DataType::Float64,
true,
),
])
}
}

/// An accumulator to compute correlation
#[derive(Debug)]
pub struct CorrelationAccumulator {
covar: CovarianceAccumulator,
stddev1: StddevAccumulator,
stddev2: StddevAccumulator,
}

impl CorrelationAccumulator {
/// Creates a new `CorrelationAccumulator`
pub fn try_new() -> Result<Self> {
Ok(Self {
covar: CovarianceAccumulator::try_new(StatsType::Population)?,
stddev1: StddevAccumulator::try_new(StatsType::Population)?,
stddev2: StddevAccumulator::try_new(StatsType::Population)?,
})
}
}

impl Accumulator for CorrelationAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
// TODO: null input skipping logic duplicated across Correlation
// and its children accumulators.
// This could be simplified by splitting up input filtering and
// calculation logic in children accumulators, and calling only
// calculation part from Correlation
let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
let values1 = filter(&values[0], &mask)?;
let values2 = filter(&values[1], &mask)?;

vec![values1, values2]
} else {
values.to_vec()
};

self.covar.update_batch(&values)?;
self.stddev1.update_batch(&values[0..1])?;
self.stddev2.update_batch(&values[1..2])?;
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let covar = self.covar.evaluate()?;
let stddev1 = self.stddev1.evaluate()?;
let stddev2 = self.stddev2.evaluate()?;

if let ScalarValue::Float64(Some(c)) = covar {
if let ScalarValue::Float64(Some(s1)) = stddev1 {
if let ScalarValue::Float64(Some(s2)) = stddev2 {
if s1 == 0_f64 || s2 == 0_f64 {
return Ok(ScalarValue::Float64(Some(0_f64)));
} else {
return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
}
}
}
}

Ok(ScalarValue::Float64(None))
}

fn size(&self) -> usize {
std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar)
+ self.covar.size()
- std::mem::size_of_val(&self.stddev1)
+ self.stddev1.size()
- std::mem::size_of_val(&self.stddev2)
+ self.stddev2.size()
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.covar.get_count()),
ScalarValue::from(self.covar.get_mean1()),
ScalarValue::from(self.stddev1.get_m2()),
ScalarValue::from(self.covar.get_mean2()),
ScalarValue::from(self.stddev2.get_m2()),
ScalarValue::from(self.covar.get_algo_const()),
])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let states_c = [
states[0].clone(),
states[1].clone(),
states[3].clone(),
states[5].clone(),
];
let states_s1 = [states[0].clone(), states[1].clone(), states[2].clone()];
let states_s2 = [states[0].clone(), states[3].clone(), states[4].clone()];

self.covar.merge_batch(&states_c)?;
self.stddev1.merge_batch(&states_s1)?;
self.stddev2.merge_batch(&states_s2)?;
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
let values1 = filter(&values[0], &mask)?;
let values2 = filter(&values[1], &mask)?;

vec![values1, values2]
} else {
values.to_vec()
};

self.covar.retract_batch(&values)?;
self.stddev1.retract_batch(&values[0..1])?;
self.stddev2.retract_batch(&values[1..2])?;
Ok(())
}
}
8 changes: 6 additions & 2 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
pub mod macros;

pub mod approx_distinct;
pub mod correlation;
pub mod count;
pub mod covariance;
pub mod first_last;
Expand All @@ -73,6 +74,7 @@ pub mod average;
pub mod bit_and_or_xor;
pub mod bool_and_or;
pub mod string_agg;

use crate::approx_percentile_cont::approx_percentile_cont_udaf;
use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
use datafusion_common::Result;
Expand All @@ -93,6 +95,7 @@ pub mod expr_fn {
pub use super::bit_and_or_xor::bit_xor;
pub use super::bool_and_or::bool_and;
pub use super::bool_and_or::bool_or;
pub use super::correlation::corr;
pub use super::count::count;
pub use super::count::count_distinct;
pub use super::covariance::covar_pop;
Expand Down Expand Up @@ -122,8 +125,9 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
first_last::first_value_udaf(),
first_last::last_value_udaf(),
covariance::covar_samp_udaf(),
sum::sum_udaf(),
covariance::covar_pop_udaf(),
correlation::corr_udaf(),
sum::sum_udaf(),
median::median_udaf(),
count::count_udaf(),
regr::regr_slope_udaf(),
Expand Down Expand Up @@ -179,7 +183,7 @@ mod tests {
let mut names = HashSet::new();
for func in all_default_aggregate_functions() {
// TODO: remove this
// These functions are in intermidiate migration state, skip them
// These functions are in intermediate migration state, skip them
if func.name().to_lowercase() == "count" {
continue;
}
Expand Down
11 changes: 0 additions & 11 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,6 @@ pub fn create_aggregate_expr(
name,
data_type,
)),
(AggregateFunction::Correlation, false) => {
Arc::new(expressions::Correlation::new(
input_phy_exprs[0].clone(),
input_phy_exprs[1].clone(),
name,
data_type,
))
}
(AggregateFunction::Correlation, true) => {
return not_impl_err!("CORR(DISTINCT) aggregations are not available");
}
(AggregateFunction::NthValue, _) => {
let expr = &input_phy_exprs[0];
let Some(n) = input_phy_exprs[1]
Expand Down
Loading

0 comments on commit d32747d

Please sign in to comment.