Skip to content

Commit

Permalink
Add WeaviateBM25Retriever (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza authored Feb 14, 2024
1 parent 2f451e9 commit ead07b3
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .bm25_retriever import WeaviateBM25Retriever

__all__ = ["WeaviateBM25Retriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, Dict, List, Optional

from haystack import Document, component, default_from_dict, default_to_dict
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore


@component
class WeaviateBM25Retriever:
"""
Retriever that uses BM25 to find the most promising documents for a given query.
"""

def __init__(
self,
*,
document_store: WeaviateDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
):
"""
Create a new instance of WeaviateBM25Retriever.
:param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever.
:param filters: Custom filters applied when running the retriever, defaults to None
:param top_k: Maximum number of documents to return, defaults to 10
"""
self._document_store = document_store
self._filters = filters or {}
self._top_k = top_k

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever":
data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
filters = filters or self._filters
top_k = top_k or self._top_k
return self._document_store._bm25_retrieval(query=query, filters=filters, top_k=top_k)
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,26 @@ def delete_documents(self, document_ids: List[str]) -> None:
"valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids],
},
)

def _bm25_retrieval(
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
) -> List[Document]:
collection_name = self._collection_settings["class"]
properties = self._client.schema.get(self._collection_settings["class"]).get("properties", [])
properties = [prop["name"] for prop in properties]

query_builder = (
self._client.query.get(collection_name, properties=properties)
.with_bm25(query=query, properties=["content"])
.with_additional(["vector"])
)

if filters:
query_builder = query_builder.with_where(convert_filters(filters))

if top_k:
query_builder = query_builder.with_limit(top_k)

result = query_builder.do()

return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]]
102 changes: 102 additions & 0 deletions integrations/weaviate/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from unittest.mock import Mock, patch

from haystack_integrations.components.retrievers.weaviate import WeaviateBM25Retriever
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore


def test_init_default():
mock_document_store = Mock(spec=WeaviateDocumentStore)
retriever = WeaviateBM25Retriever(document_store=mock_document_store)
assert retriever._document_store == mock_document_store
assert retriever._filters == {}
assert retriever._top_k == 10


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_to_dict(_mock_weaviate):
document_store = WeaviateDocumentStore()
retriever = WeaviateBM25Retriever(document_store=document_store)
assert retriever.to_dict() == {
"type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever",
"init_parameters": {
"filters": {},
"top_k": 10,
"document_store": {
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": None,
"collection_settings": {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": None,
"timeout_config": (10, 60),
"proxies": None,
"trust_env": False,
"additional_headers": None,
"startup_period": 5,
"embedded_options": None,
"additional_config": None,
},
},
},
}


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_from_dict(_mock_weaviate):
retriever = WeaviateBM25Retriever.from_dict(
{
"type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever",
"init_parameters": {
"filters": {},
"top_k": 10,
"document_store": {
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": None,
"collection_settings": {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": None,
"timeout_config": (10, 60),
"proxies": None,
"trust_env": False,
"additional_headers": None,
"startup_period": 5,
"embedded_options": None,
"additional_config": None,
},
},
},
}
)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._top_k == 10


@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore")
def test_run(mock_document_store):
retriever = WeaviateBM25Retriever(document_store=mock_document_store)
query = "some query"
filters = {"field": "content", "operator": "==", "value": "Some text"}
retriever.run(query=query, filters=filters, top_k=5)
mock_document_store._bm25_retrieval.assert_called_once_with(query=query, filters=filters, top_k=5)
67 changes: 67 additions & 0 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,70 @@ def test_comparison_less_than_equal_with_iso_date(self, document_store, filterab
@pytest.mark.skip(reason="Weaviate for some reason is not returning what we expect")
def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs):
return super().test_comparison_not_equal_with_dataframe(document_store, filterable_docs)

def test_bm25_retrieval(self, document_store):
document_store.write_documents(
[
Document(content="Haskell is a functional programming language"),
Document(content="Lisp is a functional programming language"),
Document(content="Exilir is a functional programming language"),
Document(content="F# is a functional programming language"),
Document(content="C# is a functional programming language"),
Document(content="C++ is an object oriented programming language"),
Document(content="Dart is an object oriented programming language"),
Document(content="Go is an object oriented programming language"),
Document(content="Python is a object oriented programming language"),
Document(content="Ruby is a object oriented programming language"),
Document(content="PHP is a object oriented programming language"),
]
)
result = document_store._bm25_retrieval("functional Haskell")
assert len(result) == 5
assert "functional" in result[0].content
assert "functional" in result[1].content
assert "functional" in result[2].content
assert "functional" in result[3].content
assert "functional" in result[4].content

def test_bm25_retrieval_with_filters(self, document_store):
document_store.write_documents(
[
Document(content="Haskell is a functional programming language"),
Document(content="Lisp is a functional programming language"),
Document(content="Exilir is a functional programming language"),
Document(content="F# is a functional programming language"),
Document(content="C# is a functional programming language"),
Document(content="C++ is an object oriented programming language"),
Document(content="Dart is an object oriented programming language"),
Document(content="Go is an object oriented programming language"),
Document(content="Python is a object oriented programming language"),
Document(content="Ruby is a object oriented programming language"),
Document(content="PHP is a object oriented programming language"),
]
)
filters = {"field": "content", "operator": "==", "value": "Haskell"}
result = document_store._bm25_retrieval("functional Haskell", filters=filters)
assert len(result) == 1
assert "Haskell is a functional programming language" == result[0].content

def test_bm25_retrieval_with_topk(self, document_store):
document_store.write_documents(
[
Document(content="Haskell is a functional programming language"),
Document(content="Lisp is a functional programming language"),
Document(content="Exilir is a functional programming language"),
Document(content="F# is a functional programming language"),
Document(content="C# is a functional programming language"),
Document(content="C++ is an object oriented programming language"),
Document(content="Dart is an object oriented programming language"),
Document(content="Go is an object oriented programming language"),
Document(content="Python is a object oriented programming language"),
Document(content="Ruby is a object oriented programming language"),
Document(content="PHP is a object oriented programming language"),
]
)
result = document_store._bm25_retrieval("functional Haskell", top_k=3)
assert len(result) == 3
assert "functional" in result[0].content
assert "functional" in result[1].content
assert "functional" in result[2].content

0 comments on commit ead07b3

Please sign in to comment.