diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py new file mode 100644 index 000000000..3372638b5 --- /dev/null +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Literal, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore +from haystack_integrations.document_stores.mongodb_atlas.document_store import METRIC_TYPES + + +@component +class MongoDBAtlasEmbeddingRetriever: + """ + Retrieves documents from the MongoDBAtlasDocumentStore by embedding similarity. + + Needs to be connected to the MongoDBAtlasDocumentStore. + """ + + def __init__( + self, + *, + document_store: MongoDBAtlasDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine", + ): + """ + Create the MongoDBAtlasDocumentStore component. + + :param document_store: An instance of MongoDBAtlasDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param top_k: Maximum number of Documents to return, defaults to 10. + :param similarity: The similarity function to use when searching for similar embeddings. + Defaults to "cosine". Valid values are "cosine", "euclidean", "dotProduct". + """ + if not isinstance(document_store, MongoDBAtlasDocumentStore): + msg = "document_store must be an instance of MongoDBAtlasDocumentStore" + raise ValueError(msg) + + if similarity not in METRIC_TYPES: + msg = f"vector_function must be one of {METRIC_TYPES}" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.similarity = similarity + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + similarity=self.similarity, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": + data["init_parameters"]["document_store"] = default_from_dict( + MongoDBAtlasDocumentStore, data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + similarity: Optional[Literal["euclidean", "cosine", "dotProduct"]] = None, + ): + """ + Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param similarity: The similarity function to use when searching for similar embeddings. + Defaults to the value provided in the constructor. Valid values are "cosine", "euclidean", "dotProduct". + :return: List of Documents similar to `query_embedding`. + """ + filters = filters or self.filters + top_k = top_k or self.top_k + similarity = similarity or self.similarity + + docs = self.document_store.embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + similarity=similarity, + ) + return {"documents": docs} diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 68c02522e..2ecf4308d 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -5,6 +5,7 @@ import re from typing import Any, Dict, List, Optional, Union +import numpy as np from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import Document from haystack.document_stores.errors import DuplicateDocumentError @@ -18,6 +19,9 @@ logger = logging.getLogger(__name__) +METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] + + class MongoDBAtlasDocumentStore: def __init__( self, @@ -25,6 +29,7 @@ def __init__( mongo_connection_string: Secret = Secret.from_env_var("MONGO_CONNECTION_STRING"), # noqa: B008 database_name: str, collection_name: str, + vector_search_index: str, recreate_collection: bool = False, ): """ @@ -38,7 +43,11 @@ def __init__( This value will be read automatically from the env var "MONGO_CONNECTION_STRING". :param database_name: Name of the database to use. :param collection_name: Name of the collection to use. - :param recreate_collection: Whether to recreate the collection when initializing the document store. + :param vector_search_index: The name of the vector search index to use for vector search operations. + Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \ + See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index + :param recreate_collection: Whether to recreate the collection when initializing the document store. Defaults + to False. """ if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)): msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' @@ -49,6 +58,7 @@ def __init__( self.database_name = database_name self.collection_name = collection_name + self.vector_search_index = vector_search_index self.recreate_collection = recreate_collection self.connection: MongoClient = MongoClient( @@ -75,9 +85,10 @@ def to_dict(self) -> Dict[str, Any]: mongo_connection_string=self.mongo_connection_string.to_dict(), database_name=self.database_name, collection_name=self.collection_name, + vector_search_index=self.vector_search_index, recreate_collection=self.recreate_collection, ) - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore": """ @@ -159,3 +170,83 @@ def delete_documents(self, document_ids: List[str]) -> None: if not document_ids: return self.collection.delete_many(filter={"id": {"$in": document_ids}}) + + def embedding_retrieval( + self, + query_embedding: np.ndarray, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + similarity: str = "cosine", + scale_score: bool = True, + ) -> List[Document]: + """ + Find the documents that are most similar to the provided `query_emb` by using a vector similarity metric. + + :param query_emb: Embedding of the query + :param filters: optional filters (see get_all_documents for description). + :param top_k: How many documents to return. + :param similarity: The similarity function to use. Currently supported: `dotProduct`, `cosine` and `euclidean`. + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a + different value range will be scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + """ + if similarity not in METRIC_TYPES: + raise ValueError( + "MongoDB Atlas currently supports 'dotProduct', 'cosine' and 'euclidean' similarity metrics. \ + Please set 'similarity' to one of the above." + ) + + query_embedding = np.array(query_embedding).astype(np.float32) + + if similarity == "cosine": + self.normalize_embedding(query_embedding) + + pipeline = [ + { + "$vectorSearch": { + "index": self.vector_search_index, + "queryVector": query_embedding.tolist(), + "path": "embedding", + "numCandidates": 100, + "limit": top_k, + } + } + ] + + filters = haystack_filters_to_mongo(filters) + if filters is not None: + pipeline.append({"$match": filters}) + + pipeline.append({"$set": {"score": {"$meta": "vectorSearchScore"}}}) + documents = list(self.collection.aggregate(pipeline)) + + if scale_score: + for doc in documents: + doc["score"] = self.scale_to_unit_interval(doc["score"], similarity) + + documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents] + return documents + + def mongo_doc_to_haystack_doc(mongo_doc: Dict[str, Any]) -> Document: + """ + Converts the dictionary coming out of MongoDB into a Haystack document + + :param mongo_doc: A dictionary representing a document as stored in MongoDB + :return: A Haystack Document object + """ + mongo_doc.pop("_id", None) + return Document.from_dict(mongo_doc) + + def normalize_embedding(self, emb: np.ndarray) -> None: + """ + Performs L2 normalization of a 1D embeddings vector **inplace**. + """ + norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm() + if norm != 0.0: + emb /= norm + + def scale_to_unit_interval(self, score: float, similarity: Optional[str]) -> float: + if similarity == "cosine": + return (score + 1) / 2 + return float(1 / (1 + np.exp(-score / 100))) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 07aeabc0b..274f90a95 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -20,9 +20,10 @@ def document_store(request): store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", collection_name=request.node.name + str(uuid4()), + vector_search_index="vector_search_index", + recreate_collection=True, ) - yield store - store.collection.drop() + return store @pytest.mark.skipif( @@ -56,6 +57,7 @@ def test_to_dict(self, _): document_store = MongoDBAtlasDocumentStore( database_name="database_name", collection_name="collection_name", + vector_search_index="vector_search_index", ) assert document_store.to_dict() == { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", @@ -70,6 +72,7 @@ def test_to_dict(self, _): "database_name": "database_name", "collection_name": "collection_name", "recreate_collection": False, + "vector_search_index": "vector_search_index", }, } @@ -88,6 +91,7 @@ def test_from_dict(self, _): }, "database_name": "database_name", "collection_name": "collection_name", + "vector_search_index": "vector_search_index", "recreate_collection": True, }, } @@ -95,4 +99,5 @@ def test_from_dict(self, _): assert docstore.mongo_connection_string == Secret.from_env_var("MONGO_CONNECTION_STRING") assert docstore.database_name == "database_name" assert docstore.collection_name == "collection_name" + assert docstore.vector_search_index == "vector_search_index" assert docstore.recreate_collection diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py new file mode 100644 index 000000000..50db809ba --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import List +from uuid import uuid4 +import os + +import pytest +from haystack.dataclasses.document import Document +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore +from numpy.random import rand + + +@pytest.fixture +def document_store(request): + store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name=request.node.name + str(uuid4()), + vector_search_index="vector_search_index", + recreate_collection=True, + ) + return store + + +@pytest.mark.skipif( + "MONGO_CONNECTION_STRING" not in os.environ, + reason="No MongoDB Atlas connection string provided", +) +class TestEmbeddingRetrieval: + + def test_embedding_retrieval_cosine_similarity(self, document_store: MongoDBAtlasDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding), + Document(content="2nd best document (cosine sim)", embedding=second_best_embedding), + Document(content="Not very similar document (cosine sim)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store.embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, similarity="cosine" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (cosine sim)" + assert results[1].content == "2nd best document (cosine sim)" + assert results[0].score > results[1].score + + def test_embedding_retrieval_dot_product(self, document_store: MongoDBAtlasDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (dot product)", embedding=most_similar_embedding), + Document(content="2nd best document (dot product)", embedding=second_best_embedding), + Document(content="Not very similar document (dot product)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store.embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, similarity="dotProduct" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (dot product)" + assert results[1].content == "2nd best document (dot product)" + assert results[0].score > results[1].score + + + def test_embedding_retrieval_euclidean(self, document_store: MongoDBAtlasDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (euclidean)", embedding=most_similar_embedding), + Document(content="2nd best document (euclidean)", embedding=second_best_embedding), + Document(content="Not very similar document (euclidean)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store.embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, similarity="euclidean" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (euclidean)" + assert results[1].content == "2nd best document (euclidean)" + assert results[0].score > results[1].score + + def test_empty_query_embedding(self, document_store: MongoDBAtlasDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store.embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: MongoDBAtlasDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(ValueError): + document_store.embedding_retrieval(query_embedding=query_embedding)