From 810ad842e42ec816cca5c177308284624c0d24a0 Mon Sep 17 00:00:00 2001 From: Corentin Date: Thu, 7 Mar 2024 10:17:52 +0100 Subject: [PATCH] feat(Qdrant): allow payload indexing + on disk vectors (#553) * feat(qdrant): allow payload indexing * fix(qdrant): fix tests * fix(qdrant): lint fix * fix(qdrant): black formatting * fix(qdrant): rename `payload_field_to_index` to `payload_fields_to_index` --- .../document_stores/qdrant/document_store.py | 32 ++++++++++++++++--- .../qdrant/tests/test_dict_converters.py | 6 ++++ integrations/qdrant/tests/test_retriever.py | 2 ++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index 645db88ae..dc22673fa 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -63,6 +63,7 @@ def __init__( path: Optional[str] = None, index: str = "Document", embedding_dim: int = 768, + on_disk: bool = False, # noqa: FBT001, FBT002 content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", @@ -84,6 +85,7 @@ def __init__( metadata: Optional[dict] = None, write_batch_size: int = 100, scroll_size: int = 10_000, + payload_fields_to_index: Optional[List[dict]] = None, ): super().__init__() @@ -130,11 +132,13 @@ def __init__( self.init_from = init_from self.wait_result_from_api = wait_result_from_api self.recreate_index = recreate_index + self.payload_fields_to_index = payload_fields_to_index # Make sure the collection is properly set up - self._set_up_collection(index, embedding_dim, recreate_index, similarity) + self._set_up_collection(index, embedding_dim, recreate_index, similarity, on_disk, payload_fields_to_index) self.embedding_dim = embedding_dim + self.on_disk = on_disk self.content_field = content_field self.name_field = name_field self.embedding_field = embedding_field @@ -334,19 +338,36 @@ def _get_distance(self, similarity: str) -> rest.Distance: ) raise QdrantStoreError(msg) from ke + def _create_payload_index(self, collection_name: str, payload_fields_to_index: Optional[List[dict]] = None): + """ + Create payload index for the collection if payload_fields_to_index is provided + See: https://qdrant.tech/documentation/concepts/indexing/#payload-index + """ + if payload_fields_to_index is not None: + for payload_index in payload_fields_to_index: + self.client.create_payload_index( + collection_name=collection_name, + field_name=payload_index["field_name"], + field_schema=payload_index["field_schema"], + ) + def _set_up_collection( self, collection_name: str, embedding_dim: int, recreate_collection: bool, # noqa: FBT001 similarity: str, + on_disk: bool = False, # noqa: FBT001, FBT002 + payload_fields_to_index: Optional[List[dict]] = None, ): distance = self._get_distance(similarity) if recreate_collection: # There is no need to verify the current configuration of that # collection. It might be just recreated again. - self._recreate_collection(collection_name, distance, embedding_dim) + self._recreate_collection(collection_name, distance, embedding_dim, on_disk) + # Create Payload index if payload_fields_to_index is provided + self._create_payload_index(collection_name, payload_fields_to_index) return try: @@ -360,7 +381,9 @@ def _set_up_collection( # Qdrant local raises ValueError if the collection is not found, but # with the remote server UnexpectedResponse / RpcError is raised. # Until that's unified, we need to catch both. - self._recreate_collection(collection_name, distance, embedding_dim) + self._recreate_collection(collection_name, distance, embedding_dim, on_disk) + # Create Payload index if payload_fields_to_index is provided + self._create_payload_index(collection_name, payload_fields_to_index) return current_distance = collection_info.config.params.vectors.distance @@ -384,11 +407,12 @@ def _set_up_collection( ) raise ValueError(msg) - def _recreate_collection(self, collection_name: str, distance, embedding_dim: int): + def _recreate_collection(self, collection_name: str, distance, embedding_dim: int, on_disk: bool): # noqa: FBT001 self.client.recreate_collection( collection_name=collection_name, vectors_config=rest.VectorParams( size=embedding_dim, + on_disk=on_disk, distance=distance, ), shard_number=self.shard_number, diff --git a/integrations/qdrant/tests/test_dict_converters.py b/integrations/qdrant/tests/test_dict_converters.py index 18940fbbf..3da64743a 100644 --- a/integrations/qdrant/tests/test_dict_converters.py +++ b/integrations/qdrant/tests/test_dict_converters.py @@ -21,6 +21,7 @@ def test_to_dict(): "path": None, "index": "test", "embedding_dim": 768, + "on_disk": False, "content_field": "content", "name_field": "name", "embedding_field": "embedding", @@ -42,6 +43,7 @@ def test_to_dict(): "metadata": {}, "write_batch_size": 100, "scroll_size": 10000, + "payload_fields_to_index": None, }, } @@ -57,6 +59,7 @@ def test_from_dict(): "location": ":memory:", "index": "test", "embedding_dim": 768, + "on_disk": False, "content_field": "content", "name_field": "name", "embedding_field": "embedding", @@ -72,6 +75,7 @@ def test_from_dict(): "metadata": {}, "write_batch_size": 1000, "scroll_size": 10000, + "payload_fields_to_index": None, }, } ) @@ -82,6 +86,7 @@ def test_from_dict(): document_store.content_field == "content", document_store.name_field == "name", document_store.embedding_field == "embedding", + document_store.on_disk is False, document_store.similarity == "cosine", document_store.return_embedding is False, document_store.progress_bar, @@ -101,5 +106,6 @@ def test_from_dict(): document_store.write_batch_size == 1000, document_store.scroll_size == 10000, document_store.api_key == Secret.from_env_var("ENV_VAR", strict=False), + document_store.payload_fields_to_index is None, ] ) diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index 7521642ff..41d9b3088 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -41,6 +41,7 @@ def test_to_dict(self): "path": None, "index": "test", "embedding_dim": 768, + "on_disk": False, "content_field": "content", "name_field": "name", "embedding_field": "embedding", @@ -62,6 +63,7 @@ def test_to_dict(self): "metadata": {}, "write_batch_size": 100, "scroll_size": 10000, + "payload_fields_to_index": None, }, }, "filters": None,