Skip to content

Commit

Permalink
Simplify type signatures using TypeSignatureClass for mixed type fu…
Browse files Browse the repository at this point in the history
…nction signature (#13372)

* add type sig class

Signed-off-by: jayzhan211 <[email protected]>

* timestamp

Signed-off-by: jayzhan211 <[email protected]>

* date part

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* taplo format

Signed-off-by: jayzhan211 <[email protected]>

* tpch test

Signed-off-by: jayzhan211 <[email protected]>

* msrc issue

Signed-off-by: jayzhan211 <[email protected]>

* msrc issue

Signed-off-by: jayzhan211 <[email protected]>

* explicit hash

Signed-off-by: jayzhan211 <[email protected]>

* Enhance type coercion and function signatures

- Added logic to prevent unnecessary casting of string types in `native.rs`.
- Introduced `Comparable` variant in `TypeSignature` to define coercion rules for comparisons.
- Updated imports in `functions.rs` and `signature.rs` for better organization.
- Modified `date_part.rs` to improve handling of timestamp extraction and fixed query tests in `expr.slt`.
- Added `datafusion-macros` dependency in `Cargo.toml` and `Cargo.lock`.

These changes improve type handling and ensure more accurate function behavior in SQL expressions.

* fix comment

Signed-off-by: Jay Zhan <[email protected]>

* fix signature

Signed-off-by: Jay Zhan <[email protected]>

* fix test

Signed-off-by: Jay Zhan <[email protected]>

* Enhance type coercion for timestamps to allow implicit casting from strings. Update SQL logic tests to reflect changes in timestamp handling, including expected outputs for queries involving nanoseconds and seconds.

* Refactor type coercion logic for timestamps to improve readability and maintainability. Update the `TypeSignatureClass` documentation to clarify its purpose in function signatures, particularly regarding coercible types. This change enhances the handling of implicit casting from strings to timestamps.

* Fix SQL logic tests to correct query error handling for timestamp functions. Updated expected outputs for `date_part` and `extract` functions to reflect proper behavior with nanoseconds and seconds. This change improves the accuracy of test cases in the `expr.slt` file.

* Enhance timestamp handling in TypeSignature to support timezone specification. Updated the logic to include an additional DataType for timestamps with a timezone wildcard, improving flexibility in timestamp operations.

* Refactor date_part function: remove redundant imports and add missing not_impl_err import for better error handling

---------

Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: Jay Zhan <[email protected]>
  • Loading branch information
jayzhan211 authored Dec 14, 2024
1 parent 68ead28 commit 08d3b65
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 108 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ impl LogicalType for NativeType {
(Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size),
(Self::String, LargeBinary) => LargeUtf8,
(Self::String, BinaryView) => Utf8View,
// We don't cast to another kind of string type if the origin one is already a string type
(Self::String, Utf8 | LargeUtf8 | Utf8View) => origin.to_owned(),
(Self::String, data_type) if can_cast_types(data_type, &Utf8View) => Utf8View,
(Self::String, data_type) if can_cast_types(data_type, &LargeUtf8) => {
LargeUtf8
Expand Down Expand Up @@ -433,4 +435,29 @@ impl NativeType {
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64
)
}

#[inline]
pub fn is_timestamp(&self) -> bool {
matches!(self, NativeType::Timestamp(_, _))
}

#[inline]
pub fn is_date(&self) -> bool {
matches!(self, NativeType::Date)
}

#[inline]
pub fn is_time(&self) -> bool {
matches!(self, NativeType::Time(_))
}

#[inline]
pub fn is_interval(&self) -> bool {
matches!(self, NativeType::Interval(_))
}

#[inline]
pub fn is_duration(&self) -> bool {
matches!(self, NativeType::Duration(_))
}
}
73 changes: 65 additions & 8 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
//! Signature module contains foundational types that are used to represent signatures, types,
//! and return types of functions in DataFusion.
use std::fmt::Display;

use crate::type_coercion::aggregates::NUMERICS;
use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datafusion_common::types::{LogicalTypeRef, NativeType};
use itertools::Itertools;

Expand Down Expand Up @@ -112,7 +114,7 @@ pub enum TypeSignature {
/// For example, `Coercible(vec![logical_float64()])` accepts
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
/// since i32 and f32 can be casted to f64
Coercible(Vec<LogicalTypeRef>),
Coercible(Vec<TypeSignatureClass>),
/// The arguments will be coerced to a single type based on the comparison rules.
/// For example, i32 and i64 has coerced type Int64.
///
Expand Down Expand Up @@ -154,6 +156,33 @@ impl TypeSignature {
}
}

/// Represents the class of types that can be used in a function signature.
///
/// This is used to specify what types are valid for function arguments in a more flexible way than
/// just listing specific DataTypes. For example, TypeSignatureClass::Timestamp matches any timestamp
/// type regardless of timezone or precision.
///
/// Used primarily with TypeSignature::Coercible to define function signatures that can accept
/// arguments that can be coerced to a particular class of types.
#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)]
pub enum TypeSignatureClass {
Timestamp,
Date,
Time,
Interval,
Duration,
Native(LogicalTypeRef),
// TODO:
// Numeric
// Integer
}

impl Display for TypeSignatureClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TypeSignatureClass::{self:?}")
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum ArrayFunctionSignature {
/// Specialized Signature for ArrayAppend and similar functions
Expand All @@ -180,7 +209,7 @@ pub enum ArrayFunctionSignature {
MapArray,
}

impl std::fmt::Display for ArrayFunctionSignature {
impl Display for ArrayFunctionSignature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ArrayFunctionSignature::ArrayAndElement => {
Expand Down Expand Up @@ -255,7 +284,7 @@ impl TypeSignature {
}

/// Helper function to join types with specified delimiter.
pub fn join_types<T: std::fmt::Display>(types: &[T], delimiter: &str) -> String {
pub fn join_types<T: Display>(types: &[T], delimiter: &str) -> String {
types
.iter()
.map(|t| t.to_string())
Expand Down Expand Up @@ -290,7 +319,30 @@ impl TypeSignature {
.collect(),
TypeSignature::Coercible(types) => types
.iter()
.map(|logical_type| get_data_types(logical_type.native()))
.map(|logical_type| match logical_type {
TypeSignatureClass::Native(l) => get_data_types(l.native()),
TypeSignatureClass::Timestamp => {
vec![
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(
TimeUnit::Nanosecond,
Some(TIMEZONE_WILDCARD.into()),
),
]
}
TypeSignatureClass::Date => {
vec![DataType::Date64]
}
TypeSignatureClass::Time => {
vec![DataType::Time64(TimeUnit::Nanosecond)]
}
TypeSignatureClass::Interval => {
vec![DataType::Interval(IntervalUnit::DayTime)]
}
TypeSignatureClass::Duration => {
vec![DataType::Duration(TimeUnit::Nanosecond)]
}
})
.multi_cartesian_product()
.collect(),
TypeSignature::Variadic(types) => types
Expand Down Expand Up @@ -424,7 +476,10 @@ impl Signature {
}
}
/// Target coerce types in order
pub fn coercible(target_types: Vec<LogicalTypeRef>, volatility: Volatility) -> Self {
pub fn coercible(
target_types: Vec<TypeSignatureClass>,
volatility: Volatility,
) -> Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
Expand Down Expand Up @@ -618,8 +673,10 @@ mod tests {
]
);

let type_signature =
TypeSignature::Coercible(vec![logical_string(), logical_int64()]);
let type_signature = TypeSignature::Coercible(vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_int64()),
]);
let possible_types = type_signature.get_possible_types();
assert_eq!(
possible_types,
Expand Down
82 changes: 58 additions & 24 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ use arrow::{
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
types::{LogicalType, NativeType},
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
use datafusion_expr_common::{
signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
type_coercion::binary::{comparison_coercion_numeric, string_coercion},
signature::{
ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD,
TIMEZONE_WILDCARD,
},
type_coercion::binary::comparison_coercion_numeric,
type_coercion::binary::string_coercion,
};
use std::sync::Arc;

Expand Down Expand Up @@ -568,35 +572,65 @@ fn get_valid_types(
// Make sure the corresponding test is covered
// If this function becomes COMPLEX, create another new signature!
fn can_coerce_to(
logical_type: &NativeType,
target_type: &NativeType,
) -> bool {
if logical_type == target_type {
return true;
}
current_type: &DataType,
target_type_class: &TypeSignatureClass,
) -> Result<DataType> {
let logical_type: NativeType = current_type.into();

if logical_type == &NativeType::Null {
return true;
}
match target_type_class {
TypeSignatureClass::Native(native_type) => {
let target_type = native_type.native();
if &logical_type == target_type {
return target_type.default_cast_for(current_type);
}

if target_type.is_integer() && logical_type.is_integer() {
return true;
}
if logical_type == NativeType::Null {
return target_type.default_cast_for(current_type);
}

if target_type.is_integer() && logical_type.is_integer() {
return target_type.default_cast_for(current_type);
}

false
internal_err!(
"Expect {} but received {}",
target_type_class,
current_type
)
}
// Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp
TypeSignatureClass::Timestamp
if logical_type == NativeType::String =>
{
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
TypeSignatureClass::Timestamp if logical_type.is_timestamp() => {
Ok(current_type.to_owned())
}
TypeSignatureClass::Date if logical_type.is_date() => {
Ok(current_type.to_owned())
}
TypeSignatureClass::Time if logical_type.is_time() => {
Ok(current_type.to_owned())
}
TypeSignatureClass::Interval if logical_type.is_interval() => {
Ok(current_type.to_owned())
}
TypeSignatureClass::Duration if logical_type.is_duration() => {
Ok(current_type.to_owned())
}
_ => {
not_impl_err!("Got logical_type: {logical_type} with target_type_class: {target_type_class}")
}
}
}

let mut new_types = Vec::with_capacity(current_types.len());
for (current_type, target_type) in
for (current_type, target_type_class) in
current_types.iter().zip(target_types.iter())
{
let logical_type: NativeType = current_type.into();
let target_logical_type = target_type.native();
if can_coerce_to(&logical_type, target_logical_type) {
let target_type =
target_logical_type.default_cast_for(current_type)?;
new_types.push(target_type);
}
let target_type = can_coerce_to(current_type, target_type_class)?;
new_types.push(target_type);
}

vec![new_types]
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ datafusion-common = { workspace = true }
datafusion-doc = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-expr-common = { workspace = true }
datafusion-macros = { workspace = true }
hashbrown = { workspace = true, optional = true }
hex = { version = "0.4", optional = true }
Expand Down
Loading

0 comments on commit 08d3b65

Please sign in to comment.