Skip to content

Commit

Permalink
langgraph[patch]: format messages in state (#2199)
Browse files Browse the repository at this point in the history
Add `format` flag to `add_messages` which allows you to specify if the
contents of messages in state should be formatted in a particular way.
PR only adds support for OpenAI style contents. Helpful if you're using
different models at different nodes and want a unified messages format
to interact with when you manually update messages.
  • Loading branch information
baskaryan authored Dec 18, 2024
1 parent 22fa673 commit 4f1bf4f
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 10 deletions.
126 changes: 124 additions & 2 deletions libs/langgraph/langgraph/graph/message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import uuid
from typing import Annotated, TypedDict, Union, cast
import warnings
from functools import partial
from typing import (
Annotated,
Any,
Callable,
Literal,
Optional,
Sequence,
TypedDict,
Union,
cast,
)

from langchain_core.messages import (
AnyMessage,
BaseMessage,
BaseMessageChunk,
MessageLikeRepresentation,
RemoveMessage,
Expand All @@ -15,7 +28,32 @@
Messages = Union[list[MessageLikeRepresentation], MessageLikeRepresentation]


def add_messages(left: Messages, right: Messages) -> Messages:
def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
def _add_messages(
left: Optional[Messages] = None, right: Optional[Messages] = None, **kwargs: Any
) -> Union[Messages, Callable[[Messages, Messages], Messages]]:
if left is not None and right is not None:
return func(left, right, **kwargs)
elif left is not None or right is not None:
msg = (
f"Must specify non-null arguments for both 'left' and 'right'. Only "
f"received: '{'left' if left else 'right'}'."
)
raise ValueError(msg)
else:
return partial(func, **kwargs)

_add_messages.__doc__ = func.__doc__
return cast(Callable[[Messages, Messages], Messages], _add_messages)


@_add_messages_wrapper
def add_messages(
left: Messages,
right: Messages,
*,
format: Optional[Literal["langchain-openai"]] = None,
) -> Messages:
"""Merges two lists of messages, updating existing messages by ID.
By default, this ensures the state is "append-only", unless the
Expand All @@ -25,6 +63,14 @@ def add_messages(left: Messages, right: Messages) -> Messages:
left: The base list of messages.
right: The list of messages (or single message) to merge
into the base list.
format: The format to return messages in. If None then messages will be
returned as is. If 'langchain-openai' then messages will be returned as
BaseMessage objects with their contents formatted to match OpenAI message
format, meaning contents can be string, 'text' blocks, or 'image_url' blocks
and tool responses are returned as their own ToolMessages.
**REQUIREMENT**: Must have ``langchain-core>=0.3.11`` installed to use this
feature.
Returns:
A new list of messages with the messages from `right` merged into `left`.
Expand Down Expand Up @@ -58,8 +104,59 @@ def add_messages(left: Messages, right: Messages) -> Messages:
>>> graph = builder.compile()
>>> graph.invoke({})
{'messages': [AIMessage(content='Hello', id=...)]}
>>> from typing import Annotated
>>> from typing_extensions import TypedDict
>>> from langgraph.graph import StateGraph, add_messages
>>>
>>> class State(TypedDict):
... messages: Annotated[list, add_messages(format='langchain-openai')]
...
>>> def chatbot_node(state: State) -> list:
... return {"messages": [
... {
... "role": "user",
... "content": [
... {
... "type": "text",
... "text": "Here's an image:",
... "cache_control": {"type": "ephemeral"},
... },
... {
... "type": "image",
... "source": {
... "type": "base64",
... "media_type": "image/jpeg",
... "data": "1234",
... },
... },
... ]
... },
... ]}
>>> builder = StateGraph(State)
>>> builder.add_node("chatbot", chatbot_node)
>>> builder.set_entry_point("chatbot")
>>> builder.set_finish_point("chatbot")
>>> graph = builder.compile()
>>> graph.invoke({"messages": []})
{
'messages': [
HumanMessage(
content=[
{"type": "text", "text": "Here's an image:"},
{
"type": "image_url",
"image_url": {"url": ""},
},
],
),
]
}
```
..versionchanged:: 0.2.61
Support for 'format="langchain-openai"' flag added.
"""
# coerce to list
if not isinstance(left, list):
Expand Down Expand Up @@ -100,6 +197,15 @@ def add_messages(left: Messages, right: Messages) -> Messages:

merged.append(m)
merged = [m for m in merged if m.id not in ids_to_remove]

if format == "langchain-openai":
merged = _format_messages(merged)
elif format:
msg = f"Unrecognized {format=}. Expected one of 'langchain-openai', None."
raise ValueError(msg)
else:
pass

return merged


Expand Down Expand Up @@ -156,3 +262,19 @@ def __init__(self) -> None:

class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]


def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
try:
from langchain_core.messages import convert_to_openai_messages
except ImportError:
msg = (
"Must have langchain-core>=0.3.11 installed to use automatic message "
"formatting (format='langchain-openai'). Please update your langchain-core "
"version or remove the 'format' flag. Returning un-formatted "
"messages."
)
warnings.warn(msg)
return list(messages)
else:
return convert_to_messages(convert_to_openai_messages(messages))
8 changes: 6 additions & 2 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,8 +961,12 @@ def _is_field_binop(typ: Type[Any]) -> Optional[BinaryOperatorAggregate]:
if len(meta) >= 1 and callable(meta[-1]):
sig = signature(meta[-1])
params = list(sig.parameters.values())
if len(params) == 2 and all(
p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) for p in params
if (
sum(
p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
for p in params
)
== 2
):
return BinaryOperatorAggregate(typ, meta[-1])
else:
Expand Down
12 changes: 6 additions & 6 deletions libs/langgraph/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

110 changes: 110 additions & 0 deletions libs/langgraph/tests/test_messages_state.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from typing import Annotated
from uuid import UUID

import langchain_core
import pytest
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
RemoveMessage,
SystemMessage,
ToolMessage,
)
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import TypedDict

