Skip to content

Commit

Permalink
feat(substrait): use IntervalCompound instead of interval-month-day-n…
Browse files Browse the repository at this point in the history
…ano UDT
  • Loading branch information
Blizzara committed Oct 28, 2024
1 parent 146f16a commit f1e5176
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 118 deletions.
57 changes: 48 additions & 9 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ use crate::variation_const::{
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF,
LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF,
VIEW_CONTAINER_TYPE_VARIATION_REF,
};
#[allow(deprecated)]
use crate::variation_const::{
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF,
TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF,
TIMESTAMP_SECOND_TYPE_VARIATION_REF,
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME,
INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF,
TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF,
TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF,
};
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::arrow::temporal_conversions::NANOSECONDS;
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::dataframe::DataFrame;
use datafusion::logical_expr::expr::InList;
Expand All @@ -71,10 +72,10 @@ use datafusion::{
use std::collections::HashSet;
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::literal::{
IntervalDayToSecond, IntervalYearToMonth, UserDefined,
interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth,
UserDefined,
};
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction};
Expand Down Expand Up @@ -1831,8 +1832,13 @@ fn from_substrait_type(
Ok(DataType::Interval(IntervalUnit::YearMonth))
}
r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)),
r#type::Kind::IntervalCompound(_) => {
Ok(DataType::Interval(IntervalUnit::MonthDayNano))
}
r#type::Kind::UserDefined(u) => {
// Kept for backwards compatibility, use IntervalCompound instead
if let Some(name) = extensions.types.get(&u.type_reference) {
#[allow(deprecated)]
match name.as_ref() {
INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
_ => not_impl_err!(
Expand All @@ -1842,7 +1848,7 @@ fn from_substrait_type(
),
}
} else {
// Kept for backwards compatibility, new plans should include the extension instead
// Kept for backwards compatibility, use IntervalCompound instead
#[allow(deprecated)]
match u.type_reference {
// Kept for backwards compatibility, use IntervalYear instead
Expand Down Expand Up @@ -2275,6 +2281,7 @@ fn from_substrait_literal(
subseconds,
precision_mode,
})) => {
use interval_day_to_second::PrecisionMode;
// DF only supports millisecond precision, so for any more granular type we lose precision
let milliseconds = match precision_mode {
Some(PrecisionMode::Microseconds(ms)) => ms / 1000,
Expand All @@ -2299,6 +2306,35 @@ fn from_substrait_literal(
Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => {
ScalarValue::new_interval_ym(*years, *months)
}
Some(LiteralType::IntervalCompound(IntervalCompound {
interval_year_to_month,
interval_day_to_second,
})) => match (interval_year_to_month, interval_day_to_second) {
(
Some(IntervalYearToMonth { years, months }),
Some(IntervalDayToSecond {
days,
seconds,
subseconds,
precision_mode:
Some(interval_day_to_second::PrecisionMode::Precision(p)),
}),
) => {
if (*p < 0 || *p > 9) {
return plan_err!(
"Unsupported Substrait interval day to second precision: {}",
p
);
}
let nanos = *subseconds * i64::pow(10, (*p - 9) as u32);
ScalarValue::new_interval_mdn(
*years * 12 + months,
*days,
*seconds as i64 * NANOSECONDS + nanos,
)
}
_ => return plan_err!("Substrait compound interval missing components"),
},
Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())),
Some(LiteralType::UserDefined(user_defined)) => {
// Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed
Expand Down Expand Up @@ -2329,6 +2365,8 @@ fn from_substrait_literal(

if let Some(name) = extensions.types.get(&user_defined.type_reference) {
match name.as_ref() {
// Kept for backwards compatibility - new plans should use IntervalCompound instead
#[allow(deprecated)]
INTERVAL_MONTH_DAY_NANO_TYPE_NAME => {
interval_month_day_nano(user_defined)?
}
Expand Down Expand Up @@ -2379,6 +2417,7 @@ fn from_substrait_literal(
milliseconds,
}))
}
// Kept for backwards compatibility, use IntervalCompound instead
INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
interval_month_day_nano(user_defined)?
}
Expand Down
115 changes: 24 additions & 91 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use crate::variation_const::{
UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF,
};
use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
use datafusion::arrow::temporal_conversions::NANOSECONDS;
use datafusion::common::{
exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err,
substrait_err, DFSchemaRef, ToDFSchema,
Expand All @@ -58,8 +59,8 @@ use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
use substrait::proto::expression::literal::map::KeyValue;
use substrait::proto::expression::literal::{
user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map,
PrecisionTimestamp, Struct, UserDefined,
IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map,
PrecisionTimestamp, Struct,
};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
Expand Down Expand Up @@ -1489,16 +1490,14 @@ fn to_substrait_type(
})),
}),
IntervalUnit::MonthDayNano => {
// Substrait doesn't currently support this type, so we represent it as a UDT
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::UserDefined(r#type::UserDefined {
type_reference: extensions.register_type(
INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(),
),
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability,
type_parameters: vec![],
})),
kind: Some(r#type::Kind::IntervalCompound(
r#type::IntervalCompound {
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability,
precision: 9, // nanos
},
)),
})
}
}
Expand Down Expand Up @@ -1892,23 +1891,21 @@ fn to_substrait_literal(
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::IntervalMonthDayNano(Some(i)) => {
// IntervalMonthDayNano is internally represented as a 128-bit integer, containing
// months (32bit), days (32bit), and nanoseconds (64bit)
let bytes = i.to_byte_slice();
(
LiteralType::UserDefined(UserDefined {
type_reference: extensions
.register_type(INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()),
type_parameters: vec![],
val: Some(user_defined::Val::Value(ProtoAny {
type_url: INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(),
value: bytes.to_vec().into(),
})),
ScalarValue::IntervalMonthDayNano(Some(i)) => (
LiteralType::IntervalCompound(IntervalCompound {
interval_year_to_month: Some(IntervalYearToMonth {
years: i.months / 12,
months: i.months % 12,
}),
DEFAULT_TYPE_VARIATION_REF,
)
}
interval_day_to_second: Some(IntervalDayToSecond {
days: i.days,
seconds: (i.nanoseconds / NANOSECONDS) as i32,
subseconds: i.nanoseconds % NANOSECONDS,
precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds
}),
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::IntervalDayTime(Some(i)) => (
LiteralType::IntervalDayToSecond(IntervalDayToSecond {
days: i.days,
Expand Down Expand Up @@ -2310,39 +2307,6 @@ mod test {
Ok(())
}

#[test]
fn custom_type_literal_extensions() -> Result<()> {
let mut extensions = Extensions::default();
// IntervalMonthDayNano is represented as a custom type in Substrait
let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new(
17, 25, 1234567890,
)));
let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?;
let roundtrip_scalar =
from_substrait_literal_without_names(&substrait_literal, &extensions)?;
assert_eq!(scalar, roundtrip_scalar);

assert_eq!(
extensions,
Extensions {
functions: HashMap::new(),
types: HashMap::from([(
0,
INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()
)]),
type_variations: HashMap::new(),
}
);

// Check we fail if we don't propagate extensions
assert!(from_substrait_literal_without_names(
&substrait_literal,
&Extensions::default()
)
.is_err());
Ok(())
}

#[test]
fn round_trip_types() -> Result<()> {
round_trip_type(DataType::Boolean)?;
Expand Down Expand Up @@ -2424,37 +2388,6 @@ mod test {
Ok(())
}

#[test]
fn custom_type_extensions() -> Result<()> {
let mut extensions = Extensions::default();
// IntervalMonthDayNano is represented as a custom type in Substrait
let dt = DataType::Interval(IntervalUnit::MonthDayNano);

let substrait = to_substrait_type(&dt, true, &mut extensions)?;
let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?;
assert_eq!(dt, roundtrip_dt);

assert_eq!(
extensions,
Extensions {
functions: HashMap::new(),
types: HashMap::from([(
0,
INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()
)]),
type_variations: HashMap::new(),
}
);

// Check we fail if we don't propagate extensions
assert!(
from_substrait_type_without_names(&substrait, &Extensions::default())
.is_err()
);

Ok(())
}

#[test]
fn named_struct_names() -> Result<()> {
let mut extensions = Extensions::default();
Expand Down
6 changes: 5 additions & 1 deletion datafusion/substrait/src/variation_const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2;
/// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano
#[deprecated(
since = "41.0.0",
note = "Use Substrait `UserDefinedType` with name `INTERVAL_MONTH_DAY_NANO_TYPE_NAME` instead"
note = "Use Substrait `IntervalCompund` type instead"
)]
pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3;

/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
#[deprecated(
since = "42.1.0",
note = "Use Substrait `IntervalCompund` type instead"
)]
pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano";
17 changes: 0 additions & 17 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,23 +230,6 @@ async fn select_with_reused_functions() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_udt_extensions() -> Result<()> {
let ctx = create_context().await?;
let proto =
roundtrip_with_ctx("SELECT INTERVAL '1 YEAR 1 DAY 1 SECOND' FROM data", ctx)
.await?;
let expected_type = SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionType(ExtensionType {
extension_uri_reference: u32::MAX,
type_anchor: 0,
name: "interval-month-day-nano".to_string(),
})),
};
assert_eq!(proto.extensions, vec![expected_type]);
Ok(())
}

#[tokio::test]
async fn select_with_filter_date() -> Result<()> {
roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await
Expand Down

0 comments on commit f1e5176

Please sign in to comment.