From be40dcb068902984b1d85cbbb62c87bc8bcee05a Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Wed, 14 Feb 2024 14:53:17 +0100 Subject: [PATCH] Add WeaviateEmbeddingRetriever (#412) --- .../retrievers/weaviate/__init__.py | 3 +- .../weaviate/embedding_retriever.py | 80 ++++++++++++ .../weaviate/document_store.py | 40 ++++++ .../weaviate/tests/test_document_store.py | 71 +++++++++++ .../tests/test_embedding_retriever.py | 119 ++++++++++++++++++ 5 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py create mode 100644 integrations/weaviate/tests/test_embedding_retriever.py diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py index cf596f0cb..34bfd0c7d 100644 --- a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py @@ -1,3 +1,4 @@ from .bm25_retriever import WeaviateBM25Retriever +from .embedding_retriever import WeaviateEmbeddingRetriever -__all__ = ["WeaviateBM25Retriever"] +__all__ = ["WeaviateBM25Retriever", "WeaviateEmbeddingRetriever"] diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py new file mode 100644 index 000000000..b8a163b56 --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +@component +class WeaviateEmbeddingRetriever: + """ + A retriever that uses Weaviate's vector search to find similar documents based on the embeddings of the query. + """ + + def __init__( + self, + *, + document_store: WeaviateDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ): + """ + Create a new instance of WeaviateEmbeddingRetriever. + Raises ValueError if both `distance` and `certainty` are provided. + See the official Weaviate documentation to learn more about the `distance` and `certainty` parameters: + https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables + + :param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever. + :param filters: Custom filters applied when running the retriever, defaults to None + :param top_k: Maximum number of documents to return, defaults to 10 + :param distance: The maximum allowed distance between Documents' embeddings, defaults to None + :param certainty: Normalized distance between the result item and the search vector, defaults to None + """ + if distance is not None and certainty is not None: + msg = "Can't use 'distance' and 'certainty' parameters together" + raise ValueError(msg) + + self._document_store = document_store + self._filters = filters or {} + self._top_k = top_k + self._distance = distance + self._certainty = certainty + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + distance=self._distance, + certainty=self._certainty, + document_store=self._document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever": + data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ): + filters = filters or self._filters + top_k = top_k or self._top_k + distance = distance or self._distance + certainty = certainty or self._certainty + return self._document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + distance=distance, + certainty=certainty, + ) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 90391fb0f..38f0b38cd 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -442,3 +442,43 @@ def _bm25_retrieval( result = query_builder.do() return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] + + def _embedding_retrieval( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ) -> List[Document]: + if distance is not None and certainty is not None: + msg = "Can't use 'distance' and 'certainty' parameters together" + raise ValueError(msg) + + collection_name = self._collection_settings["class"] + properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) + properties = [prop["name"] for prop in properties] + + near_vector: Dict[str, Union[float, List[float]]] = { + "vector": query_embedding, + } + if distance is not None: + near_vector["distance"] = distance + + if certainty is not None: + near_vector["certainty"] = certainty + + query_builder = ( + self._client.query.get(collection_name, properties=properties) + .with_near_vector(near_vector) + .with_additional(["vector"]) + ) + + if filters: + query_builder = query_builder.with_where(convert_filters(filters)) + + if top_k: + query_builder = query_builder.with_limit(top_k) + + result = query_builder.do() + return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 13cc92258..359af3670 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -548,3 +548,74 @@ def test_bm25_retrieval_with_topk(self, document_store): assert "functional" in result[0].content assert "functional" in result[1].content assert "functional" in result[2].content + + def test_embedding_retrieval(self, document_store): + document_store.write_documents( + [ + Document( + content="Yet another document", + embedding=[0.00001, 0.00001, 0.00001, 0.00002], + ), + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + ] + ) + result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0]) + assert len(result) == 3 + assert "The document" == result[0].content + assert "Another document" == result[1].content + assert "Yet another document" == result[2].content + + def test_embedding_retrieval_with_filters(self, document_store): + document_store.write_documents( + [ + Document( + content="Yet another document", + embedding=[0.00001, 0.00001, 0.00001, 0.00002], + ), + Document(content="The document I want", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + ] + ) + filters = {"field": "content", "operator": "==", "value": "The document I want"} + result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], filters=filters) + assert len(result) == 1 + assert "The document I want" == result[0].content + + def test_embedding_retrieval_with_topk(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], top_k=2) + assert len(results) == 2 + assert results[0].content == "The document" + assert results[1].content == "Another document" + + def test_embedding_retrieval_with_distance(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], distance=0.0) + assert len(results) == 1 + assert results[0].content == "The document" + + def test_embedding_retrieval_with_certainty(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[0.8, 0.8, 0.8, 1.0], certainty=1.0) + assert len(results) == 1 + assert results[0].content == "Another document" + + def test_embedding_retrieval_with_distance_and_certainty(self, document_store): + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1) diff --git a/integrations/weaviate/tests/test_embedding_retriever.py b/integrations/weaviate/tests/test_embedding_retriever.py new file mode 100644 index 000000000..7f07d8a24 --- /dev/null +++ b/integrations/weaviate/tests/test_embedding_retriever.py @@ -0,0 +1,119 @@ +from unittest.mock import Mock, patch + +import pytest +from haystack_integrations.components.retrievers.weaviate import WeaviateEmbeddingRetriever +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +def test_init_default(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store) + assert retriever._document_store == mock_document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._distance is None + assert retriever._certainty is None + + +def test_init_with_distance_and_certainty(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + with pytest.raises(ValueError): + WeaviateEmbeddingRetriever(document_store=mock_document_store, distance=0.1, certainty=0.8) + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_to_dict(_mock_weaviate): + document_store = WeaviateDocumentStore() + retriever = WeaviateEmbeddingRetriever(document_store=document_store) + assert retriever.to_dict() == { + "type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "distance": None, + "certainty": None, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_from_dict(_mock_weaviate): + retriever = WeaviateEmbeddingRetriever.from_dict( + { + "type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "distance": None, + "certainty": None, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + ) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._distance is None + assert retriever._certainty is None + + +@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") +def test_run(mock_document_store): + retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store) + query_embedding = [0.1, 0.1, 0.1, 0.1] + filters = {"field": "content", "operator": "==", "value": "Some text"} + retriever.run(query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1) + mock_document_store._embedding_retrieval.assert_called_once_with( + query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1 + )