From 12cdc11a05a76a16d39aed3d813dea92d4f1849f Mon Sep 17 00:00:00 2001 From: Corentin Date: Fri, 12 Apr 2024 13:02:08 +0200 Subject: [PATCH] feat(Qdrant): start to work on sparse vector integration (#578) * feat(Qdrant): start to working on sparse vector integration * Progress towards Sparse vector support with Fastembed * __init__.py * merge batch results for hybrid request * feat(Qdrant): missing comma * feat(Qdrant): making some test progress * feat(Qdrant): all current test are fixed * feat(Qdrant): linting * feat(Qdrant): working sparse retriver hooray * feat(Qdrant): fix hybrid retriver * feat(Qdrant): modify PR for haystack 2.1.0 with proper sparse vectors * feat(Qdrant): fix lint * test w Haystack main * fix deps * Update integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py Co-authored-by: Anush * feat(Qdrant): remove hybrid & old code, constant for vector field names * Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py Co-authored-by: Stefano Fiorucci * Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py Co-authored-by: Stefano Fiorucci * feat(Qdrant): reverting pop change, changing Dict to SparseEmbedding type * feat(Qdrant): fix lint * feat(Qdrant): remove old todo * simplify documents_to_batch * feat(Qdrant): SparseEmbedding instead of Dict * feat(Qdrant): introducing `use_sparse_embeddings` parameters for document store to make sparse embeddings non breaking change. Need more testing * feat(Qdrant): `use_sparse_embeddings` true by default + bugfix * feat(Qdrant): `use_sparse_embeddings` true by default + bugfix * feat(Qdrant): `use_sparse_embeddings` true by default + bugfix * feat(Qdrant): bugfix * Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py Co-authored-by: Anush * Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py Co-authored-by: Anush * Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py Co-authored-by: Anush * Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py Co-authored-by: Anush * Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py Co-authored-by: Anush * Revert "Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py" This reverts commit f7cf65ec7ea9b2bf2a360c41096cc5770b114f82. * feat(Qdrant): fixing test * Update integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py Co-authored-by: Anush * feat(Qdrant): fixing creation * feat(Qdrant): fixing creation * little fixes * make changes nonbreaking * refactoring --------- Co-authored-by: anakin87 Co-authored-by: Anush --- integrations/qdrant/pyproject.toml | 4 +- .../components/retrievers/qdrant/__init__.py | 4 +- .../components/retrievers/qdrant/retriever.py | 125 +++++- .../document_stores/qdrant/converters.py | 110 +++--- .../document_stores/qdrant/document_store.py | 188 +++++++-- .../document_stores/qdrant/filters.py | 366 +++++++++--------- integrations/qdrant/tests/test_converters.py | 57 +-- .../qdrant/tests/test_dict_converters.py | 3 + integrations/qdrant/tests/test_retriever.py | 160 +++++++- 9 files changed, 704 insertions(+), 313 deletions(-) diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index 29be8da0f..a566de955 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "qdrant-client"] +dependencies = ["haystack-ai>=2.0.1", "qdrant-client"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" @@ -103,6 +103,8 @@ ignore = [ "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", + # Allow boolean arguments in function definition + "FBT001", "FBT002", # Ignore checks for possible passwords "S105", "S106", diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py index 41b59e42d..58be4211a 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .retriever import QdrantEmbeddingRetriever +from .retriever import QdrantEmbeddingRetriever, QdrantSparseRetriever -__all__ = ("QdrantEmbeddingRetriever",) +__all__ = ("QdrantEmbeddingRetriever", "QdrantSparseRetriever") diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py index cd53ccd7b..0b7bfa1a4 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py @@ -1,13 +1,14 @@ from typing import Any, Dict, List, Optional from haystack import Document, component, default_from_dict, default_to_dict +from haystack.dataclasses.sparse_embedding import SparseEmbedding from haystack_integrations.document_stores.qdrant import QdrantDocumentStore @component class QdrantEmbeddingRetriever: """ - A component for retrieving documents from an QdrantDocumentStore. + A component for retrieving documents from an QdrantDocumentStore using dense vectors. Usage example: ```python @@ -32,8 +33,8 @@ def __init__( document_store: QdrantDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - scale_score: bool = True, # noqa: FBT001, FBT002 - return_embedding: bool = False, # noqa: FBT001, FBT002 + scale_score: bool = True, + return_embedding: bool = False, ): """ Create a QdrantEmbeddingRetriever component. @@ -120,3 +121,121 @@ def run( ) return {"documents": docs} + + +@component +class QdrantSparseRetriever: + """ + A component for retrieving documents from an QdrantDocumentStore using sparse vectors. + + Usage example: + ```python + from haystack_integrations.components.retrievers.qdrant import QdrantSparseRetriever + from haystack_integrations.document_stores.qdrant import QdrantDocumentStore + from haystack.dataclasses.sparse_embedding import SparseEmbedding + + document_store = QdrantDocumentStore( + ":memory:", + recreate_index=True, + return_embedding=True, + wait_result_from_api=True, + ) + retriever = QdrantSparseRetriever(document_store=document_store) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) + retriever.run(query_sparse_embedding=sparse_embedding) + ``` + """ + + def __init__( + self, + document_store: QdrantDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, + return_embedding: bool = False, + ): + """ + Create a QdrantSparseRetriever component. + + :param document_store: An instance of QdrantDocumentStore. + :param filters: A dictionary with filters to narrow down the search space. Default is None. + :param top_k: The maximum number of documents to retrieve. Default is 10. + :param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True. + :param return_embedding: Whether to return the sparse embedding of the retrieved Documents. Default is False. + + :raises ValueError: If 'document_store' is not an instance of QdrantDocumentStore. + """ + + if not isinstance(document_store, QdrantDocumentStore): + msg = "document_store must be an instance of QdrantDocumentStore" + raise ValueError(msg) + + self._document_store = document_store + self._filters = filters + self._top_k = top_k + self._scale_score = scale_score + self._return_embedding = return_embedding + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + d = default_to_dict( + self, + document_store=self._document_store, + filters=self._filters, + top_k=self._top_k, + scale_score=self._scale_score, + return_embedding=self._return_embedding, + ) + d["init_parameters"]["document_store"] = self._document_store.to_dict() + + return d + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "QdrantEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) + data["init_parameters"]["document_store"] = document_store + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_sparse_embedding: SparseEmbedding, + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + return_embedding: Optional[bool] = None, + ): + """ + Run the Sparse Embedding Retriever on the given input data. + + :param query_sparse_embedding: Sparse Embedding of the query. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The maximum number of documents to return. + :param scale_score: Whether to scale the scores of the retrieved documents or not. + :param return_embedding: Whether to return the embedding of the retrieved Documents. + :returns: + The retrieved documents. + + """ + docs = self._document_store.query_by_sparse( + query_sparse_embedding=query_sparse_embedding, + filters=filters or self._filters, + top_k=top_k or self._top_k, + scale_score=scale_score or self._scale_score, + return_embedding=return_embedding or self._return_embedding, + ) + + return {"documents": docs} diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py index 439fd605b..96bd4f37a 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py @@ -7,64 +7,74 @@ logger = logging.getLogger(__name__) +DENSE_VECTORS_NAME = "text-dense" +SPARSE_VECTORS_NAME = "text-sparse" -class HaystackToQdrant: - """A converter from Haystack to Qdrant types.""" - UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d") +UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d") - def documents_to_batch( - self, - documents: List[Document], - *, - embedding_field: str, - ) -> List[rest.PointStruct]: - points = [] - for document in documents: - payload = document.to_dict(flatten=False) + +def convert_haystack_documents_to_qdrant_points( + documents: List[Document], + *, + embedding_field: str, + use_sparse_embeddings: bool, +) -> List[rest.PointStruct]: + points = [] + for document in documents: + payload = document.to_dict(flatten=False) + if use_sparse_embeddings: + vector = {} + + dense_vector = payload.pop(embedding_field, None) + if dense_vector is not None: + vector[DENSE_VECTORS_NAME] = dense_vector + + sparse_vector = payload.pop("sparse_embedding", None) + if sparse_vector is not None: + sparse_vector_instance = rest.SparseVector(**sparse_vector) + vector[SPARSE_VECTORS_NAME] = sparse_vector_instance + + else: vector = payload.pop(embedding_field) or {} - _id = self.convert_id(payload.get("id")) - - # TODO: remove as soon as we introduce the support for sparse embeddings in Qdrant - if "sparse_embedding" in payload: - sparse_embedding = payload.pop("sparse_embedding", None) - if sparse_embedding: - logger.warning( - "Document %s has the `sparse_embedding` field set," - "but storing sparse embeddings in Qdrant is not currently supported." - "The `sparse_embedding` field will be ignored.", - payload["id"], - ) - - point = rest.PointStruct( - payload=payload, - vector=vector, - id=_id, - ) - points.append(point) - return points - - def convert_id(self, _id: str) -> str: - """ - Converts any string into a UUID-like format in a deterministic way. - - Qdrant does not accept any string as an id, so an internal id has to be - generated for each point. This is a deterministic way of doing so. - """ - return uuid.uuid5(self.UUID_NAMESPACE, _id).hex + _id = convert_id(payload.get("id")) + + point = rest.PointStruct( + payload=payload, + vector=vector, + id=_id, + ) + points.append(point) + return points + + +def convert_id(_id: str) -> str: + """ + Converts any string into a UUID-like format in a deterministic way. + + Qdrant does not accept any string as an id, so an internal id has to be + generated for each point. This is a deterministic way of doing so. + """ + return uuid.uuid5(UUID_NAMESPACE, _id).hex QdrantPoint = Union[rest.ScoredPoint, rest.Record] -class QdrantToHaystack: - def __init__(self, content_field: str, name_field: str, embedding_field: str): - self.content_field = content_field - self.name_field = name_field - self.embedding_field = embedding_field +def convert_qdrant_point_to_haystack_document(point: QdrantPoint, use_sparse_embeddings: bool) -> Document: + payload = {**point.payload} + payload["score"] = point.score if hasattr(point, "score") else None - def point_to_document(self, point: QdrantPoint) -> Document: - payload = {**point.payload} + if not use_sparse_embeddings: payload["embedding"] = point.vector if hasattr(point, "vector") else None - payload["score"] = point.score if hasattr(point, "score") else None - return Document.from_dict(payload) + elif hasattr(point, "vector") and point.vector is not None: + payload["embedding"] = point.vector.get(DENSE_VECTORS_NAME) + + if SPARSE_VECTORS_NAME in point.vector: + parse_vector_dict = { + "indices": point.vector[SPARSE_VECTORS_NAME].indices, + "values": point.vector[SPARSE_VECTORS_NAME].values, + } + payload["sparse_embedding"] = parse_vector_dict + + return Document.from_dict(payload) diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index dc22673fa..8771a3515 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -8,17 +8,24 @@ from grpc import RpcError from haystack import default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.dataclasses.sparse_embedding import SparseEmbedding from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.filters import convert +from haystack.utils.filters import convert as convert_legacy_filters from qdrant_client import grpc from qdrant_client.http import models as rest from qdrant_client.http.exceptions import UnexpectedResponse from tqdm import tqdm -from .converters import HaystackToQdrant, QdrantToHaystack -from .filters import QdrantFilterConverter +from .converters import ( + DENSE_VECTORS_NAME, + SPARSE_VECTORS_NAME, + convert_haystack_documents_to_qdrant_points, + convert_id, + convert_qdrant_point_to_haystack_document, +) +from .filters import convert_filters_to_qdrant logger = logging.getLogger(__name__) @@ -54,7 +61,7 @@ def __init__( url: Optional[str] = None, port: int = 6333, grpc_port: int = 6334, - prefer_grpc: bool = False, # noqa: FBT001, FBT002 + prefer_grpc: bool = False, https: Optional[bool] = None, api_key: Optional[Secret] = None, prefix: Optional[str] = None, @@ -63,15 +70,16 @@ def __init__( path: Optional[str] = None, index: str = "Document", embedding_dim: int = 768, - on_disk: bool = False, # noqa: FBT001, FBT002 + on_disk: bool = False, content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", + use_sparse_embeddings: bool = False, similarity: str = "cosine", - return_embedding: bool = False, # noqa: FBT001, FBT002 - progress_bar: bool = True, # noqa: FBT001, FBT002 + return_embedding: bool = False, + progress_bar: bool = True, duplicate_documents: str = "overwrite", - recreate_index: bool = False, # noqa: FBT001, FBT002 + recreate_index: bool = False, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None, @@ -81,7 +89,7 @@ def __init__( wal_config: Optional[dict] = None, quantization_config: Optional[dict] = None, init_from: Optional[dict] = None, - wait_result_from_api: bool = True, # noqa: FBT001, FBT002 + wait_result_from_api: bool = True, metadata: Optional[dict] = None, write_batch_size: int = 100, scroll_size: int = 10_000, @@ -133,9 +141,12 @@ def __init__( self.wait_result_from_api = wait_result_from_api self.recreate_index = recreate_index self.payload_fields_to_index = payload_fields_to_index + self.use_sparse_embeddings = use_sparse_embeddings # Make sure the collection is properly set up - self._set_up_collection(index, embedding_dim, recreate_index, similarity, on_disk, payload_fields_to_index) + self._set_up_collection( + index, embedding_dim, recreate_index, similarity, use_sparse_embeddings, on_disk, payload_fields_to_index + ) self.embedding_dim = embedding_dim self.on_disk = on_disk @@ -147,13 +158,6 @@ def __init__( self.return_embedding = return_embedding self.progress_bar = progress_bar self.duplicate_documents = duplicate_documents - self.qdrant_filter_converter = QdrantFilterConverter() - self.haystack_to_qdrant_converter = HaystackToQdrant() - self.qdrant_to_haystack = QdrantToHaystack( - content_field, - name_field, - embedding_field, - ) self.write_batch_size = write_batch_size self.scroll_size = scroll_size @@ -178,7 +182,7 @@ def filter_documents( raise ValueError(msg) if filters and "operator" not in filters: - filters = convert(filters) + filters = convert_legacy_filters(filters) return list( self.get_documents_generator( filters, @@ -194,7 +198,7 @@ def write_documents( if not isinstance(doc, Document): msg = f"DocumentStore.write_documents() expects a list of Documents but got an element of {type(doc)}." raise ValueError(msg) - self._set_up_collection(self.index, self.embedding_dim, False, self.similarity) + self._set_up_collection(self.index, self.embedding_dim, False, self.similarity, self.use_sparse_embeddings) if len(documents) == 0: logger.warning("Calling QdrantDocumentStore.write_documents() with empty list") @@ -209,9 +213,10 @@ def write_documents( batched_documents = get_batches_from_generator(document_objects, self.write_batch_size) with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar: for document_batch in batched_documents: - batch = self.haystack_to_qdrant_converter.documents_to_batch( + batch = convert_haystack_documents_to_qdrant_points( document_batch, embedding_field=self.embedding_field, + use_sparse_embeddings=self.use_sparse_embeddings, ) self.client.upsert( @@ -224,7 +229,7 @@ def write_documents( return len(document_objects) def delete_documents(self, ids: List[str]): - ids = [self.haystack_to_qdrant_converter.convert_id(_id) for _id in ids] + ids = [convert_id(_id) for _id in ids] try: self.client.delete( collection_name=self.index, @@ -257,7 +262,7 @@ def get_documents_generator( filters: Optional[Dict[str, Any]] = None, ) -> Generator[Document, None, None]: index = self.index - qdrant_filters = self.qdrant_filter_converter.convert(filters) + qdrant_filters = convert_filters_to_qdrant(filters) next_offset = None stop_scrolling = False @@ -275,7 +280,9 @@ def get_documents_generator( ) for record in records: - yield self.qdrant_to_haystack.point_to_document(record) + yield convert_qdrant_point_to_haystack_document( + record, use_sparse_embeddings=self.use_sparse_embeddings + ) def get_documents_by_id( self, @@ -286,7 +293,7 @@ def get_documents_by_id( documents: List[Document] = [] - ids = [self.haystack_to_qdrant_converter.convert_id(_id) for _id in ids] + ids = [convert_id(_id) for _id in ids] records = self.client.retrieve( collection_name=index, ids=ids, @@ -295,28 +302,77 @@ def get_documents_by_id( ) for record in records: - documents.append(self.qdrant_to_haystack.point_to_document(record)) + documents.append( + convert_qdrant_point_to_haystack_document(record, use_sparse_embeddings=self.use_sparse_embeddings) + ) return documents + def query_by_sparse( + self, + query_sparse_embedding: SparseEmbedding, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, + return_embedding: bool = False, + ) -> List[Document]: + if not self.use_sparse_embeddings: + message = ( + "You are trying to query using sparse embeddings, but the Document Store " + "was initialized with `use_sparse_embeddings=False`. " + ) + raise QdrantStoreError(message) + + qdrant_filters = convert_filters_to_qdrant(filters) + query_indices = query_sparse_embedding.indices + query_values = query_sparse_embedding.values + points = self.client.search( + collection_name=self.index, + query_vector=rest.NamedSparseVector( + name=SPARSE_VECTORS_NAME, + vector=rest.SparseVector( + indices=query_indices, + values=query_values, + ), + ), + query_filter=qdrant_filters, + limit=top_k, + with_vectors=return_embedding, + ) + results = [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for point in points + ] + if scale_score: + for document in results: + score = document.score + score = float(1 / (1 + np.exp(-score / 100))) + document.score = score + return results + def query_by_embedding( self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - scale_score: bool = True, # noqa: FBT001, FBT002 - return_embedding: bool = False, # noqa: FBT001, FBT002 + scale_score: bool = True, + return_embedding: bool = False, ) -> List[Document]: - qdrant_filters = self.qdrant_filter_converter.convert(filters) + qdrant_filters = convert_filters_to_qdrant(filters) points = self.client.search( collection_name=self.index, - query_vector=query_embedding, + query_vector=rest.NamedVector( + name=DENSE_VECTORS_NAME if self.use_sparse_embeddings else "", + vector=query_embedding, + ), query_filter=qdrant_filters, limit=top_k, with_vectors=return_embedding, ) - - results = [self.qdrant_to_haystack.point_to_document(point) for point in points] + results = [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for point in points + ] if scale_score: for document in results: score = document.score @@ -355,9 +411,10 @@ def _set_up_collection( self, collection_name: str, embedding_dim: int, - recreate_collection: bool, # noqa: FBT001 + recreate_collection: bool, similarity: str, - on_disk: bool = False, # noqa: FBT001, FBT002 + use_sparse_embeddings: bool, + on_disk: bool = False, payload_fields_to_index: Optional[List[dict]] = None, ): distance = self._get_distance(similarity) @@ -365,7 +422,7 @@ def _set_up_collection( if recreate_collection: # There is no need to verify the current configuration of that # collection. It might be just recreated again. - self._recreate_collection(collection_name, distance, embedding_dim, on_disk) + self._recreate_collection(collection_name, distance, embedding_dim, on_disk, use_sparse_embeddings) # Create Payload index if payload_fields_to_index is provided self._create_payload_index(collection_name, payload_fields_to_index) return @@ -381,13 +438,39 @@ def _set_up_collection( # Qdrant local raises ValueError if the collection is not found, but # with the remote server UnexpectedResponse / RpcError is raised. # Until that's unified, we need to catch both. - self._recreate_collection(collection_name, distance, embedding_dim, on_disk) + self._recreate_collection(collection_name, distance, embedding_dim, on_disk, use_sparse_embeddings) # Create Payload index if payload_fields_to_index is provided self._create_payload_index(collection_name, payload_fields_to_index) return - current_distance = collection_info.config.params.vectors.distance - current_vector_size = collection_info.config.params.vectors.size + has_named_vectors = ( + isinstance(collection_info.config.params.vectors, dict) + and DENSE_VECTORS_NAME in collection_info.config.params.vectors + ) + + if self.use_sparse_embeddings and not has_named_vectors: + msg = ( + f"Collection '{collection_name}' already exists in Qdrant, " + f"but it has been originally created without sparse embedding vectors. " + f"If you want to use that collection, you can set `use_sparse_embeddings=False`. " + f"To use sparse embeddings, you need to recreate the collection or migrate the existing one." + ) + raise QdrantStoreError(msg) + + elif not self.use_sparse_embeddings and has_named_vectors: + msg = ( + f"Collection '{collection_name}' already exists in Qdrant, " + f"but it has been originally created with sparse embedding vectors." + f"If you want to use that collection, please set `use_sparse_embeddings=True`." + ) + raise QdrantStoreError(msg) + + if self.use_sparse_embeddings: + current_distance = collection_info.config.params.vectors[DENSE_VECTORS_NAME].distance + current_vector_size = collection_info.config.params.vectors[DENSE_VECTORS_NAME].size + else: + current_distance = collection_info.config.params.vectors.distance + current_vector_size = collection_info.config.params.vectors.size if current_distance != distance: msg = ( @@ -407,14 +490,33 @@ def _set_up_collection( ) raise ValueError(msg) - def _recreate_collection(self, collection_name: str, distance, embedding_dim: int, on_disk: bool): # noqa: FBT001 + def _recreate_collection( + self, + collection_name: str, + distance, + embedding_dim: int, + on_disk: bool, + use_sparse_embeddings: bool, + ): + # dense vectors configuration + vectors_config = rest.VectorParams(size=embedding_dim, on_disk=on_disk, distance=distance) + + if use_sparse_embeddings: + # in this case, we need to define named vectors + vectors_config = {DENSE_VECTORS_NAME: vectors_config} + + sparse_vectors_config = { + SPARSE_VECTORS_NAME: rest.SparseVectorParams( + index=rest.SparseIndexParams( + on_disk=on_disk, + ) + ), + } + self.client.recreate_collection( collection_name=collection_name, - vectors_config=rest.VectorParams( - size=embedding_dim, - on_disk=on_disk, - distance=distance, - ), + vectors_config=vectors_config, + sparse_vectors_config=sparse_vectors_config if use_sparse_embeddings else None, shard_number=self.shard_number, replication_factor=self.replication_factor, write_consistency_factor=self.write_consistency_factor, diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py index 72a74a8b1..c4387b1e5 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py @@ -4,226 +4,230 @@ from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError from qdrant_client.http import models -from .converters import HaystackToQdrant +from .converters import convert_id COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys() LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys() -class QdrantFilterConverter: +def convert_filters_to_qdrant( + filter_term: Optional[Union[List[dict], dict]] = None, +) -> Optional[models.Filter]: """Converts Haystack filters to the format used by Qdrant.""" - def __init__(self): - self.haystack_to_qdrant_converter = HaystackToQdrant() + if not filter_term: + return None - def convert( - self, - filter_term: Optional[Union[List[dict], dict]] = None, - ) -> Optional[models.Filter]: - if not filter_term: - return None + must_clauses, should_clauses, must_not_clauses = [], [], [] - must_clauses, should_clauses, must_not_clauses = [], [], [] + if isinstance(filter_term, dict): + filter_term = [filter_term] - if isinstance(filter_term, dict): - filter_term = [filter_term] + for item in filter_term: + operator = item.get("operator") + if operator is None: + msg = "Operator not found in filters" + raise FilterError(msg) - for item in filter_term: - operator = item.get("operator") - if operator is None: - msg = "Operator not found in filters" - raise FilterError(msg) + if operator in LOGICAL_OPERATORS and "conditions" not in item: + msg = f"'conditions' not found for '{operator}'" + raise FilterError(msg) - if operator in LOGICAL_OPERATORS and "conditions" not in item: - msg = f"'conditions' not found for '{operator}'" + if operator == "AND": + must_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) + elif operator == "OR": + should_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) + elif operator == "NOT": + must_not_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) + elif operator in COMPARISON_OPERATORS: + field = item.get("field") + value = item.get("value") + if field is None or value is None: + msg = f"'field' or 'value' not found for '{operator}'" raise FilterError(msg) - if operator == "AND": - must_clauses.append(self.convert(item.get("conditions", []))) - elif operator == "OR": - should_clauses.append(self.convert(item.get("conditions", []))) - elif operator == "NOT": - must_not_clauses.append(self.convert(item.get("conditions", []))) - elif operator in COMPARISON_OPERATORS: - field = item.get("field") - value = item.get("value") - if field is None or value is None: - msg = f"'field' or 'value' not found for '{operator}'" - raise FilterError(msg) - - must_clauses.extend( - self._parse_comparison_operation(comparison_operation=operator, key=field, value=value) - ) - else: - msg = f"Unknown operator {operator} used in filters" - raise FilterError(msg) + must_clauses.extend(_parse_comparison_operation(comparison_operation=operator, key=field, value=value)) + else: + msg = f"Unknown operator {operator} used in filters" + raise FilterError(msg) - payload_filter = models.Filter( - must=must_clauses or None, - should=should_clauses or None, - must_not=must_not_clauses or None, - ) + payload_filter = models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) - filter_result = self._squeeze_filter(payload_filter) + filter_result = _squeeze_filter(payload_filter) - return filter_result + return filter_result - def _parse_comparison_operation( - self, comparison_operation: str, key: str, value: Union[dict, List, str, float] - ) -> List[models.Condition]: - conditions: List[models.Condition] = [] - condition_builder_mapping = { - "==": self._build_eq_condition, - "in": self._build_in_condition, - "!=": self._build_ne_condition, - "not in": self._build_nin_condition, - ">": self._build_gt_condition, - ">=": self._build_gte_condition, - "<": self._build_lt_condition, - "<=": self._build_lte_condition, - } +def _parse_comparison_operation( + comparison_operation: str, key: str, value: Union[dict, List, str, float] +) -> List[models.Condition]: + conditions: List[models.Condition] = [] - condition_builder = condition_builder_mapping.get(comparison_operation) + condition_builder_mapping = { + "==": _build_eq_condition, + "in": _build_in_condition, + "!=": _build_ne_condition, + "not in": _build_nin_condition, + ">": _build_gt_condition, + ">=": _build_gte_condition, + "<": _build_lt_condition, + "<=": _build_lte_condition, + } - if condition_builder is None: - msg = f"Unknown operator {comparison_operation} used in filters" - raise ValueError(msg) + condition_builder = condition_builder_mapping.get(comparison_operation) - conditions.append(condition_builder(key, value)) + if condition_builder is None: + msg = f"Unknown operator {comparison_operation} used in filters" + raise ValueError(msg) - return conditions + conditions.append(condition_builder(key, value)) - def _build_eq_condition(self, key: str, value: models.ValueVariants) -> models.Condition: - if isinstance(value, str) and " " in value: - models.FieldCondition(key=key, match=models.MatchText(text=value)) - return models.FieldCondition(key=key, match=models.MatchValue(value=value)) + return conditions - def _build_in_condition(self, key: str, value: List[models.ValueVariants]) -> models.Condition: - if not isinstance(value, list): - msg = f"Value {value} is not a list" - raise FilterError(msg) - return models.Filter( - should=[ - ( - models.FieldCondition(key=key, match=models.MatchText(text=item)) - if isinstance(item, str) and " " not in item - else models.FieldCondition(key=key, match=models.MatchValue(value=item)) - ) - for item in value - ] - ) - - def _build_ne_condition(self, key: str, value: models.ValueVariants) -> models.Condition: - return models.Filter( - must_not=[ - ( - models.FieldCondition(key=key, match=models.MatchText(text=value)) - if isinstance(value, str) and " " not in value - else models.FieldCondition(key=key, match=models.MatchValue(value=value)) - ) - ] - ) - - def _build_nin_condition(self, key: str, value: List[models.ValueVariants]) -> models.Condition: - if not isinstance(value, list): - msg = f"Value {value} is not a list" - raise FilterError(msg) - return models.Filter( - must_not=[ - ( - models.FieldCondition(key=key, match=models.MatchText(text=item)) - if isinstance(item, str) and " " in item - else models.FieldCondition(key=key, match=models.MatchValue(value=item)) - ) - for item in value - ] - ) - - def _build_lt_condition(self, key: str, value: Union[str, float, int]) -> models.Condition: - if isinstance(value, str) and is_datetime_string(value): - return models.FieldCondition(key=key, range=models.DatetimeRange(lt=value)) - - if isinstance(value, (int, float)): - return models.FieldCondition(key=key, range=models.Range(lt=value)) - - msg = f"Value {value} is not an int or float or datetime string" - raise FilterError(msg) - def _build_lte_condition(self, key: str, value: Union[str, float, int]) -> models.Condition: - if isinstance(value, str) and is_datetime_string(value): - return models.FieldCondition(key=key, range=models.DatetimeRange(lte=value)) +def _build_eq_condition(key: str, value: models.ValueVariants) -> models.Condition: + if isinstance(value, str) and " " in value: + models.FieldCondition(key=key, match=models.MatchText(text=value)) + return models.FieldCondition(key=key, match=models.MatchValue(value=value)) - if isinstance(value, (int, float)): - return models.FieldCondition(key=key, range=models.Range(lte=value)) - msg = f"Value {value} is not an int or float or datetime string" +def _build_in_condition(key: str, value: List[models.ValueVariants]) -> models.Condition: + if not isinstance(value, list): + msg = f"Value {value} is not a list" + raise FilterError(msg) + return models.Filter( + should=[ + ( + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " not in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + ) + for item in value + ] + ) + + +def _build_ne_condition(key: str, value: models.ValueVariants) -> models.Condition: + return models.Filter( + must_not=[ + ( + models.FieldCondition(key=key, match=models.MatchText(text=value)) + if isinstance(value, str) and " " not in value + else models.FieldCondition(key=key, match=models.MatchValue(value=value)) + ) + ] + ) + + +def _build_nin_condition(key: str, value: List[models.ValueVariants]) -> models.Condition: + if not isinstance(value, list): + msg = f"Value {value} is not a list" raise FilterError(msg) + return models.Filter( + must_not=[ + ( + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + ) + for item in value + ] + ) - def _build_gt_condition(self, key: str, value: Union[str, float, int]) -> models.Condition: - if isinstance(value, str) and is_datetime_string(value): - return models.FieldCondition(key=key, range=models.DatetimeRange(gt=value)) - if isinstance(value, (int, float)): - return models.FieldCondition(key=key, range=models.Range(gt=value)) +def _build_lt_condition(key: str, value: Union[str, float, int]) -> models.Condition: + if isinstance(value, str) and is_datetime_string(value): + return models.FieldCondition(key=key, range=models.DatetimeRange(lt=value)) - msg = f"Value {value} is not an int or float or datetime string" - raise FilterError(msg) + if isinstance(value, (int, float)): + return models.FieldCondition(key=key, range=models.Range(lt=value)) - def _build_gte_condition(self, key: str, value: Union[str, float, int]) -> models.Condition: - if isinstance(value, str) and is_datetime_string(value): - return models.FieldCondition(key=key, range=models.DatetimeRange(gte=value)) + msg = f"Value {value} is not an int or float or datetime string" + raise FilterError(msg) - if isinstance(value, (int, float)): - return models.FieldCondition(key=key, range=models.Range(gte=value)) - msg = f"Value {value} is not an int or float or datetime string" - raise FilterError(msg) +def _build_lte_condition(key: str, value: Union[str, float, int]) -> models.Condition: + if isinstance(value, str) and is_datetime_string(value): + return models.FieldCondition(key=key, range=models.DatetimeRange(lte=value)) + + if isinstance(value, (int, float)): + return models.FieldCondition(key=key, range=models.Range(lte=value)) + + msg = f"Value {value} is not an int or float or datetime string" + raise FilterError(msg) + - def _build_has_id_condition(self, id_values: List[models.ExtendedPointId]) -> models.HasIdCondition: - return models.HasIdCondition( - has_id=[ - # Ids are converted into their internal representation - self.haystack_to_qdrant_converter.convert_id(item) - for item in id_values - ] - ) - - def _squeeze_filter(self, payload_filter: models.Filter) -> models.Filter: - """ - Simplify given payload filter, if the nested structure might be unnested. - That happens if there is a single clause in that filter. - :param payload_filter: - :returns: - """ - filter_parts = { - "must": payload_filter.must, - "should": payload_filter.should, - "must_not": payload_filter.must_not, - } - - total_clauses = sum(len(x) for x in filter_parts.values() if x is not None) - if total_clauses == 0 or total_clauses > 1: - return payload_filter - - # Payload filter has just a single clause provided (either must, should - # or must_not). If that single clause is also of a models.Filter type, - # then it might be returned instead. - for part_name, filter_part in filter_parts.items(): - if not filter_part: - continue - - subfilter = filter_part[0] - if not isinstance(subfilter, models.Filter): - # The inner statement is a simple condition like models.FieldCondition - # so it cannot be simplified. - continue - - if subfilter.must: - return models.Filter(**{part_name: subfilter.must}) +def _build_gt_condition(key: str, value: Union[str, float, int]) -> models.Condition: + if isinstance(value, str) and is_datetime_string(value): + return models.FieldCondition(key=key, range=models.DatetimeRange(gt=value)) + if isinstance(value, (int, float)): + return models.FieldCondition(key=key, range=models.Range(gt=value)) + + msg = f"Value {value} is not an int or float or datetime string" + raise FilterError(msg) + + +def _build_gte_condition(key: str, value: Union[str, float, int]) -> models.Condition: + if isinstance(value, str) and is_datetime_string(value): + return models.FieldCondition(key=key, range=models.DatetimeRange(gte=value)) + + if isinstance(value, (int, float)): + return models.FieldCondition(key=key, range=models.Range(gte=value)) + + msg = f"Value {value} is not an int or float or datetime string" + raise FilterError(msg) + + +def _build_has_id_condition(id_values: List[models.ExtendedPointId]) -> models.HasIdCondition: + return models.HasIdCondition( + has_id=[ + # Ids are converted into their internal representation + convert_id(item) + for item in id_values + ] + ) + + +def _squeeze_filter(payload_filter: models.Filter) -> models.Filter: + """ + Simplify given payload filter, if the nested structure might be unnested. + That happens if there is a single clause in that filter. + :param payload_filter: + :returns: + """ + filter_parts = { + "must": payload_filter.must, + "should": payload_filter.should, + "must_not": payload_filter.must_not, + } + + total_clauses = sum(len(x) for x in filter_parts.values() if x is not None) + if total_clauses == 0 or total_clauses > 1: return payload_filter + # Payload filter has just a single clause provided (either must, should + # or must_not). If that single clause is also of a models.Filter type, + # then it might be returned instead. + for part_name, filter_part in filter_parts.items(): + if not filter_part: + continue + + subfilter = filter_part[0] + if not isinstance(subfilter, models.Filter): + # The inner statement is a simple condition like models.FieldCondition + # so it cannot be simplified. + continue + + if subfilter.must: + return models.Filter(**{part_name: subfilter.must}) + + return payload_filter + def is_datetime_string(value: str) -> bool: try: diff --git a/integrations/qdrant/tests/test_converters.py b/integrations/qdrant/tests/test_converters.py index 0c6c5676a..242c4cafe 100644 --- a/integrations/qdrant/tests/test_converters.py +++ b/integrations/qdrant/tests/test_converters.py @@ -1,36 +1,46 @@ import numpy as np -import pytest -from haystack_integrations.document_stores.qdrant.converters import HaystackToQdrant, QdrantToHaystack +from haystack_integrations.document_stores.qdrant.converters import ( + convert_id, + convert_qdrant_point_to_haystack_document, +) from qdrant_client.http import models as rest -CONTENT_FIELD = "content" -NAME_FIELD = "name" -EMBEDDING_FIELD = "vector" +def test_convert_id_is_deterministic(): + first_id = convert_id("test-id") + second_id = convert_id("test-id") + assert first_id == second_id -@pytest.fixture -def haystack_to_qdrant() -> HaystackToQdrant: - return HaystackToQdrant() +def test_point_to_document_reverts_proper_structure_from_record_with_sparse(): -@pytest.fixture -def qdrant_to_haystack() -> QdrantToHaystack: - return QdrantToHaystack( - content_field=CONTENT_FIELD, - name_field=NAME_FIELD, - embedding_field=EMBEDDING_FIELD, + point = rest.Record( + id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", + payload={ + "id": "my-id", + "id_hash_keys": ["content"], + "content": "Lorem ipsum", + "content_type": "text", + "meta": { + "test_field": 1, + }, + }, + vector={ + "text-dense": [1.0, 0.0, 0.0, 0.0], + "text-sparse": {"indices": [7, 1024, 367], "values": [0.1, 0.98, 0.33]}, + }, ) + document = convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) + assert "my-id" == document.id + assert "Lorem ipsum" == document.content + assert "text" == document.content_type + assert {"indices": [7, 1024, 367], "values": [0.1, 0.98, 0.33]} == document.sparse_embedding.to_dict() + assert {"test_field": 1} == document.meta + assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding) -def test_convert_id_is_deterministic(haystack_to_qdrant: HaystackToQdrant): - first_id = haystack_to_qdrant.convert_id("test-id") - second_id = haystack_to_qdrant.convert_id("test-id") - assert first_id == second_id - +def test_point_to_document_reverts_proper_structure_from_record_without_sparse(): -def test_point_to_document_reverts_proper_structure_from_record( - qdrant_to_haystack: QdrantToHaystack, -): point = rest.Record( id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", payload={ @@ -44,9 +54,10 @@ def test_point_to_document_reverts_proper_structure_from_record( }, vector=[1.0, 0.0, 0.0, 0.0], ) - document = qdrant_to_haystack.point_to_document(point) + document = convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=False) assert "my-id" == document.id assert "Lorem ipsum" == document.content assert "text" == document.content_type + assert document.sparse_embedding is None assert {"test_field": 1} == document.meta assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding) diff --git a/integrations/qdrant/tests/test_dict_converters.py b/integrations/qdrant/tests/test_dict_converters.py index 3da64743a..6c8e46710 100644 --- a/integrations/qdrant/tests/test_dict_converters.py +++ b/integrations/qdrant/tests/test_dict_converters.py @@ -25,6 +25,7 @@ def test_to_dict(): "content_field": "content", "name_field": "name", "embedding_field": "embedding", + "use_sparse_embeddings": False, "similarity": "cosine", "return_embedding": False, "progress_bar": True, @@ -63,6 +64,7 @@ def test_from_dict(): "content_field": "content", "name_field": "name", "embedding_field": "embedding", + "use_sparse_embeddings": True, "similarity": "cosine", "return_embedding": False, "progress_bar": True, @@ -86,6 +88,7 @@ def test_from_dict(): document_store.content_field == "content", document_store.name_field == "name", document_store.embedding_field == "embedding", + document_store.use_sparse_embeddings is True, document_store.on_disk is False, document_store.similarity == "cosine", document_store.return_embedding is False, diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index 41d9b3088..96e748220 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -1,17 +1,21 @@ from typing import List -from haystack.dataclasses import Document +import numpy as np +from haystack.dataclasses import Document, SparseEmbedding from haystack.testing.document_store import ( FilterableDocsFixtureMixin, _random_embeddings, ) -from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever +from haystack_integrations.components.retrievers.qdrant import ( + QdrantEmbeddingRetriever, + QdrantSparseRetriever, +) from haystack_integrations.document_stores.qdrant import QdrantDocumentStore class TestQdrantRetriever(FilterableDocsFixtureMixin): def test_init_default(self): - document_store = QdrantDocumentStore(location=":memory:", index="test") + document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=False) retriever = QdrantEmbeddingRetriever(document_store=document_store) assert retriever._document_store == document_store assert retriever._filters is None @@ -19,7 +23,7 @@ def test_init_default(self): assert retriever._return_embedding is False def test_to_dict(self): - document_store = QdrantDocumentStore(location=":memory:", index="test") + document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=False) retriever = QdrantEmbeddingRetriever(document_store=document_store) res = retriever.to_dict() assert res == { @@ -45,6 +49,7 @@ def test_to_dict(self): "content_field": "content", "name_field": "name", "embedding_field": "embedding", + "use_sparse_embeddings": False, "similarity": "cosine", "return_embedding": False, "progress_bar": True, @@ -96,19 +101,154 @@ def test_from_dict(self): assert retriever._return_embedding is True def test_run(self, filterable_docs: List[Document]): - document_store = QdrantDocumentStore(location=":memory:", index="Boi") + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=False) document_store.write_documents(filterable_docs) retriever = QdrantEmbeddingRetriever(document_store=document_store) - results: List[Document] = retriever.run(query_embedding=_random_embeddings(768)) + results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))["documents"] + assert len(results) == 10 - assert len(results["documents"]) == 10 # type: ignore + results = retriever.run(query_embedding=_random_embeddings(768), top_k=5, return_embedding=False)["documents"] + assert len(results) == 5 - results = retriever.run(query_embedding=_random_embeddings(768), top_k=5, return_embedding=False) + for document in results: + assert document.embedding is None - assert len(results["documents"]) == 5 # type: ignore + def test_run_with_sparse_activated(self, filterable_docs: List[Document]): + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) - for document in results["documents"]: # type: ignore + document_store.write_documents(filterable_docs) + + retriever = QdrantEmbeddingRetriever(document_store=document_store) + + results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))["documents"] + + assert len(results) == 10 + + results = retriever.run(query_embedding=_random_embeddings(768), top_k=5, return_embedding=False)["documents"] + + assert len(results) == 5 + + for document in results: assert document.embedding is None + + +class TestQdrantSparseRetriever(FilterableDocsFixtureMixin): + def test_init_default(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantSparseRetriever(document_store=document_store) + assert retriever._document_store == document_store + assert retriever._filters is None + assert retriever._top_k == 10 + assert retriever._return_embedding is False + + def test_to_dict(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantSparseRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseRetriever", + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", + "init_parameters": { + "location": ":memory:", + "url": None, + "port": 6333, + "grpc_port": 6334, + "prefer_grpc": False, + "https": None, + "api_key": None, + "prefix": None, + "timeout": None, + "host": None, + "path": None, + "index": "test", + "embedding_dim": 768, + "on_disk": False, + "content_field": "content", + "name_field": "name", + "embedding_field": "embedding", + "use_sparse_embeddings": False, + "similarity": "cosine", + "return_embedding": False, + "progress_bar": True, + "duplicate_documents": "overwrite", + "recreate_index": False, + "shard_number": None, + "replication_factor": None, + "write_consistency_factor": None, + "on_disk_payload": None, + "hnsw_config": None, + "optimizers_config": None, + "wal_config": None, + "quantization_config": None, + "init_from": None, + "wait_result_from_api": True, + "metadata": {}, + "write_batch_size": 100, + "scroll_size": 10000, + "payload_fields_to_index": None, + }, + }, + "filters": None, + "top_k": 10, + "scale_score": True, + "return_embedding": False, + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseRetriever", + "init_parameters": { + "document_store": { + "init_parameters": {"location": ":memory:", "index": "test"}, + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", + }, + "filters": None, + "top_k": 5, + "scale_score": False, + "return_embedding": True, + }, + } + retriever = QdrantSparseRetriever.from_dict(data) + assert isinstance(retriever._document_store, QdrantDocumentStore) + assert retriever._document_store.index == "test" + assert retriever._filters is None + assert retriever._top_k == 5 + assert retriever._scale_score is False + assert retriever._return_embedding is True + + def _generate_mocked_sparse_embedding(self, n): + list_of_sparse_vectors = [] + for _ in range(n): + random_indice_length = np.random.randint(3, 15) + data = { + "indices": list(range(random_indice_length)), + "values": [np.random.random_sample() for _ in range(random_indice_length)], + } + list_of_sparse_vectors.append(data) + return list_of_sparse_vectors + + def test_run(self, filterable_docs: List[Document]): + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) + + # Add fake sparse embedding to documents + for doc in filterable_docs: + doc.sparse_embedding = SparseEmbedding.from_dict(self._generate_mocked_sparse_embedding(1)[0]) + + document_store.write_documents(filterable_docs) + retriever = QdrantSparseRetriever(document_store=document_store) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) + + results: List[Document] = retriever.run(query_sparse_embedding=sparse_embedding)["documents"] + assert len(results) == 10 + + results = retriever.run(query_sparse_embedding=sparse_embedding, top_k=5, return_embedding=True)["documents"] + assert len(results) == 5 + + for document in results: + assert document.sparse_embedding