Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: method tools #28160

Closed
wants to merge 12 commits into from
7 changes: 7 additions & 0 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,10 @@ def run(

if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config

if "self" in tool_kwargs:
tool_kwargs["outer_self"] = tool_kwargs.pop("self")

response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
Expand Down Expand Up @@ -767,6 +771,9 @@ async def arun(
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config

if "self" in tool_kwargs:
tool_kwargs["outer_self"] = tool_kwargs.pop("self")

coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
Expand Down
49 changes: 43 additions & 6 deletions libs/core/langchain_core/tools/convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import inspect
from typing import Any, Callable, Literal, Optional, Union, get_type_hints, overload
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
get_type_hints,
overload,
)

from pydantic import BaseModel, Field, create_model

Expand Down Expand Up @@ -73,8 +81,8 @@ def tool(
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Union[
BaseTool,
Callable[[Union[Callable, Runnable]], BaseTool],
Union[BaseTool, property],
ethanglide marked this conversation as resolved.
Show resolved Hide resolved
Callable[[Union[Callable, Runnable]], Union[BaseTool, property]],
]:
"""Make tools out of functions, can be used with or without arguments.

Expand Down Expand Up @@ -102,6 +110,9 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
is_method: Whether the tool is a method. This allows the tool to be used
without manually passing the `self` argument.
Defaults to False.

Returns:
The tool.
Expand Down Expand Up @@ -202,7 +213,7 @@ def invalid_docstring_3(bar: str, baz: int) -> str:

def _create_tool_factory(
tool_name: str,
) -> Callable[[Union[Callable, Runnable]], BaseTool]:
) -> Callable[[Union[Callable, Runnable]], Union[BaseTool, property]]:
"""Create a decorator that takes a callable and returns a tool.

Args:
Expand All @@ -212,7 +223,9 @@ def _create_tool_factory(
A function that takes a callable or Runnable and returns a tool.
"""

def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
def _tool_factory(
dec_func: Union[Callable, Runnable],
) -> Union[BaseTool, property]:
if isinstance(dec_func, Runnable):
runnable = dec_func

Expand Down Expand Up @@ -246,6 +259,28 @@ def invoke_wrapper(
description = None

if infer_schema or args_schema is not None:
if (
not isinstance(dec_func, Runnable)
and "self" in inspect.signature(dec_func).parameters
):

def method_tool(self: Callable) -> StructuredTool:
return StructuredTool.from_function(
func,
coroutine,
name=tool_name,
description=description,
return_direct=return_direct,
args_schema=schema,
infer_schema=infer_schema,
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
outer_self=self,
)

return property(method_tool)

return StructuredTool.from_function(
func,
coroutine,
Expand Down Expand Up @@ -325,7 +360,9 @@ def invoke_wrapper(
# @tool(parse_docstring=True)
# def my_tool():
# pass
def _partial(func: Union[Callable, Runnable]) -> BaseTool:
def _partial(
func: Union[Callable, Runnable],
) -> Union[BaseTool, property]:
"""Partial function that takes a callable and returns a tool."""
name_ = func.get_name() if isinstance(func, Runnable) else func.__name__
tool_factory = _create_tool_factory(name_)
Expand Down
47 changes: 46 additions & 1 deletion libs/core/langchain_core/tools/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Literal,
Optional,
Union,
cast,
)

from pydantic import BaseModel, Field, SkipValidation
Expand Down Expand Up @@ -41,8 +42,41 @@ class StructuredTool(BaseTool):
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
outer_self: Optional[Any] = None
"""The outer self of the tool for methods."""

# --- Runnable ---
def _add_outer_self(
self, input: Union[str, dict, ToolCall]
) -> Union[dict, ToolCall]:
"""Add outer self into arguments for method tools."""

# If input is a string, then it is the first argument
if isinstance(input, str):
args = {"self": self.outer_self}
for x in self.args: # loop should only happen once
args[x] = input
return args

# ToolCall
if "type" in input and input["type"] == "tool_call":
input["args"]["self"] = self.outer_self
return input

# Dict
new_input = cast(dict, input) # to avoid mypy error
new_input["self"] = self.outer_self
return new_input

def invoke(
self,
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if self.outer_self is not None:
input = self._add_outer_self(input)
return super().invoke(input, config, **kwargs)

# TODO: Is this needed?
async def ainvoke(
Expand All @@ -55,14 +89,19 @@ async def ainvoke(
# If the tool does not implement async, fall back to default implementation
return await run_in_executor(config, self.invoke, input, config, **kwargs)

if self.outer_self is not None:
input = self._add_outer_self(input)
return await super().ainvoke(input, config, **kwargs)

# --- Tool ---

@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.model_json_schema()["properties"]
properties = self.args_schema.model_json_schema()["properties"]
if self.outer_self is not None:
properties.pop("self")
return properties

def _run(
self,
Expand All @@ -77,6 +116,8 @@ def _run(
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.func):
kwargs[config_param] = config
if "outer_self" in kwargs:
kwargs["self"] = kwargs.pop("outer_self")
return self.func(*args, **kwargs)
msg = "StructuredTool does not support sync invocation."
raise NotImplementedError(msg)
Expand All @@ -94,6 +135,8 @@ async def _arun(
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.coroutine):
kwargs[config_param] = config
if "outer_self" in kwargs:
kwargs["self"] = kwargs.pop("outer_self")
return await self.coroutine(*args, **kwargs)

# If self.coroutine is None, then this will delegate to the default
Expand All @@ -116,6 +159,7 @@ def from_function(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
outer_self: Optional[Any] = None,
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
Expand Down Expand Up @@ -203,6 +247,7 @@ def add(a: int, b: int) -> int:
description=description_,
return_direct=return_direct,
response_format=response_format,
outer_self=outer_self,
**kwargs,
)

Expand Down
91 changes: 91 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2110,3 +2110,94 @@ def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
return foo.value

assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore


def test_method_tool_self_ref() -> None:
"""Test that a method tool can reference self."""

class A:
def __init__(self, c: int):
self.c = c

@tool
def foo(self, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + self.c

a = A(10)
assert a.foo.invoke({"a": 1, "b": 2}) == 13


def test_method_tool_args() -> None:
"""Test that a method tool's args do not include self."""

class A:
def __init__(self, c: int):
self.c = c

@tool
def foo(self, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + self.c

a = A(10)
assert "self" not in a.foo.args


async def test_method_tool_async() -> None:
"""Test that a method tool can be async."""

class A:
def __init__(self, c: int):
self.c = c

@tool
async def foo(self, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + self.c

a = A(10)
async_response = await a.foo.ainvoke({"a": 1, "b": 2})
assert async_response == 13


def test_method_tool_string_invoke() -> None:
"""Test that a method tool can be invoked with a string."""

class A:
def __init__(self, a: str):
self.a = a

@tool
def foo(self, b: str) -> str:
"""Concatenate a and b."""
return self.a + b

a = A("a")
assert a.foo.invoke("b") == "ab"


def test_method_tool_toolcall_invoke() -> None:
"""Test that a method tool can be invoked with a ToolCall."""

class A:
def __init__(self, c: int):
self.c = c

@tool
def foo(self, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + self.c

a = A(10)

tool_call = {
"name": a.foo.name,
"args": {"a": 1, "b": 2},
"id": "123",
"type": "tool_call",
}

tool_message = a.foo.invoke(tool_call)

assert int(tool_message.content) == 13
Loading