Skip to content

Commit

Permalink
de-dup repeated IDs in faiss_embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
george1459 committed May 9, 2024
1 parent f71f5ec commit d5ba26b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
9 changes: 5 additions & 4 deletions src/suql/faiss_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,12 @@ def dot_product(self, id_list, query, top, individual_id_list=[]):
for sublist in map(lambda x: self.id2document[x], individual_id_list)
for item in sublist
]
embedding_indices = [
# remove potential duplicates here
embedding_indices = list(dict.fromkeys([
item
for sublist in map(lambda x: self.document2embedding[x], document_indices)
for item in sublist
]
]))

query_embedding = embed_query(query)

Expand All @@ -301,8 +302,8 @@ def dot_product(self, id_list, query, top, individual_id_list=[]):
params=faiss.SearchParametersIVF(sel=sel),
)
else:
if top > self.embeddings.ntotal:
top = self.embeddings.ntotal
if top > min(self.embeddings.ntotal, len(embedding_indices)):
top = min(self.embeddings.ntotal, len(embedding_indices))
D, I = self.embeddings.search(
query_embedding, top, params=faiss.SearchParametersIVF(sel=sel)
)
Expand Down
7 changes: 5 additions & 2 deletions src/suql/sql_free_text_support/execute_free_text_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,10 +723,13 @@ def _retrieve_and_verify(
enforce_ordering=True if node.sortClause is not None else False,
)
else:
id_res = []
id_res = set()
for each_res in parsed_result:
if _verify_single_res(each_res, field_query_list, llm_model_name):
id_res.append(each_res[0])
if isinstance(each_res[0], list):
id_res.update(each_res[0])
else:
id_res.add(each_res[0])

end_time = time.time()
logging.info("retrieve + verification time {}s".format(end_time - start_time))
Expand Down

0 comments on commit d5ba26b

Please sign in to comment.