Skip to content

Commit

Permalink
add new testcases and remove filter
Browse files Browse the repository at this point in the history
  • Loading branch information
alperkaya committed Oct 17, 2024
1 parent fb68881 commit e749bdd
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore

Expand All @@ -16,20 +14,14 @@ def __init__(
*,
document_store: MongoDBAtlasDocumentStore,
search_path: Union[str, List[str]] = "content",
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
):
"""
Create the MongoDBAtlasFullTextRetriever component.
:param document_store: An instance of MongoDBAtlasDocumentStore.
:param search_path: Field(s) to search within, e.g., "content" or ["content", "title"].
:param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are
included in the configuration of the `vector_search_index`. The configuration must be done manually
in the Web UI of MongoDB Atlas.
:param top_k: Maximum number of Documents to return.
:param filter_policy: Policy to determine how filters are applied.
:raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`.
"""

Expand All @@ -38,12 +30,8 @@ def __init__(
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k
self.search_path = search_path
self.filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -54,9 +42,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
filter_policy=self.filter_policy.value,
document_store=self.document_store.to_dict(),
)

Expand All @@ -73,34 +59,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasFullTextRetriever":
data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
) -> Dict[str, List[Document]]:
"""
Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided query.
:param query: Text query.
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See init method docstring for more
details.
:param top_k: Maximum number of Documents to return. Overrides the value specified at initialization.
:returns: A dictionary with the following keys:
- `documents`: List of Documents most similar to the given `query`
"""
filters = apply_filter_policy(self.filter_policy, self.filters, filters)
top_k = top_k or self.top_k

docs = self.document_store._fulltext_retrieval(
query=query, filters=filters, top_k=top_k, search_path=self.search_path
)
docs = self.document_store._fulltext_retrieval(query=query, top_k=top_k, search_path=self.search_path)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,13 @@ def _fulltext_retrieval(
self,
query: str,
search_path: Union[str, List[str]] = "content",
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
) -> List[Document]:
"""
Find the documents that are exact match provided `query`.
:param query: The text to search in the document store.
:param search_path: Field(s) to search within, e.g., "content" or ["content", "title"].
:param filters: Optional filters.
:param top_k: How many documents to return.
:returns: A list of Documents matching the full-text search query.
:raises ValueError: If `query` is empty.
Expand All @@ -248,8 +246,6 @@ def _fulltext_retrieval(
msg = "query must not be empty"
raise ValueError(msg)

filters = _normalize_filters(filters) if filters else {}

pipeline = [
{
"$search": {
Expand All @@ -260,7 +256,6 @@ def _fulltext_retrieval(
},
}
},
{"$match": filters if filters else {}},
{"$limit": top_k},
{"$project": {"_id": 0, "content": 1, "meta": 1, "score": {"$meta": "searchScore"}}},
]
Expand Down
64 changes: 3 additions & 61 deletions integrations/mongodb_atlas/tests/test_full_text_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
from haystack.dataclasses import Document
from haystack.document_stores.types import FilterPolicy
from haystack.utils.auth import EnvVarSecret

from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasFullTextRetriever
Expand All @@ -27,40 +26,9 @@ def test_init_default(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store)
assert retriever.document_store == mock_store
assert retriever.filters == {}
assert retriever.top_k == 10
assert retriever.filter_policy == FilterPolicy.REPLACE

retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="merge")
assert retriever.filter_policy == FilterPolicy.MERGE

with pytest.raises(ValueError):
MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="wrong_policy")

def test_init(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
retriever = MongoDBAtlasFullTextRetriever(
document_store=mock_store,
filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"},
top_k=5,
)
assert retriever.document_store == mock_store
assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"}
assert retriever.top_k == 5
assert retriever.filter_policy == FilterPolicy.REPLACE

def test_init_filter_policy_merge(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
retriever = MongoDBAtlasFullTextRetriever(
document_store=mock_store,
filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"},
top_k=5,
filter_policy=FilterPolicy.MERGE,
)
assert retriever.document_store == mock_store
assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"}
assert retriever.top_k == 5
assert retriever.filter_policy == FilterPolicy.MERGE
retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store)

def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required
monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str")
Expand All @@ -71,7 +39,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a
vector_search_index="default",
)

retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, filters={"field": "value"}, top_k=5)
retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, top_k=5)
res = retriever.to_dict()
assert res == {
"type": "haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501
Expand All @@ -89,9 +57,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a
"vector_search_index": "default",
},
},
"filters": {"field": "value"},
"top_k": 5,
"filter_policy": "replace",
},
}

Expand All @@ -114,9 +80,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client
"vector_search_index": "default",
},
},
"filters": {"field": "value"},
"top_k": 5,
"filter_policy": "replace",
},
}

Expand All @@ -128,9 +92,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client
assert document_store.database_name == "haystack_integration_test"
assert document_store.collection_name == "test_collection"
assert document_store.vector_search_index == "default"
assert retriever.filters == {"field": "value"}
assert retriever.top_k == 5
assert retriever.filter_policy == FilterPolicy.REPLACE

def test_run(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
Expand All @@ -140,26 +102,6 @@ def test_run(self):
retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, search_path="desc")
res = retriever.run(query="text")

mock_store._fulltext_retrieval.assert_called_once_with(query="text", filters={}, top_k=10, search_path="desc")

assert res == {"documents": [doc]}

def test_run_merge_policy_filter(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
doc = Document(content="Test doc")
mock_store._fulltext_retrieval.return_value = [doc]

retriever = MongoDBAtlasFullTextRetriever(
document_store=mock_store,
filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"},
filter_policy=FilterPolicy.MERGE,
)
res = retriever.run(query="text", filters={"field": "meta.some_field", "operator": "==", "value": "Test"})
mock_store._fulltext_retrieval.assert_called_once_with(
query="text",
filters={"field": "meta.some_field", "operator": "==", "value": "Test"},
top_k=10,
search_path="content",
)
mock_store._fulltext_retrieval.assert_called_once_with(query="text", top_k=10, search_path="desc")

assert res == {"documents": [doc]}
100 changes: 100 additions & 0 deletions integrations/mongodb_atlas/tests/test_fulltext_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import os

import pytest

from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


@pytest.mark.skipif(
"MONGO_CONNECTION_STRING" not in os.environ,
reason="No MongoDB Atlas connection string provided",
)
@pytest.mark.integration
class TestEmbeddingRetrieval:
def test_basic_fulltext_retrieval(self):
document_store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_fulltext_collection",
vector_search_index="default",
)
query = "crime"
results = document_store._fulltext_retrieval(query=query)
assert len(results) == 1

def test_fulltext_retrieval_custom_path(self):
document_store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_fulltext_collection",
vector_search_index="default",
)
query = "Godfather"
path = "title"
results = document_store._fulltext_retrieval(query=query, search_path=path)
assert len(results) == 1

def test_fulltext_retrieval_multi_paths_and_top_k(self):
document_store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_fulltext_collection",
vector_search_index="default",
)
query = "movie"
paths = ["title", "content"]
results = document_store._fulltext_retrieval(query=query, search_path=paths)
assert len(results) == 2

results = document_store._fulltext_retrieval(query=query, search_path=paths, top_k=1)
assert len(results) == 1


"""
[
{
"title": "The Matrix",
"content": "A hacker discovers that his reality is a simulation in this movie.",
"meta": {
"author": "Wachowskis",
"city": "San Francisco"
}
},
{
"title": "Inception",
"content": "A thief who steals corporate secrets through the use of dream-sharing technology.",
"meta": {
"author": "Christopher Nolan",
"city": "Los Angeles"
}
},
{
"title": "Interstellar",
"content": "A team of explorers travel through a wormhole in space in an attempt
to ensure humanity's survival.",
"meta": {
"author": "Christopher Nolan",
"city": "Houston"
}
},
{
"title": "The Dark Knight",
"content": "When the menace known as the Joker emerges from his mysterious past,
he wreaks havoc on Gotham.",
"meta": {
"author": "Christopher Nolan",
"city": "Gotham"
}
},
{
"title": "The Godfather Movie",
"content": "The aging patriarch of an organized crime dynasty transfers
control of his empire to his reluctant son.",
"meta": {
"author": "Mario Puzo",
"city": "New York"
}
}
]
"""

0 comments on commit e749bdd

Please sign in to comment.