Skip to content

Commit

Permalink
Add filter, write and delete documents in Weaviate (#270)
Browse files Browse the repository at this point in the history
* Add filter, write and delete documents in Weaviate

* Fix linting

* Fix typo
  • Loading branch information
silvanocerza authored Jan 26, 2024
1 parent 4ebedd4 commit d7a66db
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types.policy import DuplicatePolicy

import weaviate
from weaviate.auth import AuthCredentials
from weaviate.config import Config, ConnectionConfig
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5

Number = Union[int, float]
TimeoutType = Union[Tuple[Number, Number], Number]
Expand Down Expand Up @@ -239,15 +241,145 @@ def _to_document(self, data: Dict[str, Any]) -> Document:

return Document.from_dict(data)

def _query(self, properties: List[str], batch_size: int, cursor=None):
collection_name = self._collection_settings["class"]
query = (
self._client.query.get(
collection_name,
properties,
)
.with_additional(["id vector"])
.with_limit(batch_size)
)

if cursor:
# Fetch the next set of results
result = query.with_after(cursor).do()
else:
# Fetch the first set of results
result = query.do()

if "errors" in result:
errors = [e["message"] for e in result.get("errors", {})]
msg = "\n".join(errors)
msg = f"Failed to query documents in Weaviate. Errors:\n{msg}"
raise DocumentStoreError(msg)

return result["data"]["Get"][collection_name]

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002
return []
properties = self._client.schema.get(self._collection_settings["class"]).get("properties", [])
properties = [prop["name"] for prop in properties]

def write_documents(
self,
documents: List[Document], # noqa: ARG002
policy: DuplicatePolicy = DuplicatePolicy.NONE, # noqa: ARG002
) -> int:
return 0
result = []

cursor = None
while batch := self._query(properties, 100, cursor):
# Take the cursor before we convert the batch to Documents as we manipulate
# the batch dictionary and might lose that information.
cursor = batch[-1]["_additional"]["id"]

for doc in batch:
result.append(self._to_document(doc))
# Move the cursor to the last returned uuid
return result

def _batch_write(self, documents: List[Document]) -> int:
"""
Writes document to Weaviate in batches.
Documents with the same id will be overwritten.
Raises in case of errors.
"""
statuses = []
for doc in documents:
if not isinstance(doc, Document):
msg = f"Expected a Document, got '{type(doc)}' instead."
raise ValueError(msg)
if self._client.batch.num_objects() == self._client.batch.recommended_num_objects:
# Batch is full, let's create the objects
statuses.extend(self._client.batch.create_objects())
self._client.batch.add_data_object(
uuid=generate_uuid5(doc.id),
data_object=self._to_data_object(doc),
class_name=self._collection_settings["class"],
vector=doc.embedding,
)
# Write remaining documents
statuses.extend(self._client.batch.create_objects())

errors = []
# Gather errors and number of written documents
for status in statuses:
result_status = status.get("result", {}).get("status")
if result_status == "FAILED":
errors.extend([e["message"] for e in status["result"]["errors"]["error"]])

if errors:
msg = "\n".join(errors)
msg = f"Failed to write documents in Weaviate. Errors:\n{msg}"
raise DocumentStoreError(msg)

# If the document already exists we get no status message back from Weaviate.
# So we assume that all Documents were written.
return len(documents)

def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int:
"""
Writes documents to Weaviate using the specified policy.
This doesn't uses the batch API, so it's slower than _batch_write.
If policy is set to SKIP it will skip any document that already exists.
If policy is set to FAIL it will raise an exception if any of the documents already exists.
"""
written = 0
duplicate_errors_ids = []
for doc in documents:
if not isinstance(doc, Document):
msg = f"Expected a Document, got '{type(doc)}' instead."
raise ValueError(msg)

def delete_documents(self, document_ids: List[str]) -> None: # noqa: ARG002
return
if policy == DuplicatePolicy.SKIP and self._client.data_object.exists(
uuid=generate_uuid5(doc.id),
class_name=self._collection_settings["class"],
):
# This Document already exists, we skip it
continue

try:
self._client.data_object.create(
uuid=generate_uuid5(doc.id),
data_object=self._to_data_object(doc),
class_name=self._collection_settings["class"],
vector=doc.embedding,
)
written += 1
except weaviate.exceptions.ObjectAlreadyExistsException:
if policy == DuplicatePolicy.FAIL:
duplicate_errors_ids.append(doc.id)
if duplicate_errors_ids:
msg = f"IDs '{', '.join(duplicate_errors_ids)}' already exist in the document store."
raise DuplicateDocumentError(msg)
return written

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
"""
Writes documents to Weaviate using the specified policy.
We recommend using a OVERWRITE policy as it's faster than other policies for Weaviate since it uses
the batch API.
We can't use the batch API for other policies as it doesn't return any information whether the document
already exists or not. That prevents us from returning errors when using the FAIL policy or skipping a
Document when using the SKIP policy.
"""
if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]:
return self._batch_write(documents)

return self._write(documents, policy)

def delete_documents(self, document_ids: List[str]) -> None:
self._client.batch.delete_objects(
class_name=self._collection_settings["class"],
where={
"path": ["id"],
"operator": "ContainsAny",
"valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids],
},
)
31 changes: 29 additions & 2 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.document import Document
from haystack.testing.document_store import CountDocumentsTest
from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest
from haystack_integrations.document_stores.weaviate.document_store import (
DOCUMENT_COLLECTION_PROPERTIES,
WeaviateDocumentStore,
Expand All @@ -20,7 +20,7 @@
)


class TestWeaviateDocumentStore(CountDocumentsTest):
class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest):
@pytest.fixture
def document_store(self, request) -> WeaviateDocumentStore:
# Use a different index for each test so we can run them in parallel
Expand Down Expand Up @@ -256,3 +256,30 @@ def test_to_document(self, document_store, test_files_path):
assert doc.embedding == [1, 2, 3]
assert doc.score is None
assert doc.meta == {"key": "value"}

def test_write_documents(self, document_store):
"""
Test write_documents() with default policy overwrites existing documents.
"""
doc = Document(content="test doc")
assert document_store.write_documents([doc]) == 1
assert document_store.count_documents() == 1

doc.content = "test doc 2"
assert document_store.write_documents([doc]) == 1
assert document_store.count_documents() == 1

def test_write_documents_with_blob_data(self, document_store, test_files_path):
image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg")
doc = Document(content="test doc", blob=image)
assert document_store.write_documents([doc]) == 1

def test_filter_documents_with_blob_data(self, document_store, test_files_path):
image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg")
doc = Document(content="test doc", blob=image)
assert document_store.write_documents([doc]) == 1

docs = document_store.filter_documents()

assert len(docs) == 1
assert docs[0].blob == image

0 comments on commit d7a66db

Please sign in to comment.