From 1bd311ae7df42bb548a2069dd39fbc76e96e0e42 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 13 Dec 2024 14:52:54 +0100 Subject: [PATCH] Fix recursive flatten --- datafusion/expr-common/src/signature.rs | 6 ++++++ .../expr/src/type_coercion/functions.rs | 21 +++++++++++++++++++ datafusion/functions-nested/src/flatten.rs | 11 ++++++++-- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 32cbb6d0aecb..69960acc57e1 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -175,6 +175,9 @@ pub enum ArrayFunctionSignature { /// The function takes a single argument that must be a List/LargeList/FixedSizeList /// or something that can be coerced to one of those types. Array, + /// A function takes a single argument that must be a List/LargeList/FixedSizeList + /// which gets coerced to List, with element type recursively coerced to List too if it is list-like. + RecursiveArray, /// Specialized Signature for MapArray /// The function takes a single argument that must be a MapArray MapArray, @@ -198,6 +201,9 @@ impl std::fmt::Display for ArrayFunctionSignature { ArrayFunctionSignature::Array => { write!(f, "array") } + ArrayFunctionSignature::RecursiveArray => { + write!(f, "recursive_array") + } ArrayFunctionSignature::MapArray => { write!(f, "map_array") } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 9fbd46e37f7c..cbc1ce428c2f 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,6 +21,7 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; +use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, types::{LogicalType, NativeType}, @@ -414,6 +415,7 @@ fn get_valid_types( _ => Ok(vec![vec![]]), } } + fn array(array_type: &DataType) -> Option { match array_type { DataType::List(_) => Some(array_type.clone()), @@ -424,6 +426,18 @@ fn get_valid_types( } } + fn recursive_array(array_type: &DataType) -> Option { + match array_type { + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) => { + let array_type = coerced_fixed_size_list_to_list(array_type); + Some(array_type) + } + _ => None, + } + } + fn function_length_check(length: usize, expected_length: usize) -> Result<()> { if length < 1 { return plan_err!( @@ -651,6 +665,13 @@ fn get_valid_types( array(¤t_types[0]) .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) } + ArrayFunctionSignature::RecursiveArray => { + if current_types.len() != 1 { + return Ok(vec![vec![]]); + } + recursive_array(¤t_types[0]) + .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) + } ArrayFunctionSignature::MapArray => { if current_types.len() != 1 { return Ok(vec![vec![]]); diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index 9d2cb8a3f667..7cb52ae4c5c9 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -28,7 +28,8 @@ use datafusion_common::cast::{ use datafusion_common::{exec_err, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -56,7 +57,13 @@ impl Default for Flatten { impl Flatten { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::RecursiveArray, + ), + volatility: Volatility::Immutable, + }, aliases: vec![], } }