From ea482ffd46df0844db018a2f7b922fc3413c53e8 Mon Sep 17 00:00:00 2001 From: alperkaya Date: Wed, 16 Oct 2024 15:40:40 +0200 Subject: [PATCH] initial version --- .../retrievers/mongodb_atlas/__init__.py | 3 +- .../mongodb_atlas/fulltext_retriever.py | 106 ++++++++++++++++++ .../mongodb_atlas/document_store.py | 47 ++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py index fed0a4c28..b551eade8 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py @@ -1,3 +1,4 @@ from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever import MongoDBAtlasFullTextRetriever -__all__ = ["MongoDBAtlasEmbeddingRetriever"] +__all__ = ["MongoDBAtlasEmbeddingRetriever", "MongoDBAtlasFullTextRetriever"] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py new file mode 100644 index 000000000..373185f37 --- /dev/null +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py @@ -0,0 +1,106 @@ +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +@component +class MongoDBAtlasFullTextRetriever: + + def __init__( + self, + *, + document_store: MongoDBAtlasDocumentStore, + search_path: Union[str, List[str]] = "content", + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the MongoDBAtlasFullTextRetriever component. + + :param document_store: An instance of MongoDBAtlasDocumentStore. + :param search_path: Field(s) to search within, e.g., "content" or ["content", "title"]. + :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are + included in the configuration of the `vector_search_index`. The configuration must be done manually + in the Web UI of MongoDB Atlas. + :param top_k: Maximum number of Documents to return. + :param filter_policy: Policy to determine how filters are applied. + :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. + """ + + if not isinstance(document_store, MongoDBAtlasDocumentStore): + msg = "document_store must be an instance of MongoDBAtlasDocumentStore" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.search_path = search_path + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + filter_policy=self.filter_policy.value, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasFullTextRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ) -> Dict[str, List[Document]]: + """ + Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided query. + + :param query: Text query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. + :returns: A dictionary with the following keys: + - `documents`: List of Documents most similar to the given `query` + """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + top_k = top_k or self.top_k + + docs = self.document_store._fulltext_retrieval( + query=query, filters=filters, top_k=top_k, search_path=self.search_path + ) + return {"documents": docs} diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 79caa15f8..3a4a240b6 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -226,6 +226,53 @@ def delete_documents(self, document_ids: List[str]) -> None: return self.collection.delete_many(filter={"id": {"$in": document_ids}}) + def _fulltext_retrieval( + self, + query: str, + search_path: Union[str, List[str]] = "content", + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ) -> List[Document]: + """ + Find the documents that are exact match provided `query`. + + :param query: The text to search in the document store. + :param search_path: Field(s) to search within, e.g., "content" or ["content", "title"]. + :param filters: Optional filters. + :param top_k: How many documents to return. + :returns: A list of Documents matching the full-text search query. + :raises ValueError: If `query` is empty. + :raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails. + """ + if not query: + msg = "query must not be empty" + raise ValueError(msg) + + filters = _normalize_filters(filters) if filters else {} + + pipeline = [ + { + "$search": { + "index": self.vector_search_index, + "text": { + "query": query, + "path": search_path, + }, + } + }, + {"$match": filters if filters else {}}, + {"$limit": top_k}, + {"$project": {"_id": 0, "content": 1, "meta": 1, "score": {"$meta": "searchScore"}}}, + ] + try: + documents = list(self.collection.aggregate(pipeline)) + except Exception as e: + msg = f"Retrieval of documents from MongoDB Atlas failed: {e}" + raise DocumentStoreError(msg) from e + + documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] + return documents + def _embedding_retrieval( self, query_embedding: List[float],