From d475e6a5911424733ff87209d586716d6e0ce210 Mon Sep 17 00:00:00 2001 From: RussellLuo Date: Mon, 27 Jan 2025 16:32:00 +0800 Subject: [PATCH] Add tests for StructuredOutput --- coagent/agents/__init__.py | 7 +- tests/agents/test_messages.py | 119 ++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 tests/agents/test_messages.py diff --git a/coagent/agents/__init__.py b/coagent/agents/__init__.py index fd246a9..ad21703 100644 --- a/coagent/agents/__init__.py +++ b/coagent/agents/__init__.py @@ -2,7 +2,12 @@ from .chat_agent import ChatAgent, confirm, submit, RunContext, tool from .dynamic_triage import DynamicTriage from .mcp_agent import MCPAgent -from .messages import ChatHistory, ChatMessage +from .messages import ( + ChatHistory, + ChatMessage, + StructuredOutput, + type_to_response_format_param, +) from .model_client import ModelClient from .parallel import Aggregator, AggregationResult, Parallel from .sequential import Sequential diff --git a/tests/agents/test_messages.py b/tests/agents/test_messages.py new file mode 100644 index 0000000..d7c0f93 --- /dev/null +++ b/tests/agents/test_messages.py @@ -0,0 +1,119 @@ +from pydantic import BaseModel, ValidationError +import pytest + +from coagent.agents import ChatHistory, ChatMessage, StructuredOutput + + +class FriendInfo(BaseModel): + name: str + age: int + is_available: bool + + +want_output_schema = { + "json_schema": { + "name": "FriendInfo", + "schema": { + "additionalProperties": False, + "properties": { + "age": { + "title": "Age", + "type": "integer", + }, + "is_available": { + "title": "Is Available", + "type": "boolean", + }, + "name": { + "title": "Name", + "type": "string", + }, + }, + "required": [ + "name", + "age", + "is_available", + ], + "title": "FriendInfo", + "type": "object", + }, + "strict": True, + }, + "type": "json_schema", +} + + +class TestStructuredOutput: + @pytest.mark.asyncio + async def test_chat_message(self): + # Test model_dump + output = StructuredOutput( + input=ChatMessage(role="user", content="I have a friend."), + output_type=FriendInfo, + ) + want_output_dict = { + "input": { + "__message_type__": "ChatMessage", + "content": "I have a friend.", + "role": "user", + }, + "output_schema": want_output_schema, + "output_type": None, + } + assert output.model_dump(exclude_defaults=True) == want_output_dict + + # Test model_validate + output2 = StructuredOutput.model_validate(want_output_dict) + assert isinstance(output2.input, ChatMessage) + assert output2.input.role == "user" + assert output2.input.content == "I have a friend." + + @pytest.mark.asyncio + async def test_chat_history(self): + # Test model_dump + output = StructuredOutput( + input=ChatHistory( + messages=[ChatMessage(role="user", content="I have a friend.")] + ), + output_type=FriendInfo, + ) + want_output_dict = { + "input": { + "__message_type__": "ChatHistory", + "messages": [ + { + "content": "I have a friend.", + "role": "user", + } + ], + }, + "output_schema": want_output_schema, + "output_type": None, + } + assert output.model_dump(exclude_defaults=True) == want_output_dict + + # Test model_validate + output2 = StructuredOutput.model_validate(want_output_dict) + assert isinstance(output2.input, ChatHistory) + assert output2.input.messages[0].role == "user" + assert output2.input.messages[0].content == "I have a friend." + + @pytest.mark.asyncio + async def test_invalid_input(self): + class InvalidInput(BaseModel): + pass + + with pytest.raises(ValidationError) as exc: + _ = StructuredOutput( + input=InvalidInput(), + output_type=FriendInfo, + ) + + exc_value = str(exc.value) + assert "2 validation errors for StructuredOutput" in exc_value + assert ( + "Input should be a valid dictionary or instance of ChatMessage" in exc_value + ) + assert ( + "Input should be a valid dictionary or instance of ChatHistory" in exc_value + )