From 34e6f3ff72067af3265341bcea7983c106f15a74 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 15 Dec 2023 16:49:21 -0800 Subject: [PATCH] community[patch]: Implement similarity_score_threshold for MongoDB Vector Store (#14740) Adds the option for `similarity_score_threshold` when using `MongoDBAtlasVectorSearch` as a vector store retriever. Example use: ``` vector_search = MongoDBAtlasVectorSearch.from_documents(...) qa_retriever = vector_search.as_retriever( search_type="similarity_score_threshold", search_kwargs={ "score_threshold": 0.5, } ) qa = RetrievalQA.from_chain_type( llm=OpenAI(), chain_type="stuff", retriever=qa_retriever, ) docs = qa({"query": "..."}) ``` I've tested this feature locally, using a MongoDB Atlas Cluster with a vector search index. --- .../vectorstores/mongodb_atlas.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/vectorstores/mongodb_atlas.py b/libs/community/langchain_community/vectorstores/mongodb_atlas.py index 61c901940fa95..87fa45e711cdc 100644 --- a/libs/community/langchain_community/vectorstores/mongodb_atlas.py +++ b/libs/community/langchain_community/vectorstores/mongodb_atlas.py @@ -4,6 +4,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Generator, Iterable, @@ -60,6 +61,7 @@ def __init__( index_name: str = "default", text_key: str = "text", embedding_key: str = "embedding", + relevance_score_fn: str = "cosine", ): """ Args: @@ -70,17 +72,32 @@ def __init__( embedding_key: MongoDB field that will contain the embedding for each document. index_name: Name of the Atlas Search index. + relevance_score_fn: The similarity score used for the index. + Currently supported: Euclidean, cosine, and dot product. """ self._collection = collection self._embedding = embedding self._index_name = index_name self._text_key = text_key self._embedding_key = embedding_key + self._relevance_score_fn = relevance_score_fn @property def embeddings(self) -> Embeddings: return self._embedding + def _select_relevance_score_fn(self) -> Callable[[float], float]: + if self._relevance_score_fn == "euclidean": + return self._euclidean_relevance_score_fn + elif self._relevance_score_fn == "dotProduct": + return self._max_inner_product_relevance_score_fn + elif self._relevance_score_fn == "cosine": + return self._cosine_relevance_score_fn + else: + raise NotImplementedError( + f"No relevance score function for ${self._relevance_score_fn}" + ) + @classmethod def from_connection_string( cls, @@ -198,7 +215,6 @@ def _similarity_search_with_score( def similarity_search_with_score( self, query: str, - *, k: int = 4, pre_filter: Optional[Dict] = None, post_filter_pipeline: Optional[List[Dict]] = None,