Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph: check if model passed as runnable binding with tools in create_react_agent #1647

Merged
merged 12 commits into from
Sep 30, 2024
56 changes: 49 additions & 7 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/chat_agent_executor.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 59.3 ms +- 1.6 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (6.32 ms) is 11% of the mean (56.6 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. fanout_to_subgraph_10x_sync: Mean +- std dev: 56.6 ms +- 6.3 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 77.2 ms +- 0.9 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 81.8 ms +- 0.6 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 554 ms +- 13 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 504 ms +- 4 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 749 ms +- 17 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 788 ms +- 6 ms ......................................... react_agent_10x: Mean +- std dev: 41.3 ms +- 3.1 ms ......................................... react_agent_10x_sync: Mean +- std dev: 29.8 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 53.0 ms +- 1.2 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 43.3 ms +- 3.2 ms ......................................... react_agent_100x: Mean +- std dev: 414 ms +- 8 ms ......................................... react_agent_100x_sync: Mean +- std dev: 331 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 923 ms +- 12 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 822 ms +- 7 ms ......................................... wide_state_25x300: Mean +- std dev: 20.5 ms +- 0.3 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 12.8 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 238 ms +- 8 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 237 ms +- 13 ms ......................................... wide_state_15x600: Mean +- std dev: 23.7 ms +- 0.3 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 14.8 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 417 ms +- 15 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 419 ms +- 17 ms ......................................... wide_state_9x1200: Mean +- std dev: 23.7 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 14.8 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 269 ms +- 8 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 269 ms +- 15 ms

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/chat_agent_executor.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +========================================+=========+=======================+ | fanout_to_subgraph_100x_checkpoint | 783 ms | 749 ms: 1.05x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 571 ms | 554 ms: 1.03x faster | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 836 ms | 822 ms: 1.02x faster | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 928 ms | 923 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 81.5 ms | 81.8 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 14.8 ms | 14.8 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 23.6 ms | 23.7 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 20.4 ms | 20.5 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 76.6 ms | 77.2 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 23.4 ms | 23.7 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x faster | +----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (18): wide_state_25x300_checkpoint_sync, fanout_to_subgraph_10x_sync, react_agent_10x, wide_state_9x1200_checkpoint, wide_state_25x300_checkpoint, wide_state_15x600_checkpoint, wide_state_15x600_checkpoint_sync, wide_state_25x300_sync, react_agent_10x_sync, fanout_to_subgraph_100x_checkpoint_sync, react_agent_100x_sync, fanout_to_subgraph_10x, wide_state_15x600_sync, fanout_to_subgraph_100x_sync, wide_state_9x1200_checkpoint_sync, react_agent_10x_checkpoint, react_agent_100x, react_agent_10x_checkpoint_sync

from langchain_core.language_models import BaseChatModel
from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableConfig,
RunnableLambda,
)
from langchain_core.tools import BaseTool
from typing_extensions import Annotated, TypedDict

Expand Down Expand Up @@ -115,9 +120,43 @@
return _get_state_modifier_runnable(state_modifier)


def _should_bind_tools(model: LanguageModelLike, tools: Sequence[BaseTool]) -> bool:
if not isinstance(model, RunnableBinding):
return False

if "tools" not in model.kwargs:
return False

bound_tools = model.kwargs["tools"]
if len(tools) != len(bound_tools):
raise ValueError(
"Number of tools in the model.bind_tools() and tools passed to create_react_agent must match"
)

tool_names = set(tool.name for tool in tools)
bound_tool_names = set()
for bound_tool in bound_tools:
# OpenAI-style tool
if bound_tool.get("type") == "function":
bound_tool_name = bound_tool["function"]["name"]
# Anthropic-style tool
elif bound_tool.get("name"):
bound_tool_name = bound_tool["name"]
else:
# unknown tool type so we'll ignore it
continue

bound_tool_names.add(bound_tool_name)

if missing_tools := tool_names - bound_tool_names:
raise ValueError(f"Missing tools '{missing_tools}' in the model.bind_tools()")

return True


@deprecated_parameter("messages_modifier", "0.1.9", "state_modifier", removal="0.3.0")
def create_react_agent(
model: BaseChatModel,
model: LanguageModelLike,
tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode],
*,
state_schema: Optional[StateSchemaType] = None,
Expand Down Expand Up @@ -412,9 +451,12 @@
tool_classes = list(tools.tools_by_name.values())
tool_node = tools
else:
tool_classes = tools
tool_node = ToolNode(tool_classes)
model = model.bind_tools(tool_classes)
tool_node = ToolNode(tools)
# get the tool functions wrapped in a tool class from the ToolNode
tool_classes = list(tool_node.tools_by_name.values())

if _should_bind_tools(model, tool_classes):
model = cast(BaseChatModel, model).bind_tools(tool_classes)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
Expand Down
80 changes: 79 additions & 1 deletion libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Expand Down Expand Up @@ -49,6 +50,7 @@
class FakeToolCallingModel(BaseChatModel):
tool_calls: Optional[list[list[ToolCall]]] = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"

def _generate(
self,
Expand Down Expand Up @@ -79,7 +81,31 @@ def bind_tools(
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self
tool_dicts = []
for tool in tools:
if not isinstance(tool, BaseTool):
raise TypeError(
"Only BaseTool is supported by FakeToolCallingModel.bind_tools"
)

# NOTE: this is a simplified tool spec for testing purposes only
if self.tool_style == "openai":
tool_dicts.append(
{
"type": "function",
"function": {
"name": tool.name,
},
}
)
elif self.tool_style == "anthropic":
tool_dicts.append(
{
"name": tool.name,
}
)

return self.bind(tools=tool_dicts)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
Expand Down Expand Up @@ -242,6 +268,58 @@ def test_runnable_state_modifier():
assert response == expected_response


@pytest.mark.parametrize("tool_style", ["openai", "anthropic"])
def test_model_with_tools(tool_style: str):
model = FakeToolCallingModel(tool_style=tool_style)

@dec_tool
def tool1(some_val: int) -> str:
"""Tool 1 docstring."""
return f"Tool 1: {some_val}"

@dec_tool
def tool2(some_val: int) -> str:
"""Tool 2 docstring."""
return f"Tool 2: {some_val}"

# check valid agent constructor
agent = create_react_agent(model.bind_tools([tool1, tool2]), [tool1, tool2])
result = agent.nodes["tools"].invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 2},
"id": "some 1",
},
{
"name": "tool2",
"args": {"some_val": 2},
"id": "some 2",
},
],
)
]
}
)
tool_messages: ToolMessage = result["messages"][-2:]
for tool_message in tool_messages:
assert tool_message.type == "tool"
assert tool_message.content in {"Tool 1: 2", "Tool 2: 2"}
assert tool_message.tool_call_id in {"some 1", "some 2"}

# test mismatching tool lengths
with pytest.raises(ValueError):
create_react_agent(model.bind_tools([tool1]), [tool1, tool2])

# test missing bound tools
with pytest.raises(ValueError):
create_react_agent(model.bind_tools([tool1]), [tool2])


async def test_tool_node():
def tool1(some_val: int, some_other_val: str) -> str:
"""Tool 1 docstring."""
Expand Down
Loading