From 79d0d527c3887b09333d011b392ee15878586c83 Mon Sep 17 00:00:00 2001 From: Corentin Meyer Date: Fri, 22 Mar 2024 14:09:11 +0100 Subject: [PATCH] feat(Qdrant): SparseEmbedding instead of Dict --- .../components/retrievers/qdrant/retriever.py | 5 +++-- .../document_stores/qdrant/document_store.py | 5 ++--- integrations/qdrant/tests/test_retriever.py | 8 +++----- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py index 5d44107eb..12a67a3b7 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py @@ -132,6 +132,7 @@ class QdrantSparseRetriever: ```python from haystack_integrations.components.retrievers.qdrant import QdrantSparseRetriever from haystack_integrations.document_stores.qdrant import QdrantDocumentStore + from haystack.dataclasses.sparse_embedding import SparseEmbedding document_store = QdrantDocumentStore( ":memory:", @@ -140,8 +141,8 @@ class QdrantSparseRetriever: wait_result_from_api=True, ) retriever = QdrantSparseRetriever(document_store=document_store) - - retriever.run(query_sparse_embedding={"indices":[0, 1, 2, 3], "values":[0.1, 0.8, 0.05, 0.33]}) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) + retriever.run(query_sparse_embedding=sparse_embedding) ``` """ diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index d68903f06..f304a73d8 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -310,9 +310,8 @@ def query_by_sparse( return_embedding: bool = False, # noqa: FBT001, FBT002 ) -> List[Document]: qdrant_filters = self.qdrant_filter_converter.convert(filters) - - query_indices = query_sparse_embedding["indices"] - query_values = query_sparse_embedding["values"] + query_indices = query_sparse_embedding.indices + query_values = query_sparse_embedding.values points = self.client.search( collection_name=self.index, diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index fb4ac704e..6b24270b4 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -227,14 +227,12 @@ def test_run(self, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) retriever = QdrantSparseRetriever(document_store=document_store) - - results: List[Document] = retriever.run(query_sparse_embedding=self._generate_mocked_sparse_embedding(1)[0]) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) + results: List[Document] = retriever.run(query_sparse_embedding=sparse_embedding) assert len(results["documents"]) == 10 # type: ignore - results = retriever.run( - query_sparse_embedding=self._generate_mocked_sparse_embedding(1)[0], top_k=5, return_embedding=False - ) + results = retriever.run(query_sparse_embedding=sparse_embedding, top_k=5, return_embedding=False) assert len(results["documents"]) == 5 # type: ignore