diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index 5b7315719631..90cf8bcbd057 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -23,13 +23,12 @@ use arrow::datatypes::{DataType, Field}; use arrow_array::types::{Date32Type, IntervalMonthDayNanoType}; use arrow_array::NullArray; use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; -use arrow_schema::DataType::{Date32, Int64, Interval, List}; +use arrow_schema::DataType::*; use arrow_schema::IntervalUnit::MonthDayNano; use datafusion_common::cast::{as_date32_array, as_int64_array, as_interval_mdn_array}; use datafusion_common::{exec_err, not_impl_datafusion_err, Result}; -use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, -}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use itertools::Itertools; use std::any::Any; use std::iter::from_fn; use std::sync::Arc; @@ -49,16 +48,7 @@ pub(super) struct Range { impl Range { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![], } } @@ -75,9 +65,34 @@ impl ScalarUDFImpl for Range { &self.signature } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + arg_types + .iter() + .map(|arg_type| match arg_type { + Null => Ok(Null), + Int8 => Ok(Int64), + Int16 => Ok(Int64), + Int32 => Ok(Int64), + Int64 => Ok(Int64), + UInt8 => Ok(Int64), + UInt16 => Ok(Int64), + UInt32 => Ok(Int64), + UInt64 => Ok(Int64), + Timestamp(_, _) => Ok(Date32), + Date32 => Ok(Date32), + Date64 => Ok(Date32), + Utf8 => Ok(Date32), + LargeUtf8 => Ok(Date32), + Utf8View => Ok(Date32), + Interval(_) => Ok(Interval(MonthDayNano)), + _ => exec_err!("Unsupported DataType"), + }) + .try_collect() + } + fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.iter().any(|t| t.eq(&DataType::Null)) { - Ok(DataType::Null) + if arg_types.iter().any(|t| t.is_null()) { + Ok(Null) } else { Ok(List(Arc::new(Field::new( "item", @@ -88,7 +103,7 @@ impl ScalarUDFImpl for Range { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - if args.iter().any(|arg| arg.data_type() == DataType::Null) { + if args.iter().any(|arg| arg.data_type().is_null()) { return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); } match args[0].data_type() { @@ -120,16 +135,7 @@ pub(super) struct GenSeries { impl GenSeries { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![], } } @@ -146,9 +152,34 @@ impl ScalarUDFImpl for GenSeries { &self.signature } + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + _arg_types + .iter() + .map(|arg_type| match arg_type { + Null => Ok(Null), + Int8 => Ok(Int64), + Int16 => Ok(Int64), + Int32 => Ok(Int64), + Int64 => Ok(Int64), + UInt8 => Ok(Int64), + UInt16 => Ok(Int64), + UInt32 => Ok(Int64), + UInt64 => Ok(Int64), + Timestamp(_, _) => Ok(Date32), + Date32 => Ok(Date32), + Date64 => Ok(Date32), + Utf8 => Ok(Date32), + LargeUtf8 => Ok(Date32), + Utf8View => Ok(Date32), + Interval(_) => Ok(Interval(MonthDayNano)), + _ => exec_err!("Unsupported DataType"), + }) + .try_collect() + } + fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.iter().any(|t| t.eq(&DataType::Null)) { - Ok(DataType::Null) + if arg_types.iter().any(|t| t.is_null()) { + Ok(Null) } else { Ok(List(Arc::new(Field::new( "item", @@ -159,7 +190,7 @@ impl ScalarUDFImpl for GenSeries { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - if args.iter().any(|arg| arg.data_type() == DataType::Null) { + if args.iter().any(|arg| arg.data_type().is_null()) { return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); } match args[0].data_type() { @@ -167,7 +198,7 @@ impl ScalarUDFImpl for GenSeries { Date32 => make_scalar_function(|args| gen_range_date(args, true))(args), dt => { exec_err!( - "unsupported type for range. Expected Int64 or Date32, got: {}", + "unsupported type for gen_series. Expected Int64 or Date32, got: {}", dt ) } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b97ecced57e3..249241a51aea 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5804,7 +5804,7 @@ select generate_series(5), ---- [0, 1, 2, 3, 4, 5] [2, 3, 4, 5] [2, 5, 8] [1, 2, 3, 4, 5] [5, 4, 3, 2, 1] [10, 7, 4] [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01, 1993-03-01] [1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 1993-01-27, 1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 1993-01-21, 1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 1993-01-15, 1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 1993-01-09, 1993-01-08, 1993-01-07, 1993-01-06, 1993-01-05, 1993-01-04, 1993-01-03, 1993-01-02, 1993-01-01] [1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01] -query error DataFusion error: Execution error: unsupported type for range. Expected Int64 or Date32, got: Timestamp\(Nanosecond, None\) +query error DataFusion error: Execution error: Cannot generate date range less than 1 day\. select generate_series('2021-01-01'::timestamp, '2021-01-02'::timestamp, INTERVAL '1' HOUR); ## should return NULL @@ -5936,11 +5936,12 @@ select generate_series(start, '1993-03-01'::date, INTERVAL '1 year') from date_t # https://github.com/apache/datafusion/issues/11922 -query error +query ? select generate_series(start, '1993-03-01', INTERVAL '1 year') from date_table; ---- -DataFusion error: Internal error: could not cast value to arrow_array::array::primitive_array::PrimitiveArray. -This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +[1992-01-01, 1993-01-01] +[1993-02-01] +[1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01] ## array_except