From d5ba26b3f18f37d7a22ffc198c25019c9014f3fb Mon Sep 17 00:00:00 2001 From: Shicheng Liu Date: Thu, 9 May 2024 00:28:56 +0000 Subject: [PATCH] de-dup repeated IDs in `faiss_embedding` --- src/suql/faiss_embedding.py | 9 +++++---- src/suql/sql_free_text_support/execute_free_text_sql.py | 7 +++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/suql/faiss_embedding.py b/src/suql/faiss_embedding.py index d0a11fe..291b2c2 100644 --- a/src/suql/faiss_embedding.py +++ b/src/suql/faiss_embedding.py @@ -285,11 +285,12 @@ def dot_product(self, id_list, query, top, individual_id_list=[]): for sublist in map(lambda x: self.id2document[x], individual_id_list) for item in sublist ] - embedding_indices = [ + # remove potential duplicates here + embedding_indices = list(dict.fromkeys([ item for sublist in map(lambda x: self.document2embedding[x], document_indices) for item in sublist - ] + ])) query_embedding = embed_query(query) @@ -301,8 +302,8 @@ def dot_product(self, id_list, query, top, individual_id_list=[]): params=faiss.SearchParametersIVF(sel=sel), ) else: - if top > self.embeddings.ntotal: - top = self.embeddings.ntotal + if top > min(self.embeddings.ntotal, len(embedding_indices)): + top = min(self.embeddings.ntotal, len(embedding_indices)) D, I = self.embeddings.search( query_embedding, top, params=faiss.SearchParametersIVF(sel=sel) ) diff --git a/src/suql/sql_free_text_support/execute_free_text_sql.py b/src/suql/sql_free_text_support/execute_free_text_sql.py index 8d34408..82e474f 100644 --- a/src/suql/sql_free_text_support/execute_free_text_sql.py +++ b/src/suql/sql_free_text_support/execute_free_text_sql.py @@ -723,10 +723,13 @@ def _retrieve_and_verify( enforce_ordering=True if node.sortClause is not None else False, ) else: - id_res = [] + id_res = set() for each_res in parsed_result: if _verify_single_res(each_res, field_query_list, llm_model_name): - id_res.append(each_res[0]) + if isinstance(each_res[0], list): + id_res.update(each_res[0]) + else: + id_res.add(each_res[0]) end_time = time.time() logging.info("retrieve + verification time {}s".format(end_time - start_time))