From 4eb994076ec423e424f813f72413c5ddcbd7dbda Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 23 Oct 2024 18:00:39 +0800 Subject: [PATCH] fix in chatroom --- .../chatroom_with_assistant_example.py | 7 +- .../environments/chatroom/envs/chatroom.py | 91 +++++++++++++------ 2 files changed, 66 insertions(+), 32 deletions(-) diff --git a/examples/environments/chatroom/chatroom_with_assistant_example.py b/examples/environments/chatroom/chatroom_with_assistant_example.py index e1fc2096c..cd7a4bff8 100644 --- a/examples/environments/chatroom/chatroom_with_assistant_example.py +++ b/examples/environments/chatroom/chatroom_with_assistant_example.py @@ -176,9 +176,10 @@ def main(args: argparse.Namespace) -> None: # Setup the persona of Carol carol = ChatRoomAgent( name="Carol", - sys_prompt=r"""You are Carol, and now you need to interview Bob. """ - r"""Just ask him where he is from, which school he graduated from, """ - r"""his profession, and his hobbies.""", + sys_prompt="""You are Carol, and now you need to interview Bob. """ + """Just ask him where he is from, which school he graduated from, """ + """his profession, and his hobbies. You'd better only ask one """ + """question at a time.""", model_config_name=YOUR_MODEL_CONFIGURATION_NAME, to_dist=args.use_dist, ) diff --git a/examples/environments/chatroom/envs/chatroom.py b/examples/environments/chatroom/envs/chatroom.py index 9cfe22947..f3d1ab570 100644 --- a/examples/environments/chatroom/envs/chatroom.py +++ b/examples/environments/chatroom/envs/chatroom.py @@ -39,6 +39,22 @@ """ +def format_messages(msgs: Union[Msg, List[Msg]]) -> list[dict]: + """Format the messages""" + messages = [] + if isinstance(msgs, Msg): + msgs = [msgs] + for msg in msgs: + messages.append( + { + "role": msg.role, + "name": msg.name, + "content": str(msg.content), + }, + ) + return messages + + class ChatRoomMember(BasicEnv): """A member of chatroom.""" @@ -196,8 +212,9 @@ def describe(self, agent_name: str, **kwargs: Any) -> str: members_profile.append(f"{name}: {sys_prompt}") members_profile_str = "\n\n".join(members_profile) if hasattr(self, "model"): + sys_prompt = self.children[agent_name].agent.sys_prompt desc_prompt = ( - f"""{self.children[agent_name].agent.sys_prompt}\n""" + # f"""{self.children[agent_name].agent.sys_prompt}\n""" f"""You are participating in a chatroom.\n\n""" f"""======= CHATROOM MEMBERS' PROFILE BEGIN ========\n""" f"""{members_profile_str}""" @@ -205,8 +222,11 @@ def describe(self, agent_name: str, **kwargs: Any) -> str: f"""Please describe the group members in one sentence """ f"""from {agent_name}'s perspective.""" ) - prompt = self.model.format( - Msg(name="system", role="system", content=desc_prompt), + prompt = format_messages( + [ + Msg(name="system", role="system", content=sys_prompt), + Msg(name="user", role="user", content=desc_prompt), + ], ) logger.debug(prompt) response = self.model(prompt) @@ -217,11 +237,11 @@ def describe(self, agent_name: str, **kwargs: Any) -> str: self.member_description[agent_name] = desc ann += f"\n{self.member_description[agent_name]}\n\n" ann += ( - r"""Please generate a suitable response in this work group based""" - r""" on the following chat history. When you need to mention """ - r"""someone, you can use @ to remind them. You only need to """ - rf"""output {agent_name}'s possible replies, without giving """ - r"""anyone else's replies or continuing the conversation.""" + """Please generate a suitable response in this work group based""" + """ on the following chat history. When you need to mention """ + """someone, you can use @ to remind them. You only need to """ + f"""output {agent_name}'s possible replies, without giving """ + """anyone else's replies or continuing the conversation.""" ) history = "\n\n".join( [ @@ -414,15 +434,17 @@ def _want_to_speak(self, hint: str) -> bool: f"{self.sys_prompt}\n\nYou are participating in a chatroom.\n" + hint ) - prompt = self.model.format( - Msg(name="system", role="system", content=hint), - Msg( - name="user", - role="user", - content="Based on the CHATROOM." - " Do you want to or need to speak in the chatroom now?\n" - "Return yes or no.", - ), + prompt = format_messages( + [ + Msg(name="system", role="system", content=hint), + Msg( + name="user", + role="user", + content="Based on the CHATROOM." + " Do you want to or need to speak in the chatroom now?\n" + "Return yes or no.", + ), + ], ) logger.debug(prompt) response = self.model( @@ -471,16 +493,20 @@ def reply(self, x: Msg = None) -> Msg: ) else: return Msg(name="assistant", role="assistant", content="") - system_hint = ( - f"{self.sys_prompt}\n\nYou are participating in a chatroom.\n" + user_hint = ( + # f"{self.sys_prompt}\n\n" + f"You are participating in a chatroom.\n" f"\n{room_info}\n{reply_hint}" ) - prompt = self.model.format( - Msg( - name="system", - role="system", - content=system_hint, - ), + prompt = format_messages( + [ + Msg( + name="system", + role="system", + content=self.sys_prompt, + ), + Msg(name="user", role="user", content=user_hint), + ], ) prompt[-1]["content"] = prompt[-1]["content"].strip() logger.debug(prompt) @@ -552,13 +578,20 @@ def reply(self, x: Msg = None) -> Msg: f"\n{self.name}:" ) system_hint = ( - f"{self.sys_prompt}\n\nYou are participating in a chatroom.\n" + # f"{self.sys_prompt}\n\n" + f"You are participating in a chatroom.\n" f"\n{room_info}\n{reply_hint}" ) - msg_hint = Msg(name=self.name, content=system_hint, role="system") - prompt = self.model.format( - msg_hint, + prompt = format_messages( + [ + Msg( + name=self.name, + content=self.sys_prompt, + role="system", + ), + Msg(name="user", content=system_hint, role="user"), + ], ) logger.debug(prompt) response = self.model(