Skip to content

Commit

Permalink
standard-tests: split tool calling test (#20803)
Browse files Browse the repository at this point in the history
just making it a bit easier to grok
  • Loading branch information
efriis authored and hinthornw committed Apr 26, 2024
1 parent 537f862 commit 415cbfa
Showing 1 changed file with 70 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,69 +117,88 @@ async def test_abatch(
assert isinstance(result.content, str)
assert len(result.content) > 0

def test_tool_message_histories(
def test_tool_message_histories_string_content(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
"""Test that message histories are compatible across providers."""
"""
Test that message histories are compatible with string tool contents
(e.g. OpenAI).
"""
if not chat_model_has_tool_calling:
pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params)
model_with_tools = model.bind_tools([my_adder_tool])
function_name = "my_adder_tool"
function_args = {"a": "1", "b": "2"}

human_message = HumanMessage(content="What is 1 + 2")
tool_message = ToolMessage(
name=function_name,
content=json.dumps({"result": 3}),
tool_call_id="abc123",
)

# String content (e.g., OpenAI)
string_content_msg = AIMessage(
content="",
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
},
],
)
messages = [
human_message,
string_content_msg,
tool_message,
messages_string_content = [
HumanMessage(content="What is 1 + 2"),
# string content (e.g. OpenAI)
AIMessage(
content="",
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
},
],
),
ToolMessage(
name=function_name,
content=json.dumps({"result": 3}),
tool_call_id="abc123",
),
]
result = model_with_tools.invoke(messages)
assert isinstance(result, AIMessage)
result_string_content = model_with_tools.invoke(messages_string_content)
assert isinstance(result_string_content, AIMessage)

# List content (e.g., Anthropic)
list_content_msg = AIMessage(
content=[
{"type": "text", "text": "some text"},
{
"type": "tool_use",
"id": "abc123",
"name": function_name,
"input": function_args,
},
],
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
},
],
)
messages = [
human_message,
list_content_msg,
tool_message,
def test_tool_message_histories_list_content(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
"""
Test that message histories are compatible with list tool contents
(e.g. Anthropic).
"""
if not chat_model_has_tool_calling:
pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params)
model_with_tools = model.bind_tools([my_adder_tool])
function_name = "my_adder_tool"
function_args = {"a": 1, "b": 2}

messages_list_content = [
HumanMessage(content="What is 1 + 2"),
# List content (e.g., Anthropic)
AIMessage(
content=[
{"type": "text", "text": "some text"},
{
"type": "tool_use",
"id": "abc123",
"name": function_name,
"input": function_args,
},
],
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
},
],
),
ToolMessage(
name=function_name,
content=json.dumps({"result": 3}),
tool_call_id="abc123",
),
]
result = model_with_tools.invoke(messages)
assert isinstance(result, AIMessage)
result_list_content = model_with_tools.invoke(messages_list_content)
assert isinstance(result_list_content, AIMessage)

0 comments on commit 415cbfa

Please sign in to comment.