From 4f1bf4fa7a6e0f2c208e824d467686151cf161a7 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:19:04 -0800 Subject: [PATCH] langgraph[patch]: format messages in state (#2199) Add `format` flag to `add_messages` which allows you to specify if the contents of messages in state should be formatted in a particular way. PR only adds support for OpenAI style contents. Helpful if you're using different models at different nodes and want a unified messages format to interact with when you manually update messages. --- libs/langgraph/langgraph/graph/message.py | 126 +++++++++++++++++++- libs/langgraph/langgraph/graph/state.py | 8 +- libs/langgraph/poetry.lock | 12 +- libs/langgraph/tests/test_messages_state.py | 110 +++++++++++++++++ 4 files changed, 246 insertions(+), 10 deletions(-) diff --git a/libs/langgraph/langgraph/graph/message.py b/libs/langgraph/langgraph/graph/message.py index 6575bd10c..e63ebee83 100644 --- a/libs/langgraph/langgraph/graph/message.py +++ b/libs/langgraph/langgraph/graph/message.py @@ -1,8 +1,21 @@ import uuid -from typing import Annotated, TypedDict, Union, cast +import warnings +from functools import partial +from typing import ( + Annotated, + Any, + Callable, + Literal, + Optional, + Sequence, + TypedDict, + Union, + cast, +) from langchain_core.messages import ( AnyMessage, + BaseMessage, BaseMessageChunk, MessageLikeRepresentation, RemoveMessage, @@ -15,7 +28,32 @@ Messages = Union[list[MessageLikeRepresentation], MessageLikeRepresentation] -def add_messages(left: Messages, right: Messages) -> Messages: +def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]: + def _add_messages( + left: Optional[Messages] = None, right: Optional[Messages] = None, **kwargs: Any + ) -> Union[Messages, Callable[[Messages, Messages], Messages]]: + if left is not None and right is not None: + return func(left, right, **kwargs) + elif left is not None or right is not None: + msg = ( + f"Must specify non-null arguments for both 'left' and 'right'. Only " + f"received: '{'left' if left else 'right'}'." + ) + raise ValueError(msg) + else: + return partial(func, **kwargs) + + _add_messages.__doc__ = func.__doc__ + return cast(Callable[[Messages, Messages], Messages], _add_messages) + + +@_add_messages_wrapper +def add_messages( + left: Messages, + right: Messages, + *, + format: Optional[Literal["langchain-openai"]] = None, +) -> Messages: """Merges two lists of messages, updating existing messages by ID. By default, this ensures the state is "append-only", unless the @@ -25,6 +63,14 @@ def add_messages(left: Messages, right: Messages) -> Messages: left: The base list of messages. right: The list of messages (or single message) to merge into the base list. + format: The format to return messages in. If None then messages will be + returned as is. If 'langchain-openai' then messages will be returned as + BaseMessage objects with their contents formatted to match OpenAI message + format, meaning contents can be string, 'text' blocks, or 'image_url' blocks + and tool responses are returned as their own ToolMessages. + + **REQUIREMENT**: Must have ``langchain-core>=0.3.11`` installed to use this + feature. Returns: A new list of messages with the messages from `right` merged into `left`. @@ -58,8 +104,59 @@ def add_messages(left: Messages, right: Messages) -> Messages: >>> graph = builder.compile() >>> graph.invoke({}) {'messages': [AIMessage(content='Hello', id=...)]} + + >>> from typing import Annotated + >>> from typing_extensions import TypedDict + >>> from langgraph.graph import StateGraph, add_messages + >>> + >>> class State(TypedDict): + ... messages: Annotated[list, add_messages(format='langchain-openai')] + ... + >>> def chatbot_node(state: State) -> list: + ... return {"messages": [ + ... { + ... "role": "user", + ... "content": [ + ... { + ... "type": "text", + ... "text": "Here's an image:", + ... "cache_control": {"type": "ephemeral"}, + ... }, + ... { + ... "type": "image", + ... "source": { + ... "type": "base64", + ... "media_type": "image/jpeg", + ... "data": "1234", + ... }, + ... }, + ... ] + ... }, + ... ]} + >>> builder = StateGraph(State) + >>> builder.add_node("chatbot", chatbot_node) + >>> builder.set_entry_point("chatbot") + >>> builder.set_finish_point("chatbot") + >>> graph = builder.compile() + >>> graph.invoke({"messages": []}) + { + 'messages': [ + HumanMessage( + content=[ + {"type": "text", "text": "Here's an image:"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,1234"}, + }, + ], + ), + ] + } ``` + ..versionchanged:: 0.2.61 + + Support for 'format="langchain-openai"' flag added. """ # coerce to list if not isinstance(left, list): @@ -100,6 +197,15 @@ def add_messages(left: Messages, right: Messages) -> Messages: merged.append(m) merged = [m for m in merged if m.id not in ids_to_remove] + + if format == "langchain-openai": + merged = _format_messages(merged) + elif format: + msg = f"Unrecognized {format=}. Expected one of 'langchain-openai', None." + raise ValueError(msg) + else: + pass + return merged @@ -156,3 +262,19 @@ def __init__(self) -> None: class MessagesState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] + + +def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]: + try: + from langchain_core.messages import convert_to_openai_messages + except ImportError: + msg = ( + "Must have langchain-core>=0.3.11 installed to use automatic message " + "formatting (format='langchain-openai'). Please update your langchain-core " + "version or remove the 'format' flag. Returning un-formatted " + "messages." + ) + warnings.warn(msg) + return list(messages) + else: + return convert_to_messages(convert_to_openai_messages(messages)) diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 7a5614f91..e412db2d5 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -961,8 +961,12 @@ def _is_field_binop(typ: Type[Any]) -> Optional[BinaryOperatorAggregate]: if len(meta) >= 1 and callable(meta[-1]): sig = signature(meta[-1]) params = list(sig.parameters.values()) - if len(params) == 2 and all( - p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) for p in params + if ( + sum( + p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + for p in params + ) + == 2 ): return BinaryOperatorAggregate(typ, meta[-1]) else: diff --git a/libs/langgraph/poetry.lock b/libs/langgraph/poetry.lock index bdb2a4da6..deace7b65 100644 --- a/libs/langgraph/poetry.lock +++ b/libs/langgraph/poetry.lock @@ -1325,18 +1325,18 @@ files = [ [[package]] name = "langchain-core" -version = "0.3.23" +version = "0.3.25" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.3.23-py3-none-any.whl", hash = "sha256:550c0b996990830fa6515a71a1192a8a0343367999afc36d4ede14222941e420"}, - {file = "langchain_core-0.3.23.tar.gz", hash = "sha256:f9e175e3b82063cc3b160c2ca2b155832e1c6f915312e1204828f97d4aabf6e1"}, + {file = "langchain_core-0.3.25-py3-none-any.whl", hash = "sha256:e10581c6c74ba16bdc6fdf16b00cced2aa447cc4024ed19746a1232918edde38"}, + {file = "langchain_core-0.3.25.tar.gz", hash = "sha256:fdb8df41e5cdd928c0c2551ebbde1cea770ee3c64598395367ad77ddf9acbae7"}, ] [package.dependencies] jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.1.125,<0.2.0" +langsmith = ">=0.1.125,<0.3" packaging = ">=23.2,<25" pydantic = [ {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}, @@ -1348,7 +1348,7 @@ typing-extensions = ">=4.7" [[package]] name = "langgraph-checkpoint" -version = "2.0.8" +version = "2.0.9" description = "Library with base interfaces for LangGraph checkpoint savers." optional = false python-versions = "^3.9.0,<4.0" @@ -1418,7 +1418,7 @@ url = "../checkpoint-sqlite" [[package]] name = "langgraph-sdk" -version = "0.1.43" +version = "0.1.47" description = "SDK for interacting with LangGraph API" optional = false python-versions = "^3.9.0,<4.0" diff --git a/libs/langgraph/tests/test_messages_state.py b/libs/langgraph/tests/test_messages_state.py index ff8d064d6..787774baf 100644 --- a/libs/langgraph/tests/test_messages_state.py +++ b/libs/langgraph/tests/test_messages_state.py @@ -1,6 +1,7 @@ from typing import Annotated from uuid import UUID +import langchain_core import pytest from langchain_core.messages import ( AIMessage, @@ -8,9 +9,11 @@ HumanMessage, RemoveMessage, SystemMessage, + ToolMessage, ) from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModelV1 +from typing_extensions import TypedDict from langgraph.graph import add_messages from langgraph.graph.message import MessagesState @@ -18,6 +21,8 @@ from tests.conftest import IS_LANGCHAIN_CORE_030_OR_GREATER from tests.messages import _AnyIdHumanMessage +_, CORE_MINOR, CORE_PATCH = (int(v) for v in langchain_core.__version__.split(".")) + def test_add_single_message(): left = [HumanMessage(content="Hello", id="1")] @@ -178,3 +183,108 @@ def foo(state): _AnyIdHumanMessage(content="foo"), ] } + + +@pytest.mark.skipif( + condition=not ((CORE_MINOR == 3 and CORE_PATCH >= 11) or CORE_MINOR > 3), + reason="Requires langchain_core>=0.3.11.", +) +def test_messages_state_format_openai(): + class State(TypedDict): + messages: Annotated[list[AnyMessage], add_messages(format="langchain-openai")] + + def foo(state): + messages = [ + HumanMessage( + content=[ + { + "type": "text", + "text": "Here's an image:", + "cache_control": {"type": "ephemeral"}, + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "1234", + }, + }, + ] + ), + AIMessage( + content=[ + { + "type": "tool_use", + "name": "foo", + "input": {"bar": "baz"}, + "id": "1", + } + ] + ), + HumanMessage( + content=[ + { + "type": "tool_result", + "tool_use_id": "1", + "is_error": False, + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "1234", + }, + }, + ], + } + ] + ), + ] + return {"messages": messages} + + expected = [ + HumanMessage(content="meow"), + HumanMessage( + content=[ + {"type": "text", "text": "Here's an image:"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,1234"}, + }, + ], + ), + AIMessage( + content="", + tool_calls=[ + { + "name": "foo", + "type": "tool_calls", + "args": {"bar": "baz"}, + "id": "1", + } + ], + ), + ToolMessage( + content=[ + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,1234"}, + } + ], + tool_call_id="1", + ), + ] + + graph = StateGraph(State) + graph.add_edge(START, "foo") + graph.add_edge("foo", END) + graph.add_node(foo) + + app = graph.compile() + + result = app.invoke({"messages": [("user", "meow")]}) + for m in result["messages"]: + m.id = None + assert result == {"messages": expected}