-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsql_parse.py
130 lines (105 loc) · 4.32 KB
/
sql_parse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
# pylint: disable=C,R,W
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.tokens import Keyword, Name
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'}
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
class SupersetQuery(object):
def __init__(self, sql_statement):
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
# TODO: multistatement support
logging.info('Parsing with sqlparse statement {}'.format(self.sql))
self._parsed = sqlparse.parse(self.sql)
for statement in self._parsed:
self.__extract_from_token(statement)
self._table_names = self._table_names - self._alias_names
@property
def tables(self):
return self._table_names
def is_select(self):
return self._parsed[0].get_type() == 'SELECT'
def stripped(self):
return self.sql.strip(' \t\n;')
@staticmethod
def __precedes_table_name(token_value):
for keyword in PRECEDES_TABLE_NAME:
if keyword in token_value:
return True
return False
@staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
return '{}.{}'.format(identifier.tokens[0].value,
identifier.tokens[2].value)
return identifier.get_real_name()
@staticmethod
def __is_result_operation(keyword):
for operation in RESULT_OPERATIONS:
if operation in keyword.upper():
return True
return False
@staticmethod
def __is_identifier(token):
return isinstance(token, (IdentifierList, Identifier))
def __process_identifier(self, identifier):
# exclude subselects
if '(' not in '{}'.format(identifier):
self._table_names.add(self.__get_full_name(identifier))
return
# store aliases
if hasattr(identifier, 'get_alias'):
self._alias_names.add(identifier.get_alias())
if hasattr(identifier, 'tokens'):
# some aliases are not parsed properly
if identifier.tokens[0].ttype == Name:
self._alias_names.add(identifier.tokens[0].value)
self.__extract_from_token(identifier)
def as_create_table(self, table_name, overwrite=False):
"""Reformats the query into the create table as query.
Works only for the single select SQL statements, in all other cases
the sql query is not modified.
:param superset_query: string, sql query that will be executed
:param table_name: string, will contain the results of the
query execution
:param overwrite, boolean, table table_name will be dropped if true
:return: string, create table as query
"""
exec_sql = ''
sql = self.stripped()
if overwrite:
exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
exec_sql += 'CREATE TABLE {table_name} AS \n{sql}'
return exec_sql.format(**locals())
def __extract_from_token(self, token):
if not hasattr(token, 'tokens'):
return
table_name_preceding_token = False
for item in token.tokens:
if item.is_group and not self.__is_identifier(item):
self.__extract_from_token(item)
if item.ttype in Keyword:
if self.__precedes_table_name(item.value.upper()):
table_name_preceding_token = True
continue
if not table_name_preceding_token:
continue
if item.ttype in Keyword:
if self.__is_result_operation(item.value):
table_name_preceding_token = False
continue
# FROM clause is over
break
if isinstance(item, Identifier):
self.__process_identifier(item)
if isinstance(item, IdentifierList):
for token in item.tokens:
if self.__is_identifier(token):
self.__process_identifier(token)