From 94064da472c4dda429b90528ea0548ce233cf8ee Mon Sep 17 00:00:00 2001 From: paubins2 <78225817+paubins2@users.noreply.github.com> Date: Thu, 26 Sep 2024 01:23:53 -0500 Subject: [PATCH] Embedding function should always return a list of a list of vectors (#3570) Co-authored-by: patrickaubin-abbott <78225817+patrickaubin-abbott@users.noreply.github.com> Co-authored-by: Li Jiang --- autogen/agentchat/contrib/vectordb/pgvectordb.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/pgvectordb.py b/autogen/agentchat/contrib/vectordb/pgvectordb.py index ac86802b672..6fce4a6db80 100644 --- a/autogen/agentchat/contrib/vectordb/pgvectordb.py +++ b/autogen/agentchat/contrib/vectordb/pgvectordb.py @@ -415,7 +415,8 @@ def query( cursor = self.client.cursor() results = [] for query_text in query_texts: - vector = self.embedding_function(query_text, convert_to_tensor=False).tolist() + vector = self.embedding_function(query_text) + if distance_type.lower() == "cosine": index_function = "<=>" elif distance_type.lower() == "euclidean": @@ -619,7 +620,7 @@ def __init__( if embedding_function: self.embedding_function = embedding_function else: - self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode + self.embedding_function = lambda s: SentenceTransformer("all-MiniLM-L6-v2").encode(s).tolist() self.metadata = metadata register_vector(self.client) self.active_collection = None