From fa383728d481fba5bd6ec9c03f94fe25a1d441be Mon Sep 17 00:00:00 2001 From: Jesse Bakker Date: Fri, 1 Dec 2023 12:09:48 +0100 Subject: [PATCH] Correctness fixes for bugs exposed by PG JDBC --- dozer-api/src/rest/api_generator.rs | 8 +- dozer-api/src/sql/datafusion/mod.rs | 143 +++++++++++++++------------- dozer-api/src/sql/pgwire.rs | 53 ++++++----- 3 files changed, 109 insertions(+), 95 deletions(-) diff --git a/dozer-api/src/rest/api_generator.rs b/dozer-api/src/rest/api_generator.rs index 20ede7d52c..02d3118634 100644 --- a/dozer-api/src/rest/api_generator.rs +++ b/dozer-api/src/rest/api_generator.rs @@ -16,7 +16,7 @@ use openapiv3::OpenAPI; use crate::api_helper::{get_record, get_records, get_records_count}; use crate::generator::oapi::generator::OpenApiGenerator; use crate::sql::datafusion::json::record_batches_to_json_rows; -use crate::sql::datafusion::SQLExecutor; +use crate::sql::datafusion::{PlannedStatement, SQLExecutor}; use crate::CacheEndpoint; use crate::{auth::Access, errors::ApiError}; use dozer_types::grpc_types::health::health_check_response::ServingStatus; @@ -186,16 +186,16 @@ pub async fn sql( sql: extractor::SQLQueryExtractor, ) -> Result { let query = sql.0 .0; - let plan = sql_executor + let planned = sql_executor .parse(&query) .await .map_err(ApiError::SQLQueryFailed)?; - if plan.len() > 1 { + if planned.len() > 1 { return Err(ApiError::SQLQueryFailed(plan_datafusion_err!( "More than one query supplied" ))); } - let Some(Some(plan)) = plan.first().cloned() else { + let Some(PlannedStatement::Query(plan)) = planned.first().cloned() else { // This was a transaction statement, which doesn't require a result return Ok(HttpResponse::Ok().json(json!({}))); }; diff --git a/dozer-api/src/sql/datafusion/mod.rs b/dozer-api/src/sql/datafusion/mod.rs index ac35d5e539..02cb63bd0b 100644 --- a/dozer-api/src/sql/datafusion/mod.rs +++ b/dozer-api/src/sql/datafusion/mod.rs @@ -54,10 +54,16 @@ use crate::CacheEndpoint; use predicate_pushdown::{predicate_pushdown, supports_predicates_pushdown}; -pub struct SQLExecutor { +pub(crate) struct SQLExecutor { ctx: Arc, } +#[derive(Clone)] +pub(crate) enum PlannedStatement { + Query(LogicalPlan), + Statement(&'static str), +} + struct ContextResolver { tables: HashMap>, state: Arc, @@ -548,15 +554,19 @@ impl SQLExecutor { async fn parse_statement( &self, mut statement: Statement, - ) -> Result, DataFusionError> { + ) -> Result { let rewrite = if let Statement::Statement(ref stmt) = statement { match stmt.as_ref() { - ast::Statement::StartTransaction { .. } - | ast::Statement::Commit { .. } - | ast::Statement::Rollback { .. } - | ast::Statement::SetVariable { .. } => { + ast::Statement::StartTransaction { .. } => { + return Ok(PlannedStatement::Statement("BEGIN")) + } + ast::Statement::Commit { .. } => return Ok(PlannedStatement::Statement("COMMIT")), + ast::Statement::Rollback { .. } => { + return Ok(PlannedStatement::Statement("ROLLBACK")) + } + ast::Statement::SetVariable { .. } => { // dbg!(stmt); - return Ok(None); + return Ok(PlannedStatement::Statement("SET")); } ast::Statement::ShowVariable { variable } => { let variable = object_name_to_string(variable); @@ -587,14 +597,13 @@ impl SQLExecutor { ContextResolver::try_new_for_statement(Arc::new(state), &statement).await?; let planner = SqlToRel::new(&context_provider); let plan = planner.statement_to_plan(statement)?; - let options = SQLOptions::new() - .with_allow_ddl(false) - .with_allow_dml(false); + // Some BI tools use temporary tables. Let them + let options = SQLOptions::new().with_allow_ddl(true).with_allow_dml(true); options.verify_plan(&plan)?; - Ok(Some(plan)) + Ok(PlannedStatement::Query(plan)) } - pub async fn parse(&self, mut sql: &str) -> Result>, DataFusionError> { + pub async fn parse(&self, mut sql: &str) -> Result, DataFusionError> { println!("@@ query: {sql}"); if sql .to_ascii_lowercase() @@ -812,6 +821,7 @@ impl VarProvider for SystemVariables { return Ok(ScalarValue::Utf8(Some("read committed".into()))) } "@@standard_conforming_strings" => return Ok(ScalarValue::Utf8(Some("on".into()))), + "@@lc_collate" => return Ok(ScalarValue::Utf8(Some("en_US.utf8".into()))), _ => (), } } @@ -823,7 +833,7 @@ impl VarProvider for SystemVariables { fn get_type(&self, var_names: &[String]) -> Option { if var_names.len() == 1 { match var_names[0].as_str() { - "@@transaction_isolation" | "@@standard_conforming_strings" => { + "@@transaction_isolation" | "@@standard_conforming_strings" | "@@lc_collate" => { return Some(DataType::Utf8) } _ => (), @@ -856,35 +866,39 @@ fn normalize_ident(id: ast::Ident) -> String { fn sql_ast_rewrites(statement: &mut ast::Statement) { rewrite_sum(statement); rewrite_format_type(statement); - rewirte_eq_any(statement); + rewrite_eq_any(statement); + rewrite_cast_to_regclass(statement); +} + +fn rewrite_cast_to_regclass(statement: &mut ast::Statement) { + ast::visit_expressions_mut(statement, |cast: &mut ast::Expr| { + if let ast::Expr::Cast { + data_type: ast::DataType::Regclass, + .. + } = cast + { + *cast = ast::Expr::Value(ast::Value::Number("0".to_owned(), false)); + } + ControlFlow::<()>::Continue(()) + }); } // SQL AST rewirte for SUM('1') to SUM(1) fn rewrite_sum(statement: &mut ast::Statement) { ast::visit_expressions_mut(statement, |expr: &mut ast::Expr| { - match expr { - ast::Expr::Function(Function { name, args, .. }) => { - let name = &name.0; - if name.len() == 1 && name[0].value.eq_ignore_ascii_case("sum") { - if args.len() == 1 { - let arg = &mut args[0]; - match arg { - FunctionArg::Unnamed(FunctionArgExpr::Expr(ast::Expr::Value( - value, - ))) => { - if let ast::Value::SingleQuotedString(literal) = value { - if literal.parse::().is_ok() { - *value = ast::Value::Number(literal.clone(), false); - return ControlFlow::<()>::Break(()); - } - } - } - _ => (), + if let ast::Expr::Function(Function { name, args, .. }) = expr { + let name = &name.0; + if name.len() == 1 && name[0].value.eq_ignore_ascii_case("sum") && args.len() == 1 { + let arg = &mut args[0]; + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(ast::Expr::Value(value))) = arg { + if let ast::Value::SingleQuotedString(literal) = value { + if literal.parse::().is_ok() { + *value = ast::Value::Number(literal.clone(), false); + return ControlFlow::<()>::Break(()); } } } } - _ => (), }; ControlFlow::<()>::Continue(()) }); @@ -893,51 +907,44 @@ fn rewrite_sum(statement: &mut ast::Statement) { // SQL AST rewirte for format_type(arg) to format_typname((SELECT typname FROM pg_type WHERE oid = arg)) fn rewrite_format_type(statement: &mut ast::Statement) { ast::visit_expressions_mut(statement, |expr: &mut ast::Expr| { - match expr { - ast::Expr::Function(Function { name, args, .. }) => { - if name - .0 - .last() - .unwrap() - .value - .eq_ignore_ascii_case("format_type") - { - if args.len() >= 1 { - let arg = &args[0]; - let sql_expr = format!( - "format_typname((SELECT typname FROM pg_type WHERE oid = {arg}))" - ); - let result = try_parse_sql_expr(&sql_expr); - if let Ok(new_expr) = result { - *expr = new_expr; - } - return ControlFlow::<()>::Break(()); - } + if let ast::Expr::Function(Function { name, args, .. }) = expr { + if name + .0 + .last() + .unwrap() + .value + .eq_ignore_ascii_case("format_type") + && args.len() == 1 + { + let arg = &args[0]; + let sql_expr = + format!("format_typname((SELECT typname FROM pg_type WHERE oid = {arg}))"); + let result = try_parse_sql_expr(&sql_expr); + if let Ok(new_expr) = result { + *expr = new_expr; } + return ControlFlow::<()>::Break(()); } - _ => (), }; ControlFlow::<()>::Continue(()) }); } // SQL AST rewirte for left = ANY(right) to left in (right) -fn rewirte_eq_any(statement: &mut ast::Statement) { +fn rewrite_eq_any(statement: &mut ast::Statement) { ast::visit_expressions_mut(statement, |expr: &mut ast::Expr| { - match expr { - ast::Expr::AnyOp { - left, - compare_op: ast::BinaryOperator::Eq, - right, - } => { - let sql_expr = format!("{left} in ({right})"); - let result = try_parse_sql_expr(&sql_expr); - if let Ok(new_expr) = result { - *expr = new_expr; - } - return ControlFlow::<()>::Break(()); + if let ast::Expr::AnyOp { + left, + compare_op: ast::BinaryOperator::Eq, + right, + } = expr + { + let sql_expr = format!("{left} in ({right})"); + let result = try_parse_sql_expr(&sql_expr); + if let Ok(new_expr) = result { + *expr = new_expr; } - _ => (), + return ControlFlow::<()>::Break(()); }; ControlFlow::<()>::Continue(()) }); diff --git a/dozer-api/src/sql/pgwire.rs b/dozer-api/src/sql/pgwire.rs index 85167e02e2..7f9b35fa14 100644 --- a/dozer-api/src/sql/pgwire.rs +++ b/dozer-api/src/sql/pgwire.rs @@ -44,6 +44,7 @@ use crate::shutdown::ShutdownReceiver; use crate::sql::datafusion::SQLExecutor; use crate::CacheEndpoint; +use super::datafusion::PlannedStatement; use super::util::Iso8601Duration; pub struct PgWireServer { @@ -122,7 +123,7 @@ impl PgWireServer { struct QueryProcessor { sql_executor: Arc, - portal_store: Arc>>, + portal_store: Arc>>, } impl QueryProcessor { @@ -223,10 +224,15 @@ impl SimpleQueryHandler for QueryProcessor { .await .map_err(|e| PgWireError::UserError(Box::new(generic_error_info(e.to_string()))))?; + if queries.is_empty() { + return Ok(vec![Response::EmptyQuery]); + } try_join_all(queries.into_iter().map(|q| async { match q { - Some(plan) => self.execute(plan).await, - None => Ok(Response::Execution(Tag::new_for_execution("", None))), + super::datafusion::PlannedStatement::Query(plan) => self.execute(plan).await, + super::datafusion::PlannedStatement::Statement(name) => { + Ok(Response::Execution(Tag::new_for_execution(name, None))) + } } })) .await @@ -235,7 +241,7 @@ impl SimpleQueryHandler for QueryProcessor { #[async_trait] impl QueryParser for SQLExecutor { - type Statement = Option; + type Statement = Option; async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult { let mut plans = self @@ -247,18 +253,16 @@ impl QueryParser for SQLExecutor { "Multiple statements found".to_owned(), )))); } else if plans.is_empty() { - return Err(PgWireError::UserError(Box::new(generic_error_info( - "No statement found".to_owned(), - )))); + Ok(None) } else { - Ok(plans.remove(0)) + Ok(Some(plans.remove(0))) } } } #[async_trait] impl ExtendedQueryHandler for QueryProcessor { - type Statement = Option; + type Statement = Option; type PortalStore = MemPortalStore; type QueryParser = SQLExecutor; @@ -279,8 +283,14 @@ impl ExtendedQueryHandler for QueryProcessor { where C: ClientInfo + Unpin + Send + Sync, { - let Some(query) = portal.statement().statement() else { - return Ok(Response::Execution(Tag::new_for_execution("", None))); + let query = match portal.statement().statement() { + Some(PlannedStatement::Query(query)) => query, + Some(PlannedStatement::Statement(name)) => { + return Ok(Response::Execution(Tag::new_for_execution(name, None))) + } + None => { + return Ok(Response::EmptyQuery); + } }; let _params = query.get_parameter_types().map_err(|e| { PgWireError::UserError(Box::new(ErrorInfo::new( @@ -308,7 +318,7 @@ impl ExtendedQueryHandler for QueryProcessor { { match target { StatementOrPortal::Statement(stmt) => { - let Some(df_stmt) = stmt.statement() else { + let Some(PlannedStatement::Query(query)) = stmt.statement() else { return Ok(DescribeResponse::no_data()); }; let unknown_type = Type::new( @@ -317,7 +327,7 @@ impl ExtendedQueryHandler for QueryProcessor { postgres_types::Kind::Pseudo, "pg_catalog".to_owned(), ); - let types = df_stmt.get_parameter_types().map_err(|e| { + let types = query.get_parameter_types().map_err(|e| { PgWireError::UserError(Box::new(ErrorInfo::new( "FATAL".to_owned(), "XXX01".to_owned(), @@ -337,17 +347,14 @@ impl ExtendedQueryHandler for QueryProcessor { .map_or_else(|| unknown_type.clone(), map_data_type) }) .collect(); - return Ok(DescribeResponse::new( - Some(pg_types), - self.pg_schema(df_stmt), - )); - } - StatementOrPortal::Portal(portal) => { - let Some(stmt) = portal.statement().statement() else { - return Ok(DescribeResponse::no_data()); - }; - return Ok(DescribeResponse::new(None, self.pg_schema(stmt))); + return Ok(DescribeResponse::new(Some(pg_types), self.pg_schema(query))); } + StatementOrPortal::Portal(portal) => match portal.statement().statement() { + Some(PlannedStatement::Query(query)) => { + Ok(DescribeResponse::new(None, self.pg_schema(query))) + } + Some(PlannedStatement::Statement(_)) | None => Ok(DescribeResponse::no_data()), + }, } } }