Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph: check if model passed as runnable binding with tools in create_react_agent #1647

Merged
merged 12 commits into from
Sep 30, 2024
45 changes: 43 additions & 2 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
BaseMessage,
SystemMessage,
)
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableConfig,
RunnableLambda,
)
from langchain_core.tools import BaseTool

from langgraph._api.deprecation import deprecated_parameter
Expand Down Expand Up @@ -127,6 +132,40 @@ def _get_model_preprocessing_runnable(
return _get_state_modifier_runnable(state_modifier)


def _should_bind_tools(model: LanguageModelLike, tools: Sequence[BaseTool]) -> bool:
if not isinstance(model, RunnableBinding):
return False

if "tools" not in model.kwargs:
return False

bound_tools = model.kwargs["tools"]
if len(tools) != len(bound_tools):
raise ValueError(
"Number of tools in the model.bind_tools() and tools passed to create_react_agent must match"
)

tool_names = set(tool.name for tool in tools)
bound_tool_names = set()
for bound_tool in bound_tools:
# OpenAI-style tool
if bound_tool.get("type") == "function":
bound_tool_name = bound_tool["function"]["name"]
# Anthropic-style tool
elif bound_tool.get("name"):
bound_tool_name = bound_tool["name"]
else:
# unknown tool type so we'll ignore it
continue

bound_tool_names.add(bound_tool_name)

if missing_tools := tool_names - bound_tool_names:
raise ValueError(f"Missing tools '{missing_tools}' in the model.bind_tools()")

return True


@deprecated_parameter("messages_modifier", "0.1.9", "state_modifier", removal="0.3.0")
def create_react_agent(
model: LanguageModelLike,
Expand Down Expand Up @@ -426,7 +465,9 @@ class Agent,Tools otherClass
else:
tool_classes = tools
tool_node = ToolNode(tool_classes)
vbarda marked this conversation as resolved.
Show resolved Hide resolved
model = model.bind_tools(tool_classes)

if _should_bind_tools(model, tool_classes):
model = model.bind_tools(tool_classes)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState):
Expand Down
Loading