-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
langgraph: add structured output to create_react_agent #2848
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,13 @@ | ||
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast | ||
from typing import ( | ||
Check notice on line 1 in libs/langgraph/langgraph/prebuilt/chat_agent_executor.py GitHub Actions / benchmarkBenchmark results
Check notice on line 1 in libs/langgraph/langgraph/prebuilt/chat_agent_executor.py GitHub Actions / benchmarkComparison against main
|
||
Callable, | ||
Literal, | ||
Optional, | ||
Sequence, | ||
Type, | ||
TypeVar, | ||
Union, | ||
cast, | ||
) | ||
|
||
from langchain_core.language_models import BaseChatModel, LanguageModelLike | ||
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage | ||
|
@@ -8,6 +17,7 @@ | |
RunnableConfig, | ||
) | ||
from langchain_core.tools import BaseTool | ||
from pydantic import BaseModel | ||
from typing_extensions import Annotated, TypedDict | ||
|
||
from langgraph._api.deprecation import deprecated_parameter | ||
|
@@ -36,6 +46,8 @@ | |
|
||
remaining_steps: RemainingSteps | ||
|
||
structured_response: Union[dict, BaseModel] | ||
|
||
|
||
StateSchema = TypeVar("StateSchema", bound=AgentState) | ||
StateSchemaType = Type[StateSchema] | ||
|
@@ -162,6 +174,19 @@ | |
return False | ||
|
||
|
||
def _get_model(model: LanguageModelLike) -> BaseChatModel: | ||
"""Get the underlying model from a RunnableBinding or return the model itself.""" | ||
if isinstance(model, RunnableBinding): | ||
model = model.bound | ||
|
||
if not isinstance(model, BaseChatModel): | ||
raise TypeError( | ||
f"Expected `model` to be a ChatModel or RunnableBinding (e.g. model.bind_tools(...)), got {type(model)}" | ||
) | ||
|
||
return model | ||
|
||
|
||
def _validate_chat_history( | ||
messages: Sequence[BaseMessage], | ||
) -> None: | ||
|
@@ -201,6 +226,7 @@ | |
state_schema: Optional[StateSchemaType] = None, | ||
messages_modifier: Optional[MessagesModifier] = None, | ||
state_modifier: Optional[StateModifier] = None, | ||
response_format: Optional[Union[dict, type[BaseModel]]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a unit test for this? |
||
checkpointer: Optional[Checkpointer] = None, | ||
store: Optional[BaseStore] = None, | ||
interrupt_before: Optional[list[str]] = None, | ||
|
@@ -236,6 +262,22 @@ | |
- str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"]. | ||
- Callable: This function should take in full graph state and the output is then passed to the language model. | ||
- Runnable: This runnable should take in full graph state and the output is then passed to the language model. | ||
response_format: An optional schema for the final agent output. | ||
|
||
If provided, output will be formatted to match the given schema and returned in the 'structured_response' state key. | ||
If not provided, `structured_response` will not be present in the output state. | ||
Can be passed in as: | ||
|
||
- an OpenAI function/tool schema, | ||
- a JSON Schema, | ||
- a TypedDict class, | ||
- or a Pydantic class. | ||
|
||
!!! Important | ||
`response_format` requires the model to support `.with_structured_output` | ||
|
||
!!! Note | ||
The graph will make a separate call to the LLM to generate the structured response after the agent loop is finished. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to make a note that other strategies are possible? (We could note to a guide on how to customize? or can also just not worry about it) |
||
checkpointer: An optional checkpoint saver object. This is used for persisting | ||
the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation). | ||
store: An optional store object. This is used for persisting data | ||
|
@@ -527,9 +569,11 @@ | |
""" | ||
|
||
if state_schema is not None: | ||
if missing_keys := {"messages", "is_last_step"} - set( | ||
state_schema.__annotations__ | ||
): | ||
required_keys = {"messages", "remaining_steps"} | ||
if response_format is not None: | ||
required_keys.add("structured_response") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did we decide on structured_response vs. parsed? I'm OK with current |
||
|
||
if missing_keys := required_keys - set(state_schema.__annotations__): | ||
raise ValueError(f"Missing required key(s) {missing_keys} in state_schema") | ||
|
||
if isinstance(tools, ToolExecutor): | ||
|
@@ -629,11 +673,44 @@ | |
# We return a list, because this will get added to the existing list | ||
return {"messages": [response]} | ||
|
||
def generate_structured_response( | ||
state: AgentState, config: RunnableConfig | ||
) -> AgentState: | ||
model_with_structured_output = _get_model(model).with_structured_output( | ||
cast(Union[dict, type[BaseModel]], response_format) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Starting PEP604 one can use the shortcut By the way you may get away from using this cast by adding a defensive check.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
) | ||
# NOTE: we exclude the last message because there is enough information | ||
# for the LLM to generate the structured response | ||
response = model_with_structured_output.invoke(state["messages"][:-1], config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a downside in sending the full history? I am worried that this could cause bugs in the future. |
||
return {"structured_response": response} | ||
|
||
async def agenerate_structured_response( | ||
state: AgentState, config: RunnableConfig | ||
) -> AgentState: | ||
model_with_structured_output = _get_model(model).with_structured_output( | ||
cast(Union[dict, type[BaseModel]], response_format) | ||
) | ||
# NOTE: we exclude the last message because there is enough information | ||
# for the LLM to generate the structured response | ||
response = await model_with_structured_output.ainvoke( | ||
state["messages"][:-1], config | ||
) | ||
return {"structured_response": response} | ||
|
||
if not tool_calling_enabled: | ||
# Define a new graph | ||
workflow = StateGraph(state_schema or AgentState) | ||
workflow.add_node("agent", RunnableCallable(call_model, acall_model)) | ||
workflow.set_entry_point("agent") | ||
if response_format is not None: | ||
workflow.add_node( | ||
"generate_structured_response", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could also use a shorter name like |
||
RunnableCallable( | ||
generate_structured_response, agenerate_structured_response | ||
), | ||
) | ||
workflow.add_edge("agent", "generate_structured_response") | ||
|
||
return workflow.compile( | ||
checkpointer=checkpointer, | ||
store=store, | ||
|
@@ -643,12 +720,14 @@ | |
) | ||
|
||
# Define the function that determines whether to continue or not | ||
def should_continue(state: AgentState) -> Literal["tools", "__end__"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not extend the |
||
def should_continue(state: AgentState) -> str: | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
# If there is no function call, then we finish | ||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls: | ||
return "__end__" | ||
return ( | ||
"__end__" if response_format is None else "generate_structured_response" | ||
) | ||
# Otherwise if there is, we continue | ||
else: | ||
return "tools" | ||
|
@@ -664,13 +743,27 @@ | |
# This means that this node is the first one called | ||
workflow.set_entry_point("agent") | ||
|
||
# Add a structured output node if response_format is provided | ||
if response_format is not None: | ||
workflow.add_node( | ||
"generate_structured_response", | ||
RunnableCallable( | ||
generate_structured_response, agenerate_structured_response | ||
), | ||
) | ||
workflow.add_edge("generate_structured_response", "__end__") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using the |
||
should_continue_destinations = ["tools", "generate_structured_response"] | ||
else: | ||
should_continue_destinations = ["tools", "__end__"] | ||
|
||
# We now add a conditional edge | ||
workflow.add_conditional_edges( | ||
# First, we define the start node. We use `agent`. | ||
# This means these are the edges taken after the `agent` node is called. | ||
"agent", | ||
# Next, we pass in the function that will determine which node is called next. | ||
should_continue, | ||
path_map=should_continue_destinations, | ||
) | ||
|
||
# If any of the tools are configured to return_directly after running, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider
_unwrap_chat_model
or_extract_chat_model
or something similar that reflects the purpose of the function.