Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: MongoDBAtlasEmbeddingRetriever #427

Merged
merged 34 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7c80cbc
initial implementation
ZanSara Feb 15, 2024
bf4c119
Merge branch 'main' into mongodbatlas-emb-retriever
ZanSara Feb 16, 2024
13c15a4
vector index seems non functional
ZanSara Feb 16, 2024
cf93be1
tests are green for docstore
ZanSara Feb 21, 2024
f0fe7d4
tests green
ZanSara Feb 21, 2024
badb73e
no parallel tests
ZanSara Feb 21, 2024
ced84f0
lint
ZanSara Feb 21, 2024
1e7893f
use different collections for write tests
ZanSara Feb 21, 2024
420b515
fix tests
ZanSara Feb 21, 2024
bee21ca
black
ZanSara Feb 21, 2024
8c1a004
docstring
ZanSara Feb 21, 2024
8869ec7
add doc fields
ZanSara Feb 21, 2024
9b4b728
black
ZanSara Feb 21, 2024
ca8631c
Update integrations/mongodb_atlas/src/haystack_integrations/component…
ZanSara Feb 21, 2024
670bd09
Update integrations/mongodb_atlas/src/haystack_integrations/component…
ZanSara Feb 21, 2024
a6b6580
Update integrations/mongodb_atlas/src/haystack_integrations/component…
ZanSara Feb 21, 2024
b433b76
Update integrations/mongodb_atlas/src/haystack_integrations/component…
ZanSara Feb 21, 2024
20865f9
Update integrations/mongodb_atlas/src/haystack_integrations/component…
ZanSara Feb 21, 2024
517586b
Update integrations/mongodb_atlas/src/haystack_integrations/document_…
ZanSara Feb 21, 2024
47a7914
Update integrations/mongodb_atlas/src/haystack_integrations/document_…
ZanSara Feb 21, 2024
b8011c7
Update integrations/mongodb_atlas/src/haystack_integrations/document_…
ZanSara Feb 21, 2024
f6eea73
Update integrations/mongodb_atlas/src/haystack_integrations/document_…
ZanSara Feb 21, 2024
f154ed7
black
ZanSara Feb 21, 2024
fa16e76
docstring
ZanSara Feb 21, 2024
b80c4f8
black
ZanSara Feb 21, 2024
ba06c2a
mypy
ZanSara Feb 21, 2024
ac425a3
rename
ZanSara Feb 22, 2024
5b1c574
deserialization
ZanSara Feb 22, 2024
407bcf4
unused import
ZanSara Feb 22, 2024
5c11b3f
Update integrations/mongodb_atlas/src/haystack_integrations/component…
ZanSara Feb 23, 2024
aa60477
change import
ZanSara Feb 23, 2024
2126c6e
Merge branch 'mongodbatlas-emb-retriever' of github.com:deepset-ai/ha…
ZanSara Feb 23, 2024
7896536
ruff
ZanSara Feb 23, 2024
7d03dff
remove numpy conversion
ZanSara Feb 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever

