Skip to content

Commit

Permalink
fix cursor visit
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Apr 26, 2024
1 parent 54a99c9 commit 74fa074
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 36 deletions.
40 changes: 26 additions & 14 deletions nexus/analyzer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,38 @@ impl<'a> StatementAnalyzer for PeerExistanceAnalyzer<'a> {

fn analyze(&self, statement: &Statement) -> anyhow::Result<Self::Output> {
let mut peers_touched: HashSet<String> = HashSet::new();
let mut analyze_name = |name: &str| {
let name = name.to_lowercase();
if self.peers.contains_key(&name) {
peers_touched.insert(name);
}
};

// This is necessary as visit relations was not visiting drop table's object names,
// causing DROP commands for Postgres peer being interpreted as
// catalog queries.
// Necessary as visit_relations fails to deeply visit some structures.
visit_statements(statement, |stmt| {
if let &Statement::Drop { names, .. } = &stmt {
for name in names {
let peer_name = name.0[0].value.to_lowercase();
if self.peers.contains_key(&peer_name) {
peers_touched.insert(peer_name);
match stmt {
Statement::Drop { names, .. } => {
for name in names {
analyze_name(&name.0[0].value);
}
}
Statement::Declare { stmts } => {
for stmt in stmts {
if let Some(ref query) = stmt.for_query {
visit_relations(query, |relation| {
analyze_name(&relation.0[0].value);
ControlFlow::<()>::Continue(())
});
}
}
}
_ => (),
}
ControlFlow::<()>::Continue(())
});

visit_relations(statement, |relation| {
let peer_name = relation.0[0].value.to_lowercase();
if self.peers.contains_key(&peer_name) {
peers_touched.insert(peer_name);
}
analyze_name(&relation.0[0].value);
ControlFlow::<()>::Continue(())
});

Expand Down Expand Up @@ -476,15 +488,15 @@ impl StatementAnalyzer for PeerCursorAnalyzer {
}
| FetchDirection::Forward {
limit: Some(ast::Value::Number(n, _)),
} => n.parse::<usize>(),
} => n.parse::<usize>()?,
_ => {
return Err(anyhow::anyhow!(
"invalid fetch direction for cursor: {:?}",
direction
))
}
};
Ok(Some(CursorEvent::Fetch(name.value.clone(), count?)))
Ok(Some(CursorEvent::Fetch(name.value.clone(), count)))
}
Statement::Close { cursor } => match cursor {
ast::CloseCursor::All => Ok(Some(CursorEvent::CloseAll)),
Expand Down
3 changes: 1 addition & 2 deletions nexus/parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ impl NexusStatement {
});
}

let peer_cursor: PeerCursorAnalyzer = Default::default();
if let Ok(Some(cursor)) = peer_cursor.analyze(stmt) {
if let Ok(Some(cursor)) = PeerCursorAnalyzer.analyze(stmt) {
return Ok(NexusStatement::PeerCursor {
stmt: stmt.clone(),
cursor,
Expand Down
3 changes: 1 addition & 2 deletions nexus/peer-bigquery/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ use sqlparser::ast::{
FunctionArgExpr, Ident, ObjectName, Query, SetExpr, SetOperator, SetQuantifier, TimezoneInfo,
};

#[derive(Default)]
pub struct BigqueryAst {}
pub struct BigqueryAst;

impl BigqueryAst {
pub fn is_timestamp_returning_function(&self, name: &str) -> bool {
Expand Down
6 changes: 2 additions & 4 deletions nexus/peer-bigquery/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ impl QueryExecutor for BigQueryQueryExecutor {
match stmt {
Statement::Query(query) => {
let mut query = query.clone();
let bq_ast = ast::BigqueryAst::default();
bq_ast
ast::BigqueryAst
.rewrite(&self.dataset_id, &mut query)
.context("unable to rewrite query")
.map_err(|err| PgWireError::ApiError(err.into()))?;
Expand Down Expand Up @@ -206,8 +205,7 @@ impl QueryExecutor for BigQueryQueryExecutor {
match stmt {
Statement::Query(query) => {
let mut query = query.clone();
let bq_ast = ast::BigqueryAst::default();
bq_ast
ast::BigqueryAst
.rewrite(&self.dataset_id, &mut query)
.context("unable to rewrite query")
.map_err(|err| PgWireError::ApiError(err.into()))?;
Expand Down
3 changes: 1 addition & 2 deletions nexus/peer-snowflake/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use sqlparser::ast::{
FunctionArg, FunctionArgExpr, Ident, JsonOperator, ObjectName, Query, Statement, TimezoneInfo,
};

#[derive(Default)]
pub struct SnowflakeAst {}
pub struct SnowflakeAst;

impl SnowflakeAst {
pub fn rewrite(&self, query: &mut Query) -> anyhow::Result<()> {
Expand Down
11 changes: 3 additions & 8 deletions nexus/peer-snowflake/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ impl SnowflakeQueryExecutor {
pub async fn query(&self, query: &Query) -> PgWireResult<ResultSet> {
let mut query = query.clone();

let ast = ast::SnowflakeAst::default();
let _ = ast.rewrite(&mut query);
let _ = ast::SnowflakeAst.rewrite(&mut query);

let query_str: String = query.to_string();
info!("Processing SnowFlake query: {}", query_str);
Expand Down Expand Up @@ -299,8 +298,7 @@ impl QueryExecutor for SnowflakeQueryExecutor {
Statement::Query(query) => {
let mut new_query = query.clone();

let snowflake_ast = ast::SnowflakeAst::default();
snowflake_ast
ast::SnowflakeAst
.rewrite(&mut new_query)
.context("unable to rewrite query")
.map_err(|err| PgWireError::ApiError(err.into()))?;
Expand Down Expand Up @@ -402,14 +400,11 @@ impl QueryExecutor for SnowflakeQueryExecutor {
match stmt {
Statement::Query(query) => {
let mut new_query = query.clone();
let sf_ast = ast::SnowflakeAst::default();
sf_ast
ast::SnowflakeAst
.rewrite(&mut new_query)
.context("unable to rewrite query")
.map_err(|err| PgWireError::ApiError(err.into()))?;

// new_query.limit = Some(Expr::Value(Value::Number("1".to_owned(), false)));

let result_set = self.query(&new_query).await?;
let schema = SnowflakeSchema::from_result_set(&result_set);

Expand Down
5 changes: 1 addition & 4 deletions nexus/server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ impl AuthSource for FixedPasswordAuthSource {
let salt = rand::thread_rng().gen::<[u8; 4]>();
let password = &self.password;
let hash_password = hash_md5_password(login_info.user().unwrap_or(""), password, &salt);
Ok(Password::new(
Some(salt.to_vec()),
hash_password.as_bytes().to_vec(),
))
Ok(Password::new(Some(salt.to_vec()), Vec::from(hash_password)))
}
}

Expand Down

0 comments on commit 74fa074

Please sign in to comment.