Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph[patch]: format messages in state #2199

Merged
merged 16 commits into from
Dec 18, 2024
52 changes: 50 additions & 2 deletions libs/langgraph/langgraph/graph/message.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import uuid

Check notice on line 1 in libs/langgraph/langgraph/graph/message.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 47.3 ms +- 1.0 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 43.1 ms +- 1.8 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 74.6 ms +- 2.0 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 82.0 ms +- 1.4 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 475 ms +- 13 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 417 ms +- 8 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 788 ms +- 52 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 811 ms +- 21 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (3.19 ms) is 11% of the mean (30.4 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. react_agent_10x: Mean +- std dev: 30.4 ms +- 3.2 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.1 ms +- 1.6 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 46.4 ms +- 3.4 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.6 ms +- 3.2 ms ......................................... react_agent_100x: Mean +- std dev: 309 ms +- 8 ms ......................................... react_agent_100x_sync: Mean +- std dev: 256 ms +- 12 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 899 ms +- 12 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 803 ms +- 8 ms ......................................... wide_state_25x300: Mean +- std dev: 18.1 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 10.7 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 270 ms +- 4 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 265 ms +- 14 ms ......................................... wide_state_15x600: Mean +- std dev: 20.7 ms +- 0.6 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 12.3 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 465 ms +- 7 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 462 ms +- 14 ms ......................................... wide_state_9x1200: Mean +- std dev: 21.0 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 12.4 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 303 ms +- 5 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 302 ms +- 16 ms

Check notice on line 1 in libs/langgraph/langgraph/graph/message.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | react_agent_100x | 335 ms | 309 ms: 1.08x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 46.5 ms | 43.1 ms: 1.08x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 835 ms | 788 ms: 1.06x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 86.1 ms | 82.0 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 847 ms | 811 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 836 ms | 803 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 38.1 ms | 36.6 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 48.1 ms | 46.4 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 77.2 ms | 74.6 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 491 ms | 475 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 21.4 ms | 20.7 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 264 ms | 256 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 430 ms | 417 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 12.6 ms | 12.3 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 48.5 ms | 47.3 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 18.5 ms | 18.1 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 919 ms | 899 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.5 ms | 22.1 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 10.9 ms | 10.7 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 474 ms | 465 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 21.3 ms | 21.0 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 470 ms | 462 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 12.6 ms | 12.4 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 307 ms | 303 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 273 ms | 270 ms: 1.01x faster | +---------------------------------------
from typing import Annotated, TypedDict, Union, cast
from functools import partial
from typing import (
Annotated,
Any,
Callable,
Literal,
Optional,
Sequence,
TypedDict,
Union,
cast,
)

from langchain_core.messages import (
AnyMessage,
BaseMessage,
BaseMessageChunk,
MessageLikeRepresentation,
RemoveMessage,
convert_to_messages,
convert_to_openai_messages,
message_chunk_to_message,
)

Expand All @@ -15,7 +28,29 @@
Messages = Union[list[MessageLikeRepresentation], MessageLikeRepresentation]


def add_messages(left: Messages, right: Messages) -> Messages:
def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
def _add_messages(
left: Optional[Messages] = None, right: Optional[Messages] = None, **kwargs: Any
) -> Union[Messages, Callable[[Messages, Messages], Messages]]:
if left is not None and right is not None:
return func(left, right, **kwargs)
elif left is not None or right is not None:
msg = ""
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(msg)
else:
return partial(func, **kwargs)

_add_messages.__doc__ = func.__doc__
return cast(Callable[[Messages, Messages], Messages], _add_messages)


@_add_messages_wrapper
def add_messages(
left: Messages,
right: Messages,
*,
content_format: Optional[Literal["openai"]] = None,
) -> Messages:
"""Merges two lists of messages, updating existing messages by ID.

By default, this ensures the state is "append-only", unless the
Expand Down Expand Up @@ -100,6 +135,15 @@

merged.append(m)
merged = [m for m in merged if m.id not in ids_to_remove]

if content_format == "openai":
merged = _format_messages_content(merged)
elif content_format:
msg = f"Unrecognized {content_format=}. Expected one of 'openai', None."
raise ValueError(msg)
else:
pass

return merged


Expand Down Expand Up @@ -156,3 +200,7 @@

class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]


def _format_messages_content(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
return convert_to_messages(convert_to_openai_messages(messages))
8 changes: 6 additions & 2 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,12 @@ def _is_field_binop(typ: Type[Any]) -> Optional[BinaryOperatorAggregate]:
if len(meta) >= 1 and callable(meta[-1]):
sig = signature(meta[0])
params = list(sig.parameters.values())
if len(params) == 2 and all(
p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) for p in params
if (
sum(
p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
for p in params
)
== 2
):
return BinaryOperatorAggregate(typ, meta[0])
else:
Expand Down
Loading
Loading