From e749bdd85b292b1f553c97ab097941e4ce40d7ec Mon Sep 17 00:00:00 2001 From: alperkaya Date: Thu, 17 Oct 2024 11:31:50 +0200 Subject: [PATCH] add new testcases and remove filter --- .../mongodb_atlas/fulltext_retriever.py | 27 +---- .../mongodb_atlas/document_store.py | 5 - .../tests/test_full_text_retriever.py | 64 +---------- .../tests/test_fulltext_retrieval.py | 100 ++++++++++++++++++ 4 files changed, 104 insertions(+), 92 deletions(-) create mode 100644 integrations/mongodb_atlas/tests/test_fulltext_retrieval.py diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py index 373185f37..a98c8f1b4 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py @@ -2,8 +2,6 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document -from haystack.document_stores.types import FilterPolicy -from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore @@ -16,20 +14,14 @@ def __init__( *, document_store: MongoDBAtlasDocumentStore, search_path: Union[str, List[str]] = "content", - filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ Create the MongoDBAtlasFullTextRetriever component. :param document_store: An instance of MongoDBAtlasDocumentStore. :param search_path: Field(s) to search within, e.g., "content" or ["content", "title"]. - :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are - included in the configuration of the `vector_search_index`. The configuration must be done manually - in the Web UI of MongoDB Atlas. :param top_k: Maximum number of Documents to return. - :param filter_policy: Policy to determine how filters are applied. :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. """ @@ -38,12 +30,8 @@ def __init__( raise ValueError(msg) self.document_store = document_store - self.filters = filters or {} self.top_k = top_k self.search_path = search_path - self.filter_policy = ( - filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) - ) def to_dict(self) -> Dict[str, Any]: """ @@ -54,9 +42,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - filters=self.filters, top_k=self.top_k, - filter_policy=self.filter_policy.value, document_store=self.document_store.to_dict(), ) @@ -73,34 +59,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasFullTextRetriever": data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( data["init_parameters"]["document_store"] ) - # Pipelines serialized with old versions of the component might not - # have the filter_policy field. - if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) def run( self, query: str, - filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, ) -> Dict[str, List[Document]]: """ Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided query. :param query: Text query. - :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more - details. :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. :returns: A dictionary with the following keys: - `documents`: List of Documents most similar to the given `query` """ - filters = apply_filter_policy(self.filter_policy, self.filters, filters) top_k = top_k or self.top_k - docs = self.document_store._fulltext_retrieval( - query=query, filters=filters, top_k=top_k, search_path=self.search_path - ) + docs = self.document_store._fulltext_retrieval(query=query, top_k=top_k, search_path=self.search_path) return {"documents": docs} diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 3a4a240b6..080c15736 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -230,7 +230,6 @@ def _fulltext_retrieval( self, query: str, search_path: Union[str, List[str]] = "content", - filters: Optional[Dict[str, Any]] = None, top_k: int = 10, ) -> List[Document]: """ @@ -238,7 +237,6 @@ def _fulltext_retrieval( :param query: The text to search in the document store. :param search_path: Field(s) to search within, e.g., "content" or ["content", "title"]. - :param filters: Optional filters. :param top_k: How many documents to return. :returns: A list of Documents matching the full-text search query. :raises ValueError: If `query` is empty. @@ -248,8 +246,6 @@ def _fulltext_retrieval( msg = "query must not be empty" raise ValueError(msg) - filters = _normalize_filters(filters) if filters else {} - pipeline = [ { "$search": { @@ -260,7 +256,6 @@ def _fulltext_retrieval( }, } }, - {"$match": filters if filters else {}}, {"$limit": top_k}, {"$project": {"_id": 0, "content": 1, "meta": 1, "score": {"$meta": "searchScore"}}}, ] diff --git a/integrations/mongodb_atlas/tests/test_full_text_retriever.py b/integrations/mongodb_atlas/tests/test_full_text_retriever.py index b7b6c8710..41ab5e9c0 100644 --- a/integrations/mongodb_atlas/tests/test_full_text_retriever.py +++ b/integrations/mongodb_atlas/tests/test_full_text_retriever.py @@ -2,7 +2,6 @@ import pytest from haystack.dataclasses import Document -from haystack.document_stores.types import FilterPolicy from haystack.utils.auth import EnvVarSecret from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasFullTextRetriever @@ -27,40 +26,9 @@ def test_init_default(self): mock_store = Mock(spec=MongoDBAtlasDocumentStore) retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) assert retriever.document_store == mock_store - assert retriever.filters == {} assert retriever.top_k == 10 - assert retriever.filter_policy == FilterPolicy.REPLACE - retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="merge") - assert retriever.filter_policy == FilterPolicy.MERGE - - with pytest.raises(ValueError): - MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="wrong_policy") - - def test_init(self): - mock_store = Mock(spec=MongoDBAtlasDocumentStore) - retriever = MongoDBAtlasFullTextRetriever( - document_store=mock_store, - filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, - top_k=5, - ) - assert retriever.document_store == mock_store - assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} - assert retriever.top_k == 5 - assert retriever.filter_policy == FilterPolicy.REPLACE - - def test_init_filter_policy_merge(self): - mock_store = Mock(spec=MongoDBAtlasDocumentStore) - retriever = MongoDBAtlasFullTextRetriever( - document_store=mock_store, - filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, - top_k=5, - filter_policy=FilterPolicy.MERGE, - ) - assert retriever.document_store == mock_store - assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} - assert retriever.top_k == 5 - assert retriever.filter_policy == FilterPolicy.MERGE + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") @@ -71,7 +39,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a vector_search_index="default", ) - retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) + retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, top_k=5) res = retriever.to_dict() assert res == { "type": "haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 @@ -89,9 +57,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a "vector_search_index": "default", }, }, - "filters": {"field": "value"}, "top_k": 5, - "filter_policy": "replace", }, } @@ -114,9 +80,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client "vector_search_index": "default", }, }, - "filters": {"field": "value"}, "top_k": 5, - "filter_policy": "replace", }, } @@ -128,9 +92,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client assert document_store.database_name == "haystack_integration_test" assert document_store.collection_name == "test_collection" assert document_store.vector_search_index == "default" - assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 - assert retriever.filter_policy == FilterPolicy.REPLACE def test_run(self): mock_store = Mock(spec=MongoDBAtlasDocumentStore) @@ -140,26 +102,6 @@ def test_run(self): retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, search_path="desc") res = retriever.run(query="text") - mock_store._fulltext_retrieval.assert_called_once_with(query="text", filters={}, top_k=10, search_path="desc") - - assert res == {"documents": [doc]} - - def test_run_merge_policy_filter(self): - mock_store = Mock(spec=MongoDBAtlasDocumentStore) - doc = Document(content="Test doc") - mock_store._fulltext_retrieval.return_value = [doc] - - retriever = MongoDBAtlasFullTextRetriever( - document_store=mock_store, - filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, - filter_policy=FilterPolicy.MERGE, - ) - res = retriever.run(query="text", filters={"field": "meta.some_field", "operator": "==", "value": "Test"}) - mock_store._fulltext_retrieval.assert_called_once_with( - query="text", - filters={"field": "meta.some_field", "operator": "==", "value": "Test"}, - top_k=10, - search_path="content", - ) + mock_store._fulltext_retrieval.assert_called_once_with(query="text", top_k=10, search_path="desc") assert res == {"documents": [doc]} diff --git a/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py new file mode 100644 index 000000000..757e16f46 --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest + +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +@pytest.mark.skipif( + "MONGO_CONNECTION_STRING" not in os.environ, + reason="No MongoDB Atlas connection string provided", +) +@pytest.mark.integration +class TestEmbeddingRetrieval: + def test_basic_fulltext_retrieval(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_fulltext_collection", + vector_search_index="default", + ) + query = "crime" + results = document_store._fulltext_retrieval(query=query) + assert len(results) == 1 + + def test_fulltext_retrieval_custom_path(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_fulltext_collection", + vector_search_index="default", + ) + query = "Godfather" + path = "title" + results = document_store._fulltext_retrieval(query=query, search_path=path) + assert len(results) == 1 + + def test_fulltext_retrieval_multi_paths_and_top_k(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_fulltext_collection", + vector_search_index="default", + ) + query = "movie" + paths = ["title", "content"] + results = document_store._fulltext_retrieval(query=query, search_path=paths) + assert len(results) == 2 + + results = document_store._fulltext_retrieval(query=query, search_path=paths, top_k=1) + assert len(results) == 1 + + +""" +[ + { + "title": "The Matrix", + "content": "A hacker discovers that his reality is a simulation in this movie.", + "meta": { + "author": "Wachowskis", + "city": "San Francisco" + } + }, + { + "title": "Inception", + "content": "A thief who steals corporate secrets through the use of dream-sharing technology.", + "meta": { + "author": "Christopher Nolan", + "city": "Los Angeles" + } + }, + { + "title": "Interstellar", + "content": "A team of explorers travel through a wormhole in space in an attempt + to ensure humanity's survival.", + "meta": { + "author": "Christopher Nolan", + "city": "Houston" + } + }, + { + "title": "The Dark Knight", + "content": "When the menace known as the Joker emerges from his mysterious past, + he wreaks havoc on Gotham.", + "meta": { + "author": "Christopher Nolan", + "city": "Gotham" + } + }, + { + "title": "The Godfather Movie", + "content": "The aging patriarch of an organized crime dynasty transfers + control of his empire to his reluctant son.", + "meta": { + "author": "Mario Puzo", + "city": "New York" + } + } +] + +"""