diff --git a/autogen/agentchat/contrib/vectordb/qdrant.py b/autogen/agentchat/contrib/vectordb/qdrant.py index d9c4ee1d2e5..2c5194a9f73 100644 --- a/autogen/agentchat/contrib/vectordb/qdrant.py +++ b/autogen/agentchat/contrib/vectordb/qdrant.py @@ -129,7 +129,7 @@ def create_collection(self, collection_name: str, overwrite: bool = False, get_o elif not get_or_create: raise ValueError(f"Collection {collection_name} already exists.") - def get_collection(self, collection_name: str = None): + def get_collection(self, collection_name: Optional[str] = None): """ Get the collection from the vector database. @@ -231,8 +231,8 @@ def retrieve_docs( """ embeddings = self.embedding_function(queries) requests = [ - models.SearchRequest( - vector=embedding, + models.QueryRequest( + query=embedding, limit=n_results, score_threshold=distance_threshold, with_payload=True, @@ -241,8 +241,8 @@ def retrieve_docs( for embedding in embeddings ] - batch_results = self.client.search_batch(collection_name, requests) - return [self._scored_points_to_documents(results) for results in batch_results] + batch_results = self.client.query_batch_points(collection_name, requests) + return [self._scored_points_to_documents(results.points) for results in batch_results] def get_docs_by_ids( self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs