Skip to content

Commit

Permalink
snowflake: Fix handling of @~% in the stage name
Browse files Browse the repository at this point in the history
  • Loading branch information
lustefaniak committed Oct 26, 2023
1 parent 2f437db commit 6e0588c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
4 changes: 1 addition & 3 deletions src/dialect/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct SnowflakeDialect;
impl Dialect for SnowflakeDialect {
// see https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html
fn is_identifier_start(&self, ch: char) -> bool {
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' || ch == '@' || ch == '%'
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_'
}

fn is_identifier_part(&self, ch: char) -> bool {
Expand All @@ -44,8 +44,6 @@ impl Dialect for SnowflakeDialect {
|| ch.is_ascii_digit()
|| ch == '$'
|| ch == '_'
|| ch == '/'
|| ch == '~'
}

fn supports_within_after_array_aggregation(&self) -> bool {
Expand Down
25 changes: 25 additions & 0 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,18 @@ impl<'a> Tokenizer<'a> {
}
}
Some(' ') => Ok(Some(Token::AtSign)),
// Snowflake stage identifier, this should be consumed as multiple dot separated word tokens
Some(_) if dialect_of!(self is SnowflakeDialect) => {
let mut s = "@".to_string();
s.push_str(&peeking_take_while(chars, |ch| {
self.dialect.is_identifier_part(ch)
|| ch == '/'
|| ch == '~'
|| ch == '%'
|| ch == '.'
}));
Ok(Some(Token::make_word(&s, None)))
}
Some(sch) if self.dialect.is_identifier_start('@') => {
self.tokenize_identifier_or_keyword([ch, *sch], chars)
}
Expand Down Expand Up @@ -2001,6 +2013,19 @@ mod tests {
compare(expected, tokens);
}

#[test]
fn tokenize_snowflake_div() {
let sql = r#"field/1000"#;
let dialect = SnowflakeDialect {};
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
let expected = vec![
Token::make_word(r#"field"#, None),
Token::Div,
Token::Number("1000".to_string(), false),
];
compare(expected, tokens);
}

#[test]
fn tokenize_quoted_identifier_with_no_escape() {
let sql = r#" "a "" b" "a """ "c """"" "#;
Expand Down
30 changes: 20 additions & 10 deletions tests/sqlparser_snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ use test_utils::*;
#[macro_use]
mod test_utils;

#[cfg(test)]
use pretty_assertions::assert_eq;

#[test]
fn test_snowflake_create_table() {
let sql = "CREATE TABLE _my_$table (am00unt number)";
Expand Down Expand Up @@ -917,7 +920,7 @@ fn test_copy_into_with_transformations() {
} => {
assert_eq!(
from_stage,
ObjectName(vec![Ident::new("@schema"), Ident::new("general_finished")])
ObjectName(vec![Ident::new("@schema.general_finished")])
);
assert_eq!(
from_transformations.as_ref().unwrap()[0],
Expand Down Expand Up @@ -1024,15 +1027,9 @@ fn test_snowflake_stage_object_names() {
];
let mut allowed_object_names = vec![
ObjectName(vec![Ident::new("my_company"), Ident::new("emp_basic")]),
ObjectName(vec![Ident::new("@namespace"), Ident::new("%table_name")]),
ObjectName(vec![
Ident::new("@namespace"),
Ident::new("%table_name/path"),
]),
ObjectName(vec![
Ident::new("@namespace"),
Ident::new("stage_name/path"),
]),
ObjectName(vec![Ident::new("@namespace.%table_name")]),
ObjectName(vec![Ident::new("@namespace.%table_name/path")]),
ObjectName(vec![Ident::new("@namespace.stage_name/path")]),
ObjectName(vec![Ident::new("@~/path")]),
];

Expand Down Expand Up @@ -1118,3 +1115,16 @@ fn parse_subquery_function_argument() {
// the function.
snowflake().one_statement_parses_to("SELECT func(SELECT 1, 2)", "SELECT func((SELECT 1, 2))");
}

#[test]
fn parse_division_correctly() {
snowflake_and_generic().one_statement_parses_to(
"SELECT field/1000 FROM tbl1",
"SELECT field / 1000 FROM tbl1",
);

snowflake_and_generic().one_statement_parses_to(
"SELECT tbl1.field/tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
"SELECT tbl1.field / tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
);
}

0 comments on commit 6e0588c

Please sign in to comment.