diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 90be576a884e..1eef1b718ba6 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -56,6 +56,7 @@ strum = { version = "0.26.1", features = ["derive"] } ctor = { workspace = true } datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } env_logger = { workspace = true } paste = "^1.0" diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8864c97bb1ff..1d0327fadbe4 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion_expr::expr::Unnest; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, @@ -466,7 +467,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), - Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::Unnest(unnest) => self.unnest_to_sql(unnest), } } @@ -1340,6 +1341,29 @@ impl Unparser<'_> { } } + /// Converts an UNNEST operation to an AST expression by wrapping it as a function call, + /// since there is no direct representation for UNNEST in the AST. + fn unnest_to_sql(&self, unnest: &Unnest) -> Result { + let args = self.function_args_to_sql(std::slice::from_ref(&unnest.expr))?; + + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: "UNNEST".to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } + fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { match data_type { DataType::Null => { @@ -1855,6 +1879,15 @@ mod tests { }), r#"CAST(a AS DECIMAL(12,0))"#, ), + ( + Expr::Unnest(Unnest { + expr: Box::new(Expr::Column(Column { + relation: Some(TableReference::partial("schema", "table")), + name: "array_col".to_string(), + })), + }), + r#"UNNEST("schema"."table".array_col)"#, + ), ]; for (expr, expected) in tests { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 695027374fa0..7c9054656b94 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -26,8 +26,9 @@ use super::{ subquery_alias_inner_query_and_columns, TableAliasRewriter, }, utils::{ - find_agg_node_within_select, find_window_nodes_within_select, - unproject_sort_expr, unproject_window_exprs, + find_agg_node_within_select, find_unnest_node_within_select, + find_window_nodes_within_select, unproject_sort_expr, unproject_unnest_expr, + unproject_window_exprs, }, Unparser, }; @@ -173,15 +174,24 @@ impl Unparser<'_> { p: &Projection, select: &mut SelectBuilder, ) -> Result<()> { + let mut exprs = p.expr.clone(); + + // If an Unnest node is found within the select, find and unproject the unnest column + if let Some(unnest) = find_unnest_node_within_select(plan) { + exprs = exprs + .into_iter() + .map(|e| unproject_unnest_expr(e, unnest)) + .collect::>>()?; + }; + match ( find_agg_node_within_select(plan, true), find_window_nodes_within_select(plan, None, true), ) { (Some(agg), window) => { let window_option = window.as_deref(); - let items = p - .expr - .iter() + let items = exprs + .into_iter() .map(|proj_expr| { let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?; self.select_item_to_sql(&unproj) @@ -198,9 +208,8 @@ impl Unparser<'_> { )); } (None, Some(window)) => { - let items = p - .expr - .iter() + let items = exprs + .into_iter() .map(|proj_expr| { let unproj = unproject_window_exprs(proj_expr, &window)?; self.select_item_to_sql(&unproj) @@ -210,8 +219,7 @@ impl Unparser<'_> { select.projection(items); } _ => { - let items = p - .expr + let items = exprs .iter() .map(|e| self.select_item_to_sql(e)) .collect::>>()?; @@ -318,7 +326,8 @@ impl Unparser<'_> { if let Some(agg) = find_agg_node_within_select(plan, select.already_projected()) { - let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?; + let unprojected = + unproject_agg_exprs(filter.predicate.clone(), agg, None)?; let filter_expr = self.expr_to_sql(&unprojected)?; select.having(Some(filter_expr)); } else { @@ -596,6 +605,28 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), + LogicalPlan::Unnest(unnest) => { + if !unnest.struct_type_columns.is_empty() { + return internal_err!( + "Struct type columns are not currently supported in UNNEST: {:?}", + unnest.struct_type_columns + ); + } + + // In the case of UNNEST, the Unnest node is followed by a duplicate Projection node that we should skip. + // Otherwise, there will be a duplicate SELECT clause. + // | Projection: table.col1, UNNEST(table.col2) + // | Unnest: UNNEST(table.col2) + // | Projection: table.col1, table.col2 AS UNNEST(table.col2) + // | Filter: table.col3 = Int64(3) + // | TableScan: table projection=None + if let LogicalPlan::Projection(p) = unnest.input.as_ref() { + // continue with projection input + self.select_to_sql_recursively(&p.input, query, select, relation) + } else { + internal_err!("Unnest input is not a Projection: {unnest:?}") + } + } _ => not_impl_err!("Unsupported operator: {plan:?}"), } } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 5e3a3aa600b6..d3d1bf351384 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -23,8 +23,8 @@ use datafusion_common::{ Column, Result, ScalarValue, }; use datafusion_expr::{ - utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr, - Window, + expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, + SortExpr, Unnest, Window, }; use sqlparser::ast; @@ -62,6 +62,28 @@ pub(crate) fn find_agg_node_within_select( } } +/// Recursively searches children of [LogicalPlan] to find Unnest node if exist +pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> { + // Note that none of the nodes that have a corresponding node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return None; + } else { + input.first()? + }; + + if let LogicalPlan::Unnest(unnest) = input { + Some(unnest) + } else if let LogicalPlan::TableScan(_) = input { + None + } else if let LogicalPlan::Projection(_) = input { + None + } else { + find_unnest_node_within_select(input) + } +} + /// Recursively searches children of [LogicalPlan] to find Window nodes if exist /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). /// If Window node is not found prior to this or at all before reaching the end @@ -104,18 +126,46 @@ pub(crate) fn find_window_nodes_within_select<'a>( } } +/// Recursively identify Column expressions and transform them into the appropriate unnest expression +/// +/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" +/// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) +pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { + expr.transform(|sub_expr| { + if let Expr::Column(col_ref) = &sub_expr { + // Check if the column is among the columns to run unnest on. + // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. + if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if let Ok(idx) = unnest.schema.index_of_column(col_ref) { + if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { + if let Some(unprojected_expr) = expr.get(idx) { + let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone())); + return Ok(Transformed::yes(unnest_expr)); + } + } + } + return internal_err!( + "Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name + ); + } + } + + Ok(Transformed::no(sub_expr)) + + }).map(|e| e.data) +} + /// Recursively identify all Column expressions and transform them into the appropriate /// aggregate expression contained in agg. /// /// For example, if expr contains the column expr "COUNT(*)" it will be transformed /// into an actual aggregate expression COUNT(*) as identified in the aggregate node. pub(crate) fn unproject_agg_exprs( - expr: &Expr, + expr: Expr, agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { - expr.clone() - .transform(|sub_expr| { + expr.transform(|sub_expr| { if let Expr::Column(c) = sub_expr { if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) @@ -123,7 +173,7 @@ pub(crate) fn unproject_agg_exprs( windows.and_then(|w| find_window_expr(w, &c.name).cloned()) { // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - return Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?)); + return Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)); } else { internal_err!( "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name @@ -141,20 +191,19 @@ pub(crate) fn unproject_agg_exprs( /// /// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed /// into an actual window expression as identified in the window node. -pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result { - expr.clone() - .transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { - if let Some(unproj) = find_window_expr(windows, &c.name) { - Ok(Transformed::yes(unproj.clone())) - } else { - Ok(Transformed::no(Expr::Column(c))) - } +pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { + expr.transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { + if let Some(unproj) = find_window_expr(windows, &c.name) { + Ok(Transformed::yes(unproj.clone())) } else { - Ok(Transformed::no(sub_expr)) + Ok(Transformed::no(Expr::Column(c))) } - }) - .map(|e| e.data) + } else { + Ok(Transformed::no(sub_expr)) + } + }) + .map(|e| e.data) } fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { @@ -218,7 +267,7 @@ pub(crate) fn unproject_sort_expr( // In case of aggregation there could be columns containing aggregation functions we need to unproject if let Some(agg) = agg { if agg.schema.is_column_from_schema(col_ref) { - let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?; + let new_expr = unproject_agg_exprs(sort_expr.expr, agg, None)?; sort_expr.expr = new_expr; return Ok(sort_expr); } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 8e25c1c5b1cd..37082fa64a06 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -24,6 +24,7 @@ use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_u use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder}; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; +use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::rank::rank_udwf; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -711,7 +712,8 @@ where .with_aggregate_function(max_udaf()) .with_aggregate_function(grouping_udaf()) .with_window_function(rank_udwf()) - .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())), + .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())) + .with_scalar_function(make_array_udf()), }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -1069,3 +1071,18 @@ FROM person GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(), ); } + +#[test] +fn test_unnest_to_sql() { + sql_round_trip( + GenericDialect {}, + r#"SELECT unnest(array_col) as u1, struct_col, array_col FROM unnest_table WHERE array_col != NULL ORDER BY struct_col, array_col"#, + r#"SELECT UNNEST(unnest_table.array_col) AS u1, unnest_table.struct_col, unnest_table.array_col FROM unnest_table WHERE (unnest_table.array_col <> NULL) ORDER BY unnest_table.struct_col ASC NULLS LAST, unnest_table.array_col ASC NULLS LAST"#, + ); + + sql_round_trip( + GenericDialect {}, + r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#, + r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS u1"#, + ); +}