From 7783b394bb9c830ee1466bf729dd9544964f75f5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 12 Jul 2024 13:21:50 -0600 Subject: [PATCH] chore: Move `cast` to `spark-expr` crate (#654) * refactor in preparation for moving cast to spark-expr crate * errors * move cast to spark-expr crate * machete * refactor errors * clean up imports --- native/Cargo.lock | 35 ++- native/Cargo.toml | 4 + native/core/Cargo.toml | 7 +- .../execution/datafusion/expressions/mod.rs | 2 +- .../execution/datafusion/expressions/utils.rs | 249 +----------------- native/core/src/execution/mod.rs | 2 +- native/spark-expr/Cargo.toml | 5 + .../expressions => spark-expr/src}/cast.rs | 119 +++++++-- native/spark-expr/src/lib.rs | 1 + native/utils/Cargo.toml | 5 + native/utils/src/lib.rs | 162 ++++++++++++ .../src/execution => utils/src}/timezone.rs | 0 12 files changed, 309 insertions(+), 282 deletions(-) rename native/{core/src/execution/datafusion/expressions => spark-expr/src}/cast.rs (94%) rename native/{core/src/execution => utils/src}/timezone.rs (100%) diff --git a/native/Cargo.lock b/native/Cargo.lock index 6bba0a8e4..f64b7b634 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -482,9 +482,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.106" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "066fce287b1d4eafef758e89e09d724a24808a9196fe9756b8ca90e86d0719a2" +checksum = "907d8581360765417f8f2e0e7d602733bbed60156b4465b7617243689ef9b83d" dependencies = [ "jobserver", "libc", @@ -588,18 +588,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.8" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84b3edb18336f4df585bc9aa31dd99c036dfa5dc5e9a2939a722a188f3a8970d" +checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.8" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1c09dd5ada6c6c78075d6fd0da3f90d8080651e2d6cc8eb2f1aaa4034ced708" +checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" dependencies = [ "anstyle", "clap_lex", @@ -862,7 +862,6 @@ dependencies = [ "brotli", "bytes", "chrono", - "chrono-tz 0.8.6", "crc32fast", "criterion", "datafusion", @@ -908,12 +907,17 @@ name = "datafusion-comet-spark-expr" version = "0.1.0" dependencies = [ "arrow", + "arrow-array", "arrow-schema", + "chrono", "datafusion", "datafusion-comet-utils", "datafusion-common", + "datafusion-expr", "datafusion-functions", "datafusion-physical-expr", + "num", + "regex", "thiserror", ] @@ -921,6 +925,11 @@ dependencies = [ name = "datafusion-comet-utils" version = "0.1.0" dependencies = [ + "arrow", + "arrow-array", + "arrow-schema", + "chrono", + "chrono-tz 0.8.6", "datafusion-physical-plan", ] @@ -2723,18 +2732,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "f2675633b1499176c2dff06b0856a27976a8f9d436737b4cf4f312d4d91d8bbb" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "d20468752b09f49e909e55a5d338caa8bedf615594e9d80bc4c565d30faf798c" dependencies = [ "proc-macro2", "quote", @@ -2938,9 +2947,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.9.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" dependencies = [ "getrandom", ] diff --git a/native/Cargo.toml b/native/Cargo.toml index 0b392833a..09865742b 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -48,6 +48,10 @@ datafusion-physical-expr-common = { git = "https://github.com/apache/datafusion. datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0", default-features = false } datafusion-comet-spark-expr = { path = "spark-expr", version = "0.1.0" } datafusion-comet-utils = { path = "utils", version = "0.1.0" } +chrono = { version = "0.4", default-features = false, features = ["clock"] } +chrono-tz = { version = "0.8" } +num = "0.4" +regex = "1.9.6" thiserror = "1" [profile.release] diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 50c1ce2b3..8e02324c6 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -59,13 +59,12 @@ flate2 = "1.0" lz4 = "1.24" zstd = "0.11" rand = "0.8" -num = "0.4" +num = { workspace = true } bytes = "1.5.0" tempfile = "3.8.0" ahash = { version = "0.8", default-features = false } itertools = "0.11.0" -chrono = { version = "0.4", default-features = false, features = ["clock"] } -chrono-tz = { version = "0.8" } +chrono = { workspace = true } paste = "1.0.14" datafusion-common = { workspace = true } datafusion = { workspace = true } @@ -74,7 +73,7 @@ datafusion-physical-expr-common = { workspace = true } datafusion-physical-expr = { workspace = true } unicode-segmentation = "^1.10.1" once_cell = "1.18.0" -regex = "1.9.6" +regex = { workspace = true } crc32fast = "1.3.2" simd-adler32 = "0.3.7" datafusion-comet-spark-expr = { workspace = true } diff --git a/native/core/src/execution/datafusion/expressions/mod.rs b/native/core/src/execution/datafusion/expressions/mod.rs index c61266cea..f6fb26b6a 100644 --- a/native/core/src/execution/datafusion/expressions/mod.rs +++ b/native/core/src/execution/datafusion/expressions/mod.rs @@ -18,7 +18,7 @@ //! Native DataFusion expressions pub mod bitwise_not; -pub mod cast; +pub use datafusion_comet_spark_expr::cast; pub mod checkoverflow; mod normalize_nan; pub mod scalar_funcs; diff --git a/native/core/src/execution/datafusion/expressions/utils.rs b/native/core/src/execution/datafusion/expressions/utils.rs index 6a7ec2e12..04e41e0ba 100644 --- a/native/core/src/execution/datafusion/expressions/utils.rs +++ b/native/core/src/execution/datafusion/expressions/utils.rs @@ -15,250 +15,5 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::timezone::Tz; -use arrow::{ - array::{ - as_dictionary_array, as_primitive_array, Array, ArrayRef, GenericStringArray, - PrimitiveArray, - }, - compute::unary, - datatypes::{Int32Type, Int64Type, TimestampMicrosecondType}, - error::ArrowError, - temporal_conversions::as_datetime, -}; -use arrow_array::{cast::AsArray, types::ArrowPrimitiveType}; -use arrow_schema::DataType; -use chrono::{DateTime, Offset, TimeZone}; -use datafusion_common::cast::as_generic_string_array; -use num::integer::div_floor; -use std::sync::Arc; - -pub use datafusion_comet_utils::down_cast_any_ref; - -/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or -/// to apply timezone offset. -// -// We consider the following cases: -// -// | --------------------- | ------------ | ----------------- | -------------------------------- | -// | Conversion | Input array | Timezone | Output array | -// | --------------------- | ------------ | ----------------- | -------------------------------- | -// | Timestamp -> | Array in UTC | Timezone of input | A timestamp with the timezone | -// | Utf8 or Date32 | | | offset applied and timezone | -// | | | | removed | -// | --------------------- | ------------ | ----------------- | -------------------------------- | -// | Timestamp -> | Array in UTC | Timezone of input | Same as input array | -// | Timestamp w/Timezone| | | | -// | --------------------- | ------------ | ----------------- | -------------------------------- | -// | Timestamp_ntz -> | Array in | Timezone of input | Same as input array | -// | Utf8 or Date32 | timezone | | | -// | | session local| | | -// | | timezone | | | -// | --------------------- | ------------ | ----------------- | -------------------------------- | -// | Timestamp_ntz -> | Array in | Timezone of input | Array in UTC and timezone | -// | Timestamp w/Timezone | session local| | specified in input | -// | | timezone | | | -// | --------------------- | ------------ | ----------------- | -------------------------------- | -// | Timestamp(_ntz) -> | | -// | Any other type | Not Supported | -// | --------------------- | ------------ | ----------------- | -------------------------------- | -// -pub fn array_with_timezone( - array: ArrayRef, - timezone: String, - to_type: Option<&DataType>, -) -> Result { - match array.data_type() { - DataType::Timestamp(_, None) => { - assert!(!timezone.is_empty()); - match to_type { - Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array), - Some(DataType::Timestamp(_, Some(_))) => { - timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str())) - } - _ => { - // Not supported - panic!( - "Cannot convert from {:?} to {:?}", - array.data_type(), - to_type.unwrap() - ) - } - } - } - DataType::Timestamp(_, Some(_)) => { - assert!(!timezone.is_empty()); - let array = as_primitive_array::(&array); - let array_with_timezone = array.clone().with_timezone(timezone.clone()); - let array = Arc::new(array_with_timezone) as ArrayRef; - match to_type { - Some(DataType::Utf8) | Some(DataType::Date32) => { - pre_timestamp_cast(array, timezone) - } - _ => Ok(array), - } - } - DataType::Dictionary(_, value_type) - if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => - { - let dict = as_dictionary_array::(&array); - let array = as_primitive_array::(dict.values()); - let array_with_timezone = - array_with_timezone(Arc::new(array.clone()) as ArrayRef, timezone, to_type)?; - let dict = dict.with_values(array_with_timezone); - Ok(Arc::new(dict)) - } - _ => Ok(array), - } -} - -fn datetime_cast_err(value: i64) -> ArrowError { - ArrowError::CastError(format!( - "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE", - )) -} - -/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns -/// a Timestamp(Microsecond, Some<_>) array. -/// The understanding is that the input array has time in the timezone specified in the second -/// argument. -/// Parameters: -/// array - input array of timestamp without timezone -/// tz - timezone of the values in the input array -/// to_timezone - timezone to change the input values to -fn timestamp_ntz_to_timestamp( - array: ArrayRef, - tz: &str, - to_timezone: Option<&str>, -) -> Result { - assert!(!tz.is_empty()); - match array.data_type() { - DataType::Timestamp(_, None) => { - let array = as_primitive_array::(&array); - let tz: Tz = tz.parse()?; - let array: PrimitiveArray = array.try_unary(|value| { - as_datetime::(value) - .ok_or_else(|| datetime_cast_err(value)) - .map(|local_datetime| { - let datetime: DateTime = - tz.from_local_datetime(&local_datetime).unwrap(); - datetime.timestamp_micros() - }) - })?; - let array_with_tz = if let Some(to_tz) = to_timezone { - array.with_timezone(to_tz) - } else { - array - }; - Ok(Arc::new(array_with_tz)) - } - _ => Ok(array), - } -} - -const MICROS_PER_SECOND: i64 = 1000000; - -/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String. -fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result { - assert!(!timezone.is_empty()); - match array.data_type() { - DataType::Timestamp(_, _) => { - // Spark doesn't output timezone while casting timestamp to string, but arrow's cast - // kernel does if timezone exists. So we need to apply offset of timezone to array - // timestamp value and remove timezone from array datatype. - let array = as_primitive_array::(&array); - - let tz: Tz = timezone.parse()?; - let array: PrimitiveArray = array.try_unary(|value| { - as_datetime::(value) - .ok_or_else(|| datetime_cast_err(value)) - .map(|datetime| { - let offset = tz.offset_from_utc_datetime(&datetime).fix(); - let datetime = datetime + offset; - datetime.and_utc().timestamp_micros() - }) - })?; - - Ok(Arc::new(array)) - } - _ => Ok(array), - } -} - -/// This takes for special casting cases of Spark. E.g., Timestamp to Long. -/// This function runs as a post process of the DataFusion cast(). By the time it arrives here, -/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify -/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in -/// expressions/cast.rs, so it can be still Dictionary. -pub(crate) fn spark_cast(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { - match (from_type, to_type) { - (DataType::Timestamp(_, _), DataType::Int64) => { - // See Spark's `Cast` expression - unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() - } - (DataType::Dictionary(_, value_type), DataType::Int64) - if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => - { - // See Spark's `Cast` expression - unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() - } - (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array), - (DataType::Dictionary(_, value_type), DataType::Utf8) - if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => - { - remove_trailing_zeroes(array) - } - _ => array, - } -} - -/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated -fn unary_dyn(array: &ArrayRef, op: F) -> Result -where - T: ArrowPrimitiveType, - F: Fn(T::Native) -> T::Native, -{ - if let Some(d) = array.as_any_dictionary_opt() { - let new_values = unary_dyn::(d.values(), op)?; - return Ok(Arc::new(d.with_values(Arc::new(new_values)))); - } - - match array.as_primitive_opt::() { - Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { - Ok(Arc::new(unary::( - array.as_any().downcast_ref::>().unwrap(), - op, - ))) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation of type {} on array of type {}", - T::DATA_TYPE, - array.data_type() - ))), - } -} - -/// Remove any trailing zeroes in the string if they occur after in the fractional seconds, -/// to match Spark behavior -/// example: -/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9" -/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99" -/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999" -/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00" -/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001" -fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef { - let string_array = as_generic_string_array::(&array).unwrap(); - let result = string_array - .iter() - .map(|s| s.map(trim_end)) - .collect::>(); - Arc::new(result) as ArrayRef -} - -fn trim_end(s: &str) -> &str { - if s.rfind('.').is_some() { - s.trim_end_matches('0') - } else { - s - } -} +// re-export for legacy reasons +pub use datafusion_comet_utils::{array_with_timezone, down_cast_any_ref}; diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index b3be83b5f..a13a1bc85 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -26,7 +26,7 @@ pub mod operators; pub mod serde; pub mod shuffle; pub(crate) mod sort; -mod timezone; +pub use datafusion_comet_utils::timezone; pub(crate) mod utils; mod memory_pool; diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 4a9b94087..220417fe8 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -28,12 +28,17 @@ edition = { workspace = true } [dependencies] arrow = { workspace = true } +arrow-array = { workspace = true } arrow-schema = { workspace = true } +chrono = { workspace = true } datafusion = { workspace = true } datafusion-common = { workspace = true } datafusion-functions = { workspace = true } +datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-comet-utils = { workspace = true } +num = { workspace = true } +regex = { workspace = true } thiserror = { workspace = true } [lib] diff --git a/native/core/src/execution/datafusion/expressions/cast.rs b/native/spark-expr/src/cast.rs similarity index 94% rename from native/core/src/execution/datafusion/expressions/cast.rs rename to native/spark-expr/src/cast.rs index 0b513e776..b9cf2790b 100644 --- a/native/core/src/execution/datafusion/expressions/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -24,36 +24,45 @@ use std::{ }; use arrow::{ - compute::{cast_with_options, CastOptions}, + array::{ + cast::AsArray, + types::{Date32Type, Int16Type, Int32Type, Int8Type}, + Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, + GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, + PrimitiveArray, + }, + compute::{cast_with_options, unary, CastOptions}, datatypes::{ - ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, + ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type, TimestampMicrosecondType, }, + error::ArrowError, record_batch::RecordBatch, util::display::FormatOptions, }; -use arrow_array::{ - types::{Date32Type, Int16Type, Int32Type, Int64Type, Int8Type}, - Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, GenericStringArray, - Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, -}; use arrow_schema::{DataType, Schema}; -use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; -use datafusion::logical_expr::ColumnarValue; -use datafusion_comet_spark_expr::{SparkError, SparkResult}; -use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; + +use datafusion_common::{ + cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, +}; +use datafusion_expr::ColumnarValue; use datafusion_physical_expr::PhysicalExpr; -use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive}; -use regex::Regex; -use crate::execution::datafusion::expressions::utils::{ - array_with_timezone, down_cast_any_ref, spark_cast, +use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; +use num::{ + cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, + ToPrimitive, }; +use regex::Regex; + +use datafusion_comet_utils::{array_with_timezone, down_cast_any_ref}; -use super::EvalMode; +use crate::{EvalMode, SparkError, SparkResult}; static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); +const MICROS_PER_SECOND: i64 = 1000000; + static CAST_OPTIONS: CastOptions = CastOptions { safe: true, format_options: FormatOptions::new() @@ -1633,6 +1642,84 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult> } } +/// This takes for special casting cases of Spark. E.g., Timestamp to Long. +/// This function runs as a post process of the DataFusion cast(). By the time it arrives here, +/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify +/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in +/// expressions/cast.rs, so it can be still Dictionary. +fn spark_cast(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { + match (from_type, to_type) { + (DataType::Timestamp(_, _), DataType::Int64) => { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() + } + (DataType::Dictionary(_, value_type), DataType::Int64) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() + } + (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array), + (DataType::Dictionary(_, value_type), DataType::Utf8) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + remove_trailing_zeroes(array) + } + _ => array, + } +} + +/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated +fn unary_dyn(array: &ArrayRef, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + if let Some(d) = array.as_any_dictionary_opt() { + let new_values = unary_dyn::(d.values(), op)?; + return Ok(Arc::new(d.with_values(Arc::new(new_values)))); + } + + match array.as_primitive_opt::() { + Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { + Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + array.data_type() + ))), + } +} + +/// Remove any trailing zeroes in the string if they occur after in the fractional seconds, +/// to match Spark behavior +/// example: +/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9" +/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99" +/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999" +/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00" +/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001" +fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef { + let string_array = as_generic_string_array::(&array).unwrap(); + let result = string_array + .iter() + .map(|s| s.map(trim_end)) + .collect::>(); + Arc::new(result) as ArrayRef +} + +fn trim_end(s: &str) -> &str { + if s.rfind('.').is_some() { + s.trim_end_matches('0') + } else { + s + } +} + #[cfg(test)] mod tests { use arrow::datatypes::TimestampMicrosecondType; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 57da56f9a..93c7f249e 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -16,6 +16,7 @@ // under the License. mod abs; +pub mod cast; mod error; mod if_expr; diff --git a/native/utils/Cargo.toml b/native/utils/Cargo.toml index 05ddd3488..f9ae47433 100644 --- a/native/utils/Cargo.toml +++ b/native/utils/Cargo.toml @@ -27,6 +27,11 @@ license = { workspace = true } edition = { workspace = true } [dependencies] +arrow = { workspace = true } +arrow-array = { workspace = true } +arrow-schema = { workspace = true } +chrono = { workspace = true } +chrono-tz = { workspace = true } datafusion-physical-plan = { workspace = true } [lib] diff --git a/native/utils/src/lib.rs b/native/utils/src/lib.rs index 54ff55b46..4600abfaf 100644 --- a/native/utils/src/lib.rs +++ b/native/utils/src/lib.rs @@ -15,9 +15,23 @@ // specific language governing permissions and limitations // under the License. +use arrow_array::{ + cast::as_primitive_array, + types::{Int32Type, TimestampMicrosecondType}, +}; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; +use arrow::{ + array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray}, + temporal_conversions::as_datetime, +}; +use chrono::{DateTime, Offset, TimeZone}; +use timezone::Tz; + +pub mod timezone; + use datafusion_physical_plan::PhysicalExpr; /// A utility function from DataFusion. It is not exposed by DataFusion. @@ -34,3 +48,151 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { any } } + +/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or +/// to apply timezone offset. +// +// We consider the following cases: +// +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Conversion | Input array | Timezone | Output array | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp -> | Array in UTC | Timezone of input | A timestamp with the timezone | +// | Utf8 or Date32 | | | offset applied and timezone | +// | | | | removed | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp -> | Array in UTC | Timezone of input | Same as input array | +// | Timestamp w/Timezone| | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp_ntz -> | Array in | Timezone of input | Same as input array | +// | Utf8 or Date32 | timezone | | | +// | | session local| | | +// | | timezone | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp_ntz -> | Array in | Timezone of input | Array in UTC and timezone | +// | Timestamp w/Timezone | session local| | specified in input | +// | | timezone | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp(_ntz) -> | | +// | Any other type | Not Supported | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// +pub fn array_with_timezone( + array: ArrayRef, + timezone: String, + to_type: Option<&DataType>, +) -> Result { + match array.data_type() { + DataType::Timestamp(_, None) => { + assert!(!timezone.is_empty()); + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array), + Some(DataType::Timestamp(_, Some(_))) => { + timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str())) + } + _ => { + // Not supported + panic!( + "Cannot convert from {:?} to {:?}", + array.data_type(), + to_type.unwrap() + ) + } + } + } + DataType::Timestamp(_, Some(_)) => { + assert!(!timezone.is_empty()); + let array = as_primitive_array::(&array); + let array_with_timezone = array.clone().with_timezone(timezone.clone()); + let array = Arc::new(array_with_timezone) as ArrayRef; + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => { + pre_timestamp_cast(array, timezone) + } + _ => Ok(array), + } + } + DataType::Dictionary(_, value_type) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + let dict = as_dictionary_array::(&array); + let array = as_primitive_array::(dict.values()); + let array_with_timezone = + array_with_timezone(Arc::new(array.clone()) as ArrayRef, timezone, to_type)?; + let dict = dict.with_values(array_with_timezone); + Ok(Arc::new(dict)) + } + _ => Ok(array), + } +} + +fn datetime_cast_err(value: i64) -> ArrowError { + ArrowError::CastError(format!( + "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE", + )) +} + +/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns +/// a Timestamp(Microsecond, Some<_>) array. +/// The understanding is that the input array has time in the timezone specified in the second +/// argument. +/// Parameters: +/// array - input array of timestamp without timezone +/// tz - timezone of the values in the input array +/// to_timezone - timezone to change the input values to +fn timestamp_ntz_to_timestamp( + array: ArrayRef, + tz: &str, + to_timezone: Option<&str>, +) -> Result { + assert!(!tz.is_empty()); + match array.data_type() { + DataType::Timestamp(_, None) => { + let array = as_primitive_array::(&array); + let tz: Tz = tz.parse()?; + let array: PrimitiveArray = array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|local_datetime| { + let datetime: DateTime = + tz.from_local_datetime(&local_datetime).unwrap(); + datetime.timestamp_micros() + }) + })?; + let array_with_tz = if let Some(to_tz) = to_timezone { + array.with_timezone(to_tz) + } else { + array + }; + Ok(Arc::new(array_with_tz)) + } + _ => Ok(array), + } +} + +/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String. +fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result { + assert!(!timezone.is_empty()); + match array.data_type() { + DataType::Timestamp(_, _) => { + // Spark doesn't output timezone while casting timestamp to string, but arrow's cast + // kernel does if timezone exists. So we need to apply offset of timezone to array + // timestamp value and remove timezone from array datatype. + let array = as_primitive_array::(&array); + + let tz: Tz = timezone.parse()?; + let array: PrimitiveArray = array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|datetime| { + let offset = tz.offset_from_utc_datetime(&datetime).fix(); + let datetime = datetime + offset; + datetime.and_utc().timestamp_micros() + }) + })?; + + Ok(Arc::new(array)) + } + _ => Ok(array), + } +} diff --git a/native/core/src/execution/timezone.rs b/native/utils/src/timezone.rs similarity index 100% rename from native/core/src/execution/timezone.rs rename to native/utils/src/timezone.rs