Skip to content

Commit

Permalink
Console to return last processed message (#4279)
Browse files Browse the repository at this point in the history
* Console to return last processed (#4277)

* Preserve input generator type

* Add tests

* format

---------

Co-authored-by: Jack Gerrits <[email protected]>
Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent 611666c commit 8593b7d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys
import time
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Optional, TypeVar, cast

from autogen_core.components import Image
from autogen_core.components.models import RequestUsage
Expand All @@ -18,23 +18,32 @@ def _is_output_a_tty() -> bool:
return sys.stdout.isatty()


T = TypeVar("T", bound=TaskResult | Response)


async def Console(
stream: AsyncGenerator[AgentMessage | TaskResult, None] | AsyncGenerator[AgentMessage | Response, None],
stream: AsyncGenerator[AgentMessage | T, None],
*,
no_inline_images: bool = False,
) -> None:
"""Consume the stream from :meth:`~autogen_agentchat.base.Team.run_stream`
) -> T:
"""
Consume the stream from :meth:`~autogen_agentchat.base.Team.run_stream`
or :meth:`~autogen_agentchat.base.ChatAgent.on_messages_stream`
and print the messages to the console.
print the messages to the console and return the last processed TaskResult or Response.
Args:
stream (AsyncGenerator[AgentMessage | TaskResult, None] | AsyncGenerator[AgentMessage | Response, None]): Stream to render
no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False.
"""
Returns:
last_processed: The last processed TaskResult or Response.
"""
render_image_iterm = _is_running_in_iterm() and _is_output_a_tty() and not no_inline_images
start_time = time.time()
total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)

last_processed: Optional[T] = None

async for message in stream:
if isinstance(message, TaskResult):
duration = time.time() - start_time
Expand All @@ -47,6 +56,9 @@ async def Console(
f"Duration: {duration:.2f} seconds\n"
)
sys.stdout.write(output)
# mypy ignore
last_processed = message # type: ignore

elif isinstance(message, Response):
duration = time.time() - start_time

Expand All @@ -71,14 +83,24 @@ async def Console(
f"Duration: {duration:.2f} seconds\n"
)
sys.stdout.write(output)
# mypy ignore
last_processed = message # type: ignore

else:
# Cast required for mypy to be happy
message = cast(AgentMessage, message) # type: ignore
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n"
if message.models_usage:
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
sys.stdout.write(output)

if last_processed is None:
raise ValueError("No TaskResult or Response was processed.")

return last_processed


# iTerm2 image rendering protocol: https://iterm2.com/documentation-images.html
def _image_to_iterm(image: Image) -> str:
Expand Down
41 changes: 40 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_agentchat.task import HandoffTermination, MaxMessageTermination, TextMentionTermination
from autogen_agentchat.task import Console, HandoffTermination, MaxMessageTermination, TextMentionTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
Expand Down Expand Up @@ -315,6 +315,14 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert message == result.messages[index]
index += 1

# Test Console.
tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result


@pytest.mark.asyncio
async def test_round_robin_group_chat_with_resume_and_reset() -> None:
Expand Down Expand Up @@ -476,6 +484,14 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
assert message == result.messages[index]
index += 1

# Test Console.
mock.reset()
agent1._count = 0 # pyright: ignore
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result


@pytest.mark.asyncio
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -528,6 +544,14 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
assert message == result.messages[index]
index += 1

# Test Console.
mock.reset()
agent1._count = 0 # pyright: ignore
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result


@pytest.mark.asyncio
async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -595,6 +619,13 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
assert message == result.messages[index]
index += 1

# Test Console.
mock.reset()
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result


@pytest.mark.asyncio
async def test_selector_group_chat_custom_selector(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -792,6 +823,14 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
assert message == result.messages[index]
index += 1

# Test Console
agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="task"))
assert result2 == result


@pytest.mark.asyncio
async def test_swarm_pause_and_resume() -> None:
Expand Down

0 comments on commit 8593b7d

Please sign in to comment.