Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph: add structured output to create_react_agent #2848

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 99 additions & 6 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast
from typing import (

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/chat_agent_executor.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 61.7 ms +- 1.3 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 52.7 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 75.3 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 94.8 ms +- 1.0 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 657 ms +- 33 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 514 ms +- 8 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 817 ms +- 23 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 946 ms +- 21 ms ......................................... react_agent_10x: Mean +- std dev: 30.7 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 23.0 ms +- 0.3 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 37.9 ms +- 0.7 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.8 ms +- 0.6 ms ......................................... react_agent_100x: Mean +- std dev: 341 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 276 ms +- 2 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 862 ms +- 12 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 856 ms +- 11 ms ......................................... wide_state_25x300: Mean +- std dev: 23.1 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 14.8 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 278 ms +- 14 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 275 ms +- 13 ms ......................................... wide_state_15x600: Mean +- std dev: 26.8 ms +- 0.8 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.2 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 477 ms +- 18 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 473 ms +- 13 ms ......................................... wide_state_9x1200: Mean +- std dev: 26.8 ms +- 0.7 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.2 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 309 ms +- 14 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 306 ms +- 14 ms

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/chat_agent_executor.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +====================================+=========+=======================+ | wide_state_9x1200_sync | 17.3 ms | 17.2 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 274 ms | 276 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 510 ms | 514 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 26.7 ms | 26.8 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.8 ms | 23.0 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_25x300 | 22.8 ms | 23.1 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 52.0 ms | 52.7 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 471 ms | 477 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 273 ms | 278 ms: 1.02x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 60.8 ms | 61.7 ms: 1.02x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 73.9 ms | 75.3 ms: 1.02x slower | +------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 832 ms | 856 ms: 1.03x slower | +------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 837 ms | 862 ms: 1.03x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 792 ms | 817 ms: 1.03x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 615 ms | 657 ms: 1.07x slower | +------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.01x slower | +------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (13): wide_state_15x600, react_agent_10x_checkpoint_sync, wide_state_15x600_sync, wide_state_25x300_sync, wide_state_9x1200_checkpoint, react_agent_10x_checkpoint, react_agent_100x, fanout_to_subgraph_10x_checkpoint_sync, react_agent_10x, wide_state_25x300_checkpoint_sync, fanout_to_subgraph_100x_checkpoint_sync, wide_state_15x600_checkpoint_sync, wide_state_9x1200_checkpoint_sync
Callable,
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
)

from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
Expand All @@ -8,6 +17,7 @@
RunnableConfig,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict

from langgraph._api.deprecation import deprecated_parameter
Expand Down Expand Up @@ -36,6 +46,8 @@

remaining_steps: RemainingSteps

structured_response: Union[dict, BaseModel]


StateSchema = TypeVar("StateSchema", bound=AgentState)
StateSchemaType = Type[StateSchema]
Expand Down Expand Up @@ -162,6 +174,19 @@
return False


def _get_model(model: LanguageModelLike) -> BaseChatModel:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider _unwrap_chat_model or _extract_chat_model or something similar that reflects the purpose of the function.

"""Get the underlying model from a RunnableBinding or return the model itself."""
if isinstance(model, RunnableBinding):
model = model.bound

if not isinstance(model, BaseChatModel):
raise TypeError(
f"Expected `model` to be a ChatModel or RunnableBinding (e.g. model.bind_tools(...)), got {type(model)}"
)

return model


