Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph[patch]: format messages in state #2199

Merged
merged 16 commits into from
Dec 18, 2024
125 changes: 123 additions & 2 deletions libs/langgraph/langgraph/graph/message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import uuid

Check notice on line 1 in libs/langgraph/langgraph/graph/message.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... WARNING: the benchmark result may be unstable * the standard deviation (6.27 ms) is 11% of the mean (54.8 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. fanout_to_subgraph_10x: Mean +- std dev: 54.8 ms +- 6.3 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 46.5 ms +- 3.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 76.8 ms +- 1.6 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 84.8 ms +- 0.7 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 469 ms +- 9 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 429 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 799 ms +- 46 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 833 ms +- 16 ms ......................................... react_agent_10x: Mean +- std dev: 29.0 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.4 ms +- 1.6 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.6 ms +- 3.2 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 37.5 ms +- 3.2 ms ......................................... react_agent_100x: Mean +- std dev: 328 ms +- 14 ms ......................................... react_agent_100x_sync: Mean +- std dev: 262 ms +- 12 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 919 ms +- 7 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 821 ms +- 6 ms ......................................... wide_state_25x300: Mean +- std dev: 18.3 ms +- 0.3 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 10.8 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 271 ms +- 3 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 261 ms +- 4 ms ......................................... wide_state_15x600: Mean +- std dev: 21.2 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 12.5 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 470 ms +- 7 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 467 ms +- 15 ms ......................................... wide_state_9x1200: Mean +- std dev: 21.2 ms +- 0.4 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 12.5 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 305 ms +- 4 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 301 ms +- 13 ms

Check notice on line 1 in libs/langgraph/langgraph/graph/message.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_100x | 491 ms | 469 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 835 ms | 799 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x | 335 ms | 328 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 836 ms | 821 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 847 ms | 833 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 86.1 ms | 84.8 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 18.5 ms | 18.3 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 10.9 ms | 10.8 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 474 ms | 470 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 307 ms | 305 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 12.6 ms | 12.5 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 273 ms | 271 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 21.4 ms | 21.2 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 21.3 ms | 21.2 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 12.6 ms | 12.5 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 48.5 ms | 54.8 ms: 1.13x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.01x faster | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (12): react_agent_10x_checkpoint_sync, react_agent_10x_checkpoint, react_agent_100x_sync, wide_state_9x1200_checkpoint_sync, react_agent_10x_sync, fanout_to_subgraph_10x_checkpoint, wide_state_15x600_checkpoint_sync, fanout_to_subgraph_100x_sync, wide_state_25x300_checkpoint_sync, react_agent_10x, fanout_to_subgraph_10x_sync, react_agent_100x_checkpoint
from typing import Annotated, TypedDict, Union, cast
import warnings
from functools import partial
from typing import (
Annotated,
Any,
Callable,
Literal,
Optional,
Sequence,
TypedDict,
Union,
cast,
)

from langchain_core.messages import (
AnyMessage,
BaseMessage,
BaseMessageChunk,
MessageLikeRepresentation,
RemoveMessage,
Expand All @@ -15,7 +28,29 @@
Messages = Union[list[MessageLikeRepresentation], MessageLikeRepresentation]


def add_messages(left: Messages, right: Messages) -> Messages:
def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
def _add_messages(
left: Optional[Messages] = None, right: Optional[Messages] = None, **kwargs: Any
) -> Union[Messages, Callable[[Messages, Messages], Messages]]:
if left is not None and right is not None:
return func(left, right, **kwargs)
elif left is not None or right is not None:
msg = ""
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(msg)
else:
return partial(func, **kwargs)

_add_messages.__doc__ = func.__doc__
return cast(Callable[[Messages, Messages], Messages], _add_messages)


@_add_messages_wrapper
def add_messages(
left: Messages,
right: Messages,
*,
format: Optional[Literal["langchain-openai"]] = None,
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
) -> Messages:
"""Merges two lists of messages, updating existing messages by ID.

By default, this ensures the state is "append-only", unless the
Expand All @@ -25,6 +60,14 @@
left: The base list of messages.
right: The list of messages (or single message) to merge
into the base list.
format: The format to return messages in. If None then messages will be
returned as is. If 'langchain-openai' then messages will be returned as
BaseMessage objects with their contents formatted to match OpenAI message
format, meaning contents can be string, 'text' blocks, or 'image_url' blocks
and tool responses are returned as their own ToolMessages.

**REQUIREMENT**: Must have ``langchain-core>=0.3.11`` installed to use this
feature.

Returns:
A new list of messages with the messages from `right` merged into `left`.
Expand Down Expand Up @@ -58,8 +101,59 @@
>>> graph = builder.compile()
>>> graph.invoke({})
{'messages': [AIMessage(content='Hello', id=...)]}

>>> from typing import Annotated
>>> from typing_extensions import TypedDict
>>> from langgraph.graph import StateGraph, add_messages
>>>
>>> class State(TypedDict):
... messages: Annotated[list, add_messages(format='langchain-openai')]
...
>>> def chatbot_node(state: State) -> list:
... return {"messages": [
... {
... "role": "user",
... "content": [
... {
... "type": "text",
... "text": "Here's an image:",
... "cache_control": {"type": "ephemeral"},
... },
... {
... "type": "image",
... "source": {
... "type": "base64",
... "media_type": "image/jpeg",
... "data": "1234",
... },
... },
... ]
... },
... ]}
>>> builder = StateGraph(State)
>>> builder.add_node("chatbot", chatbot_node)
>>> builder.set_entry_point("chatbot")
>>> builder.set_finish_point("chatbot")
>>> graph = builder.compile()
>>> graph.invoke({"messages": []})
{
'messages': [
HumanMessage(
content=[
{"type": "text", "text": "Here's an image:"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,1234"},
},
],
),
]
}
```

..versionchanged:: 0.2.40

Support for 'format="langchain-openai"' flag added.
"""
# coerce to list
if not isinstance(left, list):
Expand Down Expand Up @@ -100,6 +194,15 @@

merged.append(m)
merged = [m for m in merged if m.id not in ids_to_remove]

if format == "langchain-openai":
merged = _format_messages(merged)
elif format:
msg = f"Unrecognized {format=}. Expected one of 'openai', None."
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(msg)
else:
pass

return merged


Expand Down Expand Up @@ -156,3 +259,21 @@

class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]


def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
try:
from langchain_core.messages import ( # type: ignore[attr-defined]
convert_to_openai_messages,
)
except ImportError:
msg = (
"Must have langchain-core>=0.3.11 installed to use automatic message "
"formatting (format='langchain-openai'). Please update your langchain-core "
"version or remove the 'format' flag. Returning un-formatted "
"messages."
)
warnings.warn(msg)
return list(messages)
else:
return convert_to_messages(convert_to_openai_messages(messages))
8 changes: 6 additions & 2 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,12 @@ def _is_field_binop(typ: Type[Any]) -> Optional[BinaryOperatorAggregate]:
if len(meta) >= 1 and callable(meta[-1]):
sig = signature(meta[0])
params = list(sig.parameters.values())
if len(params) == 2 and all(
p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) for p in params
if (
sum(
p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
for p in params
)
== 2
):
return BinaryOperatorAggregate(typ, meta[0])
else:
Expand Down
110 changes: 110 additions & 0 deletions libs/langgraph/tests/test_messages_state.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from typing import Annotated
from uuid import UUID

import langchain_core
import pytest
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
RemoveMessage,
SystemMessage,
ToolMessage,
)
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import TypedDict

