From 17b211d7594189424e062421daf45c42d25e1008 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 1 Sep 2024 14:44:02 +0800 Subject: [PATCH 1/3] change to udf signature Signed-off-by: jayzhan211 --- .../functions-aggregate/src/approx_median.rs | 19 +- .../src/approx_percentile_cont.rs | 187 +++++------------- .../src/approx_percentile_cont_with_weight.rs | 86 ++++---- .../sqllogictest/test_files/aggregate.slt | 70 +++---- 4 files changed, 150 insertions(+), 212 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 7a7b12432544..b573c47413b4 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -20,12 +20,12 @@ use std::any::Any; use std::fmt::Debug; +use arrow::compute::can_cast_types; use arrow::{datatypes::DataType, datatypes::Field}; use arrow_schema::DataType::{Float64, UInt64}; -use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; @@ -63,7 +63,7 @@ impl ApproxMedian { /// Create a new APPROX_MEDIAN aggregate function pub fn new() -> Self { Self { - signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -116,4 +116,17 @@ impl AggregateUDFImpl for ApproxMedian { acc_args.exprs[0].data_type(acc_args.schema)?, ))) } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!("Expect to get single argument"); + } + + if arg_types[0].is_numeric() && !can_cast_types(&arg_types[0], &DataType::Float64) + { + return exec_err!("1st argument {} is not coercible to f64", arg_types[0]); + } + + Ok(vec![DataType::Float64]) + } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 5578aebbf403..7ee6bf653ffa 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -19,31 +19,22 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use arrow::array::{Array, RecordBatch}; -use arrow::compute::{filter, is_not_null}; -use arrow::{ - array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, - datatypes::DataType, -}; +use arrow::array::{Array, AsArray, RecordBatch}; +use arrow::compute::{can_cast_types, filter, is_not_null}; +use arrow::datatypes::Float64Type; +use arrow::{array::ArrayRef, datatypes::DataType}; use arrow_schema::{Field, Schema}; use datafusion_common::{ - downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, - DataFusionError, Result, ScalarValue, + exec_err, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, Result, + ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, - Volatility, -}; -use datafusion_functions_aggregate_common::tdigest::{ - TDigest, TryIntoF64, DEFAULT_MAX_SIZE, + Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, Volatility, }; +use datafusion_functions_aggregate_common::tdigest::{TDigest, DEFAULT_MAX_SIZE}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); @@ -84,21 +75,8 @@ impl Default for ApproxPercentileCont { impl ApproxPercentileCont { /// Create a new [`ApproxPercentileCont`] aggregate function. pub fn new() -> Self { - let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); - // Accept any numeric value paired with a float64 percentile - for num in NUMERICS { - variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); - // Additionally accept an integer number of centroids for T-Digest - for int in INTEGERS { - variants.push(TypeSignature::Exact(vec![ - num.clone(), - DataType::Float64, - int.clone(), - ])) - } - } Self { - signature: Signature::one_of(variants, Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } @@ -156,15 +134,12 @@ fn get_scalar_value(expr: &Arc) -> Result { fn validate_input_percentile_expr(expr: &Arc) -> Result { let percentile = match get_scalar_value(expr) .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? { - ScalarValue::Float32(Some(value)) => { - value as f64 - } ScalarValue::Float64(Some(value)) => { value } sv => { - return not_impl_err!( - "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + return internal_err!( + "Percentile value for 'APPROX_PERCENTILE_CONT' should be coerced to f64 (got data type {})", sv.data_type() ) } @@ -182,17 +157,10 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { fn validate_input_max_size_expr(expr: &Arc) -> Result { let max_size = match get_scalar_value(expr) .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? { - ScalarValue::UInt8(Some(q)) => q as usize, - ScalarValue::UInt16(Some(q)) => q as usize, - ScalarValue::UInt32(Some(q)) => q as usize, ScalarValue::UInt64(Some(q)) => q as usize, - ScalarValue::Int32(Some(q)) if q > 0 => q as usize, - ScalarValue::Int64(Some(q)) if q > 0 => q as usize, - ScalarValue::Int16(Some(q)) if q > 0 => q as usize, - ScalarValue::Int8(Some(q)) if q > 0 => q as usize, sv => { return not_impl_err!( - "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", + "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' should be coerced to u64 literal (got data type {}).", sv.data_type() ) }, @@ -257,16 +225,38 @@ impl AggregateUDFImpl for ApproxPercentileCont { Ok(Box::new(self.create_accumulator(acc_args)?)) } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("approx_percentile_cont requires numeric input types"); + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 && arg_types.len() != 3 { + return exec_err!("Expect to get 2 or 3 args"); + } + + // Check `is_numeric` to filter out numeric string case + if !arg_types[0].is_numeric() + || !can_cast_types(&arg_types[0], &DataType::Float64) + { + return exec_err!("1st argument {} is not coercible to f64", arg_types[0]); + } + if !arg_types[1].is_numeric() + || !can_cast_types(&arg_types[1], &DataType::Float64) + { + return exec_err!("2nd argument {} is not coercible to f64", arg_types[1]); + } + if arg_types.len() == 3 + && (!arg_types[2].is_integer() + || !can_cast_types(&arg_types[2], &DataType::UInt64)) + { + return exec_err!("3rd argument {} is not coercible to u64", arg_types[2]); } - if arg_types.len() == 3 && !arg_types[2].is_integer() { - return plan_err!( - "approx_percentile_cont requires integer max_size input types" - ); + + if arg_types.len() == 2 { + Ok(vec![DataType::Float64; 2]) + } else { + Ok(vec![DataType::Float64, DataType::Float64, DataType::UInt64]) } - Ok(arg_types[0].clone()) } } @@ -306,91 +296,8 @@ impl ApproxPercentileAccumulator { // public for approx_percentile_cont_with_weight pub fn convert_to_float(values: &ArrayRef) -> Result> { - match values.data_type() { - DataType::Float64 => { - let array = downcast_value!(values, Float64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Float32 => { - let array = downcast_value!(values, Float32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int64 => { - let array = downcast_value!(values, Int64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int32 => { - let array = downcast_value!(values, Int32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int16 => { - let array = downcast_value!(values, Int16Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int8 => { - let array = downcast_value!(values, Int8Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt64 => { - let array = downcast_value!(values, UInt64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt32 => { - let array = downcast_value!(values, UInt32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt16 => { - let array = downcast_value!(values, UInt16Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt8 => { - let array = downcast_value!(values, UInt8Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - e => internal_err!( - "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}" - ), - } + let array = values.as_primitive::(); + Ok(array.values().as_ref().to_vec()) } } @@ -406,7 +313,11 @@ impl Accumulator for ApproxPercentileAccumulator { values = filter(&values, &is_not_null(&values)?)?; } let sorted_values = &arrow::compute::sort(&values, None)?; - let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?; + let sorted_values = sorted_values + .as_primitive::() + .values() + .as_ref() + .to_vec(); self.digest = self.digest.merge_sorted_f64(&sorted_values); Ok(()) } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index fee67ba1623d..dc206f3b0ee7 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -19,17 +19,19 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::Arc; +use arrow::array::AsArray; +use arrow::compute::can_cast_types; +use arrow::datatypes::Float64Type; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; -use datafusion_common::ScalarValue; +use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::Volatility::Immutable; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature}; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature}; use datafusion_functions_aggregate_common::tdigest::{ Centroid, TDigest, DEFAULT_MAX_SIZE, }; @@ -68,20 +70,7 @@ impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. pub fn new() -> Self { Self { - signature: Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Immutable, - ), + signature: Signature::user_defined(Immutable), approx_percentile_cont: ApproxPercentileCont::new(), } } @@ -100,21 +89,8 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!( - "approx_percentile_cont_with_weight requires numeric input types" - ); - } - if !arg_types[1].is_numeric() { - return plan_err!( - "approx_percentile_cont_with_weight requires numeric weight input types" - ); - } - if arg_types[2] != DataType::Float64 { - return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); - } - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -151,6 +127,34 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.approx_percentile_cont.state_fields(args) } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 && arg_types.len() != 3 { + return exec_err!("Expect to get 2 or 3 args"); + } + + // Check `is_numeric` to filter out numeric string case + if arg_types[0].is_numeric() + && (!arg_types[0].is_numeric() + || !can_cast_types(&arg_types[0], &DataType::Float64)) + { + return exec_err!("1st argument {} is not coercible to f64", arg_types[0]); + } + if arg_types[1].is_numeric() + && (!arg_types[1].is_numeric() + || !can_cast_types(&arg_types[1], &DataType::Float64)) + { + return exec_err!("2nd argument {} is not coercible to f64", arg_types[1]); + } + if arg_types.len() == 3 + && (!arg_types[2].is_numeric() + || !can_cast_types(&arg_types[2], &DataType::Float64)) + { + return exec_err!("3rd argument {} is not coercible to f64", arg_types[2]); + } + + Ok(vec![DataType::Float64; arg_types.len()]) + } } #[derive(Debug)] @@ -179,13 +183,23 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { weights.len(), "invalid number of values in means and weights" ); - let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?; - let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?; + + let means = means + .as_primitive::() + .values() + .as_ref() + .to_vec(); + let weights = weights + .as_primitive::() + .values() + .as_ref() + .to_vec(); + let mut digests: Vec = vec![]; - for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) { + for (mean, weight) in means.into_iter().zip(weights.into_iter()) { digests.push(TDigest::new_with_centroid( DEFAULT_MAX_SIZE, - Centroid::new(*mean, *weight), + Centroid::new(mean, weight), )) } self.approx_percentile_cont_accumulator diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 45cb4d4615d7..6a477f2809cb 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -76,26 +76,26 @@ statement error DataFusion error: Schema error: Schema contains duplicate unqual SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight -statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*) failed(.|\n)* +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c' to value of Float64 type SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*) failed(.|\n)* +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c' to value of Float64 type SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*) failed(.|\n)* +statement error 3rd argument Utf8 is not coercible to f64 SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins -statement error DataFusion error: External error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\)\. +statement error DataFusion error: External error: Arrow error: Cast error: Can't cast value \-1000 to type UInt64 SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*) failed(.|\n)* +query error 3rd argument Utf8 is not coercible to u64 SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* +query error 3rd argument Float64 is not coercible to u64 SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Float64, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* +query error 3rd argument Float64 is not coercible to u64 SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal @@ -591,16 +591,16 @@ SELECT c2, var_samp(CASE WHEN c12 > 0.90 THEN c12 ELSE null END) FROM aggregate_ # csv_query_approx_median_1 -query I +query R SELECT approx_median(c2) FROM aggregate_test_100 ---- 3 # csv_query_approx_median_2 -query I +query R SELECT approx_median(c6) FROM aggregate_test_100 ---- -1146409980542786560 +1146409980542786600 # csv_query_approx_median_3 query R @@ -649,7 +649,7 @@ SELECT median(col_i8), median(distinct col_i8) FROM median_table -14 100 # approx_distinct_median_i8 -query I +query R SELECT approx_median(distinct col_i8) FROM median_table ---- 100 @@ -1328,13 +1328,13 @@ SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.9) AS DOUBLE) / 0.834) < 0.05 true # percentile_cont_with_nulls -query I +query R SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); ---- 2 # percentile_cont_with_nulls_only -query I +query R SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (CAST(NULL as INT))) as t (v); ---- NULL @@ -1596,43 +1596,43 @@ b 5 NULL 20135.4 b NULL NULL 7732.315789473684 # csv_query_approx_percentile_cont_with_weight -query TI +query TR SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 73 +a 73.55 b 68 -c 122 -d 124 -e 115 +c 122.5 +d 124.2 +e 115.6 # csv_query_approx_percentile_cont_with_weight (2) -query TI +query TR SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 73 +a 73.55 b 68 -c 122 -d 124 -e 115 +c 122.5 +d 124.2 +e 115.6 # csv_query_approx_percentile_cont_with_histogram_bins -query TI +query TR SELECT c1, approx_percentile_cont(c3, 0.95, 200) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 73 +a 73.55 b 68 -c 122 -d 124 -e 115 +c 122.5 +d 124.2 +e 115.6 -query TI +query TR SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 74 +a 74.2625 b 68 c 123 -d 124 -e 115 +d 124.2 +e 115.866666666667 # csv_query_sum_crossjoin query TTI @@ -2864,10 +2864,10 @@ SELECT COUNT(DISTINCT c1) FROM test # TODO: aggregate_with_alias # test_approx_percentile_cont_decimal_support -query TI +query TR SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 4 +a 4.175 b 5 c 4 d 4 @@ -4288,7 +4288,7 @@ select median(a) from (select 1 as a where 1=0); ---- NULL -query I +query R select approx_median(a) from (select 1 as a where 1=0); ---- NULL From 3d59b67dcb5da7e1fa43d0662931bb14cd4b5f02 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 1 Sep 2024 15:35:16 +0800 Subject: [PATCH 2/3] simplify code Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 46 ++++++++++--------- .../tests/dataframe/dataframe_functions.rs | 8 ++-- .../functions-aggregate/src/approx_median.rs | 21 ++------- .../src/approx_percentile_cont.rs | 30 ++++-------- .../src/approx_percentile_cont_with_weight.rs | 27 +---------- .../sqllogictest/test_files/aggregate.slt | 13 ++++-- 6 files changed, 51 insertions(+), 94 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 5dbeb535a546..1609df085d39 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2428,31 +2428,33 @@ mod tests { let df: Vec = df.select(aggr_expr)?.collect().await?; assert_batches_sorted_eq!( - ["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + [ + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", "| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |", "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", "| | | | | | | | 1 | -85 |", - "| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |", - "| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |", - "| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |", - "| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |", - "| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |", - "| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |", - "| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |", - "| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |", - "| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |", - "| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |", - "| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |", - "| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |", - "| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |", - "| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |", - "| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |", - "| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |", - "| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |", - "| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |", - "| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |", - "| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |", - "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"], + "| -85 | -101 | 14 | -12.0 | -101 | 83 | -101 | 4 | -54 |", + "| -85 | -101 | 17 | -25.0 | -101 | 83 | -101 | 5 | -31 |", + "| -85 | -12 | 10 | -32.75 | -12 | 83 | -85 | 3 | 13 |", + "| -85 | -25 | 3 | -56.0 | -25 | -25 | -85 | 1 | -5 |", + "| -85 | -31 | 18 | -29.75 | -31 | 83 | -101 | 5 | 36 |", + "| -85 | -38 | 16 | -25.0 | -38 | 83 | -101 | 4 | 65 |", + "| -85 | -43 | 7 | -43.0 | -43 | 83 | -85 | 2 | 45 |", + "| -85 | -48 | 6 | -35.75 | -48 | 83 | -85 | 2 | -43 |", + "| -85 | -5 | 4 | -37.75 | -5 | -5 | -85 | 1 | 83 |", + "| -85 | -54 | 15 | -17.0 | -54 | 83 | -101 | 4 | -38 |", + "| -85 | -56 | 2 | -70.5 | -56 | -56 | -85 | 1 | -25 |", + "| -85 | -72 | 9 | -43.0 | -72 | 83 | -85 | 3 | -12 |", + "| -85 | -85 | 1 | -85.0 | -85 | -85 | -85 | 1 | -56 |", + "| -85 | 13 | 11 | -17.0 | 13 | 83 | -85 | 3 | 14 |", + "| -85 | 13 | 11 | -25.0 | 13 | 83 | -85 | 3 | 13 |", + "| -85 | 14 | 12 | -12.0 | 14 | 83 | -85 | 3 | 17 |", + "| -85 | 17 | 13 | -11.25 | 17 | 83 | -85 | 4 | -101 |", + "| -85 | 45 | 8 | -34.5 | 45 | 83 | -85 | 3 | -72 |", + "| -85 | 65 | 17 | -17.0 | 65 | 83 | -101 | 5 | -101 |", + "| -85 | 83 | 5 | -25.0 | 83 | 83 | -85 | 2 | -48 |", + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + ], &df ); diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 1bd90fce839d..b6d1e222f57f 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -346,7 +346,7 @@ async fn test_fn_approx_median() -> Result<()> { "+-----------------------+", "| approx_median(test.b) |", "+-----------------------+", - "| 10 |", + "| 10.0 |", "+-----------------------+", ]; @@ -366,7 +366,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { "+---------------------------------------------+", "| approx_percentile_cont(test.b,Float64(0.5)) |", "+---------------------------------------------+", - "| 10 |", + "| 10.0 |", "+---------------------------------------------+", ]; @@ -387,7 +387,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { "+--------------------------------------+", "| approx_percentile_cont(test.b,arg_2) |", "+--------------------------------------+", - "| 10 |", + "| 10.0 |", "+--------------------------------------+", ]; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; @@ -400,7 +400,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { "+------------------------------------------------------+", "| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) |", "+------------------------------------------------------+", - "| 30 |", + "| 30.25 |", "+------------------------------------------------------+", ]; diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index b573c47413b4..37d8db1e4a18 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -20,11 +20,10 @@ use std::any::Any; use std::fmt::Debug; -use arrow::compute::can_cast_types; use arrow::{datatypes::DataType, datatypes::Field}; use arrow_schema::DataType::{Float64, UInt64}; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; +use datafusion_common::{not_impl_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; @@ -97,11 +96,8 @@ impl AggregateUDFImpl for ApproxMedian { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("ApproxMedian requires numeric input types"); - } - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -117,16 +113,7 @@ impl AggregateUDFImpl for ApproxMedian { ))) } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return exec_err!("Expect to get single argument"); - } - - if arg_types[0].is_numeric() && !can_cast_types(&arg_types[0], &DataType::Float64) - { - return exec_err!("1st argument {} is not coercible to f64", arg_types[0]); - } - + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { Ok(vec![DataType::Float64]) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 7ee6bf653ffa..c53b9024ea5c 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -20,7 +20,7 @@ use std::fmt::{Debug, Formatter}; use std::sync::Arc; use arrow::array::{Array, AsArray, RecordBatch}; -use arrow::compute::{can_cast_types, filter, is_not_null}; +use arrow::compute::{filter, is_not_null}; use arrow::datatypes::Float64Type; use arrow::{array::ArrayRef, datatypes::DataType}; use arrow_schema::{Field, Schema}; @@ -230,26 +230,14 @@ impl AggregateUDFImpl for ApproxPercentileCont { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 && arg_types.len() != 3 { - return exec_err!("Expect to get 2 or 3 args"); - } - - // Check `is_numeric` to filter out numeric string case - if !arg_types[0].is_numeric() - || !can_cast_types(&arg_types[0], &DataType::Float64) - { - return exec_err!("1st argument {} is not coercible to f64", arg_types[0]); - } - if !arg_types[1].is_numeric() - || !can_cast_types(&arg_types[1], &DataType::Float64) - { - return exec_err!("2nd argument {} is not coercible to f64", arg_types[1]); - } - if arg_types.len() == 3 - && (!arg_types[2].is_integer() - || !can_cast_types(&arg_types[2], &DataType::UInt64)) - { - return exec_err!("3rd argument {} is not coercible to u64", arg_types[2]); + if arg_types.len() == 3 { + // Since float is coercible to u64 in `can_cast_types`, we check whether it is integer + if !arg_types[2].is_integer() { + return exec_err!( + "3rd argument should be integer but got {}", + arg_types[2] + ); + } } if arg_types.len() == 2 { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index dc206f3b0ee7..b8b5e3fa9dd9 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -20,14 +20,13 @@ use std::fmt::{Debug, Formatter}; use std::sync::Arc; use arrow::array::AsArray; -use arrow::compute::can_cast_types; use arrow::datatypes::Float64Type; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; @@ -129,30 +128,6 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 && arg_types.len() != 3 { - return exec_err!("Expect to get 2 or 3 args"); - } - - // Check `is_numeric` to filter out numeric string case - if arg_types[0].is_numeric() - && (!arg_types[0].is_numeric() - || !can_cast_types(&arg_types[0], &DataType::Float64)) - { - return exec_err!("1st argument {} is not coercible to f64", arg_types[0]); - } - if arg_types[1].is_numeric() - && (!arg_types[1].is_numeric() - || !can_cast_types(&arg_types[1], &DataType::Float64)) - { - return exec_err!("2nd argument {} is not coercible to f64", arg_types[1]); - } - if arg_types.len() == 3 - && (!arg_types[2].is_numeric() - || !can_cast_types(&arg_types[2], &DataType::Float64)) - { - return exec_err!("3rd argument {} is not coercible to f64", arg_types[2]); - } - Ok(vec![DataType::Float64; arg_types.len()]) } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 6a477f2809cb..14c020a30c71 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -82,20 +82,20 @@ SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c' to value of Float64 type SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 -statement error 3rd argument Utf8 is not coercible to f64 +statement error DataFusion error: This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: CAST\(c1@0 AS Float64\) SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins statement error DataFusion error: External error: Arrow error: Cast error: Can't cast value \-1000 to type UInt64 SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -query error 3rd argument Utf8 is not coercible to u64 +query error 3rd argument should be integer but got Utf8 SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 -query error 3rd argument Float64 is not coercible to u64 +query error 3rd argument should be integer but got Float64 SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 -query error 3rd argument Float64 is not coercible to u64 +query error 3rd argument should be integer but got Float64 SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal @@ -591,6 +591,11 @@ SELECT c2, var_samp(CASE WHEN c12 > 0.90 THEN c12 ELSE null END) FROM aggregate_ # csv_query_approx_median_1 +query R +select approx_median('1'); +---- +1 + query R SELECT approx_median(c2) FROM aggregate_test_100 ---- From 9223fe46330d32841695776373be45b695d35925 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 4 Sep 2024 19:46:17 +0800 Subject: [PATCH 3/3] rebase and use coericible Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 40 +++++++++---------- .../functions-aggregate/src/approx_median.rs | 9 ++--- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2138bd1294b4..946d9ad0da22 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2433,26 +2433,26 @@ mod tests { "| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |", "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", "| | | | | | | | 1 | -85 |", - "| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |", - "| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |", - "| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |", - "| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |", - "| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |", - "| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |", - "| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |", - "| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |", - "| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |", - "| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |", - "| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |", - "| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |", - "| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |", - "| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |", - "| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |", - "| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |", - "| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |", - "| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |", - "| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |", - "| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |", + "| -85 | -101 | 14 | -12.0 | -101 | 83 | -101 | 4 | -54 |", + "| -85 | -101 | 17 | -25.0 | -101 | 83 | -101 | 5 | -31 |", + "| -85 | -12 | 10 | -32.75 | -12 | 83 | -85 | 3 | 13 |", + "| -85 | -25 | 3 | -56.0 | -25 | -25 | -85 | 1 | -5 |", + "| -85 | -31 | 18 | -29.75 | -31 | 83 | -101 | 5 | 36 |", + "| -85 | -38 | 16 | -25.0 | -38 | 83 | -101 | 4 | 65 |", + "| -85 | -43 | 7 | -43.0 | -43 | 83 | -85 | 2 | 45 |", + "| -85 | -48 | 6 | -35.75 | -48 | 83 | -85 | 2 | -43 |", + "| -85 | -5 | 4 | -37.75 | -5 | -5 | -85 | 1 | 83 |", + "| -85 | -54 | 15 | -17.0 | -54 | 83 | -101 | 4 | -38 |", + "| -85 | -56 | 2 | -70.5 | -56 | -56 | -85 | 1 | -25 |", + "| -85 | -72 | 9 | -43.0 | -72 | 83 | -85 | 3 | -12 |", + "| -85 | -85 | 1 | -85.0 | -85 | -85 | -85 | 1 | -56 |", + "| -85 | 13 | 11 | -17.0 | 13 | 83 | -85 | 3 | 14 |", + "| -85 | 13 | 11 | -25.0 | 13 | 83 | -85 | 3 | 13 |", + "| -85 | 14 | 12 | -12.0 | 14 | 83 | -85 | 3 | 17 |", + "| -85 | 17 | 13 | -11.25 | 17 | 83 | -85 | 4 | -101 |", + "| -85 | 45 | 8 | -34.5 | 45 | 83 | -85 | 3 | -72 |", + "| -85 | 65 | 17 | -17.0 | 65 | 83 | -101 | 5 | -101 |", + "| -85 | 83 | 5 | -25.0 | 83 | 83 | -85 | 2 | -48 |", "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", ], &df diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 37d8db1e4a18..e1bb6e6f06ca 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -62,7 +62,10 @@ impl ApproxMedian { /// Create a new APPROX_MEDIAN aggregate function pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible( + vec![DataType::Float64], + Volatility::Immutable, + ), } } } @@ -112,8 +115,4 @@ impl AggregateUDFImpl for ApproxMedian { acc_args.exprs[0].data_type(acc_args.schema)?, ))) } - - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { - Ok(vec![DataType::Float64]) - } }