diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 32cbb6d0aecb..69960acc57e1 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -175,6 +175,9 @@ pub enum ArrayFunctionSignature { /// The function takes a single argument that must be a List/LargeList/FixedSizeList /// or something that can be coerced to one of those types. Array, + /// A function takes a single argument that must be a List/LargeList/FixedSizeList + /// which gets coerced to List, with element type recursively coerced to List too if it is list-like. + RecursiveArray, /// Specialized Signature for MapArray /// The function takes a single argument that must be a MapArray MapArray, @@ -198,6 +201,9 @@ impl std::fmt::Display for ArrayFunctionSignature { ArrayFunctionSignature::Array => { write!(f, "array") } + ArrayFunctionSignature::RecursiveArray => { + write!(f, "recursive_array") + } ArrayFunctionSignature::MapArray => { write!(f, "map_array") } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 9d15d9693992..199f649c37e8 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,10 +21,11 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; +use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, types::{LogicalType, NativeType}, - utils::{coerced_fixed_size_list_to_list, list_ndims}, + utils::list_ndims, Result, }; use datafusion_expr_common::{ @@ -414,7 +415,16 @@ fn get_valid_types( _ => Ok(vec![vec![]]), } } + fn array(array_type: &DataType) -> Option { + match array_type { + DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()), + DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))), + _ => None, + } + } + + fn recursive_array(array_type: &DataType) -> Option { match array_type { DataType::List(_) | DataType::LargeList(_) @@ -653,6 +663,13 @@ fn get_valid_types( array(¤t_types[0]) .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) } + ArrayFunctionSignature::RecursiveArray => { + if current_types.len() != 1 { + return Ok(vec![vec![]]); + } + recursive_array(¤t_types[0]) + .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) + } ArrayFunctionSignature::MapArray => { if current_types.len() != 1 { return Ok(vec![vec![]]); diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index fc35f0076330..f972597bbf84 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -993,3 +993,86 @@ where let data = mutable.freeze(); Ok(arrow::array::make_array(data)) } + +#[cfg(test)] +mod tests { + use super::array_element_udf; + use arrow_schema::{DataType, Field}; + use datafusion_common::{Column, DFSchema, ScalarValue}; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{cast, Expr, ExprSchemable}; + use std::collections::HashMap; + + // Regression test for https://github.com/apache/datafusion/issues/13755 + #[test] + fn test_array_element_return_type_fixed_size_list() { + let fixed_size_list_type = DataType::FixedSizeList( + Field::new("some_arbitrary_test_field", DataType::Int32, false).into(), + 13, + ); + let array_type = DataType::List( + Field::new_list_field(fixed_size_list_type.clone(), true).into(), + ); + let index_type = DataType::Int64; + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("my_array", array_type.clone(), false), + Field::new("my_index", index_type.clone(), false), + ] + .into(), + HashMap::default(), + ) + .unwrap(); + + let udf = array_element_udf(); + + // ScalarUDFImpl::return_type + assert_eq!( + udf.return_type(&[array_type.clone(), index_type.clone()]) + .unwrap(), + fixed_size_list_type + ); + + // ScalarUDFImpl::return_type_from_exprs with typed exprs + assert_eq!( + udf.return_type_from_exprs( + &[ + cast(Expr::Literal(ScalarValue::Null), array_type.clone()), + cast(Expr::Literal(ScalarValue::Null), index_type.clone()), + ], + &schema, + &[array_type.clone(), index_type.clone()] + ) + .unwrap(), + fixed_size_list_type + ); + + // ScalarUDFImpl::return_type_from_exprs with exprs not carrying type + assert_eq!( + udf.return_type_from_exprs( + &[ + Expr::Column(Column::new_unqualified("my_array")), + Expr::Column(Column::new_unqualified("my_index")), + ], + &schema, + &[array_type.clone(), index_type.clone()] + ) + .unwrap(), + fixed_size_list_type + ); + + // Via ExprSchemable::get_type (e.g. SimplifyInfo) + let udf_expr = Expr::ScalarFunction(ScalarFunction { + func: array_element_udf(), + args: vec![ + Expr::Column(Column::new_unqualified("my_array")), + Expr::Column(Column::new_unqualified("my_index")), + ], + }); + assert_eq!( + ExprSchemable::get_type(&udf_expr, &schema).unwrap(), + fixed_size_list_type + ); + } +} diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index 9d2cb8a3f667..7cb52ae4c5c9 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -28,7 +28,8 @@ use datafusion_common::cast::{ use datafusion_common::{exec_err, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -56,7 +57,13 @@ impl Default for Flatten { impl Flatten { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::RecursiveArray, + ), + volatility: Volatility::Immutable, + }, aliases: vec![], } }