Skip to content

Commit

Permalink
initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed Feb 15, 2024
1 parent 14c3de8 commit 7c80cbc
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Literal, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore
from haystack_integrations.document_stores.mongodb_atlas.document_store import METRIC_TYPES


@component
class MongoDBAtlasEmbeddingRetriever:
"""
Retrieves documents from the MongoDBAtlasDocumentStore by embedding similarity.
Needs to be connected to the MongoDBAtlasDocumentStore.
"""

def __init__(
self,
*,
document_store: MongoDBAtlasDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine",
):
"""
Create the MongoDBAtlasDocumentStore component.
:param document_store: An instance of MongoDBAtlasDocumentStore.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
:param top_k: Maximum number of Documents to return, defaults to 10.
:param similarity: The similarity function to use when searching for similar embeddings.
Defaults to "cosine". Valid values are "cosine", "euclidean", "dotProduct".
"""
if not isinstance(document_store, MongoDBAtlasDocumentStore):
msg = "document_store must be an instance of MongoDBAtlasDocumentStore"
raise ValueError(msg)

if similarity not in METRIC_TYPES:
msg = f"vector_function must be one of {METRIC_TYPES}"
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k
self.similarity = similarity

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
similarity=self.similarity,
document_store=self.document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever":
data["init_parameters"]["document_store"] = default_from_dict(
MongoDBAtlasDocumentStore, data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
similarity: Optional[Literal["euclidean", "cosine", "dotProduct"]] = None,
):
"""
Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings.
:param query_embedding: Embedding of the query.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:param similarity: The similarity function to use when searching for similar embeddings.
Defaults to the value provided in the constructor. Valid values are "cosine", "euclidean", "dotProduct".
:return: List of Documents similar to `query_embedding`.
"""
filters = filters or self.filters
top_k = top_k or self.top_k
similarity = similarity or self.similarity

docs = self.document_store.embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
similarity=similarity,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from typing import Any, Dict, List, Optional, Union

import numpy as np
from haystack import default_from_dict, default_to_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.errors import DuplicateDocumentError
Expand All @@ -18,13 +19,17 @@
logger = logging.getLogger(__name__)


METRIC_TYPES = ["euclidean", "cosine", "dotProduct"]


