From 5911d182eef08afee4fbdef3da7642ee92d1314c Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Thu, 14 Mar 2024 06:10:01 -0400 Subject: [PATCH] feat: implement more expr_to_sql functionality (#9578) * more impls * fix tests * cargo update dfcli * fix custom_dialect test * add tests and feature flag * fix comment * remove chrono use arrow-array conversions * fix cargo lock again * fix count distinct * retry windows ci * retry windows ci again * add roundtrip tests * cargo fmt --- README.md | 1 + datafusion-cli/Cargo.lock | 1 + datafusion/sql/Cargo.toml | 4 +- datafusion/sql/src/lib.rs | 1 + datafusion/sql/src/unparser/expr.rs | 389 +++++++++++++++++++----- datafusion/sql/tests/sql_integration.rs | 24 +- 6 files changed, 338 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index e5ac9503be44..abd727672aca 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ Default features: - `parquet`: support for reading the [Apache Parquet] format - `regex_expressions`: regular expression functions, such as `regexp_match` - `unicode_expressions`: Include unicode aware functions such as `character_length` +- `unparser` : enables support to reverse LogicalPlans back into SQL Optional features: diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e0c7c4391b25..1c2514811c7d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1366,6 +1366,7 @@ name = "datafusion-sql" version = "36.0.0" dependencies = [ "arrow", + "arrow-array", "arrow-schema", "datafusion-common", "datafusion-expr", diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 7739058a5c9d..ca2c1a240c21 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -33,11 +33,13 @@ name = "datafusion_sql" path = "src/lib.rs" [features] -default = ["unicode_expressions"] +default = ["unicode_expressions", "unparser"] unicode_expressions = [] +unparser = [] [dependencies] arrow = { workspace = true } +arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index da66ee197adb..e8e07eebe22d 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -36,6 +36,7 @@ mod relation; mod select; mod set_expr; mod statement; +#[cfg(feature = "unparser")] pub mod unparser; pub mod utils; mod values; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 2a9fdd47ad93..403a7c6193d0 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{not_impl_err, Column, Result, ScalarValue}; +use arrow_array::{Date32Array, Date64Array}; +use arrow_schema::DataType; +use datafusion_common::{ + internal_datafusion_err, not_impl_err, Column, Result, ScalarValue, +}; use datafusion_expr::{ - expr::{Alias, InList, ScalarFunction, WindowFunction}, + expr::{AggregateFunctionDefinition, Alias, InList, ScalarFunction, WindowFunction}, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, }; -use sqlparser::ast; +use sqlparser::ast::{self, Function, FunctionArg, Ident}; use super::Unparser; @@ -36,7 +40,7 @@ use super::Unparser; /// let expr = col("a").gt(lit(4)); /// let sql = expr_to_sql(&expr).unwrap(); /// -/// assert_eq!(format!("{}", sql), "a > 4") +/// assert_eq!(format!("{}", sql), "(a > 4)") /// ``` pub fn expr_to_sql(expr: &Expr) -> Result { let unparser = Unparser::default(); @@ -70,7 +74,7 @@ impl Unparser<'_> { let r = self.expr_to_sql(right.as_ref())?; let op = self.op_to_sql(op)?; - Ok(self.binary_op_to_sql(l, r, op)) + Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, op)))) } Expr::Case(Case { expr, @@ -79,10 +83,15 @@ impl Unparser<'_> { }) => { not_impl_err!("Unsupported expression: {expr:?}") } - Expr::Cast(Cast { expr, data_type: _ }) => { - not_impl_err!("Unsupported expression: {expr:?}") + Expr::Cast(Cast { expr, data_type }) => { + let inner_expr = self.expr_to_sql(expr)?; + Ok(ast::Expr::Cast { + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }) } - Expr::Literal(value) => Ok(ast::Expr::Value(self.scalar_to_sql(value)?)), + Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr), Expr::WindowFunction(WindowFunction { fun: _, @@ -103,6 +112,45 @@ impl Unparser<'_> { }) => { not_impl_err!("Unsupported expression: {expr:?}") } + Expr::AggregateFunction(agg) => { + let func_name = if let AggregateFunctionDefinition::BuiltIn(built_in) = + &agg.func_def + { + built_in.name() + } else { + return not_impl_err!( + "Only built in agg functions are supported, got {agg:?}" + ); + }; + + let args = agg + .args + .iter() + .map(|e| { + if matches!(e, Expr::Wildcard { qualifier: None }) { + Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) + } else { + self.expr_to_sql(e).map(|e| { + FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) + }) + } + }) + .collect::>>()?; + + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args, + filter: None, + null_treatment: None, + over: None, + distinct: agg.distinct, + special: false, + order_by: vec![], + })) + } _ => not_impl_err!("Unsupported expression: {expr:?}"), } } @@ -174,139 +222,265 @@ impl Unparser<'_> { } } - fn scalar_to_sql(&self, v: &ScalarValue) -> Result { + /// DataFusion ScalarValues sometimes require a ast::Expr to construct. + /// For example ScalarValue::Date32(d) corresponds to the ast::Expr CAST('datestr' as DATE) + fn scalar_to_sql(&self, v: &ScalarValue) -> Result { match v { - ScalarValue::Null => Ok(ast::Value::Null), - ScalarValue::Boolean(Some(b)) => Ok(ast::Value::Boolean(b.to_owned())), - ScalarValue::Boolean(None) => Ok(ast::Value::Null), - ScalarValue::Float32(Some(f)) => Ok(ast::Value::Number(f.to_string(), false)), - ScalarValue::Float32(None) => Ok(ast::Value::Null), - ScalarValue::Float64(Some(f)) => Ok(ast::Value::Number(f.to_string(), false)), - ScalarValue::Float64(None) => Ok(ast::Value::Null), + ScalarValue::Null => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Boolean(Some(b)) => { + Ok(ast::Expr::Value(ast::Value::Boolean(b.to_owned()))) + } + ScalarValue::Boolean(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Float32(Some(f)) => { + Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) + } + ScalarValue::Float32(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Float64(Some(f)) => { + Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) + } + ScalarValue::Float64(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Decimal128(Some(_), ..) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::Decimal128(None, ..) => Ok(ast::Value::Null), + ScalarValue::Decimal128(None, ..) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Decimal256(Some(_), ..) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::Decimal256(None, ..) => Ok(ast::Value::Null), - ScalarValue::Int8(Some(i)) => Ok(ast::Value::Number(i.to_string(), false)), - ScalarValue::Int8(None) => Ok(ast::Value::Null), - ScalarValue::Int16(Some(i)) => Ok(ast::Value::Number(i.to_string(), false)), - ScalarValue::Int16(None) => Ok(ast::Value::Null), - ScalarValue::Int32(Some(i)) => Ok(ast::Value::Number(i.to_string(), false)), - ScalarValue::Int32(None) => Ok(ast::Value::Null), - ScalarValue::Int64(Some(i)) => Ok(ast::Value::Number(i.to_string(), false)), - ScalarValue::Int64(None) => Ok(ast::Value::Null), - ScalarValue::UInt8(Some(ui)) => Ok(ast::Value::Number(ui.to_string(), false)), - ScalarValue::UInt8(None) => Ok(ast::Value::Null), - ScalarValue::UInt16(Some(ui)) => { - Ok(ast::Value::Number(ui.to_string(), false)) + ScalarValue::Decimal256(None, ..) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Int8(Some(i)) => { + Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false))) } - ScalarValue::UInt16(None) => Ok(ast::Value::Null), - ScalarValue::UInt32(Some(ui)) => { - Ok(ast::Value::Number(ui.to_string(), false)) + ScalarValue::Int8(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Int16(Some(i)) => { + Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false))) } - ScalarValue::UInt32(None) => Ok(ast::Value::Null), - ScalarValue::UInt64(Some(ui)) => { - Ok(ast::Value::Number(ui.to_string(), false)) + ScalarValue::Int16(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Int32(Some(i)) => { + Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false))) + } + ScalarValue::Int32(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Int64(Some(i)) => { + Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false))) } - ScalarValue::UInt64(None) => Ok(ast::Value::Null), - ScalarValue::Utf8(Some(str)) => { - Ok(ast::Value::SingleQuotedString(str.to_string())) + ScalarValue::Int64(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::UInt8(Some(ui)) => { + Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false))) } - ScalarValue::Utf8(None) => Ok(ast::Value::Null), - ScalarValue::LargeUtf8(Some(str)) => { - Ok(ast::Value::SingleQuotedString(str.to_string())) + ScalarValue::UInt8(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::UInt16(Some(ui)) => { + Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false))) + } + ScalarValue::UInt16(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::UInt32(Some(ui)) => { + Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false))) } - ScalarValue::LargeUtf8(None) => Ok(ast::Value::Null), + ScalarValue::UInt32(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::UInt64(Some(ui)) => { + Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false))) + } + ScalarValue::UInt64(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Utf8(Some(str)) => Ok(ast::Expr::Value( + ast::Value::SingleQuotedString(str.to_string()), + )), + ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value( + ast::Value::SingleQuotedString(str.to_string()), + )), + ScalarValue::LargeUtf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: {v:?}"), - ScalarValue::Binary(None) => Ok(ast::Value::Null), + ScalarValue::Binary(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::FixedSizeBinary(..) => { not_impl_err!("Unsupported scalar: {v:?}") } ScalarValue::LargeBinary(Some(_)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::LargeBinary(None) => Ok(ast::Value::Null), + ScalarValue::LargeBinary(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::FixedSizeList(_a) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::List(_a) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::LargeList(_a) => not_impl_err!("Unsupported scalar: {v:?}"), - ScalarValue::Date32(Some(_d)) => not_impl_err!("Unsupported scalar: {v:?}"), - ScalarValue::Date32(None) => Ok(ast::Value::Null), - ScalarValue::Date64(Some(_d)) => not_impl_err!("Unsupported scalar: {v:?}"), - ScalarValue::Date64(None) => Ok(ast::Value::Null), + ScalarValue::Date32(Some(_)) => { + let date = v + .to_array()? + .as_any() + .downcast_ref::() + .ok_or(internal_datafusion_err!( + "Unable to downcast to Date32 from Date32 scalar" + ))? + .value_as_date(0) + .ok_or(internal_datafusion_err!( + "Unable to convert Date32 to NaiveDate" + ))?; + + Ok(ast::Expr::Cast { + expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + date.to_string(), + ))), + data_type: ast::DataType::Date, + format: None, + }) + } + ScalarValue::Date32(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Date64(Some(_)) => { + let datetime = v + .to_array()? + .as_any() + .downcast_ref::() + .ok_or(internal_datafusion_err!( + "Unable to downcast to Date64 from Date64 scalar" + ))? + .value_as_datetime(0) + .ok_or(internal_datafusion_err!( + "Unable to convert Date64 to NaiveDateTime" + ))?; + + Ok(ast::Expr::Cast { + expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + datetime.to_string(), + ))), + data_type: ast::DataType::Datetime(None), + format: None, + }) + } + ScalarValue::Date64(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Time32Second(Some(_t)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::Time32Second(None) => Ok(ast::Value::Null), + ScalarValue::Time32Second(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Time32Millisecond(Some(_t)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::Time32Millisecond(None) => Ok(ast::Value::Null), + ScalarValue::Time32Millisecond(None) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::Time64Microsecond(Some(_t)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::Time64Microsecond(None) => Ok(ast::Value::Null), + ScalarValue::Time64Microsecond(None) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::Time64Nanosecond(Some(_t)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::Time64Nanosecond(None) => Ok(ast::Value::Null), + ScalarValue::Time64Nanosecond(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::TimestampSecond(Some(_ts), _) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::TimestampSecond(None, _) => Ok(ast::Value::Null), + ScalarValue::TimestampSecond(None, _) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::TimestampMillisecond(Some(_ts), _) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::TimestampMillisecond(None, _) => Ok(ast::Value::Null), + ScalarValue::TimestampMillisecond(None, _) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::TimestampMicrosecond(Some(_ts), _) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::TimestampMicrosecond(None, _) => Ok(ast::Value::Null), + ScalarValue::TimestampMicrosecond(None, _) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::TimestampNanosecond(Some(_ts), _) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::TimestampNanosecond(None, _) => Ok(ast::Value::Null), + ScalarValue::TimestampNanosecond(None, _) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::IntervalYearMonth(Some(_i)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::IntervalYearMonth(None) => Ok(ast::Value::Null), + ScalarValue::IntervalYearMonth(None) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::IntervalDayTime(Some(_i)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::IntervalDayTime(None) => Ok(ast::Value::Null), + ScalarValue::IntervalDayTime(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::IntervalMonthDayNano(Some(_i)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::IntervalMonthDayNano(None) => Ok(ast::Value::Null), + ScalarValue::IntervalMonthDayNano(None) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::DurationSecond(Some(_d)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::DurationSecond(None) => Ok(ast::Value::Null), + ScalarValue::DurationSecond(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::DurationMillisecond(Some(_d)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::DurationMillisecond(None) => Ok(ast::Value::Null), + ScalarValue::DurationMillisecond(None) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::DurationMicrosecond(Some(_d)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::DurationMicrosecond(None) => Ok(ast::Value::Null), + ScalarValue::DurationMicrosecond(None) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::DurationNanosecond(Some(_d)) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::DurationNanosecond(None) => Ok(ast::Value::Null), + ScalarValue::DurationNanosecond(None) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } } + + fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { + match data_type { + DataType::Null => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Boolean => Ok(ast::DataType::Bool), + DataType::Int8 => Ok(ast::DataType::TinyInt(None)), + DataType::Int16 => Ok(ast::DataType::SmallInt(None)), + DataType::Int32 => Ok(ast::DataType::Integer(None)), + DataType::Int64 => Ok(ast::DataType::BigInt(None)), + DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)), + DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)), + DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)), + DataType::UInt64 => Ok(ast::DataType::UnsignedBigInt(None)), + DataType::Float16 => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Float32 => Ok(ast::DataType::Float(None)), + DataType::Float64 => Ok(ast::DataType::Double), + DataType::Timestamp(_, _) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Date32 => Ok(ast::DataType::Date), + DataType::Date64 => Ok(ast::DataType::Datetime(None)), + DataType::Time32(_) => todo!(), + DataType::Time64(_) => todo!(), + DataType::Duration(_) => todo!(), + DataType::Interval(_) => todo!(), + DataType::Binary => todo!(), + DataType::FixedSizeBinary(_) => todo!(), + DataType::LargeBinary => todo!(), + DataType::Utf8 => Ok(ast::DataType::Varchar(None)), + DataType::LargeUtf8 => Ok(ast::DataType::Text), + DataType::List(_) => todo!(), + DataType::FixedSizeList(_, _) => todo!(), + DataType::LargeList(_) => todo!(), + DataType::Struct(_) => todo!(), + DataType::Union(_, _) => todo!(), + DataType::Dictionary(_, _) => todo!(), + DataType::Decimal128(_, _) => todo!(), + DataType::Decimal256(_, _) => todo!(), + DataType::Map(_, _) => todo!(), + DataType::RunEndEncoded(_, _) => todo!(), + } + } } #[cfg(test)] mod tests { use datafusion_common::TableReference; - use datafusion_expr::{col, lit}; + use datafusion_expr::{col, expr::AggregateFunction, lit}; use crate::unparser::dialect::CustomDialect; @@ -316,14 +490,81 @@ mod tests { #[test] fn expr_to_sql_ok() -> Result<()> { - let tests: Vec<(Expr, &str)> = vec![( - Expr::Column(Column { - relation: Some(TableReference::partial("a", "b")), - name: "c".to_string(), - }) - .gt(lit(4)), - r#"a.b.c > 4"#, - )]; + let tests: Vec<(Expr, &str)> = vec![ + ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), + ( + Expr::Column(Column { + relation: Some(TableReference::partial("a", "b")), + name: "c".to_string(), + }) + .gt(lit(4)), + r#"(a.b.c > 4)"#, + ), + ( + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Date64, + }), + r#"CAST(a AS DATETIME)"#, + ), + ( + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::UInt32, + }), + r#"CAST(a AS INTEGER UNSIGNED)"#, + ), + ( + Expr::Literal(ScalarValue::Date64(Some(0))), + r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, + ), + ( + Expr::Literal(ScalarValue::Date64(Some(10000))), + r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#, + ), + ( + Expr::Literal(ScalarValue::Date64(Some(-10000))), + r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#, + ), + ( + Expr::Literal(ScalarValue::Date32(Some(0))), + r#"CAST('1970-01-01' AS DATE)"#, + ), + ( + Expr::Literal(ScalarValue::Date32(Some(10))), + r#"CAST('1970-01-11' AS DATE)"#, + ), + ( + Expr::Literal(ScalarValue::Date32(Some(-1))), + r#"CAST('1969-12-31' AS DATE)"#, + ), + ( + Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + datafusion_expr::AggregateFunction::Sum, + ), + args: vec![col("a")], + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }), + "SUM(a)", + ), + ( + Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + datafusion_expr::AggregateFunction::Count, + ), + args: vec![Expr::Wildcard { qualifier: None }], + distinct: true, + filter: None, + order_by: None, + null_treatment: None, + }), + "COUNT(DISTINCT *)", + ), + ]; for (expr, expected) in tests { let ast = expr_to_sql(&expr)?; @@ -346,7 +587,7 @@ mod tests { let actual = format!("{}", ast); - let expected = r#"'a' > 4"#; + let expected = r#"('a' > 4)"#; assert_eq!(actual, expected); Ok(()) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index fdf7ab8c3d28..a6ea22db9651 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4493,8 +4493,18 @@ impl TableSource for EmptyTable { #[test] fn roundtrip_expr() { let tests: Vec<(TableReference, &str, &str)> = vec![ - (TableReference::bare("person"), "age > 35", "age > 35"), - (TableReference::bare("person"), "id = '10'", "id = '10'"), + (TableReference::bare("person"), "age > 35", "(age > 35)"), + (TableReference::bare("person"), "id = '10'", "(id = '10')"), + ( + TableReference::bare("person"), + "CAST(id AS VARCHAR)", + "CAST(id AS VARCHAR)", + ), + ( + TableReference::bare("person"), + "SUM((age * 2))", + "SUM((age * 2))", + ), ]; let roundtrip = |table, sql: &str| -> Result { @@ -4540,15 +4550,15 @@ fn roundtrip_statement() { ), ( "select ta.j1_id from j1 ta where ta.j1_id > 1;", - r#"SELECT ta.j1_id FROM j1 AS ta WHERE ta.j1_id > 1"#, + r#"SELECT ta.j1_id FROM j1 AS ta WHERE (ta.j1_id > 1)"#, ), ( - "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on ta.j1_id = tb.j2_id;", - r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON ta.j1_id = tb.j2_id"#, + "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);", + r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON (ta.j1_id = tb.j2_id)"#, ), ( - "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on ta.j1_id = tb.j2_id join j3 tc on ta.j1_id = tc.j3_id;", - r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN j2 AS tb ON ta.j1_id = tb.j2_id JOIN j3 AS tc ON ta.j1_id = tc.j3_id"#, + "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);", + r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN j2 AS tb ON (ta.j1_id = tb.j2_id) JOIN j3 AS tc ON (ta.j1_id = tc.j3_id)"#, ), ];