Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Console to return last processed message #4279

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys
import time
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Optional, TypeVar, cast

from autogen_core.components import Image
from autogen_core.components.models import RequestUsage
Expand All @@ -18,23 +18,32 @@ def _is_output_a_tty() -> bool:
return sys.stdout.isatty()


T = TypeVar("T", bound=TaskResult | Response)


async def Console(
stream: AsyncGenerator[AgentMessage | TaskResult, None] | AsyncGenerator[AgentMessage | Response, None],
stream: AsyncGenerator[AgentMessage | T, None],
*,
no_inline_images: bool = False,
) -> None:
"""Consume the stream from :meth:`~autogen_agentchat.base.Team.run_stream`
) -> T:
"""
Consume the stream from :meth:`~autogen_agentchat.base.Team.run_stream`
or :meth:`~autogen_agentchat.base.ChatAgent.on_messages_stream`
and print the messages to the console.
print the messages to the console and return the last processed TaskResult or Response.

Args:
stream (AsyncGenerator[AgentMessage | TaskResult, None] | AsyncGenerator[AgentMessage | Response, None]): Stream to render
no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False.
"""

Returns:
last_processed: The last processed TaskResult or Response.
"""
render_image_iterm = _is_running_in_iterm() and _is_output_a_tty() and not no_inline_images
start_time = time.time()
total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)

last_processed: Optional[T] = None

async for message in stream:
if isinstance(message, TaskResult):
duration = time.time() - start_time
Expand All @@ -47,6 +56,9 @@ async def Console(
f"Duration: {duration:.2f} seconds\n"
)
sys.stdout.write(output)
# mypy ignore
last_processed = message # type: ignore

elif isinstance(message, Response):
duration = time.time() - start_time

Expand All @@ -71,14 +83,24 @@ async def Console(
f"Duration: {duration:.2f} seconds\n"
)
sys.stdout.write(output)
# mypy ignore
last_processed = message # type: ignore

else:
# Cast required for mypy to be happy
message = cast(AgentMessage, message) # type: ignore
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n"
if message.models_usage:
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
sys.stdout.write(output)

if last_processed is None:
raise ValueError("No TaskResult or Response was processed.")

return last_processed


# iTerm2 image rendering protocol: https://iterm2.com/documentation-images.html
def _image_to_iterm(image: Image) -> str:
Expand Down
41 changes: 40 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_agentchat.task import HandoffTermination, MaxMessageTermination, TextMentionTermination
from autogen_agentchat.task import Console, HandoffTermination, MaxMessageTermination, TextMentionTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
Expand Down Expand Up @@ -315,6 +315,14 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert message == result.messages[index]
index += 1

# Test Console.
tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result


@pytest.mark.asyncio
async def test_round_robin_group_chat_with_resume_and_reset() -> None:
Expand Down Expand Up @@ -476,6 +484,14 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
assert message == result.messages[index]
index += 1

# Test Console.
mock.reset()
agent1._count = 0 # pyright: ignore
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result


@pytest.mark.asyncio
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -528,6 +544,14 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
assert message == result.messages[index]
index += 1

# Test Console.
mock.reset()
agent1._count = 0 # pyright: ignore
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result


@pytest.mark.asyncio
async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -595,6 +619,13 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
assert message == result.messages[index]
index += 1

# Test Console.
mock.reset()
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result


@pytest.mark.asyncio
async def test_selector_group_chat_custom_selector(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -792,6 +823,14 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
assert message == result.messages[index]
index += 1

# Test Console
agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="task"))
assert result2 == result


@pytest.mark.asyncio
async def test_swarm_pause_and_resume() -> None:
Expand Down
Loading