Skip to content

Commit

Permalink
Use _extract_recursive_joins to handle multiple joins
Browse files Browse the repository at this point in the history
  • Loading branch information
george1459 committed May 8, 2024
1 parent 2de9c11 commit f71f5ec
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions src/suql/sql_free_text_support/execute_free_text_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def _retrieve_and_verify(
elif len(node.fromClause) == 1 and isinstance(node.fromClause[0], JoinExpr):
single_table = False
id_list = {}
for arg in [node.fromClause[0].larg, node.fromClause[0].rarg]:
for arg in _extract_recursive_joins(node.fromClause[0]):
if isinstance(arg, RangeVar):
table_name = arg.relname
id_field_name = table_w_ids[table_name]
Expand Down Expand Up @@ -1133,8 +1133,31 @@ def visit_ColumnRef(self, ancestors: Ancestor, node: ColumnRef):
# the same field appears twice, this means that the original syntax is problematic
break
res = (String(sval=f"{table_name}^{node.fields[0].sval}"),)
node.fields = res

# do not replace if None, b/c this should be an aliased field
if res is not None:
node.fields = res

def _extract_recursive_joins(
fromClause: JoinExpr
):
"""
A FROM clause of a SelectStmt could have multiple joins.
This functions searilizes the joins and returns them as a list.
"""
res = []
if isinstance(fromClause.larg, RangeVar):
res.append(fromClause.larg)
if isinstance(fromClause.rarg, RangeVar):
res.append(fromClause.rarg)

if isinstance(fromClause.larg, JoinExpr):
res.extend(_extract_recursive_joins(fromClause.larg))
if isinstance(fromClause.rarg, JoinExpr):
res.extend(_extract_recursive_joins(fromClause.rarg))

return res


def _execute_structural_sql(
original_node: SelectStmt,
Expand All @@ -1157,7 +1180,7 @@ def _execute_structural_sql(
elif len(node.fromClause) == 1 and isinstance(node.fromClause[0], JoinExpr):
all_projection_fields = []
table_column_mapping = {}
for table in [node.fromClause[0].larg, node.fromClause[0].rarg]:
for table in _extract_recursive_joins(node.fromClause[0]):
# find out what columns this table has
_, columns = execute_sql_with_column_info(
RawStream()(SelectStmt(fromClause=(table,), targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),))),
Expand Down Expand Up @@ -1189,6 +1212,8 @@ def _execute_structural_sql(
table_column_mapping=table_column_mapping
)
replace_original_target_visitor(original_node.targetList)
if original_node.groupClause is not None:
replace_original_target_visitor(original_node.groupClause)
if original_node.sortClause is not None:
replace_original_target_visitor(original_node.sortClause)
# next, there are tuple joins (self joins)
Expand Down Expand Up @@ -1234,8 +1259,10 @@ def _execute_structural_sql(
node.limitOffset = None
# change predicates
node.whereClause = predicate
# reset other unecessary clauses
node.groupClause = None
node.havingClause = None
node.sortClause = None

# only queries that involve only structural parts can be executed
assert _if_all_structural(node)
Expand Down

0 comments on commit f71f5ec

Please sign in to comment.