from langgraph.graph import add_messages
from langgraph.graph.message import MessagesState
from langgraph.graph.state import END, START, StateGraph
from tests.conftest import IS_LANGCHAIN_CORE_030_OR_GREATER
from tests.messages import _AnyIdHumanMessage

_, CORE_MINOR, CORE_PATCH = (int(v) for v in langchain_core.__version__.split("."))


def test_add_single_message():
left = [HumanMessage(content="Hello", id="1")]
Expand Down Expand Up @@ -178,3 +183,108 @@ def foo(state):
_AnyIdHumanMessage(content="foo"),
]
}


@pytest.mark.skipif(
condition=not ((CORE_MINOR == 3 and CORE_PATCH >= 11) or CORE_MINOR > 3),
reason="Requires langchain_core>=0.3.11.",
)
def test_messages_state_format_openai():
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages(format="langchain-openai")]

def foo(state):
messages = [
HumanMessage(
content=[
{
"type": "text",
"text": "Here's an image:",
"cache_control": {"type": "ephemeral"},
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "1234",
},
},
]
),
AIMessage(
content=[
{
"type": "tool_use",
"name": "foo",
"input": {"bar": "baz"},
"id": "1",
}
]
),
HumanMessage(
content=[
{
"type": "tool_result",
"tool_use_id": "1",
"is_error": False,
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "1234",
},
},
],
}
]
),
]
return {"messages": messages}

expected = [
HumanMessage(content="meow"),
HumanMessage(
content=[
{"type": "text", "text": "Here's an image:"},
{
"type": "image_url",
"image_url": {"url": ""},
},
],
),
AIMessage(
content="",
tool_calls=[
{
"name": "foo",
"type": "tool_calls",
"args": {"bar": "baz"},
"id": "1",
}
],
),
ToolMessage(
content=[
{
"type": "image_url",
"image_url": {"url": ""},
}
],
tool_call_id="1",
),
]

graph = StateGraph(State)
graph.add_edge(START, "foo")
graph.add_edge("foo", END)
graph.add_node(foo)

app = graph.compile()

result = app.invoke({"messages": [("user", "meow")]})
for m in result["messages"]:
m.id = None
assert result == {"messages": expected}

0 comments on commit 4f1bf4f

Please sign in to comment.