From ead07b33a3b14a00c8e15a55dfcff5988a24b8ba Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Wed, 14 Feb 2024 14:48:24 +0100 Subject: [PATCH] Add WeaviateBM25Retriever (#410) --- .../retrievers/weaviate/__init__.py | 3 + .../retrievers/weaviate/bm25_retriever.py | 50 +++++++++ .../weaviate/document_store.py | 23 ++++ .../weaviate/tests/test_bm25_retriever.py | 102 ++++++++++++++++++ .../weaviate/tests/test_document_store.py | 67 ++++++++++++ 5 files changed, 245 insertions(+) create mode 100644 integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py create mode 100644 integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py create mode 100644 integrations/weaviate/tests/test_bm25_retriever.py diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py new file mode 100644 index 000000000..cf596f0cb --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py @@ -0,0 +1,3 @@ +from .bm25_retriever import WeaviateBM25Retriever + +__all__ = ["WeaviateBM25Retriever"] diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py new file mode 100644 index 000000000..6c27378cf --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +@component +class WeaviateBM25Retriever: + """ + Retriever that uses BM25 to find the most promising documents for a given query. + """ + + def __init__( + self, + *, + document_store: WeaviateDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ): + """ + Create a new instance of WeaviateBM25Retriever. + + :param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever. + :param filters: Custom filters applied when running the retriever, defaults to None + :param top_k: Maximum number of documents to return, defaults to 10 + """ + self._document_store = document_store + self._filters = filters or {} + self._top_k = top_k + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever": + data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + filters = filters or self._filters + top_k = top_k or self._top_k + return self._document_store._bm25_retrieval(query=query, filters=filters, top_k=top_k) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index b7aba3716..90391fb0f 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -419,3 +419,26 @@ def delete_documents(self, document_ids: List[str]) -> None: "valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids], }, ) + + def _bm25_retrieval( + self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None + ) -> List[Document]: + collection_name = self._collection_settings["class"] + properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) + properties = [prop["name"] for prop in properties] + + query_builder = ( + self._client.query.get(collection_name, properties=properties) + .with_bm25(query=query, properties=["content"]) + .with_additional(["vector"]) + ) + + if filters: + query_builder = query_builder.with_where(convert_filters(filters)) + + if top_k: + query_builder = query_builder.with_limit(top_k) + + result = query_builder.do() + + return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] diff --git a/integrations/weaviate/tests/test_bm25_retriever.py b/integrations/weaviate/tests/test_bm25_retriever.py new file mode 100644 index 000000000..83f90735b --- /dev/null +++ b/integrations/weaviate/tests/test_bm25_retriever.py @@ -0,0 +1,102 @@ +from unittest.mock import Mock, patch + +from haystack_integrations.components.retrievers.weaviate import WeaviateBM25Retriever +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +def test_init_default(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + retriever = WeaviateBM25Retriever(document_store=mock_document_store) + assert retriever._document_store == mock_document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_to_dict(_mock_weaviate): + document_store = WeaviateDocumentStore() + retriever = WeaviateBM25Retriever(document_store=document_store) + assert retriever.to_dict() == { + "type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_from_dict(_mock_weaviate): + retriever = WeaviateBM25Retriever.from_dict( + { + "type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + ) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + + +@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") +def test_run(mock_document_store): + retriever = WeaviateBM25Retriever(document_store=mock_document_store) + query = "some query" + filters = {"field": "content", "operator": "==", "value": "Some text"} + retriever.run(query=query, filters=filters, top_k=5) + mock_document_store._bm25_retrieval.assert_called_once_with(query=query, filters=filters, top_k=5) diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 2322a9484..13cc92258 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -481,3 +481,70 @@ def test_comparison_less_than_equal_with_iso_date(self, document_store, filterab @pytest.mark.skip(reason="Weaviate for some reason is not returning what we expect") def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): return super().test_comparison_not_equal_with_dataframe(document_store, filterable_docs) + + def test_bm25_retrieval(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + result = document_store._bm25_retrieval("functional Haskell") + assert len(result) == 5 + assert "functional" in result[0].content + assert "functional" in result[1].content + assert "functional" in result[2].content + assert "functional" in result[3].content + assert "functional" in result[4].content + + def test_bm25_retrieval_with_filters(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + filters = {"field": "content", "operator": "==", "value": "Haskell"} + result = document_store._bm25_retrieval("functional Haskell", filters=filters) + assert len(result) == 1 + assert "Haskell is a functional programming language" == result[0].content + + def test_bm25_retrieval_with_topk(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + result = document_store._bm25_retrieval("functional Haskell", top_k=3) + assert len(result) == 3 + assert "functional" in result[0].content + assert "functional" in result[1].content + assert "functional" in result[2].content