From 97b26f18230758eb96b19669bd058eb677336f3a Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 17 Mar 2023 17:37:21 -0400 Subject: [PATCH] All timestamp UDFs should accept nanoseconds / Date32 / Date64 (#257) * All timestamp UDFs should accept nanoseconds * Date32 and Date64 should be allowed everywhere timestamps are * Allow casting to timestamp in duckdb and postgres * More Date32/Date64 support * Cast to timestamp for other dialects --- .../src/udfs/datetime/date_add_tz.rs | 31 ++++++++++++--- .../src/udfs/datetime/date_part_tz.rs | 29 ++++++++++++-- .../src/udfs/datetime/date_trunc_tz.rs | 26 +++++++++++-- .../src/udfs/datetime/format_timestamp.rs | 2 + .../src/udfs/datetime/from_utc_timestamp.rs | 2 + .../src/udfs/datetime/to_utc_timestamp.rs | 4 ++ .../udfs/datetime/utc_timestamp_to_epoch.rs | 12 +++++- .../src/udfs/datetime/utc_timestamp_to_str.rs | 2 + .../date_time/date_format.rs | 11 +++++- .../builtin_functions/date_time/date_parts.rs | 12 +++++- .../builtin_functions/date_time/time.rs | 2 +- vegafusion-sql/src/dialect/mod.rs | 38 ++++++++++++++++++- 12 files changed, 151 insertions(+), 20 deletions(-) diff --git a/vegafusion-datafusion-udfs/src/udfs/datetime/date_add_tz.rs b/vegafusion-datafusion-udfs/src/udfs/datetime/date_add_tz.rs index 14f54bcd6..390c292c0 100644 --- a/vegafusion-datafusion-udfs/src/udfs/datetime/date_add_tz.rs +++ b/vegafusion-datafusion-udfs/src/udfs/datetime/date_add_tz.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use vegafusion_common::datafusion_expr::TypeSignature; use vegafusion_common::{ arrow::datatypes::{DataType, TimeUnit}, datafusion_expr::{ @@ -15,12 +16,32 @@ fn make_date_add_tz_udf() -> ScalarUDF { let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Timestamp(TimeUnit::Millisecond, None)))); - let signature = Signature::exact( + let signature = Signature::one_of( vec![ - DataType::Utf8, - DataType::Int32, - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Utf8, + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int32, + DataType::Date32, + DataType::Utf8, + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int32, + DataType::Date64, + DataType::Utf8, + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int32, + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Utf8, + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int32, + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Utf8, + ]), ], Volatility::Immutable, ); diff --git a/vegafusion-datafusion-udfs/src/udfs/datetime/date_part_tz.rs b/vegafusion-datafusion-udfs/src/udfs/datetime/date_part_tz.rs index 38f191537..d358ae35a 100644 --- a/vegafusion-datafusion-udfs/src/udfs/datetime/date_part_tz.rs +++ b/vegafusion-datafusion-udfs/src/udfs/datetime/date_part_tz.rs @@ -1,7 +1,9 @@ use crate::udfs::datetime::from_utc_timestamp::from_utc_timestamp; +use crate::udfs::datetime::to_utc_timestamp::to_timestamp_ms; use datafusion_physical_expr::datetime_expressions; use std::str::FromStr; use std::sync::Arc; +use vegafusion_common::datafusion_expr::TypeSignature; use vegafusion_common::{ arrow::datatypes::{DataType, TimeUnit}, datafusion_common::DataFusionError, @@ -19,6 +21,8 @@ fn make_date_part_tz_udf() -> ScalarUDF { ColumnarValue::Scalar(scalar) => scalar.to_array(), }; + let timestamp_array = to_timestamp_ms(×tamp_array)?; + // [2] timezone string let tz_str = if let ColumnarValue::Scalar(default_input_tz) = &args[2] { default_input_tz.to_string() @@ -48,11 +52,28 @@ fn make_date_part_tz_udf() -> ScalarUDF { let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Float64))); - let signature = Signature::exact( + let signature = Signature::one_of( vec![ - DataType::Utf8, // part - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Utf8, // timezone + TypeSignature::Exact(vec![ + DataType::Utf8, // part + DataType::Date32, + DataType::Utf8, // timezone + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, // part + DataType::Date64, + DataType::Utf8, // timezone + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, // part + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Utf8, // timezone + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, // part + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Utf8, // timezone + ]), ], Volatility::Immutable, ); diff --git a/vegafusion-datafusion-udfs/src/udfs/datetime/date_trunc_tz.rs b/vegafusion-datafusion-udfs/src/udfs/datetime/date_trunc_tz.rs index 097966d72..3314ce270 100644 --- a/vegafusion-datafusion-udfs/src/udfs/datetime/date_trunc_tz.rs +++ b/vegafusion-datafusion-udfs/src/udfs/datetime/date_trunc_tz.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use vegafusion_common::datafusion_expr::TypeSignature; use vegafusion_common::{ arrow::datatypes::{DataType, TimeUnit}, datafusion_expr::{ @@ -15,11 +16,28 @@ fn make_date_trunc_tz_udf() -> ScalarUDF { let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Timestamp(TimeUnit::Millisecond, None)))); - let signature = Signature::exact( + let signature = Signature::one_of( vec![ - DataType::Utf8, // part - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Utf8, // timezone + TypeSignature::Exact(vec![ + DataType::Utf8, // part + DataType::Date32, + DataType::Utf8, // timezone + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, // part + DataType::Date64, + DataType::Utf8, // timezone + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, // part + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Utf8, // timezone + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, // part + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Utf8, // timezone + ]), ], Volatility::Immutable, ); diff --git a/vegafusion-datafusion-udfs/src/udfs/datetime/format_timestamp.rs b/vegafusion-datafusion-udfs/src/udfs/datetime/format_timestamp.rs index bc7791704..35228436b 100644 --- a/vegafusion-datafusion-udfs/src/udfs/datetime/format_timestamp.rs +++ b/vegafusion-datafusion-udfs/src/udfs/datetime/format_timestamp.rs @@ -73,6 +73,8 @@ fn make_format_timestamp_udf() -> ScalarUDF { let signature: Signature = Signature::one_of( vec![ + TypeSignature::Exact(vec![DataType::Date32, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Date64, DataType::Utf8]), TypeSignature::Exact(vec![ DataType::Timestamp(TimeUnit::Millisecond, None), DataType::Utf8, diff --git a/vegafusion-datafusion-udfs/src/udfs/datetime/from_utc_timestamp.rs b/vegafusion-datafusion-udfs/src/udfs/datetime/from_utc_timestamp.rs index 9381c1d7b..190abe336 100644 --- a/vegafusion-datafusion-udfs/src/udfs/datetime/from_utc_timestamp.rs +++ b/vegafusion-datafusion-udfs/src/udfs/datetime/from_utc_timestamp.rs @@ -53,6 +53,8 @@ fn make_from_utc_timestamp() -> ScalarUDF { let signature: Signature = Signature::one_of( vec![ + TypeSignature::Exact(vec![DataType::Date32, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Date64, DataType::Utf8]), TypeSignature::Exact(vec![ DataType::Timestamp(TimeUnit::Millisecond, None), DataType::Utf8, diff --git a/vegafusion-datafusion-udfs/src/udfs/datetime/to_utc_timestamp.rs b/vegafusion-datafusion-udfs/src/udfs/datetime/to_utc_timestamp.rs index 9c584c450..f665893c2 100644 --- a/vegafusion-datafusion-udfs/src/udfs/datetime/to_utc_timestamp.rs +++ b/vegafusion-datafusion-udfs/src/udfs/datetime/to_utc_timestamp.rs @@ -115,6 +115,10 @@ pub fn to_timestamp_ms(array: &ArrayRef) -> Result { )?) } } + DataType::Date32 => Ok(cast( + array, + &DataType::Timestamp(TimeUnit::Millisecond, None), + )?), DataType::Date64 => Ok(cast( array, &DataType::Timestamp(TimeUnit::Millisecond, None), diff --git a/vegafusion-datafusion-udfs/src/udfs/datetime/utc_timestamp_to_epoch.rs b/vegafusion-datafusion-udfs/src/udfs/datetime/utc_timestamp_to_epoch.rs index 42b39aaa7..01ef27608 100644 --- a/vegafusion-datafusion-udfs/src/udfs/datetime/utc_timestamp_to_epoch.rs +++ b/vegafusion-datafusion-udfs/src/udfs/datetime/utc_timestamp_to_epoch.rs @@ -1,4 +1,6 @@ +use crate::udfs::datetime::to_utc_timestamp::to_timestamp_ms; use std::sync::Arc; +use vegafusion_common::datafusion_expr::TypeSignature; use vegafusion_common::{ arrow::{ compute::cast, @@ -18,6 +20,7 @@ fn make_utc_timestamp_to_epoch_ms_udf() -> ScalarUDF { ColumnarValue::Array(array) => array.clone(), ColumnarValue::Scalar(scalar) => scalar.to_array(), }; + let data_array = to_timestamp_ms(&data_array)?; // cast timestamp millis to Int64 let result_array = cast(&data_array, &DataType::Int64)?; @@ -31,8 +34,13 @@ fn make_utc_timestamp_to_epoch_ms_udf() -> ScalarUDF { }); let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Int64))); - let signature: Signature = Signature::exact( - vec![DataType::Timestamp(TimeUnit::Millisecond, None)], + let signature = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Date32]), + TypeSignature::Exact(vec![DataType::Date64]), + TypeSignature::Exact(vec![DataType::Timestamp(TimeUnit::Millisecond, None)]), + TypeSignature::Exact(vec![DataType::Timestamp(TimeUnit::Nanosecond, None)]), + ], Volatility::Immutable, ); diff --git a/vegafusion-datafusion-udfs/src/udfs/datetime/utc_timestamp_to_str.rs b/vegafusion-datafusion-udfs/src/udfs/datetime/utc_timestamp_to_str.rs index 627b1c846..12c4f6ec3 100644 --- a/vegafusion-datafusion-udfs/src/udfs/datetime/utc_timestamp_to_str.rs +++ b/vegafusion-datafusion-udfs/src/udfs/datetime/utc_timestamp_to_str.rs @@ -80,6 +80,8 @@ fn make_utc_timestamp_to_str_udf() -> ScalarUDF { let signature = Signature::one_of( vec![ + TypeSignature::Exact(vec![DataType::Date32, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Date64, DataType::Utf8]), TypeSignature::Exact(vec![ DataType::Timestamp(TimeUnit::Millisecond, None), DataType::Utf8, diff --git a/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/date_format.rs b/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/date_format.rs index 3df4b1203..d0fb8ef26 100644 --- a/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/date_format.rs +++ b/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/date_format.rs @@ -1,9 +1,10 @@ use crate::task_graph::timezone::RuntimeTzConfig; -use datafusion_expr::{lit, Expr, ExprSchemable}; +use datafusion_expr::{expr, lit, Expr, ExprSchemable}; use std::sync::Arc; use vegafusion_common::arrow::datatypes::DataType; use vegafusion_common::datafusion_common::{DFSchema, ScalarValue}; use vegafusion_common::datatypes::{cast_to, is_numeric_datatype}; +use vegafusion_core::arrow::datatypes::TimeUnit; use vegafusion_core::error::{Result, VegaFusionError}; use vegafusion_datafusion_udfs::udfs::datetime::epoch_to_utc_timestamp::EPOCH_MS_TO_UTC_TIMESTAMP_UDF; use vegafusion_datafusion_udfs::udfs::datetime::format_timestamp::FORMAT_TIMESTAMP_UDF; @@ -91,6 +92,14 @@ pub fn utc_format_fn( fn to_timestamptz_expr(arg: &Expr, schema: &DFSchema, default_input_tz: &str) -> Result { Ok(match arg.get_type(schema)? { + DataType::Date32 => Expr::Cast(expr::Cast { + expr: Box::new(arg.clone()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + }), + DataType::Date64 => Expr::Cast(expr::Cast { + expr: Box::new(arg.clone()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + }), DataType::Timestamp(_, _) => arg.clone(), DataType::Utf8 => Expr::ScalarUDF { fun: Arc::new((*STR_TO_UTC_TIMESTAMP_UDF).clone()), diff --git a/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/date_parts.rs b/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/date_parts.rs index ac085d8c6..d5e5226f0 100644 --- a/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/date_parts.rs +++ b/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/date_parts.rs @@ -1,8 +1,8 @@ use crate::expression::compiler::call::TzTransformFn; use crate::task_graph::timezone::RuntimeTzConfig; -use datafusion_expr::{floor, lit, Expr, ExprSchemable}; +use datafusion_expr::{expr, floor, lit, Expr, ExprSchemable}; use std::sync::Arc; -use vegafusion_common::arrow::datatypes::DataType; +use vegafusion_common::arrow::datatypes::{DataType, TimeUnit}; use vegafusion_common::datafusion_common::DFSchema; use vegafusion_common::datatypes::{cast_to, is_numeric_datatype}; use vegafusion_core::error::{Result, VegaFusionError}; @@ -64,6 +64,14 @@ fn extract_timestamp_arg( ) -> Result { if let Some(arg) = args.get(0) { Ok(match arg.get_type(schema)? { + DataType::Date32 => Expr::Cast(expr::Cast { + expr: Box::new(arg.clone()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + }), + DataType::Date64 => Expr::Cast(expr::Cast { + expr: Box::new(arg.clone()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + }), DataType::Timestamp(_, _) => arg.clone(), DataType::Utf8 => Expr::ScalarUDF { fun: Arc::new((*STR_TO_UTC_TIMESTAMP_UDF).clone()), diff --git a/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/time.rs b/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/time.rs index c774b8837..13a707768 100644 --- a/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/time.rs +++ b/vegafusion-runtime/src/expression/compiler/builtin_functions/date_time/time.rs @@ -22,7 +22,7 @@ pub fn time_fn(tz_config: &RuntimeTzConfig, args: &[Expr], schema: &DFSchema) -> // Dispatch handling on data type let expr = match arg.get_type(schema)? { - DataType::Timestamp(_, _) => Expr::ScalarUDF { + DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64 => Expr::ScalarUDF { fun: Arc::new((*UTC_TIMESTAMP_TO_EPOCH_MS).clone()), args: vec![arg.clone()], }, diff --git a/vegafusion-sql/src/dialect/mod.rs b/vegafusion-sql/src/dialect/mod.rs index a075ed287..4654bbff7 100644 --- a/vegafusion-sql/src/dialect/mod.rs +++ b/vegafusion-sql/src/dialect/mod.rs @@ -41,7 +41,7 @@ use crate::dialect::transforms::utc_timestamp_to_str::{ UtcTimestampToStrDuckDBTransformer, UtcTimestampToStrPostgresTransformer, UtcTimestampToStrSnowflakeTransformer, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::scalar::ScalarValue; use datafusion_common::DFSchema; use datafusion_expr::{lit, ExprSchemable}; @@ -329,6 +329,10 @@ impl Dialect { (DataType::Float32, SqlDataType::Double), (DataType::Float64, SqlDataType::Double), (DataType::Utf8, SqlDataType::Varchar(None)), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(), @@ -450,6 +454,10 @@ impl Dialect { (DataType::Float32, float64dtype.clone()), (DataType::Float64, float64dtype), (DataType::Utf8, SqlDataType::String), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(), @@ -544,6 +552,10 @@ impl Dialect { (DataType::Float32, SqlDataType::Float(None)), (DataType::Float64, SqlDataType::Double), (DataType::Utf8, SqlDataType::Varchar(None)), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(), @@ -681,6 +693,10 @@ impl Dialect { (DataType::Float32, SqlDataType::Float(None)), (DataType::Float64, SqlDataType::Double), (DataType::Utf8, SqlDataType::String), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(), @@ -820,6 +836,10 @@ impl Dialect { (DataType::Float32, SqlDataType::Float(None)), (DataType::Float64, SqlDataType::Double), (DataType::Utf8, SqlDataType::String), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(), @@ -969,6 +989,10 @@ impl Dialect { (DataType::Float32, SqlDataType::Float(None)), (DataType::Float64, SqlDataType::Double), (DataType::Utf8, SqlDataType::Varchar(None)), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(), @@ -1213,6 +1237,10 @@ impl Dialect { (DataType::Float32, SqlDataType::Real), (DataType::Float64, SqlDataType::DoublePrecision), (DataType::Utf8, SqlDataType::Text), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(), @@ -1339,6 +1367,10 @@ impl Dialect { (DataType::Float32, SqlDataType::Real), (DataType::Float64, SqlDataType::DoublePrecision), (DataType::Utf8, SqlDataType::Text), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(), @@ -1475,6 +1507,10 @@ impl Dialect { (DataType::Float32, SqlDataType::Float(None)), (DataType::Float64, SqlDataType::Double), (DataType::Utf8, SqlDataType::Varchar(None)), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + SqlDataType::Timestamp(None, TimezoneInfo::None), + ), ] .into_iter() .collect(),