Skip to content

Non-official implementation of paper "In-context Reinforcement Learning with Algorithm Distillation"

Notifications You must be signed in to change notification settings

cinemere/ad-icrl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

67 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ad-icrl

Implementation of paper "In-context Reinforcement Learning with Algorithm Distillation"

Installation 🧑‍🔧

git clone ...
cd ad-icrl
python -m venv venv
pip install -r requirements/requirements.txt

Environment Variables

Set up environmental variable PYTHONPATH required to run the project:

export PYTHONPATH=.

Also you can set up:

DEVICE

Repository structure

.
├── notebooks/
│   ├── colab.ipynb                        # run training in colab
│   ├── explore_dataset.ipynb              # overview the dataset for training
│   ├── config_paper.py                    # hyperparams from paper
│   └── test_dark_room.py                  # playground for dark room env
├── README.md
├── requirements/
│   ├── requirements_colab.txt
│   └── requirements.txt
├── saved_data                             # dir where to save data during execution
│   ├── goals_9.txt
│   ├── learning_history/
│   ├── logs/
│   └── permutations_9.txt                 # permutation of goals for train-test-split
├── scripts/ 
│   ├── collect_data.sh                    # run stage 1: dataset collection
│   ├── eval.sh                            # run stage 3: evaluate the trained model
│   └── train_ad.sh                        # run stage 2: train algorithm distillation
├── src/
│   ├── check_in_context/                  # stage 3:
│   │   └── eval_one_model.py              #    evaluate a given model
│   ├── data/
│   │   ├── __init__.py                    # stage 1:
│   │   ├── env.py                         #    env config 
│   │   ├── generate_goals.py              #    goals setup
│   │   └── ppo.py                         #    collect dataset with ppo
│   └── dt                                 # stage 2:
│       ├── eval.py                        #    rollout eval
│       ├── model.py                       #    AD model
│       ├── schedule.py                    #    lr sheduler
│       ├── seq_dataset.py                 #    dataloader 
│       ├── train.py                       #    train script & config
│       └── utils.py
├── static/                                # images for report  
├── runs/                                  # tensorboard dir
└── wandb/                                 # wandb dir  

Quick start 🏃

0.0 Setup wandb

To enable wandb logging you need to create your wandb profile and run the following once:

wandb init

To disable wandb logging (for debugging or other reason) you need to run:

wandb disabled

To enable wandb logging (when you need to turn on looging again) you need to run:

wandb enabled

0.1 Generate goals and permuations

or use the provided file saved_data/permutations_9.txt.

python src/data/generate_goals.py

1. Learn PPO agents to collect the dataset

or load the trajectories from gdrive via link.

Run the script to create the trajectories:

chmod +x ./scripts/collect_data.sh
./scripts/collect_data.sh

The script will use src/data/ppo.py file. This script uses wandb logging, to disable it provide --no-track flag.

Observe the learning process on tensorboard:

tensorboard --logdir saved_data/logs/

Use notebooks/explore_dataset.ipynb to get the statistics about learned trajectories.

2. Train AD

or load trained models from gdrive via link

python src/dt/train.py

Also you can run bash script: scripts/train_ad.sh

Train AD with reward predictor:

python src/dt/train.py --config.add-reward_head

3. Evaluate trained model

Pass the directory with model and yaml config to evaluation script to get evaluation pngs.

python3 src/check_in_context/eval_one_model.py \
--model-dir /path/to/dir/with/model/and/config

Results and experiments 🖥️

The results and the experiments can be found in this wandb report.

Acknowledgements ⭐ 🌟 ⭐

The code is based on the following implementations:

About

Non-official implementation of paper "In-context Reinforcement Learning with Algorithm Distillation"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published