Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into feat/sqlite-strftime-…
Browse files Browse the repository at this point in the history
…support
  • Loading branch information
alamb committed Sep 3, 2024
2 parents da5f09d + e4a9424 commit dfbd5b6
Show file tree
Hide file tree
Showing 18 changed files with 568 additions and 152 deletions.
6 changes: 4 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2428,7 +2428,8 @@ mod tests {
let df: Vec<RecordBatch> = df.select(aggr_expr)?.collect().await?;

assert_batches_sorted_eq!(
["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
[
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| | | | | | | | 1 | -85 |",
Expand All @@ -2452,7 +2453,8 @@ mod tests {
"| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"],
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
],
&df
);

Expand Down
15 changes: 14 additions & 1 deletion datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ pub enum TypeSignature {
Uniform(usize, Vec<DataType>),
/// Exact number of arguments of an exact type
Exact(Vec<DataType>),
/// The number of arguments that can be coerced to in order
/// For example, `Coercible(vec![DataType::Float64])` accepts
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
/// since i32 and f32 can be casted to f64
Coercible(Vec<DataType>),
/// Fixed number of arguments of arbitrary types
/// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
Any(usize),
Expand Down Expand Up @@ -188,7 +193,7 @@ impl TypeSignature {
TypeSignature::Numeric(num) => {
vec![format!("Numeric({})", num)]
}
TypeSignature::Exact(types) => {
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Any(arg_count) => {
Expand Down Expand Up @@ -300,6 +305,14 @@ impl Signature {
volatility,
}
}
/// Target coerce types in order
pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
}
}

/// A specified number of arguments of any type
pub fn any(arg_count: usize, volatility: Volatility) -> Self {
Signature {
Expand Down
6 changes: 4 additions & 2 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ pub fn check_arg_count(
);
}
}
TypeSignature::UserDefined | TypeSignature::Numeric(_) => {
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_) => {
// User-defined signature is validated in `coerce_types`
// Numreic signature is validated in `get_valid_types`
// Numeric and Coercible signature is validated in `get_valid_types`
}
_ => {
return internal_err!(
Expand Down
33 changes: 32 additions & 1 deletion datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,14 @@ fn try_coerce_types(
let mut valid_types = valid_types;

// Well-supported signature that returns exact valid types.
if !valid_types.is_empty() && matches!(type_signature, TypeSignature::UserDefined) {
if !valid_types.is_empty()
&& matches!(
type_signature,
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_)
)
{
// exact valid types
assert_eq!(valid_types.len(), 1);
let valid_types = valid_types.swap_remove(0);
Expand Down Expand Up @@ -397,6 +404,30 @@ fn get_valid_types(

vec![vec![valid_type; *number]]
}
TypeSignature::Coercible(target_types) => {
if target_types.is_empty() {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if target_types.len() != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
target_types.len(),
current_types.len()
);
}

for (data_type, target_type) in current_types.iter().zip(target_types.iter())
{
if !can_cast_types(data_type, target_type) {
return plan_err!("{data_type} is not coercible to {target_type}");
}
}

vec![target_types.to_owned()]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
Expand Down
11 changes: 5 additions & 6 deletions datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ impl Stddev {
/// Create a new STDDEV aggregate function
pub fn new() -> Self {
Self {
signature: Signature::numeric(1, Volatility::Immutable),
signature: Signature::coercible(
vec![DataType::Float64],
Volatility::Immutable,
),
alias: vec!["stddev_samp".to_string()],
}
}
Expand All @@ -88,11 +91,7 @@ impl AggregateUDFImpl for Stddev {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Stddev requires numeric input types");
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

Expand Down
11 changes: 5 additions & 6 deletions datafusion/functions-aggregate/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ impl VarianceSample {
pub fn new() -> Self {
Self {
aliases: vec![String::from("var_sample"), String::from("var_samp")],
signature: Signature::numeric(1, Volatility::Immutable),
signature: Signature::coercible(
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}
Expand All @@ -97,11 +100,7 @@ impl AggregateUDFImpl for VarianceSample {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Variance requires numeric input types");
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

Expand Down
21 changes: 9 additions & 12 deletions datafusion/functions/src/core/named_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,17 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
}
}

let arrays = ColumnarValue::values_to_arrays(&values)?;

let fields = names
let fields: Fields = names
.into_iter()
.zip(arrays)
.map(|(name, value)| {
(
Arc::new(Field::new(name, value.data_type().clone(), true)),
value,
)
})
.collect::<Vec<_>>();
.zip(&values)
.map(|(name, value)| Arc::new(Field::new(name, value.data_type().clone(), true)))
.collect::<Vec<_>>()
.into();

let arrays = ColumnarValue::values_to_arrays(&values)?;

Ok(ColumnarValue::Array(Arc::new(StructArray::from(fields))))
let struct_array = StructArray::new(fields, arrays, None);
Ok(ColumnarValue::Array(Arc::new(struct_array)))
}

#[derive(Debug)]
Expand Down
68 changes: 12 additions & 56 deletions datafusion/functions/src/core/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,31 @@ fn array_struct(args: &[ArrayRef]) -> Result<ArrayRef> {
return exec_err!("struct requires at least one argument");
}

let vec: Vec<_> = args
let fields = args
.iter()
.enumerate()
.map(|(i, arg)| {
let field_name = format!("c{i}");
Ok((
Arc::new(Field::new(
field_name.as_str(),
arg.data_type().clone(),
true,
)),
Arc::clone(arg),
))
Ok(Arc::new(Field::new(
field_name.as_str(),
arg.data_type().clone(),
true,
)))
})
.collect::<Result<Vec<_>>>()?;
.collect::<Result<Vec<_>>>()?
.into();

Ok(Arc::new(StructArray::from(vec)))
let arrays = args.to_vec();

Ok(Arc::new(StructArray::new(fields, arrays, None)))
}

/// put values in a struct array.
fn struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let arrays = ColumnarValue::values_to_arrays(args)?;
Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?))
}

#[derive(Debug)]
pub struct StructFunc {
signature: Signature,
Expand Down Expand Up @@ -97,48 +98,3 @@ impl ScalarUDFImpl for StructFunc {
struct_expr(args)
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int64Array;
use datafusion_common::cast::as_struct_array;
use datafusion_common::ScalarValue;

#[test]
fn test_struct() {
// struct(1, 2, 3) = {"c0": 1, "c1": 2, "c2": 3}
let args = [
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
];
let struc = struct_expr(&args)
.expect("failed to initialize function struct")
.into_array(1)
.expect("Failed to convert to array");
let result =
as_struct_array(&struc).expect("failed to initialize function struct");
assert_eq!(
&Int64Array::from(vec![1]),
Arc::clone(result.column_by_name("c0").unwrap())
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
);
assert_eq!(
&Int64Array::from(vec![2]),
Arc::clone(result.column_by_name("c1").unwrap())
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
);
assert_eq!(
&Int64Array::from(vec![3]),
Arc::clone(result.column_by_name("c2").unwrap())
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
);
}
}
Loading

0 comments on commit dfbd5b6

Please sign in to comment.