Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

async user hook support added #3583

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 102 additions & 5 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union

from openai import BadRequestError

Expand Down Expand Up @@ -247,10 +247,13 @@ def __init__(

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists: Dict[str, List[Callable]] = {
self.hook_lists: Dict[str, List[Union[Callable, Callable[..., Coroutine]]]] = {
"process_last_received_message": [],
"a_process_last_received_message": [],
"process_all_messages_before_reply": [],
"a_process_all_messages_before_reply": [],
"process_message_before_send": [],
"a_process_message_before_send": [],
}

def _validate_llm_config(self, llm_config):
Expand Down Expand Up @@ -680,11 +683,24 @@ def _process_message_before_send(
"""Process the message before sending it to the recipient."""
hook_list = self.hook_lists["process_message_before_send"]
for hook in hook_list:
if inspect.iscoroutinefunction(hook):
continue
message = hook(
sender=self, message=message, recipient=recipient, silent=ConversableAgent._is_silent(self, silent)
)
return message

async def _a_process_message_before_send(
self, message: Union[Dict, str], recipient: Agent, silent: bool
) -> Union[Dict, str]:
"""(async) Process the message before sending it to the recipient."""
hook_list = self.hook_lists["a_process_message_before_send"]
for hook in hook_list:
if not inspect.iscoroutinefunction(hook):
continue
message = await hook(sender=self, message=message, recipient=recipient, silent=silent)
return message

def send(
self,
message: Union[Dict, str],
Expand Down Expand Up @@ -774,7 +790,9 @@ async def a_send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
"""
message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent))
message = await self._a_process_message_before_send(
message, recipient, ConversableAgent._is_silent(self, silent)
)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient, is_sending=True)
Expand Down Expand Up @@ -2104,11 +2122,11 @@ async def a_generate_reply(

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages_before_reply(messages)
messages = await self.a_process_all_messages_before_reply(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_received_message(messages)
messages = await self.a_process_last_received_message(messages)

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
Expand Down Expand Up @@ -2786,6 +2804,19 @@ def register_hook(self, hookable_method: str, hook: Callable):
assert hookable_method in self.hook_lists, f"{hookable_method} is not a hookable method."
hook_list = self.hook_lists[hookable_method]
assert hook not in hook_list, f"{hook} is already registered as a hook."

# async hookable checks
expected_async = hookable_method.startswith("a_")
hook_is_async = inspect.iscoroutinefunction(hook)
if expected_async != hook_is_async:
context_type = "asynchronous" if expected_async else "synchronous"
warnings.warn(
f"Hook '{hook.__name__}' is {'asynchronous' if hook_is_async else 'synchronous'}, "
f"but it's being registered in a {context_type} context ('{hookable_method}'). "
"Ensure the hook matches the expected execution context.",
UserWarning,
)

hook_list.append(hook)

def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
Expand All @@ -2800,9 +2831,28 @@ def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
# Call each hook (in order of registration) to process the messages.
processed_messages = messages
for hook in hook_list:
if inspect.iscoroutinefunction(hook):
continue
processed_messages = hook(processed_messages)
return processed_messages

async def a_process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
"""
hook_list = self.hook_lists["a_process_all_messages_before_reply"]
# If no hooks are registered, or if there are no messages to process, return the original message list.
if len(hook_list) == 0 or messages is None:
return messages

# Call each hook (in order of registration) to process the messages.
processed_messages = messages
for hook in hook_list:
if not inspect.iscoroutinefunction(hook):
continue
processed_messages = await hook(processed_messages)
return processed_messages

def process_last_received_message(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
Expand Down Expand Up @@ -2836,6 +2886,8 @@ def process_last_received_message(self, messages: List[Dict]) -> List[Dict]:
# Call each hook (in order of registration) to process the user's message.
processed_user_content = user_content
for hook in hook_list:
if inspect.iscoroutinefunction(hook):
continue
processed_user_content = hook(processed_user_content)

if processed_user_content == user_content:
Expand All @@ -2846,6 +2898,51 @@ def process_last_received_message(self, messages: List[Dict]) -> List[Dict]:
messages[-1]["content"] = processed_user_content
return messages

async def a_process_last_received_message(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
as long as the last message is not a function call or exit command.
"""

# If any required condition is not met, return the original message list.
hook_list = self.hook_lists["a_process_last_received_message"]
if len(hook_list) == 0:
return messages # No hooks registered.
if messages is None:
return None # No message to process.
if len(messages) == 0:
return messages # No message to process.
last_message = messages[-1]
if "function_call" in last_message:
return messages # Last message is a function call.
if "context" in last_message:
return messages # Last message contains a context key.
if "content" not in last_message:
return messages # Last message has no content.

user_content = last_message["content"]
if not isinstance(user_content, str) and not isinstance(user_content, list):
# if the user_content is a string, it is for regular LLM
# if the user_content is a list, it should follow the multimodal LMM format.
return messages
if user_content == "exit":
return messages # Last message is an exit command.

# Call each hook (in order of registration) to process the user's message.
processed_user_content = user_content
for hook in hook_list:
if not inspect.iscoroutinefunction(hook):
continue
processed_user_content = await hook(processed_user_content)

if processed_user_content == user_content:
return messages # No hooks actually modified the user's message.

# Replace the last user message with the expanded one.
messages = messages.copy()
messages[-1]["content"] = processed_user_content
return messages

def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
"""Print the usage summary."""
iostream = IOStream.get_default()
Expand Down
195 changes: 194 additions & 1 deletion test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import time
import unittest
from typing import Any, Callable, Dict, Literal
from typing import Any, Callable, Dict, List, Literal
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -1230,6 +1230,46 @@ def my_summary(sender, recipient, summary_args):
print(chat_res_play.summary)


def test_register_hook_async_sync():
agent = ConversableAgent("test_agent", llm_config=False)

# Synchronous hook for synchronous method
def sync_hook():
pass

agent.register_hook("process_all_messages_before_reply", sync_hook)
assert sync_hook in agent.hook_lists["process_all_messages_before_reply"]

# Asynchronous hook for asynchronous method
async def async_hook():
pass

agent.register_hook("a_process_all_messages_before_reply", async_hook)
assert async_hook in agent.hook_lists["a_process_all_messages_before_reply"]

# Synchronous hook for asynchronous method (should raise a warning)
with pytest.warns(
UserWarning, match="Hook 'sync_hook' is synchronous, but it's being registered in a asynchronous context"
):
agent.register_hook("a_process_all_messages_before_reply", sync_hook)
assert sync_hook in agent.hook_lists["a_process_all_messages_before_reply"]

# Asynchronous hook for synchronous method (should raise a warning)
with pytest.warns(
UserWarning, match="Hook 'async_hook' is asynchronous, but it's being registered in a synchronous context"
):
agent.register_hook("process_all_messages_before_reply", async_hook)
assert async_hook in agent.hook_lists["process_all_messages_before_reply"]

# Attempt to register the same hook twice (should raise an AssertionError)
with pytest.raises(AssertionError, match=r"<function.*sync_hook.*> is already registered as a hook"):
agent.register_hook("process_all_messages_before_reply", sync_hook)

# Attempt to register a hook for a non-existent method (should raise an AssertionError)
with pytest.raises(AssertionError, match="non_existent_method is not a hookable method"):
agent.register_hook("non_existent_method", sync_hook)


def test_process_before_send():
print_mock = unittest.mock.MagicMock()

Expand All @@ -1250,6 +1290,159 @@ def send_to_frontend(sender, message, recipient, silent):
print_mock.assert_called_once_with(message="hello")


@pytest.mark.asyncio
async def test_a_process_before_send():
print_mock = unittest.mock.MagicMock()

# Updated to include sender parameter
async def a_send_to_frontend(sender, message, recipient, silent):
# Simulating an async operation with asyncio.sleep
await asyncio.sleep(0.5)

assert sender.name == "dummy_agent_1", "Sender is not the expected agent"
if not silent:
print(f"Message sent from {sender.name} to {recipient.name}: {message}")
print_mock(message=message)
return message

dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
dummy_agent_2 = ConversableAgent(name="dummy_agent_2", llm_config=False, human_input_mode="NEVER")
dummy_agent_1.register_hook("a_process_message_before_send", a_send_to_frontend)
await dummy_agent_1.a_send("hello", dummy_agent_2)
print_mock.assert_called_once_with(message="hello")
dummy_agent_1.send("silent hello", dummy_agent_2, silent=True)
print_mock.assert_called_once_with(message="hello")


def test_process_last_received_message():

# Create a mock function to be used as a hook
def expand_message(message):
return message + " [Expanded]"

dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
dummy_agent_1.register_hook("process_last_received_message", expand_message)

# Normal message
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "user", "content": "How are you?"},
]

processed_messages = messages.copy()
dummy_agent_1.generate_reply(messages=processed_messages, sender=None)
assert processed_messages[-2]["content"] == "Hi there"
assert processed_messages[-1]["content"] == "How are you? [Expanded]"


@pytest.mark.asyncio
async def test_a_process_last_received_message():

# Create a mock function to be used as a hook
async def expand_message(message):
await asyncio.sleep(0.5)
return message + " [Expanded]"

dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
dummy_agent_1.register_hook("a_process_last_received_message", expand_message)

# Normal message
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "user", "content": "How are you?"},
]

processed_messages = messages.copy()
await dummy_agent_1.a_generate_reply(messages=processed_messages, sender=None)
assert processed_messages[-2]["content"] == "Hi there"
assert processed_messages[-1]["content"] == "How are you? [Expanded]"


def test_process_all_messages_before_reply():

messages = [
{"role": "user", "content": "hello"},
{"function_call": {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}, "role": "assistant"},
]

def _transform_messages(transformed_messages: List[Dict]) -> List[Dict]:
# ensure we are looking at all messages
assert len(transformed_messages) == len(messages), "Message length does not match"

# deep copy to ensure hooks applied comprehensively
post_transformed_messages = copy.deepcopy(transformed_messages)

# directly modify the message content for the function call (additional value)
post_transformed_messages[1]["function_call"]["arguments"] = '{ "num_to_be_added": 6 }'

return post_transformed_messages

def add_num(num_to_be_added):
given_num = 10
return num_to_be_added + given_num

dummy_agent_2 = ConversableAgent(
name="user_proxy", llm_config=False, human_input_mode="TERMINATE", function_map={"add_num": add_num}
)

# Baseline check before hook is executed
assert (
dummy_agent_2.generate_reply(messages=messages, sender=None)["content"] == "15"
), "generate_reply not working when sender is None"

dummy_agent_2.register_hook("process_all_messages_before_reply", _transform_messages)

# Hook is applied, updating the message content for the function call
assert (
dummy_agent_2.generate_reply(messages=messages, sender=None)["content"] == "16"
), "generate_reply not working when sender is None"


@pytest.mark.asyncio
async def test_a_process_all_messages_before_reply():

messages = [
{"role": "user", "content": "hello"},
{"function_call": {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}, "role": "assistant"},
]

async def a_transform_messages(transformed_messages: List[Dict]) -> List[Dict]:

# ensure we are looking at all messages
assert len(transformed_messages) == len(messages), "Message length does not match"

# Simulating an async operation with asyncio.sleep
await asyncio.sleep(0.5)

# deep copy to ensure hooks applied comprehensively
post_transformed_messages = copy.deepcopy(transformed_messages)

# directly modify the message content for the function call (additional value)
post_transformed_messages[1]["function_call"]["arguments"] = '{ "num_to_be_added": 6 }'

return post_transformed_messages

def add_num(num_to_be_added):
given_num = 10
return num_to_be_added + given_num

dummy_agent_2 = ConversableAgent(
name="user_proxy", llm_config=False, human_input_mode="TERMINATE", function_map={"add_num": add_num}
)

# Baseline check before hook is executed
response = await dummy_agent_2.a_generate_reply(messages=messages, sender=None)
assert response["content"] == "15", "generate_reply not working when sender is None"

dummy_agent_2.register_hook("a_process_all_messages_before_reply", a_transform_messages)

# Hook is applied, updating the message content for the function call
response = await dummy_agent_2.a_generate_reply(messages=messages, sender=None)
assert response["content"] == "16", "generate_reply not working when sender is None"


def test_messages_with_carryover():
agent1 = autogen.ConversableAgent(
"alice",
Expand Down
Loading