__all__ = ["MongoDBAtlasEmbeddingRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


@component
class MongoDBAtlasEmbeddingRetriever:
"""
Retrieves documents from the MongoDBAtlasDocumentStore by embedding similarity.

Needs to be connected to the MongoDBAtlasDocumentStore.
"""

def __init__(
self,
*,
document_store: MongoDBAtlasDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
):
"""
Create the MongoDBAtlasDocumentStore component.

:param document_store: An instance of MongoDBAtlasDocumentStore.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
"""
if not isinstance(document_store, MongoDBAtlasDocumentStore):
msg = "document_store must be an instance of MongoDBAtlasDocumentStore"
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this component into a dictionary.
"""
return default_to_dict(
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
self,
filters=self.filters,
top_k=self.top_k,
document_store=self.document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever":
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
"""
Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever.to_dict()` into a
`MongoDBAtlasEmbeddingRetriever` instance.

:param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever.to_dict()`
"""
data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
top_k: Optional[int] = None,
) -> Dict[str, List[Document]]:
"""
Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings.

:param query_embedding: Embedding of the query.
:param filters: Filters applied to the retrieved Documents. Overrides the value specified at initialization.
:param top_k: Maximum number of Documents to return. Overrides the value specified at initialization.
:return: List of Documents similar to `query_embedding`.
"""
filters = filters or self.filters
top_k = top_k or self.top_k

docs = self.document_store.embedding_retrieval(
query_embedding_np=query_embedding,
filters=filters,
top_k=top_k,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from haystack import default_from_dict, default_to_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.errors import DuplicateDocumentError
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo
Expand All @@ -25,7 +25,7 @@ def __init__(
mongo_connection_string: Secret = Secret.from_env_var("MONGO_CONNECTION_STRING"), # noqa: B008
database_name: str,
collection_name: str,
recreate_collection: bool = False,
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved
vector_search_index: str,
):
"""
Creates a new MongoDBAtlasDocumentStore instance.
Expand All @@ -37,8 +37,11 @@ def __init__(
This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button.
This value will be read automatically from the env var "MONGO_CONNECTION_STRING".
:param database_name: Name of the database to use.
:param collection_name: Name of the collection to use.
:param recreate_collection: Whether to recreate the collection when initializing the document store.
:param collection_name: Name of the collection to use. To use this document store for embedding retrieval,
this collection needs to have a vector search index set up on the `embedding` field.
:param vector_search_index: The name of the vector search index to use for vector search operations.
Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \
See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index
"""
if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)):
msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.'
Expand All @@ -49,21 +52,16 @@ def __init__(

self.database_name = database_name
self.collection_name = collection_name
self.recreate_collection = recreate_collection
self.vector_search_index = vector_search_index

self.connection: MongoClient = MongoClient(
resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
database = self.connection[self.database_name]

if self.recreate_collection and self.collection_name in database.list_collection_names():
database[self.collection_name].drop()

# Implicitly create the collection if it doesn't exist
if collection_name not in database.list_collection_names():
database.create_collection(self.collection_name)
database[self.collection_name].create_index("id", unique=True)

msg = f"Collection '{collection_name}' does not exist in database '{database_name}'."
raise ValueError(msg)
self.collection = database[self.collection_name]

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -75,7 +73,7 @@ def to_dict(self) -> Dict[str, Any]:
mongo_connection_string=self.mongo_connection_string.to_dict(),
database_name=self.database_name,
collection_name=self.collection_name,
recreate_collection=self.recreate_collection,
vector_search_index=self.vector_search_index,
)

@classmethod
Expand Down Expand Up @@ -157,3 +155,63 @@ def delete_documents(self, document_ids: List[str]) -> None:
if not document_ids:
return
self.collection.delete_many(filter={"id": {"$in": document_ids}})

def embedding_retrieval(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
) -> List[Document]:
"""
Find the documents that are most similar to the provided `query_embedding` by using a vector similarity metric.

:param query_embedding: Embedding of the query
:param filters: Optional filters.
:param top_k: How many documents to return.
"""
if not query_embedding:
msg = "Query embedding must not be empty"
raise ValueError(msg)

filters = haystack_filters_to_mongo(filters)
pipeline = [
{
"$vectorSearch": {
"index": self.vector_search_index,
"path": "embedding",
"queryVector": query_embedding,
"numCandidates": 100,
"limit": top_k,
# "filter": filters,
}
},
{
"$project": {
"_id": 0,
"content": 1,
"dataframe": 1,
"blob": 1,
"meta": 1,
"embedding": 1,
"score": {"$meta": "vectorSearchScore"},
}
},
]
try:
documents = list(self.collection.aggregate(pipeline))
except Exception as e:
msg = f"Retrieval of documents from MongoDB Atlas failed: {e}"
raise DocumentStoreError(msg) from e

documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents]
return documents

def mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document:
"""
Converts the dictionary coming out of MongoDB into a Haystack document

:param mongo_doc: A dictionary representing a document as stored in MongoDB
:return: A Haystack Document object
"""
mongo_doc.pop("_id", None)
return Document.from_dict(mongo_doc)
56 changes: 32 additions & 24 deletions integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import patch
from uuid import uuid4

import pytest
Expand All @@ -13,24 +12,38 @@
from haystack.utils import Secret
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore
from pandas import DataFrame
from pymongo import MongoClient # type: ignore
from pymongo.driver_info import DriverInfo # type: ignore


@pytest.fixture
def document_store(request):
def document_store():
database_name = "haystack_integration_test"
collection_name = "test_collection_" + str(uuid4())

connection: MongoClient = MongoClient(
os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
database = connection[database_name]
if collection_name in database.list_collection_names():
database[collection_name].drop()
database.create_collection(collection_name)
database[collection_name].create_index("id", unique=True)

store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name=request.node.name + str(uuid4()),
database_name=database_name,
collection_name=collection_name,
vector_search_index="cosine_index",
)
yield store
store.collection.drop()
database[collection_name].drop()


@pytest.mark.skipif(
"MONGO_CONNECTION_STRING" not in os.environ,
reason="No MongoDB Atlas connection string provided",
)
class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest):

def test_write_documents(self, document_store: MongoDBAtlasDocumentStore):
docs = [Document(content="some text")]
assert document_store.write_documents(docs) == 1
Expand All @@ -51,13 +64,10 @@ def test_write_dataframe(self, document_store: MongoDBAtlasDocumentStore):
retrieved_docs = document_store.filter_documents()
assert retrieved_docs == docs

@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_to_dict(self, _):
document_store = MongoDBAtlasDocumentStore(
database_name="database_name",
collection_name="collection_name",
)
assert document_store.to_dict() == {
def test_to_dict(self, document_store):
serialized_store = document_store.to_dict()
assert serialized_store["init_parameters"].pop("collection_name").startswith("test_collection_")
assert serialized_store == {
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",
"init_parameters": {
"mongo_connection_string": {
Expand All @@ -67,14 +77,12 @@ def test_to_dict(self, _):
"strict": True,
"type": "env_var",
},
"database_name": "database_name",
"collection_name": "collection_name",
"recreate_collection": False,
"database_name": "haystack_integration_test",
"vector_search_index": "cosine_index",
},
}

@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_from_dict(self, _):
def test_from_dict(self):
docstore = MongoDBAtlasDocumentStore.from_dict(
{
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",
Expand All @@ -86,13 +94,13 @@ def test_from_dict(self, _):
"strict": True,
"type": "env_var",
},
"database_name": "database_name",
"collection_name": "collection_name",
"recreate_collection": True,
"database_name": "haystack_integration_test",
"collection_name": "test_embeddings_collection",
"vector_search_index": "cosine_index",
},
}
)
assert docstore.mongo_connection_string == Secret.from_env_var("MONGO_CONNECTION_STRING")
assert docstore.database_name == "database_name"
assert docstore.collection_name == "collection_name"
assert docstore.recreate_collection
assert docstore.database_name == "haystack_integration_test"
assert docstore.collection_name == "test_embeddings_collection"
assert docstore.vector_search_index == "cosine_index"
Loading