Skip to content

Commit

Permalink
fix in chatroom
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyushuo committed Oct 23, 2024
1 parent a50e58c commit 4eb9940
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ def main(args: argparse.Namespace) -> None:
# Setup the persona of Carol
carol = ChatRoomAgent(
name="Carol",
sys_prompt=r"""You are Carol, and now you need to interview Bob. """
r"""Just ask him where he is from, which school he graduated from, """
r"""his profession, and his hobbies.""",
sys_prompt="""You are Carol, and now you need to interview Bob. """
"""Just ask him where he is from, which school he graduated from, """
"""his profession, and his hobbies. You'd better only ask one """
"""question at a time.""",
model_config_name=YOUR_MODEL_CONFIGURATION_NAME,
to_dist=args.use_dist,
)
Expand Down
91 changes: 62 additions & 29 deletions examples/environments/chatroom/envs/chatroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@
"""


def format_messages(msgs: Union[Msg, List[Msg]]) -> list[dict]:
"""Format the messages"""
messages = []
if isinstance(msgs, Msg):
msgs = [msgs]
for msg in msgs:
messages.append(
{
"role": msg.role,
"name": msg.name,
"content": str(msg.content),
},
)
return messages


class ChatRoomMember(BasicEnv):
"""A member of chatroom."""

Expand Down Expand Up @@ -196,17 +212,21 @@ def describe(self, agent_name: str, **kwargs: Any) -> str:
members_profile.append(f"{name}: {sys_prompt}")
members_profile_str = "\n\n".join(members_profile)
if hasattr(self, "model"):
sys_prompt = self.children[agent_name].agent.sys_prompt
desc_prompt = (
f"""{self.children[agent_name].agent.sys_prompt}\n"""
# f"""{self.children[agent_name].agent.sys_prompt}\n"""
f"""You are participating in a chatroom.\n\n"""
f"""======= CHATROOM MEMBERS' PROFILE BEGIN ========\n"""
f"""{members_profile_str}"""
f"""======= CHATROOM MEMBERS' PROFILE END ========\n"""
f"""Please describe the group members in one sentence """
f"""from {agent_name}'s perspective."""
)
prompt = self.model.format(
Msg(name="system", role="system", content=desc_prompt),
prompt = format_messages(
[
Msg(name="system", role="system", content=sys_prompt),
Msg(name="user", role="user", content=desc_prompt),
],
)
logger.debug(prompt)
response = self.model(prompt)
Expand All @@ -217,11 +237,11 @@ def describe(self, agent_name: str, **kwargs: Any) -> str:
self.member_description[agent_name] = desc
ann += f"\n{self.member_description[agent_name]}\n\n"
ann += (
r"""Please generate a suitable response in this work group based"""
r""" on the following chat history. When you need to mention """
r"""someone, you can use @ to remind them. You only need to """
rf"""output {agent_name}'s possible replies, without giving """
r"""anyone else's replies or continuing the conversation."""
"""Please generate a suitable response in this work group based"""
""" on the following chat history. When you need to mention """
"""someone, you can use @ to remind them. You only need to """
f"""output {agent_name}'s possible replies, without giving """
"""anyone else's replies or continuing the conversation."""
)
history = "\n\n".join(
[
Expand Down Expand Up @@ -414,15 +434,17 @@ def _want_to_speak(self, hint: str) -> bool:
f"{self.sys_prompt}\n\nYou are participating in a chatroom.\n"
+ hint
)
prompt = self.model.format(
Msg(name="system", role="system", content=hint),
Msg(
name="user",
role="user",
content="Based on the CHATROOM."
" Do you want to or need to speak in the chatroom now?\n"
"Return yes or no.",
),
prompt = format_messages(
[
Msg(name="system", role="system", content=hint),
Msg(
name="user",
role="user",
content="Based on the CHATROOM."
" Do you want to or need to speak in the chatroom now?\n"
"Return yes or no.",
),
],
)
logger.debug(prompt)
response = self.model(
Expand Down Expand Up @@ -471,16 +493,20 @@ def reply(self, x: Msg = None) -> Msg:
)
else:
return Msg(name="assistant", role="assistant", content="")
system_hint = (
f"{self.sys_prompt}\n\nYou are participating in a chatroom.\n"
user_hint = (
# f"{self.sys_prompt}\n\n"
f"You are participating in a chatroom.\n"
f"\n{room_info}\n{reply_hint}"
)
prompt = self.model.format(
Msg(
name="system",
role="system",
content=system_hint,
),
prompt = format_messages(
[
Msg(
name="system",
role="system",
content=self.sys_prompt,
),
Msg(name="user", role="user", content=user_hint),
],
)
prompt[-1]["content"] = prompt[-1]["content"].strip()
logger.debug(prompt)
Expand Down Expand Up @@ -552,13 +578,20 @@ def reply(self, x: Msg = None) -> Msg:
f"\n{self.name}:"
)
system_hint = (
f"{self.sys_prompt}\n\nYou are participating in a chatroom.\n"
# f"{self.sys_prompt}\n\n"
f"You are participating in a chatroom.\n"
f"\n{room_info}\n{reply_hint}"
)
msg_hint = Msg(name=self.name, content=system_hint, role="system")

prompt = self.model.format(
msg_hint,
prompt = format_messages(
[
Msg(
name=self.name,
content=self.sys_prompt,
role="system",
),
Msg(name="user", content=system_hint, role="user"),
],
)
logger.debug(prompt)
response = self.model(
Expand Down

0 comments on commit 4eb9940

Please sign in to comment.