Skip to content

Commit

Permalink
Merge pull request #347 from MarcCote/enh_possible_commands
Browse files Browse the repository at this point in the history
ENH: add new requestable attributes to EnvInfos
  • Loading branch information
MarcCote authored Sep 28, 2024
2 parents e62b90d + 78e4bc3 commit c1b959b
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 12 deletions.
16 changes: 13 additions & 3 deletions textworld/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ class EnvInfos:
'game',
'won', 'lost',
'score', 'moves', 'max_score', 'objective',
'entities', 'verbs', 'command_templates',
'admissible_commands', 'intermediate_reward',
'policy_commands',
'entities', 'typed_entities', 'verbs', 'command_templates',
'admissible_commands', 'possible_admissible_commands',
'possible_commands',
'intermediate_reward', 'policy_commands',
'extras']

def __init__(self, **kwargs):
Expand Down Expand Up @@ -71,6 +72,12 @@ def __init__(self, **kwargs):
#: bool: All commands relevant to the current state.
#: This information changes from one step to another.
self.admissible_commands = kwargs.get("admissible_commands", False)
#: bool: All possible commands regardless of the current state.
#: This information *doesn't* change from one step to another.
self.possible_admissible_commands = kwargs.get("possible_admissible_commands", False)
#: bool: All possible commands regardless of the current state and the arguments type.
#: This information *doesn't* change from one step to another.
self.possible_commands = kwargs.get("possible_commands", False)
#: bool: Sequence of commands leading to a winning state.
#: This information changes from one step to another.
self.policy_commands = kwargs.get("policy_commands", False)
Expand All @@ -92,6 +99,9 @@ def __init__(self, **kwargs):
#: bool: Names of all entities in the game.
#: This information *doesn't* change from one step to another.
self.entities = kwargs.get("entities", False)
#: bool: Names of all entities in the game and their type.
#: This information *doesn't* change from one step to another.
self.typed_entities = kwargs.get("typed_entities", False)
#: bool: Verbs understood by the the game.
#: This information *doesn't* change from one step to another.
self.verbs = kwargs.get("verbs", False)
Expand Down
2 changes: 2 additions & 0 deletions textworld/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def _guess_backend(path):
return GitGlulxEnv
elif re.search(r"\.z[1-8]", path):
return JerichoEnv
elif path.endswith(".json"):
return TextWorldEnv
elif path.endswith(".tw-pddl"):
return PddlEnv

Expand Down
3 changes: 3 additions & 0 deletions textworld/envs/tw.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def _gather_infos(self):
self.state["command_templates"] = self._game.command_templates
self.state["verbs"] = self._game.verbs
self.state["entities"] = self._game.entity_names
self.state["typed_entities"] = self._game.objects_names_and_types
self.state["possible_commands"] = self._game.possible_commands
self.state["possible_admissible_commands"] = self._game.possible_admissible_commands
self.state["objective"] = self._game.objective
self.state["max_score"] = self._game.max_score

Expand Down
9 changes: 8 additions & 1 deletion textworld/envs/wrappers/tw_inform7.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def _wrap(self, env):
@classmethod
def compatible(cls, path: str) -> bool:
""" Check if path point to a TW Inform7 compatible game. """
return os.path.isfile(os.path.splitext(path)[0] + ".json")
basepath, ext = os.path.splitext(path)
if ext not in [".z8", ".ulx"]:
return False

return os.path.isfile(basepath + ".json")

def copy(self) -> "TWInform7":
""" Returns a copy this wrapper. """
Expand Down Expand Up @@ -362,6 +366,9 @@ def _gather_infos(self):
self.state["command_templates"] = self._game.command_templates
self.state["verbs"] = self._game.verbs
self.state["entities"] = self._game.entity_names
self.state["typed_entities"] = self._game.objects_names_and_types
self.state["possible_commands"] = self._game.possible_commands
self.state["possible_admissible_commands"] = self._game.possible_admissible_commands
self.state["objective"] = self._game.objective
self.state["max_score"] = self._game.max_score

Expand Down
46 changes: 38 additions & 8 deletions textworld/generator/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# Licensed under the MIT license.


import re
import copy
import json
import textwrap

from typing import List, Dict, Optional, Mapping, Any, Iterable, Union, Tuple
from collections import OrderedDict
from functools import partial
from collections import OrderedDict, defaultdict
from functools import cached_property, partial
from itertools import product

import numpy as np
from numpy.random import RandomState
Expand Down Expand Up @@ -535,7 +537,7 @@ def max_score(self) -> float:

return sum(quest.reward for quest in self.quests if not quest.optional or quest.reward > 0)

@property
@cached_property
def command_templates(self) -> List[str]:
""" All command templates understood in this game. """
return sorted(set(cmd for cmd in self.kb.inform7_commands.values()))
Expand All @@ -544,12 +546,12 @@ def command_templates(self) -> List[str]:
def directions_names(self) -> List[str]:
return DIRECTIONS

@property
@cached_property
def objects_types(self) -> List[str]:
""" All types of objects in this game. """
return sorted(self.kb.types.types)

@property
@cached_property
def objects_names(self) -> List[str]:
""" The names of all relevant objects in this game. """
def _filter_unnamed_and_room_entities(e):
Expand All @@ -558,11 +560,11 @@ def _filter_unnamed_and_room_entities(e):
entities_infos = filter(_filter_unnamed_and_room_entities, self.infos.values())
return [info.name for info in entities_infos]

@property
@cached_property
def entity_names(self) -> List[str]:
return self.objects_names + self.directions_names

@property
@cached_property
def objects_names_and_types(self) -> List[str]:
""" The names of all non-player objects along with their type in this game. """
def _filter_unnamed_and_room_entities(e):
Expand All @@ -571,12 +573,40 @@ def _filter_unnamed_and_room_entities(e):
entities_infos = filter(_filter_unnamed_and_room_entities, self.infos.values())
return [(info.name, info.type) for info in entities_infos]

@property
@cached_property
def verbs(self) -> List[str]:
""" Verbs that should be recognized in this game. """
# Retrieve commands templates for every rule.
return sorted(set(cmd.split()[0] for cmd in self.command_templates))

@cached_property
def possible_commands(self) -> List[str]:
""" All possible commands when ignoring their arguments' type. """
action_templates = set(re.sub(r"{.*?}", "{}", a) for a in self.command_templates)
possible_commands = [template.format(*mapping)
for template in action_templates
for mapping in product(self.objects_names, repeat=template.count("{}"))]
return sorted(possible_commands)

@cached_property
def possible_admissible_commands(self) -> List[str]:
""" Superset of the admissible commands irrespective of the current state. """
type2names = defaultdict(list)
for name, type in self.objects_names_and_types:
type2names[f'{{{type}}}'].append(name)

templates = [re.sub(r"{.*?}", "{{{}}}", template).format(*mappings)
for template in self.command_templates
for mappings in product(*[[arg.strip('{}')] + self.kb.types.descendants(arg.strip('{}'))
for arg in re.findall(r"{.*?}", template)])]

templates = sorted(set(templates))
commands = [re.sub(r"{.*?}", "{}", template).format(*mappings)
for template in templates
for mappings in product(*[type2names[arg] for arg in re.findall(r"{.*?}", template)])]

return sorted(commands)

@property
def objective(self) -> str:
if self._objective is not None:
Expand Down

0 comments on commit c1b959b

Please sign in to comment.