Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: improved method tools #28695

Closed
wants to merge 9 commits into from
1 change: 1 addition & 0 deletions libs/core/langchain_core/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from langchain_core.tools.base import (
create_schema_from_function as create_schema_from_function,
)
from langchain_core.tools.convert import MethodTool as MethodTool
ethanglide marked this conversation as resolved.
Show resolved Hide resolved
from langchain_core.tools.convert import (
convert_runnable_to_tool as convert_runnable_to_tool,
)
Expand Down
20 changes: 20 additions & 0 deletions libs/core/langchain_core/tools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
outer_instance: Optional[Any] = None,
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...


Expand All @@ -33,6 +34,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
outer_instance: Optional[Any] = None,
) -> BaseTool: ...


Expand All @@ -46,6 +48,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
outer_instance: Optional[Any] = None,
) -> BaseTool: ...


Expand All @@ -59,6 +62,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
outer_instance: Optional[Any] = None,
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...


Expand All @@ -72,6 +76,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
outer_instance: Optional[Any] = None,
) -> Union[
BaseTool,
Callable[[Union[Callable, Runnable]], BaseTool],
Expand Down Expand Up @@ -102,6 +107,8 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
outer_instance: For method tools, the instance of the class that the tool is
being created. Defaults to None.

Returns:
The tool.
Expand Down Expand Up @@ -257,6 +264,7 @@ def invoke_wrapper(
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
outer_instance=outer_instance,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
Expand Down Expand Up @@ -421,3 +429,15 @@ def invoke_wrapper(callbacks: Optional[Callbacks] = None, **kwargs: Any) -> Any:
description=description,
args_schema=args_schema,
)


class MethodTool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this a function instead of a class?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - this maybe solves the issue of typing. going to suggest some ideas - not sure if they work

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class is needed for classmethod support. I have created a wrapper function methodtool that simply returns MethodTool(func). How's that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it needed for classmethod support?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned in another reply:

With the way the __get__ descriptor works owner is for the class itself and instance is for the instance of the class.

Keep in mind that in your other solution (a function that returns a property), you are still returning a descriptor class. That class, property. works for our regular method tool case, but it will not work for class methods. Notice in the MethodTool class that it will pass in owner (class itself) for outer_instance with classmethods, and instance (class instance) for regular methods. Since the property class is not enough for our use, I had to create a custom descriptor that did exactly what we wanted it to.

"Why could we not do classmethod(property(func)) for classmethods?" you may ask. That is because classmethod no longer wraps descriptors as of Python 3.13.

If, for syntax reasons, you would like the decorator to be a function that is something that I have added to the code. The function is simply a wrapper for the MethodTool class since that overloaded __get__ method is needed, but the MethodTool class never needs to be exposed to client code.

"""A descriptor that converts a method into a tool."""

def __init__(self, func: Union[Callable, classmethod]) -> None:
self.func = func

def __get__(self, instance: Any, owner: Any) -> BaseTool:
if isinstance(self.func, classmethod):
return tool(self.func.__func__, outer_instance=owner)
return tool(self.func, outer_instance=instance)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MethodTool:
"""A descriptor that converts a method into a tool."""
def __init__(self, func: Union[Callable, classmethod]) -> None:
self.func = func
def __get__(self, instance: Any, owner: Any) -> BaseTool:
if isinstance(self.func, classmethod):
return tool(self.func.__func__, outer_instance=owner)
return tool(self.func, outer_instance=instance)
def toolmethod(f: callable) -> property:
"""A decorator that converts a method into a tool."""
def inner_f(outer_instance) -> BaseTool:
tool_f = functools.partial(f, outer_instance) # this substitutes in for self
return tool(tool_f)
return property(inner_f)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this leaves out what was being handled by owner vs instance for outer_instance - what is that doing for us?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was to allow the support for classmethods. With the way the __get__ descriptor works owner is for the class itself and instance is for the instance of the class.

20 changes: 19 additions & 1 deletion libs/core/langchain_core/tools/structured.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
import textwrap
from collections.abc import Awaitable
from inspect import signature
Expand Down Expand Up @@ -41,6 +42,9 @@ class StructuredTool(BaseTool):
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
outer_instance: Optional[Any] = None
ethanglide marked this conversation as resolved.
Show resolved Hide resolved
"""For methods/classmethods, the 'self' or 'cls' instance to use
when calling the function."""

# --- Runnable ---

Expand Down Expand Up @@ -77,6 +81,8 @@ def _run(
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.func):
kwargs[config_param] = config
if self.outer_instance is not None:
args = (self.outer_instance, *args)
return self.func(*args, **kwargs)
msg = "StructuredTool does not support sync invocation."
raise NotImplementedError(msg)
Expand All @@ -94,6 +100,8 @@ async def _arun(
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.coroutine):
kwargs[config_param] = config
if self.outer_instance is not None:
args = (self.outer_instance, *args)
return await self.coroutine(*args, **kwargs)

# If self.coroutine is None, then this will delegate to the default
Expand All @@ -116,6 +124,7 @@ def from_function(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
outer_instance: Optional[Any] = None,
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
Expand Down Expand Up @@ -172,13 +181,21 @@ def add(a: int, b: int) -> int:
raise ValueError(msg)
name = name or source_function.__name__
if args_schema is None and infer_schema:
filter_args = _filter_schema_args(source_function)

# if outer_instance is provided, add the first argument to the filter_args
if outer_instance is not None:
filter_args.append(
list(inspect.signature(source_function).parameters.values())[0].name
)

# schema name is appended within function
args_schema = create_schema_from_function(
name,
source_function,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
filter_args=_filter_schema_args(source_function),
filter_args=filter_args,
)
description_ = description
if description is None and not parse_docstring:
Expand All @@ -203,6 +220,7 @@ def add(a: int, b: int) -> int:
description=description_,
return_direct=return_direct,
response_format=response_format,
outer_instance=outer_instance,
**kwargs,
)

Expand Down
122 changes: 122 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
ethanglide marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from langchain_core.tools import (
BaseTool,
MethodTool,
StructuredTool,
Tool,
ToolException,
Expand Down Expand Up @@ -2174,3 +2175,124 @@ def foo(x: int) -> Bar:
assert foo.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == Bar(x=0)


def test_method_tool_self_ref() -> None:
"""Test that a method tool can reference self."""

class A:
def __init__(self, c: int):
self.c = c

@MethodTool
ethanglide marked this conversation as resolved.
Show resolved Hide resolved
def foo(self, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + self.c

a = A(10)
assert a.foo.invoke({"a": 1, "b": 2}) == 13
ethanglide marked this conversation as resolved.
Show resolved Hide resolved
ethanglide marked this conversation as resolved.
Show resolved Hide resolved


def test_method_tool_args() -> None:
"""Test that a method tool's args do not include self."""

class A:
def __init__(self, c: int):
self.c = c

@MethodTool
def foo(self, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + self.c

a = A(10)
assert "self" not in a.foo.args


async def test_method_tool_async() -> None:
"""Test that a method tool can be async."""

class A:
def __init__(self, c: int):
self.c = c

@MethodTool
async def foo(self, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + self.c

a = A(10)
async_response = await a.foo.ainvoke({"a": 1, "b": 2})
assert async_response == 13


def test_method_tool_string_invoke() -> None:
"""Test that a method tool can be invoked with a string."""

class A:
def __init__(self, a: str):
self.a = a

@MethodTool
def foo(self, b: str) -> str:
"""Concatenate a and b."""
return self.a + b

a = A("a")
assert a.foo.invoke("b") == "ab"


def test_method_tool_toolcall_invoke() -> None:
"""Test that a method tool can be invoked with a ToolCall."""

class A:
def __init__(self, c: int):
self.c = c

@MethodTool
def foo(self, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + self.c

a = A(10)

tool_call = {
"name": a.foo.name,
"args": {"a": 1, "b": 2},
"id": "123",
"type": "tool_call",
}

tool_message = a.foo.invoke(tool_call)

assert int(tool_message.content) == 13


def test_method_tool_classmethod() -> None:
"""Test that a method tool can be a classmethod."""

class A:
c = 5

@MethodTool
@classmethod
def foo(cls, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + cls.c

assert A.foo.invoke({"a": 1, "b": 2}) == 8


def test_method_tool_classmethod_args() -> None:
"""Test that a classmethod tool's args do not include cls."""

class A:
c = 5

@MethodTool
@classmethod
def foo(cls, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + cls.c

assert "cls" not in A.foo.args
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert "cls" not in A.foo.args
assert "cls" not in A.foo.args
class A:
...
class TestStandardAUnit(BaseToolUnitTests):
...
class TestStandardAIntegration(BaseToolIntegrationTests):
...

would also be good to implement some of these as standard tests - it tests a few of the exhaustive cases: https://python.langchain.com/docs/contributing/how_to/integrations/standard_tests/

sorry this is dragging on a bit! appreciate your help!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems cool, I've never heard of this before and would like to give it a try since it could make the tests more concise. But I am having trouble putting them into the test_tools.py file. The issue comes from being unable to import the ToolsUnitTests class.

Loading