diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 333f83c673cc..7868a7f9e59c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -42,7 +42,8 @@ use crate::{ logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, - DropView, LogicalPlan, LogicalPlanBuilder, SetVariable, TableType, UNNAMED_TABLE, + DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable, + TableType, UNNAMED_TABLE, }, physical_expr::PhysicalExpr, physical_plan::ExecutionPlan, @@ -54,9 +55,9 @@ use arrow::record_batch::RecordBatch; use arrow_schema::Schema; use datafusion_common::{ config::{ConfigExtension, TableOptions}, - exec_err, not_impl_err, plan_datafusion_err, plan_err, + exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, SchemaReference, TableReference, + DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -687,7 +688,31 @@ impl SessionContext { LogicalPlan::Statement(Statement::SetVariable(stmt)) => { self.set_variable(stmt).await } - + LogicalPlan::Prepare(Prepare { + name, + input, + data_types, + }) => { + // The number of parameters must match the specified data types length. + if !data_types.is_empty() { + let param_names = input.get_parameter_names()?; + if param_names.len() != data_types.len() { + return plan_err!( + "Prepare specifies {} data types but query has {} parameters", + data_types.len(), + param_names.len() + ); + } + } + // Store the unoptimized plan into the session state. Although storing the + // optimized plan or the physical plan would be more efficient, doing so is + // not currently feasible. This is because `now()` would be optimized to a + // constant value, causing each EXECUTE to yield the same result, which is + // incorrect behavior. + self.state.write().store_prepared(name, data_types, input)?; + self.return_empty_dataframe() + } + LogicalPlan::Execute(execute) => self.execute_prepared(execute), plan => Ok(DataFrame::new(self.state(), plan)), } } @@ -1088,6 +1113,49 @@ impl SessionContext { } } + fn execute_prepared(&self, execute: Execute) -> Result { + let Execute { + name, parameters, .. + } = execute; + let prepared = self.state.read().get_prepared(&name).ok_or_else(|| { + exec_datafusion_err!("Prepared statement '{}' does not exist", name) + })?; + + // Only allow literals as parameters for now. + let mut params: Vec = parameters + .into_iter() + .map(|e| match e { + Expr::Literal(scalar) => Ok(scalar), + _ => not_impl_err!("Unsupported parameter type: {}", e), + }) + .collect::>()?; + + // If the prepared statement provides data types, cast the params to those types. + if !prepared.data_types.is_empty() { + if params.len() != prepared.data_types.len() { + return exec_err!( + "Prepared statement '{}' expects {} parameters, but {} provided", + name, + prepared.data_types.len(), + params.len() + ); + } + params = params + .into_iter() + .zip(prepared.data_types.iter()) + .map(|(e, dt)| e.cast_to(dt)) + .collect::>()?; + } + + let params = ParamValues::List(params); + let plan = prepared + .plan + .as_ref() + .clone() + .replace_params_with_values(¶ms)?; + Ok(DataFrame::new(self.state(), plan)) + } + /// Registers a variable provider within this context. pub fn register_variable( &self, @@ -1705,6 +1773,14 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } + // TODO: Implement PREPARE as a LogicalPlan::Statement + LogicalPlan::Prepare(_) if !self.options.allow_statements => { + plan_err!("Statement not supported: PREPARE") + } + // TODO: Implement EXECUTE as a LogicalPlan::Statement + LogicalPlan::Execute(_) if !self.options.allow_statements => { + plan_err!("Statement not supported: EXECUTE") + } _ => Ok(TreeNodeRecursion::Continue), } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d50c912dd2fd..ecb59f7b03b7 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -40,7 +40,7 @@ use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, + config_err, exec_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, }; use datafusion_execution::config::SessionConfig; @@ -171,6 +171,9 @@ pub struct SessionState { /// It will be invoked on `CREATE FUNCTION` statements. /// thus, changing dialect o PostgreSql is required function_factory: Option>, + /// Cache logical plans of prepared statements for later execution. + /// Key is the prepared statement name. + prepared_plans: HashMap>, } impl Debug for SessionState { @@ -197,6 +200,7 @@ impl Debug for SessionState { .field("scalar_functions", &self.scalar_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) + .field("prepared_plans", &self.prepared_plans) .finish() } } @@ -906,6 +910,29 @@ impl SessionState { let udtf = self.table_functions.remove(name); Ok(udtf.map(|x| x.function().clone())) } + + /// Store the logical plan and the parameter types of a prepared statement. + pub(crate) fn store_prepared( + &mut self, + name: String, + data_types: Vec, + plan: Arc, + ) -> datafusion_common::Result<()> { + match self.prepared_plans.entry(name) { + Entry::Vacant(e) => { + e.insert(Arc::new(PreparedPlan { data_types, plan })); + Ok(()) + } + Entry::Occupied(e) => { + exec_err!("Prepared statement '{}' already exists", e.key()) + } + } + } + + /// Get the prepared plan with the given name. + pub(crate) fn get_prepared(&self, name: &str) -> Option> { + self.prepared_plans.get(name).map(Arc::clone) + } } /// A builder to be used for building [`SessionState`]'s. Defaults will @@ -1327,6 +1354,7 @@ impl SessionStateBuilder { table_factories: table_factories.unwrap_or_default(), runtime_env, function_factory, + prepared_plans: HashMap::new(), }; if let Some(file_formats) = file_formats { @@ -1876,6 +1904,14 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { } } +#[derive(Debug)] +pub(crate) struct PreparedPlan { + /// Data types of the parameters + pub(crate) data_types: Vec, + /// The prepared logical plan + pub(crate) plan: Arc, +} + #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index dd660512f346..2e815303e3ce 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -57,7 +57,6 @@ async fn test_named_query_parameters() -> Result<()> { let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?; // sql to statement then to logical plan with parameters - // c1 defined as UINT32, c2 defined as UInt64 let results = ctx .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") .await? @@ -106,9 +105,9 @@ async fn test_prepare_statement() -> Result<()> { let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?; // sql to statement then to prepare logical plan with parameters - // c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64 - let dataframe = - ctx.sql("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1").await?; + let dataframe = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1") + .await?; // prepare logical plan to logical plan without parameters let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))]; @@ -156,7 +155,7 @@ async fn prepared_statement_type_coercion() -> Result<()> { ("unsigned", Arc::new(unsigned_ints) as ArrayRef), ])?; ctx.register_batch("test", batch)?; - let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3") + let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3") .await? .with_param_values(vec![ ScalarValue::from(1_i64), @@ -176,27 +175,6 @@ async fn prepared_statement_type_coercion() -> Result<()> { Ok(()) } -#[tokio::test] -async fn prepared_statement_invalid_types() -> Result<()> { - let ctx = SessionContext::new(); - let signed_ints: Int32Array = vec![-1, 0, 1].into(); - let unsigned_ints: UInt64Array = vec![1, 2, 3].into(); - let batch = RecordBatch::try_from_iter(vec![ - ("signed", Arc::new(signed_ints) as ArrayRef), - ("unsigned", Arc::new(unsigned_ints) as ArrayRef), - ])?; - ctx.register_batch("test", batch)?; - let results = ctx - .sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed = $1") - .await? - .with_param_values(vec![ScalarValue::from("1")]); - assert_eq!( - results.unwrap_err().strip_backtrace(), - "Error during planning: Expected parameter of type Int32, got Utf8 at index 0" - ); - Ok(()) -} - #[tokio::test] async fn test_parameter_type_coercion() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index 48f4a66b65dc..b2ffefa43708 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -113,6 +113,30 @@ async fn unsupported_statement_returns_error() { ctx.sql_with_options(sql, options).await.unwrap(); } +// Disallow PREPARE and EXECUTE statements if `allow_statements` is false +#[tokio::test] +async fn disable_prepare_and_execute_statement() { + let ctx = SessionContext::new(); + + let prepare_sql = "PREPARE plan(INT) AS SELECT $1"; + let execute_sql = "EXECUTE plan(1)"; + let options = SQLOptions::new().with_allow_statements(false); + let df = ctx.sql_with_options(prepare_sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: Statement not supported: PREPARE" + ); + let df = ctx.sql_with_options(execute_sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: Statement not supported: EXECUTE" + ); + + let options = options.with_allow_statements(true); + ctx.sql_with_options(prepare_sql, options).await.unwrap(); + ctx.sql_with_options(execute_sql, options).await.unwrap(); +} + #[tokio::test] async fn empty_statement_returns_error() { let ctx = SessionContext::new(); diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ea8fca3ec9d6..db309d9b5232 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1440,6 +1440,22 @@ impl LogicalPlan { .map(|res| res.data) } + /// Walk the logical plan, find any `Placeholder` tokens, and return a set of their names. + pub fn get_parameter_names(&self) -> Result> { + let mut param_names = HashSet::new(); + self.apply_with_subqueries(|plan| { + plan.apply_expressions(|expr| { + expr.apply(|expr| { + if let Expr::Placeholder(Placeholder { id, .. }) = expr { + param_names.insert(id.clone()); + } + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| param_names) + } + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index 91b925efa26c..b0c67af9e14f 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -30,56 +30,175 @@ select * from person; # Error due to syntax and semantic violation # Syntax error: no name specified after the keyword prepare -statement error +statement error DataFusion error: SQL error: ParserError PREPARE AS SELECT id, age FROM person WHERE age = $foo; # param following a non-number, $foo, not supported -statement error +statement error Invalid placeholder, not a number: \$foo PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo; # not specify table hence cannot specify columns -statement error +statement error Schema error: No field named id PREPARE my_plan(INT) AS SELECT id + $1; # not specify data types for all params -statement error +statement error Prepare specifies 1 data types but query has 2 parameters PREPARE my_plan(INT) AS SELECT 1 + $1 + $2; +# sepecify too many data types for params +statement error Prepare specifies 2 data types but query has 1 parameters +PREPARE my_plan(INT, INT) AS SELECT 1 + $1; + # cannot use IS param -statement error +statement error SQL error: ParserError PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1; +# TODO: allow prepare without specifying data types +statement error Placeholder type could not be resolved +PREPARE my_plan AS SELECT $1; + # ####################### -# TODO: all the errors below should work ok after we store the prepare logical plan somewhere -statement error +# Test prepare and execute statements + +# execute a non-existing plan +statement error Prepared statement \'my_plan\' does not exist +EXECUTE my_plan('Foo', 'Bar'); + +statement ok PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); -statement error +query IT +EXECUTE my_plan('Foo', 'Bar'); +---- +1 Foo +2 Bar + +# duplicate prepare statement +statement error Prepared statement \'my_plan\' already exists +PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); + +statement error Prepare specifies 1 data types but query has 0 parameters PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10; -statement error -PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10; +# prepare statement has no params +statement ok +PREPARE my_plan2 AS SELECT id, age FROM person WHERE age = 20; + +query II +EXECUTE my_plan2; +---- +1 20 + +statement ok +PREPARE my_plan3(INT) AS SELECT $1; -statement error -PREPARE my_plan(INT) AS SELECT $1; +query I +EXECUTE my_plan3(10); +---- +10 -statement error -PREPARE my_plan(INT) AS SELECT 1 + $1; +statement ok +PREPARE my_plan4(INT) AS SELECT 1 + $1; -statement error -PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2; +query I +EXECUTE my_plan4(10); +---- +11 -statement error -PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1; +statement ok +PREPARE my_plan5(INT, DOUBLE) AS SELECT 1 + $1 + $2; -statement error -PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS SELECT id, age, $6 FROM person WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; +query R +EXECUTE my_plan5(10, 20.5); +---- +31.5 -statement error -PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS SELECT id, SUM(age) FROM person WHERE salary > $2 GROUP BY id HAVING sum(age) < $1 AND SUM(age) > 10 OR SUM(age) in ($3, $4); +statement ok +PREPARE my_plan6(INT) AS SELECT id, age FROM person WHERE age = $1; + +query II +EXECUTE my_plan6(20); +---- +1 20 + +# EXECUTE param is a different type but compatible +query II +EXECUTE my_plan6('20'); +---- +1 20 + +query II +EXECUTE my_plan6(20.0); +---- +1 20 + +# invalid execute param +statement error Cast error: Cannot cast string 'foo' to value of Int32 type +EXECUTE my_plan6('foo'); + +# TODO: support non-literal expressions +statement error Unsupported parameter type +EXECUTE my_plan6(10 + 20); + +statement ok +PREPARE my_plan7(INT, STRING, DOUBLE, INT, DOUBLE, STRING) + AS +SELECT id, age, $6 FROM person WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2; + +query IIT +EXECUTE my_plan7(10, 'jane', 99999.45, 20, 200000.45, 'foo'); +---- +1 20 foo + +statement ok +PREPARE my_plan8(INT, DOUBLE, DOUBLE, DOUBLE) + AS +SELECT id, SUM(age) FROM person WHERE salary > $2 GROUP BY id + HAVING sum(age) < $1 AND SUM(age) > 10 OR SUM(age) in ($3, $4); + +query II +EXECUTE my_plan8(100000, 99999.45, 100000.45, 200000.45); +---- +1 20 + +statement ok +PREPARE my_plan9(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); + +query IT +EXECUTE my_plan9('Foo', 'Bar'); +---- +1 Foo +2 Bar + + +# Test issue: https://github.com/apache/datafusion/issues/12294 +# prepare argument is in the LIMIT clause +statement ok +CREATE TABLE test(id INT, run_id TEXT) AS VALUES(1, 'foo'), (1, 'foo'), (3, 'bar'); + +statement ok +PREPARE get_N_rand_ints_from_last_run(INT) AS +SELECT id +FROM + "test" +WHERE run_id = 'foo' +ORDER BY random() +LIMIT $1 + +query I +EXECUTE get_N_rand_ints_from_last_run(1); +---- +1 + +query I +EXECUTE get_N_rand_ints_from_last_run(2); +---- +1 +1 + +statement ok +DROP TABLE test; -statement error -PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); # test creating logical plan for EXECUTE statements query TT @@ -96,7 +215,3 @@ physical_plan_error This feature is not implemented: Unsupported logical plan: E query error DataFusion error: Schema error: No field named a\. EXPLAIN EXECUTE my_plan(a); - -# TODO: support EXECUTE queries -query error DataFusion error: This feature is not implemented: Unsupported logical plan: Execute -EXECUTE my_plan;