Skip to content

Commit

Permalink
snowflake: Fix handling of /~% in the stage name (apache#1009)
Browse files Browse the repository at this point in the history
  • Loading branch information
lustefaniak authored and serprex committed Nov 6, 2023
1 parent 47bd477 commit b256730
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
48 changes: 43 additions & 5 deletions src/dialect/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct SnowflakeDialect;
impl Dialect for SnowflakeDialect {
// see https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html
fn is_identifier_start(&self, ch: char) -> bool {
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' || ch == '@' || ch == '%'
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_'
}

fn is_identifier_part(&self, ch: char) -> bool {
Expand All @@ -44,8 +44,6 @@ impl Dialect for SnowflakeDialect {
|| ch.is_ascii_digit()
|| ch == '$'
|| ch == '_'
|| ch == '/'
|| ch == '~'
}

fn supports_within_after_array_aggregation(&self) -> bool {
Expand Down Expand Up @@ -148,8 +146,48 @@ pub fn parse_create_stage(
})
}

pub fn parse_stage_name_identifier(parser: &mut Parser) -> Result<Ident, ParserError> {
let mut ident = String::new();
while let Some(next_token) = parser.next_token_no_skip() {
match &next_token.token {
Token::Whitespace(_) => break,
Token::Period => {
parser.prev_token();
break;
}
Token::AtSign => ident.push('@'),
Token::Tilde => ident.push('~'),
Token::Mod => ident.push('%'),
Token::Div => ident.push('/'),
Token::Word(w) => ident.push_str(&w.value),
_ => return parser.expected("stage name identifier", parser.peek_token()),
}
}
Ok(Ident::new(ident))
}

pub fn parse_snowflake_stage_name(parser: &mut Parser) -> Result<ObjectName, ParserError> {
match parser.next_token().token {
Token::AtSign => {
parser.prev_token();
let mut idents = vec![];
loop {
idents.push(parse_stage_name_identifier(parser)?);
if !parser.consume_token(&Token::Period) {
break;
}
}
Ok(ObjectName(idents))
}
_ => {
parser.prev_token();
Ok(parser.parse_object_name()?)
}
}
}

pub fn parse_copy_into(parser: &mut Parser) -> Result<Statement, ParserError> {
let into: ObjectName = parser.parse_object_name()?;
let into: ObjectName = parse_snowflake_stage_name(parser)?;
let mut files: Vec<String> = vec![];
let mut from_transformations: Option<Vec<StageLoadSelectItem>> = None;
let from_stage_alias;
Expand All @@ -165,7 +203,7 @@ pub fn parse_copy_into(parser: &mut Parser) -> Result<Statement, ParserError> {
from_transformations = parse_select_items_for_data_load(parser)?;

parser.expect_keyword(Keyword::FROM)?;
from_stage = parser.parse_object_name()?;
from_stage = parse_snowflake_stage_name(parser)?;
stage_params = parse_stage_params(parser)?;

// as
Expand Down
13 changes: 13 additions & 0 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,19 @@ mod tests {
compare(expected, tokens);
}

#[test]
fn tokenize_snowflake_div() {
let sql = r#"field/1000"#;
let dialect = SnowflakeDialect {};
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
let expected = vec![
Token::make_word(r#"field"#, None),
Token::Div,
Token::Number("1000".to_string(), false),
];
compare(expected, tokens);
}

#[test]
fn tokenize_quoted_identifier_with_no_escape() {
let sql = r#" "a "" b" "a """ "c """"" "#;
Expand Down
16 changes: 16 additions & 0 deletions tests/sqlparser_snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ use test_utils::*;
#[macro_use]
mod test_utils;

#[cfg(test)]
use pretty_assertions::assert_eq;

#[test]
fn test_snowflake_create_table() {
let sql = "CREATE TABLE _my_$table (am00unt number)";
Expand Down Expand Up @@ -1118,3 +1121,16 @@ fn parse_subquery_function_argument() {
// the function.
snowflake().one_statement_parses_to("SELECT func(SELECT 1, 2)", "SELECT func((SELECT 1, 2))");
}

#[test]
fn parse_division_correctly() {
snowflake_and_generic().one_statement_parses_to(
"SELECT field/1000 FROM tbl1",
"SELECT field / 1000 FROM tbl1",
);

snowflake_and_generic().one_statement_parses_to(
"SELECT tbl1.field/tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
"SELECT tbl1.field / tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
);
}

0 comments on commit b256730

Please sign in to comment.