diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 7fe24ab20..3d658c316 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -7,12 +7,14 @@ from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses.document import Document +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types.policy import DuplicatePolicy import weaviate from weaviate.auth import AuthCredentials from weaviate.config import Config, ConnectionConfig from weaviate.embedded import EmbeddedOptions +from weaviate.util import generate_uuid5 Number = Union[int, float] TimeoutType = Union[Tuple[Number, Number], Number] @@ -239,15 +241,145 @@ def _to_document(self, data: Dict[str, Any]) -> Document: return Document.from_dict(data) + def _query(self, properties: List[str], batch_size: int, cursor=None): + collection_name = self._collection_settings["class"] + query = ( + self._client.query.get( + collection_name, + properties, + ) + .with_additional(["id vector"]) + .with_limit(batch_size) + ) + + if cursor: + # Fetch the next set of results + result = query.with_after(cursor).do() + else: + # Fetch the first set of results + result = query.do() + + if "errors" in result: + errors = [e["message"] for e in result.get("errors", {})] + msg = "\n".join(errors) + msg = f"Failed to query documents in Weaviate. Errors:\n{msg}" + raise DocumentStoreError(msg) + + return result["data"]["Get"][collection_name] + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002 - return [] + properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) + properties = [prop["name"] for prop in properties] - def write_documents( - self, - documents: List[Document], # noqa: ARG002 - policy: DuplicatePolicy = DuplicatePolicy.NONE, # noqa: ARG002 - ) -> int: - return 0 + result = [] + + cursor = None + while batch := self._query(properties, 100, cursor): + # Take the cursor before we convert the batch to Documents as we manipulate + # the batch dictionary and might lose that information. + cursor = batch[-1]["_additional"]["id"] + + for doc in batch: + result.append(self._to_document(doc)) + # Move the cursor to the last returned uuid + return result + + def _batch_write(self, documents: List[Document]) -> int: + """ + Writes document to Weaviate in batches. + Documents with the same id will be overwritten. + Raises in case of errors. + """ + statuses = [] + for doc in documents: + if not isinstance(doc, Document): + msg = f"Expected a Document, got '{type(doc)}' instead." + raise ValueError(msg) + if self._client.batch.num_objects() == self._client.batch.recommended_num_objects: + # Batch is full, let's create the objects + statuses.extend(self._client.batch.create_objects()) + self._client.batch.add_data_object( + uuid=generate_uuid5(doc.id), + data_object=self._to_data_object(doc), + class_name=self._collection_settings["class"], + vector=doc.embedding, + ) + # Write remaining documents + statuses.extend(self._client.batch.create_objects()) + + errors = [] + # Gather errors and number of written documents + for status in statuses: + result_status = status.get("result", {}).get("status") + if result_status == "FAILED": + errors.extend([e["message"] for e in status["result"]["errors"]["error"]]) + + if errors: + msg = "\n".join(errors) + msg = f"Failed to write documents in Weaviate. Errors:\n{msg}" + raise DocumentStoreError(msg) + + # If the document already exists we get no status message back from Weaviate. + # So we assume that all Documents were written. + return len(documents) + + def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int: + """ + Writes documents to Weaviate using the specified policy. + This doesn't uses the batch API, so it's slower than _batch_write. + If policy is set to SKIP it will skip any document that already exists. + If policy is set to FAIL it will raise an exception if any of the documents already exists. + """ + written = 0 + duplicate_errors_ids = [] + for doc in documents: + if not isinstance(doc, Document): + msg = f"Expected a Document, got '{type(doc)}' instead." + raise ValueError(msg) - def delete_documents(self, document_ids: List[str]) -> None: # noqa: ARG002 - return + if policy == DuplicatePolicy.SKIP and self._client.data_object.exists( + uuid=generate_uuid5(doc.id), + class_name=self._collection_settings["class"], + ): + # This Document already exists, we skip it + continue + + try: + self._client.data_object.create( + uuid=generate_uuid5(doc.id), + data_object=self._to_data_object(doc), + class_name=self._collection_settings["class"], + vector=doc.embedding, + ) + written += 1 + except weaviate.exceptions.ObjectAlreadyExistsException: + if policy == DuplicatePolicy.FAIL: + duplicate_errors_ids.append(doc.id) + if duplicate_errors_ids: + msg = f"IDs '{', '.join(duplicate_errors_ids)}' already exist in the document store." + raise DuplicateDocumentError(msg) + return written + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + """ + Writes documents to Weaviate using the specified policy. + We recommend using a OVERWRITE policy as it's faster than other policies for Weaviate since it uses + the batch API. + We can't use the batch API for other policies as it doesn't return any information whether the document + already exists or not. That prevents us from returning errors when using the FAIL policy or skipping a + Document when using the SKIP policy. + """ + if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: + return self._batch_write(documents) + + return self._write(documents, policy) + + def delete_documents(self, document_ids: List[str]) -> None: + self._client.batch.delete_objects( + class_name=self._collection_settings["class"], + where={ + "path": ["id"], + "operator": "ContainsAny", + "valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids], + }, + ) diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index e988eb297..0682282f3 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -4,7 +4,7 @@ import pytest from haystack.dataclasses.byte_stream import ByteStream from haystack.dataclasses.document import Document -from haystack.testing.document_store import CountDocumentsTest +from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest from haystack_integrations.document_stores.weaviate.document_store import ( DOCUMENT_COLLECTION_PROPERTIES, WeaviateDocumentStore, @@ -20,7 +20,7 @@ ) -class TestWeaviateDocumentStore(CountDocumentsTest): +class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): @pytest.fixture def document_store(self, request) -> WeaviateDocumentStore: # Use a different index for each test so we can run them in parallel @@ -256,3 +256,30 @@ def test_to_document(self, document_store, test_files_path): assert doc.embedding == [1, 2, 3] assert doc.score is None assert doc.meta == {"key": "value"} + + def test_write_documents(self, document_store): + """ + Test write_documents() with default policy overwrites existing documents. + """ + doc = Document(content="test doc") + assert document_store.write_documents([doc]) == 1 + assert document_store.count_documents() == 1 + + doc.content = "test doc 2" + assert document_store.write_documents([doc]) == 1 + assert document_store.count_documents() == 1 + + def test_write_documents_with_blob_data(self, document_store, test_files_path): + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") + doc = Document(content="test doc", blob=image) + assert document_store.write_documents([doc]) == 1 + + def test_filter_documents_with_blob_data(self, document_store, test_files_path): + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") + doc = Document(content="test doc", blob=image) + assert document_store.write_documents([doc]) == 1 + + docs = document_store.filter_documents() + + assert len(docs) == 1 + assert docs[0].blob == image