Skip to content

Commit

Permalink
fix: adds test to other hook types
Browse files Browse the repository at this point in the history
  • Loading branch information
robraux committed Sep 30, 2024
1 parent 4ddac48 commit d6bbc1f
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 4 deletions.
2 changes: 1 addition & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2905,7 +2905,7 @@ async def a_process_last_received_message(self, messages: List[Dict]) -> List[Di
"""

# If any required condition is not met, return the original message list.
hook_list = self.hook_lists["process_last_received_message"]
hook_list = self.hook_lists["a_process_last_received_message"]
if len(hook_list) == 0:
return messages # No hooks registered.
if messages is None:
Expand Down
175 changes: 172 additions & 3 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import time
import unittest
from typing import Any, Callable, Dict, Literal
from typing import Any, Callable, Dict, List, Literal
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -1230,6 +1230,46 @@ def my_summary(sender, recipient, summary_args):
print(chat_res_play.summary)


def test_register_hook_async_sync():
agent = ConversableAgent("test_agent", llm_config=False)

# Synchronous hook for synchronous method
def sync_hook():
pass

agent.register_hook("process_all_messages_before_reply", sync_hook)
assert sync_hook in agent.hook_lists["process_all_messages_before_reply"]

# Asynchronous hook for asynchronous method
async def async_hook():
pass

agent.register_hook("a_process_all_messages_before_reply", async_hook)
assert async_hook in agent.hook_lists["a_process_all_messages_before_reply"]

# Synchronous hook for asynchronous method (should raise a warning)
with pytest.warns(
UserWarning, match="Hook 'sync_hook' is synchronous, but it's being registered in a asynchronous context"
):
agent.register_hook("a_process_all_messages_before_reply", sync_hook)
assert sync_hook in agent.hook_lists["a_process_all_messages_before_reply"]

# Asynchronous hook for synchronous method (should raise a warning)
with pytest.warns(
UserWarning, match="Hook 'async_hook' is asynchronous, but it's being registered in a synchronous context"
):
agent.register_hook("process_all_messages_before_reply", async_hook)
assert async_hook in agent.hook_lists["process_all_messages_before_reply"]

# Attempt to register the same hook twice (should raise an AssertionError)
with pytest.raises(AssertionError, match=r"<function.*sync_hook.*> is already registered as a hook"):
agent.register_hook("process_all_messages_before_reply", sync_hook)

# Attempt to register a hook for a non-existent method (should raise an AssertionError)
with pytest.raises(AssertionError, match="non_existent_method is not a hookable method"):
agent.register_hook("non_existent_method", sync_hook)


def test_process_before_send():
print_mock = unittest.mock.MagicMock()

Expand All @@ -1251,13 +1291,13 @@ def send_to_frontend(sender, message, recipient, silent):


@pytest.mark.asyncio
async def test_a_process_before_send_async():
async def test_a_process_before_send():
print_mock = unittest.mock.MagicMock()

# Updated to include sender parameter
async def a_send_to_frontend(sender, message, recipient, silent):
# Simulating an async operation with asyncio.sleep
await asyncio.sleep(1)
await asyncio.sleep(0.5)

assert sender.name == "dummy_agent_1", "Sender is not the expected agent"
if not silent:
Expand All @@ -1274,6 +1314,135 @@ async def a_send_to_frontend(sender, message, recipient, silent):
print_mock.assert_called_once_with(message="hello")


def test_process_last_received_message():

# Create a mock function to be used as a hook
def expand_message(message):
return message + " [Expanded]"

dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
dummy_agent_1.register_hook("process_last_received_message", expand_message)

# Normal message
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "user", "content": "How are you?"},
]

processed_messages = messages.copy()
dummy_agent_1.generate_reply(messages=processed_messages, sender=None)
assert processed_messages[-2]["content"] == "Hi there"
assert processed_messages[-1]["content"] == "How are you? [Expanded]"


@pytest.mark.asyncio
async def test_a_process_last_received_message():

# Create a mock function to be used as a hook
async def expand_message(message):
await asyncio.sleep(0.5)
return message + " [Expanded]"

dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
dummy_agent_1.register_hook("a_process_last_received_message", expand_message)

# Normal message
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "user", "content": "How are you?"},
]

processed_messages = messages.copy()
await dummy_agent_1.a_generate_reply(messages=processed_messages, sender=None)
assert processed_messages[-2]["content"] == "Hi there"
assert processed_messages[-1]["content"] == "How are you? [Expanded]"


def test_process_all_messages_before_reply():

messages = [
{"role": "user", "content": "hello"},
{"function_call": {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}, "role": "assistant"},
]

def _transform_messages(transformed_messages: List[Dict]) -> List[Dict]:
# ensure we are looking at all messages
assert len(transformed_messages) == len(messages), "Message length does not match"

# deep copy to ensure hooks applied comprehensively
post_transformed_messages = copy.deepcopy(transformed_messages)

# directly modify the message content for the function call (additional value)
post_transformed_messages[1]["function_call"]["arguments"] = '{ "num_to_be_added": 6 }'

return post_transformed_messages

def add_num(num_to_be_added):
given_num = 10
return num_to_be_added + given_num

dummy_agent_2 = ConversableAgent(
name="user_proxy", llm_config=False, human_input_mode="TERMINATE", function_map={"add_num": add_num}
)

# Baseline check before hook is executed
assert (
dummy_agent_2.generate_reply(messages=messages, sender=None)["content"] == "15"
), "generate_reply not working when sender is None"

dummy_agent_2.register_hook("process_all_messages_before_reply", _transform_messages)

# Hook is applied, updating the message content for the function call
assert (
dummy_agent_2.generate_reply(messages=messages, sender=None)["content"] == "16"
), "generate_reply not working when sender is None"


@pytest.mark.asyncio
async def test_a_process_all_messages_before_reply():

messages = [
{"role": "user", "content": "hello"},
{"function_call": {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}, "role": "assistant"},
]

async def a_transform_messages(transformed_messages: List[Dict]) -> List[Dict]:

# ensure we are looking at all messages
assert len(transformed_messages) == len(messages), "Message length does not match"

# Simulating an async operation with asyncio.sleep
await asyncio.sleep(0.5)

# deep copy to ensure hooks applied comprehensively
post_transformed_messages = copy.deepcopy(transformed_messages)

# directly modify the message content for the function call (additional value)
post_transformed_messages[1]["function_call"]["arguments"] = '{ "num_to_be_added": 6 }'

return post_transformed_messages

def add_num(num_to_be_added):
given_num = 10
return num_to_be_added + given_num

dummy_agent_2 = ConversableAgent(
name="user_proxy", llm_config=False, human_input_mode="TERMINATE", function_map={"add_num": add_num}
)

# Baseline check before hook is executed
response = await dummy_agent_2.a_generate_reply(messages=messages, sender=None)
assert response["content"] == "15", "generate_reply not working when sender is None"

dummy_agent_2.register_hook("a_process_all_messages_before_reply", a_transform_messages)

# Hook is applied, updating the message content for the function call
response = await dummy_agent_2.a_generate_reply(messages=messages, sender=None)
assert response["content"] == "16", "generate_reply not working when sender is None"


def test_messages_with_carryover():
agent1 = autogen.ConversableAgent(
"alice",
Expand Down

0 comments on commit d6bbc1f

Please sign in to comment.