Skip to content

Commit

Permalink
Allow react.Tool to wrap methods (#1856)
Browse files Browse the repository at this point in the history
The big reason for this is to pass parameters out-of-band, e.g. a
user_id to ensure the LLM doesn't get the wrong data.

The unit test includes a usage, you can't use it as a decorator this
way, but it works.

The alternative, of course, is to have a very long function and have all
the tools be nested functions. It works, but can lead to some very long
functions. I prefer long classes over long functions.
  • Loading branch information
tkellogg authored Nov 25, 2024
1 parent ff6f5a8 commit 1b10e23
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class Tool:
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None):
annotations_func = func if inspect.isfunction(func) else func.__call__
annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__
self.func = func
self.name = name or getattr(func, '__name__', type(func).__name__)
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "")
Expand Down
27 changes: 26 additions & 1 deletion tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dspy
from dspy.utils.dummies import DummyLM, dummy_rm
from dspy.predict import react


# def test_example_no_tools():
Expand Down Expand Up @@ -121,4 +122,28 @@
# react = dspy.ReAct(ExampleSignature)

# assert react.react[0].signature.instructions is not None
# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.")
# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.")

def test_tool_from_function():
def foo(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = react.Tool(foo)
assert tool.name == "foo"
assert tool.desc == "Add two numbers."
assert tool.args == {"a": "int", "b": "int"}

def test_tool_from_class():
class Foo:
def __init__(self, user_id: str):
self.user_id = user_id

def foo(self, a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = react.Tool(Foo("123").foo)
assert tool.name == "foo"
assert tool.desc == "Add two numbers."
assert tool.args == {"a": "int", "b": "int"}

0 comments on commit 1b10e23

Please sign in to comment.