Implementation of paper "In-context Reinforcement Learning with Algorithm Distillation"
git clone ...
cd ad-icrl
python -m venv venv
pip install -r requirements/requirements.txt
Set up environmental variable PYTHONPATH
required to run the project:
export PYTHONPATH=.
Also you can set up:
DEVICE
.
├── notebooks/
│ ├── colab.ipynb # run training in colab
│ ├── explore_dataset.ipynb # overview the dataset for training
│ ├── config_paper.py # hyperparams from paper
│ └── test_dark_room.py # playground for dark room env
├── README.md
├── requirements/
│ ├── requirements_colab.txt
│ └── requirements.txt
├── saved_data # dir where to save data during execution
│ ├── goals_9.txt
│ ├── learning_history/
│ ├── logs/
│ └── permutations_9.txt # permutation of goals for train-test-split
├── scripts/
│ ├── collect_data.sh # run stage 1: dataset collection
│ ├── eval.sh # run stage 3: evaluate the trained model
│ └── train_ad.sh # run stage 2: train algorithm distillation
├── src/
│ ├── check_in_context/ # stage 3:
│ │ └── eval_one_model.py # evaluate a given model
│ ├── data/
│ │ ├── __init__.py # stage 1:
│ │ ├── env.py # env config
│ │ ├── generate_goals.py # goals setup
│ │ └── ppo.py # collect dataset with ppo
│ └── dt # stage 2:
│ ├── eval.py # rollout eval
│ ├── model.py # AD model
│ ├── schedule.py # lr sheduler
│ ├── seq_dataset.py # dataloader
│ ├── train.py # train script & config
│ └── utils.py
├── static/ # images for report
├── runs/ # tensorboard dir
└── wandb/ # wandb dir
To enable wandb logging you need to create your wandb profile and run the following once:
wandb init
To disable wandb logging (for debugging or other reason) you need to run:
wandb disabled
To enable wandb logging (when you need to turn on looging again) you need to run:
wandb enabled
or use the provided file saved_data/permutations_9.txt.
python src/data/generate_goals.py
or load the trajectories from gdrive via link.
Run the script to create the trajectories:
chmod +x ./scripts/collect_data.sh
./scripts/collect_data.sh
The script will use src/data/ppo.py
file. This script uses wandb
logging, to disable it provide --no-track
flag.
Observe the learning process on tensorboard:
tensorboard --logdir saved_data/logs/
Use notebooks/explore_dataset.ipynb to get the statistics about learned trajectories.
or load trained models from gdrive via link
python src/dt/train.py
Also you can run bash script: scripts/train_ad.sh
python src/dt/train.py --config.add-reward_head
Pass the directory with model and yaml config to evaluation script to get evaluation pngs.
python3 src/check_in_context/eval_one_model.py \
--model-dir /path/to/dir/with/model/and/config
The results and the experiments can be found in this wandb report.
The code is based on the following implementations: