Skip to content

Commit

Permalink
Support message trimming on single messages (#27729)
Browse files Browse the repository at this point in the history
Permit trimming message lists of length 1
  • Loading branch information
hinthornw authored Oct 30, 2024
1 parent 5111063 commit 5a2cfb4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
5 changes: 3 additions & 2 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,13 +1210,14 @@ def _first_max_tokens(
] = None,
) -> list[BaseMessage]:
messages = list(messages)
if not messages:
return messages
idx = 0
for i in range(len(messages)):
if token_counter(messages[:-i] if i else messages) <= max_tokens:
idx = len(messages) - i
break

if idx < len(messages) - 1 and partial_strategy:
if partial_strategy and (idx < len(messages) - 1 or idx == 0):
included_partial = False
if isinstance(messages[idx].content, list):
excluded = messages[idx].model_copy(deep=True)
Expand Down
36 changes: 36 additions & 0 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,42 @@ def test_trim_messages_last_40_include_system_allow_partial_start_on_human() ->
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY


def test_trim_messages_allow_partial_one_message() -> None:
expected = [
HumanMessage("Th", id="third"),
]

actual = trim_messages(
[HumanMessage("This is a funky text.", id="third")],
max_tokens=2,
token_counter=lambda messages: sum(len(m.content) for m in messages),
text_splitter=lambda x: list(x),
strategy="first",
allow_partial=True,
)

assert actual == expected
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY


def test_trim_messages_last_allow_partial_one_message() -> None:
expected = [
HumanMessage("t.", id="third"),
]

actual = trim_messages(
[HumanMessage("This is a funky text.", id="third")],
max_tokens=2,
token_counter=lambda messages: sum(len(m.content) for m in messages),
text_splitter=lambda x: list(x),
strategy="last",
allow_partial=True,
)

assert actual == expected
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY


def test_trim_messages_allow_partial_text_splitter() -> None:
expected = [
HumanMessage("a 4 token text.", id="third"),
Expand Down

0 comments on commit 5a2cfb4

Please sign in to comment.