Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[fix] @tool typing #25856

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
33 changes: 24 additions & 9 deletions libs/core/langchain_core/tools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
) -> Callable[[Union[Callable, Runnable]], Union[StructuredTool, Tool]]: ...


@overload
Expand All @@ -33,7 +33,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> BaseTool: ...
) -> Union[StructuredTool, Tool]: ...


@overload
Expand All @@ -46,7 +46,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> BaseTool: ...
) -> Union[StructuredTool, Tool]: ...


@overload
Expand All @@ -55,11 +55,24 @@ def tool(
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
infer_schema: Literal[True] = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Callable[[Union[Callable, Runnable]], StructuredTool]: ...


@overload
def tool(
name_or_callable: str,
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: Literal[False],
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
) -> Callable[[Union[Callable, Runnable]], Tool]: ...


def tool(
Expand All @@ -74,7 +87,7 @@ def tool(
error_on_invalid_docstring: bool = True,
) -> Union[
BaseTool,
Callable[[Union[Callable, Runnable]], BaseTool],
Callable[[Union[Callable, Runnable]], Union[StructuredTool, Tool]],
]:
"""Make tools out of functions, can be used with or without arguments.

Expand Down Expand Up @@ -202,7 +215,7 @@ def invalid_docstring_3(bar: str, baz: int) -> str:

def _create_tool_factory(
tool_name: str,
) -> Callable[[Union[Callable, Runnable]], BaseTool]:
) -> Callable[[Union[Callable, Runnable]], Union[StructuredTool, Tool]]:
"""Create a decorator that takes a callable and returns a tool.

Args:
Expand All @@ -212,7 +225,9 @@ def _create_tool_factory(
A function that takes a callable or Runnable and returns a tool.
"""

def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
def _tool_factory(
dec_func: Union[Callable, Runnable],
) -> Union[StructuredTool, Tool]:
if isinstance(dec_func, Runnable):
runnable = dec_func

Expand Down Expand Up @@ -325,7 +340,7 @@ def invoke_wrapper(
# @tool(parse_docstring=True)
# def my_tool():
# pass
def _partial(func: Union[Callable, Runnable]) -> BaseTool:
def _partial(func: Union[Callable, Runnable]) -> Union[StructuredTool, Tool]:
"""Partial function that takes a callable and returns a tool."""
name_ = func.get_name() if isinstance(func, Runnable) else func.__name__
tool_factory = _create_tool_factory(name_)
Expand Down
11 changes: 11 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str
raise NotImplementedError


def test_tool_typing() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add a test case for each tool overload variant?

"""Test typing annotations (checked with mypy)."""

@tool
def multiply(a: int, b: int) -> int:
"""multiply two ints"""
return a * b

_: StructuredTool = multiply


def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
Expand Down
Loading