Skip to content

Commit

Permalink
Merge pull request #26 from stanford-oval/fix/17
Browse files Browse the repository at this point in the history
Fix/#17
  • Loading branch information
george1459 authored May 9, 2024
2 parents 2de9c11 + ec62e84 commit d118e6c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package metadata
name = "suql"
version = "1.1.7a5"
version = "1.1.7a6"
description = "Structured and Unstructured Query Language (SUQL) Python API"
author = "Shicheng Liu"
author_email = "[email protected]"
Expand Down
9 changes: 5 additions & 4 deletions src/suql/faiss_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,12 @@ def dot_product(self, id_list, query, top, individual_id_list=[]):
for sublist in map(lambda x: self.id2document[x], individual_id_list)
for item in sublist
]
embedding_indices = [
# remove potential duplicates here
embedding_indices = list(dict.fromkeys([
item
for sublist in map(lambda x: self.document2embedding[x], document_indices)
for item in sublist
]
]))

query_embedding = embed_query(query)

Expand All @@ -301,8 +302,8 @@ def dot_product(self, id_list, query, top, individual_id_list=[]):
params=faiss.SearchParametersIVF(sel=sel),
)
else:
if top > self.embeddings.ntotal:
top = self.embeddings.ntotal
if top > min(self.embeddings.ntotal, len(embedding_indices)):
top = min(self.embeddings.ntotal, len(embedding_indices))
D, I = self.embeddings.search(
query_embedding, top, params=faiss.SearchParametersIVF(sel=sel)
)
Expand Down
40 changes: 35 additions & 5 deletions src/suql/sql_free_text_support/execute_free_text_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def _retrieve_and_verify(
elif len(node.fromClause) == 1 and isinstance(node.fromClause[0], JoinExpr):
single_table = False
id_list = {}
for arg in [node.fromClause[0].larg, node.fromClause[0].rarg]:
for arg in _extract_recursive_joins(node.fromClause[0]):
if isinstance(arg, RangeVar):
table_name = arg.relname
id_field_name = table_w_ids[table_name]
Expand Down Expand Up @@ -723,10 +723,13 @@ def _retrieve_and_verify(
enforce_ordering=True if node.sortClause is not None else False,
)
else:
id_res = []
id_res = set()
for each_res in parsed_result:
if _verify_single_res(each_res, field_query_list, llm_model_name):
id_res.append(each_res[0])
if isinstance(each_res[0], list):
id_res.update(each_res[0])
else:
id_res.add(each_res[0])

end_time = time.time()
logging.info("retrieve + verification time {}s".format(end_time - start_time))
Expand Down Expand Up @@ -1133,8 +1136,31 @@ def visit_ColumnRef(self, ancestors: Ancestor, node: ColumnRef):
# the same field appears twice, this means that the original syntax is problematic
break
res = (String(sval=f"{table_name}^{node.fields[0].sval}"),)
node.fields = res

# do not replace if None, b/c this should be an aliased field
if res is not None:
node.fields = res

def _extract_recursive_joins(
fromClause: JoinExpr
):
"""
A FROM clause of a SelectStmt could have multiple joins.
This functions searilizes the joins and returns them as a list.
"""
res = []
if isinstance(fromClause.larg, RangeVar):
res.append(fromClause.larg)
if isinstance(fromClause.rarg, RangeVar):
res.append(fromClause.rarg)

if isinstance(fromClause.larg, JoinExpr):
res.extend(_extract_recursive_joins(fromClause.larg))
if isinstance(fromClause.rarg, JoinExpr):
res.extend(_extract_recursive_joins(fromClause.rarg))

return res


def _execute_structural_sql(
original_node: SelectStmt,
Expand All @@ -1157,7 +1183,7 @@ def _execute_structural_sql(
elif len(node.fromClause) == 1 and isinstance(node.fromClause[0], JoinExpr):
all_projection_fields = []
table_column_mapping = {}
for table in [node.fromClause[0].larg, node.fromClause[0].rarg]:
for table in _extract_recursive_joins(node.fromClause[0]):
# find out what columns this table has
_, columns = execute_sql_with_column_info(
RawStream()(SelectStmt(fromClause=(table,), targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),))),
Expand Down Expand Up @@ -1189,6 +1215,8 @@ def _execute_structural_sql(
table_column_mapping=table_column_mapping
)
replace_original_target_visitor(original_node.targetList)
if original_node.groupClause is not None:
replace_original_target_visitor(original_node.groupClause)
if original_node.sortClause is not None:
replace_original_target_visitor(original_node.sortClause)
# next, there are tuple joins (self joins)
Expand Down Expand Up @@ -1234,8 +1262,10 @@ def _execute_structural_sql(
node.limitOffset = None
# change predicates
node.whereClause = predicate
# reset other unecessary clauses
node.groupClause = None
node.havingClause = None
node.sortClause = None

# only queries that involve only structural parts can be executed
assert _if_all_structural(node)
Expand Down

0 comments on commit d118e6c

Please sign in to comment.