Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle type coercion in signature for ApproxPercentileCont #12274

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2428,31 +2428,33 @@ mod tests {
let df: Vec<RecordBatch> = df.select(aggr_expr)?.collect().await?;

assert_batches_sorted_eq!(
["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
[
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| | | | | | | | 1 | -85 |",
"| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |",
"| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |",
"| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |",
"| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |",
"| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |",
"| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |",
"| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |",
"| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |",
"| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |",
"| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |",
"| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |",
"| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |",
"| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |",
"| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |",
"| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |",
"| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |",
"| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |",
"| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"],
"| -85 | -101 | 14 | -12.0 | -101 | 83 | -101 | 4 | -54 |",
"| -85 | -101 | 17 | -25.0 | -101 | 83 | -101 | 5 | -31 |",
"| -85 | -12 | 10 | -32.75 | -12 | 83 | -85 | 3 | 13 |",
"| -85 | -25 | 3 | -56.0 | -25 | -25 | -85 | 1 | -5 |",
"| -85 | -31 | 18 | -29.75 | -31 | 83 | -101 | 5 | 36 |",
"| -85 | -38 | 16 | -25.0 | -38 | 83 | -101 | 4 | 65 |",
"| -85 | -43 | 7 | -43.0 | -43 | 83 | -85 | 2 | 45 |",
"| -85 | -48 | 6 | -35.75 | -48 | 83 | -85 | 2 | -43 |",
"| -85 | -5 | 4 | -37.75 | -5 | -5 | -85 | 1 | 83 |",
"| -85 | -54 | 15 | -17.0 | -54 | 83 | -101 | 4 | -38 |",
"| -85 | -56 | 2 | -70.5 | -56 | -56 | -85 | 1 | -25 |",
"| -85 | -72 | 9 | -43.0 | -72 | 83 | -85 | 3 | -12 |",
"| -85 | -85 | 1 | -85.0 | -85 | -85 | -85 | 1 | -56 |",
"| -85 | 13 | 11 | -17.0 | 13 | 83 | -85 | 3 | 14 |",
"| -85 | 13 | 11 | -25.0 | 13 | 83 | -85 | 3 | 13 |",
"| -85 | 14 | 12 | -12.0 | 14 | 83 | -85 | 3 | 17 |",
"| -85 | 17 | 13 | -11.25 | 17 | 83 | -85 | 4 | -101 |",
"| -85 | 45 | 8 | -34.5 | 45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17.0 | 65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25.0 | 83 | 83 | -85 | 2 | -48 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
],
&df
);

Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ async fn test_fn_approx_median() -> Result<()> {
"+-----------------------+",
"| approx_median(test.b) |",
"+-----------------------+",
"| 10 |",
"| 10.0 |",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a change in behavior -- with this PR now median always returns float but before it returned the same type as its input

This comment was marked as outdated.

Copy link
Contributor Author

@jayzhan211 jayzhan211 Sep 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I change the result to f64 now.

I think it is fine to have f64 for median value. I check the result of Duckdb, they have double for integer, although they have decimal for decimal input, but since we doesn't support decimal for approx_median so there is no regression. We could support decimal case later on

"+-----------------------+",
];

Expand All @@ -366,7 +366,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
"+---------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5)) |",
"+---------------------------------------------+",
"| 10 |",
"| 10.0 |",
"+---------------------------------------------+",
];

Expand All @@ -387,7 +387,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
"+--------------------------------------+",
"| approx_percentile_cont(test.b,arg_2) |",
"+--------------------------------------+",
"| 10 |",
"| 10.0 |",
"+--------------------------------------+",
];
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
Expand All @@ -400,7 +400,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
"+------------------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) |",
"+------------------------------------------------------+",
"| 30 |",
"| 30.25 |",
"+------------------------------------------------------+",
];

Expand Down
16 changes: 8 additions & 8 deletions datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ use std::fmt::Debug;
use arrow::{datatypes::DataType, datatypes::Field};
use arrow_schema::DataType::{Float64, UInt64};

use datafusion_common::{not_impl_err, plan_err, Result};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};

Expand Down Expand Up @@ -63,7 +62,7 @@ impl ApproxMedian {
/// Create a new APPROX_MEDIAN aggregate function
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand Down Expand Up @@ -97,11 +96,8 @@ impl AggregateUDFImpl for ApproxMedian {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("ApproxMedian requires numeric input types");
}
Ok(arg_types[0].clone())
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Expand All @@ -116,4 +112,8 @@ impl AggregateUDFImpl for ApproxMedian {
acc_args.exprs[0].data_type(acc_args.schema)?,
)))
}

fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
Ok(vec![DataType::Float64])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC, why this instead of just defining the signature as DataType::Float64? Afaik DF already tries to coerce inputs to the signature

Copy link
Contributor Author

