diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 90fb4a88149c..6df89624fc51 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -20,6 +20,7 @@ mod dfschema; mod error; mod functional_dependencies; mod join_type; +mod param_value; #[cfg(feature = "pyarrow")] mod pyarrow; mod schema_reference; @@ -59,6 +60,7 @@ pub use functional_dependencies::{ Constraints, Dependency, FunctionalDependence, FunctionalDependencies, }; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::{OwnedSchemaReference, SchemaReference}; pub use stats::{ColumnStatistics, Statistics}; diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs new file mode 100644 index 000000000000..253c312b66d5 --- /dev/null +++ b/datafusion/common/src/param_value.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{_internal_err, _plan_err}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow_schema::DataType; +use std::collections::HashMap; + +/// The parameter value corresponding to the placeholder +#[derive(Debug, Clone)] +pub enum ParamValues { + /// for positional query parameters, like select * from test where a > $1 and b = $2 + LIST(Vec), + /// for named query parameters, like select * from test where a > $foo and b = $goo + MAP(HashMap), +} + +impl ParamValues { + /// Verify parameter list length and type + pub fn verify(&self, expect: &Vec) -> Result<()> { + match self { + ParamValues::LIST(list) => { + // Verify if the number of params matches the number of values + if expect.len() != list.len() { + return _plan_err!( + "Expected {} parameters, got {}", + expect.len(), + list.len() + ); + } + + // Verify if the types of the params matches the types of the values + let iter = expect.iter().zip(list.iter()); + for (i, (param_type, value)) in iter.enumerate() { + if *param_type != value.data_type() { + return _plan_err!( + "Expected parameter of type {:?}, got {:?} at index {}", + param_type, + value.data_type(), + i + ); + } + } + Ok(()) + } + ParamValues::MAP(_) => { + // If it is a named query, variables can be reused, + // but the lengths are not necessarily equal + Ok(()) + } + } + } + + pub fn get_placeholders_with_values( + &self, + id: &String, + data_type: &Option, + ) -> Result { + match self { + ParamValues::LIST(list) => { + if id.is_empty() || id == "$0" { + return _plan_err!("Empty placeholder id"); + } + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..].parse::().map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {e}" + )) + })? - 1; + // value at the idx-th position in param_values should be the value for the placeholder + let value = list.get(idx).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(value.data_type()) != *data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + ParamValues::MAP(map) => { + // convert name (in format $a, $b, ..) to mapped values (a, b, ..) + let name = &id[1..]; + // value at the name position in param_values should be the value for the placeholder + let value = map.get(name).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with name {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(value.data_type()) != *data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + } + } +} + +impl From> for ParamValues { + fn from(value: Vec) -> Self { + Self::LIST(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: Vec<(K, ScalarValue)>) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::MAP(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: HashMap) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::MAP(value) + } +} diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 89e82fa952bb..52b5157b7313 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -32,11 +32,12 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions, + DataFusionError, FileType, FileTypeWriterOptions, ParamValues, SchemaError, + UnnestOptions, }; use datafusion_expr::dml::CopyOptions; -use datafusion_common::{Column, DFSchema, ScalarValue}; +use datafusion_common::{Column, DFSchema}; use datafusion_expr::{ avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, @@ -1227,11 +1228,32 @@ impl DataFrame { /// ], /// &results /// ); + /// // Note you can also provide named parameters + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $my_param") + /// .await? + /// // replace $my_param with value 2 + /// // Note you can also use a HashMap as well + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(2i64)) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); /// # Ok(()) /// # } /// ``` - pub fn with_param_values(self, param_values: Vec) -> Result { - let plan = self.plan.with_param_values(param_values)?; + pub fn with_param_values(self, query_values: impl Into) -> Result { + let plan = self.plan.with_param_values(query_values)?; Ok(Self::new(self.session_state, plan)) } diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 63f3e979305a..cbdea9d72948 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -525,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_named_query_parameters() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + // sql to statement then to logical plan with parameters + // c1 defined as UINT32, c2 defined as UInt64 + let results = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") + .await? + .with_param_values(vec![ + ("foo", ScalarValue::UInt32(Some(3))), + ("coo", ScalarValue::UInt32(Some(0))), + ])? + .collect() + .await?; + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + #[tokio::test] async fn parallel_query_with_filter() -> Result<()> { let tmp_dir = TempDir::new()?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ee9b0ad6f967..6fa400454dff 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -671,7 +671,7 @@ impl InSubquery { } } -/// Placeholder, representing bind parameter values such as `$1`. +/// Placeholder, representing bind parameter values such as `$1` or `$name`. /// /// The type of these parameters is inferred using [`Expr::infer_placeholder_types`] /// or can be specified directly using `PREPARE` statements. diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9bb47c7da058..fc8590294fe9 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -48,7 +48,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - OwnedTableReference, Result, ScalarValue, UnnestOptions, + OwnedTableReference, ParamValues, Result, UnnestOptions, }; // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; @@ -993,32 +993,12 @@ impl LogicalPlan { /// ``` pub fn with_param_values( self, - param_values: Vec, + param_values: impl Into, ) -> Result { + let param_values = param_values.into(); match self { LogicalPlan::Prepare(prepare_lp) => { - // Verify if the number of params matches the number of values - if prepare_lp.data_types.len() != param_values.len() { - return plan_err!( - "Expected {} parameters, got {}", - prepare_lp.data_types.len(), - param_values.len() - ); - } - - // Verify if the types of the params matches the types of the values - let iter = prepare_lp.data_types.iter().zip(param_values.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { - return plan_err!( - "Expected parameter of type {:?}, got {:?} at index {}", - param_type, - value.data_type(), - i - ); - } - } - + param_values.verify(&prepare_lp.data_types)?; let input_plan = prepare_lp.input; input_plan.replace_params_with_values(¶m_values) } @@ -1182,7 +1162,7 @@ impl LogicalPlan { /// See [`Self::with_param_values`] for examples and usage pub fn replace_params_with_values( &self, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { let new_exprs = self .expressions() @@ -1239,36 +1219,15 @@ impl LogicalPlan { /// corresponding values provided in the params_values fn replace_placeholders_with_values( expr: Expr, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { - if id.is_empty() || id == "$0" { - return plan_err!("Empty placeholder id"); - } - // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) - })? - 1; - // value at the idx-th position in param_values should be the value for the placeholder - let value = param_values.get(idx).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {id}" - )) - })?; - // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { - return internal_err!( - "Placeholder value type mismatch: expected {:?}, got {:?}", - data_type, - value.data_type() - ); - } + let value = + param_values.get_placeholders_with_values(id, data_type)?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value.clone()))) + Ok(Transformed::Yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = @@ -2580,7 +2539,7 @@ mod tests { use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{not_impl_err, DFSchema, TableReference}; + use datafusion_common::{not_impl_err, DFSchema, ScalarValue, TableReference}; use std::collections::HashMap; fn employee_schema() -> Schema { @@ -3028,7 +2987,8 @@ digraph { .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + let param_values = vec![ScalarValue::Int32(Some(42))]; + plan.replace_params_with_values(¶m_values.clone().into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); // test $0 placeholder @@ -3041,7 +3001,7 @@ digraph { .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + plan.replace_params_with_values(¶m_values.into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index a3f29da488ba..708f7c60011a 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -108,7 +108,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Ok(index) => index - 1, Err(_) => { - return plan_err!("Invalid placeholder, not a number: {param}"); + return if param_data_types.is_empty() { + Ok(Expr::Placeholder(Placeholder::new(param, None))) + } else { + // when PREPARE Statement, param_data_types length is always 0 + plan_err!("Invalid placeholder, not a number: {param}") + }; } }; // Check if the placeholder is in the parameter list diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index d5b06bcf815f..83bdb954b134 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,11 +22,11 @@ use std::{sync::Arc, vec}; use arrow_schema::*; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; -use datafusion_common::plan_err; use datafusion_common::{ assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; +use datafusion_common::{plan_err, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ScalarUDF, TableSource, WindowUDF, @@ -471,6 +471,10 @@ Dml: op=[Insert Into] table=[test_decimal] "INSERT INTO person (id, first_name, last_name) VALUES ($2, $4, $6)", "Error during planning: Placeholder type could not be resolved" )] +#[case::placeholder_type_unresolved( + "INSERT INTO person (id, first_name, last_name) VALUES ($id, $first_name, $last_name)", + "Error during planning: Can't parse placeholder: $id" +)] #[test] fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) { let err = logical_plan(sql).unwrap_err(); @@ -2674,7 +2678,7 @@ fn prepare_stmt_quick_test( fn prepare_stmt_replace_params_quick_test( plan: LogicalPlan, - param_values: Vec, + param_values: impl Into, expected_plan: &str, ) -> LogicalPlan { // replace params @@ -3726,7 +3730,7 @@ fn test_prepare_statement_to_plan_no_param() { /////////////////// // replace params with values - let param_values = vec![]; + let param_values: Vec = vec![]; let expected_plan = "Projection: person.id, person.age\ \n Filter: person.age = Int64(10)\ \n TableScan: person"; @@ -3740,7 +3744,7 @@ fn test_prepare_statement_to_plan_one_param_no_value_panic() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; let plan = logical_plan(sql).unwrap(); // declare 1 param but provide 0 - let param_values = vec![]; + let param_values: Vec = vec![]; assert_eq!( plan.with_param_values(param_values) .unwrap_err() @@ -3853,7 +3857,7 @@ Projection: person.id, orders.order_id assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) @@ -3885,7 +3889,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = Int32(10) @@ -3919,7 +3923,8 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; + let param_values = + vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age BETWEEN Int32(10) AND Int32(30) @@ -3955,7 +3960,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; + let param_values = vec![ScalarValue::UInt32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = () @@ -3995,7 +4000,8 @@ Dml: op=[Update] table=[person] assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; + let param_values = + vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))].into(); let expected_plan = r#" Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 @@ -4034,7 +4040,8 @@ fn test_prepare_statement_insert_infer() { ScalarValue::UInt32(Some(1)), ScalarValue::Utf8(Some("Alan".to_string())), ScalarValue::Utf8(Some("Turing".to_string())), - ]; + ] + .into(); let expected_plan = "Dml: op=[Insert Into] table=[person]\ \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \