Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Mongodb full text search #1140

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever
from haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever import MongoDBAtlasFullTextRetriever

__all__ = ["MongoDBAtlasEmbeddingRetriever"]
__all__ = ["MongoDBAtlasEmbeddingRetriever", "MongoDBAtlasFullTextRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document

from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


@component
class MongoDBAtlasFullTextRetriever:

def __init__(
self,
*,
document_store: MongoDBAtlasDocumentStore,
search_path: Union[str, List[str]] = "content",
top_k: int = 10,
):
"""
Create the MongoDBAtlasFullTextRetriever component.

:param document_store: An instance of MongoDBAtlasDocumentStore.
:param search_path: Field(s) to search within, e.g., "content" or ["content", "title"].
:param top_k: Maximum number of Documents to return.
:raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`.
"""

if not isinstance(document_store, MongoDBAtlasDocumentStore):
msg = "document_store must be an instance of MongoDBAtlasDocumentStore"
raise ValueError(msg)

self.document_store = document_store
self.top_k = top_k
self.search_path = search_path

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
top_k=self.top_k,
document_store=self.document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasFullTextRetriever":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query: str,
top_k: Optional[int] = None,
) -> Dict[str, List[Document]]:
"""
Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided query.

:param query: Text query.
:param top_k: Maximum number of Documents to return. Overrides the value specified at initialization.
:returns: A dictionary with the following keys:
- `documents`: List of Documents most similar to the given `query`
"""
top_k = top_k or self.top_k

docs = self.document_store._fulltext_retrieval(query=query, top_k=top_k, search_path=self.search_path)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,48 @@ def delete_documents(self, document_ids: List[str]) -> None:
return
self.collection.delete_many(filter={"id": {"$in": document_ids}})

def _fulltext_retrieval(
self,
query: str,
search_path: Union[str, List[str]] = "content",
top_k: int = 10,
) -> List[Document]:
"""
Find the documents that are exact match provided `query`.

:param query: The text to search in the document store.
:param search_path: Field(s) to search within, e.g., "content" or ["content", "title"].
:param top_k: How many documents to return.
:returns: A list of Documents matching the full-text search query.
:raises ValueError: If `query` is empty.
:raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails.
"""
if not query:
msg = "query must not be empty"
raise ValueError(msg)

pipeline = [
{
"$search": {
"index": self.vector_search_index,
"text": {
"query": query,
"path": search_path,
},
}
},
{"$limit": top_k},
{"$project": {"_id": 0, "content": 1, "meta": 1, "score": {"$meta": "searchScore"}}},
]
try:
documents = list(self.collection.aggregate(pipeline))
except Exception as e:
msg = f"Retrieval of documents from MongoDB Atlas failed: {e}"
raise DocumentStoreError(msg) from e

documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents]
return documents

def _embedding_retrieval(
self,
query_embedding: List[float],
Expand Down
107 changes: 107 additions & 0 deletions integrations/mongodb_atlas/tests/test_full_text_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from haystack.dataclasses import Document
from haystack.utils.auth import EnvVarSecret

from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasFullTextRetriever
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


class TestFullTextRetriever:
@pytest.fixture
def mock_client(self):
with patch(
"haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient"
) as mock_mongo_client:
mock_connection = MagicMock()
mock_database = MagicMock()
mock_collection_names = MagicMock(return_value=["test_collection"])
mock_database.list_collection_names = mock_collection_names
mock_connection.__getitem__.return_value = mock_database
mock_mongo_client.return_value = mock_connection
yield mock_mongo_client

def test_init_default(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store)
assert retriever.document_store == mock_store
assert retriever.top_k == 10

retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store)

def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required
monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str")

document_store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_collection",
vector_search_index="default",
)

retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, top_k=5)
res = retriever.to_dict()
assert res == {
"type": "haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501
"init_parameters": {
"document_store": {
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501
"init_parameters": {
"mongo_connection_string": {
"env_vars": ["MONGO_CONNECTION_STRING"],
"strict": True,
"type": "env_var",
},
"database_name": "haystack_integration_test",
"collection_name": "test_collection",
"vector_search_index": "default",
},
},
"top_k": 5,
},
}

def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required
monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str")

data = {
"type": "haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501
"init_parameters": {
"document_store": {
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501
"init_parameters": {
"mongo_connection_string": {
"env_vars": ["MONGO_CONNECTION_STRING"],
"strict": True,
"type": "env_var",
},
"database_name": "haystack_integration_test",
"collection_name": "test_collection",
"vector_search_index": "default",
},
},
"top_k": 5,
},
}

retriever = MongoDBAtlasFullTextRetriever.from_dict(data)
document_store = retriever.document_store

assert isinstance(document_store, MongoDBAtlasDocumentStore)
assert isinstance(document_store.mongo_connection_string, EnvVarSecret)
assert document_store.database_name == "haystack_integration_test"
assert document_store.collection_name == "test_collection"
assert document_store.vector_search_index == "default"
assert retriever.top_k == 5

def test_run(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
doc = Document(content="Test doc")
mock_store._fulltext_retrieval.return_value = [doc]

retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, search_path="desc")
res = retriever.run(query="text")

mock_store._fulltext_retrieval.assert_called_once_with(query="text", top_k=10, search_path="desc")

assert res == {"documents": [doc]}
100 changes: 100 additions & 0 deletions integrations/mongodb_atlas/tests/test_fulltext_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import os

import pytest

from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


@pytest.mark.skipif(
"MONGO_CONNECTION_STRING" not in os.environ,
reason="No MongoDB Atlas connection string provided",
)
@pytest.mark.integration
class TestEmbeddingRetrieval:
def test_basic_fulltext_retrieval(self):
document_store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_fulltext_collection",
vector_search_index="default",
)
query = "crime"
results = document_store._fulltext_retrieval(query=query)
assert len(results) == 1

def test_fulltext_retrieval_custom_path(self):
document_store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_fulltext_collection",
vector_search_index="default",
)
query = "Godfather"
path = "title"
results = document_store._fulltext_retrieval(query=query, search_path=path)
assert len(results) == 1

def test_fulltext_retrieval_multi_paths_and_top_k(self):
document_store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_fulltext_collection",
vector_search_index="default",
)
query = "movie"
paths = ["title", "content"]
results = document_store._fulltext_retrieval(query=query, search_path=paths)
assert len(results) == 2

results = document_store._fulltext_retrieval(query=query, search_path=paths, top_k=1)
assert len(results) == 1


"""
[
{
"title": "The Matrix",
"content": "A hacker discovers that his reality is a simulation in this movie.",
"meta": {
"author": "Wachowskis",
"city": "San Francisco"
}
},
{
"title": "Inception",
"content": "A thief who steals corporate secrets through the use of dream-sharing technology.",
"meta": {
"author": "Christopher Nolan",
"city": "Los Angeles"
}
},
{
"title": "Interstellar",
"content": "A team of explorers travel through a wormhole in space in an attempt
to ensure humanity's survival.",
"meta": {
"author": "Christopher Nolan",
"city": "Houston"
}
},
{
"title": "The Dark Knight",
"content": "When the menace known as the Joker emerges from his mysterious past,
he wreaks havoc on Gotham.",
"meta": {
"author": "Christopher Nolan",
"city": "Gotham"
}
},
{
"title": "The Godfather Movie",
"content": "The aging patriarch of an organized crime dynasty transfers
control of his empire to his reluctant son.",
"meta": {
"author": "Mario Puzo",
"city": "New York"
}
}
]

"""
Loading