From b75c3c77a5fc03450b23275a34c9f68d81c8868a Mon Sep 17 00:00:00 2001 From: jonahgao Date: Fri, 1 Nov 2024 21:16:14 +0800 Subject: [PATCH 1/7] feat: basic support for executing prepared statements --- datafusion/core/src/execution/context/mod.rs | 28 +++++++++++++++--- .../core/src/execution/session_state.rs | 29 ++++++++++++++++++- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 333f83c673cc..47d0ff5000ca 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -42,7 +42,8 @@ use crate::{ logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, - DropView, LogicalPlan, LogicalPlanBuilder, SetVariable, TableType, UNNAMED_TABLE, + DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable, + TableType, UNNAMED_TABLE, }, physical_expr::PhysicalExpr, physical_plan::ExecutionPlan, @@ -54,9 +55,9 @@ use arrow::record_batch::RecordBatch; use arrow_schema::Schema; use datafusion_common::{ config::{ConfigExtension, TableOptions}, - exec_err, not_impl_err, plan_datafusion_err, plan_err, + exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, SchemaReference, TableReference, + DFSchema, ScalarValue, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -687,7 +688,26 @@ impl SessionContext { LogicalPlan::Statement(Statement::SetVariable(stmt)) => { self.set_variable(stmt).await } - + LogicalPlan::Prepare(Prepare { name, input, .. }) => { + self.state.write().store_prepared(name, input)?; + self.return_empty_dataframe() + } + LogicalPlan::Execute(Execute { + name, parameters, .. + }) => { + let plan = self.state.read().get_prepared(&name).ok_or_else(|| { + exec_datafusion_err!("Prepared statement '{}' not exists", name) + })?; + let values: Vec = parameters + .into_iter() + .map(|e| match e { + Expr::Literal(scalar) => Ok(scalar), + _ => exec_err!("Invalid parameter type"), + }) + .collect::>()?; + let plan = plan.as_ref().clone().with_param_values(values)?; + Ok(DataFrame::new(self.state(), plan)) + } plan => Ok(DataFrame::new(self.state(), plan)), } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d50c912dd2fd..46410e87812d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -40,7 +40,7 @@ use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, + config_err, exec_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, }; use datafusion_execution::config::SessionConfig; @@ -171,6 +171,9 @@ pub struct SessionState { /// It will be invoked on `CREATE FUNCTION` statements. /// thus, changing dialect o PostgreSql is required function_factory: Option>, + /// Cache logical plans of prepared statements for later execution. + /// Key is the prepared statement name. + prepared_plans: HashMap>, } impl Debug for SessionState { @@ -197,6 +200,7 @@ impl Debug for SessionState { .field("scalar_functions", &self.scalar_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) + .field("prepared_plans", &self.prepared_plans) .finish() } } @@ -906,6 +910,28 @@ impl SessionState { let udtf = self.table_functions.remove(name); Ok(udtf.map(|x| x.function().clone())) } + + /// Store the logical plan of a prepared statement. + pub(crate) fn store_prepared( + &mut self, + name: String, + plan: Arc, + ) -> datafusion_common::Result<()> { + match self.prepared_plans.entry(name) { + Entry::Vacant(e) => { + e.insert(plan); + Ok(()) + } + Entry::Occupied(e) => { + exec_err!("Prepared statement with name '{}' already exists", e.key()) + } + } + } + + /// Get the logical plan for the prepared statement named `name`. + pub(crate) fn get_prepared(&self, name: &str) -> Option> { + self.prepared_plans.get(name).map(Arc::clone) + } } /// A builder to be used for building [`SessionState`]'s. Defaults will @@ -1327,6 +1353,7 @@ impl SessionStateBuilder { table_factories: table_factories.unwrap_or_default(), runtime_env, function_factory, + prepared_plans: HashMap::new(), }; if let Some(file_formats) = file_formats { From 317dd84c4fb8a62161cacd24528c34d53d7436e9 Mon Sep 17 00:00:00 2001 From: jonahgao Date: Sun, 3 Nov 2024 19:31:24 +0800 Subject: [PATCH 2/7] Improve execute_prepared --- datafusion/core/src/execution/context/mod.rs | 70 ++++++++++++++----- .../core/src/execution/session_state.rs | 19 +++-- datafusion/core/tests/sql/select.rs | 2 +- 3 files changed, 66 insertions(+), 25 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 47d0ff5000ca..3c16ee0483de 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -57,7 +57,7 @@ use datafusion_common::{ config::{ConfigExtension, TableOptions}, exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, ScalarValue, SchemaReference, TableReference, + DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -688,26 +688,15 @@ impl SessionContext { LogicalPlan::Statement(Statement::SetVariable(stmt)) => { self.set_variable(stmt).await } - LogicalPlan::Prepare(Prepare { name, input, .. }) => { - self.state.write().store_prepared(name, input)?; - self.return_empty_dataframe() - } - LogicalPlan::Execute(Execute { - name, parameters, .. + LogicalPlan::Prepare(Prepare { + name, + input, + data_types, }) => { - let plan = self.state.read().get_prepared(&name).ok_or_else(|| { - exec_datafusion_err!("Prepared statement '{}' not exists", name) - })?; - let values: Vec = parameters - .into_iter() - .map(|e| match e { - Expr::Literal(scalar) => Ok(scalar), - _ => exec_err!("Invalid parameter type"), - }) - .collect::>()?; - let plan = plan.as_ref().clone().with_param_values(values)?; - Ok(DataFrame::new(self.state(), plan)) + self.state.write().store_prepared(name, data_types, input)?; + self.return_empty_dataframe() } + LogicalPlan::Execute(execute) => self.execute_prepared(execute), plan => Ok(DataFrame::new(self.state(), plan)), } } @@ -1108,6 +1097,49 @@ impl SessionContext { } } + fn execute_prepared(&self, execute: Execute) -> Result { + let Execute { + name, parameters, .. + } = execute; + let prepared = self.state.read().get_prepared(&name).ok_or_else(|| { + exec_datafusion_err!("Prepared statement '{}' not exists", name) + })?; + + // Only allow literals as parameters for now. + let mut params: Vec = parameters + .into_iter() + .map(|e| match e { + Expr::Literal(scalar) => Ok(scalar), + _ => not_impl_err!("Unsupported data type for parameter: {}", e), + }) + .collect::>()?; + + // If the prepared statement provides data types, cast the params to those types. + if !prepared.data_types.is_empty() { + if params.len() != prepared.data_types.len() { + return exec_err!( + "Prepared statement '{}' expects {} parameters, but {} provided", + name, + prepared.data_types.len(), + params.len() + ); + } + params = params + .into_iter() + .zip(prepared.data_types.iter()) + .map(|(e, dt)| e.cast_to(dt)) + .collect::>()?; + } + + let params = ParamValues::List(params); + let plan = prepared + .plan + .as_ref() + .clone() + .replace_params_with_values(¶ms)?; + Ok(DataFrame::new(self.state(), plan)) + } + /// Registers a variable provider within this context. pub fn register_variable( &self, diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 46410e87812d..3ea7622f81a3 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -173,7 +173,7 @@ pub struct SessionState { function_factory: Option>, /// Cache logical plans of prepared statements for later execution. /// Key is the prepared statement name. - prepared_plans: HashMap>, + prepared_plans: HashMap>, } impl Debug for SessionState { @@ -911,15 +911,16 @@ impl SessionState { Ok(udtf.map(|x| x.function().clone())) } - /// Store the logical plan of a prepared statement. + /// Store the logical plan and parameter types of a prepared statement. pub(crate) fn store_prepared( &mut self, name: String, + data_types: Vec, plan: Arc, ) -> datafusion_common::Result<()> { match self.prepared_plans.entry(name) { Entry::Vacant(e) => { - e.insert(plan); + e.insert(Arc::new(PreparedPlan { data_types, plan })); Ok(()) } Entry::Occupied(e) => { @@ -928,8 +929,8 @@ impl SessionState { } } - /// Get the logical plan for the prepared statement named `name`. - pub(crate) fn get_prepared(&self, name: &str) -> Option> { + /// Get the prepared plan with the given name. + pub(crate) fn get_prepared(&self, name: &str) -> Option> { self.prepared_plans.get(name).map(Arc::clone) } } @@ -1903,6 +1904,14 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { } } +#[derive(Debug)] +pub(crate) struct PreparedPlan { + /// Data types of the parameters + pub(crate) data_types: Vec, + /// The prepared logical plan + pub(crate) plan: Arc, +} + #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index dd660512f346..3ee7c688f614 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -108,7 +108,7 @@ async fn test_prepare_statement() -> Result<()> { // sql to statement then to prepare logical plan with parameters // c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64 let dataframe = - ctx.sql("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1").await?; + ctx.state().create_logical_plan("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1").await?; // prepare logical plan to logical plan without parameters let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))]; From 01f5b6973f51fb2c5da66667f9d6e484e155700a Mon Sep 17 00:00:00 2001 From: jonahgao Date: Mon, 4 Nov 2024 14:14:31 +0800 Subject: [PATCH 3/7] Fix tests --- datafusion/core/src/execution/context/mod.rs | 11 +- .../core/src/execution/session_state.rs | 2 +- datafusion/core/tests/sql/select.rs | 30 +--- datafusion/expr/src/logical_plan/plan.rs | 16 ++ .../sqllogictest/test_files/prepare.slt | 137 ++++++++++++++---- 5 files changed, 140 insertions(+), 56 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 3c16ee0483de..53830d8a4a6d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -693,6 +693,15 @@ impl SessionContext { input, data_types, }) => { + // The number of parameters must match the specified data types length. + let param_names = input.get_parameter_names()?; + if param_names.len() != data_types.len() { + return plan_err!( + "Prepare specifies {} data types but query has {} parameters", + data_types.len(), + param_names.len() + ); + } self.state.write().store_prepared(name, data_types, input)?; self.return_empty_dataframe() } @@ -1102,7 +1111,7 @@ impl SessionContext { name, parameters, .. } = execute; let prepared = self.state.read().get_prepared(&name).ok_or_else(|| { - exec_datafusion_err!("Prepared statement '{}' not exists", name) + exec_datafusion_err!("Prepared statement '{}' does not exist", name) })?; // Only allow literals as parameters for now. diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 3ea7622f81a3..9baf5862edce 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -924,7 +924,7 @@ impl SessionState { Ok(()) } Entry::Occupied(e) => { - exec_err!("Prepared statement with name '{}' already exists", e.key()) + exec_err!("Prepared statement '{}' already exists", e.key()) } } } diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 3ee7c688f614..2e815303e3ce 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -57,7 +57,6 @@ async fn test_named_query_parameters() -> Result<()> { let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?; // sql to statement then to logical plan with parameters - // c1 defined as UINT32, c2 defined as UInt64 let results = ctx .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") .await? @@ -106,9 +105,9 @@ async fn test_prepare_statement() -> Result<()> { let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?; // sql to statement then to prepare logical plan with parameters - // c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64 - let dataframe = - ctx.state().create_logical_plan("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1").await?; + let dataframe = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1") + .await?; // prepare logical plan to logical plan without parameters let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))]; @@ -156,7 +155,7 @@ async fn prepared_statement_type_coercion() -> Result<()> { ("unsigned", Arc::new(unsigned_ints) as ArrayRef), ])?; ctx.register_batch("test", batch)?; - let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3") + let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3") .await? .with_param_values(vec![ ScalarValue::from(1_i64), @@ -176,27 +175,6 @@ async fn prepared_statement_type_coercion() -> Result<()> { Ok(()) } -#[tokio::test] -async fn prepared_statement_invalid_types() -> Result<()> { - let ctx = SessionContext::new(); - let signed_ints: Int32Array = vec![-1, 0, 1].into(); - let unsigned_ints: UInt64Array = vec![1, 2, 3].into(); - let batch = RecordBatch::try_from_iter(vec![ - ("signed", Arc::new(signed_ints) as ArrayRef), - ("unsigned", Arc::new(unsigned_ints) as ArrayRef), - ])?; - ctx.register_batch("test", batch)?; - let results = ctx - .sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed = $1") - .await? - .with_param_values(vec![ScalarValue::from("1")]); - assert_eq!( - results.unwrap_err().strip_backtrace(), - "Error during planning: Expected parameter of type Int32, got Utf8 at index 0" - ); - Ok(()) -} - #[tokio::test] async fn test_parameter_type_coercion() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 191a42e38e3a..71dc303cdbe8 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1437,6 +1437,22 @@ impl LogicalPlan { .map(|res| res.data) } + /// Walk the logical plan, find any `Placeholder` tokens, and return a set of their names. + pub fn get_parameter_names(&self) -> Result> { + let mut param_names = HashSet::new(); + self.apply_with_subqueries(|plan| { + plan.apply_expressions(|expr| { + expr.apply(|expr| { + if let Expr::Placeholder(Placeholder { id, .. }) = expr { + param_names.insert(id.clone()); + } + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| param_names) + } + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index e306ec7767c7..be8dadc3f0a0 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -30,56 +30,141 @@ select * from person; # Error due to syntax and semantic violation # Syntax error: no name specified after the keyword prepare -statement error +statement error DataFusion error: SQL error: ParserError PREPARE AS SELECT id, age FROM person WHERE age = $foo; # param following a non-number, $foo, not supported -statement error +statement error Invalid placeholder, not a number: \$foo PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo; # not specify table hence cannot specify columns -statement error +statement error Schema error: No field named id PREPARE my_plan(INT) AS SELECT id + $1; # not specify data types for all params -statement error +statement error Prepare specifies 1 data types but query has 2 parameters PREPARE my_plan(INT) AS SELECT 1 + $1 + $2; +# sepecify too many data types for params +statement error Prepare specifies 2 data types but query has 1 parameters +PREPARE my_plan(INT, INT) AS SELECT 1 + $1; + # cannot use IS param -statement error +statement error SQL error: ParserError PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1; +# TODO: allow prepare without specifying data types +statement error Placeholder type could not be resolved +PREPARE my_plan AS SELECT $1; + # ####################### -# TODO: all the errors below should work ok after we store the prepare logical plan somewhere -statement error +# Test prepare and execute statements + +# execute a non-existing plan +statement error Prepared statement \'my_plan\' does not exist +EXECUTE my_plan('Foo', 'Bar'); + +statement ok PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); -statement error +query IT +EXECUTE my_plan('Foo', 'Bar'); +---- +1 Foo +2 Bar + +# duplicate prepare statement +statement error Prepared statement \'my_plan\' already exists +PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); + +statement error Prepare specifies 1 data types but query has 0 parameters PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10; -statement error -PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10; +# prepare statement has no params +statement ok +PREPARE my_plan2 AS SELECT id, age FROM person WHERE age = 20; + +query II +EXECUTE my_plan2; +---- +1 20 -statement error -PREPARE my_plan(INT) AS SELECT $1; +statement ok +PREPARE my_plan3(INT) AS SELECT $1; -statement error -PREPARE my_plan(INT) AS SELECT 1 + $1; +query I +EXECUTE my_plan3(10); +---- +10 -statement error -PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2; +statement ok +PREPARE my_plan4(INT) AS SELECT 1 + $1; -statement error -PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1; +query I +EXECUTE my_plan4(10); +---- +11 -statement error -PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS SELECT id, age, $6 FROM person WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; +statement ok +PREPARE my_plan5(INT, DOUBLE) AS SELECT 1 + $1 + $2; -statement error -PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS SELECT id, SUM(age) FROM person WHERE salary > $2 GROUP BY id HAVING sum(age) < $1 AND SUM(age) > 10 OR SUM(age) in ($3, $4); +query R +EXECUTE my_plan5(10, 20.5); +---- +31.5 -statement error -PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); +statement ok +PREPARE my_plan6(INT) AS SELECT id, age FROM person WHERE age = $1; + +query II +EXECUTE my_plan6(20); +---- +1 20 + +# EXECUTE param is a different type but compatible +query II +EXECUTE my_plan6('20'); +---- +1 20 + +query II +EXECUTE my_plan6(20.0); +---- +1 20 + +# invalid execute param +statement error Cast error: Cannot cast string 'foo' to value of Int32 type +EXECUTE my_plan6('foo'); + +statement ok +PREPARE my_plan7(INT, STRING, DOUBLE, INT, DOUBLE, STRING) + AS +SELECT id, age, $6 FROM person WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2; + +query IIT +EXECUTE my_plan7(10, 'jane', 99999.45, 20, 200000.45, 'foo'); +---- +1 20 foo + +statement ok +PREPARE my_plan8(INT, DOUBLE, DOUBLE, DOUBLE) + AS +SELECT id, SUM(age) FROM person WHERE salary > $2 GROUP BY id + HAVING sum(age) < $1 AND SUM(age) > 10 OR SUM(age) in ($3, $4); + +query II +EXECUTE my_plan8(100000, 99999.45, 100000.45, 200000.45); +---- +1 20 + +statement ok +PREPARE my_plan9(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); + +query IT +EXECUTE my_plan9('Foo', 'Bar'); +---- +1 Foo +2 Bar # test creating logical plan for EXECUTE statements query TT @@ -94,7 +179,3 @@ logical_plan Execute: my_plan params=[Int64(21), Utf8("Foo")] query error DataFusion error: Schema error: No field named a\. EXPLAIN EXECUTE my_plan(a); - -# TODO: support EXECUTE queries -query error DataFusion error: This feature is not implemented: Unsupported logical plan: Execute -EXECUTE my_plan; From c81c1b49ffacb5435e9c8d3f846279b82744de0c Mon Sep 17 00:00:00 2001 From: jonahgao Date: Mon, 4 Nov 2024 14:26:59 +0800 Subject: [PATCH 4/7] Update doc --- datafusion/core/src/execution/context/mod.rs | 21 ++++++++++++------- .../core/src/execution/session_state.rs | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 53830d8a4a6d..dd3c96aa7c7d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -694,14 +694,21 @@ impl SessionContext { data_types, }) => { // The number of parameters must match the specified data types length. - let param_names = input.get_parameter_names()?; - if param_names.len() != data_types.len() { - return plan_err!( - "Prepare specifies {} data types but query has {} parameters", - data_types.len(), - param_names.len() - ); + if !data_types.is_empty() { + let param_names = input.get_parameter_names()?; + if param_names.len() != data_types.len() { + return plan_err!( + "Prepare specifies {} data types but query has {} parameters", + data_types.len(), + param_names.len() + ); + } } + // Store the unoptimized plan into the session state. Although storing the + // optimized plan or the physical plan would be more efficient, doing so is + // not currently feasible. This is because `now()` would be optimized to a + // constant value, causing each EXECUTE to yield the same result, which is + // incorrect behavior. self.state.write().store_prepared(name, data_types, input)?; self.return_empty_dataframe() } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 9baf5862edce..ecb59f7b03b7 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -911,7 +911,7 @@ impl SessionState { Ok(udtf.map(|x| x.function().clone())) } - /// Store the logical plan and parameter types of a prepared statement. + /// Store the logical plan and the parameter types of a prepared statement. pub(crate) fn store_prepared( &mut self, name: String, From aad77446f54c89cdaaa4f61909a3e0230b8cc326 Mon Sep 17 00:00:00 2001 From: jonahgao Date: Mon, 4 Nov 2024 14:54:22 +0800 Subject: [PATCH 5/7] Add test --- datafusion/core/src/execution/context/mod.rs | 2 +- datafusion/sqllogictest/test_files/prepare.slt | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index dd3c96aa7c7d..d47d3ffb2b9e 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1126,7 +1126,7 @@ impl SessionContext { .into_iter() .map(|e| match e { Expr::Literal(scalar) => Ok(scalar), - _ => not_impl_err!("Unsupported data type for parameter: {}", e), + _ => not_impl_err!("Unsupported parameter type: {}", e), }) .collect::>()?; diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index be8dadc3f0a0..f6260e48f65a 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -136,6 +136,10 @@ EXECUTE my_plan6(20.0); statement error Cast error: Cannot cast string 'foo' to value of Int32 type EXECUTE my_plan6('foo'); +# TODO: support non-literal expressions +statement error Unsupported parameter type +EXECUTE my_plan6(10 + 20); + statement ok PREPARE my_plan7(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS From b9a9674255f0de1643317c991f708dda8a4eb878 Mon Sep 17 00:00:00 2001 From: jonahgao Date: Mon, 4 Nov 2024 15:15:55 +0800 Subject: [PATCH 6/7] Add issue test --- .../sqllogictest/test_files/prepare.slt | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index f6260e48f65a..493290a75f17 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -170,6 +170,36 @@ EXECUTE my_plan9('Foo', 'Bar'); 1 Foo 2 Bar + +# Test issue: https://github.com/apache/datafusion/issues/12294 +# prepare argument is in the LIMIT clause +statement ok +CREATE TABLE test(id INT, run_id TEXT) AS VALUES(1, 'foo'), (1, 'foo'), (3, 'bar'); + +statement ok +PREPARE get_N_rand_ints_from_last_run(INT) AS +SELECT id +FROM + "test" +WHERE run_id = 'foo' +ORDER BY random() +LIMIT $1 + +query I +EXECUTE get_N_rand_ints_from_last_run(1); +---- +1 + +query I +EXECUTE get_N_rand_ints_from_last_run(2); +---- +1 +1 + +statement ok +DROP TABLE test; + + # test creating logical plan for EXECUTE statements query TT EXPLAIN EXECUTE my_plan; From 472976639ac5a7ebe6914205f4026c3ee440d099 Mon Sep 17 00:00:00 2001 From: jonahgao Date: Wed, 6 Nov 2024 10:40:52 +0800 Subject: [PATCH 7/7] Respect allow_statements option --- datafusion/core/src/execution/context/mod.rs | 8 +++++++ datafusion/core/tests/sql/sql_api.rs | 24 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index d47d3ffb2b9e..7868a7f9e59c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1773,6 +1773,14 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } + // TODO: Implement PREPARE as a LogicalPlan::Statement + LogicalPlan::Prepare(_) if !self.options.allow_statements => { + plan_err!("Statement not supported: PREPARE") + } + // TODO: Implement EXECUTE as a LogicalPlan::Statement + LogicalPlan::Execute(_) if !self.options.allow_statements => { + plan_err!("Statement not supported: EXECUTE") + } _ => Ok(TreeNodeRecursion::Continue), } } diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index 48f4a66b65dc..b2ffefa43708 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -113,6 +113,30 @@ async fn unsupported_statement_returns_error() { ctx.sql_with_options(sql, options).await.unwrap(); } +// Disallow PREPARE and EXECUTE statements if `allow_statements` is false +#[tokio::test] +async fn disable_prepare_and_execute_statement() { + let ctx = SessionContext::new(); + + let prepare_sql = "PREPARE plan(INT) AS SELECT $1"; + let execute_sql = "EXECUTE plan(1)"; + let options = SQLOptions::new().with_allow_statements(false); + let df = ctx.sql_with_options(prepare_sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: Statement not supported: PREPARE" + ); + let df = ctx.sql_with_options(execute_sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: Statement not supported: EXECUTE" + ); + + let options = options.with_allow_statements(true); + ctx.sql_with_options(prepare_sql, options).await.unwrap(); + ctx.sql_with_options(execute_sql, options).await.unwrap(); +} + #[tokio::test] async fn empty_statement_returns_error() { let ctx = SessionContext::new();