def _validate_chat_history(
messages: Sequence[BaseMessage],
) -> None:
Expand Down Expand Up @@ -201,6 +226,7 @@
state_schema: Optional[StateSchemaType] = None,
messages_modifier: Optional[MessagesModifier] = None,
state_modifier: Optional[StateModifier] = None,
response_format: Optional[Union[dict, type[BaseModel]]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a unit test for this?

checkpointer: Optional[Checkpointer] = None,
store: Optional[BaseStore] = None,
interrupt_before: Optional[list[str]] = None,
Expand Down Expand Up @@ -236,6 +262,22 @@
- str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"].
- Callable: This function should take in full graph state and the output is then passed to the language model.
- Runnable: This runnable should take in full graph state and the output is then passed to the language model.
response_format: An optional schema for the final agent output.

If provided, output will be formatted to match the given schema and returned in the 'structured_response' state key.
If not provided, `structured_response` will not be present in the output state.
Can be passed in as:

- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class,
- or a Pydantic class.

!!! Important
`response_format` requires the model to support `.with_structured_output`

!!! Note
The graph will make a separate call to the LLM to generate the structured response after the agent loop is finished.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to make a note that other strategies are possible? (We could note to a guide on how to customize? or can also just not worry about it)

checkpointer: An optional checkpoint saver object. This is used for persisting
the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation).
store: An optional store object. This is used for persisting data
Expand Down Expand Up @@ -527,9 +569,11 @@
"""

if state_schema is not None:
if missing_keys := {"messages", "is_last_step"} - set(
state_schema.__annotations__
):
required_keys = {"messages", "remaining_steps"}
if response_format is not None:
required_keys.add("structured_response")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we decide on structured_response vs. parsed? I'm OK with current


if missing_keys := required_keys - set(state_schema.__annotations__):
raise ValueError(f"Missing required key(s) {missing_keys} in state_schema")

if isinstance(tools, ToolExecutor):
Expand Down Expand Up @@ -629,11 +673,44 @@
# We return a list, because this will get added to the existing list
return {"messages": [response]}

def generate_structured_response(
state: AgentState, config: RunnableConfig
) -> AgentState:
model_with_structured_output = _get_model(model).with_structured_output(
cast(Union[dict, type[BaseModel]], response_format)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Starting PEP604 one can use the shortcut dict | type[BaseModel].

By the way you may get away from using this cast by adding a defensive check.

assert response_format is not None, "Internal error: calling generate_structured_response when response_format is None".

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The | operator is only supported in python 3.10+, never mind the suggestion.

)
# NOTE: we exclude the last message because there is enough information
# for the LLM to generate the structured response
response = model_with_structured_output.invoke(state["messages"][:-1], config)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a downside in sending the full history? I am worried that this could cause bugs in the future.

return {"structured_response": response}

async def agenerate_structured_response(
state: AgentState, config: RunnableConfig
) -> AgentState:
model_with_structured_output = _get_model(model).with_structured_output(
cast(Union[dict, type[BaseModel]], response_format)
)
# NOTE: we exclude the last message because there is enough information
# for the LLM to generate the structured response
response = await model_with_structured_output.ainvoke(
state["messages"][:-1], config
)
return {"structured_response": response}

if not tool_calling_enabled:
# Define a new graph
workflow = StateGraph(state_schema or AgentState)
workflow.add_node("agent", RunnableCallable(call_model, acall_model))
workflow.set_entry_point("agent")
if response_format is not None:
workflow.add_node(
"generate_structured_response",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also use a shorter name like structure_response -- or can have chatty g brain storm a 1-2 word name

RunnableCallable(
generate_structured_response, agenerate_structured_response
),
)
workflow.add_edge("agent", "generate_structured_response")

return workflow.compile(
checkpointer=checkpointer,
store=store,
Expand All @@ -643,12 +720,14 @@
)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not extend the Literal this provides extra safety and readability.

def should_continue(state: AgentState) -> str:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return "__end__"
return (
"__end__" if response_format is None else "generate_structured_response"
)
# Otherwise if there is, we continue
else:
return "tools"
Expand All @@ -664,13 +743,27 @@
# This means that this node is the first one called
workflow.set_entry_point("agent")

# Add a structured output node if response_format is provided
if response_format is not None:
workflow.add_node(
"generate_structured_response",
RunnableCallable(
generate_structured_response, agenerate_structured_response
),
)
workflow.add_edge("generate_structured_response", "__end__")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using the END constant. This is both safer, and helps in future refactoring.

should_continue_destinations = ["tools", "generate_structured_response"]
else:
should_continue_destinations = ["tools", "__end__"]

# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
path_map=should_continue_destinations,
)

# If any of the tools are configured to return_directly after running,
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/tests/__snapshots__/test_large_cases.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -2360,10 +2360,10 @@
'''
# ---
# name: test_prebuilt_tool_chat
'{"$defs": {"BaseMessage": {"additionalProperties": true, "description": "Base abstract message class.\\n\\nMessages are the inputs and outputs of ChatModels.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "type"], "title": "BaseMessage", "type": "object"}}, "properties": {"messages": {"items": {"$ref": "#/$defs/BaseMessage"}, "title": "Messages", "type": "array"}}, "required": ["messages"], "title": "LangGraphInput", "type": "object"}'
'{"$defs": {"BaseMessage": {"additionalProperties": true, "description": "Base abstract message class.\\n\\nMessages are the inputs and outputs of ChatModels.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "type"], "title": "BaseMessage", "type": "object"}, "BaseModel": {"properties": {}, "title": "BaseModel", "type": "object"}}, "properties": {"messages": {"items": {"$ref": "#/$defs/BaseMessage"}, "title": "Messages", "type": "array"}, "structured_response": {"anyOf": [{"type": "object"}, {"$ref": "#/$defs/BaseModel"}], "title": "Structured Response"}}, "required": ["messages", "structured_response"], "title": "LangGraphInput", "type": "object"}'
# ---
# name: test_prebuilt_tool_chat.1
'{"$defs": {"BaseMessage": {"additionalProperties": true, "description": "Base abstract message class.\\n\\nMessages are the inputs and outputs of ChatModels.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "type"], "title": "BaseMessage", "type": "object"}}, "properties": {"messages": {"items": {"$ref": "#/$defs/BaseMessage"}, "title": "Messages", "type": "array"}}, "required": ["messages"], "title": "LangGraphOutput", "type": "object"}'
'{"$defs": {"BaseMessage": {"additionalProperties": true, "description": "Base abstract message class.\\n\\nMessages are the inputs and outputs of ChatModels.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "type"], "title": "BaseMessage", "type": "object"}, "BaseModel": {"properties": {}, "title": "BaseModel", "type": "object"}}, "properties": {"messages": {"items": {"$ref": "#/$defs/BaseMessage"}, "title": "Messages", "type": "array"}, "structured_response": {"anyOf": [{"type": "object"}, {"$ref": "#/$defs/BaseModel"}], "title": "Structured Response"}}, "required": ["messages", "structured_response"], "title": "LangGraphOutput", "type": "object"}'
# ---
# name: test_prebuilt_tool_chat.2
'''
Expand Down
Loading