Skip to content

Commit

Permalink
Fix(snowflake): refactor location paths (#2668)
Browse files Browse the repository at this point in the history
* Fix(snowflake): refactor location paths

* Rename connected parser

* Fixup

* Fixup

* PR feedback
  • Loading branch information
georgesittas authored Dec 14, 2023
1 parent 2027841 commit 2ae0deb
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 27 deletions.
47 changes: 22 additions & 25 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,30 +371,27 @@ def _parse_lateral(self) -> t.Optional[exp.Lateral]:

def _parse_table_parts(self, schema: bool = False) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
table: t.Optional[exp.Expression] = None
if self._match_text_seq("@"):
table_name = "@"
while self._curr:
self._advance()
table_name += self._prev.text
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
break
while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
table_name += self._prev.text

table = exp.var(table_name)
if self._match_text_seq("@", advance=False):
table: t.Optional[exp.Expression] = self._parse_location_path()
elif self._match(TokenType.STRING, advance=False):
table = self._parse_string()
else:
table = None

if table:
file_format = None
pattern = None

if self._match_text_seq("(", "FILE_FORMAT", "=>"):
file_format = self._parse_string() or super()._parse_table_parts()
if self._match_text_seq(",", "PATTERN", "=>"):
self._match(TokenType.L_PAREN)
while self._curr and not self._match(TokenType.R_PAREN):
if self._match_text_seq("FILE_FORMAT", "=>"):
file_format = self._parse_string() or super()._parse_table_parts()
elif self._match_text_seq("PATTERN", "=>"):
pattern = self._parse_string()
self._match_r_paren()
else:
break

self._match(TokenType.COMMA)

return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)

Expand Down Expand Up @@ -438,17 +435,17 @@ def _parse_alter_table_swap(self) -> exp.SwapTable:

def _parse_location(self) -> exp.LocationProperty:
self._match(TokenType.EQ)
return self.expression(exp.LocationProperty, this=self._parse_location_path())

parts = [self._parse_var(any_token=True)]
def _parse_location_path(self) -> exp.Var:
parts = [self._advance_any(ignore_reserved=True)]

while self._match(TokenType.SLASH):
if self._is_connected():
parts.append(self._parse_var(any_token=True))
else:
parts.append(exp.Var(this=""))
return self.expression(
exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
)
# We avoid consuming a comma token because external tables like @foo and @bar
# can be joined in a query with a comma separator.
while self._is_connected() and not self._match(TokenType.COMMA, advance=False):
parts.append(self._advance_any(ignore_reserved=True))

return exp.var("".join(part.text for part in parts if part))

class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4910,8 +4910,8 @@ def _parse_var(
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()

def _advance_any(self) -> t.Optional[Token]:
if self._curr and self._curr.token_type not in self.RESERVED_TOKENS:
def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]:
if self._curr and (ignore_reserved or self._curr.token_type not in self.RESERVED_TOKENS):
self._advance()
return self._prev
return None
Expand Down
4 changes: 4 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,10 @@ def test_staged_files(self):
"SELECT * FROM @mystage t (c1)",
"SELECT * FROM @mystage AS t(c1)",
)
self.validate_identity(
"SELECT * FROM @foo/bar (PATTERN => 'test', FILE_FORMAT => ds_sandbox.test.my_csv_format) AS bla",
"SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla",
)

def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
Expand Down

0 comments on commit 2ae0deb

Please sign in to comment.