Skip to content

Commit

Permalink
Add WeaviateEmbeddingRetriever (#412)
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza authored Feb 14, 2024
1 parent ead07b3 commit be40dcb
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .bm25_retriever import WeaviateBM25Retriever
from .embedding_retriever import WeaviateEmbeddingRetriever

__all__ = ["WeaviateBM25Retriever"]
__all__ = ["WeaviateBM25Retriever", "WeaviateEmbeddingRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Dict, List, Optional

from haystack import Document, component, default_from_dict, default_to_dict
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore


@component
class WeaviateEmbeddingRetriever:
"""
A retriever that uses Weaviate's vector search to find similar documents based on the embeddings of the query.
"""

def __init__(
self,
*,
document_store: WeaviateDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
distance: Optional[float] = None,
certainty: Optional[float] = None,
):
"""
Create a new instance of WeaviateEmbeddingRetriever.
Raises ValueError if both `distance` and `certainty` are provided.
See the official Weaviate documentation to learn more about the `distance` and `certainty` parameters:
https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables
:param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever.
:param filters: Custom filters applied when running the retriever, defaults to None
:param top_k: Maximum number of documents to return, defaults to 10
:param distance: The maximum allowed distance between Documents' embeddings, defaults to None
:param certainty: Normalized distance between the result item and the search vector, defaults to None
"""
if distance is not None and certainty is not None:
msg = "Can't use 'distance' and 'certainty' parameters together"
raise ValueError(msg)

self._document_store = document_store
self._filters = filters or {}
self._top_k = top_k
self._distance = distance
self._certainty = certainty

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
distance=self._distance,
certainty=self._certainty,
document_store=self._document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever":
data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
distance: Optional[float] = None,
certainty: Optional[float] = None,
):
filters = filters or self._filters
top_k = top_k or self._top_k
distance = distance or self._distance
certainty = certainty or self._certainty
return self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
distance=distance,
certainty=certainty,
)
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,43 @@ def _bm25_retrieval(
result = query_builder.do()

return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]]

def _embedding_retrieval(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
distance: Optional[float] = None,
certainty: Optional[float] = None,
) -> List[Document]:
if distance is not None and certainty is not None:
msg = "Can't use 'distance' and 'certainty' parameters together"
raise ValueError(msg)

collection_name = self._collection_settings["class"]
properties = self._client.schema.get(self._collection_settings["class"]).get("properties", [])
properties = [prop["name"] for prop in properties]

near_vector: Dict[str, Union[float, List[float]]] = {
"vector": query_embedding,
}
if distance is not None:
near_vector["distance"] = distance

if certainty is not None:
near_vector["certainty"] = certainty

query_builder = (
self._client.query.get(collection_name, properties=properties)
.with_near_vector(near_vector)
.with_additional(["vector"])
)

if filters:
query_builder = query_builder.with_where(convert_filters(filters))

if top_k:
query_builder = query_builder.with_limit(top_k)

result = query_builder.do()
return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]]
71 changes: 71 additions & 0 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,74 @@ def test_bm25_retrieval_with_topk(self, document_store):
assert "functional" in result[0].content
assert "functional" in result[1].content
assert "functional" in result[2].content

def test_embedding_retrieval(self, document_store):
document_store.write_documents(
[
Document(
content="Yet another document",
embedding=[0.00001, 0.00001, 0.00001, 0.00002],
),
Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]),
]
)
result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0])
assert len(result) == 3
assert "The document" == result[0].content
assert "Another document" == result[1].content
assert "Yet another document" == result[2].content

def test_embedding_retrieval_with_filters(self, document_store):
document_store.write_documents(
[
Document(
content="Yet another document",
embedding=[0.00001, 0.00001, 0.00001, 0.00002],
),
Document(content="The document I want", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]),
]
)
filters = {"field": "content", "operator": "==", "value": "The document I want"}
result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], filters=filters)
assert len(result) == 1
assert "The document I want" == result[0].content

def test_embedding_retrieval_with_topk(self, document_store):
docs = [
Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]),
Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]),
]
document_store.write_documents(docs)
results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], top_k=2)
assert len(results) == 2
assert results[0].content == "The document"
assert results[1].content == "Another document"

def test_embedding_retrieval_with_distance(self, document_store):
docs = [
Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]),
Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]),
]
document_store.write_documents(docs)
results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], distance=0.0)
assert len(results) == 1
assert results[0].content == "The document"

def test_embedding_retrieval_with_certainty(self, document_store):
docs = [
Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]),
Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]),
]
document_store.write_documents(docs)
results = document_store._embedding_retrieval(query_embedding=[0.8, 0.8, 0.8, 1.0], certainty=1.0)
assert len(results) == 1
assert results[0].content == "Another document"

def test_embedding_retrieval_with_distance_and_certainty(self, document_store):
with pytest.raises(ValueError):
document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1)
119 changes: 119 additions & 0 deletions integrations/weaviate/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from unittest.mock import Mock, patch

import pytest
from haystack_integrations.components.retrievers.weaviate import WeaviateEmbeddingRetriever
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore


def test_init_default():
mock_document_store = Mock(spec=WeaviateDocumentStore)
retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store)
assert retriever._document_store == mock_document_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._distance is None
assert retriever._certainty is None


def test_init_with_distance_and_certainty():
mock_document_store = Mock(spec=WeaviateDocumentStore)
with pytest.raises(ValueError):
WeaviateEmbeddingRetriever(document_store=mock_document_store, distance=0.1, certainty=0.8)


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_to_dict(_mock_weaviate):
document_store = WeaviateDocumentStore()
retriever = WeaviateEmbeddingRetriever(document_store=document_store)
assert retriever.to_dict() == {
"type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever",
"init_parameters": {
"filters": {},
"top_k": 10,
"distance": None,
"certainty": None,
"document_store": {
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": None,
"collection_settings": {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": None,
"timeout_config": (10, 60),
"proxies": None,
"trust_env": False,
"additional_headers": None,
"startup_period": 5,
"embedded_options": None,
"additional_config": None,
},
},
},
}


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_from_dict(_mock_weaviate):
retriever = WeaviateEmbeddingRetriever.from_dict(
{
"type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", # noqa: E501
"init_parameters": {
"filters": {},
"top_k": 10,
"distance": None,
"certainty": None,
"document_store": {
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": None,
"collection_settings": {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": None,
"timeout_config": (10, 60),
"proxies": None,
"trust_env": False,
"additional_headers": None,
"startup_period": 5,
"embedded_options": None,
"additional_config": None,
},
},
},
}
)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._distance is None
assert retriever._certainty is None


@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore")
def test_run(mock_document_store):
retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store)
query_embedding = [0.1, 0.1, 0.1, 0.1]
filters = {"field": "content", "operator": "==", "value": "Some text"}
retriever.run(query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1)
mock_document_store._embedding_retrieval.assert_called_once_with(
query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1
)

0 comments on commit be40dcb

Please sign in to comment.