From e288828915a03ef48cc6fab42f019365bd36dec3 Mon Sep 17 00:00:00 2001 From: Michael Levin Date: Thu, 12 Dec 2024 00:17:11 -0800 Subject: [PATCH 1/2] Support binary temporal arithmetic with integers --- .../expr-common/src/type_coercion/binary.rs | 81 +++++++++++++++++++ datafusion/expr/src/expr_schema.rs | 60 ++++++++++++++ .../expr/src/type_coercion/functions.rs | 32 ++++++++ 3 files changed, 173 insertions(+) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 7a6e9841e22c..9359ac0b2b2b 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -186,6 +186,12 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) { // Numeric arithmetic, e.g. Int32 + Int32 Ok(Signature::uniform(numeric)) + } else if let Some((new_lhs, new_rhs, ret)) = resolve_ints_to_intervals(lhs, rhs) { + Ok(Signature { + lhs: new_lhs, + rhs: new_rhs, + ret, + }) } else { plan_err!( "Cannot coerce arithmetic expression {lhs} {op} {rhs} to valid types" @@ -1449,6 +1455,22 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { } } +/// Resolves integer types to interval types for temporal arithmetic +fn resolve_ints_to_intervals( + lhs: &DataType, + rhs: &DataType, +) -> Option<(DataType, DataType, DataType)> { + use arrow::datatypes::DataType::*; + use arrow::datatypes::IntervalUnit::*; + + match (lhs, rhs) { + // Handle integer + interval cases + (Int32 | Int64, _) => Some((Interval(DayTime), rhs.clone(), rhs.clone())), + (_, Int32 | Int64) => Some((lhs.clone(), Interval(DayTime), lhs.clone())), + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; @@ -1607,6 +1629,18 @@ mod tests { }}; } + /// Test coercion rules for assymetric binary operators + /// + /// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that the + /// the result types are `$RESULT_TYPE1` and `$RESULT_TYPE2` respectively + macro_rules! test_coercion_assymetric_binary_rule { + ($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE1:expr, $RESULT_TYPE2:expr) => {{ + let (lhs, rhs) = get_input_types(&$LHS_TYPE, &$OP, &$RHS_TYPE)?; + assert_eq!(lhs, $RESULT_TYPE1); + assert_eq!(rhs, $RESULT_TYPE2); + }}; + } + /// Test coercion rules for like /// /// Applies coercion rules for both @@ -1837,6 +1871,8 @@ mod tests { #[test] fn test_type_coercion_arithmetic() -> Result<()> { + use arrow::datatypes::IntervalUnit; + // integer test_coercion_binary_rule!( DataType::Int32, @@ -1869,6 +1905,51 @@ mod tests { Operator::Multiply, DataType::Float64 ); + + // Test integer to interval coercion for temporal arithmetic + test_coercion_assymetric_binary_rule!( + DataType::Int32, + DataType::Timestamp(TimeUnit::Nanosecond, None), + Operator::Plus, + DataType::Interval(IntervalUnit::DayTime), + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + test_coercion_assymetric_binary_rule!( + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Int64, + Operator::Plus, + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Interval(IntervalUnit::DayTime) + ); + test_coercion_assymetric_binary_rule!( + DataType::Int32, + DataType::Date32, + Operator::Plus, + DataType::Interval(IntervalUnit::DayTime), + DataType::Date32 + ); + test_coercion_assymetric_binary_rule!( + DataType::Date32, + DataType::Int64, + Operator::Plus, + DataType::Date32, + DataType::Interval(IntervalUnit::DayTime) + ); + test_coercion_assymetric_binary_rule!( + DataType::Int32, + DataType::Date64, + Operator::Plus, + DataType::Interval(IntervalUnit::DayTime), + DataType::Date64 + ); + test_coercion_assymetric_binary_rule!( + DataType::Date64, + DataType::Int64, + Operator::Plus, + DataType::Date64, + DataType::Interval(IntervalUnit::DayTime) + ); + // TODO add other data type Ok(()) } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3317deafbd6c..3507c00ebfee 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -437,6 +437,8 @@ impl ExprSchemable for Expr { /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result { + use datafusion_common::ScalarValue; + let this_type = self.get_type(schema)?; if this_type == *cast_to_type { return Ok(self); @@ -453,6 +455,26 @@ impl ExprSchemable for Expr { } _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))), } + } else if matches!( + (&this_type, cast_to_type), + (DataType::Int32 | DataType::Int64, DataType::Interval(_)) + ) { + // Convert integer (days) to the corresponding DayTime interval + match self { + Expr::Literal(ScalarValue::Int32(Some(days))) => { + Ok(Expr::Literal(ScalarValue::IntervalDayTime(Some( + arrow_buffer::IntervalDayTime::new(days, 0), + )))) + } + Expr::Literal(ScalarValue::Int64(Some(days))) => { + Ok(Expr::Literal(ScalarValue::IntervalDayTime(Some( + arrow_buffer::IntervalDayTime::new(days as i32, 0), + )))) + } + _ => plan_err!( + "Cannot automatically convert {this_type:?} to {cast_to_type:?}" + ), + } } else { plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}") } @@ -761,4 +783,42 @@ mod tests { Ok((self.data_type(col)?, self.nullable(col)?)) } } + + #[test] + fn test_cast_int_to_interval() -> Result<()> { + use arrow::datatypes::IntervalUnit; + + let schema = MockExprSchema::new().with_data_type(DataType::Int32); + + // Test casting Int32 literal to Interval + let expr = lit(ScalarValue::Int32(Some(5))); + let result = expr.cast_to(&DataType::Interval(IntervalUnit::DayTime), &schema)?; + assert_eq!( + result, + Expr::Literal(ScalarValue::IntervalDayTime(Some( + arrow_buffer::IntervalDayTime::new(5, 0) + ))) + ); + + // Test casting Int64 literal to Interval + let expr = lit(ScalarValue::Int64(Some(7))); + let result = expr.cast_to(&DataType::Interval(IntervalUnit::DayTime), &schema)?; + assert_eq!( + result, + Expr::Literal(ScalarValue::IntervalDayTime(Some( + arrow_buffer::IntervalDayTime::new(7, 0) + ))) + ); + + // Test that non-literal expressions cannot be cast from int to interval + let expr = col("foo") + lit(1); + let err = expr + .cast_to(&DataType::Interval(IntervalUnit::DayTime), &schema) + .unwrap_err(); + assert!(err + .to_string() + .contains("Cannot automatically convert Int32 to Interval(DayTime)")); + + Ok(()) + } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 9d15d9693992..ecc4a3dcaf00 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -880,6 +880,9 @@ fn coerced_from<'a>( (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => { Some(type_into.clone()) } + // Support Date?? + Int?? + (Date32, Int32 | Int64) | (Int32 | Int64, Date32) => Some(Date32), + (Date64, Int32 | Int64) | (Int32 | Int64, Date64) => Some(Date64), _ => None, } } @@ -1072,4 +1075,33 @@ mod tests { Some(type_into.clone()) ); } + + #[test] + fn test_date_coercion_return_values() { + let test_cases = vec![ + // Date32 cases - should return Date32 when coercion is possible + (DataType::Date32, DataType::Int32, Some(DataType::Date32)), + (DataType::Date32, DataType::Int64, Some(DataType::Date32)), + (DataType::Int32, DataType::Date32, Some(DataType::Date32)), + (DataType::Int64, DataType::Date32, Some(DataType::Date32)), + // Date64 cases - should return Date64 when coercion is possible + (DataType::Date64, DataType::Int32, Some(DataType::Date64)), + (DataType::Date64, DataType::Int64, Some(DataType::Date64)), + (DataType::Int32, DataType::Date64, Some(DataType::Date64)), + (DataType::Int64, DataType::Date64, Some(DataType::Date64)), + // Negative cases - should return None when coercion is not possible + (DataType::Date32, DataType::Int16, None), + (DataType::Date64, DataType::Int16, None), + (DataType::Int16, DataType::Date32, None), + (DataType::Int16, DataType::Date64, None), + ]; + + for (type_into, type_from, expected) in test_cases { + assert_eq!( + coerced_from(&type_into, &type_from), + expected, + "Coercion from {type_from:?} to {type_into:?} should return {expected:?}" + ); + } + } } From 1d604769f415057934827f62c6f7245433de7221 Mon Sep 17 00:00:00 2001 From: Michael Levin Date: Thu, 12 Dec 2024 00:56:47 -0800 Subject: [PATCH 2/2] PR feedback: cleaning up tests; adding type restriction --- .../expr-common/src/type_coercion/binary.rs | 37 +++++++------------ 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 9359ac0b2b2b..0f774286fac9 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -1464,9 +1464,13 @@ fn resolve_ints_to_intervals( use arrow::datatypes::IntervalUnit::*; match (lhs, rhs) { - // Handle integer + interval cases - (Int32 | Int64, _) => Some((Interval(DayTime), rhs.clone(), rhs.clone())), - (_, Int32 | Int64) => Some((lhs.clone(), Interval(DayTime), lhs.clone())), + // Handle integer + temporal types cases + (Int32 | Int64, rhs) if rhs.is_temporal() => { + Some((Interval(DayTime), rhs.clone(), rhs.clone())) + } + (lhs, Int32 | Int64) if lhs.is_temporal() => { + Some((lhs.clone(), Interval(DayTime), lhs.clone())) + } _ => None, } } @@ -1907,20 +1911,7 @@ mod tests { ); // Test integer to interval coercion for temporal arithmetic - test_coercion_assymetric_binary_rule!( - DataType::Int32, - DataType::Timestamp(TimeUnit::Nanosecond, None), - Operator::Plus, - DataType::Interval(IntervalUnit::DayTime), - DataType::Timestamp(TimeUnit::Nanosecond, None) - ); - test_coercion_assymetric_binary_rule!( - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Int64, - Operator::Plus, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Interval(IntervalUnit::DayTime) - ); + // (Using Date32 only since the logic is invariant wrt the temporal type) test_coercion_assymetric_binary_rule!( DataType::Int32, DataType::Date32, @@ -1930,23 +1921,23 @@ mod tests { ); test_coercion_assymetric_binary_rule!( DataType::Date32, - DataType::Int64, + DataType::Int32, Operator::Plus, DataType::Date32, DataType::Interval(IntervalUnit::DayTime) ); test_coercion_assymetric_binary_rule!( - DataType::Int32, - DataType::Date64, + DataType::Int64, + DataType::Date32, Operator::Plus, DataType::Interval(IntervalUnit::DayTime), - DataType::Date64 + DataType::Date32 ); test_coercion_assymetric_binary_rule!( - DataType::Date64, + DataType::Date32, DataType::Int64, Operator::Plus, - DataType::Date64, + DataType::Date32, DataType::Interval(IntervalUnit::DayTime) );