diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index bb8b85f5558cc..9cabc9f30b6f4 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -19,7 +19,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... +) -> Callable[[Union[Callable, Runnable]], Union[StructuredTool, Tool]]: ... @overload @@ -33,7 +33,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> BaseTool: ... +) -> Union[StructuredTool, Tool]: ... @overload @@ -46,7 +46,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> BaseTool: ... +) -> Union[StructuredTool, Tool]: ... @overload @@ -55,11 +55,24 @@ def tool( *, return_direct: bool = False, args_schema: Optional[type] = None, - infer_schema: bool = True, + infer_schema: Literal[True] = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, +) -> Callable[[Union[Callable, Runnable]], StructuredTool]: ... + + +@overload +def tool( + name_or_callable: str, + *, + return_direct: bool = False, + args_schema: Optional[type] = None, + infer_schema: Literal[False], response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... +) -> Callable[[Union[Callable, Runnable]], Tool]: ... def tool( @@ -74,7 +87,7 @@ def tool( error_on_invalid_docstring: bool = True, ) -> Union[ BaseTool, - Callable[[Union[Callable, Runnable]], BaseTool], + Callable[[Union[Callable, Runnable]], Union[StructuredTool, Tool]], ]: """Make tools out of functions, can be used with or without arguments. @@ -202,7 +215,7 @@ def invalid_docstring_3(bar: str, baz: int) -> str: def _create_tool_factory( tool_name: str, - ) -> Callable[[Union[Callable, Runnable]], BaseTool]: + ) -> Callable[[Union[Callable, Runnable]], Union[StructuredTool, Tool]]: """Create a decorator that takes a callable and returns a tool. Args: @@ -212,7 +225,9 @@ def _create_tool_factory( A function that takes a callable or Runnable and returns a tool. """ - def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool: + def _tool_factory( + dec_func: Union[Callable, Runnable], + ) -> Union[StructuredTool, Tool]: if isinstance(dec_func, Runnable): runnable = dec_func @@ -325,7 +340,7 @@ def invoke_wrapper( # @tool(parse_docstring=True) # def my_tool(): # pass - def _partial(func: Union[Callable, Runnable]) -> BaseTool: + def _partial(func: Union[Callable, Runnable]) -> Union[StructuredTool, Tool]: """Partial function that takes a callable and returns a tool.""" name_ = func.get_name() if isinstance(func, Runnable) else func.__name__ tool_factory = _create_tool_factory(name_) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index b331abea7da7a..d9b310080c335 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -101,6 +101,17 @@ async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str raise NotImplementedError +def test_tool_typing() -> None: + """Test typing annotations (checked with mypy).""" + + @tool + def multiply(a: int, b: int) -> int: + """multiply two ints""" + return a * b + + _: StructuredTool = multiply + + def test_structured_args() -> None: """Test functionality with structured arguments.""" structured_api = _MockStructuredTool()