Skip to content

Commit

Permalink
Rename Environment.infos to Environment.request_infos
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcCote committed Nov 21, 2023
1 parent 3bb5df6 commit e72ed2a
Show file tree
Hide file tree
Showing 20 changed files with 133 additions and 136 deletions.
73 changes: 36 additions & 37 deletions notebooks/Building a simple agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
" @property\n",
" def infos_to_request(self) -> textworld.EnvInfos:\n",
" return textworld.EnvInfos(admissible_commands=True)\n",
" \n",
"\n",
" def act(self, obs: str, score: int, done: bool, infos: Mapping[str, Any]) -> str:\n",
" return self.rng.choice(infos[\"admissible_commands\"])\n"
]
Expand Down Expand Up @@ -213,11 +213,11 @@
"\n",
" infos_to_request = agent.infos_to_request\n",
" infos_to_request.max_score = True # Needed to normalize the scores.\n",
" \n",
"\n",
" gamefiles = [path]\n",
" if os.path.isdir(path):\n",
" gamefiles = glob(os.path.join(path, \"*.z8\"))\n",
" \n",
"\n",
" env_id = textworld.gym.register_games(gamefiles,\n",
" request_infos=infos_to_request,\n",
" max_episode_steps=max_step)\n",
Expand All @@ -227,7 +227,7 @@
" print(os.path.dirname(path), end=\"\")\n",
" else:\n",
" print(os.path.basename(path), end=\"\")\n",
" \n",
"\n",
" # Collect some statistics: nb_steps, final reward.\n",
" avg_moves, avg_scores, avg_norm_scores = [], [], []\n",
" for no_episode in range(nb_episodes):\n",
Expand All @@ -240,9 +240,9 @@
" command = agent.act(obs, score, done, infos)\n",
" obs, score, done, infos = env.step(command)\n",
" nb_moves += 1\n",
" \n",
"\n",
" agent.act(obs, score, done, infos) # Let the agent know the game is done.\n",
" \n",
"\n",
" if verbose:\n",
" print(\".\", end=\"\")\n",
" avg_moves.append(nb_moves)\n",
Expand All @@ -256,8 +256,7 @@
" print(msg.format(np.mean(avg_moves), np.mean(avg_norm_scores), 1))\n",
" else:\n",
" msg = \" \\tavg. steps: {:5.1f}; avg. score: {:4.1f} / {}.\"\n",
" print(msg.format(np.mean(avg_moves), np.mean(avg_scores), infos[\"max_score\"]))\n",
" "
" print(msg.format(np.mean(avg_moves), np.mean(avg_scores), infos[\"max_score\"]))\n"
]
},
{
Expand Down Expand Up @@ -389,45 +388,45 @@
" UPDATE_FREQUENCY = 10\n",
" LOG_FREQUENCY = 1000\n",
" GAMMA = 0.9\n",
" \n",
"\n",
" def __init__(self) -> None:\n",
" self._initialized = False\n",
" self._epsiode_has_started = False\n",
" self.id2word = [\"<PAD>\", \"<UNK>\"]\n",
" self.word2id = {w: i for i, w in enumerate(self.id2word)}\n",
" \n",
"\n",
" self.model = CommandScorer(input_size=self.MAX_VOCAB_SIZE, hidden_size=128)\n",
" self.optimizer = optim.Adam(self.model.parameters(), 0.00003)\n",
" \n",
"\n",
" self.mode = \"test\"\n",
" \n",
"\n",
" def train(self):\n",
" self.mode = \"train\"\n",
" self.stats = {\"max\": defaultdict(list), \"mean\": defaultdict(list)}\n",
" self.transitions = []\n",
" self.model.reset_hidden(1)\n",
" self.last_score = 0\n",
" self.no_train_step = 0\n",
" \n",
"\n",
" def test(self):\n",
" self.mode = \"test\"\n",
" self.model.reset_hidden(1)\n",
" \n",
"\n",
" @property\n",
" def infos_to_request(self) -> EnvInfos:\n",
" return EnvInfos(description=True, inventory=True, admissible_commands=True,\n",
" won=True, lost=True)\n",
" \n",
"\n",
" def _get_word_id(self, word):\n",
" if word not in self.word2id:\n",
" if len(self.word2id) >= self.MAX_VOCAB_SIZE:\n",
" return self.word2id[\"<UNK>\"]\n",
" \n",
"\n",
" self.id2word.append(word)\n",
" self.word2id[word] = len(self.word2id)\n",
" \n",
"\n",
" return self.word2id[word]\n",
" \n",
"\n",
" def _tokenize(self, text):\n",
" # Simple tokenizer: strip out all non-alphabetic characters.\n",
" text = re.sub(\"[^a-zA-Z0-9\\- ]\", \" \", text)\n",
Expand All @@ -445,7 +444,7 @@
" padded_tensor = torch.from_numpy(padded).type(torch.long).to(device)\n",
" padded_tensor = padded_tensor.permute(1, 0) # Batch x Seq => Seq x Batch\n",
" return padded_tensor\n",
" \n",
"\n",
" def _discount_rewards(self, last_values):\n",
" returns, advantages = [], []\n",
" R = last_values.data\n",
Expand All @@ -455,48 +454,48 @@
" adv = R - values\n",
" returns.append(R)\n",
" advantages.append(adv)\n",
" \n",
"\n",
" return returns[::-1], advantages[::-1]\n",
"\n",
" def act(self, obs: str, score: int, done: bool, infos: Mapping[str, Any]) -> Optional[str]:\n",
" \n",
"\n",
" # Build agent's observation: feedback + look + inventory.\n",
" input_ = \"{}\\n{}\\n{}\".format(obs, infos[\"description\"], infos[\"inventory\"])\n",
" \n",
"\n",
" # Tokenize and pad the input and the commands to chose from.\n",
" input_tensor = self._process([input_])\n",
" commands_tensor = self._process(infos[\"admissible_commands\"])\n",
" \n",
"\n",
" # Get our next action and value prediction.\n",
" outputs, indexes, values = self.model(input_tensor, commands_tensor)\n",
" action = infos[\"admissible_commands\"][indexes[0]]\n",
" \n",
"\n",
" if self.mode == \"test\":\n",
" if done:\n",
" self.model.reset_hidden(1)\n",
" return action\n",
" \n",
"\n",
" self.no_train_step += 1\n",
" \n",
"\n",
" if self.transitions:\n",
" reward = score - self.last_score # Reward is the gain/loss in score.\n",
" self.last_score = score\n",
" if infos[\"won\"]:\n",
" reward += 100\n",
" if infos[\"lost\"]:\n",
" reward -= 100\n",
" \n",
"\n",
" self.transitions[-1][0] = reward # Update reward information.\n",
" \n",
"\n",
" self.stats[\"max\"][\"score\"].append(score)\n",
" if self.no_train_step % self.UPDATE_FREQUENCY == 0:\n",
" # Update model\n",
" returns, advantages = self._discount_rewards(values)\n",
" \n",
"\n",
" loss = 0\n",
" for transition, ret, advantage in zip(self.transitions, returns, advantages):\n",
" reward, indexes_, outputs_, values_ = transition\n",
" \n",
"\n",
" advantage = advantage.detach() # Block gradients flow here.\n",
" probs = F.softmax(outputs_, dim=2)\n",
" log_probs = torch.log(probs)\n",
Expand All @@ -505,35 +504,35 @@
" value_loss = (.5 * (values_ - ret) ** 2.).sum()\n",
" entropy = (-probs * log_probs).sum()\n",
" loss += policy_loss + 0.5 * value_loss - 0.1 * entropy\n",
" \n",
"\n",
" self.stats[\"mean\"][\"reward\"].append(reward)\n",
" self.stats[\"mean\"][\"policy\"].append(policy_loss.item())\n",
" self.stats[\"mean\"][\"value\"].append(value_loss.item())\n",
" self.stats[\"mean\"][\"entropy\"].append(entropy.item())\n",
" self.stats[\"mean\"][\"confidence\"].append(torch.exp(log_action_probs).item())\n",
" \n",
"\n",
" if self.no_train_step % self.LOG_FREQUENCY == 0:\n",
" msg = \"{:6d}. \".format(self.no_train_step)\n",
" msg += \" \".join(\"{}: {: 3.3f}\".format(k, np.mean(v)) for k, v in self.stats[\"mean\"].items())\n",
" msg += \" \" + \" \".join(\"{}: {:2d}\".format(k, np.max(v)) for k, v in self.stats[\"max\"].items())\n",
" msg += \" vocab: {:3d}\".format(len(self.id2word))\n",
" print(msg)\n",
" self.stats = {\"max\": defaultdict(list), \"mean\": defaultdict(list)}\n",
" \n",
"\n",
" loss.backward()\n",
" nn.utils.clip_grad_norm_(self.model.parameters(), 40)\n",
" self.optimizer.step()\n",
" self.optimizer.zero_grad()\n",
" \n",
"\n",
" self.transitions = []\n",
" self.model.reset_hidden(1)\n",
" else:\n",
" # Keep information about transitions for Truncated Backpropagation Through Time.\n",
" self.transitions.append([None, indexes, outputs, values]) # Reward will be set on the next call\n",
" \n",
"\n",
" if done:\n",
" self.last_score = 0 # Will be starting a new episode. Reset the last score.\n",
" \n",
"\n",
" return action"
]
},
Expand Down Expand Up @@ -990,7 +989,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions notebooks/Playing text-based games with TextWorld.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
"outputs": [],
"source": [
"# We are now ready to start the game.\n",
"env = textworld.start('./zork1.z5', infos=infos)"
"env = textworld.start('./zork1.z5', request_infos=infos)"
]
},
{
Expand Down Expand Up @@ -381,7 +381,7 @@
" env.render()\n",
" command = input(\"> \")\n",
" game_state, reward, done = env.step(command)\n",
" \n",
"\n",
" env.render() # Final message.\n",
"except KeyboardInterrupt:\n",
" pass # Quit the game.\n",
Expand Down Expand Up @@ -706,7 +706,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion scripts/check_generated_games.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
for i, game in enumerate(args.games, start=1):
print("{}. Testing {} ...".format(i, game))
env = textworld.start(game)
env.infos.admissible_commands = True
env.request_infos.admissible_commands = True
agent.reset(env)
game_state = env.reset()

Expand Down
2 changes: 1 addition & 1 deletion scripts/tw-view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ if __name__ == "__main__":
args = build_parser().parse_args()

gamefile = os.path.splitext(args.game)[0] + ".json"
env = textworld.start(gamefile, infos=EnvInfos(facts=True))
env = textworld.start(gamefile, request_infos=EnvInfos(facts=True))
state = env.reset()

show_graph(state.facts, renderer="browser")
2 changes: 0 additions & 2 deletions tests/test_textworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def test_playing_generated_games():
# Play the game using RandomAgent and make sure we can always finish the
# game by following the winning policy.
env = textworld.start(game_file)
env.infos.policy_commands = True
env.infos.game = True

agent = textworld.agents.RandomCommandAgent()
agent.reset(env)
Expand Down
24 changes: 12 additions & 12 deletions textworld/agents/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,32 @@ def reset(self, env):
env.display_command_during_render = False

if self.autocompletion:
env.infos.admissible_commands = True
env.request_infos.admissible_commands = True

if self.oracle:
env.infos.policy_commands = True
env.infos.intermediate_reward = True
env.request_infos.policy_commands = True
env.request_infos.intermediate_reward = True

def act(self, game_state, reward, done):
if (self.oracle and game_state.policy_commands and not done):
if (self.oracle and game_state["policy_commands"] and not done):
text = '[{score}/{max_score}|({intermediate_score}): {policy}]\n'.format(
score=game_state.score,
max_score=game_state.max_score,
intermediate_score=game_state.intermediate_reward,
policy=" > ".join(game_state.policy_commands)
score=game_state["score"],
max_score=game_state["max_score"],
intermediate_score=game_state["intermediate_reward"],
policy=" > ".join(game_state["policy_commands"])
)
print("Oracle: {}\n".format(text))

if prompt_toolkit_available:
actions_completer = None
if self.autocompletion and game_state.admissible_commands:
actions_completer = WordCompleter(game_state.admissible_commands,
if self.autocompletion and game_state["admissible_commands"]:
actions_completer = WordCompleter(game_state["admissible_commands"],
ignore_case=True, sentence=True)
action = prompt('> ', completer=actions_completer,
history=self._history, enable_history_search=True)
else:
if self.autocompletion and game_state.admissible_commands:
print("Available actions: {}\n".format(game_state.admissible_commands))
if self.autocompletion and game_state["admissible_commands"]:
print("Available actions: {}\n".format(game_state["admissible_commands"]))

action = input('> ')

Expand Down
2 changes: 1 addition & 1 deletion textworld/agents/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, seed=1234):
self.rng = np.random.RandomState(self.seed)

def reset(self, env):
env.infos.admissible_commands = True
env.request_infos.admissible_commands = True
env.display_command_during_render = True

def act(self, game_state, reward, done):
Expand Down
10 changes: 5 additions & 5 deletions textworld/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ class Environment:
You pick up the TextWorld style key from the ground.
"""

def __init__(self, infos: Optional[EnvInfos] = None) -> None:
def __init__(self, request_infos: Optional[EnvInfos] = None) -> None:
"""
Arguments:
infos: Information to be included in the game state. By
default, only the game's narrative is included.
request_infos: Information to be included in the game state. By
default, only the game's narrative is included.
"""
self.state = GameState()
self.infos = infos or EnvInfos()
self.request_infos = request_infos or EnvInfos()

def load(self, path: str) -> None:
""" Loads a new text-based game.
Expand Down Expand Up @@ -423,5 +423,5 @@ class EnvInfoMissingError(NameError):

def __init__(self, requester, info):
msg = ("The info '{info}' requested by `{requester}` is missing."
" Make sure it is enabled like so `Environment(infos=EnvInfos(`{info}`=True))`.")
" Make sure it is enabled like so `Environment(request_infos=EnvInfos(`{info}`=True))`.")
super().__init__(msg.format(info=info, requester=requester))
4 changes: 2 additions & 2 deletions textworld/envs/tests/test_tw.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setUpClass(cls):

cls.game = testing.build_game(cls.options)
cls.game.save(cls.gamefile)
cls.infos = EnvInfos(
cls.request_infos = EnvInfos(
facts=True,
policy_commands=True,
admissible_commands=True,
Expand All @@ -40,7 +40,7 @@ def tearDownClass(cls):
shutil.rmtree(cls.tmpdir)

def setUp(self):
self.env = TextWorldEnv(self.infos)
self.env = TextWorldEnv(self.request_infos)
self.env.load(self.gamefile)

def test_feedback(self):
Expand Down
Loading

0 comments on commit e72ed2a

Please sign in to comment.