class MongoDBAtlasDocumentStore:
def __init__(
self,
*,
mongo_connection_string: Secret = Secret.from_env_var("MONGO_CONNECTION_STRING"), # noqa: B008
database_name: str,
collection_name: str,
vector_search_index: str,
recreate_collection: bool = False,
):
"""
Expand All @@ -38,7 +43,11 @@ def __init__(
This value will be read automatically from the env var "MONGO_CONNECTION_STRING".
:param database_name: Name of the database to use.
:param collection_name: Name of the collection to use.
:param recreate_collection: Whether to recreate the collection when initializing the document store.
:param vector_search_index: The name of the vector search index to use for vector search operations.
Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \
See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index
:param recreate_collection: Whether to recreate the collection when initializing the document store. Defaults
to False.
"""
if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)):
msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.'
Expand All @@ -49,6 +58,7 @@ def __init__(

self.database_name = database_name
self.collection_name = collection_name
self.vector_search_index = vector_search_index
self.recreate_collection = recreate_collection

self.connection: MongoClient = MongoClient(
Expand All @@ -75,9 +85,10 @@ def to_dict(self) -> Dict[str, Any]:
mongo_connection_string=self.mongo_connection_string.to_dict(),
database_name=self.database_name,
collection_name=self.collection_name,
vector_search_index=self.vector_search_index,
recreate_collection=self.recreate_collection,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore":
"""
Expand Down Expand Up @@ -159,3 +170,83 @@ def delete_documents(self, document_ids: List[str]) -> None:
if not document_ids:
return
self.collection.delete_many(filter={"id": {"$in": document_ids}})

def embedding_retrieval(
self,
query_embedding: np.ndarray,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
similarity: str = "cosine",
scale_score: bool = True,
) -> List[Document]:
"""
Find the documents that are most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query
:param filters: optional filters (see get_all_documents for description).
:param top_k: How many documents to return.
:param similarity: The similarity function to use. Currently supported: `dotProduct`, `cosine` and `euclidean`.
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a
different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
"""
if similarity not in METRIC_TYPES:
raise ValueError(
"MongoDB Atlas currently supports 'dotProduct', 'cosine' and 'euclidean' similarity metrics. \
Please set 'similarity' to one of the above."
)

query_embedding = np.array(query_embedding).astype(np.float32)

if similarity == "cosine":
self.normalize_embedding(query_embedding)

pipeline = [
{
"$vectorSearch": {
"index": self.vector_search_index,
"queryVector": query_embedding.tolist(),
"path": "embedding",
"numCandidates": 100,
"limit": top_k,
}
}
]

filters = haystack_filters_to_mongo(filters)
if filters is not None:
pipeline.append({"$match": filters})

pipeline.append({"$set": {"score": {"$meta": "vectorSearchScore"}}})
documents = list(self.collection.aggregate(pipeline))

if scale_score:
for doc in documents:
doc["score"] = self.scale_to_unit_interval(doc["score"], similarity)

documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents]
return documents

def mongo_doc_to_haystack_doc(mongo_doc: Dict[str, Any]) -> Document:
"""
Converts the dictionary coming out of MongoDB into a Haystack document
:param mongo_doc: A dictionary representing a document as stored in MongoDB
:return: A Haystack Document object
"""
mongo_doc.pop("_id", None)
return Document.from_dict(mongo_doc)

def normalize_embedding(self, emb: np.ndarray) -> None:
"""
Performs L2 normalization of a 1D embeddings vector **inplace**.
"""
norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm()
if norm != 0.0:
emb /= norm

def scale_to_unit_interval(self, score: float, similarity: Optional[str]) -> float:
if similarity == "cosine":
return (score + 1) / 2
return float(1 / (1 + np.exp(-score / 100)))
9 changes: 7 additions & 2 deletions integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ def document_store(request):
store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name=request.node.name + str(uuid4()),
vector_search_index="vector_search_index",
recreate_collection=True,
)
yield store
store.collection.drop()
return store


@pytest.mark.skipif(
Expand Down Expand Up @@ -56,6 +57,7 @@ def test_to_dict(self, _):
document_store = MongoDBAtlasDocumentStore(
database_name="database_name",
collection_name="collection_name",
vector_search_index="vector_search_index",
)
assert document_store.to_dict() == {
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",
Expand All @@ -70,6 +72,7 @@ def test_to_dict(self, _):
"database_name": "database_name",
"collection_name": "collection_name",
"recreate_collection": False,
"vector_search_index": "vector_search_index",
},
}

Expand All @@ -88,11 +91,13 @@ def test_from_dict(self, _):
},
"database_name": "database_name",
"collection_name": "collection_name",
"vector_search_index": "vector_search_index",
"recreate_collection": True,
},
}
)
assert docstore.mongo_connection_string == Secret.from_env_var("MONGO_CONNECTION_STRING")
assert docstore.database_name == "database_name"
assert docstore.collection_name == "collection_name"
assert docstore.vector_search_index == "vector_search_index"
assert docstore.recreate_collection
106 changes: 106 additions & 0 deletions integrations/mongodb_atlas/tests/test_embedding_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import List
from uuid import uuid4
import os

import pytest
from haystack.dataclasses.document import Document
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore
from numpy.random import rand


@pytest.fixture
def document_store(request):
store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name=request.node.name + str(uuid4()),
vector_search_index="vector_search_index",
recreate_collection=True,
)
return store


@pytest.mark.skipif(
"MONGO_CONNECTION_STRING" not in os.environ,
reason="No MongoDB Atlas connection string provided",
)
class TestEmbeddingRetrieval:

def test_embedding_retrieval_cosine_similarity(self, document_store: MongoDBAtlasDocumentStore):
query_embedding = [0.1] * 768
most_similar_embedding = [0.8] * 768
second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65
another_embedding = rand(768).tolist()

docs = [
Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding),
Document(content="2nd best document (cosine sim)", embedding=second_best_embedding),
Document(content="Not very similar document (cosine sim)", embedding=another_embedding),
]

document_store.write_documents(docs)

results = document_store.embedding_retrieval(
query_embedding=query_embedding, top_k=2, filters={}, similarity="cosine"
)
assert len(results) == 2
assert results[0].content == "Most similar document (cosine sim)"
assert results[1].content == "2nd best document (cosine sim)"
assert results[0].score > results[1].score

def test_embedding_retrieval_dot_product(self, document_store: MongoDBAtlasDocumentStore):
query_embedding = [0.1] * 768
most_similar_embedding = [0.8] * 768
second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65
another_embedding = rand(768).tolist()

docs = [
Document(content="Most similar document (dot product)", embedding=most_similar_embedding),
Document(content="2nd best document (dot product)", embedding=second_best_embedding),
Document(content="Not very similar document (dot product)", embedding=another_embedding),
]

document_store.write_documents(docs)

results = document_store.embedding_retrieval(
query_embedding=query_embedding, top_k=2, filters={}, similarity="dotProduct"
)
assert len(results) == 2
assert results[0].content == "Most similar document (dot product)"
assert results[1].content == "2nd best document (dot product)"
assert results[0].score > results[1].score


def test_embedding_retrieval_euclidean(self, document_store: MongoDBAtlasDocumentStore):
query_embedding = [0.1] * 768
most_similar_embedding = [0.8] * 768
second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65
another_embedding = rand(768).tolist()

docs = [
Document(content="Most similar document (euclidean)", embedding=most_similar_embedding),
Document(content="2nd best document (euclidean)", embedding=second_best_embedding),
Document(content="Not very similar document (euclidean)", embedding=another_embedding),
]

document_store.write_documents(docs)

results = document_store.embedding_retrieval(
query_embedding=query_embedding, top_k=2, filters={}, similarity="euclidean"
)
assert len(results) == 2
assert results[0].content == "Most similar document (euclidean)"
assert results[1].content == "2nd best document (euclidean)"
assert results[0].score > results[1].score

def test_empty_query_embedding(self, document_store: MongoDBAtlasDocumentStore):
query_embedding: List[float] = []
with pytest.raises(ValueError):
document_store.embedding_retrieval(query_embedding=query_embedding)

def test_query_embedding_wrong_dimension(self, document_store: MongoDBAtlasDocumentStore):
query_embedding = [0.1] * 4
with pytest.raises(ValueError):
document_store.embedding_retrieval(query_embedding=query_embedding)

0 comments on commit 7c80cbc

Please sign in to comment.