Skip to content

Commit

Permalink
docs: Mongo atlas (#534)
Browse files Browse the repository at this point in the history
* Improve docs

* Pylint, update tests

* Apply suggestions from code review

Co-authored-by: Tobias Wochinger <[email protected]>

* apply suggestions from review

* make embedding_retrieval an internal method

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
Co-authored-by: Tobias Wochinger <[email protected]>
  • Loading branch information
3 people authored Mar 6, 2024
1 parent cb28fee commit 5321d3a
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,28 @@ class MongoDBAtlasEmbeddingRetriever:
"""
Retrieves documents from the MongoDBAtlasDocumentStore by embedding similarity.
Needs to be connected to the MongoDBAtlasDocumentStore.
The similarity is dependent on the vector_search_index used in the MongoDBAtlasDocumentStore and the chosen metric
during the creation of the index (i.e. cosine, dot product, or euclidean). See MongoDBAtlasDocumentStore for more
information.
Usage example:
```python
import numpy as np
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore
from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever
store = MongoDBAtlasDocumentStore(database_name="haystack_integration_test",
collection_name="test_embeddings_collection",
vector_search_index="cosine_index")
retriever = MongoDBAtlasEmbeddingRetriever(document_store=store)
results = retriever.run(query_embedding=np.random.random(768).tolist())
print(results["documents"])
```
The example above retrieves the 10 most similar documents to a random query embedding from the
MongoDBAtlasDocumentStore. Note that dimensions of the query_embedding must match the dimensions of the embeddings
stored in the MongoDBAtlasDocumentStore.
"""

def __init__(
Expand All @@ -29,6 +50,8 @@ def __init__(
:param document_store: An instance of MongoDBAtlasDocumentStore.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`.
"""
if not isinstance(document_store, MongoDBAtlasDocumentStore):
msg = "document_store must be an instance of MongoDBAtlasDocumentStore"
Expand All @@ -40,7 +63,10 @@ def __init__(

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this component into a dictionary.
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
Expand All @@ -52,10 +78,12 @@ def to_dict(self) -> Dict[str, Any]:
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever":
"""
Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever.to_dict()` into a
`MongoDBAtlasEmbeddingRetriever` instance.
Deserializes the component from a dictionary.
:param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever.to_dict()`
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict(
data["init_parameters"]["document_store"]
Expand All @@ -70,18 +98,19 @@ def run(
top_k: Optional[int] = None,
) -> Dict[str, List[Document]]:
"""
Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings.
Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided embedding similarity.
:param query_embedding: Embedding of the query.
:param filters: Filters applied to the retrieved Documents. Overrides the value specified at initialization.
:param top_k: Maximum number of Documents to return. Overrides the value specified at initialization.
:return: List of Documents similar to `query_embedding`.
:returns: A dictionary with the following keys:
- `documents`: List of Documents most similar to the given `query_embedding`
"""
filters = filters or self.filters
top_k = top_k or self.top_k

docs = self.document_store.embedding_retrieval(
query_embedding_np=query_embedding,
docs = self.document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,39 @@


class MongoDBAtlasDocumentStore:
"""
MongoDBAtlasDocumentStore is a DocumentStore implementation that uses [MongoDB Atlas](https://www.mongodb.com/atlas/database).
service that is easy to deploy, operate, and scale.
To connect to MongoDB Atlas, you need to provide a connection string in the format:
"mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}".
This connection string can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button, selecting
Python as the driver, and copying the connection string. The connection string can be provided as an environment
variable `MONGO_CONNECTION_STRING` or directly as a parameter to the `MongoDBAtlasDocumentStore` constructor.
After providing the connection string, you'll need to specify the `database_name` and `collection_name` to use.
Most likely that you'll create these via the MongoDB Atlas web UI but one can also create them via the MongoDB
Python driver. Creating databases and collections is beyond the scope of MongoDBAtlasDocumentStore. The primary
purpose of this document store is to read and write documents to an existing collection.
The last parameter users needs to provide is a `vector_search_index` - used for vector search operations. This index
can support a chosen metric (i.e. cosine, dot product, or euclidean) and can be created in the Atlas web UI.
For more details on MongoDB Atlas, see the official
MongoDB Atlas [documentation](https://www.mongodb.com/docs/atlas/getting-started/)
Usage example:
```python
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore
store = MongoDBAtlasDocumentStore(database_name="your_existing_db",
collection_name="your_existing_collection",
vector_search_index="your_existing_index")
print(store.count_documents())
```
"""

def __init__(
self,
*,
Expand All @@ -30,8 +63,6 @@ def __init__(
"""
Creates a new MongoDBAtlasDocumentStore instance.
This Document Store uses MongoDB Atlas as a backend (https://www.mongodb.com/docs/atlas/getting-started/).
:param mongo_connection_string: MongoDB Atlas connection string in the format:
"mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}".
This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button.
Expand All @@ -41,7 +72,10 @@ def __init__(
this collection needs to have a vector search index set up on the `embedding` field.
:param vector_search_index: The name of the vector search index to use for vector search operations.
Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \
See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index
For more details refer to MongoDB
Atlas [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index)
:raises ValueError: If the collection name contains invalid characters.
"""
if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)):
msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.'
Expand All @@ -66,7 +100,10 @@ def __init__(

def to_dict(self) -> Dict[str, Any]:
"""
Utility function that serializes this Document Store's configuration into a dictionary.
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
Expand All @@ -79,14 +116,21 @@ def to_dict(self) -> Dict[str, Any]:
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore":
"""
Utility function that deserializes this Document Store's configuration from a dictionary.
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["mongo_connection_string"])
return default_from_dict(cls, data)

def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
:returns: The number of documents in the document store.
"""
return self.collection.count_documents({})

Expand All @@ -95,10 +139,10 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
Returns the documents that match the filters provided.
For a detailed specification of the filters,
refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering).
refer to the Haystack [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering).
:param filters: The filters to apply. It returns only the documents that match the filters.
:return: A list of Documents that match the given filters.
:returns: A list of Documents that match the given filters.
"""
mongo_filters = haystack_filters_to_mongo(filters)
documents = list(self.collection.find(mongo_filters))
Expand All @@ -108,13 +152,14 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
"""
Writes documents into to PgvectorDocumentStore.
Writes documents into the MongoDB Atlas collection.
:param documents: A list of Documents to write to the document store.
:param policy: The duplicate policy to use when writing documents.
:raises DuplicateDocumentError: If a document with the same id already exists in the document store
:raises DuplicateDocumentError: If a document with the same ID already exists in the document store
and the policy is set to DuplicatePolicy.FAIL (or not specified).
:return: The number of documents written to the document store.
:raises ValueError: If the documents are not of type Document.
:returns: The number of documents written to the document store.
"""

if len(documents) > 0:
Expand Down Expand Up @@ -156,7 +201,7 @@ def delete_documents(self, document_ids: List[str]) -> None:
return
self.collection.delete_many(filter={"id": {"$in": document_ids}})

def embedding_retrieval(
def _embedding_retrieval(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
Expand All @@ -168,6 +213,9 @@ def embedding_retrieval(
:param query_embedding: Embedding of the query
:param filters: Optional filters.
:param top_k: How many documents to return.
:returns: A list of Documents that are most similar to the given `query_embedding`
:raises ValueError: If `query_embedding` is empty.
:raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails.
"""
if not query_embedding:
msg = "Query embedding must not be empty"
Expand Down Expand Up @@ -203,15 +251,15 @@ def embedding_retrieval(
msg = f"Retrieval of documents from MongoDB Atlas failed: {e}"
raise DocumentStoreError(msg) from e

documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents]
documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents]
return documents

def mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document:
def _mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document:
"""
Converts the dictionary coming out of MongoDB into a Haystack document
:param mongo_doc: A dictionary representing a document as stored in MongoDB
:return: A Haystack Document object
:returns: A Haystack Document object
"""
mongo_doc.pop("_id", None)
return Document.from_dict(mongo_doc)
10 changes: 5 additions & 5 deletions integrations/mongodb_atlas/tests/test_embedding_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_embedding_retrieval_cosine_similarity(self):
vector_search_index="cosine_index",
)
query_embedding = [0.1] * 768
results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
assert len(results) == 2
assert results[0].content == "Document A"
assert results[1].content == "Document B"
Expand All @@ -34,7 +34,7 @@ def test_embedding_retrieval_dot_product(self):
vector_search_index="dotProduct_index",
)
query_embedding = [0.1] * 768
results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
assert len(results) == 2
assert results[0].content == "Document A"
assert results[1].content == "Document B"
Expand All @@ -47,7 +47,7 @@ def test_embedding_retrieval_euclidean(self):
vector_search_index="euclidean_index",
)
query_embedding = [0.1] * 768
results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
assert len(results) == 2
assert results[0].content == "Document C"
assert results[1].content == "Document B"
Expand All @@ -61,7 +61,7 @@ def test_empty_query_embedding(self):
)
query_embedding: List[float] = []
with pytest.raises(ValueError):
document_store.embedding_retrieval(query_embedding=query_embedding)
document_store._embedding_retrieval(query_embedding=query_embedding)

def test_query_embedding_wrong_dimension(self):
document_store = MongoDBAtlasDocumentStore(
Expand All @@ -71,4 +71,4 @@ def test_query_embedding_wrong_dimension(self):
)
query_embedding = [0.1] * 4
with pytest.raises(DocumentStoreError):
document_store.embedding_retrieval(query_embedding=query_embedding)
document_store._embedding_retrieval(query_embedding=query_embedding)
4 changes: 2 additions & 2 deletions integrations/mongodb_atlas/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ def test_from_dict(self):
def test_run(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
doc = Document(content="Test doc", embedding=[0.1, 0.2])
mock_store.embedding_retrieval.return_value = [doc]
mock_store._embedding_retrieval.return_value = [doc]

retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store)
res = retriever.run(query_embedding=[0.3, 0.5])

mock_store.embedding_retrieval.assert_called_once_with(query_embedding_np=[0.3, 0.5], filters={}, top_k=10)
mock_store._embedding_retrieval.assert_called_once_with(query_embedding=[0.3, 0.5], filters={}, top_k=10)

assert res == {"documents": [doc]}

0 comments on commit 5321d3a

Please sign in to comment.