Skip to content

Commit

Permalink
fix in chatroom
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyushuo committed Oct 23, 2024
1 parent d29a4ac commit a50e58c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 43 deletions.
9 changes: 7 additions & 2 deletions examples/environments/chatroom/chatroom_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def main(args: argparse.Namespace) -> None:
),
role="system",
)
r = ChatRoom(name="chat", announcement=ann, model_config_name=YOUR_MODEL_CONFIGURATION_NAME, to_dist=args.use_dist)
r = ChatRoom(
name="chat",
announcement=ann,
model_config_name=YOUR_MODEL_CONFIGURATION_NAME,
to_dist=args.use_dist,
)

# Setup the persona of Alice, Bob and Carol
alice = ChatRoomAgent( # Game Art Designer
Expand Down Expand Up @@ -90,7 +95,7 @@ def main(args: argparse.Namespace) -> None:
r.chat_freely(
delay=10,
interval=10,
max_round=3, # 10,
max_round=10,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def main(args: argparse.Namespace) -> None:
name="Bob",
sys_prompt=r"""You are Bob's chat room assistant and he is """
r"""currently unable to reply to messages. Please generate a """
r"""suitable response based on the following chat history. """
r"""The content you reply to must be based on the chat history. """
r"""Please refuse to reply to questions that are beyond the scope """
r"""of the chat history.""",
r"""suitable response based on the following chat history without """
r"""reasoning. The content you reply to must be based on the chat """
r"""history. Please refuse to reply to questions that are beyond """
r"""the scope of the chat history.""",
model_config_name=YOUR_MODEL_CONFIGURATION_NAME,
to_dist=args.use_dist,
timeout=args.timeout,
Expand Down Expand Up @@ -178,9 +178,7 @@ def main(args: argparse.Namespace) -> None:
name="Carol",
sys_prompt=r"""You are Carol, and now you need to interview Bob. """
r"""Just ask him where he is from, which school he graduated from, """
r"""his profession, and his hobbies. At the end of the interview, """
r"""please output a reply containing Goodbye to indicate the end """
r"""of the conversation.""",
r"""his profession, and his hobbies.""",
model_config_name=YOUR_MODEL_CONFIGURATION_NAME,
to_dist=args.use_dist,
)
Expand Down
89 changes: 55 additions & 34 deletions examples/environments/chatroom/envs/chatroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ def get_history(self, agent_name: str) -> List[Msg]:
history_idx = self.children[agent_name].history_idx
return deepcopy(self.history[history_idx:])

def get_history_length(self, agent_name: str) -> int:
"""Get the length of the history of the agent."""
if agent_name not in self.children:
return 0
if self.all_history:
history_idx = 0
else:
history_idx = self.children[agent_name].history_idx
return len(self.history) - history_idx

def describe(self, agent_name: str, **kwargs: Any) -> str:
"""Get the description of the chatroom."""
ann = (
Expand All @@ -183,18 +193,20 @@ def describe(self, agent_name: str, **kwargs: Any) -> str:
members_profile = []
for name, member in self.children.items():
sys_prompt = member.agent.sys_prompt
members_profile.append(f'{name}: {sys_prompt}')
members_profile.append(f"{name}: {sys_prompt}")
members_profile_str = "\n\n".join(members_profile)
if hasattr(self, 'model'):
if hasattr(self, "model"):
desc_prompt = (
f"""{self.children[agent_name].agent.sys_prompt}\nYou are participating in a chatroom.\n\n"""
f"""{self.children[agent_name].agent.sys_prompt}\n"""
f"""You are participating in a chatroom.\n\n"""
f"""======= CHATROOM MEMBERS' PROFILE BEGIN ========\n"""
f"""{members_profile_str}"""
f"""======= CHATROOM MEMBERS' PROFILE END ========\n"""
f"""Please describe the group members in one sentence from {agent_name}'s perspective."""
f"""Please describe the group members in one sentence """
f"""from {agent_name}'s perspective."""
)
prompt = self.model.format(
Msg(name="system", role="system", content=desc_prompt)
Msg(name="system", role="system", content=desc_prompt),
)
logger.debug(prompt)
response = self.model(prompt)
Expand All @@ -205,11 +217,11 @@ def describe(self, agent_name: str, **kwargs: Any) -> str:
self.member_description[agent_name] = desc
ann += f"\n{self.member_description[agent_name]}\n\n"
ann += (
r"""Please generate a suitable response in this work group based """
r"""on the following chat history. When you need to mention """
r"""someone, you can use @ to remind them. You only need to output """
rf"""{agent_name}'s possible replies, without giving anyone else's replies """
r"""or continuing the conversation."""
r"""Please generate a suitable response in this work group based"""
r""" on the following chat history. When you need to mention """
r"""someone, you can use @ to remind them. You only need to """
rf"""output {agent_name}'s possible replies, without giving """
r"""anyone else's replies or continuing the conversation."""
)
history = "\n\n".join(
[
Expand Down Expand Up @@ -293,7 +305,8 @@ def chat_freely(
) -> None:
"""Let all agents to chat freely without any preset order"""
tasks = []
agent_name_list = agent_name_list or list(self.children.keys())
if agent_name_list is None:
agent_name_list = list(self.children.keys())
for agent_name in agent_name_list:
task = threading.Thread(
target=self.children[agent_name].chat_freely,
Expand All @@ -314,7 +327,7 @@ def chat_in_sequence(self, agent_name_order: List[str] = None) -> None:
Args:
agent_name_order (`List[str]`): Order of speakers' names.
"""
agent_name_list = agent_name_list or list(self.children.keys())
agent_name_order = agent_name_order or list(self.children.keys())
for agent_name in agent_name_order:
self.children[agent_name].chat()

Expand Down Expand Up @@ -373,6 +386,7 @@ def add_mentioned_message(self, msg: Msg) -> None:
def join(self, room: ChatRoom) -> bool:
"""Join a room"""
self.room = room
self.room_history_length = self.room.get_history_length(self.name)
return room.join(self)

def _is_mentioned(self) -> bool:
Expand All @@ -390,21 +404,27 @@ def _generate_mentioned_prompt(self) -> Tuple[bool, str]:
for msg in self.mentioned_messages
],
)
self.mentioned_messages = []
return True, hint
return False, ""

def _want_to_speak(self, hint: str) -> bool:
"""Check whether the agent want to speak currently"""
hint = (
f"{self.sys_prompt}\n\nYou are participating in a chatroom.\n"
+ hint
)
prompt = self.model.format(
Msg(name="system", role="system", content=hint),
Msg(
name="user",
role="user",
content="Based on the CHATROOM."
" Do you want to speak in the chatroom now?\n"
"Speak yes or no.",
" Do you want to or need to speak in the chatroom now?\n"
"Return yes or no.",
),
)
logger.debug(prompt)
response = self.model(
prompt,
max_retries=3,
Expand All @@ -428,21 +448,27 @@ def speak(

def reply(self, x: Msg = None) -> Msg:
"""Generate reply to chat room"""
room_history_length = len(self.room.history)
room_history_length = self.room.get_history_length(self.name)
if room_history_length != self.room_history_length:
self.room_history_length = room_history_length
self.room_slient_count = 0
else:
self.room_slient_count += 1
room_info = self.room.describe(self.name)
reply_hint = ''
reply_hint = ""
mentioned, mentioned_hint = self._generate_mentioned_prompt()
if mentioned:
reply_hint = f'{mentioned_hint}\n{self.name}:'
reply_hint = f"{mentioned_hint}\n{self.name}:"
else:
# decide whether to speak
if self.room_history_length <= 3 or (self.room_slient_count <= 2 and self._want_to_speak(room_info)):
reply_hint = f"Please generate a response based on the CHATROOM.\n{self.name}:"
if self.room_history_length <= 3 or (
self.room_slient_count <= 2 and self._want_to_speak(room_info)
):
reply_hint = (
f"Please generate a response based on the"
f" CHATROOM. You need only generate response without "
f"reasoning.\n{self.name}:"
)
else:
return Msg(name="assistant", role="assistant", content="")
system_hint = (
Expand All @@ -454,7 +480,7 @@ def reply(self, x: Msg = None) -> Msg:
name="system",
role="system",
content=system_hint,
)
),
)
prompt[-1]["content"] = prompt[-1]["content"].strip()
logger.debug(prompt)
Expand All @@ -466,6 +492,7 @@ def reply(self, x: Msg = None) -> Msg:
msg = Msg(name=self.name, content=response, role="assistant")
if response:
self.speak(msg)
self.room_history_length = self.room.get_history_length(self.name)
return msg


Expand Down Expand Up @@ -510,35 +537,28 @@ def reply(self, x: Msg = None) -> Msg:
if content is not None: # user input
response = content
else: # assistant reply
room_history_length = len(self.room.history)
room_history_length = self.room.get_history_length(self.name)
if room_history_length == self.room_history_length:
return Msg(name="assistant", role="assistant", content="")
self.room_history_length = room_history_length
# msg_hint = self._generate_mentioned_prompt()
room_info = self.room.describe(self.name)
reply_hint = ''
reply_hint = ""
mentioned, mentioned_hint = self._generate_mentioned_prompt()
if mentioned:
reply_hint = f'{mentioned_hint}\n{self.name}:'
reply_hint = f"{mentioned_hint}\n{self.name}:"
else:
# decide whether to speak
if self.room_history_length <= 3 or (self.room_slient_count <= 2 and self._want_to_speak(room_info)):
reply_hint = f"Please generate a response based on the CHATROOM.\n{self.name}:"
else:
return Msg(name="assistant", role="assistant", content="")
reply_hint = (
f"Please generate a response based on the CHATROOM."
f"\n{self.name}:"
)
system_hint = (
f"{self.sys_prompt}\n\nYou are participating in a chatroom.\n"
f"\n{room_info}\n{reply_hint}"
)
msg_hint = Msg(name=self.name, content=system_hint, role="system")

self_msg = Msg(name=self.name, content="", role="assistant")

# history = self.room.get_history(self.agent_id)
prompt = self.model.format(
msg_hint,
# history,
self_msg,
)
logger.debug(prompt)
response = self.model(
Expand All @@ -550,4 +570,5 @@ def reply(self, x: Msg = None) -> Msg:
response = "[auto reply] " + response
msg = Msg(name=self.name, content=response, role="user")
self.speak(msg)
self.room_history_length = self.room.get_history_length(self.name)
return msg

0 comments on commit a50e58c

Please sign in to comment.