diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index dbf9cc88d784..839255cea5bc 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -800,6 +800,7 @@ dependencies = [ "arrow-schema", "datafusion-common", "datafusion-expr", + "log", "sqlparser", ] diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 8b20fc5d61f9..2c014068dff8 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -121,7 +121,8 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { | Expr::Sort { .. } | Expr::WindowFunction { .. } | Expr::Wildcard - | Expr::QualifiedWildcard { .. } => { + | Expr::QualifiedWildcard { .. } + | Expr::Placeholder { .. } => { *self.is_applicable = false; Recursion::Stop(self) } diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 596c1888f30d..bbfa1b6e120f 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -344,6 +344,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( "Create physical name does not support qualified wildcard".to_string(), )), + Expr::Placeholder { .. } => Err(DataFusionError::Internal( + "Create physical name does not support placeholder".to_string(), + )), } } @@ -1031,6 +1034,14 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: CreateExternalTable".to_string(), )) } + LogicalPlan::Prepare(_) => { + // There is no default plan for "PREPARE" -- it must be + // handled at a higher level (so that the appropriate + // statement can be prepared) + Err(DataFusionError::Internal( + "Unsupported logical plan: Prepare".to_string(), + )) + } LogicalPlan::CreateCatalogSchema(_) => { // There is no default plan for "CREATE SCHEMA". // It must be handled at a higher level (so diff --git a/datafusion/core/tests/sqllogictests/test_files/prepare.slt b/datafusion/core/tests/sqllogictests/test_files/prepare.slt new file mode 100644 index 000000000000..948a2e3bc830 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/prepare.slt @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Prepare Statement Tests +########## + +statement ok +create table person (id int, first_name varchar, last_name varchar, age int, state varchar, salary double, birthday timestamp, "😀" int) as values (1, 'jane', 'smith', 20, 'MA', 100000.45, '2000-11-12T00:00:00'::timestamp, 99); + +query C rowsort +select * from person; +---- +1 jane smith 20 MA 100000.45 2000-11-12T00:00:00.000000000 99 + +# Error due to syntax and semantic violation + +# Syntax error: no name specified after the keyword prepare +statement error +PREPARE AS SELECT id, age FROM person WHERE age = $foo; + +# param following a non-number, $foo, not supported +statement error +PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo; + +# not specify table hence cannot specify columns +statement error +PREPARE my_plan(INT) AS SELECT id + $1; + +# not specify data types for all params +statement error +PREPARE my_plan(INT) AS SELECT 1 + $1 + $2; + +# cannot use IS param +statement error +PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1; + +# ####################### +# TODO: all the errors below should work ok after we store the prepare logical plan somewhere +statement error +PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); + +statement error +PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10; + +statement error +PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10; + +statement error +PREPARE my_plan(INT) AS SELECT $1; + +statement error +PREPARE my_plan(INT) AS SELECT 1 + $1; + +statement error +PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2; + +statement error +PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1; + +statement error +PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS SELECT id, age, $6 FROM person WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; + +statement error +PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS SELECT id, SUM(age) FROM person WHERE salary > $2 GROUP BY id HAVING sum(age) < $1 AND SUM(age) > 10 OR SUM(age) in ($3, $4); + +statement error +PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); + diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8d0d5c18095a..fbc98cf01a20 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -244,6 +244,14 @@ pub enum Expr { /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list GroupingSet(GroupingSet), + /// A place holder for parameters in a prepared statement + /// (e.g. `$foo` or `$1`) + Placeholder { + /// The identifier of the parameter (e.g, $1 or $foo) + id: String, + /// The type the parameter will be filled in with + data_type: DataType, + }, } /// Binary expression @@ -528,6 +536,7 @@ impl Expr { Expr::Literal(..) => "Literal", Expr::Negative(..) => "Negative", Expr::Not(..) => "Not", + Expr::Placeholder { .. } => "Placeholder", Expr::QualifiedWildcard { .. } => "QualifiedWildcard", Expr::ScalarFunction { .. } => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", @@ -980,6 +989,7 @@ impl fmt::Debug for Expr { ) } }, + Expr::Placeholder { id, .. } => write!(f, "{}", id), } } } @@ -1263,6 +1273,7 @@ fn create_name(e: &Expr) -> Result { Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( "Create name does not support qualified wildcard".to_string(), )), + Expr::Placeholder { id, .. } => Ok((*id).to_string()), } } diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index fa7e00d0f5fa..b107d591769f 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -291,6 +291,7 @@ impl ExprRewritable for Expr { key, )) } + Expr::Placeholder { id, data_type } => Expr::Placeholder { id, data_type }, }; // now rewrite this expression itself diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 8424fa2aa2d1..ae516001bc07 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -127,6 +127,7 @@ impl ExprSchemable for Expr { Expr::Like { .. } | Expr::ILike { .. } | Expr::SimilarTo { .. } => { Ok(DataType::Boolean) } + Expr::Placeholder { data_type, .. } => Ok(data_type.clone()), Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -198,7 +199,8 @@ impl ExprSchemable for Expr { | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) - | Expr::Exists { .. } => Ok(false), + | Expr::Exists { .. } + | Expr::Placeholder { .. } => Ok(true), Expr::InSubquery { expr, .. } => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs index bd839f098fc3..b5c6c6802555 100644 --- a/datafusion/expr/src/expr_visitor.rs +++ b/datafusion/expr/src/expr_visitor.rs @@ -133,7 +133,8 @@ impl ExprVisitable for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard - | Expr::QualifiedWildcard { .. } => Ok(visitor), + | Expr::QualifiedWildcard { .. } + | Expr::Placeholder { .. } => Ok(visitor), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { let visitor = left.accept(visitor)?; right.accept(visitor) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index caeffa2def6f..28d3ccc919b6 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -26,9 +26,9 @@ use crate::{and, binary_expr, Operator}; use crate::{ logical_plan::{ Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Projection, - Repartition, Sort, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Values, - Window, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Projection, Repartition, Sort, SubqueryAlias, TableScan, ToStringifiedPlan, + Union, Values, Window, }, utils::{ can_hash, expand_qualified_wildcard, expand_wildcard, @@ -118,6 +118,8 @@ impl LogicalPlanBuilder { /// By default, it assigns the names column1, column2, etc. to the columns of a VALUES table. /// The column names are not specified by the SQL standard and different database systems do it differently, /// so it's usually better to override the default names with a table alias list. + /// + /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. pub fn values(mut values: Vec>) -> Result { if values.is_empty() { return Err(DataFusionError::Plan("Values list cannot be empty".into())); @@ -279,6 +281,15 @@ impl LogicalPlanBuilder { )?))) } + /// Make a builder for a prepare logical plan from the builder's plan + pub fn prepare(&self, name: String, data_types: Vec) -> Result { + Ok(Self::from(LogicalPlan::Prepare(Prepare { + name, + data_types, + input: Arc::new(self.plan.clone()), + }))) + } + /// Limit the number of rows returned /// /// `skip` - Number of rows to skip before fetch any row. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 2cfe921e67b3..9d26d2a6554e 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -25,7 +25,7 @@ pub use plan::{ Aggregate, Analyze, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CreateView, CrossJoin, Distinct, DropTable, DropView, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, - LogicalPlan, Partitioning, PlanType, PlanVisitor, Projection, Repartition, + LogicalPlan, Partitioning, PlanType, PlanVisitor, Prepare, Projection, Repartition, SetVariable, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Values, Window, }; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e7fa9c39d90f..7f38e7dbb2ef 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -110,6 +110,8 @@ pub enum LogicalPlan { Distinct(Distinct), /// Set a Variable SetVariable(SetVariable), + /// Prepare a statement + Prepare(Prepare), } impl LogicalPlan { @@ -136,6 +138,7 @@ impl LogicalPlan { LogicalPlan::CreateExternalTable(CreateExternalTable { schema, .. }) => { schema } + LogicalPlan::Prepare(Prepare { input, .. }) => input.schema(), LogicalPlan::Explain(explain) => &explain.schema, LogicalPlan::Analyze(analyze) => &analyze.schema, LogicalPlan::Extension(extension) => extension.node.schema(), @@ -203,8 +206,9 @@ impl LogicalPlan { | LogicalPlan::Sort(Sort { input, .. }) | LogicalPlan::CreateMemoryTable(CreateMemoryTable { input, .. }) | LogicalPlan::CreateView(CreateView { input, .. }) - | LogicalPlan::Filter(Filter { input, .. }) => input.all_schemas(), - LogicalPlan::Distinct(Distinct { input, .. }) => input.all_schemas(), + | LogicalPlan::Filter(Filter { input, .. }) + | LogicalPlan::Distinct(Distinct { input, .. }) + | LogicalPlan::Prepare(Prepare { input, .. }) => input.all_schemas(), LogicalPlan::DropTable(_) | LogicalPlan::DropView(_) | LogicalPlan::SetVariable(_) => vec![], @@ -273,7 +277,8 @@ impl LogicalPlan { | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) - | LogicalPlan::Distinct(_) => { + | LogicalPlan::Distinct(_) + | LogicalPlan::Prepare(_) => { vec![] } } @@ -302,7 +307,8 @@ impl LogicalPlan { LogicalPlan::Explain(explain) => vec![&explain.plan], LogicalPlan::Analyze(analyze) => vec![&analyze.input], LogicalPlan::CreateMemoryTable(CreateMemoryTable { input, .. }) - | LogicalPlan::CreateView(CreateView { input, .. }) => { + | LogicalPlan::CreateView(CreateView { input, .. }) + | LogicalPlan::Prepare(Prepare { input, .. }) => { vec![input] } // plans without inputs @@ -450,9 +456,8 @@ impl LogicalPlan { input.accept(visitor)? } LogicalPlan::CreateMemoryTable(CreateMemoryTable { input, .. }) - | LogicalPlan::CreateView(CreateView { input, .. }) => { - input.accept(visitor)? - } + | LogicalPlan::CreateView(CreateView { input, .. }) + | LogicalPlan::Prepare(Prepare { input, .. }) => input.accept(visitor)?, LogicalPlan::Extension(extension) => { for input in extension.node.inputs() { if !input.accept(visitor)? { @@ -963,6 +968,11 @@ impl LogicalPlan { LogicalPlan::Analyze { .. } => write!(f, "Analyze"), LogicalPlan::Union(_) => write!(f, "Union"), LogicalPlan::Extension(e) => e.node.fmt_for_explain(f), + LogicalPlan::Prepare(Prepare { + name, data_types, .. + }) => { + write!(f, "Prepare: {:?} {:?} ", name, data_types) + } } } } @@ -1373,6 +1383,18 @@ pub struct CreateExternalTable { pub options: HashMap, } +/// Prepare a statement but do not execute it. Prepare statements can have 0 or more +/// `Expr::Placeholder` expressions that are filled in during execution +#[derive(Clone)] +pub struct Prepare { + /// The name of the statement + pub name: String, + /// Data types of the parameters ([`Expr::Placeholder`]) + pub data_types: Vec, + /// The logical plan of the statements + pub input: Arc, +} + /// Produces a relation with string representations of /// various parts of the plan #[derive(Clone)] diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index fcb0365b5325..88631cc6f07b 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -22,8 +22,8 @@ use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use crate::logical_plan::builder::build_join_schema; use crate::logical_plan::{ Aggregate, Analyze, CreateMemoryTable, CreateView, Distinct, Extension, Filter, Join, - Limit, Partitioning, Projection, Repartition, Sort, Subquery, SubqueryAlias, Union, - Values, Window, + Limit, Partitioning, Prepare, Projection, Repartition, Sort, Subquery, SubqueryAlias, + Union, Values, Window, }; use crate::{Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; use arrow::datatypes::{DataType, TimeUnit}; @@ -126,7 +126,8 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { | Expr::ScalarSubquery(_) | Expr::Wildcard | Expr::QualifiedWildcard { .. } - | Expr::GetIndexedField { .. } => {} + | Expr::GetIndexedField { .. } + | Expr::Placeholder { .. } => {} } Ok(Recursion::Continue(self)) } @@ -579,6 +580,13 @@ pub fn from_plan( Ok(plan.clone()) } + LogicalPlan::Prepare(Prepare { + name, data_types, .. + }) => Ok(LogicalPlan::Prepare(Prepare { + name: name.clone(), + data_types: data_types.clone(), + input: Arc::new(inputs[0].clone()), + })), LogicalPlan::EmptyRelation(_) | LogicalPlan::TableScan { .. } | LogicalPlan::CreateExternalTable(_) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 11ed5cdbebb0..482298e160e3 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -240,7 +240,8 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::DropView(_) | LogicalPlan::SetVariable(_) | LogicalPlan::Distinct(_) - | LogicalPlan::Extension(_) => { + | LogicalPlan::Extension(_) + | LogicalPlan::Prepare(_) => { // apply the optimization to all inputs of the plan utils::optimize_children(self, plan, optimizer_config) } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 3cedddc600d1..2d156d1ce398 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -391,7 +391,8 @@ fn optimize_plan( | LogicalPlan::SetVariable(_) | LogicalPlan::CrossJoin(_) | LogicalPlan::Distinct(_) - | LogicalPlan::Extension { .. } => { + | LogicalPlan::Extension { .. } + | LogicalPlan::Prepare(_) => { let expr = plan.expressions(); // collect all required columns by this plan exprlist_to_columns(&expr, &mut new_required_columns)?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index b32fc53dbaff..3a51099fe645 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -253,7 +253,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::Sort { .. } | Expr::GroupingSet(_) | Expr::Wildcard - | Expr::QualifiedWildcard { .. } => false, + | Expr::QualifiedWildcard { .. } + | Expr::Placeholder { .. } => false, Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility), Expr::Literal(_) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d5284ef5956d..97ba57a7e0da 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -71,6 +71,7 @@ message LogicalPlanNode { DistinctNode distinct = 23; ViewTableScanNode view_scan = 24; CustomTableScanNode custom_scan = 25; + PrepareNode prepare = 26; } } @@ -181,6 +182,12 @@ message CreateExternalTableNode { map options = 11; } +message PrepareNode { + string name = 1; + repeated ArrowType data_types = 2; + LogicalPlanNode input = 3; + } + message CreateCatalogSchemaNode { string schema_name = 1; bool if_not_exists = 2; @@ -345,9 +352,16 @@ message LogicalExprNode { ILikeNode ilike = 32; SimilarToNode similar_to = 33; + PlaceholderNode placeholder = 34; + } } +message PlaceholderNode { + string id = 1; + ArrowType data_type = 2; +} + message LogicalExprList { repeated LogicalExprNode expr = 1; } diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 98ffbd240a3b..5d022695862d 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -19,7 +19,7 @@ use crate::protobuf::plan_type::PlanTypeEnum::{ FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, OptimizedPhysicalPlan, }; -use crate::protobuf::{self}; +use crate::protobuf::{self, PlaceholderNode}; use crate::protobuf::{ CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, RollupNode, @@ -1223,6 +1223,16 @@ pub fn parse_expr( .collect::, Error>>()?, ))) } + ExprType::Placeholder(PlaceholderNode { id, data_type }) => match data_type { + None => { + let message = format!("Protobuf deserialization error: data type must be provided for the placeholder {}", id); + Err(proto_error(message)) + } + Some(data_type) => Ok(Expr::Placeholder { + id: id.clone(), + data_type: data_type.try_into()?, + }), + }, } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 97b796257c0a..13236a935839 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -11040,6 +11040,9 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::SimilarTo(v) => { struct_ser.serialize_field("similarTo", v)?; } + logical_expr_node::ExprType::Placeholder(v) => { + struct_ser.serialize_field("placeholder", v)?; + } } } struct_ser.end() @@ -11106,6 +11109,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "ilike", "similar_to", "similarTo", + "placeholder", ]; #[allow(clippy::enum_variant_names)] @@ -11143,6 +11147,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { Like, Ilike, SimilarTo, + Placeholder, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11197,6 +11202,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "like" => Ok(GeneratedField::Like), "ilike" => Ok(GeneratedField::Ilike), "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), + "placeholder" => Ok(GeneratedField::Placeholder), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11447,6 +11453,13 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("similarTo")); } expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::SimilarTo) +; + } + GeneratedField::Placeholder => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholder")); + } + expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Placeholder) ; } } @@ -11655,6 +11668,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::CustomScan(v) => { struct_ser.serialize_field("customScan", v)?; } + logical_plan_node::LogicalPlanType::Prepare(v) => { + struct_ser.serialize_field("prepare", v)?; + } } } struct_ser.end() @@ -11701,6 +11717,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "viewScan", "custom_scan", "customScan", + "prepare", ]; #[allow(clippy::enum_variant_names)] @@ -11729,6 +11746,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { Distinct, ViewScan, CustomScan, + Prepare, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11774,6 +11792,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "distinct" => Ok(GeneratedField::Distinct), "viewScan" | "view_scan" => Ok(GeneratedField::ViewScan), "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), + "prepare" => Ok(GeneratedField::Prepare), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11962,6 +11981,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("customScan")); } logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CustomScan) +; + } + GeneratedField::Prepare => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("prepare")); + } + logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Prepare) ; } } @@ -16419,6 +16445,115 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PlaceholderNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.id.is_empty() { + len += 1; + } + if self.data_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderNode", len)?; + if !self.id.is_empty() { + struct_ser.serialize_field("id", &self.id)?; + } + if let Some(v) = self.data_type.as_ref() { + struct_ser.serialize_field("dataType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PlaceholderNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "id", + "data_type", + "dataType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Id, + DataType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "id" => Ok(GeneratedField::Id), + "dataType" | "data_type" => Ok(GeneratedField::DataType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PlaceholderNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PlaceholderNode") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut id__ = None; + let mut data_type__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Id => { + if id__.is_some() { + return Err(serde::de::Error::duplicate_field("id")); + } + id__ = Some(map.next_value()?); + } + GeneratedField::DataType => { + if data_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dataType")); + } + data_type__ = map.next_value()?; + } + } + } + Ok(PlaceholderNode { + id: id__.unwrap_or_default(), + data_type: data_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -16580,6 +16715,132 @@ impl<'de> serde::Deserialize<'de> for PlanType { deserializer.deserialize_struct("datafusion.PlanType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PrepareNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { + len += 1; + } + if !self.data_types.is_empty() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PrepareNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if !self.data_types.is_empty() { + struct_ser.serialize_field("dataTypes", &self.data_types)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PrepareNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + "data_types", + "dataTypes", + "input", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + DataTypes, + Input, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + "dataTypes" | "data_types" => Ok(GeneratedField::DataTypes), + "input" => Ok(GeneratedField::Input), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PrepareNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PrepareNode") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + let mut data_types__ = None; + let mut input__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map.next_value()?); + } + GeneratedField::DataTypes => { + if data_types__.is_some() { + return Err(serde::de::Error::duplicate_field("dataTypes")); + } + data_types__ = Some(map.next_value()?); + } + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map.next_value()?; + } + } + } + Ok(PrepareNode { + name: name__.unwrap_or_default(), + data_types: data_types__.unwrap_or_default(), + input: input__, + }) + } + } + deserializer.deserialize_struct("datafusion.PrepareNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ProjectionColumns { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6bfb6b96c32a..1405e1eba638 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -33,7 +33,7 @@ pub struct DfSchema { pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26" )] pub logical_plan_type: ::core::option::Option, } @@ -89,6 +89,8 @@ pub mod logical_plan_node { ViewScan(::prost::alloc::boxed::Box), #[prost(message, tag = "25")] CustomScan(super::CustomTableScanNode), + #[prost(message, tag = "26")] + Prepare(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -275,6 +277,15 @@ pub struct CreateExternalTableNode { >, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrepareNode { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "2")] + pub data_types: ::prost::alloc::vec::Vec, + #[prost(message, optional, boxed, tag = "3")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateCatalogSchemaNode { #[prost(string, tag = "1")] pub schema_name: ::prost::alloc::string::String, @@ -406,7 +417,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34" )] pub expr_type: ::core::option::Option, } @@ -488,9 +499,18 @@ pub mod logical_expr_node { Ilike(::prost::alloc::boxed::Box), #[prost(message, tag = "33")] SimilarTo(::prost::alloc::boxed::Box), + #[prost(message, tag = "34")] + Placeholder(super::PlaceholderNode), } } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PlaceholderNode { + #[prost(string, tag = "1")] + pub id: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub data_type: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct LogicalExprList { #[prost(message, repeated, tag = "1")] pub expr: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs index f0b6d109507a..d9423f8de69e 100644 --- a/datafusion/proto/src/logical_plan.rs +++ b/datafusion/proto/src/logical_plan.rs @@ -25,7 +25,7 @@ use crate::{ }, to_proto, }; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::datasource::TableProvider; use datafusion::{ datasource::{ @@ -39,7 +39,7 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_common::{context, Column, DataFusionError, OwnedTableReference}; -use datafusion_expr::logical_plan::builder::project; +use datafusion_expr::logical_plan::{builder::project, Prepare}; use datafusion_expr::{ logical_plan::{ Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, @@ -816,6 +816,18 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } + LogicalPlanType::Prepare(prepare) => { + let input: LogicalPlan = + into_logical_plan!(prepare.input, ctx, extension_codec)?; + let data_types: Vec = prepare + .data_types + .iter() + .map(DataType::try_from) + .collect::>()?; + LogicalPlanBuilder::from(input) + .prepare(prepare.name.clone(), data_types)? + .build() + } } } @@ -1377,6 +1389,28 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } + LogicalPlan::Prepare(Prepare { + name, + data_types, + input, + }) => { + let input = protobuf::LogicalPlanNode::try_from_logical_plan( + input, + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Prepare(Box::new( + protobuf::PrepareNode { + name: name.clone(), + data_types: data_types + .iter() + .map(|t| t.try_into()) + .collect::, _>>()?, + input: Some(Box::new(input)), + }, + ))), + }) + } LogicalPlan::CreateMemoryTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for CreateMemoryTable", )), diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 133e2f89d54e..4c280f7b0370 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -27,7 +27,7 @@ use crate::protobuf::{ OptimizedLogicalPlan, OptimizedPhysicalPlan, }, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, - OptimizedPhysicalPlanType, RollupNode, + OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::datatypes::{ DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, @@ -888,6 +888,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .collect::, Self::Error>>()?, })), }, + Expr::Placeholder{ id, data_type } => Self { + expr_type: Some(ExprType::Placeholder(PlaceholderNode { id: id.clone(), data_type: Some(data_type.try_into()?) })), + }, Expr::QualifiedWildcard { .. } | Expr::TryCast { .. } => return Err(Error::General("Proto serialization error: Expr::QualifiedWildcard { .. } | Expr::TryCast { .. } not supported".to_string())), diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index decec707546c..5139bd2b7a6e 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -40,4 +40,5 @@ unicode_expressions = [] arrow-schema = "28.0.0" datafusion-common = { path = "../common", version = "15.0.0" } datafusion-expr = { path = "../expr", version = "15.0.0" } +log = "^0.4" sqlparser = "0.27" diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 8f0129c5cd5a..82d8c3834294 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -16,7 +16,7 @@ // under the License. //! SQL Query Planner (produces logical plan from SQL AST) - +use log::debug; use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::Arc; @@ -46,7 +46,6 @@ use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; use datafusion_expr::expr_rewriter::normalize_col; use datafusion_expr::expr_rewriter::normalize_col_with_schemas; use datafusion_expr::logical_plan::builder::project; -use datafusion_expr::logical_plan::Join as HashJoin; use datafusion_expr::logical_plan::JoinConstraint as HashJoinConstraint; use datafusion_expr::logical_plan::{ Analyze, CreateCatalog, CreateCatalogSchema, @@ -55,6 +54,7 @@ use datafusion_expr::logical_plan::{ Partitioning, PlanType, SetVariable, ToStringifiedPlan, }; use datafusion_expr::logical_plan::{Filter, Subquery}; +use datafusion_expr::logical_plan::{Join as HashJoin, Prepare}; use datafusion_expr::utils::{ can_hash, check_all_column_from_schema, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_column_exprs, @@ -120,13 +120,23 @@ impl Default for PlannerContext { } impl PlannerContext { - /// Create a new PlannerContext + /// Create an empty PlannerContext pub fn new() -> Self { Self { prepare_param_data_types: vec![], ctes: HashMap::new(), } } + + /// Create a new PlannerContext with provided prepare_param_data_types + pub fn new_with_prepare_param_data_types( + prepare_param_data_types: Vec, + ) -> Self { + Self { + prepare_param_data_types, + ctes: HashMap::new(), + } + } } /// SQL query planner @@ -197,6 +207,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logical plan from an SQL statement pub fn sql_statement_to_plan(&self, statement: Statement) -> Result { + self.sql_statement_to_plan_with_context(statement, &mut PlannerContext::new()) + } + + /// Generate a logical plan from an SQL statement + pub fn sql_statement_to_plan_with_context( + &self, + statement: Statement, + planner_context: &mut PlannerContext, + ) -> Result { let sql = Some(statement.to_string()); match statement { Statement::Explain { @@ -207,9 +226,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { describe_alias: _, .. } => self.explain_statement_to_plan(verbose, analyze, *statement), - Statement::Query(query) => { - self.query_to_plan(*query, &mut PlannerContext::new()) - } + Statement::Query(query) => self.query_to_plan(*query, planner_context), Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable), Statement::SetVariable { local, @@ -232,7 +249,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { && table_properties.is_empty() && with_options.is_empty() => { - let plan = self.query_to_plan(*query, &mut PlannerContext::new())?; + let plan = self.query_to_plan(*query, planner_context)?; let input_schema = plan.schema(); let plan = if !columns.is_empty() { @@ -323,7 +340,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } => { // We don't support cascade and purge for now. // nor do we support multiple object names - let name = match names.len() { 0 => Err(ParserError("Missing table name.".to_string()).into()), 1 => object_name_to_table_reference(names.pop().unwrap()), @@ -350,6 +366,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )), } } + Statement::Prepare { + name, + data_types, + statement, + } => { + // Convert parser data types to DataFusion data types + let data_types: Vec = data_types + .into_iter() + .map(|t| self.convert_data_type(&t)) + .collect::>()?; + + // Create planner context with parameters + let mut planner_context = + PlannerContext::new_with_prepare_param_data_types(data_types.clone()); + + // Build logical plan for inner statement of the prepare statement + let plan = self.sql_statement_to_plan_with_context( + *statement, + &mut planner_context, + )?; + Ok(LogicalPlan::Prepare(Prepare { + name: name.to_string(), + data_types, + input: Arc::new(plan), + })) + } Statement::ShowTables { extended, @@ -483,7 +525,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SetExpr::Select(s) => { self.select_to_plan(*s, planner_context, alias, outer_query_schema) } - SetExpr::Values(v) => self.sql_values_to_plan(v), + SetExpr::Values(v) => { + self.sql_values_to_plan(v, &planner_context.prepare_param_data_types) + } SetExpr::SetOperation { op, left, @@ -1088,6 +1132,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process `from` clause let plan = self.plan_from_tables(select.from, planner_context, outer_query_schema)?; + let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // build from schema for unqualifier column ambiguous check // we should get only one field for unqualifier column from schema. @@ -1786,7 +1831,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn sql_values_to_plan(&self, values: SQLValues) -> Result { + fn sql_values_to_plan( + &self, + values: SQLValues, + param_data_types: &[DataType], + ) -> Result { // values should not be based on any other schema let schema = DFSchema::empty(); let values = values @@ -1803,6 +1852,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::Literal(ScalarValue::Null)) } SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), + SQLExpr::Value(Value::Placeholder(param)) => { + Self::create_placeholder_expr(param, param_data_types) + } SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op( op, *expr, @@ -1842,6 +1894,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { LogicalPlanBuilder::values(values)?.build() } + /// Create a placeholder expression + /// This is the same as Postgres's prepare statement syntax in which a placeholder starts with `$` sign and then + /// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on. + fn create_placeholder_expr( + param: String, + param_data_types: &[DataType], + ) -> Result { + // Parse the placeholder as a number because it is the only support from sqlparser and postgres + let index = param[1..].parse::(); + let idx = match index { + Ok(index) => index - 1, + Err(_) => { + return Err(DataFusionError::Internal(format!( + "Invalid placeholder, not a number: {}", + param + ))) + } + }; + // Check if the placeholder is in the parameter list + if param_data_types.len() <= idx { + return Err(DataFusionError::Internal(format!( + "Placehoder {} does not exist in the parameter list: {:?}", + param, param_data_types + ))); + } + // Data type of the parameter + let param_type = param_data_types[idx].clone(); + debug!( + "type of param {} param_data_types[idx]: {:?}", + param, param_type + ); + + Ok(Expr::Placeholder { + id: param, + data_type: param_type, + }) + } + fn sql_expr_to_logical_expr( &self, sql: SQLExpr, @@ -1853,6 +1943,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::SingleQuotedString(ref s) | Value::DoubleQuotedString(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)), + SQLExpr::Value(Value::Placeholder(param)) => Self::create_placeholder_expr(param, &planner_context.prepare_param_data_types), SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { fun: BuiltinScalarFunction::DatePart, args: vec![ @@ -5326,6 +5417,21 @@ mod tests { assert_eq!(format!("{:?}", plan), expected); } + fn prepare_stmt_quick_test( + sql: &str, + expected_plan: &str, + expected_data_types: &str, + ) { + let plan = logical_plan(sql).unwrap(); + // verify plan + assert_eq!(format!("{:?}", plan), expected_plan); + // verify data types + if let LogicalPlan::Prepare(Prepare { data_types, .. }) = plan { + let dt = format!("{:?}", data_types); + assert_eq!(dt, expected_data_types); + } + } + struct MockContextProvider {} impl ContextProvider for MockContextProvider { @@ -6125,6 +6231,199 @@ mod tests { quick_test(sql, expected); } + #[test] + #[should_panic( + expected = "value: Internal(\"Invalid placeholder, not a number: $foo\"" + )] + fn test_prepare_statement_to_plan_panic_param_format() { + // param is not number following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; + + let expected_plan = "whatever"; + let expected_dt = "whatever"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + #[should_panic(expected = "value: SQL(ParserError(\"Expected AS, found: SELECT\"))")] + fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { + // param is not number following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; + + let expected_plan = "whatever"; + let expected_dt = "whatever"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + #[should_panic( + expected = "value: SchemaError(FieldNotFound { field: Column { relation: None, name: \"id\" }, valid_fields: Some([]) })" + )] + fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; + + let expected_plan = "whatever"; + let expected_dt = "whatever"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + #[should_panic( + expected = "value: Internal(\"Placehoder $2 does not exist in the parameter list: [Int32]\")" + )] + fn test_prepare_statement_to_plan_panic_no_data_types() { + // only provide 1 data type while using 2 params + let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; + + let expected_plan = "whatever"; + let expected_dt = "whatever"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + #[should_panic( + expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\"" + )] + fn test_prepare_statement_to_plan_panic_is_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; + + let expected_plan = "whatever"; + let expected_dt = "whatever"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + fn test_prepare_statement_to_plan_no_param() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + + let expected_plan = "Prepare: \"my_plan\" [Int32] \ + \n Projection: person.id, person.age\ + \n Filter: person.age = Int64(10)\ + \n TableScan: person"; + + let expected_dt = "[Int32]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + ///////////////////////// + // no embedded parameter and no declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + + let expected_plan = "Prepare: \"my_plan\" [] \ + \n Projection: person.id, person.age\ + \n Filter: person.age = Int64(10)\ + \n TableScan: person"; + + let expected_dt = "[]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + fn test_prepare_statement_to_plan_params_as_constants() { + let sql = "PREPARE my_plan(INT) AS SELECT $1"; + + let expected_plan = "Prepare: \"my_plan\" [Int32] \ + \n Projection: $1\n EmptyRelation"; + let expected_dt = "[Int32]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + ///////////////////////// + let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; + + let expected_plan = "Prepare: \"my_plan\" [Int32] \ + \n Projection: Int64(1) + $1\n EmptyRelation"; + let expected_dt = "[Int32]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + ///////////////////////// + let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; + + let expected_plan = "Prepare: \"my_plan\" [Int32, Float64] \ + \n Projection: Int64(1) + $1 + $2\n EmptyRelation"; + let expected_dt = "[Int32, Float64]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + fn test_prepare_statement_to_plan_one_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1"; + + let expected_plan = "Prepare: \"my_plan\" [Int32] \ + \n Projection: person.id, person.age\ + \n Filter: person.age = $1\ + \n TableScan: person"; + + let expected_dt = "[Int32]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + fn test_prepare_statement_to_plan_multi_params() { + let sql = "PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS + SELECT id, age, $6 + FROM person + WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; + + let expected_plan = "Prepare: \"my_plan\" [Int32, Utf8, Float64, Int32, Float64, Utf8] \ + \n Projection: person.id, person.age, $6\ + \n Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2\ + \n TableScan: person"; + + let expected_dt = "[Int32, Utf8, Float64, Int32, Float64, Utf8]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + fn test_prepare_statement_to_plan_having() { + let sql = "PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS + SELECT id, SUM(age) + FROM person \ + WHERE salary > $2 + GROUP BY id + HAVING sum(age) < $1 AND SUM(age) > 10 OR SUM(age) in ($3, $4)\ + "; + + let expected_plan = "Prepare: \"my_plan\" [Int32, Float64, Float64, Float64] \ + \n Projection: person.id, SUM(person.age)\ + \n Filter: SUM(person.age) < $1 AND SUM(person.age) > Int64(10) OR SUM(person.age) IN ([$3, $4])\ + \n Aggregate: groupBy=[[person.id]], aggr=[[SUM(person.age)]]\ + \n Filter: person.salary > $2\ + \n TableScan: person"; + + let expected_dt = "[Int32, Float64, Float64, Float64]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + + #[test] + fn test_prepare_statement_to_plan_value_list() { + let sql = "PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter);"; + + let expected_plan = "Prepare: \"my_plan\" [Utf8, Utf8] \ + \n Projection: num, letter\ + \n Projection: t.column1 AS num, t.column2 AS letter\ + \n SubqueryAlias: t\ + \n Values: (Int64(1), $1), (Int64(2), $2)"; + + let expected_dt = "[Utf8, Utf8]"; + + prepare_stmt_quick_test(sql, expected_plan, expected_dt); + } + #[test] fn test_table_alias() { let sql = "select * from (\ diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 6a0dd7c3f581..4b9ae3ae9e0c 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -411,6 +411,10 @@ where ))) } }, + Expr::Placeholder { id, data_type } => Ok(Expr::Placeholder { + id: id.clone(), + data_type: data_type.clone(), + }), }, } }