Skip to content

Commit

Permalink
anthropic[patch]: allow multiple sys not at start (#27725)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 30, 2024
1 parent 1ed3cd2 commit 6691202
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
15 changes: 9 additions & 6 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def _merge_messages(
]
)
last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if any(
all(isinstance(m, c) for m in (curr, last))
for c in (SystemMessage, HumanMessage)
):
if isinstance(last.content, str):
new_content: List = [{"type": "text", "text": last.content}]
else:
Expand All @@ -148,7 +151,7 @@ def _merge_messages(
new_content.append({"type": "text", "text": curr.content})
else:
new_content.extend(curr.content)
merged[-1] = curr.model_copy(update={"content": new_content}, deep=False)
merged[-1] = curr.model_copy(update={"content": new_content})
else:
merged.append(curr)
return merged
Expand All @@ -174,14 +177,14 @@ def _format_messages(
merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if isinstance(message.content, list):
if system is not None:
raise ValueError("Received multiple non-consecutive system messages.")
elif isinstance(message.content, list):
system = [
(
block
if isinstance(block, dict)
else {"type": "text", "text": "block"}
else {"type": "text", "text": block}
)
for block in message.content
]
Expand Down
22 changes: 22 additions & 0 deletions libs/partners/anthropic/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,28 @@ def test__format_messages_with_cache_control() -> None:
assert expected_messages == actual_messages


def test__format_messages_with_multiple_system() -> None:
messages = [
HumanMessage("baz"),
SystemMessage("bar"),
SystemMessage("baz"),
SystemMessage(
[
{"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}},
]
),
]
expected_system = [
{"type": "text", "text": "bar"},
{"type": "text", "text": "baz"},
{"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}},
]
expected_messages = [{"role": "user", "content": "baz"}]
actual_system, actual_messages = _format_messages(messages)
assert expected_system == actual_system
assert expected_messages == actual_messages


def test_anthropic_api_key_is_secret_string() -> None:
"""Test that the API key is stored as a SecretStr."""
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]
Expand Down

0 comments on commit 6691202

Please sign in to comment.