Skip to content

Commit

Permalink
feat: add substrait support for Interval types and literals (#10646)
Browse files Browse the repository at this point in the history
* feat: support interval types

Signed-off-by: Ruihang Xia <[email protected]>

* impl literals

Signed-off-by: Ruihang Xia <[email protected]>

* fix deadlink in doc

Signed-off-by: Ruihang Xia <[email protected]>

---------

Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored May 26, 2024
1 parent 52c4f3c commit 6167ce9
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 10 deletions.
76 changes: 72 additions & 4 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};
Expand All @@ -39,6 +39,7 @@ use datafusion::{
scalar::ScalarValue,
};
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
use substrait::proto::{
Expand Down Expand Up @@ -71,9 +72,10 @@ use std::sync::Arc;

use crate::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF,
LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF,
TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
};

enum ScalarFunctionType {
Expand Down Expand Up @@ -1162,6 +1164,24 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
"Unsupported Substrait type variation {v} of type {s_kind:?}"
),
},
r#type::Kind::UserDefined(u) => {
match u.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
Ok(DataType::Interval(IntervalUnit::YearMonth))
}
INTERVAL_DAY_TIME_TYPE_REF => {
Ok(DataType::Interval(IntervalUnit::DayTime))
}
INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
Ok(DataType::Interval(IntervalUnit::MonthDayNano))
}
_ => not_impl_err!(
"Unsupported Substrait user defined type with ref {} and variation {}",
u.type_reference,
u.type_variation_reference
),
}
},
r#type::Kind::Struct(s) => {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
Expand Down Expand Up @@ -1387,6 +1407,54 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
builder.build()?
}
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
Some(LiteralType::UserDefined(user_defined)) => {
match user_defined.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
return substrait_err!("Interval year month value is empty");
};
let value_slice: [u8; 4] =
raw_val.value.clone().try_into().map_err(|_| {
substrait_datafusion_err!(
"Failed to parse interval year month value"
)
})?;
ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice)))
}
INTERVAL_DAY_TIME_TYPE_REF => {
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
return substrait_err!("Interval day time value is empty");
};
let value_slice: [u8; 8] =
raw_val.value.clone().try_into().map_err(|_| {
substrait_datafusion_err!(
"Failed to parse interval day time value"
)
})?;
ScalarValue::IntervalDayTime(Some(i64::from_le_bytes(value_slice)))
}
INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
return substrait_err!("Interval month day nano value is empty");
};
let value_slice: [u8; 16] =
raw_val.value.clone().try_into().map_err(|_| {
substrait_datafusion_err!(
"Failed to parse interval month day nano value"
)
})?;
ScalarValue::IntervalMonthDayNano(Some(i128::from_le_bytes(
value_slice,
)))
}
_ => {
return not_impl_err!(
"Unsupported Substrait user defined type with ref {}",
user_defined.type_reference
)
}
}
}
_ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type),
};

Expand Down
125 changes: 122 additions & 3 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;

