diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index a0a89fb3d14f..76f981986823 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1409,6 +1409,7 @@ dependencies = [ "datafusion-doc", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "datafusion-macros", "hashbrown 0.14.5", "hex", diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 7e326dc15bb2..c5f180a15035 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -245,6 +245,8 @@ impl LogicalType for NativeType { (Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size), (Self::String, LargeBinary) => LargeUtf8, (Self::String, BinaryView) => Utf8View, + // We don't cast to another kind of string type if the origin one is already a string type + (Self::String, Utf8 | LargeUtf8 | Utf8View) => origin.to_owned(), (Self::String, data_type) if can_cast_types(data_type, &Utf8View) => Utf8View, (Self::String, data_type) if can_cast_types(data_type, &LargeUtf8) => { LargeUtf8 @@ -433,4 +435,29 @@ impl NativeType { UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 ) } + + #[inline] + pub fn is_timestamp(&self) -> bool { + matches!(self, NativeType::Timestamp(_, _)) + } + + #[inline] + pub fn is_date(&self) -> bool { + matches!(self, NativeType::Date) + } + + #[inline] + pub fn is_time(&self) -> bool { + matches!(self, NativeType::Time(_)) + } + + #[inline] + pub fn is_interval(&self) -> bool { + matches!(self, NativeType::Interval(_)) + } + + #[inline] + pub fn is_duration(&self) -> bool { + matches!(self, NativeType::Duration(_)) + } } diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 32cbb6d0aecb..148ddac73a57 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -18,8 +18,10 @@ //! Signature module contains foundational types that are used to represent signatures, types, //! and return types of functions in DataFusion. +use std::fmt::Display; + use crate::type_coercion::aggregates::NUMERICS; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; use datafusion_common::types::{LogicalTypeRef, NativeType}; use itertools::Itertools; @@ -112,7 +114,7 @@ pub enum TypeSignature { /// For example, `Coercible(vec![logical_float64()])` accepts /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` /// since i32 and f32 can be casted to f64 - Coercible(Vec), + Coercible(Vec), /// The arguments will be coerced to a single type based on the comparison rules. /// For example, i32 and i64 has coerced type Int64. /// @@ -154,6 +156,33 @@ impl TypeSignature { } } +/// Represents the class of types that can be used in a function signature. +/// +/// This is used to specify what types are valid for function arguments in a more flexible way than +/// just listing specific DataTypes. For example, TypeSignatureClass::Timestamp matches any timestamp +/// type regardless of timezone or precision. +/// +/// Used primarily with TypeSignature::Coercible to define function signatures that can accept +/// arguments that can be coerced to a particular class of types. +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)] +pub enum TypeSignatureClass { + Timestamp, + Date, + Time, + Interval, + Duration, + Native(LogicalTypeRef), + // TODO: + // Numeric + // Integer +} + +impl Display for TypeSignatureClass { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TypeSignatureClass::{self:?}") + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { /// Specialized Signature for ArrayAppend and similar functions @@ -180,7 +209,7 @@ pub enum ArrayFunctionSignature { MapArray, } -impl std::fmt::Display for ArrayFunctionSignature { +impl Display for ArrayFunctionSignature { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ArrayFunctionSignature::ArrayAndElement => { @@ -255,7 +284,7 @@ impl TypeSignature { } /// Helper function to join types with specified delimiter. - pub fn join_types(types: &[T], delimiter: &str) -> String { + pub fn join_types(types: &[T], delimiter: &str) -> String { types .iter() .map(|t| t.to_string()) @@ -290,7 +319,30 @@ impl TypeSignature { .collect(), TypeSignature::Coercible(types) => types .iter() - .map(|logical_type| get_data_types(logical_type.native())) + .map(|logical_type| match logical_type { + TypeSignatureClass::Native(l) => get_data_types(l.native()), + TypeSignatureClass::Timestamp => { + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp( + TimeUnit::Nanosecond, + Some(TIMEZONE_WILDCARD.into()), + ), + ] + } + TypeSignatureClass::Date => { + vec![DataType::Date64] + } + TypeSignatureClass::Time => { + vec![DataType::Time64(TimeUnit::Nanosecond)] + } + TypeSignatureClass::Interval => { + vec![DataType::Interval(IntervalUnit::DayTime)] + } + TypeSignatureClass::Duration => { + vec![DataType::Duration(TimeUnit::Nanosecond)] + } + }) .multi_cartesian_product() .collect(), TypeSignature::Variadic(types) => types @@ -424,7 +476,10 @@ impl Signature { } } /// Target coerce types in order - pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { + pub fn coercible( + target_types: Vec, + volatility: Volatility, + ) -> Self { Self { type_signature: TypeSignature::Coercible(target_types), volatility, @@ -618,8 +673,10 @@ mod tests { ] ); - let type_signature = - TypeSignature::Coercible(vec![logical_string(), logical_int64()]); + let type_signature = TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Native(logical_int64()), + ]); let possible_types = type_signature.get_possible_types(); assert_eq!( possible_types, diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 9d15d9693992..b12489167b8f 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -22,14 +22,18 @@ use arrow::{ datatypes::{DataType, TimeUnit}, }; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, types::{LogicalType, NativeType}, utils::{coerced_fixed_size_list_to_list, list_ndims}, Result, }; use datafusion_expr_common::{ - signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, - type_coercion::binary::{comparison_coercion_numeric, string_coercion}, + signature::{ + ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD, + TIMEZONE_WILDCARD, + }, + type_coercion::binary::comparison_coercion_numeric, + type_coercion::binary::string_coercion, }; use std::sync::Arc; @@ -568,35 +572,65 @@ fn get_valid_types( // Make sure the corresponding test is covered // If this function becomes COMPLEX, create another new signature! fn can_coerce_to( - logical_type: &NativeType, - target_type: &NativeType, - ) -> bool { - if logical_type == target_type { - return true; - } + current_type: &DataType, + target_type_class: &TypeSignatureClass, + ) -> Result { + let logical_type: NativeType = current_type.into(); - if logical_type == &NativeType::Null { - return true; - } + match target_type_class { + TypeSignatureClass::Native(native_type) => { + let target_type = native_type.native(); + if &logical_type == target_type { + return target_type.default_cast_for(current_type); + } - if target_type.is_integer() && logical_type.is_integer() { - return true; - } + if logical_type == NativeType::Null { + return target_type.default_cast_for(current_type); + } + + if target_type.is_integer() && logical_type.is_integer() { + return target_type.default_cast_for(current_type); + } - false + internal_err!( + "Expect {} but received {}", + target_type_class, + current_type + ) + } + // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp + TypeSignatureClass::Timestamp + if logical_type == NativeType::String => + { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { + Ok(current_type.to_owned()) + } + TypeSignatureClass::Date if logical_type.is_date() => { + Ok(current_type.to_owned()) + } + TypeSignatureClass::Time if logical_type.is_time() => { + Ok(current_type.to_owned()) + } + TypeSignatureClass::Interval if logical_type.is_interval() => { + Ok(current_type.to_owned()) + } + TypeSignatureClass::Duration if logical_type.is_duration() => { + Ok(current_type.to_owned()) + } + _ => { + not_impl_err!("Got logical_type: {logical_type} with target_type_class: {target_type_class}") + } + } } let mut new_types = Vec::with_capacity(current_types.len()); - for (current_type, target_type) in + for (current_type, target_type_class) in current_types.iter().zip(target_types.iter()) { - let logical_type: NativeType = current_type.into(); - let target_logical_type = target_type.native(); - if can_coerce_to(&logical_type, target_logical_type) { - let target_type = - target_logical_type.default_cast_for(current_type)?; - new_types.push(target_type); - } + let target_type = can_coerce_to(current_type, target_type_class)?; + new_types.push(target_type); } vec![new_types] diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 575e8484a92f..de72c7ee946b 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -75,6 +75,7 @@ datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } datafusion-macros = { workspace = true } hashbrown = { workspace = true, optional = true } hex = { version = "0.4", optional = true } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 6d6adf2a344d..b43fcb6db706 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -23,27 +23,28 @@ use arrow::array::{Array, ArrayRef, Float64Array, Int32Array}; use arrow::compute::kernels::cast_utils::IntervalUnit; use arrow::compute::{binary, date_part, DatePart}; use arrow::datatypes::DataType::{ - Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, Utf8, Utf8View, + Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; -use arrow::datatypes::IntervalUnit::{DayTime, MonthDayNano, YearMonth}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, - as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, - as_timestamp_microsecond_array, as_timestamp_millisecond_array, - as_timestamp_nanosecond_array, as_timestamp_second_array, -}; +use datafusion_common::not_impl_err; use datafusion_common::{ - exec_err, internal_err, not_impl_err, ExprSchema, Result, ScalarValue, + cast::{ + as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, + as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, + as_timestamp_microsecond_array, as_timestamp_millisecond_array, + as_timestamp_nanosecond_array, as_timestamp_second_array, + }, + exec_err, internal_err, + types::logical_string, + ExprSchema, Result, ScalarValue, }; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; -use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, - TIMEZONE_WILDCARD, + scalar_doc_sections::DOC_SECTION_DATETIME, ColumnarValue, Documentation, Expr, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; +use datafusion_expr_common::signature::TypeSignatureClass; #[derive(Debug)] pub struct DatePartFunc { @@ -62,72 +63,26 @@ impl DatePartFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Timestamp(Nanosecond, None)]), - Exact(vec![Utf8View, Timestamp(Nanosecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![Utf8, Timestamp(Millisecond, None)]), - Exact(vec![Utf8View, Timestamp(Millisecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Timestamp, ]), - Exact(vec![Utf8, Timestamp(Microsecond, None)]), - Exact(vec![Utf8View, Timestamp(Microsecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Date, ]), - Exact(vec![ - Utf8View, - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Time, ]), - Exact(vec![Utf8, Timestamp(Second, None)]), - Exact(vec![Utf8View, Timestamp(Second, None)]), - Exact(vec![ - Utf8, - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Interval, ]), - Exact(vec![ - Utf8View, - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Duration, ]), - Exact(vec![Utf8, Date64]), - Exact(vec![Utf8View, Date64]), - Exact(vec![Utf8, Date32]), - Exact(vec![Utf8View, Date32]), - Exact(vec![Utf8, Time32(Second)]), - Exact(vec![Utf8View, Time32(Second)]), - Exact(vec![Utf8, Time32(Millisecond)]), - Exact(vec![Utf8View, Time32(Millisecond)]), - Exact(vec![Utf8, Time64(Microsecond)]), - Exact(vec![Utf8View, Time64(Microsecond)]), - Exact(vec![Utf8, Time64(Nanosecond)]), - Exact(vec![Utf8View, Time64(Nanosecond)]), - Exact(vec![Utf8, Interval(YearMonth)]), - Exact(vec![Utf8View, Interval(YearMonth)]), - Exact(vec![Utf8, Interval(DayTime)]), - Exact(vec![Utf8View, Interval(DayTime)]), - Exact(vec![Utf8, Interval(MonthDayNano)]), - Exact(vec![Utf8View, Interval(MonthDayNano)]), - Exact(vec![Utf8, Duration(Second)]), - Exact(vec![Utf8View, Duration(Second)]), - Exact(vec![Utf8, Duration(Millisecond)]), - Exact(vec![Utf8View, Duration(Millisecond)]), - Exact(vec![Utf8, Duration(Microsecond)]), - Exact(vec![Utf8View, Duration(Microsecond)]), - Exact(vec![Utf8, Duration(Nanosecond)]), - Exact(vec![Utf8View, Duration(Nanosecond)]), ], Volatility::Immutable, ), diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index d16508c6af5a..044b3549243b 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -32,6 +32,7 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr_common::signature::TypeSignatureClass; #[derive(Debug)] pub struct RepeatFunc { @@ -48,7 +49,10 @@ impl RepeatFunc { pub fn new() -> Self { Self { signature: Signature::coercible( - vec![logical_string(), logical_int64()], + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Native(logical_int64()), + ], Volatility::Immutable, ), } diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 2306eda77d35..577c7ede9079 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -560,7 +560,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32')); ---- -1.2-1.2-1.2 -query error DataFusion error: Error during planning: Error during planning: Failed to coerce arguments to satisfy a call to repeat function: coercion from \[Utf8, Float64\] to the signature +query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received Float64 select repeat('-1.2', 3.2); query T