Skip to content

Commit

Permalink
Correctness fixes for bugs exposed by PG JDBC
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse-Bakker committed Dec 4, 2023
1 parent 27d5f0a commit fa38372
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 95 deletions.
8 changes: 4 additions & 4 deletions dozer-api/src/rest/api_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use openapiv3::OpenAPI;
use crate::api_helper::{get_record, get_records, get_records_count};
use crate::generator::oapi::generator::OpenApiGenerator;
use crate::sql::datafusion::json::record_batches_to_json_rows;
use crate::sql::datafusion::SQLExecutor;
use crate::sql::datafusion::{PlannedStatement, SQLExecutor};
use crate::CacheEndpoint;
use crate::{auth::Access, errors::ApiError};
use dozer_types::grpc_types::health::health_check_response::ServingStatus;
Expand Down Expand Up @@ -186,16 +186,16 @@ pub async fn sql(
sql: extractor::SQLQueryExtractor,
) -> Result<actix_web::HttpResponse, crate::errors::ApiError> {
let query = sql.0 .0;
let plan = sql_executor
let planned = sql_executor
.parse(&query)
.await
.map_err(ApiError::SQLQueryFailed)?;
if plan.len() > 1 {
if planned.len() > 1 {
return Err(ApiError::SQLQueryFailed(plan_datafusion_err!(
"More than one query supplied"
)));
}
let Some(Some(plan)) = plan.first().cloned() else {
let Some(PlannedStatement::Query(plan)) = planned.first().cloned() else {
// This was a transaction statement, which doesn't require a result
return Ok(HttpResponse::Ok().json(json!({})));
};
Expand Down
143 changes: 75 additions & 68 deletions dozer-api/src/sql/datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,16 @@ use crate::CacheEndpoint;

use predicate_pushdown::{predicate_pushdown, supports_predicates_pushdown};

pub struct SQLExecutor {
pub(crate) struct SQLExecutor {
ctx: Arc<SessionContext>,
}

#[derive(Clone)]
pub(crate) enum PlannedStatement {
Query(LogicalPlan),
Statement(&'static str),
}

struct ContextResolver {
tables: HashMap<String, Arc<dyn TableSource>>,
state: Arc<SessionState>,
Expand Down Expand Up @@ -548,15 +554,19 @@ impl SQLExecutor {
async fn parse_statement(
&self,
mut statement: Statement,
) -> Result<Option<LogicalPlan>, DataFusionError> {
) -> Result<PlannedStatement, DataFusionError> {
let rewrite = if let Statement::Statement(ref stmt) = statement {
match stmt.as_ref() {
ast::Statement::StartTransaction { .. }
| ast::Statement::Commit { .. }
| ast::Statement::Rollback { .. }
| ast::Statement::SetVariable { .. } => {
ast::Statement::StartTransaction { .. } => {
return Ok(PlannedStatement::Statement("BEGIN"))
}
ast::Statement::Commit { .. } => return Ok(PlannedStatement::Statement("COMMIT")),
ast::Statement::Rollback { .. } => {
return Ok(PlannedStatement::Statement("ROLLBACK"))
}
ast::Statement::SetVariable { .. } => {
// dbg!(stmt);
return Ok(None);
return Ok(PlannedStatement::Statement("SET"));
}
ast::Statement::ShowVariable { variable } => {
let variable = object_name_to_string(variable);
Expand Down Expand Up @@ -587,14 +597,13 @@ impl SQLExecutor {
ContextResolver::try_new_for_statement(Arc::new(state), &statement).await?;
let planner = SqlToRel::new(&context_provider);
let plan = planner.statement_to_plan(statement)?;
let options = SQLOptions::new()
.with_allow_ddl(false)
.with_allow_dml(false);
// Some BI tools use temporary tables. Let them
let options = SQLOptions::new().with_allow_ddl(true).with_allow_dml(true);
options.verify_plan(&plan)?;
Ok(Some(plan))
Ok(PlannedStatement::Query(plan))
}

pub async fn parse(&self, mut sql: &str) -> Result<Vec<Option<LogicalPlan>>, DataFusionError> {
pub async fn parse(&self, mut sql: &str) -> Result<Vec<PlannedStatement>, DataFusionError> {
println!("@@ query: {sql}");
if sql
.to_ascii_lowercase()
Expand Down Expand Up @@ -812,6 +821,7 @@ impl VarProvider for SystemVariables {
return Ok(ScalarValue::Utf8(Some("read committed".into())))
}
"@@standard_conforming_strings" => return Ok(ScalarValue::Utf8(Some("on".into()))),
"@@lc_collate" => return Ok(ScalarValue::Utf8(Some("en_US.utf8".into()))),
_ => (),
}
}
Expand All @@ -823,7 +833,7 @@ impl VarProvider for SystemVariables {
fn get_type(&self, var_names: &[String]) -> Option<DataType> {
if var_names.len() == 1 {
match var_names[0].as_str() {
"@@transaction_isolation" | "@@standard_conforming_strings" => {
"@@transaction_isolation" | "@@standard_conforming_strings" | "@@lc_collate" => {
return Some(DataType::Utf8)
}
_ => (),
Expand Down Expand Up @@ -856,35 +866,39 @@ fn normalize_ident(id: ast::Ident) -> String {
fn sql_ast_rewrites(statement: &mut ast::Statement) {
rewrite_sum(statement);
rewrite_format_type(statement);
rewirte_eq_any(statement);
rewrite_eq_any(statement);
rewrite_cast_to_regclass(statement);
}

fn rewrite_cast_to_regclass(statement: &mut ast::Statement) {
ast::visit_expressions_mut(statement, |cast: &mut ast::Expr| {
if let ast::Expr::Cast {
data_type: ast::DataType::Regclass,
..
} = cast
{
*cast = ast::Expr::Value(ast::Value::Number("0".to_owned(), false));
}
ControlFlow::<()>::Continue(())
});
}

// SQL AST rewirte for SUM('1') to SUM(1)
fn rewrite_sum(statement: &mut ast::Statement) {
ast::visit_expressions_mut(statement, |expr: &mut ast::Expr| {
match expr {
ast::Expr::Function(Function { name, args, .. }) => {
let name = &name.0;
if name.len() == 1 && name[0].value.eq_ignore_ascii_case("sum") {
if args.len() == 1 {
let arg = &mut args[0];
match arg {
FunctionArg::Unnamed(FunctionArgExpr::Expr(ast::Expr::Value(
value,
))) => {
if let ast::Value::SingleQuotedString(literal) = value {
if literal.parse::<i64>().is_ok() {
*value = ast::Value::Number(literal.clone(), false);
return ControlFlow::<()>::Break(());
}
}
}
_ => (),
if let ast::Expr::Function(Function { name, args, .. }) = expr {
let name = &name.0;
if name.len() == 1 && name[0].value.eq_ignore_ascii_case("sum") && args.len() == 1 {
let arg = &mut args[0];
if let FunctionArg::Unnamed(FunctionArgExpr::Expr(ast::Expr::Value(value))) = arg {
if let ast::Value::SingleQuotedString(literal) = value {
if literal.parse::<i64>().is_ok() {
*value = ast::Value::Number(literal.clone(), false);
return ControlFlow::<()>::Break(());
}
}
}
}
_ => (),
};
ControlFlow::<()>::Continue(())
});
Expand All @@ -893,51 +907,44 @@ fn rewrite_sum(statement: &mut ast::Statement) {
// SQL AST rewirte for format_type(arg) to format_typname((SELECT typname FROM pg_type WHERE oid = arg))
fn rewrite_format_type(statement: &mut ast::Statement) {
ast::visit_expressions_mut(statement, |expr: &mut ast::Expr| {
match expr {
ast::Expr::Function(Function { name, args, .. }) => {
if name
.0
.last()
.unwrap()
.value
.eq_ignore_ascii_case("format_type")
{
if args.len() >= 1 {
let arg = &args[0];
let sql_expr = format!(
"format_typname((SELECT typname FROM pg_type WHERE oid = {arg}))"
);
let result = try_parse_sql_expr(&sql_expr);
if let Ok(new_expr) = result {
*expr = new_expr;
}
return ControlFlow::<()>::Break(());
}
if let ast::Expr::Function(Function { name, args, .. }) = expr {
if name
.0
.last()
.unwrap()
.value
.eq_ignore_ascii_case("format_type")
&& args.len() == 1
{
let arg = &args[0];
let sql_expr =
format!("format_typname((SELECT typname FROM pg_type WHERE oid = {arg}))");
let result = try_parse_sql_expr(&sql_expr);
if let Ok(new_expr) = result {
*expr = new_expr;
}
return ControlFlow::<()>::Break(());
}
_ => (),
};
ControlFlow::<()>::Continue(())
});
}

// SQL AST rewirte for left = ANY(right) to left in (right)
fn rewirte_eq_any(statement: &mut ast::Statement) {
fn rewrite_eq_any(statement: &mut ast::Statement) {
ast::visit_expressions_mut(statement, |expr: &mut ast::Expr| {
match expr {
ast::Expr::AnyOp {
left,
compare_op: ast::BinaryOperator::Eq,
right,
} => {
let sql_expr = format!("{left} in ({right})");
let result = try_parse_sql_expr(&sql_expr);
if let Ok(new_expr) = result {
*expr = new_expr;
}
return ControlFlow::<()>::Break(());
if let ast::Expr::AnyOp {
left,
compare_op: ast::BinaryOperator::Eq,
right,
} = expr
{
let sql_expr = format!("{left} in ({right})");
let result = try_parse_sql_expr(&sql_expr);
if let Ok(new_expr) = result {
*expr = new_expr;
}
_ => (),
return ControlFlow::<()>::Break(());
};
ControlFlow::<()>::Continue(())
});
Expand Down
53 changes: 30 additions & 23 deletions dozer-api/src/sql/pgwire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::shutdown::ShutdownReceiver;
use crate::sql::datafusion::SQLExecutor;
use crate::CacheEndpoint;

use super::datafusion::PlannedStatement;
use super::util::Iso8601Duration;

pub struct PgWireServer {
Expand Down Expand Up @@ -122,7 +123,7 @@ impl PgWireServer {

struct QueryProcessor {
sql_executor: Arc<SQLExecutor>,
portal_store: Arc<MemPortalStore<Option<LogicalPlan>>>,
portal_store: Arc<MemPortalStore<Option<PlannedStatement>>>,
}

impl QueryProcessor {
Expand Down Expand Up @@ -223,10 +224,15 @@ impl SimpleQueryHandler for QueryProcessor {
.await
.map_err(|e| PgWireError::UserError(Box::new(generic_error_info(e.to_string()))))?;

if queries.is_empty() {
return Ok(vec![Response::EmptyQuery]);
}
try_join_all(queries.into_iter().map(|q| async {
match q {
Some(plan) => self.execute(plan).await,
None => Ok(Response::Execution(Tag::new_for_execution("", None))),
super::datafusion::PlannedStatement::Query(plan) => self.execute(plan).await,
super::datafusion::PlannedStatement::Statement(name) => {
Ok(Response::Execution(Tag::new_for_execution(name, None)))
}
}
}))
.await
Expand All @@ -235,7 +241,7 @@ impl SimpleQueryHandler for QueryProcessor {

#[async_trait]
impl QueryParser for SQLExecutor {
type Statement = Option<LogicalPlan>;
type Statement = Option<PlannedStatement>;

async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
let mut plans = self
Expand All @@ -247,18 +253,16 @@ impl QueryParser for SQLExecutor {
"Multiple statements found".to_owned(),
))));
} else if plans.is_empty() {
return Err(PgWireError::UserError(Box::new(generic_error_info(
"No statement found".to_owned(),
))));
Ok(None)
} else {
Ok(plans.remove(0))
Ok(Some(plans.remove(0)))
}
}
}

#[async_trait]
impl ExtendedQueryHandler for QueryProcessor {
type Statement = Option<LogicalPlan>;
type Statement = Option<PlannedStatement>;
type PortalStore = MemPortalStore<Self::Statement>;
type QueryParser = SQLExecutor;

Expand All @@ -279,8 +283,14 @@ impl ExtendedQueryHandler for QueryProcessor {
where
C: ClientInfo + Unpin + Send + Sync,
{
let Some(query) = portal.statement().statement() else {
return Ok(Response::Execution(Tag::new_for_execution("", None)));
let query = match portal.statement().statement() {
Some(PlannedStatement::Query(query)) => query,
Some(PlannedStatement::Statement(name)) => {
return Ok(Response::Execution(Tag::new_for_execution(name, None)))
}
None => {
return Ok(Response::EmptyQuery);
}
};
let _params = query.get_parameter_types().map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
Expand Down Expand Up @@ -308,7 +318,7 @@ impl ExtendedQueryHandler for QueryProcessor {
{
match target {
StatementOrPortal::Statement(stmt) => {
let Some(df_stmt) = stmt.statement() else {
let Some(PlannedStatement::Query(query)) = stmt.statement() else {
return Ok(DescribeResponse::no_data());
};
let unknown_type = Type::new(
Expand All @@ -317,7 +327,7 @@ impl ExtendedQueryHandler for QueryProcessor {
postgres_types::Kind::Pseudo,
"pg_catalog".to_owned(),
);
let types = df_stmt.get_parameter_types().map_err(|e| {
let types = query.get_parameter_types().map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"XXX01".to_owned(),
Expand All @@ -337,17 +347,14 @@ impl ExtendedQueryHandler for QueryProcessor {
.map_or_else(|| unknown_type.clone(), map_data_type)
})
.collect();
return Ok(DescribeResponse::new(
Some(pg_types),
self.pg_schema(df_stmt),
));
}
StatementOrPortal::Portal(portal) => {
let Some(stmt) = portal.statement().statement() else {
return Ok(DescribeResponse::no_data());
};
return Ok(DescribeResponse::new(None, self.pg_schema(stmt)));
return Ok(DescribeResponse::new(Some(pg_types), self.pg_schema(query)));
}
StatementOrPortal::Portal(portal) => match portal.statement().statement() {
Some(PlannedStatement::Query(query)) => {
Ok(DescribeResponse::new(None, self.pg_schema(query)))
}
Some(PlannedStatement::Statement(_)) | None => Ok(DescribeResponse::no_data()),
},
}
}
}
Expand Down

0 comments on commit fa38372

Please sign in to comment.