Hikaru Shindo, Quentin Delfosse, Devendra Singh Dhami, Kristian Kersting
We propose a framework that jointly learns symbolic and neural policies for reinforcement learning.
Install nsfr
and nudge
.
Training script:
python train_blenderl.py --env-name seaquest --joint-training --num-steps 128 --num-envs 5 --gamma 0.99
- --joint-training: train neural and logic modules jointly
- --num-steps: the number of steps for policy rollout
- --num-envs: the number of environments to train agents
- --gamma: the discount factor for future rewards
Download the trained agents:
wget https://hessenbox.tu-darmstadt.de/dl/fiCNznPuWkALH8JaCJWHeeAV/models.zip
unzip models.zip
rm models.zip
Play script:
python play_gui.py --env-name kangaroo --agent-path models/kangaroo_demo
python play_gui.py --env-name seaquest --agent-path models/seaquest_demo
Note that a checkpoint is required to run the play script.
The hyperparameters are configured inside in/config/default.yaml
which is loaded as default. You can specify a different configuration by providing the corresponding YAML file path as an argument, e.g., python train.py in/config/my_config.yaml
. A description of all hyperparameters can be found in train.py
.
Inside in/envs/[env_name]/logic/[ruleset_name]/
, you find the logic rules that are used as a starting point for training. You can change them or create new rule sets. The ruleset to use is specified with the hyperparam rules
.
If you want to use NUDGE within other projects, you can install NUDGE locally as follows:
- Inside
nsfr/
runpython setup.py develop
- Inside
nudge/
runpython setup.py develop
-
Install packages by
pip install -r requirements.txt
-
PyG and torch-scatter for neumann Install PyG and torch-scatter packages for neumann reasoner. See the installation guide. These should be consistent in terms of ther versions, e.g.
pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html pip install torch_geometric pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
You add a new environment inside in/envs/[new_env_name]/
. There, you need to define a NudgeEnv
class that wraps the original environment in order to do
- logic state extraction: translates raw env states into logic representations
- valuation: Each relation (like
closeby
) has a corresponding valuation function which maps the (logic) game state to a probability that the relation is true. Each valuation function is defined as a simple Python function. The function's name must match the name of the corresponding relation. - action mapping: action-predicates predicted by the agent need to be mapped to the actual env actions
See the freeway
env to see how it is done.