Skip to content

Commit

Permalink
community[patch]: Implement similarity_score_threshold for MongoDB Ve…
Browse files Browse the repository at this point in the history
…ctor Store (#14740)

Adds the option for `similarity_score_threshold` when using
`MongoDBAtlasVectorSearch` as a vector store retriever.

Example use:

```
vector_search = MongoDBAtlasVectorSearch.from_documents(...)

qa_retriever = vector_search.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={
        "score_threshold": 0.5,
    }
)

qa = RetrievalQA.from_chain_type(
	llm=OpenAI(), 
	chain_type="stuff", 
	retriever=qa_retriever,
)

docs = qa({"query": "..."})
```

I've tested this feature locally, using a MongoDB Atlas Cluster with a
vector search index.
  • Loading branch information
NoahStapp authored Dec 16, 2023
1 parent dcead81 commit 34e6f3f
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion libs/community/langchain_community/vectorstores/mongodb_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
index_name: str = "default",
text_key: str = "text",
embedding_key: str = "embedding",
relevance_score_fn: str = "cosine",
):
"""
Args:
Expand All @@ -70,17 +72,32 @@ def __init__(
embedding_key: MongoDB field that will contain the embedding for
each document.
index_name: Name of the Atlas Search index.
relevance_score_fn: The similarity score used for the index.
Currently supported: Euclidean, cosine, and dot product.
"""
self._collection = collection
self._embedding = embedding
self._index_name = index_name
self._text_key = text_key
self._embedding_key = embedding_key
self._relevance_score_fn = relevance_score_fn

@property
def embeddings(self) -> Embeddings:
return self._embedding

def _select_relevance_score_fn(self) -> Callable[[float], float]:
if self._relevance_score_fn == "euclidean":
return self._euclidean_relevance_score_fn
elif self._relevance_score_fn == "dotProduct":
return self._max_inner_product_relevance_score_fn
elif self._relevance_score_fn == "cosine":
return self._cosine_relevance_score_fn
else:
raise NotImplementedError(
f"No relevance score function for ${self._relevance_score_fn}"
)

@classmethod
def from_connection_string(
cls,
Expand Down Expand Up @@ -198,7 +215,6 @@ def _similarity_search_with_score(
def similarity_search_with_score(
self,
query: str,
*,
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
Expand Down

0 comments on commit 34e6f3f

Please sign in to comment.