diff --git a/libs/community/langchain_community/vectorstores/mongodb_atlas.py b/libs/community/langchain_community/vectorstores/mongodb_atlas.py index 61c901940fa95..87fa45e711cdc 100644 --- a/libs/community/langchain_community/vectorstores/mongodb_atlas.py +++ b/libs/community/langchain_community/vectorstores/mongodb_atlas.py @@ -4,6 +4,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Generator, Iterable, @@ -60,6 +61,7 @@ def __init__( index_name: str = "default", text_key: str = "text", embedding_key: str = "embedding", + relevance_score_fn: str = "cosine", ): """ Args: @@ -70,17 +72,32 @@ def __init__( embedding_key: MongoDB field that will contain the embedding for each document. index_name: Name of the Atlas Search index. + relevance_score_fn: The similarity score used for the index. + Currently supported: Euclidean, cosine, and dot product. """ self._collection = collection self._embedding = embedding self._index_name = index_name self._text_key = text_key self._embedding_key = embedding_key + self._relevance_score_fn = relevance_score_fn @property def embeddings(self) -> Embeddings: return self._embedding + def _select_relevance_score_fn(self) -> Callable[[float], float]: + if self._relevance_score_fn == "euclidean": + return self._euclidean_relevance_score_fn + elif self._relevance_score_fn == "dotProduct": + return self._max_inner_product_relevance_score_fn + elif self._relevance_score_fn == "cosine": + return self._cosine_relevance_score_fn + else: + raise NotImplementedError( + f"No relevance score function for ${self._relevance_score_fn}" + ) + @classmethod def from_connection_string( cls, @@ -198,7 +215,6 @@ def _similarity_search_with_score( def similarity_search_with_score( self, query: str, - *, k: int = 4, pre_filter: Optional[Dict] = None, post_filter_pipeline: Optional[List[Dict]] = None,