diff --git a/libs/partners/mongodb/langchain_mongodb/cache.py b/libs/partners/mongodb/langchain_mongodb/cache.py index b7ace63a1e719..b55840efa9109 100644 --- a/libs/partners/mongodb/langchain_mongodb/cache.py +++ b/libs/partners/mongodb/langchain_mongodb/cache.py @@ -215,6 +215,7 @@ def __init__( embedding: Embeddings, collection_name: str = "default", database_name: str = "default", + index_name: str = "default", wait_until_ready: bool = False, **kwargs: Dict[str, Any], ): @@ -229,13 +230,20 @@ def __init__( Defaults to "default". database_name (str): MongoDB Database where to store texts. Defaults to "default". + index_name: Name of the Atlas Search index. + defaults to 'default' wait_until_ready (bool): Block until MongoDB Atlas finishes indexing the stored text. Hard timeout of 10 seconds. Defaults to False. """ client = _generate_mongo_client(connection_string) self.collection = client[database_name][collection_name] self._wait_until_ready = wait_until_ready - super().__init__(self.collection, embedding, **kwargs) # type: ignore + super().__init__( + collection=self.collection, + embedding=embedding, + index_name=index_name, + **kwargs, # type: ignore + ) def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" diff --git a/libs/partners/mongodb/tests/integration_tests/test_cache.py b/libs/partners/mongodb/tests/integration_tests/test_cache.py index c7618c4ea2836..27e5846ac3d67 100644 --- a/libs/partners/mongodb/tests/integration_tests/test_cache.py +++ b/libs/partners/mongodb/tests/integration_tests/test_cache.py @@ -29,6 +29,7 @@ def llm_cache(cls: Any) -> BaseCache: connection_string=CONN_STRING, collection_name=COLLECTION, database_name=DATABASE, + index_name=INDEX_NAME, wait_until_ready=True, ) )