Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Jan 10, 2025
1 parent b2e50c8 commit 570c561
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 21 deletions.
54 changes: 35 additions & 19 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from langgraph._api.deprecation import deprecated_parameter
from langgraph.errors import ErrorCode, create_error_message
from langgraph.graph import StateGraph
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep, RemainingSteps
Expand All @@ -32,11 +32,14 @@
from langgraph.types import Checkpointer
from langgraph.utils.runnable import RunnableCallable

StructuredResponse = Union[dict, BaseModel]
StructuredResponseSchema = Union[dict, type[BaseModel]]


# We create the AgentState that we will pass around
# This simply involves a list of messages
# We want steps to return messages to append to the list
# So we annotate the messages attribute with operator.add
# So we annotate the messages attribute with `add_messages` reducer
class AgentState(TypedDict):
"""The state of the agent."""

Expand All @@ -46,7 +49,7 @@ class AgentState(TypedDict):

remaining_steps: RemainingSteps

structured_response: Union[dict, BaseModel]
structured_response: StructuredResponse


StateSchema = TypeVar("StateSchema", bound=AgentState)
Expand Down Expand Up @@ -226,7 +229,9 @@ def create_react_agent(
state_schema: Optional[StateSchemaType] = None,
messages_modifier: Optional[MessagesModifier] = None,
state_modifier: Optional[StateModifier] = None,
response_format: Optional[Union[dict, type[BaseModel]]] = None,
response_format: Optional[
Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]]
] = None,
checkpointer: Optional[Checkpointer] = None,
store: Optional[BaseStore] = None,
interrupt_before: Optional[list[str]] = None,
Expand Down Expand Up @@ -272,12 +277,15 @@ def create_react_agent(
- a JSON Schema,
- a TypedDict class,
- or a Pydantic class.
- a tuple (prompt, schema), where schema is one of the above.
The prompt will be used together with the model that is being used to generate the structured response.
!!! Important
`response_format` requires the model to support `.with_structured_output`
!!! Note
The graph will make a separate call to the LLM to generate the structured response after the agent loop is finished.
This is not the only strategy to get structured responses, see more options in [this guide](https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/).
checkpointer: An optional checkpoint saver object. This is used for persisting
the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation).
store: An optional store object. This is used for persisting data
Expand Down Expand Up @@ -680,25 +688,35 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
def generate_structured_response(
state: AgentState, config: RunnableConfig
) -> AgentState:
model_with_structured_output = _get_model(model).with_structured_output(
cast(Union[dict, type[BaseModel]], response_format)
)
# NOTE: we exclude the last message because there is enough information
# for the LLM to generate the structured response
response = model_with_structured_output.invoke(state["messages"][:-1], config)
messages = state["messages"][:-1]
structured_response_schema = response_format
if isinstance(response_format, tuple):
system_prompt, structured_response_schema = response_format
messages = [SystemMessage(content=system_prompt)] + list(messages)

model_with_structured_output = _get_model(model).with_structured_output(
cast(StructuredResponseSchema, structured_response_schema)
)
response = model_with_structured_output.invoke(messages, config)
return {"structured_response": response}

async def agenerate_structured_response(
state: AgentState, config: RunnableConfig
) -> AgentState:
model_with_structured_output = _get_model(model).with_structured_output(
cast(Union[dict, type[BaseModel]], response_format)
)
# NOTE: we exclude the last message because there is enough information
# for the LLM to generate the structured response
response = await model_with_structured_output.ainvoke(
state["messages"][:-1], config
messages = state["messages"][:-1]
structured_response_schema = response_format
if isinstance(response_format, tuple):
system_prompt, structured_response_schema = response_format
messages = [SystemMessage(content=system_prompt)] + list(messages)

model_with_structured_output = _get_model(model).with_structured_output(
cast(StructuredResponseSchema, structured_response_schema)
)
response = await model_with_structured_output.ainvoke(messages, config)
return {"structured_response": response}

if not tool_calling_enabled:
Expand Down Expand Up @@ -729,9 +747,7 @@ def should_continue(state: AgentState) -> str:
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return (
"__end__" if response_format is None else "generate_structured_response"
)
return END if response_format is None else "generate_structured_response"
# Otherwise if there is, we continue
else:
return "tools"
Expand All @@ -755,10 +771,10 @@ def should_continue(state: AgentState) -> str:
generate_structured_response, agenerate_structured_response
),
)
workflow.add_edge("generate_structured_response", "__end__")
workflow.add_edge("generate_structured_response", END)
should_continue_destinations = ["tools", "generate_structured_response"]
else:
should_continue_destinations = ["tools", "__end__"]
should_continue_destinations = ["tools", END]

# We now add a conditional edge
workflow.add_conditional_edges(
Expand All @@ -775,7 +791,7 @@ def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]:
if not isinstance(m, ToolMessage):
break
if m.name in should_return_direct:
return "__end__"
return END
return "agent"

if should_return_direct:
Expand Down
41 changes: 39 additions & 2 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.tools import BaseTool, ToolException
from langchain_core.tools import tool as dec_tool
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict
Expand All @@ -47,7 +47,11 @@
create_react_agent,
tools_condition,
)
from langgraph.prebuilt.chat_agent_executor import AgentState, _validate_chat_history
from langgraph.prebuilt.chat_agent_executor import (
AgentState,
StructuredResponse,
_validate_chat_history,
)
from langgraph.prebuilt.tool_node import (
TOOL_CALL_ERROR_TEMPLATE,
InjectedState,
Expand All @@ -71,6 +75,7 @@

class FakeToolCallingModel(BaseChatModel):
tool_calls: Optional[list[list[ToolCall]]] = None
structured_response: Optional[StructuredResponse] = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"

Expand Down Expand Up @@ -98,6 +103,14 @@ def _generate(
def _llm_type(self) -> str:
return "fake-tool-call-model"

def with_structured_output(
self, schema: Type[BaseModel]
) -> Runnable[LanguageModelInput, StructuredResponse]:
if self.structured_response is None:
raise ValueError("Structured response is not set")

return RunnableLambda(lambda x: self.structured_response)

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
Expand Down Expand Up @@ -511,6 +524,30 @@ def handler(e: Union[str, int]):
_infer_handled_types(handler)


def test_react_agent_with_structured_response() -> None:
class WeatherResponse(BaseModel):
temperature: float = Field(description="The temperature in fahrenheit")

tool_calls = [[{"args": {}, "id": "1", "name": "get_weather"}], []]

def get_weather():
"""Get the weather"""
return "The weather is sunny and 75°F."

expected_structured_response = WeatherResponse(temperature=75)
model = FakeToolCallingModel(
tool_calls=tool_calls, structured_response=expected_structured_response
)
for response_format in (WeatherResponse, ("Meow", WeatherResponse)):
agent = create_react_agent(
model, [get_weather], response_format=response_format
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert response["structured_response"] == expected_structured_response
assert len(response["messages"]) == 4
assert response["messages"][-2].content == "The weather is sunny and 75°F."


# tools for testing Too
def tool1(some_val: int, some_other_val: str) -> str:
"""Tool 1 docstring."""
Expand Down

0 comments on commit 570c561

Please sign in to comment.