-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
300 additions
and
4 deletions.
There are no files selected for viewing
94 changes: 94 additions & 0 deletions
94
...tlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Any, Dict, List, Literal, Optional | ||
|
||
from haystack import component, default_from_dict, default_to_dict | ||
from haystack.dataclasses import Document | ||
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore | ||
from haystack_integrations.document_stores.mongodb_atlas.document_store import METRIC_TYPES | ||
|
||
|
||
@component | ||
class MongoDBAtlasEmbeddingRetriever: | ||
""" | ||
Retrieves documents from the MongoDBAtlasDocumentStore by embedding similarity. | ||
Needs to be connected to the MongoDBAtlasDocumentStore. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
document_store: MongoDBAtlasDocumentStore, | ||
filters: Optional[Dict[str, Any]] = None, | ||
top_k: int = 10, | ||
similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine", | ||
): | ||
""" | ||
Create the MongoDBAtlasDocumentStore component. | ||
:param document_store: An instance of MongoDBAtlasDocumentStore. | ||
:param filters: Filters applied to the retrieved Documents. Defaults to None. | ||
:param top_k: Maximum number of Documents to return, defaults to 10. | ||
:param similarity: The similarity function to use when searching for similar embeddings. | ||
Defaults to "cosine". Valid values are "cosine", "euclidean", "dotProduct". | ||
""" | ||
if not isinstance(document_store, MongoDBAtlasDocumentStore): | ||
msg = "document_store must be an instance of MongoDBAtlasDocumentStore" | ||
raise ValueError(msg) | ||
|
||
if similarity not in METRIC_TYPES: | ||
msg = f"vector_function must be one of {METRIC_TYPES}" | ||
raise ValueError(msg) | ||
|
||
self.document_store = document_store | ||
self.filters = filters or {} | ||
self.top_k = top_k | ||
self.similarity = similarity | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
return default_to_dict( | ||
self, | ||
filters=self.filters, | ||
top_k=self.top_k, | ||
similarity=self.similarity, | ||
document_store=self.document_store.to_dict(), | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": | ||
data["init_parameters"]["document_store"] = default_from_dict( | ||
MongoDBAtlasDocumentStore, data["init_parameters"]["document_store"] | ||
) | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(documents=List[Document]) | ||
def run( | ||
self, | ||
query_embedding: List[float], | ||
filters: Optional[Dict[str, Any]] = None, | ||
top_k: Optional[int] = None, | ||
similarity: Optional[Literal["euclidean", "cosine", "dotProduct"]] = None, | ||
): | ||
""" | ||
Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings. | ||
:param query_embedding: Embedding of the query. | ||
:param filters: Filters applied to the retrieved Documents. | ||
:param top_k: Maximum number of Documents to return. | ||
:param similarity: The similarity function to use when searching for similar embeddings. | ||
Defaults to the value provided in the constructor. Valid values are "cosine", "euclidean", "dotProduct". | ||
:return: List of Documents similar to `query_embedding`. | ||
""" | ||
filters = filters or self.filters | ||
top_k = top_k or self.top_k | ||
similarity = similarity or self.similarity | ||
|
||
docs = self.document_store.embedding_retrieval( | ||
query_embedding=query_embedding, | ||
filters=filters, | ||
top_k=top_k, | ||
similarity=similarity, | ||
) | ||
return {"documents": docs} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
integrations/mongodb_atlas/tests/test_embedding_retrieval.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import List | ||
from uuid import uuid4 | ||
import os | ||
|
||
import pytest | ||
from haystack.dataclasses.document import Document | ||
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore | ||
from numpy.random import rand | ||
|
||
|
||
@pytest.fixture | ||
def document_store(request): | ||
store = MongoDBAtlasDocumentStore( | ||
database_name="haystack_integration_test", | ||
collection_name=request.node.name + str(uuid4()), | ||
vector_search_index="vector_search_index", | ||
recreate_collection=True, | ||
) | ||
return store | ||
|
||
|
||
@pytest.mark.skipif( | ||
"MONGO_CONNECTION_STRING" not in os.environ, | ||
reason="No MongoDB Atlas connection string provided", | ||
) | ||
class TestEmbeddingRetrieval: | ||
|
||
def test_embedding_retrieval_cosine_similarity(self, document_store: MongoDBAtlasDocumentStore): | ||
query_embedding = [0.1] * 768 | ||
most_similar_embedding = [0.8] * 768 | ||
second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 | ||
another_embedding = rand(768).tolist() | ||
|
||
docs = [ | ||
Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding), | ||
Document(content="2nd best document (cosine sim)", embedding=second_best_embedding), | ||
Document(content="Not very similar document (cosine sim)", embedding=another_embedding), | ||
] | ||
|
||
document_store.write_documents(docs) | ||
|
||
results = document_store.embedding_retrieval( | ||
query_embedding=query_embedding, top_k=2, filters={}, similarity="cosine" | ||
) | ||
assert len(results) == 2 | ||
assert results[0].content == "Most similar document (cosine sim)" | ||
assert results[1].content == "2nd best document (cosine sim)" | ||
assert results[0].score > results[1].score | ||
|
||
def test_embedding_retrieval_dot_product(self, document_store: MongoDBAtlasDocumentStore): | ||
query_embedding = [0.1] * 768 | ||
most_similar_embedding = [0.8] * 768 | ||
second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 | ||
another_embedding = rand(768).tolist() | ||
|
||
docs = [ | ||
Document(content="Most similar document (dot product)", embedding=most_similar_embedding), | ||
Document(content="2nd best document (dot product)", embedding=second_best_embedding), | ||
Document(content="Not very similar document (dot product)", embedding=another_embedding), | ||
] | ||
|
||
document_store.write_documents(docs) | ||
|
||
results = document_store.embedding_retrieval( | ||
query_embedding=query_embedding, top_k=2, filters={}, similarity="dotProduct" | ||
) | ||
assert len(results) == 2 | ||
assert results[0].content == "Most similar document (dot product)" | ||
assert results[1].content == "2nd best document (dot product)" | ||
assert results[0].score > results[1].score | ||
|
||
|
||
def test_embedding_retrieval_euclidean(self, document_store: MongoDBAtlasDocumentStore): | ||
query_embedding = [0.1] * 768 | ||
most_similar_embedding = [0.8] * 768 | ||
second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 | ||
another_embedding = rand(768).tolist() | ||
|
||
docs = [ | ||
Document(content="Most similar document (euclidean)", embedding=most_similar_embedding), | ||
Document(content="2nd best document (euclidean)", embedding=second_best_embedding), | ||
Document(content="Not very similar document (euclidean)", embedding=another_embedding), | ||
] | ||
|
||
document_store.write_documents(docs) | ||
|
||
results = document_store.embedding_retrieval( | ||
query_embedding=query_embedding, top_k=2, filters={}, similarity="euclidean" | ||
) | ||
assert len(results) == 2 | ||
assert results[0].content == "Most similar document (euclidean)" | ||
assert results[1].content == "2nd best document (euclidean)" | ||
assert results[0].score > results[1].score | ||
|
||
def test_empty_query_embedding(self, document_store: MongoDBAtlasDocumentStore): | ||
query_embedding: List[float] = [] | ||
with pytest.raises(ValueError): | ||
document_store.embedding_retrieval(query_embedding=query_embedding) | ||
|
||
def test_query_embedding_wrong_dimension(self, document_store: MongoDBAtlasDocumentStore): | ||
query_embedding = [0.1] * 4 | ||
with pytest.raises(ValueError): | ||
document_store.embedding_retrieval(query_embedding=query_embedding) |