diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 142272292f609..3f25e02fb2381 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List, Type import pytest @@ -12,6 +13,7 @@ ToolMessage, ) from langchain_core.messages.utils import ( + convert_to_messages, filter_messages, merge_message_runs, trim_messages, @@ -357,3 +359,176 @@ def dummy_token_counter(messages: List[BaseMessage]) -> int: class FakeTokenCountingModel(FakeChatModel): def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: return dummy_token_counter(messages) + + +def test_convert_to_messages() -> None: + message_like: List = [ + # BaseMessage + SystemMessage("1"), + HumanMessage([{"type": "image_url", "image_url": {"url": "2.1"}}], name="2.2"), + AIMessage( + [ + {"type": "text", "text": "3.1"}, + { + "type": "tool_use", + "id": "3.2", + "name": "3.3", + "input": {"3.4": "3.5"}, + }, + ] + ), + AIMessage( + [ + {"type": "text", "text": "4.1"}, + { + "type": "tool_use", + "id": "4.2", + "name": "4.3", + "input": {"4.4": "4.5"}, + }, + ], + tool_calls=[ + { + "name": "4.3", + "args": {"4.4": "4.5"}, + "id": "4.2", + "type": "tool_call", + } + ], + ), + ToolMessage("5.1", tool_call_id="5.2", name="5.3"), + # OpenAI dict + {"role": "system", "content": "6"}, + { + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": "7.1"}}], + "name": "7.2", + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "8.1"}], + "tool_calls": [ + { + "type": "function", + "function": { + "arguments": json.dumps({"8.2": "8.3"}), + "name": "8.4", + }, + "id": "8.5", + } + ], + "name": "8.6", + }, + {"role": "tool", "content": "10.1", "tool_call_id": "10.2"}, + # Tuple/List + ("system", "11.1"), + ("human", [{"type": "image_url", "image_url": {"url": "12.1"}}]), + ( + "ai", + [ + {"type": "text", "text": "13.1"}, + { + "type": "tool_use", + "id": "13.2", + "name": "13.3", + "input": {"13.4": "13.5"}, + }, + ], + ), + # String + "14.1", + # LangChain dict + { + "role": "ai", + "content": [{"type": "text", "text": "15.1"}], + "tool_calls": [{"args": {"15.2": "15.3"}, "name": "15.4", "id": "15.5"}], + "name": "15.6", + }, + ] + expected = [ + SystemMessage(content="1"), + HumanMessage( + content=[{"type": "image_url", "image_url": {"url": "2.1"}}], name="2.2" + ), + AIMessage( + content=[ + {"type": "text", "text": "3.1"}, + { + "type": "tool_use", + "id": "3.2", + "name": "3.3", + "input": {"3.4": "3.5"}, + }, + ] + ), + AIMessage( + content=[ + {"type": "text", "text": "4.1"}, + { + "type": "tool_use", + "id": "4.2", + "name": "4.3", + "input": {"4.4": "4.5"}, + }, + ], + tool_calls=[ + { + "name": "4.3", + "args": {"4.4": "4.5"}, + "id": "4.2", + "type": "tool_call", + } + ], + ), + ToolMessage(content="5.1", name="5.3", tool_call_id="5.2"), + SystemMessage(content="6"), + HumanMessage( + content=[{"type": "image_url", "image_url": {"url": "7.1"}}], name="7.2" + ), + AIMessage( + content=[{"type": "text", "text": "8.1"}], + name="8.6", + tool_calls=[ + { + "name": "8.4", + "args": {"8.2": "8.3"}, + "id": "8.5", + "type": "tool_call", + } + ], + ), + ToolMessage(content="10.1", tool_call_id="10.2"), + SystemMessage(content="11.1"), + HumanMessage(content=[{"type": "image_url", "image_url": {"url": "12.1"}}]), + AIMessage( + content=[ + {"type": "text", "text": "13.1"}, + { + "type": "tool_use", + "id": "13.2", + "name": "13.3", + "input": {"13.4": "13.5"}, + }, + ] + ), + HumanMessage(content="14.1"), + AIMessage( + content=[{"type": "text", "text": "15.1"}], + name="15.6", + tool_calls=[ + { + "name": "15.4", + "args": {"15.2": "15.3"}, + "id": "15.5", + "type": "tool_call", + } + ], + ), + ] + actual = convert_to_messages(message_like) + assert expected == actual + + +@pytest.mark.xfail(reason="AI message does not support refusal key yet.") +def test_convert_to_messages_openai_refusal() -> None: + convert_to_messages([{"role": "assistant", "refusal": "9.1"}])