diff --git a/setup.py b/setup.py index df546c3..613e10f 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ # Package metadata name = "suql" -version = "1.1.7a5" +version = "1.1.7a6" description = "Structured and Unstructured Query Language (SUQL) Python API" author = "Shicheng Liu" author_email = "shicheng@cs.stanford.edu" diff --git a/src/suql/faiss_embedding.py b/src/suql/faiss_embedding.py index d0a11fe..291b2c2 100644 --- a/src/suql/faiss_embedding.py +++ b/src/suql/faiss_embedding.py @@ -285,11 +285,12 @@ def dot_product(self, id_list, query, top, individual_id_list=[]): for sublist in map(lambda x: self.id2document[x], individual_id_list) for item in sublist ] - embedding_indices = [ + # remove potential duplicates here + embedding_indices = list(dict.fromkeys([ item for sublist in map(lambda x: self.document2embedding[x], document_indices) for item in sublist - ] + ])) query_embedding = embed_query(query) @@ -301,8 +302,8 @@ def dot_product(self, id_list, query, top, individual_id_list=[]): params=faiss.SearchParametersIVF(sel=sel), ) else: - if top > self.embeddings.ntotal: - top = self.embeddings.ntotal + if top > min(self.embeddings.ntotal, len(embedding_indices)): + top = min(self.embeddings.ntotal, len(embedding_indices)) D, I = self.embeddings.search( query_embedding, top, params=faiss.SearchParametersIVF(sel=sel) ) diff --git a/src/suql/sql_free_text_support/execute_free_text_sql.py b/src/suql/sql_free_text_support/execute_free_text_sql.py index 48614b5..82e474f 100644 --- a/src/suql/sql_free_text_support/execute_free_text_sql.py +++ b/src/suql/sql_free_text_support/execute_free_text_sql.py @@ -619,7 +619,7 @@ def _retrieve_and_verify( elif len(node.fromClause) == 1 and isinstance(node.fromClause[0], JoinExpr): single_table = False id_list = {} - for arg in [node.fromClause[0].larg, node.fromClause[0].rarg]: + for arg in _extract_recursive_joins(node.fromClause[0]): if isinstance(arg, RangeVar): table_name = arg.relname id_field_name = table_w_ids[table_name] @@ -723,10 +723,13 @@ def _retrieve_and_verify( enforce_ordering=True if node.sortClause is not None else False, ) else: - id_res = [] + id_res = set() for each_res in parsed_result: if _verify_single_res(each_res, field_query_list, llm_model_name): - id_res.append(each_res[0]) + if isinstance(each_res[0], list): + id_res.update(each_res[0]) + else: + id_res.add(each_res[0]) end_time = time.time() logging.info("retrieve + verification time {}s".format(end_time - start_time)) @@ -1133,8 +1136,31 @@ def visit_ColumnRef(self, ancestors: Ancestor, node: ColumnRef): # the same field appears twice, this means that the original syntax is problematic break res = (String(sval=f"{table_name}^{node.fields[0].sval}"),) - node.fields = res + + # do not replace if None, b/c this should be an aliased field + if res is not None: + node.fields = res +def _extract_recursive_joins( + fromClause: JoinExpr +): + """ + A FROM clause of a SelectStmt could have multiple joins. + This functions searilizes the joins and returns them as a list. + """ + res = [] + if isinstance(fromClause.larg, RangeVar): + res.append(fromClause.larg) + if isinstance(fromClause.rarg, RangeVar): + res.append(fromClause.rarg) + + if isinstance(fromClause.larg, JoinExpr): + res.extend(_extract_recursive_joins(fromClause.larg)) + if isinstance(fromClause.rarg, JoinExpr): + res.extend(_extract_recursive_joins(fromClause.rarg)) + + return res + def _execute_structural_sql( original_node: SelectStmt, @@ -1157,7 +1183,7 @@ def _execute_structural_sql( elif len(node.fromClause) == 1 and isinstance(node.fromClause[0], JoinExpr): all_projection_fields = [] table_column_mapping = {} - for table in [node.fromClause[0].larg, node.fromClause[0].rarg]: + for table in _extract_recursive_joins(node.fromClause[0]): # find out what columns this table has _, columns = execute_sql_with_column_info( RawStream()(SelectStmt(fromClause=(table,), targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),))), @@ -1189,6 +1215,8 @@ def _execute_structural_sql( table_column_mapping=table_column_mapping ) replace_original_target_visitor(original_node.targetList) + if original_node.groupClause is not None: + replace_original_target_visitor(original_node.groupClause) if original_node.sortClause is not None: replace_original_target_visitor(original_node.sortClause) # next, there are tuple joins (self joins) @@ -1234,8 +1262,10 @@ def _execute_structural_sql( node.limitOffset = None # change predicates node.whereClause = predicate + # reset other unecessary clauses node.groupClause = None node.havingClause = None + node.sortClause = None # only queries that involve only structural parts can be executed assert _if_all_structural(node)