Skip to content

Commit

Permalink
fix chatroom bug
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Oct 14, 2024
1 parent df680e8 commit ca1085a
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tests/environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def reply(self, x: Msg = None) -> Msg:
self.event_list.append(event)
return Msg(name=self.name, content="", role="assistant")
else:
history = self.room.get_history(self.agent_id)
history = self.room.get_history(self.name)
msg = Msg(name=self.name, content=len(history), role="assistant")
self.room.speak(msg)
return msg
Expand Down Expand Up @@ -408,7 +408,7 @@ def __call__(self, env: Env, event: Event) -> None:
self.assertEqual(master.event_list[-1].name, "speak")
self.assertEqual(master.event_list[-1].args["message"], r1)
self.assertEqual(master.event_list[-2].name, "get_history")
self.assertEqual(master.event_list[-2].args["agent_id"], a1.agent_id)
self.assertEqual(master.event_list[-2].args["agent_name"], a1.name)
self.assertEqual(r1.content, 0)

a2 = AgentWithChatRoom("a2")
Expand All @@ -419,12 +419,12 @@ def __call__(self, env: Env, event: Event) -> None:
self.assertEqual(master.event_list[-1].name, "speak")
self.assertEqual(master.event_list[-1].args["message"], r2)
self.assertEqual(master.event_list[-2].name, "get_history")
self.assertEqual(master.event_list[-2].args["agent_id"], a2.agent_id)
self.assertEqual(master.event_list[-2].args["agent_name"], a2.name)
self.assertEqual(r2.content, 0)

# test history_idx
self.assertEqual(r[a1.agent_id].history_idx, 0)
self.assertEqual(r[a2.agent_id].history_idx, 1)
self.assertEqual(r[a1.name].history_idx, 0)
self.assertEqual(r[a2.name].history_idx, 1)


class AgentWithMutableEnv(AgentBase):
Expand Down Expand Up @@ -564,28 +564,28 @@ def __call__(self, env: Env, event: Event) -> None:
self.assertEqual(event.args["message"].content, r1.content)
event = master.get_event(-2)
self.assertEqual(event.name, "get_history")
self.assertEqual(event.args["agent_id"], a1.agent_id)
self.assertEqual(event.args["agent_name"], a1.name)

# test mix of rpc agent and local agent
a2 = AgentWithChatRoom("a2")
a2.join(r)
event = master.get_event(-1)
self.assertEqual(event.name, "join")
self.assertEqual(event.args["agent"].agent_id, a2.agent_id)
self.assertEqual(event.args["agent"].name, a2.name)
r2 = a2(Msg(name="user", role="user", content="hello"))
self.assertEqual(r2.content, 0)
self.assertEqual(master.get_event(-1).name, "speak")
self.assertEqual(master.get_event(-1).args["message"], r2)
self.assertEqual(master.get_event(-2).name, "get_history")

# test rpc type
ra1 = r[a1.agent_id].agent
ra1 = r[a1.name].agent
self.assertTrue(isinstance(ra1, RpcObject))
self.assertEqual(ra1.agent_id, a1.agent_id)
rr = a1.chatroom()
self.assertTrue(isinstance(rr, RpcObject))
self.assertEqual(r._oid, rr._oid) # pylint: disable=W0212

# test history_idx
self.assertEqual(r[a1.agent_id].history_idx, 0)
self.assertEqual(r[a2.agent_id].history_idx, 1)
self.assertEqual(r[a1.name].history_idx, 0)
self.assertEqual(r[a2.name].history_idx, 1)

0 comments on commit ca1085a

Please sign in to comment.