From ddc2274aeac1da8e8a189c8d0128ae81d5e84875 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Tue, 23 Apr 2024 13:59:45 -0700 Subject: [PATCH] standard-tests: split tool calling test (#20803) just making it a bit easier to grok --- .../integration_tests/chat_models.py | 121 ++++++++++-------- 1 file changed, 70 insertions(+), 51 deletions(-) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 734283f7298d6..15a92d7133f1e 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -117,13 +117,16 @@ async def test_abatch( assert isinstance(result.content, str) assert len(result.content) > 0 - def test_tool_message_histories( + def test_tool_message_histories_string_content( self, chat_model_class: Type[BaseChatModel], chat_model_params: dict, chat_model_has_tool_calling: bool, ) -> None: - """Test that message histories are compatible across providers.""" + """ + Test that message histories are compatible with string tool contents + (e.g. OpenAI). + """ if not chat_model_has_tool_calling: pytest.skip("Test requires tool calling.") model = chat_model_class(**chat_model_params) @@ -131,55 +134,71 @@ def test_tool_message_histories( function_name = "my_adder_tool" function_args = {"a": "1", "b": "2"} - human_message = HumanMessage(content="What is 1 + 2") - tool_message = ToolMessage( - name=function_name, - content=json.dumps({"result": 3}), - tool_call_id="abc123", - ) - - # String content (e.g., OpenAI) - string_content_msg = AIMessage( - content="", - tool_calls=[ - { - "name": function_name, - "args": function_args, - "id": "abc123", - }, - ], - ) - messages = [ - human_message, - string_content_msg, - tool_message, + messages_string_content = [ + HumanMessage(content="What is 1 + 2"), + # string content (e.g. OpenAI) + AIMessage( + content="", + tool_calls=[ + { + "name": function_name, + "args": function_args, + "id": "abc123", + }, + ], + ), + ToolMessage( + name=function_name, + content=json.dumps({"result": 3}), + tool_call_id="abc123", + ), ] - result = model_with_tools.invoke(messages) - assert isinstance(result, AIMessage) + result_string_content = model_with_tools.invoke(messages_string_content) + assert isinstance(result_string_content, AIMessage) - # List content (e.g., Anthropic) - list_content_msg = AIMessage( - content=[ - {"type": "text", "text": "some text"}, - { - "type": "tool_use", - "id": "abc123", - "name": function_name, - "input": function_args, - }, - ], - tool_calls=[ - { - "name": function_name, - "args": function_args, - "id": "abc123", - }, - ], - ) - messages = [ - human_message, - list_content_msg, - tool_message, + def test_tool_message_histories_list_content( + self, + chat_model_class: Type[BaseChatModel], + chat_model_params: dict, + chat_model_has_tool_calling: bool, + ) -> None: + """ + Test that message histories are compatible with list tool contents + (e.g. Anthropic). + """ + if not chat_model_has_tool_calling: + pytest.skip("Test requires tool calling.") + model = chat_model_class(**chat_model_params) + model_with_tools = model.bind_tools([my_adder_tool]) + function_name = "my_adder_tool" + function_args = {"a": 1, "b": 2} + + messages_list_content = [ + HumanMessage(content="What is 1 + 2"), + # List content (e.g., Anthropic) + AIMessage( + content=[ + {"type": "text", "text": "some text"}, + { + "type": "tool_use", + "id": "abc123", + "name": function_name, + "input": function_args, + }, + ], + tool_calls=[ + { + "name": function_name, + "args": function_args, + "id": "abc123", + }, + ], + ), + ToolMessage( + name=function_name, + content=json.dumps({"result": 3}), + tool_call_id="abc123", + ), ] - result = model_with_tools.invoke(messages) - assert isinstance(result, AIMessage) + result_list_content = model_with_tools.invoke(messages_list_content) + assert isinstance(result_list_content, AIMessage)