diff --git a/textworld/core.py b/textworld/core.py index a6350cbf..69ad26ff 100644 --- a/textworld/core.py +++ b/textworld/core.py @@ -25,9 +25,10 @@ class EnvInfos: 'game', 'won', 'lost', 'score', 'moves', 'max_score', 'objective', - 'entities', 'verbs', 'command_templates', - 'admissible_commands', 'intermediate_reward', - 'policy_commands', + 'entities', 'typed_entities', 'verbs', 'command_templates', + 'admissible_commands', 'possible_admissible_commands', + 'possible_commands', + 'intermediate_reward', 'policy_commands', 'extras'] def __init__(self, **kwargs): @@ -71,6 +72,12 @@ def __init__(self, **kwargs): #: bool: All commands relevant to the current state. #: This information changes from one step to another. self.admissible_commands = kwargs.get("admissible_commands", False) + #: bool: All possible commands regardless of the current state. + #: This information *doesn't* change from one step to another. + self.possible_admissible_commands = kwargs.get("possible_admissible_commands", False) + #: bool: All possible commands regardless of the current state and the arguments type. + #: This information *doesn't* change from one step to another. + self.possible_commands = kwargs.get("possible_commands", False) #: bool: Sequence of commands leading to a winning state. #: This information changes from one step to another. self.policy_commands = kwargs.get("policy_commands", False) @@ -92,6 +99,9 @@ def __init__(self, **kwargs): #: bool: Names of all entities in the game. #: This information *doesn't* change from one step to another. self.entities = kwargs.get("entities", False) + #: bool: Names of all entities in the game and their type. + #: This information *doesn't* change from one step to another. + self.typed_entities = kwargs.get("typed_entities", False) #: bool: Verbs understood by the the game. #: This information *doesn't* change from one step to another. self.verbs = kwargs.get("verbs", False) diff --git a/textworld/envs/__init__.py b/textworld/envs/__init__.py index f7bf0e76..cb39b91c 100644 --- a/textworld/envs/__init__.py +++ b/textworld/envs/__init__.py @@ -15,6 +15,8 @@ def _guess_backend(path): return GitGlulxEnv elif re.search(r"\.z[1-8]", path): return JerichoEnv + elif path.endswith(".json"): + return TextWorldEnv elif path.endswith(".tw-pddl"): return PddlEnv diff --git a/textworld/envs/tw.py b/textworld/envs/tw.py index ebcb5b5a..c2dbd68a 100644 --- a/textworld/envs/tw.py +++ b/textworld/envs/tw.py @@ -49,6 +49,9 @@ def _gather_infos(self): self.state["command_templates"] = self._game.command_templates self.state["verbs"] = self._game.verbs self.state["entities"] = self._game.entity_names + self.state["typed_entities"] = self._game.objects_names_and_types + self.state["possible_commands"] = self._game.possible_commands + self.state["possible_admissible_commands"] = self._game.possible_admissible_commands self.state["objective"] = self._game.objective self.state["max_score"] = self._game.max_score diff --git a/textworld/envs/wrappers/tw_inform7.py b/textworld/envs/wrappers/tw_inform7.py index 3fe6af56..0ad8e4c0 100644 --- a/textworld/envs/wrappers/tw_inform7.py +++ b/textworld/envs/wrappers/tw_inform7.py @@ -104,7 +104,11 @@ def _wrap(self, env): @classmethod def compatible(cls, path: str) -> bool: """ Check if path point to a TW Inform7 compatible game. """ - return os.path.isfile(os.path.splitext(path)[0] + ".json") + basepath, ext = os.path.splitext(path) + if ext not in [".z8", ".ulx"]: + return False + + return os.path.isfile(basepath + ".json") def copy(self) -> "TWInform7": """ Returns a copy this wrapper. """ @@ -362,6 +366,9 @@ def _gather_infos(self): self.state["command_templates"] = self._game.command_templates self.state["verbs"] = self._game.verbs self.state["entities"] = self._game.entity_names + self.state["typed_entities"] = self._game.objects_names_and_types + self.state["possible_commands"] = self._game.possible_commands + self.state["possible_admissible_commands"] = self._game.possible_admissible_commands self.state["objective"] = self._game.objective self.state["max_score"] = self._game.max_score diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 845a9023..1096f137 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -2,13 +2,15 @@ # Licensed under the MIT license. +import re import copy import json import textwrap from typing import List, Dict, Optional, Mapping, Any, Iterable, Union, Tuple -from collections import OrderedDict -from functools import partial +from collections import OrderedDict, defaultdict +from functools import cached_property, partial +from itertools import product import numpy as np from numpy.random import RandomState @@ -535,7 +537,7 @@ def max_score(self) -> float: return sum(quest.reward for quest in self.quests if not quest.optional or quest.reward > 0) - @property + @cached_property def command_templates(self) -> List[str]: """ All command templates understood in this game. """ return sorted(set(cmd for cmd in self.kb.inform7_commands.values())) @@ -544,12 +546,12 @@ def command_templates(self) -> List[str]: def directions_names(self) -> List[str]: return DIRECTIONS - @property + @cached_property def objects_types(self) -> List[str]: """ All types of objects in this game. """ return sorted(self.kb.types.types) - @property + @cached_property def objects_names(self) -> List[str]: """ The names of all relevant objects in this game. """ def _filter_unnamed_and_room_entities(e): @@ -558,11 +560,11 @@ def _filter_unnamed_and_room_entities(e): entities_infos = filter(_filter_unnamed_and_room_entities, self.infos.values()) return [info.name for info in entities_infos] - @property + @cached_property def entity_names(self) -> List[str]: return self.objects_names + self.directions_names - @property + @cached_property def objects_names_and_types(self) -> List[str]: """ The names of all non-player objects along with their type in this game. """ def _filter_unnamed_and_room_entities(e): @@ -571,12 +573,40 @@ def _filter_unnamed_and_room_entities(e): entities_infos = filter(_filter_unnamed_and_room_entities, self.infos.values()) return [(info.name, info.type) for info in entities_infos] - @property + @cached_property def verbs(self) -> List[str]: """ Verbs that should be recognized in this game. """ # Retrieve commands templates for every rule. return sorted(set(cmd.split()[0] for cmd in self.command_templates)) + @cached_property + def possible_commands(self) -> List[str]: + """ All possible commands when ignoring their arguments' type. """ + action_templates = set(re.sub(r"{.*?}", "{}", a) for a in self.command_templates) + possible_commands = [template.format(*mapping) + for template in action_templates + for mapping in product(self.objects_names, repeat=template.count("{}"))] + return sorted(possible_commands) + + @cached_property + def possible_admissible_commands(self) -> List[str]: + """ Superset of the admissible commands irrespective of the current state. """ + type2names = defaultdict(list) + for name, type in self.objects_names_and_types: + type2names[f'{{{type}}}'].append(name) + + templates = [re.sub(r"{.*?}", "{{{}}}", template).format(*mappings) + for template in self.command_templates + for mappings in product(*[[arg.strip('{}')] + self.kb.types.descendants(arg.strip('{}')) + for arg in re.findall(r"{.*?}", template)])] + + templates = sorted(set(templates)) + commands = [re.sub(r"{.*?}", "{}", template).format(*mappings) + for template in templates + for mappings in product(*[type2names[arg] for arg in re.findall(r"{.*?}", template)])] + + return sorted(commands) + @property def objective(self) -> str: if self._objective is not None: