From f71f5ecaed50e3fc1b9a361bf6b3df807a908efe Mon Sep 17 00:00:00 2001 From: Shicheng Liu Date: Wed, 8 May 2024 20:50:21 +0000 Subject: [PATCH] Use `_extract_recursive_joins` to handle multiple joins --- .../execute_free_text_sql.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/src/suql/sql_free_text_support/execute_free_text_sql.py b/src/suql/sql_free_text_support/execute_free_text_sql.py index 48614b5..8d34408 100644 --- a/src/suql/sql_free_text_support/execute_free_text_sql.py +++ b/src/suql/sql_free_text_support/execute_free_text_sql.py @@ -619,7 +619,7 @@ def _retrieve_and_verify( elif len(node.fromClause) == 1 and isinstance(node.fromClause[0], JoinExpr): single_table = False id_list = {} - for arg in [node.fromClause[0].larg, node.fromClause[0].rarg]: + for arg in _extract_recursive_joins(node.fromClause[0]): if isinstance(arg, RangeVar): table_name = arg.relname id_field_name = table_w_ids[table_name] @@ -1133,8 +1133,31 @@ def visit_ColumnRef(self, ancestors: Ancestor, node: ColumnRef): # the same field appears twice, this means that the original syntax is problematic break res = (String(sval=f"{table_name}^{node.fields[0].sval}"),) - node.fields = res + + # do not replace if None, b/c this should be an aliased field + if res is not None: + node.fields = res +def _extract_recursive_joins( + fromClause: JoinExpr +): + """ + A FROM clause of a SelectStmt could have multiple joins. + This functions searilizes the joins and returns them as a list. + """ + res = [] + if isinstance(fromClause.larg, RangeVar): + res.append(fromClause.larg) + if isinstance(fromClause.rarg, RangeVar): + res.append(fromClause.rarg) + + if isinstance(fromClause.larg, JoinExpr): + res.extend(_extract_recursive_joins(fromClause.larg)) + if isinstance(fromClause.rarg, JoinExpr): + res.extend(_extract_recursive_joins(fromClause.rarg)) + + return res + def _execute_structural_sql( original_node: SelectStmt, @@ -1157,7 +1180,7 @@ def _execute_structural_sql( elif len(node.fromClause) == 1 and isinstance(node.fromClause[0], JoinExpr): all_projection_fields = [] table_column_mapping = {} - for table in [node.fromClause[0].larg, node.fromClause[0].rarg]: + for table in _extract_recursive_joins(node.fromClause[0]): # find out what columns this table has _, columns = execute_sql_with_column_info( RawStream()(SelectStmt(fromClause=(table,), targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),))), @@ -1189,6 +1212,8 @@ def _execute_structural_sql( table_column_mapping=table_column_mapping ) replace_original_target_visitor(original_node.targetList) + if original_node.groupClause is not None: + replace_original_target_visitor(original_node.groupClause) if original_node.sortClause is not None: replace_original_target_visitor(original_node.sortClause) # next, there are tuple joins (self joins) @@ -1234,8 +1259,10 @@ def _execute_structural_sql( node.limitOffset = None # change predicates node.whereClause = predicate + # reset other unecessary clauses node.groupClause = None node.havingClause = None + node.sortClause = None # only queries that involve only structural parts can be executed assert _if_all_structural(node)