from langgraph.graph import add_messages
from langgraph.graph.message import MessagesState
from langgraph.graph.state import END, START, StateGraph
from tests.conftest import IS_LANGCHAIN_CORE_030_OR_GREATER
from tests.messages import _AnyIdHumanMessage

CORE_MAJOR, CORE_MINOR, _ = (int(v) for v in langchain_core.__version__.split("."))


def test_add_single_message():
left = [HumanMessage(content="Hello", id="1")]
Expand Down Expand Up @@ -178,3 +183,108 @@ def foo(state):
_AnyIdHumanMessage(content="foo"),
]
}


@pytest.mark.skipif(
condition=CORE_MAJOR < 3 or CORE_MINOR < 11,
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
reason="Requires langchain_core>=0.3.11.",
)
def test_messages_state_format_openai():
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages(format="langchain-openai")]

def foo(state):
messages = [
HumanMessage(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these just be dicts instead?

content=[
{
"type": "text",
"text": "Here's an image:",
"cache_control": {"type": "ephemeral"},
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "1234",
},
},
]
),
AIMessage(
content=[
{
"type": "tool_use",
"name": "foo",
"input": {"bar": "baz"},
"id": "1",
}
]
),
HumanMessage(
content=[
{
"type": "tool_result",
"tool_use_id": "1",
"is_error": False,
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "1234",
},
},
],
}
]
),
]
return {"messages": messages}

expected = [
HumanMessage(content="meow"),
HumanMessage(
content=[
{"type": "text", "text": "Here's an image:"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,1234"},
},
],
),
AIMessage(
content="",
tool_calls=[
{
"name": "foo",
"type": "tool_calls",
"args": {"bar": "baz"},
"id": "1",
}
],
),
ToolMessage(
content=[
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,1234"},
}
],
tool_call_id="1",
),
]

graph = StateGraph(State)
graph.add_edge(START, "foo")
graph.add_edge("foo", END)
graph.add_node(foo)

app = graph.compile()

result = app.invoke({"messages": [("user", "meow")]})
for m in result["messages"]:
m.id = None
assert result == {"messages": expected}
Loading