use datafusion::arrow::datatypes::IntervalUnit;
use datafusion::logical_expr::{
CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits,
};
Expand All @@ -43,9 +44,12 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::literal::UserDefined;
use substrait::proto::expression::literal::{List, Struct};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::r#type::{parameter, Parameter};
use substrait::proto::{CrossRel, ExchangeRel};
use substrait::{
proto::{
Expand Down Expand Up @@ -84,9 +88,12 @@ use substrait::{

use crate::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
INTERVAL_DAY_TIME_TYPE_URL, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_MONTH_DAY_NANO_TYPE_URL, INTERVAL_YEAR_MONTH_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_URL, LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF,
TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF,
UNSIGNED_INTEGER_TYPE_REF,
};

/// Convert DataFusion LogicalPlan to Substrait Plan
Expand Down Expand Up @@ -1398,6 +1405,49 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result<substrait::proto::
nullability,
})),
}),
DataType::Interval(interval_unit) => {
// define two type parameters for convenience
let i32_param = Parameter {
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
};
let i64_param = Parameter {
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
};

let (type_parameters, type_reference) = match interval_unit {
IntervalUnit::YearMonth => {
let type_parameters = vec![i32_param];
(type_parameters, INTERVAL_YEAR_MONTH_TYPE_REF)
}
IntervalUnit::DayTime => {
let type_parameters = vec![i64_param];
(type_parameters, INTERVAL_DAY_TIME_TYPE_REF)
}
IntervalUnit::MonthDayNano => {
// use 2 `i64` as `i128`
let type_parameters = vec![i64_param.clone(), i64_param];
(type_parameters, INTERVAL_MONTH_DAY_NANO_TYPE_REF)
}
};
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::UserDefined(r#type::UserDefined {
type_reference,
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
type_parameters,
})),
})
}
DataType::Binary => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
Expand Down Expand Up @@ -1735,6 +1785,75 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
}
ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_REF),
// Date64 literal is not supported in Substrait
ScalarValue::IntervalYearMonth(Some(i)) => {
let bytes = i.to_le_bytes();
(
LiteralType::UserDefined(UserDefined {
type_reference: INTERVAL_YEAR_MONTH_TYPE_REF,
type_parameters: vec![Parameter {
parameter: Some(parameter::Parameter::DataType(
substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: r#type::Nullability::Required as i32,
})),
},
)),
}],
val: Some(Val::Value(ProtoAny {
type_url: INTERVAL_YEAR_MONTH_TYPE_URL.to_string(),
value: bytes.to_vec(),
})),
}),
INTERVAL_YEAR_MONTH_TYPE_REF,
)
}
ScalarValue::IntervalMonthDayNano(Some(i)) => {
// treat `i128` as two contiguous `i64`
let bytes = i.to_le_bytes();
let i64_param = Parameter {
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: r#type::Nullability::Required as i32,
})),
})),
};
(
LiteralType::UserDefined(UserDefined {
type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF,
type_parameters: vec![i64_param.clone(), i64_param],
val: Some(Val::Value(ProtoAny {
type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(),
value: bytes.to_vec(),
})),
}),
INTERVAL_MONTH_DAY_NANO_TYPE_REF,
)
}
ScalarValue::IntervalDayTime(Some(i)) => {
let bytes = i.to_le_bytes();
(
LiteralType::UserDefined(UserDefined {
type_reference: INTERVAL_DAY_TIME_TYPE_REF,
type_parameters: vec![Parameter {
parameter: Some(parameter::Parameter::DataType(
substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: r#type::Nullability::Required as i32,
})),
},
)),
}],
val: Some(Val::Value(ProtoAny {
type_url: INTERVAL_DAY_TIME_TYPE_URL.to_string(),
value: bytes.to_vec(),
})),
}),
INTERVAL_DAY_TIME_TYPE_REF,
)
}
ScalarValue::Binary(Some(b)) => {
(LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF)
}
Expand Down
56 changes: 56 additions & 0 deletions datafusion/substrait/src/variation_const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//! - Default type reference is 0. It is used when the actual type is the same with the original type.
//! - Extended variant type references start from 1, and ususlly increase by 1.
// For type variations
pub const DEFAULT_TYPE_REF: u32 = 0;
pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1;
pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0;
Expand All @@ -37,3 +38,58 @@ pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0;
pub const LARGE_CONTAINER_TYPE_REF: u32 = 1;
pub const DECIMAL_128_TYPE_REF: u32 = 0;
pub const DECIMAL_256_TYPE_REF: u32 = 1;

// For custom types
/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
///
/// An `i32` for elapsed whole months. See also [`ScalarValue::IntervalYearMonth`]
/// for the literal definition in DataFusion.
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth
/// [`ScalarValue::IntervalYearMonth`]: datafusion::common::ScalarValue::IntervalYearMonth
pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1;

/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
///
/// An `i64` as:
/// - days: `i32`
/// - milliseconds: `i32`
///
/// See also [`ScalarValue::IntervalDayTime`] for the literal definition in DataFusion.
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime
/// [`ScalarValue::IntervalDayTime`]: datafusion::common::ScalarValue::IntervalDayTime
pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2;

/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
///
/// An `i128` as:
/// - months: `i32`
/// - days: `i32`
/// - nanoseconds: `i64`
///
/// See also [`ScalarValue::IntervalMonthDayNano`] for the literal definition in DataFusion.
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
/// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano
pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3;

// For User Defined URLs
/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth
pub const INTERVAL_YEAR_MONTH_TYPE_URL: &str = "interval-year-month";
/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime
pub const INTERVAL_DAY_TIME_TYPE_URL: &str = "interval-day-time";
/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
pub const INTERVAL_MONTH_DAY_NANO_TYPE_URL: &str = "interval-month-day-nano";
26 changes: 23 additions & 3 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion_substrait::logical_plan::{
use std::hash::Hash;
use std::sync::Arc;

use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
Expand Down Expand Up @@ -496,6 +496,24 @@ async fn roundtrip_arithmetic_ops() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_interval_literal() -> Result<()> {
roundtrip(
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(YearMonth)')",
)
.await?;
roundtrip(
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(DayTime)')",
)
.await?;
roundtrip(
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(MonthDayNano)')",
)
.await?;

Ok(())
}

#[tokio::test]
async fn roundtrip_like() -> Result<()> {
roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await
Expand Down Expand Up @@ -1035,14 +1053,16 @@ async fn create_context() -> Result<SessionContext> {
.with_serializer_registry(Arc::new(MockSerializerRegistry));
let ctx = SessionContext::new_with_state(state);
let mut explicit_options = CsvReadOptions::new();
let schema = Schema::new(vec![
let fields = vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(5, 2), true),
Field::new("c", DataType::Date32, true),
Field::new("d", DataType::Boolean, true),
Field::new("e", DataType::UInt32, true),
Field::new("f", DataType::Utf8, true),
]);
Field::new("g", DataType::Interval(IntervalUnit::DayTime), true),
];
let schema = Schema::new(fields);
explicit_options.schema = Some(&schema);
ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
.await?;
Expand Down

0 comments on commit 6167ce9

Please sign in to comment.