From 7c80cbc139dbf258f4059930a3286ec21e3d08d3 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 15 Feb 2024 17:54:01 +0100 Subject: [PATCH 01/32] initial implementation --- .../mongodb_atlas/embedding_retriever.py | 94 ++++++++++++++++ .../mongodb_atlas/document_store.py | 95 +++++++++++++++- .../tests/test_document_store.py | 9 +- .../tests/test_embedding_retrieval.py | 106 ++++++++++++++++++ 4 files changed, 300 insertions(+), 4 deletions(-) create mode 100644 integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py create mode 100644 integrations/mongodb_atlas/tests/test_embedding_retrieval.py diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py new file mode 100644 index 000000000..3372638b5 --- /dev/null +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Literal, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore +from haystack_integrations.document_stores.mongodb_atlas.document_store import METRIC_TYPES + + +@component +class MongoDBAtlasEmbeddingRetriever: + """ + Retrieves documents from the MongoDBAtlasDocumentStore by embedding similarity. + + Needs to be connected to the MongoDBAtlasDocumentStore. + """ + + def __init__( + self, + *, + document_store: MongoDBAtlasDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine", + ): + """ + Create the MongoDBAtlasDocumentStore component. + + :param document_store: An instance of MongoDBAtlasDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param top_k: Maximum number of Documents to return, defaults to 10. + :param similarity: The similarity function to use when searching for similar embeddings. + Defaults to "cosine". Valid values are "cosine", "euclidean", "dotProduct". + """ + if not isinstance(document_store, MongoDBAtlasDocumentStore): + msg = "document_store must be an instance of MongoDBAtlasDocumentStore" + raise ValueError(msg) + + if similarity not in METRIC_TYPES: + msg = f"vector_function must be one of {METRIC_TYPES}" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.similarity = similarity + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + similarity=self.similarity, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": + data["init_parameters"]["document_store"] = default_from_dict( + MongoDBAtlasDocumentStore, data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + similarity: Optional[Literal["euclidean", "cosine", "dotProduct"]] = None, + ): + """ + Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param similarity: The similarity function to use when searching for similar embeddings. + Defaults to the value provided in the constructor. Valid values are "cosine", "euclidean", "dotProduct". + :return: List of Documents similar to `query_embedding`. + """ + filters = filters or self.filters + top_k = top_k or self.top_k + similarity = similarity or self.similarity + + docs = self.document_store.embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + similarity=similarity, + ) + return {"documents": docs} diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 68c02522e..2ecf4308d 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -5,6 +5,7 @@ import re from typing import Any, Dict, List, Optional, Union +import numpy as np from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import Document from haystack.document_stores.errors import DuplicateDocumentError @@ -18,6 +19,9 @@ logger = logging.getLogger(__name__) +METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] + + class MongoDBAtlasDocumentStore: def __init__( self, @@ -25,6 +29,7 @@ def __init__( mongo_connection_string: Secret = Secret.from_env_var("MONGO_CONNECTION_STRING"), # noqa: B008 database_name: str, collection_name: str, + vector_search_index: str, recreate_collection: bool = False, ): """ @@ -38,7 +43,11 @@ def __init__( This value will be read automatically from the env var "MONGO_CONNECTION_STRING". :param database_name: Name of the database to use. :param collection_name: Name of the collection to use. - :param recreate_collection: Whether to recreate the collection when initializing the document store. + :param vector_search_index: The name of the vector search index to use for vector search operations. + Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \ + See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index + :param recreate_collection: Whether to recreate the collection when initializing the document store. Defaults + to False. """ if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)): msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' @@ -49,6 +58,7 @@ def __init__( self.database_name = database_name self.collection_name = collection_name + self.vector_search_index = vector_search_index self.recreate_collection = recreate_collection self.connection: MongoClient = MongoClient( @@ -75,9 +85,10 @@ def to_dict(self) -> Dict[str, Any]: mongo_connection_string=self.mongo_connection_string.to_dict(), database_name=self.database_name, collection_name=self.collection_name, + vector_search_index=self.vector_search_index, recreate_collection=self.recreate_collection, ) - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore": """ @@ -159,3 +170,83 @@ def delete_documents(self, document_ids: List[str]) -> None: if not document_ids: return self.collection.delete_many(filter={"id": {"$in": document_ids}}) + + def embedding_retrieval( + self, + query_embedding: np.ndarray, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + similarity: str = "cosine", + scale_score: bool = True, + ) -> List[Document]: + """ + Find the documents that are most similar to the provided `query_emb` by using a vector similarity metric. + + :param query_emb: Embedding of the query + :param filters: optional filters (see get_all_documents for description). + :param top_k: How many documents to return. + :param similarity: The similarity function to use. Currently supported: `dotProduct`, `cosine` and `euclidean`. + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a + different value range will be scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + """ + if similarity not in METRIC_TYPES: + raise ValueError( + "MongoDB Atlas currently supports 'dotProduct', 'cosine' and 'euclidean' similarity metrics. \ + Please set 'similarity' to one of the above." + ) + + query_embedding = np.array(query_embedding).astype(np.float32) + + if similarity == "cosine": + self.normalize_embedding(query_embedding) + + pipeline = [ + { + "$vectorSearch": { + "index": self.vector_search_index, + "queryVector": query_embedding.tolist(), + "path": "embedding", + "numCandidates": 100, + "limit": top_k, + } + } + ] + + filters = haystack_filters_to_mongo(filters) + if filters is not None: + pipeline.append({"$match": filters}) + + pipeline.append({"$set": {"score": {"$meta": "vectorSearchScore"}}}) + documents = list(self.collection.aggregate(pipeline)) + + if scale_score: + for doc in documents: + doc["score"] = self.scale_to_unit_interval(doc["score"], similarity) + + documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents] + return documents + + def mongo_doc_to_haystack_doc(mongo_doc: Dict[str, Any]) -> Document: + """ + Converts the dictionary coming out of MongoDB into a Haystack document + + :param mongo_doc: A dictionary representing a document as stored in MongoDB + :return: A Haystack Document object + """ + mongo_doc.pop("_id", None) + return Document.from_dict(mongo_doc) + + def normalize_embedding(self, emb: np.ndarray) -> None: + """ + Performs L2 normalization of a 1D embeddings vector **inplace**. + """ + norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm() + if norm != 0.0: + emb /= norm + + def scale_to_unit_interval(self, score: float, similarity: Optional[str]) -> float: + if similarity == "cosine": + return (score + 1) / 2 + return float(1 / (1 + np.exp(-score / 100))) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 07aeabc0b..274f90a95 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -20,9 +20,10 @@ def document_store(request): store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", collection_name=request.node.name + str(uuid4()), + vector_search_index="vector_search_index", + recreate_collection=True, ) - yield store - store.collection.drop() + return store @pytest.mark.skipif( @@ -56,6 +57,7 @@ def test_to_dict(self, _): document_store = MongoDBAtlasDocumentStore( database_name="database_name", collection_name="collection_name", + vector_search_index="vector_search_index", ) assert document_store.to_dict() == { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", @@ -70,6 +72,7 @@ def test_to_dict(self, _): "database_name": "database_name", "collection_name": "collection_name", "recreate_collection": False, + "vector_search_index": "vector_search_index", }, } @@ -88,6 +91,7 @@ def test_from_dict(self, _): }, "database_name": "database_name", "collection_name": "collection_name", + "vector_search_index": "vector_search_index", "recreate_collection": True, }, } @@ -95,4 +99,5 @@ def test_from_dict(self, _): assert docstore.mongo_connection_string == Secret.from_env_var("MONGO_CONNECTION_STRING") assert docstore.database_name == "database_name" assert docstore.collection_name == "collection_name" + assert docstore.vector_search_index == "vector_search_index" assert docstore.recreate_collection diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py new file mode 100644 index 000000000..50db809ba --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import List +from uuid import uuid4 +import os + +import pytest +from haystack.dataclasses.document import Document +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore +from numpy.random import rand + + +@pytest.fixture +def document_store(request): + store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name=request.node.name + str(uuid4()), + vector_search_index="vector_search_index", + recreate_collection=True, + ) + return store + + +@pytest.mark.skipif( + "MONGO_CONNECTION_STRING" not in os.environ, + reason="No MongoDB Atlas connection string provided", +) +class TestEmbeddingRetrieval: + + def test_embedding_retrieval_cosine_similarity(self, document_store: MongoDBAtlasDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding), + Document(content="2nd best document (cosine sim)", embedding=second_best_embedding), + Document(content="Not very similar document (cosine sim)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store.embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, similarity="cosine" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (cosine sim)" + assert results[1].content == "2nd best document (cosine sim)" + assert results[0].score > results[1].score + + def test_embedding_retrieval_dot_product(self, document_store: MongoDBAtlasDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (dot product)", embedding=most_similar_embedding), + Document(content="2nd best document (dot product)", embedding=second_best_embedding), + Document(content="Not very similar document (dot product)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store.embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, similarity="dotProduct" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (dot product)" + assert results[1].content == "2nd best document (dot product)" + assert results[0].score > results[1].score + + + def test_embedding_retrieval_euclidean(self, document_store: MongoDBAtlasDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (euclidean)", embedding=most_similar_embedding), + Document(content="2nd best document (euclidean)", embedding=second_best_embedding), + Document(content="Not very similar document (euclidean)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store.embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, similarity="euclidean" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (euclidean)" + assert results[1].content == "2nd best document (euclidean)" + assert results[0].score > results[1].score + + def test_empty_query_embedding(self, document_store: MongoDBAtlasDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store.embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: MongoDBAtlasDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(ValueError): + document_store.embedding_retrieval(query_embedding=query_embedding) From 13c15a4bbcdf5e54301fc3c6004f3fc6d807ce9c Mon Sep 17 00:00:00 2001 From: ZanSara Date: Fri, 16 Feb 2024 11:38:57 +0100 Subject: [PATCH 02/32] vector index seems non functional --- .../mongodb_atlas/document_store.py | 16 ++++++++++------ .../mongodb_atlas/tests/test_document_store.py | 6 +++--- .../tests/test_embedding_retrieval.py | 4 ++-- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index b0f4e80fc..316fa2bc9 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -73,6 +73,7 @@ def __init__( if collection_name not in database.list_collection_names(): database.create_collection(self.collection_name) database[self.collection_name].create_index("id", unique=True) + database[self.collection_name].create_index(self.vector_search_index) self.collection = database[self.collection_name] @@ -200,6 +201,7 @@ def embedding_retrieval( if similarity == "cosine": self.normalize_embedding(query_embedding) + filters = haystack_filters_to_mongo(filters) pipeline = [ { "$vectorSearch": { @@ -208,15 +210,17 @@ def embedding_retrieval( "path": "embedding", "numCandidates": 100, "limit": top_k, + #"filter": filters, + } + }, { + '$project': { + '_id': 0, + 'score': { + '$meta': 'vectorSearchScore' + } } } ] - - filters = haystack_filters_to_mongo(filters) - if filters is not None: - pipeline.append({"$match": filters}) - - pipeline.append({"$set": {"score": {"$meta": "vectorSearchScore"}}}) documents = list(self.collection.aggregate(pipeline)) if scale_score: diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 274f90a95..2f1dcd5a4 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -16,11 +16,11 @@ @pytest.fixture -def document_store(request): +def document_store(): store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", - collection_name=request.node.name + str(uuid4()), - vector_search_index="vector_search_index", + collection_name="test_collection", + vector_search_index="vector_index", recreate_collection=True, ) return store diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index 50db809ba..6b16768e0 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -15,8 +15,8 @@ def document_store(request): store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", - collection_name=request.node.name + str(uuid4()), - vector_search_index="vector_search_index", + collection_name="test_collection", + vector_search_index="vector_index", recreate_collection=True, ) return store From cf93be16dbcd13ba2d69a01e4d05e40105a0da9b Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 11:55:20 +0100 Subject: [PATCH 03/32] tests are green for docstore --- .../mongodb_atlas/document_store.py | 32 +++--- .../tests/test_embedding_retrieval.py | 101 +++++++----------- 2 files changed, 53 insertions(+), 80 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 316fa2bc9..5bea56471 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -8,7 +8,7 @@ import numpy as np from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import Document -from haystack.document_stores.errors import DuplicateDocumentError +from haystack.document_stores.errors import DuplicateDocumentError, DocumentStoreError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo @@ -30,7 +30,6 @@ def __init__( database_name: str, collection_name: str, vector_search_index: str, - recreate_collection: bool = False, ): """ Creates a new MongoDBAtlasDocumentStore instance. @@ -42,7 +41,8 @@ def __init__( This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button. This value will be read automatically from the env var "MONGO_CONNECTION_STRING". :param database_name: Name of the database to use. - :param collection_name: Name of the collection to use. + :param collection_name: Name of the collection to use. To use this document store for embedding retrieval, + this collection needs to have a vector search index set up on the `embedding` field. :param vector_search_index: The name of the vector search index to use for vector search operations. Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \ See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index @@ -59,22 +59,14 @@ def __init__( self.database_name = database_name self.collection_name = collection_name self.vector_search_index = vector_search_index - self.recreate_collection = recreate_collection self.connection: MongoClient = MongoClient( resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") ) database = self.connection[self.database_name] - if self.recreate_collection and self.collection_name in database.list_collection_names(): - database[self.collection_name].drop() - - # Implicitly create the collection if it doesn't exist if collection_name not in database.list_collection_names(): - database.create_collection(self.collection_name) - database[self.collection_name].create_index("id", unique=True) - database[self.collection_name].create_index(self.vector_search_index) - + raise ValueError(f"Collection '{collection_name}' does not exist in database '{database_name}'.") self.collection = database[self.collection_name] def to_dict(self) -> Dict[str, Any]: @@ -87,7 +79,6 @@ def to_dict(self) -> Dict[str, Any]: database_name=self.database_name, collection_name=self.collection_name, vector_search_index=self.vector_search_index, - recreate_collection=self.recreate_collection, ) @classmethod @@ -190,6 +181,9 @@ def embedding_retrieval( different value range will be scaled to a range of [0,1], where 1 means extremely relevant. Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. """ + if not query_embedding: + raise ValueError("Query embedding must not be empty") + if similarity not in METRIC_TYPES: raise ValueError( "MongoDB Atlas currently supports 'dotProduct', 'cosine' and 'euclidean' similarity metrics. \ @@ -206,8 +200,8 @@ def embedding_retrieval( { "$vectorSearch": { "index": self.vector_search_index, - "queryVector": query_embedding.tolist(), "path": "embedding", + "queryVector": query_embedding.tolist(), "numCandidates": 100, "limit": top_k, #"filter": filters, @@ -215,14 +209,18 @@ def embedding_retrieval( }, { '$project': { '_id': 0, + 'content': 1, 'score': { '$meta': 'vectorSearchScore' } } } ] - documents = list(self.collection.aggregate(pipeline)) - + try: + documents = list(self.collection.aggregate(pipeline)) + except Exception as e: + raise DocumentStoreError(f"Retrieval of documents from MongoDB Atlas failed: {e}") from e + if scale_score: for doc in documents: doc["score"] = self.scale_to_unit_interval(doc["score"], similarity) @@ -230,7 +228,7 @@ def embedding_retrieval( documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents] return documents - def mongo_doc_to_haystack_doc(mongo_doc: Dict[str, Any]) -> Document: + def mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: """ Converts the dictionary coming out of MongoDB into a Haystack document diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index 6b16768e0..c81976e01 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -2,25 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 from typing import List -from uuid import uuid4 import os import pytest -from haystack.dataclasses.document import Document +from haystack.document_stores.errors import DocumentStoreError from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore -from numpy.random import rand - - -@pytest.fixture -def document_store(request): - store = MongoDBAtlasDocumentStore( - database_name="haystack_integration_test", - collection_name="test_collection", - vector_search_index="vector_index", - recreate_collection=True, - ) - return store - @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, @@ -28,79 +14,68 @@ def document_store(request): ) class TestEmbeddingRetrieval: - def test_embedding_retrieval_cosine_similarity(self, document_store: MongoDBAtlasDocumentStore): + def test_embedding_retrieval_cosine_similarity(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="cosine_index", + ) query_embedding = [0.1] * 768 - most_similar_embedding = [0.8] * 768 - second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 - another_embedding = rand(768).tolist() - - docs = [ - Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding), - Document(content="2nd best document (cosine sim)", embedding=second_best_embedding), - Document(content="Not very similar document (cosine sim)", embedding=another_embedding), - ] - - document_store.write_documents(docs) - results = document_store.embedding_retrieval( query_embedding=query_embedding, top_k=2, filters={}, similarity="cosine" ) assert len(results) == 2 - assert results[0].content == "Most similar document (cosine sim)" - assert results[1].content == "2nd best document (cosine sim)" + assert results[0].content == "Document A" + assert results[1].content == "Document B" assert results[0].score > results[1].score - def test_embedding_retrieval_dot_product(self, document_store: MongoDBAtlasDocumentStore): + def test_embedding_retrieval_dot_product(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="dotProduct_index", + ) query_embedding = [0.1] * 768 - most_similar_embedding = [0.8] * 768 - second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 - another_embedding = rand(768).tolist() - - docs = [ - Document(content="Most similar document (dot product)", embedding=most_similar_embedding), - Document(content="2nd best document (dot product)", embedding=second_best_embedding), - Document(content="Not very similar document (dot product)", embedding=another_embedding), - ] - - document_store.write_documents(docs) - results = document_store.embedding_retrieval( query_embedding=query_embedding, top_k=2, filters={}, similarity="dotProduct" ) assert len(results) == 2 - assert results[0].content == "Most similar document (dot product)" - assert results[1].content == "2nd best document (dot product)" + assert results[0].content == "Document A" + assert results[1].content == "Document B" assert results[0].score > results[1].score - def test_embedding_retrieval_euclidean(self, document_store: MongoDBAtlasDocumentStore): + def test_embedding_retrieval_euclidean(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="euclidean_index", + ) query_embedding = [0.1] * 768 - most_similar_embedding = [0.8] * 768 - second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 - another_embedding = rand(768).tolist() - - docs = [ - Document(content="Most similar document (euclidean)", embedding=most_similar_embedding), - Document(content="2nd best document (euclidean)", embedding=second_best_embedding), - Document(content="Not very similar document (euclidean)", embedding=another_embedding), - ] - - document_store.write_documents(docs) - results = document_store.embedding_retrieval( query_embedding=query_embedding, top_k=2, filters={}, similarity="euclidean" ) assert len(results) == 2 - assert results[0].content == "Most similar document (euclidean)" - assert results[1].content == "2nd best document (euclidean)" + assert results[0].content == "Document C" + assert results[1].content == "Document B" assert results[0].score > results[1].score - def test_empty_query_embedding(self, document_store: MongoDBAtlasDocumentStore): + def test_empty_query_embedding(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="cosine_index", + ) query_embedding: List[float] = [] with pytest.raises(ValueError): document_store.embedding_retrieval(query_embedding=query_embedding) - def test_query_embedding_wrong_dimension(self, document_store: MongoDBAtlasDocumentStore): + def test_query_embedding_wrong_dimension(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="cosine_index", + ) query_embedding = [0.1] * 4 - with pytest.raises(ValueError): + with pytest.raises(DocumentStoreError): document_store.embedding_retrieval(query_embedding=query_embedding) From f0fe7d485388d147754089015097efceef13e3da Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 12:35:13 +0100 Subject: [PATCH 04/32] tests green --- .../retrievers/mongodb_atlas/__init__.py | 5 + .../mongodb_atlas/embedding_retriever.py | 16 +-- .../tests/test_document_store.py | 57 +++++----- .../tests/test_embedding_retrieval.py | 6 +- .../mongodb_atlas/tests/test_retriever.py | 103 ++++++++++++++++++ 5 files changed, 144 insertions(+), 43 deletions(-) create mode 100644 integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py create mode 100644 integrations/mongodb_atlas/tests/test_retriever.py diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py new file mode 100644 index 000000000..da96677fb --- /dev/null +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py @@ -0,0 +1,5 @@ +from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever + +__all__ = [ + "MongoDBAtlasEmbeddingRetriever" +] \ No newline at end of file diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 3372638b5..4eb4d4298 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Literal, Optional from haystack import component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace from haystack.dataclasses import Document from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore from haystack_integrations.document_stores.mongodb_atlas.document_store import METRIC_TYPES @@ -23,7 +24,6 @@ def __init__( document_store: MongoDBAtlasDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine", ): """ Create the MongoDBAtlasDocumentStore component. @@ -31,33 +31,26 @@ def __init__( :param document_store: An instance of MongoDBAtlasDocumentStore. :param filters: Filters applied to the retrieved Documents. Defaults to None. :param top_k: Maximum number of Documents to return, defaults to 10. - :param similarity: The similarity function to use when searching for similar embeddings. - Defaults to "cosine". Valid values are "cosine", "euclidean", "dotProduct". """ if not isinstance(document_store, MongoDBAtlasDocumentStore): msg = "document_store must be an instance of MongoDBAtlasDocumentStore" raise ValueError(msg) - if similarity not in METRIC_TYPES: - msg = f"vector_function must be one of {METRIC_TYPES}" - raise ValueError(msg) - self.document_store = document_store self.filters = filters or {} self.top_k = top_k - self.similarity = similarity def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, filters=self.filters, top_k=self.top_k, - similarity=self.similarity, document_store=self.document_store.to_dict(), ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": + deserialize_secrets_inplace(data["init_parameters"]["document_store"]["init_parameters"], keys=["mongo_connection_string"]) data["init_parameters"]["document_store"] = default_from_dict( MongoDBAtlasDocumentStore, data["init_parameters"]["document_store"] ) @@ -69,7 +62,6 @@ def run( query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, - similarity: Optional[Literal["euclidean", "cosine", "dotProduct"]] = None, ): """ Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings. @@ -77,18 +69,14 @@ def run( :param query_embedding: Embedding of the query. :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. - :param similarity: The similarity function to use when searching for similar embeddings. - Defaults to the value provided in the constructor. Valid values are "cosine", "euclidean", "dotProduct". :return: List of Documents similar to `query_embedding`. """ filters = filters or self.filters top_k = top_k or self.top_k - similarity = similarity or self.similarity docs = self.document_store.embedding_retrieval( query_embedding=query_embedding, filters=filters, top_k=top_k, - similarity=similarity, ) return {"documents": docs} diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 2f1dcd5a4..7f8fb5d80 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -13,17 +13,32 @@ from haystack.utils import Secret from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore from pandas import DataFrame +from pymongo import MongoClient # type: ignore +from pymongo.driver_info import DriverInfo # type: ignore @pytest.fixture def document_store(): + database_name="haystack_integration_test" + collection_name="test_collection" + + connection: MongoClient = MongoClient( + os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + ) + database = connection[database_name] + if collection_name in database.list_collection_names(): + database[collection_name].drop() + database.create_collection(collection_name) + database[collection_name].create_index("id", unique=True) + store = MongoDBAtlasDocumentStore( - database_name="haystack_integration_test", - collection_name="test_collection", - vector_search_index="vector_index", - recreate_collection=True, + database_name=database_name, + collection_name=collection_name, + vector_search_index="cosine_index", ) - return store + yield store + database[collection_name].drop() + @pytest.mark.skipif( @@ -52,13 +67,7 @@ def test_write_dataframe(self, document_store: MongoDBAtlasDocumentStore): retrieved_docs = document_store.filter_documents() assert retrieved_docs == docs - @patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient") - def test_to_dict(self, _): - document_store = MongoDBAtlasDocumentStore( - database_name="database_name", - collection_name="collection_name", - vector_search_index="vector_search_index", - ) + def test_to_dict(self, document_store): assert document_store.to_dict() == { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", "init_parameters": { @@ -69,15 +78,13 @@ def test_to_dict(self, _): "strict": True, "type": "env_var", }, - "database_name": "database_name", - "collection_name": "collection_name", - "recreate_collection": False, - "vector_search_index": "vector_search_index", + "database_name": "haystack_integration_test", + "collection_name": "test_collection", + "vector_search_index": "cosine_index", }, } - @patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient") - def test_from_dict(self, _): + def test_from_dict(self): docstore = MongoDBAtlasDocumentStore.from_dict( { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", @@ -89,15 +96,13 @@ def test_from_dict(self, _): "strict": True, "type": "env_var", }, - "database_name": "database_name", - "collection_name": "collection_name", - "vector_search_index": "vector_search_index", - "recreate_collection": True, + "database_name": "haystack_integration_test", + "collection_name": "test_embeddings_collection", + "vector_search_index": "cosine_index", }, } ) assert docstore.mongo_connection_string == Secret.from_env_var("MONGO_CONNECTION_STRING") - assert docstore.database_name == "database_name" - assert docstore.collection_name == "collection_name" - assert docstore.vector_search_index == "vector_search_index" - assert docstore.recreate_collection + assert docstore.database_name == "haystack_integration_test" + assert docstore.collection_name == "test_embeddings_collection" + assert docstore.vector_search_index == "cosine_index" diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index c81976e01..7b56b5641 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -22,7 +22,7 @@ def test_embedding_retrieval_cosine_similarity(self): ) query_embedding = [0.1] * 768 results = document_store.embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={}, similarity="cosine" + query_embedding=query_embedding, top_k=2, filters={} ) assert len(results) == 2 assert results[0].content == "Document A" @@ -37,7 +37,7 @@ def test_embedding_retrieval_dot_product(self): ) query_embedding = [0.1] * 768 results = document_store.embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={}, similarity="dotProduct" + query_embedding=query_embedding, top_k=2, filters={} ) assert len(results) == 2 assert results[0].content == "Document A" @@ -53,7 +53,7 @@ def test_embedding_retrieval_euclidean(self): ) query_embedding = [0.1] * 768 results = document_store.embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={}, similarity="euclidean" + query_embedding=query_embedding, top_k=2, filters={} ) assert len(results) == 2 assert results[0].content == "Document C" diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py new file mode 100644 index 000000000..8f9e06ad1 --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import pytest +from unittest.mock import Mock + +from haystack.dataclasses import Document +from haystack.utils.auth import EnvVarSecret +from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +@pytest.fixture +def document_store(): + store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="cosine_index", + ) + return store + + +class TestRetriever: + def test_init_default(self, document_store: MongoDBAtlasDocumentStore): + retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store) + assert retriever.document_store == document_store + assert retriever.filters == {} + assert retriever.top_k == 10 + + def test_init(self, document_store: MongoDBAtlasDocumentStore): + retriever = MongoDBAtlasEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, + ) + assert retriever.document_store == document_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + + def test_to_dict(self, document_store: MongoDBAtlasDocumentStore): + retriever = MongoDBAtlasEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5 + ) + res = retriever.to_dict() + t = "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever" + assert res == { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", + "init_parameters": { + "mongo_connection_string": {"env_vars": ["MONGO_CONNECTION_STRING"], "strict": True, "type": "env_var"}, + "database_name": "haystack_integration_test", + "collection_name": "test_embeddings_collection", + "vector_search_index": "cosine_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + def test_from_dict(self): + t = "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", + "init_parameters": { + "mongo_connection_string": {"env_vars": ["MONGO_CONNECTION_STRING"], "strict": True, "type": "env_var"}, + "database_name": "haystack_integration_test", + "collection_name": "test_embeddings_collection", + "vector_search_index": "cosine_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + retriever = MongoDBAtlasEmbeddingRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, MongoDBAtlasDocumentStore) + assert isinstance(document_store.mongo_connection_string, EnvVarSecret) + assert document_store.database_name == "haystack_integration_test" + assert document_store.collection_name == "test_embeddings_collection" + assert document_store.vector_search_index == "cosine_index" + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + + def test_run(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store.embedding_retrieval.return_value = [doc] + + retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.3, 0.5]) + + mock_store.embedding_retrieval.assert_called_once_with( + query_embedding=[0.3, 0.5], filters={}, top_k=10 + ) + + assert res == {"documents": [doc]} From badb73e65d14abba8245b6b29d43b8fc0be1f421 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 12:38:55 +0100 Subject: [PATCH 05/32] no parallel tests --- integrations/pgvector/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index 178d9f7e8..caf5fe305 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -162,7 +162,7 @@ ban-relative-imports = "parents" [tool.coverage.run] source = ["haystack_integrations"] branch = true -parallel = true +parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] From ced84f07d787474acffd32259493cc35d199dd73 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 12:51:11 +0100 Subject: [PATCH 06/32] lint --- .../retrievers/mongodb_atlas/__init__.py | 4 +- .../mongodb_atlas/embedding_retriever.py | 9 +-- .../mongodb_atlas/document_store.py | 66 ++++--------------- .../tests/test_document_store.py | 7 +- .../tests/test_embedding_retrieval.py | 18 ++--- .../mongodb_atlas/tests/test_retriever.py | 34 ++++++---- 6 files changed, 47 insertions(+), 91 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py index da96677fb..fed0a4c28 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py @@ -1,5 +1,3 @@ from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever -__all__ = [ - "MongoDBAtlasEmbeddingRetriever" -] \ No newline at end of file +__all__ = ["MongoDBAtlasEmbeddingRetriever"] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 4eb4d4298..4dc522905 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict -from haystack.utils import Secret, deserialize_secrets_inplace from haystack.dataclasses import Document +from haystack.utils import deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore -from haystack_integrations.document_stores.mongodb_atlas.document_store import METRIC_TYPES @component @@ -50,7 +49,9 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": - deserialize_secrets_inplace(data["init_parameters"]["document_store"]["init_parameters"], keys=["mongo_connection_string"]) + deserialize_secrets_inplace( + data["init_parameters"]["document_store"]["init_parameters"], keys=["mongo_connection_string"] + ) data["init_parameters"]["document_store"] = default_from_dict( MongoDBAtlasDocumentStore, data["init_parameters"]["document_store"] ) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 5bea56471..978d22471 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -8,7 +8,7 @@ import numpy as np from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import Document -from haystack.document_stores.errors import DuplicateDocumentError, DocumentStoreError +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo @@ -19,9 +19,6 @@ logger = logging.getLogger(__name__) -METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] - - class MongoDBAtlasDocumentStore: def __init__( self, @@ -66,7 +63,8 @@ def __init__( database = self.connection[self.database_name] if collection_name not in database.list_collection_names(): - raise ValueError(f"Collection '{collection_name}' does not exist in database '{database_name}'.") + msg = f"Collection '{collection_name}' does not exist in database '{database_name}'." + raise ValueError(msg) self.collection = database[self.collection_name] def to_dict(self) -> Dict[str, Any]: @@ -80,7 +78,7 @@ def to_dict(self) -> Dict[str, Any]: collection_name=self.collection_name, vector_search_index=self.vector_search_index, ) - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore": """ @@ -166,8 +164,6 @@ def embedding_retrieval( query_embedding: np.ndarray, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - similarity: str = "cosine", - scale_score: bool = True, ) -> List[Document]: """ Find the documents that are most similar to the provided `query_emb` by using a vector similarity metric. @@ -175,25 +171,12 @@ def embedding_retrieval( :param query_emb: Embedding of the query :param filters: optional filters (see get_all_documents for description). :param top_k: How many documents to return. - :param similarity: The similarity function to use. Currently supported: `dotProduct`, `cosine` and `euclidean`. - :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). - If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a - different value range will be scaled to a range of [0,1], where 1 means extremely relevant. - Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. """ if not query_embedding: - raise ValueError("Query embedding must not be empty") - - if similarity not in METRIC_TYPES: - raise ValueError( - "MongoDB Atlas currently supports 'dotProduct', 'cosine' and 'euclidean' similarity metrics. \ - Please set 'similarity' to one of the above." - ) - - query_embedding = np.array(query_embedding).astype(np.float32) + msg = "Query embedding must not be empty" + raise ValueError(msg) - if similarity == "cosine": - self.normalize_embedding(query_embedding) + query_embedding = np.array(query_embedding).astype(np.float32) filters = haystack_filters_to_mongo(filters) pipeline = [ @@ -204,30 +187,20 @@ def embedding_retrieval( "queryVector": query_embedding.tolist(), "numCandidates": 100, "limit": top_k, - #"filter": filters, - } - }, { - '$project': { - '_id': 0, - 'content': 1, - 'score': { - '$meta': 'vectorSearchScore' - } + # "filter": filters, } - } + }, + {"$project": {"_id": 0, "content": 1, "score": {"$meta": "vectorSearchScore"}}}, ] try: documents = list(self.collection.aggregate(pipeline)) except Exception as e: - raise DocumentStoreError(f"Retrieval of documents from MongoDB Atlas failed: {e}") from e - - if scale_score: - for doc in documents: - doc["score"] = self.scale_to_unit_interval(doc["score"], similarity) + msg = f"Retrieval of documents from MongoDB Atlas failed: {e}" + raise DocumentStoreError(msg) from e documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents] return documents - + def mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: """ Converts the dictionary coming out of MongoDB into a Haystack document @@ -237,16 +210,3 @@ def mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: """ mongo_doc.pop("_id", None) return Document.from_dict(mongo_doc) - - def normalize_embedding(self, emb: np.ndarray) -> None: - """ - Performs L2 normalization of a 1D embeddings vector **inplace**. - """ - norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm() - if norm != 0.0: - emb /= norm - - def scale_to_unit_interval(self, score: float, similarity: Optional[str]) -> float: - if similarity == "cosine": - return (score + 1) / 2 - return float(1 / (1 + np.exp(-score / 100))) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 7f8fb5d80..c4ce30472 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 import os -from unittest.mock import patch -from uuid import uuid4 import pytest from haystack.dataclasses.document import ByteStream, Document @@ -19,8 +17,8 @@ @pytest.fixture def document_store(): - database_name="haystack_integration_test" - collection_name="test_collection" + database_name = "haystack_integration_test" + collection_name = "test_collection" connection: MongoClient = MongoClient( os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") @@ -40,7 +38,6 @@ def document_store(): database[collection_name].drop() - @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index 7b56b5641..ba2ad0baa 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -1,19 +1,20 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import List import os +from typing import List import pytest from haystack.document_stores.errors import DocumentStoreError from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) class TestEmbeddingRetrieval: - + def test_embedding_retrieval_cosine_similarity(self): document_store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", @@ -21,9 +22,7 @@ def test_embedding_retrieval_cosine_similarity(self): vector_search_index="cosine_index", ) query_embedding = [0.1] * 768 - results = document_store.embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={} - ) + results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document A" assert results[1].content == "Document B" @@ -36,15 +35,12 @@ def test_embedding_retrieval_dot_product(self): vector_search_index="dotProduct_index", ) query_embedding = [0.1] * 768 - results = document_store.embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={} - ) + results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document A" assert results[1].content == "Document B" assert results[0].score > results[1].score - def test_embedding_retrieval_euclidean(self): document_store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", @@ -52,9 +48,7 @@ def test_embedding_retrieval_euclidean(self): vector_search_index="euclidean_index", ) query_embedding = [0.1] * 768 - results = document_store.embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={} - ) + results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document C" assert results[1].content == "Document B" diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index 8f9e06ad1..9aadcae9a 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import pytest from unittest.mock import Mock +import pytest from haystack.dataclasses import Document from haystack.utils.auth import EnvVarSecret from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever @@ -29,25 +29,29 @@ def test_init_default(self, document_store: MongoDBAtlasDocumentStore): def test_init(self, document_store: MongoDBAtlasDocumentStore): retriever = MongoDBAtlasEmbeddingRetriever( - document_store=document_store, filters={"field": "value"}, top_k=5, + document_store=document_store, + filters={"field": "value"}, + top_k=5, ) assert retriever.document_store == document_store assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 def test_to_dict(self, document_store: MongoDBAtlasDocumentStore): - retriever = MongoDBAtlasEmbeddingRetriever( - document_store=document_store, filters={"field": "value"}, top_k=5 - ) + retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) res = retriever.to_dict() - t = "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever" + t = "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever" # noqa: E501 assert res == { "type": t, "init_parameters": { "document_store": { - "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 "init_parameters": { - "mongo_connection_string": {"env_vars": ["MONGO_CONNECTION_STRING"], "strict": True, "type": "env_var"}, + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", @@ -59,14 +63,18 @@ def test_to_dict(self, document_store: MongoDBAtlasDocumentStore): } def test_from_dict(self): - t = "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever" + t = "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever" # noqa: E501 data = { "type": t, "init_parameters": { "document_store": { - "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 "init_parameters": { - "mongo_connection_string": {"env_vars": ["MONGO_CONNECTION_STRING"], "strict": True, "type": "env_var"}, + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", @@ -96,8 +104,6 @@ def test_run(self): retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store) res = retriever.run(query_embedding=[0.3, 0.5]) - mock_store.embedding_retrieval.assert_called_once_with( - query_embedding=[0.3, 0.5], filters={}, top_k=10 - ) + mock_store.embedding_retrieval.assert_called_once_with(query_embedding=[0.3, 0.5], filters={}, top_k=10) assert res == {"documents": [doc]} From 1e7893ff9b592c32dc8fe6e9c81785d299c669b6 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 12:54:15 +0100 Subject: [PATCH 07/32] use different collections for write tests --- integrations/mongodb_atlas/tests/test_document_store.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index c4ce30472..f44db9954 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import os +from uuid import uuid4 import pytest from haystack.dataclasses.document import ByteStream, Document @@ -18,7 +19,7 @@ @pytest.fixture def document_store(): database_name = "haystack_integration_test" - collection_name = "test_collection" + collection_name = "test_collection_"+str(uuid4()) connection: MongoClient = MongoClient( os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") From 420b51517e8cc16f3e7a507128a35823bbdb88b0 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 13:02:29 +0100 Subject: [PATCH 08/32] fix tests --- integrations/mongodb_atlas/tests/test_document_store.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index f44db9954..f8ba392b1 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -66,7 +66,9 @@ def test_write_dataframe(self, document_store: MongoDBAtlasDocumentStore): assert retrieved_docs == docs def test_to_dict(self, document_store): - assert document_store.to_dict() == { + serialized_store = document_store.to_dict() + assert serialized_store["init_parameters"].pop("collection_name").startswith("test_collection_") + assert serialized_store == { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", "init_parameters": { "mongo_connection_string": { @@ -77,7 +79,6 @@ def test_to_dict(self, document_store): "type": "env_var", }, "database_name": "haystack_integration_test", - "collection_name": "test_collection", "vector_search_index": "cosine_index", }, } From bee21ca58fac1fb52002f610f378521899f78f77 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 14:13:50 +0100 Subject: [PATCH 09/32] black --- integrations/mongodb_atlas/tests/test_document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index f8ba392b1..40fe563fa 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -19,7 +19,7 @@ @pytest.fixture def document_store(): database_name = "haystack_integration_test" - collection_name = "test_collection_"+str(uuid4()) + collection_name = "test_collection_" + str(uuid4()) connection: MongoClient = MongoClient( os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") From 8c1a0048a6035c426f8b78026bd596c929cabc1f Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 14:14:51 +0100 Subject: [PATCH 10/32] docstring --- .../document_stores/mongodb_atlas/document_store.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 978d22471..5d9016d27 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -43,8 +43,6 @@ def __init__( :param vector_search_index: The name of the vector search index to use for vector search operations. Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \ See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index - :param recreate_collection: Whether to recreate the collection when initializing the document store. Defaults - to False. """ if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)): msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' From 8869ec7517b666d29140c60b5657709d87218f33 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 14:17:29 +0100 Subject: [PATCH 11/32] add doc fields --- .../mongodb_atlas/document_store.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 5d9016d27..d16b05408 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -188,7 +188,19 @@ def embedding_retrieval( # "filter": filters, } }, - {"$project": {"_id": 0, "content": 1, "score": {"$meta": "vectorSearchScore"}}}, + { + "$project": { + "_id": 0, + "content": 1, + "dataframe": 1, + "blob": 1, + "meta": 1, + "embedding": 1, + "score": { + "$meta": "vectorSearchScore" + } + } + }, ] try: documents = list(self.collection.aggregate(pipeline)) From 9b4b728c92bd5222c5bdd3d3b91ea2561d2a9870 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 14:48:33 +0100 Subject: [PATCH 12/32] black --- .../document_stores/mongodb_atlas/document_store.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index d16b05408..3bb5b9e5e 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -190,15 +190,13 @@ def embedding_retrieval( }, { "$project": { - "_id": 0, - "content": 1, + "_id": 0, + "content": 1, "dataframe": 1, "blob": 1, "meta": 1, "embedding": 1, - "score": { - "$meta": "vectorSearchScore" - } + "score": {"$meta": "vectorSearchScore"}, } }, ] From ca8631c70fa455b9dfa6e7796ec888735abbe909 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:34:29 +0100 Subject: [PATCH 13/32] Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py Co-authored-by: Madeesh Kannan --- .../retrievers/mongodb_atlas/embedding_retriever.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 4dc522905..a38f1ed88 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -28,8 +28,8 @@ def __init__( Create the MongoDBAtlasDocumentStore component. :param document_store: An instance of MongoDBAtlasDocumentStore. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - :param top_k: Maximum number of Documents to return, defaults to 10. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. """ if not isinstance(document_store, MongoDBAtlasDocumentStore): msg = "document_store must be an instance of MongoDBAtlasDocumentStore" From 670bd09cbf05fc9394d640aea05a64e172427707 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:37:05 +0100 Subject: [PATCH 14/32] Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py --- .../components/retrievers/mongodb_atlas/embedding_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index a38f1ed88..f6f7ac433 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -63,7 +63,7 @@ def run( query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, - ): + ) -> Dict[str, List[Document]]: """ Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings. From a6b6580324e1f7eba6c002209223e5510f16469b Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:38:09 +0100 Subject: [PATCH 15/32] Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py --- .../components/retrievers/mongodb_atlas/embedding_retriever.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index f6f7ac433..7fc3258d8 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -40,6 +40,9 @@ def __init__( self.top_k = top_k def to_dict(self) -> Dict[str, Any]: + """ + Serializes this component into a dictionary. + """ return default_to_dict( self, filters=self.filters, From b433b763f38fc08bc17b0790dc5ba4df46a5babc Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:40:21 +0100 Subject: [PATCH 16/32] Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py --- .../retrievers/mongodb_atlas/embedding_retriever.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 7fc3258d8..44e1185ed 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -52,6 +52,11 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": + """ + Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever. to_dict()` into a `MongoDBAtlasEmbeddingRetriever` instance. + + :param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever. to_dict()` + """ deserialize_secrets_inplace( data["init_parameters"]["document_store"]["init_parameters"], keys=["mongo_connection_string"] ) From 20865f946628cf020ba62591fc3dd3b1bba475b0 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:42:28 +0100 Subject: [PATCH 17/32] Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py --- .../retrievers/mongodb_atlas/embedding_retriever.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 44e1185ed..b02b056b8 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -76,8 +76,8 @@ def run( Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings. :param query_embedding: Embedding of the query. - :param filters: Filters applied to the retrieved Documents. - :param top_k: Maximum number of Documents to return. + :param filters: Filters applied to the retrieved Documents. Overrides the value specified at initialization. + :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. :return: List of Documents similar to `query_embedding`. """ filters = filters or self.filters From 517586b5775e161fbda7cf65a6f152f859919b4d Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:43:07 +0100 Subject: [PATCH 18/32] Update integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py --- .../document_stores/mongodb_atlas/document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 3bb5b9e5e..bef2fe562 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -159,7 +159,7 @@ def delete_documents(self, document_ids: List[str]) -> None: def embedding_retrieval( self, - query_embedding: np.ndarray, + query_embedding: List[float] filters: Optional[Dict[str, Any]] = None, top_k: int = 10, ) -> List[Document]: From 47a79147c6689c817d92e6d6aa9322ae45a1a761 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:44:08 +0100 Subject: [PATCH 19/32] Update integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py Co-authored-by: Madeesh Kannan --- .../document_stores/mongodb_atlas/document_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index bef2fe562..308b53c61 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -164,9 +164,9 @@ def embedding_retrieval( top_k: int = 10, ) -> List[Document]: """ - Find the documents that are most similar to the provided `query_emb` by using a vector similarity metric. + Find the documents that are most similar to the provided `query_embedding` by using a vector similarity metric. - :param query_emb: Embedding of the query + :param query_embedding: Embedding of the query :param filters: optional filters (see get_all_documents for description). :param top_k: How many documents to return. """ From b8011c75ed6fed66a7e57e39eda84efba36fc343 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:44:48 +0100 Subject: [PATCH 20/32] Update integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py Co-authored-by: Madeesh Kannan --- .../document_stores/mongodb_atlas/document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 308b53c61..01bc49cee 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -167,7 +167,7 @@ def embedding_retrieval( Find the documents that are most similar to the provided `query_embedding` by using a vector similarity metric. :param query_embedding: Embedding of the query - :param filters: optional filters (see get_all_documents for description). + :param filters: Optional filters. :param top_k: How many documents to return. """ if not query_embedding: From f6eea73b184155ace2e9770ac6d2e2ce765537e7 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:45:35 +0100 Subject: [PATCH 21/32] Update integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py --- .../document_stores/mongodb_atlas/document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 01bc49cee..a3ee8743b 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -159,7 +159,7 @@ def delete_documents(self, document_ids: List[str]) -> None: def embedding_retrieval( self, - query_embedding: List[float] + query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: int = 10, ) -> List[Document]: From f154ed72b84642fbaa706824fec1268dc1fe060c Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:51:43 +0100 Subject: [PATCH 22/32] black --- .../components/retrievers/mongodb_atlas/embedding_retriever.py | 2 +- integrations/mongodb_atlas/tests/test_document_store.py | 1 - integrations/mongodb_atlas/tests/test_embedding_retrieval.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index b02b056b8..d39a8897f 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -54,7 +54,7 @@ def to_dict(self) -> Dict[str, Any]: def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": """ Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever. to_dict()` into a `MongoDBAtlasEmbeddingRetriever` instance. - + :param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever. to_dict()` """ deserialize_secrets_inplace( diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 40fe563fa..39a4465c1 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -44,7 +44,6 @@ def document_store(): reason="No MongoDB Atlas connection string provided", ) class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): - def test_write_documents(self, document_store: MongoDBAtlasDocumentStore): docs = [Document(content="some text")] assert document_store.write_documents(docs) == 1 diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index ba2ad0baa..aa7790bc7 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -14,7 +14,6 @@ reason="No MongoDB Atlas connection string provided", ) class TestEmbeddingRetrieval: - def test_embedding_retrieval_cosine_similarity(self): document_store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", From fa16e76022112eef1d70afbb7b595a6e3cb8ec87 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 17:59:22 +0100 Subject: [PATCH 23/32] docstring --- .../components/retrievers/mongodb_atlas/embedding_retriever.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index d39a8897f..a20db0d8b 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -53,7 +53,8 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": """ - Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever. to_dict()` into a `MongoDBAtlasEmbeddingRetriever` instance. + Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever. to_dict()` into a + `MongoDBAtlasEmbeddingRetriever` instance. :param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever. to_dict()` """ From b80c4f85458cc163f0e26fa997dce02da0b78f1a Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 18:01:06 +0100 Subject: [PATCH 24/32] black --- .../components/retrievers/mongodb_atlas/embedding_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index a20db0d8b..f11162d68 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -53,7 +53,7 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": """ - Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever. to_dict()` into a + Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever. to_dict()` into a `MongoDBAtlasEmbeddingRetriever` instance. :param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever. to_dict()` From ba06c2a6707a1b82499ab2b17b547654cd6c9eee Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 21 Feb 2024 18:08:35 +0100 Subject: [PATCH 25/32] mypy --- .../retrievers/mongodb_atlas/embedding_retriever.py | 2 +- .../document_stores/mongodb_atlas/document_store.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index f11162d68..db3ed99ac 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -85,7 +85,7 @@ def run( top_k = top_k or self.top_k docs = self.document_store.embedding_retrieval( - query_embedding=query_embedding, + query_embedding_np=query_embedding, filters=filters, top_k=top_k, ) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index a3ee8743b..8151c4820 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -174,7 +174,7 @@ def embedding_retrieval( msg = "Query embedding must not be empty" raise ValueError(msg) - query_embedding = np.array(query_embedding).astype(np.float32) + query_embedding_np = np.array(query_embedding).astype(np.float32) filters = haystack_filters_to_mongo(filters) pipeline = [ @@ -182,7 +182,7 @@ def embedding_retrieval( "$vectorSearch": { "index": self.vector_search_index, "path": "embedding", - "queryVector": query_embedding.tolist(), + "queryVector": query_embedding_np.tolist(), "numCandidates": 100, "limit": top_k, # "filter": filters, From ac425a3a6cd9d67b7d68f8d4cffbf740241c0457 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 22 Feb 2024 14:09:33 +0100 Subject: [PATCH 26/32] rename --- integrations/mongodb_atlas/tests/test_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index 9aadcae9a..e6af92f9e 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -104,6 +104,6 @@ def test_run(self): retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store) res = retriever.run(query_embedding=[0.3, 0.5]) - mock_store.embedding_retrieval.assert_called_once_with(query_embedding=[0.3, 0.5], filters={}, top_k=10) + mock_store.embedding_retrieval.assert_called_once_with(query_embedding_np=[0.3, 0.5], filters={}, top_k=10) assert res == {"documents": [doc]} From 5b1c574f04f2db08fd553afa58ce36fb9538989b Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 22 Feb 2024 16:10:59 +0100 Subject: [PATCH 27/32] deserialization --- .../retrievers/mongodb_atlas/embedding_retriever.py | 7 ++----- integrations/mongodb_atlas/tests/test_retriever.py | 6 ++---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index db3ed99ac..a6a60126d 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -58,11 +58,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": :param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever. to_dict()` """ - deserialize_secrets_inplace( - data["init_parameters"]["document_store"]["init_parameters"], keys=["mongo_connection_string"] - ) - data["init_parameters"]["document_store"] = default_from_dict( - MongoDBAtlasDocumentStore, data["init_parameters"]["document_store"] + data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( + data["init_parameters"]["document_store"] ) return default_from_dict(cls, data) diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index e6af92f9e..887d22b76 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -40,9 +40,8 @@ def test_init(self, document_store: MongoDBAtlasDocumentStore): def test_to_dict(self, document_store: MongoDBAtlasDocumentStore): retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) res = retriever.to_dict() - t = "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever" # noqa: E501 assert res == { - "type": t, + "type": "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever", # noqa: E501 "init_parameters": { "document_store": { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 @@ -63,9 +62,8 @@ def test_to_dict(self, document_store: MongoDBAtlasDocumentStore): } def test_from_dict(self): - t = "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever" # noqa: E501 data = { - "type": t, + "type": "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever", # noqa: E501 "init_parameters": { "document_store": { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 From 407bcf4934898bc91640d38058c5b00fc15b4050 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 22 Feb 2024 16:38:49 +0100 Subject: [PATCH 28/32] unused import --- .../components/retrievers/mongodb_atlas/embedding_retriever.py | 1 - 1 file changed, 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index a6a60126d..e7bb13a0c 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -5,7 +5,6 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document -from haystack.utils import deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore From 5c11b3fce882876c00f9df4a70e2d43fe1e29e33 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Fri, 23 Feb 2024 12:07:16 +0100 Subject: [PATCH 29/32] Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- .../retrievers/mongodb_atlas/embedding_retriever.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index e7bb13a0c..a4ef3b497 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -52,10 +52,10 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": """ - Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever. to_dict()` into a + Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever.to_dict()` into a `MongoDBAtlasEmbeddingRetriever` instance. - :param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever. to_dict()` + :param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever.to_dict()` """ data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( data["init_parameters"]["document_store"] From aa604777e7147935bdcc2813b3681131a49d4e8c Mon Sep 17 00:00:00 2001 From: ZanSara Date: Fri, 23 Feb 2024 12:09:03 +0100 Subject: [PATCH 30/32] change import --- .../document_stores/mongodb_atlas/document_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 8151c4820..7965b1a72 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -5,7 +5,7 @@ import re from typing import Any, Dict, List, Optional, Union -import numpy as np +from numpy import array, float32 from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError @@ -174,7 +174,7 @@ def embedding_retrieval( msg = "Query embedding must not be empty" raise ValueError(msg) - query_embedding_np = np.array(query_embedding).astype(np.float32) + query_embedding_np = array(query_embedding).astype(float32) filters = haystack_filters_to_mongo(filters) pipeline = [ From 7896536b706e76aedbdd567c98a62865f03369b0 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Fri, 23 Feb 2024 12:12:45 +0100 Subject: [PATCH 31/32] ruff --- .../document_stores/mongodb_atlas/document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 7965b1a72..cf7bfa758 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -5,13 +5,13 @@ import re from typing import Any, Dict, List, Optional, Union -from numpy import array, float32 from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo +from numpy import array, float32 from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore from pymongo.driver_info import DriverInfo # type: ignore from pymongo.errors import BulkWriteError # type: ignore From 7d03dff79b2e92c0d84a768de6e80301995a337b Mon Sep 17 00:00:00 2001 From: ZanSara Date: Fri, 23 Feb 2024 12:43:02 +0100 Subject: [PATCH 32/32] remove numpy conversion --- .../document_stores/mongodb_atlas/document_store.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index cf7bfa758..e2f2534f5 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -11,7 +11,6 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo -from numpy import array, float32 from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore from pymongo.driver_info import DriverInfo # type: ignore from pymongo.errors import BulkWriteError # type: ignore @@ -174,15 +173,13 @@ def embedding_retrieval( msg = "Query embedding must not be empty" raise ValueError(msg) - query_embedding_np = array(query_embedding).astype(float32) - filters = haystack_filters_to_mongo(filters) pipeline = [ { "$vectorSearch": { "index": self.vector_search_index, "path": "embedding", - "queryVector": query_embedding_np.tolist(), + "queryVector": query_embedding, "numCandidates": 100, "limit": top_k, # "filter": filters,