diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 7462569ddd30d..e1be4588081a2 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Optional from pydantic import ConfigDict -from typing_extensions import TypedDict +from typing_extensions import Self, TypedDict from langchain_core._api import deprecated from langchain_core.documents import Document @@ -180,6 +180,18 @@ def __init_subclass__(cls, **kwargs: Any) -> None: cls._aget_relevant_documents = aswap # type: ignore[assignment] parameters = signature(cls._get_relevant_documents).parameters cls._new_arg_supported = parameters.get("run_manager") is not None + if ( + not cls._new_arg_supported + and cls._aget_relevant_documents == BaseRetriever._aget_relevant_documents + ): + # we need to tolerate no run_manager in _aget_relevant_documents signature + async def _aget_relevant_documents( + self: Self, query: str + ) -> list[Document]: + return await run_in_executor(None, self._get_relevant_documents, query) # type: ignore + + cls._aget_relevant_documents = _aget_relevant_documents # type: ignore[assignment] + # If a V1 retriever broke the interface and expects additional arguments cls._expects_other_args = ( len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 diff --git a/libs/core/tests/unit_tests/test_retrievers.py b/libs/core/tests/unit_tests/test_retrievers.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/standard-tests/tests/unit_tests/test_basic_retriever.py b/libs/standard-tests/tests/unit_tests/test_basic_retriever.py index af5d598c722f2..fb7999a09fb5c 100644 --- a/libs/standard-tests/tests/unit_tests/test_basic_retriever.py +++ b/libs/standard-tests/tests/unit_tests/test_basic_retriever.py @@ -1,6 +1,5 @@ from typing import Any, Type -from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever @@ -11,9 +10,7 @@ class ParrotRetriever(BaseRetriever): parrot_name: str k: int = 3 - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any - ) -> list[Document]: + def _get_relevant_documents(self, query: str, **kwargs: Any) -> list[Document]: k = kwargs.get("k", self.k) return [Document(page_content=f"{self.parrot_name} says: {query}")] * k