@jayzhan211 jayzhan211 Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are referring to coerced_from? I want to deprecate that function because the downside of having single truth of coercion make updating the coercion rule unpredictable, easy to cause bug without notice (hard to have full test coverage too). The new approach is to handle the coercion by signature.

I think I could change the signature to Coercible(vec![Float64]) #12275

}
}
173 changes: 36 additions & 137 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,22 @@ use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use arrow::array::{Array, RecordBatch};
use arrow::array::{Array, AsArray, RecordBatch};
use arrow::compute::{filter, is_not_null};
use arrow::{
array::{
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
},
datatypes::DataType,
};
use arrow::datatypes::Float64Type;
use arrow::{array::ArrayRef, datatypes::DataType};
use arrow_schema::{Field, Schema};

use datafusion_common::{
downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
DataFusionError, Result, ScalarValue,
exec_err, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, Result,
ScalarValue,
};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature,
Volatility,
};
use datafusion_functions_aggregate_common::tdigest::{
TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, Volatility,
};
use datafusion_functions_aggregate_common::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
Expand Down Expand Up @@ -84,21 +75,8 @@ impl Default for ApproxPercentileCont {
impl ApproxPercentileCont {
/// Create a new [`ApproxPercentileCont`] aggregate function.
pub fn new() -> Self {
let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
// Accept any numeric value paired with a float64 percentile
for num in NUMERICS {
variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
// Additionally accept an integer number of centroids for T-Digest
for int in INTEGERS {
variants.push(TypeSignature::Exact(vec![
num.clone(),
DataType::Float64,
int.clone(),
]))
}
}
Self {
signature: Signature::one_of(variants, Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
}
}

Expand Down Expand Up @@ -156,15 +134,12 @@ fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
let percentile = match get_scalar_value(expr)
.map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
ScalarValue::Float32(Some(value)) => {
value as f64
}
ScalarValue::Float64(Some(value)) => {
value
}
sv => {
return not_impl_err!(
"Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
return internal_err!(
"Percentile value for 'APPROX_PERCENTILE_CONT' should be coerced to f64 (got data type {})",
sv.data_type()
)
}
Expand All @@ -182,17 +157,10 @@ fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
let max_size = match get_scalar_value(expr)
.map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
ScalarValue::UInt8(Some(q)) => q as usize,
ScalarValue::UInt16(Some(q)) => q as usize,
ScalarValue::UInt32(Some(q)) => q as usize,
ScalarValue::UInt64(Some(q)) => q as usize,
ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
sv => {
return not_impl_err!(
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' should be coerced to u64 literal (got data type {}).",
sv.data_type()
)
},
Expand Down Expand Up @@ -257,16 +225,26 @@ impl AggregateUDFImpl for ApproxPercentileCont {
Ok(Box::new(self.create_accumulator(acc_args)?))
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("approx_percentile_cont requires numeric input types");
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() == 3 {
// Since float is coercible to u64 in `can_cast_types`, we check whether it is integer
if !arg_types[2].is_integer() {
return exec_err!(
"3rd argument should be integer but got {}",
arg_types[2]
);
}
}
if arg_types.len() == 3 && !arg_types[2].is_integer() {
return plan_err!(
"approx_percentile_cont requires integer max_size input types"
);

if arg_types.len() == 2 {
Ok(vec![DataType::Float64; 2])
} else {
Ok(vec![DataType::Float64, DataType::Float64, DataType::UInt64])
}
Ok(arg_types[0].clone())
}
}

Expand Down Expand Up @@ -306,91 +284,8 @@ impl ApproxPercentileAccumulator {

// public for approx_percentile_cont_with_weight
pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
match values.data_type() {
DataType::Float64 => {
let array = downcast_value!(values, Float64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Float32 => {
let array = downcast_value!(values, Float32Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int64 => {
let array = downcast_value!(values, Int64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int32 => {
let array = downcast_value!(values, Int32Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int16 => {
let array = downcast_value!(values, Int16Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int8 => {
let array = downcast_value!(values, Int8Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt64 => {
let array = downcast_value!(values, UInt64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt32 => {
let array = downcast_value!(values, UInt32Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt16 => {
let array = downcast_value!(values, UInt16Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt8 => {
let array = downcast_value!(values, UInt8Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
e => internal_err!(
"APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
),
}
let array = values.as_primitive::<Float64Type>();
Ok(array.values().as_ref().to_vec())
}
}

Expand All @@ -406,7 +301,11 @@ impl Accumulator for ApproxPercentileAccumulator {
values = filter(&values, &is_not_null(&values)?)?;
}
let sorted_values = &arrow::compute::sort(&values, None)?;
let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
let sorted_values = sorted_values
.as_primitive::<Float64Type>()
.values()
.as_ref()
.to_vec();
self.digest = self.digest.merge_sorted_f64(&sorted_values);
Ok(())
}
Expand Down
Loading