Skip to content

Commit

Permalink
Make sure batch env are sent list of commands when calling step funct…
Browse files Browse the repository at this point in the history
…ion.
  • Loading branch information
MarcCote committed Feb 5, 2024
1 parent e170c6e commit feefbd8
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
7 changes: 6 additions & 1 deletion textworld/envs/batch/batch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ def step(self, actions: List[str]) -> Tuple[List[str], int, bool, Dict[str, List
done: Whether the game is over or not.
infos: Information requested when creating the environments.
"""
results = []
assert isinstance(actions, (list, tuple)), "Expected a list of actions."
assert len(actions) == len(self.envs), "Expected one action per environment."

results = []
for i, (env, action) in enumerate(zip(self.envs, actions)):
if self.last[i] is not None and self.last[i][2]: # Game has ended on the last step.
obs, reward, done, infos = self.last[i] # Copy last state over.
Expand Down Expand Up @@ -252,6 +254,9 @@ def step(self, actions):
done: Whether the game is over or not.
infos: Information requested when creating the environments.
"""
assert isinstance(actions, (list, tuple)), "Expected a list of actions."
assert len(actions) == len(self.envs), "Expected one action per environment."

results = []
for i, (env, action) in enumerate(zip(self.envs, actions)):
if self.last[i] is not None and self.last[i][2]: # Game has ended on the last step.
Expand Down
2 changes: 2 additions & 0 deletions textworld/gym/envs/textworld_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def step(self, commands) -> Tuple[List[str], List[float], List[bool], Dict[str,
* dones: whether each game in the batch is finished or not;
* infos: additional information as requested for each game in the batch.
"""
assert isinstance(commands, (list, tuple)), "Expected a list of commands."

self.last_commands = commands
self.obs, scores, dones, infos = self.batch_env.step(self.last_commands)
return self.obs, scores, dones, infos
Expand Down
16 changes: 16 additions & 0 deletions textworld/gym/tests/test_textworld_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ def test_batch_sync(self):
for values in infos.values():
assert len(values) == batch_size

# Sending single command should raise assertion.
with pytest.raises(AssertionError):
obs, scores, dones, infos = env.step("wait")

# Sending not engough commands should raise assertion.
with pytest.raises(AssertionError):
obs, scores, dones, infos = env.step(["wait"] * (batch_size - 1))

for cmds in zip(*infos.get("extra.walkthrough")):
obs, scores, dones, infos = env.step(cmds)

Expand Down Expand Up @@ -165,6 +173,14 @@ def test_batch_async(self):
for values in infos.values():
assert len(values) == batch_size

# Sending single command should raise assertion.
with pytest.raises(AssertionError):
obs, scores, dones, infos = env.step("wait")

# Sending not engough commands should raise assertion.
with pytest.raises(AssertionError):
obs, scores, dones, infos = env.step(["wait"] * (batch_size - 1))

for cmds in zip(*infos.get("extra.walkthrough")):
obs, scores, dones, infos = env.step(cmds)

Expand Down

0 comments on commit feefbd8

Please sign in to comment.