This is the official codebase for the ICLR 2023 spotlight paper SMART: Self-supervised Multi-task pretrAining with contRol Transformers. If you use this code in an academic context, please use the following citation:
@inproceedings{
sun2023smart,
title={{SMART}: Self-supervised Multi-task pretrAining with contRol Transformers},
author={Yanchao Sun and Shuang Ma and Ratnesh Madaan and Rogerio Bonatti and Furong Huang and Ashish Kapoor},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=9piH3Hg8QEf}
}
-
Using conda
# dmc specific # create env conda env create --file docker/environment.yml # activate conda conda activate smart bash scripts/dmc_setup.sh # install this repo (smart) $ pip install -e .
-
Using docker
# build image docker build \ -f Dockerfile_base_azureml_dmc \ --build-arg BASE_IMAGE=openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04:latest \ -t smart:latest . # run image docker run -it -d --gpus=all --name=rl_pretrain_dmc_1 -v HOST_PATH:CONTAINER_PATH smart:latest # setup the repo (run inside the container) pip install -e .
-
Install azcopy
wget https://aka.ms/downloadazcopy-v10-linux tar -xvf downloadazcopy-v10-linux sudo cp ./azcopy_linux_amd64_*/azcopy /usr/bin/ rm -rf *azcopy*
-
Downloading the full dataset (1.18TiB)
# download to data/ directory azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae' 'data' --recursive
-
Downloading a subset of the full dataset
# download to data/ directory azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/TYPE_DOMAIN_TASK' 'data' --recursive
where
-
TYPE
:randcollect
,fullcollect
Note:fullcollect
datasets are ~10x larger thanrandcollect
datasets) -
DOMAIN_TASK
:cartpole_balance
,cartpole_swingup
,cheetah_run
,finger_spin
,hopper_hop
,hopper_stand
,pendulum_swingup
,walker_run
,walker_stand
, orwalker_walk
(See Table 2 in the paper)
Example:
# download to data/ directory (~ 9.7 GB each) azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/randcollect_walker_walk' 'data' --recursive azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/randcollect_cheetah_run' 'data' --recursive
-
-
Downloading the pretrained models
# download to pretrained_models/ directory (236.34 MiB) azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/pretrained_models' '.' --recursive
Let us run the code on the aforementioned small subset of randcollect_walker_walk
and randcollect_cheetah_run
.
python src/dmc_pretrain.py base=configs/pretrain.yaml \
epochs=10 \
data.num_steps=80000 \
domain_and_task.source_data_type=rand \
data.train_replay_id=1 \
data.data_dir_prefix=data \
model.model_type=naive \
domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \
output_dir=./outputs/pretrain_explore_subset
The set of pretraining tasks can be specified in the config file as shown below:
- Pretrain with offline data collected by exploratory policies
python src/dmc_pretrain.py base=configs/pretrain.yaml \
epochs=10 \
data.num_steps=80000 \
data.train_replay_id=5 \
data.data_dir_prefix=data \
model.model_type=naive \
domain_and_task.source_data_type=full \
domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \
output_dir=./outputs/pretrain_explore
- Pretrain with offline data collected by random policies
python src/dmc_pretrain.py base=configs/pretrain.yaml \
epochs=10 \
data.num_steps=80000 \
data.train_replay_id=5 \
data.data_dir_prefix=data \
model.model_type=naive \
domain_and_task.source_data_type=rand \
domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \
output_dir=./outputs/pretrain_random
You can also download our pretrained models as reported in the paper, using the azcopy
command in the previous section.
## set the downstream domain and task
DOMAIN=cheetah
TASK=run
## behavior cloning as the learning algorithm
python src/dmc_downstream.py base=configs/downstream.yaml \
epochs=30 \
data.num_steps=1000000 \
domain_and_task.domain=${DOMAIN} \
domain_and_task.task=${TASK} \
model.model_type=naive \
no_load_action=True \
load_model_from=./outputs/pretrain_explore/checkpoints/last.ckpt \
output_dir=./outputs/${DOMAIN}_${TASK}_bc/
## RTG-conditioned learning as the learning algorithm
python src/dmc_downstream.py \
epochs=30 \
data.num_steps=1000000 \
domain_and_task.domain=${DOMAIN} \
domain_and_task.task=${TASK} \
model.model_type=reward_conditioned \
data.rand_select=True \
no_load_action=True \
load_model_from=./outputs/pretrain_explore/checkpoints/last.ckpt \
output_dir=./outputs/${DOMAIN}_${TASK}_bc/
Note that if --load_model_from is not specified, the model is trained from scratch.