From 74fa0748f87cf71bb63d02292878e0ccdabadbf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 26 Apr 2024 20:05:27 +0000 Subject: [PATCH] fix cursor visit --- nexus/analyzer/src/lib.rs | 40 +++++++++++++++++++++------------ nexus/parser/src/lib.rs | 3 +-- nexus/peer-bigquery/src/ast.rs | 3 +-- nexus/peer-bigquery/src/lib.rs | 6 ++--- nexus/peer-snowflake/src/ast.rs | 3 +-- nexus/peer-snowflake/src/lib.rs | 11 +++------ nexus/server/src/main.rs | 5 +---- 7 files changed, 35 insertions(+), 36 deletions(-) diff --git a/nexus/analyzer/src/lib.rs b/nexus/analyzer/src/lib.rs index 9e3e5b7524..ec1ebf57cb 100644 --- a/nexus/analyzer/src/lib.rs +++ b/nexus/analyzer/src/lib.rs @@ -53,26 +53,38 @@ impl<'a> StatementAnalyzer for PeerExistanceAnalyzer<'a> { fn analyze(&self, statement: &Statement) -> anyhow::Result { let mut peers_touched: HashSet = HashSet::new(); + let mut analyze_name = |name: &str| { + let name = name.to_lowercase(); + if self.peers.contains_key(&name) { + peers_touched.insert(name); + } + }; - // This is necessary as visit relations was not visiting drop table's object names, - // causing DROP commands for Postgres peer being interpreted as - // catalog queries. + // Necessary as visit_relations fails to deeply visit some structures. visit_statements(statement, |stmt| { - if let &Statement::Drop { names, .. } = &stmt { - for name in names { - let peer_name = name.0[0].value.to_lowercase(); - if self.peers.contains_key(&peer_name) { - peers_touched.insert(peer_name); + match stmt { + Statement::Drop { names, .. } => { + for name in names { + analyze_name(&name.0[0].value); } } + Statement::Declare { stmts } => { + for stmt in stmts { + if let Some(ref query) = stmt.for_query { + visit_relations(query, |relation| { + analyze_name(&relation.0[0].value); + ControlFlow::<()>::Continue(()) + }); + } + } + } + _ => (), } ControlFlow::<()>::Continue(()) }); + visit_relations(statement, |relation| { - let peer_name = relation.0[0].value.to_lowercase(); - if self.peers.contains_key(&peer_name) { - peers_touched.insert(peer_name); - } + analyze_name(&relation.0[0].value); ControlFlow::<()>::Continue(()) }); @@ -476,7 +488,7 @@ impl StatementAnalyzer for PeerCursorAnalyzer { } | FetchDirection::Forward { limit: Some(ast::Value::Number(n, _)), - } => n.parse::(), + } => n.parse::()?, _ => { return Err(anyhow::anyhow!( "invalid fetch direction for cursor: {:?}", @@ -484,7 +496,7 @@ impl StatementAnalyzer for PeerCursorAnalyzer { )) } }; - Ok(Some(CursorEvent::Fetch(name.value.clone(), count?))) + Ok(Some(CursorEvent::Fetch(name.value.clone(), count))) } Statement::Close { cursor } => match cursor { ast::CloseCursor::All => Ok(Some(CursorEvent::CloseAll)), diff --git a/nexus/parser/src/lib.rs b/nexus/parser/src/lib.rs index 491693a4c5..4f5cad356a 100644 --- a/nexus/parser/src/lib.rs +++ b/nexus/parser/src/lib.rs @@ -59,8 +59,7 @@ impl NexusStatement { }); } - let peer_cursor: PeerCursorAnalyzer = Default::default(); - if let Ok(Some(cursor)) = peer_cursor.analyze(stmt) { + if let Ok(Some(cursor)) = PeerCursorAnalyzer.analyze(stmt) { return Ok(NexusStatement::PeerCursor { stmt: stmt.clone(), cursor, diff --git a/nexus/peer-bigquery/src/ast.rs b/nexus/peer-bigquery/src/ast.rs index 8429e0ebe1..15e5efe5a4 100644 --- a/nexus/peer-bigquery/src/ast.rs +++ b/nexus/peer-bigquery/src/ast.rs @@ -8,8 +8,7 @@ use sqlparser::ast::{ FunctionArgExpr, Ident, ObjectName, Query, SetExpr, SetOperator, SetQuantifier, TimezoneInfo, }; -#[derive(Default)] -pub struct BigqueryAst {} +pub struct BigqueryAst; impl BigqueryAst { pub fn is_timestamp_returning_function(&self, name: &str) -> bool { diff --git a/nexus/peer-bigquery/src/lib.rs b/nexus/peer-bigquery/src/lib.rs index 23216df671..7a3aca9646 100644 --- a/nexus/peer-bigquery/src/lib.rs +++ b/nexus/peer-bigquery/src/lib.rs @@ -97,8 +97,7 @@ impl QueryExecutor for BigQueryQueryExecutor { match stmt { Statement::Query(query) => { let mut query = query.clone(); - let bq_ast = ast::BigqueryAst::default(); - bq_ast + ast::BigqueryAst .rewrite(&self.dataset_id, &mut query) .context("unable to rewrite query") .map_err(|err| PgWireError::ApiError(err.into()))?; @@ -206,8 +205,7 @@ impl QueryExecutor for BigQueryQueryExecutor { match stmt { Statement::Query(query) => { let mut query = query.clone(); - let bq_ast = ast::BigqueryAst::default(); - bq_ast + ast::BigqueryAst .rewrite(&self.dataset_id, &mut query) .context("unable to rewrite query") .map_err(|err| PgWireError::ApiError(err.into()))?; diff --git a/nexus/peer-snowflake/src/ast.rs b/nexus/peer-snowflake/src/ast.rs index 3dddd577df..0934ec5592 100644 --- a/nexus/peer-snowflake/src/ast.rs +++ b/nexus/peer-snowflake/src/ast.rs @@ -5,8 +5,7 @@ use sqlparser::ast::{ FunctionArg, FunctionArgExpr, Ident, JsonOperator, ObjectName, Query, Statement, TimezoneInfo, }; -#[derive(Default)] -pub struct SnowflakeAst {} +pub struct SnowflakeAst; impl SnowflakeAst { pub fn rewrite(&self, query: &mut Query) -> anyhow::Result<()> { diff --git a/nexus/peer-snowflake/src/lib.rs b/nexus/peer-snowflake/src/lib.rs index fec5c074c6..af1eab1944 100644 --- a/nexus/peer-snowflake/src/lib.rs +++ b/nexus/peer-snowflake/src/lib.rs @@ -198,8 +198,7 @@ impl SnowflakeQueryExecutor { pub async fn query(&self, query: &Query) -> PgWireResult { let mut query = query.clone(); - let ast = ast::SnowflakeAst::default(); - let _ = ast.rewrite(&mut query); + let _ = ast::SnowflakeAst.rewrite(&mut query); let query_str: String = query.to_string(); info!("Processing SnowFlake query: {}", query_str); @@ -299,8 +298,7 @@ impl QueryExecutor for SnowflakeQueryExecutor { Statement::Query(query) => { let mut new_query = query.clone(); - let snowflake_ast = ast::SnowflakeAst::default(); - snowflake_ast + ast::SnowflakeAst .rewrite(&mut new_query) .context("unable to rewrite query") .map_err(|err| PgWireError::ApiError(err.into()))?; @@ -402,14 +400,11 @@ impl QueryExecutor for SnowflakeQueryExecutor { match stmt { Statement::Query(query) => { let mut new_query = query.clone(); - let sf_ast = ast::SnowflakeAst::default(); - sf_ast + ast::SnowflakeAst .rewrite(&mut new_query) .context("unable to rewrite query") .map_err(|err| PgWireError::ApiError(err.into()))?; - // new_query.limit = Some(Expr::Value(Value::Number("1".to_owned(), false))); - let result_set = self.query(&new_query).await?; let schema = SnowflakeSchema::from_result_set(&result_set); diff --git a/nexus/server/src/main.rs b/nexus/server/src/main.rs index e9e2597cfe..d7a0016b98 100644 --- a/nexus/server/src/main.rs +++ b/nexus/server/src/main.rs @@ -66,10 +66,7 @@ impl AuthSource for FixedPasswordAuthSource { let salt = rand::thread_rng().gen::<[u8; 4]>(); let password = &self.password; let hash_password = hash_md5_password(login_info.user().unwrap_or(""), password, &salt); - Ok(Password::new( - Some(salt.to_vec()), - hash_password.as_bytes().to_vec(), - )) + Ok(Password::new(Some(salt.to_vec()), Vec::from(hash_password))) } }