From d89da8ef383bc82014c86728fefbe3cd97a3982d Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Tue, 20 Aug 2024 15:41:32 +0200 Subject: [PATCH] use IntervalCompound instead of interval-month-day-nano UDT --- .../substrait/src/logical_plan/consumer.rs | 59 +++++++-- .../substrait/src/logical_plan/producer.rs | 119 ++++-------------- datafusion/substrait/src/variation_const.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 17 --- 4 files changed, 79 insertions(+), 122 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b1b510f1792de..3374f5552cf27 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -41,15 +41,14 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; #[allow(deprecated)] use crate::variation_const::{ - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, - TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, - TIMESTAMP_SECOND_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::common::scalar::ScalarStructBuilder; @@ -69,10 +68,10 @@ use datafusion::{ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; -use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::user_defined::Val; use substrait::proto::expression::literal::{ - IntervalDayToSecond, IntervalYearToMonth, UserDefined, + interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, + UserDefined, }; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; @@ -1505,8 +1504,13 @@ fn from_substrait_type( Ok(DataType::Interval(IntervalUnit::YearMonth)) } r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), + r#type::Kind::IntervalCompound(_) => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } r#type::Kind::UserDefined(u) => { + // Kept for backwards compatibility, use IntervalCompound instead if let Some(name) = extensions.types.get(&u.type_reference) { + #[allow(deprecated)] match name.as_ref() { INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), _ => not_impl_err!( @@ -1516,7 +1520,7 @@ fn from_substrait_type( ), } } else { - // Kept for backwards compatibility, new plans should include the extension instead + // Kept for backwards compatibility, use IntervalCompound instead #[allow(deprecated)] match u.type_reference { // Kept for backwards compatibility, use IntervalYear instead @@ -1946,6 +1950,7 @@ fn from_substrait_literal( subseconds, precision_mode, })) => { + use interval_day_to_second::PrecisionMode; // DF only supports millisecond precision, so for any more granular type we lose precision let milliseconds = match precision_mode { Some(PrecisionMode::Microseconds(ms)) => ms / 1000, @@ -1965,6 +1970,39 @@ fn from_substrait_literal( Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { ScalarValue::new_interval_ym(*years, *months) } + Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month, + interval_day_to_second, + })) => match (interval_year_to_month, interval_day_to_second) { + ( + Some(IntervalYearToMonth { years, months }), + Some(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode: + Some(interval_day_to_second::PrecisionMode::Precision(p)), + }), + ) => { + let nanos = match p { + 0 => *subseconds * 1_000_000_000, + 3 => *subseconds * 1_000_000, + 6 => *subseconds * 1_000, + 9 => *subseconds, + _ => { + return not_impl_err!( + "Unsupported Substrait interval day to second precision mode" + ) + } + }; + ScalarValue::new_interval_mdn( + *years * 12 + months, + *days, + *seconds as i64 * 1_000_000_000 + nanos, + ) + } + _ => return not_impl_err!("Unsupported Substrait compound interval literal"), + }, Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed @@ -1995,6 +2033,8 @@ fn from_substrait_literal( if let Some(name) = extensions.types.get(&user_defined.type_reference) { match name.as_ref() { + // Kept for backwards compatibility - new plans should use IntervalCompound instead + #[allow(deprecated)] INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { interval_month_day_nano(user_defined)? } @@ -2045,6 +2085,7 @@ fn from_substrait_literal( milliseconds, })) } + // Kept for backwards compatibility, use IntervalCompound instead INTERVAL_MONTH_DAY_NANO_TYPE_REF => { interval_month_day_nano(user_defined)? } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 72b6760be29c1..bc6f55641a0af 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,6 @@ use itertools::Itertools; use std::ops::Deref; use std::sync::Arc; -use arrow_buffer::ToByteSlice; use datafusion::arrow::datatypes::IntervalUnit; use datafusion::logical_expr::{ CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, @@ -37,8 +36,7 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{ @@ -56,8 +54,8 @@ use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::map::KeyValue; use substrait::proto::expression::literal::{ - user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, - PrecisionTimestamp, Struct, UserDefined, + IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map, + PrecisionTimestamp, Struct, }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; @@ -1429,16 +1427,14 @@ fn to_substrait_type( })), }), IntervalUnit::MonthDayNano => { - // Substrait doesn't currently support this type, so we represent it as a UDT Ok(substrait::proto::Type { - kind: Some(r#type::Kind::UserDefined(r#type::UserDefined { - type_reference: extensions.register_type( - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), - ), - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - type_parameters: vec![], - })), + kind: Some(r#type::Kind::IntervalCompound( + r#type::IntervalCompound { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: 9, // nanos + }, + )), }) } } @@ -1880,23 +1876,21 @@ fn to_substrait_literal( }), DEFAULT_TYPE_VARIATION_REF, ), - ScalarValue::IntervalMonthDayNano(Some(i)) => { - // IntervalMonthDayNano is internally represented as a 128-bit integer, containing - // months (32bit), days (32bit), and nanoseconds (64bit) - let bytes = i.to_byte_slice(); - ( - LiteralType::UserDefined(UserDefined { - type_reference: extensions - .register_type(INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()), - type_parameters: vec![], - val: Some(user_defined::Val::Value(ProtoAny { - type_url: INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), - value: bytes.to_vec().into(), - })), + ScalarValue::IntervalMonthDayNano(Some(i)) => ( + LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: 0, + months: i.months, }), - DEFAULT_TYPE_VARIATION_REF, - ) - } + interval_day_to_second: Some(IntervalDayToSecond { + days: i.days, + seconds: 0, + subseconds: i.nanoseconds, + precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds + }), + }), + DEFAULT_TYPE_VARIATION_REF, + ), ScalarValue::IntervalDayTime(Some(i)) => ( LiteralType::IntervalDayToSecond(IntervalDayToSecond { days: i.days, @@ -2161,7 +2155,6 @@ mod test { }; use datafusion::arrow::datatypes::Field; use datafusion::common::scalar::ScalarStructBuilder; - use std::collections::HashMap; #[test] fn round_trip_literals() -> Result<()> { @@ -2292,39 +2285,6 @@ mod test { Ok(()) } - #[test] - fn custom_type_literal_extensions() -> Result<()> { - let mut extensions = Extensions::default(); - // IntervalMonthDayNano is represented as a custom type in Substrait - let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new( - 17, 25, 1234567890, - ))); - let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; - let roundtrip_scalar = - from_substrait_literal_without_names(&substrait_literal, &extensions)?; - assert_eq!(scalar, roundtrip_scalar); - - assert_eq!( - extensions, - Extensions { - functions: HashMap::new(), - types: HashMap::from([( - 0, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() - )]), - type_variations: HashMap::new(), - } - ); - - // Check we fail if we don't propagate extensions - assert!(from_substrait_literal_without_names( - &substrait_literal, - &Extensions::default() - ) - .is_err()); - Ok(()) - } - #[test] fn round_trip_types() -> Result<()> { round_trip_type(DataType::Boolean)?; @@ -2403,35 +2363,4 @@ mod test { assert_eq!(dt, roundtrip_dt); Ok(()) } - - #[test] - fn custom_type_extensions() -> Result<()> { - let mut extensions = Extensions::default(); - // IntervalMonthDayNano is represented as a custom type in Substrait - let dt = DataType::Interval(IntervalUnit::MonthDayNano); - - let substrait = to_substrait_type(&dt, true, &mut extensions)?; - let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; - assert_eq!(dt, roundtrip_dt); - - assert_eq!( - extensions, - Extensions { - functions: HashMap::new(), - types: HashMap::from([( - 0, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() - )]), - type_variations: HashMap::new(), - } - ); - - // Check we fail if we don't propagate extensions - assert!( - from_substrait_type_without_names(&substrait, &Extensions::default()) - .is_err() - ); - - Ok(()) - } } diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index 1525da7645096..af254e99beed5 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -95,7 +95,7 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano #[deprecated( since = "41.0.0", - note = "Use Substrait `UserDefinedType` with name `INTERVAL_MONTH_DAY_NANO_TYPE_NAME` instead" + note = "Use Substrait `IntervalCompund` type instead" )] pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; @@ -103,4 +103,8 @@ pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; /// /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano +#[deprecated( + since = "42.0.0", + note = "Use Substrait `IntervalCompund` type instead" +)] pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano"; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 083a589fce267..1374887e12b8f 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -204,23 +204,6 @@ async fn select_with_reused_functions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn roundtrip_udt_extensions() -> Result<()> { - let ctx = create_context().await?; - let proto = - roundtrip_with_ctx("SELECT INTERVAL '1 YEAR 1 DAY 1 SECOND' FROM data", ctx) - .await?; - let expected_type = SimpleExtensionDeclaration { - mapping_type: Some(MappingType::ExtensionType(ExtensionType { - extension_uri_reference: u32::MAX, - type_anchor: 0, - name: "interval-month-day-nano".to_string(), - })), - }; - assert_eq!(proto.extensions, vec![expected_type]); - Ok(()) -} - #[tokio::test] async fn select_with_filter_date() -> Result<()> { roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await