diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..2c54980 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,16 @@ +# ref: https://github.com/pre-commit-ci-demo/demo/blob/main/.github/workflows/pre-commit.yml +name: pre-commit + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - uses: pre-commit/action@v2.0.0 diff --git a/.gitignore b/.gitignore index 9a65382..71b7e08 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,23 @@ +data/ +pretrained_models/ -outputs/ +# Byte-compiled / optimized / DLL files __pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +build/ +dist/ +*.egg-info/ + +# data +data/ + +# logs +logs/ +outputs/ +# env +.env +.autoenv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..0ccead7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,78 @@ +default_language_version: + python: python3 + +ci: + autofix_prs: true + autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" + autoupdate_schedule: quarterly + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-case-conflict + - id: debug-statements + - id: detect-private-key + - id: check-added-large-files + args: ["--maxkb=500", "--enforce-all"] + exclude: | + (?x)^( + )$ + + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: [--py37-plus] + name: Upgrade code + + # python formatting + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + name: Format code + args: ["--line-length=120"] + + - repo: https://github.com/hadialqattan/pycln + rev: v2.1.3 # Possible releases: https://github.com/hadialqattan/pycln/releases + hooks: + - id: pycln + args: [--all] + + # ref: https://github.com/microsoft/vscode-isort] + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + args: [--profile, "black"] + + # python docstring formatting + - repo: https://github.com/myint/docformatter + rev: v1.5.1 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries, "99", --wrap-descriptions, "92"] + + # yaml formatting + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.0.0-alpha.4 + hooks: + - id: prettier + types: [yaml] + + # markdown formatting + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.16 + hooks: + - id: mdformat + additional_dependencies: + - mdformat-gfm + #- mdformat-black + - mdformat_frontmatter + exclude: CHANGELOG.md diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..4bbc1e9 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,20 @@ +{ + // See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations. + // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp + // List of extensions which should be recommended for users of this workspace. + "recommendations": [ + "ms-python.python", + "ms-azuretools.vscode-docker", + "ms-vscode-remote.remote-containers", + "ms-vscode-remote.remote-ssh", + "ms-vscode-remote.remote-ssh-edit", + "ms-vscode-remote.vscode-remote-extensionpack", + "redhat.vscode-yaml", + "yzhang.markdown-all-in-one", + "TrungNgo.autoflake", + "njpwerner.autodocstring", + "jbockle.jbockle-format-files" + ], + // List of extensions recommended by VS Code that should not be recommended for users of this workspace. + "unwantedRecommendations": [] +} diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..efb6f1d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,39 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + }, + { + "name": "dmc_pretrain", + "type": "python", + "request": "launch", + "program": "src/dmc_pretrain.py", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "base=configs/pretrain.yaml", + "trainer.num_nodes=1", + "trainer.devices=1", + ] + }, + { + "name": "dmc_downstream", + "type": "python", + "request": "launch", + "program": "src/dmc_downstream.py", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "base=configs/downstream.yaml", + "trainer.num_nodes=1", + "trainer.devices=1", + ] + }, + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..f6fb23b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,29 @@ +{ + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnPaste": true, + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": true + }, + "python.analysis.typeCheckingMode": "basic", + "python.formatting.provider": "black", + "python.formatting.blackArgs": [ + "--line-length", + "120" + ], + "python.linting.enabled": true, + "python.linting.pylintEnabled": false, + "python.linting.flake8Enabled": true, + "python.linting.flake8Args": [ + "--max-line-length=120", + ], + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "isort.args": [ + "--profile", + "black" + ], +} diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..f9ba8cf --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..e744752 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,47 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + + PACT borrows from the following external code: + + - minGPT, located at /src/models/modules/minGPT.py + + """ + The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ diff --git a/README.md b/README.md index 9748c79..6b474de 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,18 @@ -# SMART: Self-supervised Multi-task pretrAining with contRol Transformers +# SMART: Self-supervised Multi-task pretrAining with contRol Transformers +This is the official codebase for the ICLR 2023 spotlight paper [SMART: Self-supervised Multi-task pretrAining with contRol Transformers](https://openreview.net/forum?id=9piH3Hg8QEf). +If you use this code in an academic context, please use the following citation: -This is the official codebase for the ICLR 2023 spotlight paper "SMART: Self-supervised Multi-task pretrAining with contRol Transformers". Pretrained models can be downloaded [here](https://link-url-here.org). Dataset can be downloaded [here](https://link-url-here.org). +``` +@inproceedings{ +sun2023smart, +title={{SMART}: Self-supervised Multi-task pretrAining with contRol Transformers}, +author={Yanchao Sun and Shuang Ma and Ratnesh Madaan and Rogerio Bonatti and Furong Huang and Ashish Kapoor}, +booktitle={International Conference on Learning Representations}, +year={2023}, +url={https://openreview.net/forum?id=9piH3Hg8QEf} +} +``` ## Setting up @@ -14,7 +25,7 @@ This is the official codebase for the ICLR 2023 spotlight paper "SMART: Self-sup # activate conda conda activate smart - bash src/scripts/dmc_setup.sh + bash scripts/dmc_setup.sh # install this repo (smart) $ pip install -e . @@ -23,62 +34,147 @@ This is the official codebase for the ICLR 2023 spotlight paper "SMART: Self-sup - Using docker ``` - # dmc specific - docker pull PUBLIC_DOCKER_IMAGE + # build image + docker build \ + -f Dockerfile_base_azureml_dmc \ + --build-arg BASE_IMAGE=openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04:latest \ + -t smart:latest . # run image - docker run -it -d --gpus=all --name=rl_pretrain_dmc_1 -v HOST_PATH:CONTAINER_PATH commondockerimages.azurecr.io/atari_pretrain:latest-azureml-dmc + docker run -it -d --gpus=all --name=rl_pretrain_dmc_1 -v HOST_PATH:CONTAINER_PATH smart:latest # setup the repo (run inside the container) pip install -e . ``` -## Preparing the dataset +## Downloading data and pre-trained models download from Azure + +- Install azcopy + + ``` + wget https://aka.ms/downloadazcopy-v10-linux + tar -xvf downloadazcopy-v10-linux + sudo cp ./azcopy_linux_amd64_*/azcopy /usr/bin/ + rm -rf *azcopy* + ``` + +- Downloading the full dataset (1.18TiB) + + ``` + # download to data/ directory + azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae' 'data' --recursive + ``` + +- Downloading a subset of the full dataset + + ``` + # download to data/ directory + azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/TYPE_DOMAIN_TASK' 'data' --recursive + ``` + + where + + - `TYPE`: `randcollect`, `fullcollect` + Note: `fullcollect` datasets are ~10x larger than `randcollect` datasets) + + - `DOMAIN_TASK`: `cartpole_balance`, `cartpole_swingup`, `cheetah_run`, `finger_spin`, `hopper_hop`, `hopper_stand`, `pendulum_swingup`, `walker_run`, `walker_stand`, or `walker_walk` (See Table 2 in the paper) + + Example: + + ``` + # download to data/ directory (~ 9.7 GB each) + azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/randcollect_walker_walk' 'data' --recursive + azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/randcollect_cheetah_run' 'data' --recursive + ``` + +- Downloading the pretrained models -Download dataset to PATH_TO_DATASET, or collect data following this instruction. + ``` + # download to pretrained_models/ directory (236.34 MiB) + azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/pretrained_models' '.' --recursive + ``` ## Running the code -**Pretraining on multiple domains and tasks** (selection of pretraining tasks can be specified in the config file as shown below): +### Testing on small subset of full dataset + +Let us run the code on the aforementioned small subset of `randcollect_walker_walk` and `randcollect_cheetah_run`. + ``` -## pretrain with offline data collected by exploratory policies -python src/dmc_multidomain_train.py \ - --epochs 10 --num_steps 80000 --train_replay_id 5 --model_type naive \ - --multi_config configs/train_configs/multipretrain_source_v1.json \ - --output_dir ./outputs/pretrain_explore/ \ - --data_dir_prefix PATH_TO_DATASET - -## pretrain with offline data collected by random policies -python src/dmc_multidomain_train.py \ - --epochs 10 --num_steps 80000 --train_replay_id 5 --model_type naive \ - --multi_config configs/train_configs/multipretrain_source_v1.json --source_data_type rand \ - --output_dir ./outputs/pretrain_random/ \ - --data_dir_prefix PATH_TO_DATASET +python src/dmc_pretrain.py base=configs/pretrain.yaml \ + epochs=10 \ + data.num_steps=80000 \ + domain_and_task.source_data_type=rand \ + data.train_replay_id=1 \ + data.data_dir_prefix=data \ + model.model_type=naive \ + domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \ + output_dir=./outputs/pretrain_explore_subset ``` -You can also download our pretrained models as reported in the paper in this [here](https://link-url-here.org). +### Pretraining on multiple domains and tasks + +The set of pretraining tasks can be specified in the config file as shown below: + +- Pretrain with offline data collected by exploratory policies + +``` +python src/dmc_pretrain.py base=configs/pretrain.yaml \ + epochs=10 \ + data.num_steps=80000 \ + data.train_replay_id=5 \ + data.data_dir_prefix=data \ + model.model_type=naive \ + domain_and_task.source_data_type=full \ + domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \ + output_dir=./outputs/pretrain_explore +``` + +- Pretrain with offline data collected by random policies + +``` +python src/dmc_pretrain.py base=configs/pretrain.yaml \ + epochs=10 \ + data.num_steps=80000 \ + data.train_replay_id=5 \ + data.data_dir_prefix=data \ + model.model_type=naive \ + domain_and_task.source_data_type=rand \ + domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \ + output_dir=./outputs/pretrain_random +``` +### Using pretrained model and finetunes the policy on a specific downstream task: +You can also download our pretrained models as reported in the paper, using the `azcopy` command in the previous section. -The command below **loads the pretrained model and finetunes the policy on a specific downstream task**: ``` -## (example) set the downstream domain and task +## set the downstream domain and task DOMAIN=cheetah TASK=run ## behavior cloning as the learning algorithm -python src/dmc_train.py \ - --epochs 30 --num_steps 1000000 --domain ${DOMAIN} --task ${TASK} \ - --model_type naive --no_load_action \ - --load_model_from ./outputs/pretrain_explore/checkpoints/last.ckpt \ - --output_dir ./outputs/${DOMAIN}_${TASK}_bc/ +python src/dmc_downstream.py base=configs/downstream.yaml \ + epochs=30 \ + data.num_steps=1000000 \ + domain_and_task.domain=${DOMAIN} \ + domain_and_task.task=${TASK} \ + model.model_type=naive \ + no_load_action=True \ + load_model_from=./outputs/pretrain_explore/checkpoints/last.ckpt \ + output_dir=./outputs/${DOMAIN}_${TASK}_bc/ ## RTG-conditioned learning as the learning algorithm -python src/dmc_train.py \ - --epochs 30 --num_steps 1000000 --domain ${DOMAIN} --task ${TASK} \ - --model_type reward_conditioned --rand_select --no_load_action \ - --load_model_from ./outputs/pretrain_explore/checkpoints/last.ckpt \ - --output_dir ./outputs/${DOMAIN}_${TASK}_rtg/ +python src/dmc_downstream.py \ + epochs=30 \ + data.num_steps=1000000 \ + domain_and_task.domain=${DOMAIN} \ + domain_and_task.task=${TASK} \ + model.model_type=reward_conditioned + data.rand_select=False + no_load_action=True \ + load_model_from=./outputs/pretrain_explore/checkpoints/last.ckpt \ + output_dir=./outputs/${DOMAIN}_${TASK}_bc/ ``` -Note that if *--load_model_from* is not specified, the model is trained from scratch. \ No newline at end of file +Note that if *--load_model_from* is not specified, the model is trained from scratch. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..8a110e1 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + +- Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) +- Full paths of source file(s) related to the manifestation of the issue +- The location of the affected source code (tag/branch/commit or direct URL) +- Any special configuration required to reproduce the issue +- Step-by-step instructions to reproduce the issue +- Proof-of-concept or exploit code (if possible) +- Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). + + diff --git a/SUPPORT.md b/SUPPORT.md new file mode 100644 index 0000000..0aaa49c --- /dev/null +++ b/SUPPORT.md @@ -0,0 +1,11 @@ +# Support + +## How to file issues and get help + +This project uses GitHub Issues to track bugs, feature requests, or any other questions. +Please search the existing issues before filing new issues to avoid duplicates. +For new issues, file your bug or feature request as a new Issue. + +## Microsoft Support Policy + +Support for this project is limited to the resources listed above. diff --git a/cgmanifest.json b/cgmanifest.json new file mode 100644 index 0000000..9ef1db7 --- /dev/null +++ b/cgmanifest.json @@ -0,0 +1,14 @@ +{ + "Registrations": [ + { + "component": { + "type": "git", + "git": { + "repositoryUrl": "https://github.com/denisyarats/dmc2gym", + "commitHash": "06f7e335d988b17145947be9f6a76f557d0efe81" + } + } + }, + ], + "Version": 1 +} diff --git a/configs/downstream.yaml b/configs/downstream.yaml new file mode 100644 index 0000000..fa3daef --- /dev/null +++ b/configs/downstream.yaml @@ -0,0 +1,113 @@ +seed: 123 + +output_dir: ./outputs +exp_name: +load_model_from: +no_load_action: False +no_strict: False +no_action_head: False +stat_file: stat.csv +timestep: 10000 +epochs: 50 +eval_epochs: 50 + +trainer: + default_root_dir: ${output_dir} + + num_nodes: 1 + devices: 1 + accelerator: gpu + strategy: ddp + + min_epochs: 1 + max_epochs: ${epochs} + + enable_progress_bar: true + + # debugging + fast_dev_run: false + enable_checkpointing: True + +logger: + tensorboard: + _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + save_dir: ${output_dir}/tb_logs/exp_name + name: null + version: null + log_graph: False + default_hp_metric: True + prefix: "" + +callbacks: + checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + dirpath: "${output_dir}/checkpoints_bc/" + monitor: "val/interactive_reward" # name of the logged metric which determines when model is improving + mode: "max" # "max" means higher metric value is better, can be also "min" + save_top_k: 5 # save k best models (determined by above metric) + filename: "best_reward_model" # best_model + +domain_and_task: + timesteps: 250 + domain: "walker" + task: "walk" + +data: + # _target_ can be datamodules.dmc_datamodule.DMCDataModule or datamodules.dmc_datamodule.DMCDataModule + domain: ${domain_and_task.domain} + task: ${domain_and_task.task} + data_dir_prefix: ./data + context_length: 30 + num_buffers: 50 + num_steps: 500000 + trajectories_per_buffer: 10 # Number of trajectories to sample from each of the buffers + stack_size: 4 + batch_size: 256 + num_workers: 1 + train_replay_id: 1 + val_replay_id: 2 + + # used for DMCBCDataModule only (used when model_type is "naive") + select_rate: 0 + rand_select: False + +# DTModel +model: + _target_: models.ct_module.CTLitModule + + domain: ${domain_and_task.domain} + task: ${domain_and_task.task} + agent_type: gpt + model_type: reward_conditioned # choices=["reward_conditioned", "naive"] + timestep: ${timestep} + n_embd: 256 + lr: 6e-4 + context_length: 30 + weight_decay: 0.1 + betas: [0.9, 0.95] + epochs: ${epochs} + eval_epochs: ${eval_epochs} + seed: ${seed} + + # set these to false in downstream learning (unless using as an auxiliary task) + unsupervise: False + forward: False + inverse: False + reward: False + rand_inverse: False + freeze_encoder: False + rand_attn_only: False + + ## whether to use random mask hindsight control + rand_mask_size: -1 # mask size for action, -1 is to set masks by curriculum + mask_obs_size: -1 # mask size for observations, -1 is to set masks by curriculum + + # weights + forward_weight: 1.0 + + # layers and network configs + n_layer: 8 + n_head: 8 + bc_layers: 1 + pred_layers: 1 + rtg_layers: 1 diff --git a/configs/pretrain.yaml b/configs/pretrain.yaml new file mode 100644 index 0000000..3dcf994 --- /dev/null +++ b/configs/pretrain.yaml @@ -0,0 +1,122 @@ +seed: 123 + +output_dir: ./outputs +exp_name: +load_model_from: +no_load_action: False +no_strict: False +no_action_head: False +stat_file: stat.csv +timestep: 10000 +epochs: 50 + +trainer: + default_root_dir: ${output_dir} + + num_nodes: 1 + devices: 1 + accelerator: gpu + strategy: ddp + + min_epochs: 1 + max_epochs: ${epochs} + + enable_progress_bar: true + + # debugging + fast_dev_run: false + enable_checkpointing: True + +logger: + tensorboard: + _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + save_dir: ${output_dir}/tb_logs/exp_name + name: null + version: null + log_graph: False + default_hp_metric: True + prefix: "" + +callbacks: + checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + dirpath: "${output_dir}/checkpoints/" + monitor: "val/avg_loss" # name of the logged metric which determines when model is improving + mode: "min" # "max" means higher metric value is better, can be also "min" + save_top_k: 5 # save k best models (determined by above metric) + save_last: True # additionaly always save model from last epoch + verbose: False + filename: "checkpoint_{epoch:02d}" # best_model + +domain_and_task: + source_data_type: rand # choices: rand, full, max + timesteps: 250 + source_envs: + walker: ["walk"] + cheetah: ["run"] + # walker: ["stand", "run"] + # cartpole: ["swingup"] + # hopper: ["hop"] + # finger: ["spin"] + +data: + _target_: datamodules.dmc_datamodule.DMCMultiDomainDataModule + + source_envs: ${domain_and_task.source_envs} + data_dir_prefix: ./data + context_length: 30 + num_buffers: 50 + num_steps: 80000 + trajectories_per_buffer: 10 # Number of trajectories to sample from each of the buffers + stack_size: 4 + batch_size: 256 + num_workers: 1 + train_replay_id: 1 + val_replay_id: 2 + select_rate: 0 + seed: ${seed} + biased_multi: False + source_data_type: ${domain_and_task.source_data_type} + +# DTModel +model: + _target_: models.multitask_ct_module.MultiTaskCTLitModule + + source_envs: ${domain_and_task.source_envs} + agent_type: gpt + model_type: reward_conditioned # choices=["reward_conditioned", "naive"] + timestep: ${timestep} + n_embd: 256 + lr: 6e-4 + freeze_encoder: False + context_length: 30 + betas: [0.9, 0.95] + weight_decay: 0.1 + epochs: ${epochs} + + ## whether to use supervision + unsupervise: True + + ## whether to use forward prediction + forward: True + + ## whether to use inverse prediction + inverse: True + + ## whether to use random mask hindsight control + rand_inverse: True + rand_mask_size: -1 # mask size for action, -1 is to set masks by curriculum + mask_obs_size: -1 # mask size for observations, -1 is to set masks by curriculum + + # weights + forward_weight: 1.0 + + # layers and network configs + n_layer: 8 + n_head: 8 + bc_layers: 1 + pred_layers: 1 + rtg_layers: 1 + + ## additional options that are not used in the original method + reward: False # whether to predict reward diff --git a/configs/pretrain_configs/multipretrain_source_test.json b/configs/pretrain_configs/multipretrain_source_test.json deleted file mode 100644 index fbdcea7..0000000 --- a/configs/pretrain_configs/multipretrain_source_test.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "walker": ["walk"], - "cheetah": ["run"] -} diff --git a/configs/pretrain_configs/multipretrain_source_v1.json b/configs/pretrain_configs/multipretrain_source_v1.json deleted file mode 100644 index 1bb68c9..0000000 --- a/configs/pretrain_configs/multipretrain_source_v1.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "walker": ["stand", "run"], - "cheetah": ["run"], - "cartpole": ["swingup"], - "hopper": ["hop"] -} diff --git a/configs/pretrain_configs/multipretrain_source_v2.json b/configs/pretrain_configs/multipretrain_source_v2.json deleted file mode 100644 index d6addf6..0000000 --- a/configs/pretrain_configs/multipretrain_source_v2.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "walker": ["run"], - "hopper": ["stand"], - "finger": ["spin"], - "swimmer": ["swimmer15"], - "fish": ["swim"] -} diff --git a/src/scripts/dmc_setup.sh b/scripts/dmc_setup.sh similarity index 90% rename from src/scripts/dmc_setup.sh rename to scripts/dmc_setup.sh index 076b64f..f6bbf02 100644 --- a/src/scripts/dmc_setup.sh +++ b/scripts/dmc_setup.sh @@ -4,4 +4,4 @@ pip install -e . pip install protobuf==3.20.1 pip list pip install gym==0.25.2 -cd ../../../ \ No newline at end of file +cd ../../../ diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..5692801 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,21 @@ +[metadata] +name = smart +version = 0.0.1 +author = Yanchao Sun, Shuang Ma, Ratnesh Madaan, Rogerio Bonatti, Furong Huang, Ashish Kapoor +author_email = ycs@umd.edu, shuama@microsoft.com, ramadaan@microsoft.com +description = Code associated with SMART: Self-supervised Multi-task pretrAining with contRol Transformers (ICLR 2023) +long_description = file: README.md +url = https://github.com/microsoft/smart +classifiers = + Programming Language :: Python :: 3 + License :: OSI Approved :: MIT License + Operating System :: OS Independent + +[options] +package_dir = + = . +packages = find: +python_requires = >=3.6 + +[options.packages.find] +where = . diff --git a/setup.py b/setup.py index 6068493..adeb577 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from setuptools import setup setup() diff --git a/src/__init__.py b/src/__init__.py index e69de29..9a04545 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/src/arguments.py b/src/arguments.py deleted file mode 100644 index cfea045..0000000 --- a/src/arguments.py +++ /dev/null @@ -1,43 +0,0 @@ -import argparse - -parser = argparse.ArgumentParser() - -## program and path -parser.add_argument("--data_dir_prefix", type=str, default="./data/") -parser.add_argument("--output_dir", type=str, default="./outputs/") -parser.add_argument("--exp_name", type=str, default="") -parser.add_argument("--load_model_from", type=str, default=None) -parser.add_argument("--no_load_action", default=False, action="store_true") -parser.add_argument("--no_strict", default=False, action="store_true") -parser.add_argument("--no_action_head", default=False, action="store_true") - -## trainer (for pytorch_lightning) -parser.add_argument("--seed", type=int, default=123) -parser.add_argument("--accelerator", type=str, default="gpu") -parser.add_argument("--devices", type=int, default=1, help="how many GPUs to use") -parser.add_argument("--nodes", type=int, default=1) -parser.add_argument("--epochs", type=int, default=10) -parser.add_argument("--save_k", type=int, default=5, help="how many checkpoints to save in finetuning") - -## evaluate -parser.add_argument("--eval_epochs", type=int, default=50) - -# dmc -parser.add_argument("--domain", type=str, default="cheetah") -parser.add_argument("--task", type=str, default="run") -parser.add_argument("--multi_config", type=str, default=None) - -## data -parser.add_argument("--source_data_type", type=str, default="full", choices=["full", "rand", "mix"]) -parser.add_argument("--context_length", type=int, default=30) -parser.add_argument("--batch_size", type=int, default=256) -parser.add_argument("--num_steps", type=int, default=500000) -parser.add_argument("--select_rate", type=float, default=0.1) -parser.add_argument("--train_replay_id", type=int, default=2) -parser.add_argument("--val_replay_id", type=int, default=5) -parser.add_argument("--timestep", type=int, default=10000) -parser.add_argument("--rand_select", default=False, action="store_true") -parser.add_argument("--biased_multi", default=False, action="store_true") - - - diff --git a/src/datamodules/__init__.py b/src/datamodules/__init__.py index e69de29..9a04545 100644 --- a/src/datamodules/__init__.py +++ b/src/datamodules/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/src/datamodules/dmc_datamodule.py b/src/datamodules/dmc_datamodule.py index c3491ed..4eb7740 100644 --- a/src/datamodules/dmc_datamodule.py +++ b/src/datamodules/dmc_datamodule.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import os from typing import Optional @@ -14,7 +17,6 @@ def create_selected_dataset( domain_name, task_name, data_dir_prefix, num_steps, replay_id=1, stack_size=4, select_rate=0.1, rand_select=False ): - obss, actions, returns, done_idxs, rtgs, timesteps, step_returns = create_dataset( domain_name, task_name, data_dir_prefix, num_steps, replay_id, stack_size ) @@ -73,7 +75,6 @@ def create_selected_dataset( def create_dataset(domain_name, task_name, data_dir_prefix, num_steps, replay_id=1, stack_size=4): - env = dmc2gym.make( domain_name=domain_name, task_name=task_name, @@ -269,6 +270,8 @@ def __init__( num_workers=1, train_replay_id=1, val_replay_id=2, + select_rate=0, # not used. only to match signature with DMCBCDataModule + rand_select=False, # not used. only to match signature with DMCBCDataModule ): super().__init__() @@ -280,7 +283,7 @@ def __init__( if self.hparams.domain not in self.hparams.data_dir_prefix: self.hparams.data_dir_prefix = os.path.join( - self.hparams.data_dir_prefix, "fullcollect_" + self.hparams.domain + "_" + self.hparams.task + self.hparams.data_dir_prefix, "randcollect_" + self.hparams.domain + "_" + self.hparams.task ) @property @@ -511,8 +514,10 @@ def val_dataloader(self): def test_dataloader(self): return None + class DMCMultiDomainDataModule(LightningDataModule): - """Example of LightningDataModule for pretraining in multiple domains and multiple tasks on DMC.""" + """Example of LightningDataModule for pretraining in multiple domains and multiple tasks on + DMC.""" def __init__( self, @@ -528,8 +533,8 @@ def __init__( train_replay_id=1, val_replay_id=2, select_rate=0.1, - seed=42, - dataset_types=["fullcollect"], + seed=123, + source_data_type=["full"], biased_multi=False, ): super().__init__() @@ -537,7 +542,7 @@ def __init__( # this line allows to access init params with 'self.hparams' attribute self.save_hyperparameters(logger=False) - self.domains = self.hparams.source_envs.keys() + self.domains = source_envs.keys() self.action_dims = [get_min_action_dmc(domain) for domain in self.domains] print("domains", self.domains) print("action dims", self.action_dims) @@ -545,6 +550,14 @@ def __init__( self.train_dataset: Optional[Dataset] = None self.val_dataset: Optional[Dataset] = None + if self.hparams.source_data_type == "full": + self.hparams.dataset_types = ["fullcollect"] + elif self.hparams.source_data_type == "rand": + self.hparams.dataset_types = ["randcollect"] + elif self.hparams.source_data_type == "mix": + self.hparams.dataset_types = ["fullcollect", "randcollect"] + self.hparams.num_steps = self.hparams.num_steps // 2 + @property def context_length(self) -> int: return self.context_length diff --git a/src/datamodules/dmc_replay_buffer.py b/src/datamodules/dmc_replay_buffer.py index 6b0adc0..5145cff 100644 --- a/src/datamodules/dmc_replay_buffer.py +++ b/src/datamodules/dmc_replay_buffer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import os import numpy as np diff --git a/src/datamodules/dummy_datamodule.py b/src/datamodules/dummy_datamodule.py index 24a9a19..a37ed18 100644 --- a/src/datamodules/dummy_datamodule.py +++ b/src/datamodules/dummy_datamodule.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import torch from torch.utils.data import Dataset diff --git a/src/dmc_downstream.py b/src/dmc_downstream.py index e7891cb..3e21a6a 100644 --- a/src/dmc_downstream.py +++ b/src/dmc_downstream.py @@ -1,72 +1,39 @@ -import csv +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import os -import time -import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger +from omegaconf import OmegaConf -from arguments import parser from datamodules.dmc_datamodule import DMCBCDataModule, DMCDataModule from datamodules.dummy_datamodule import RandomDataset -from models.ct_module import CTLitModule - +from utils import pl_utils -def main(args): - # set seed for reproducibility, although the trainer does not allow deterministic for this implementation - pl.seed_everything(args.seed, workers=True) +def main(cfg): + cfg_dict = OmegaConf.to_container(cfg, resolve=True) - bc = (args.model_type == "naive") - # init data module - if bc: - dmc_data = DMCBCDataModule.from_argparse_args(args) + if cfg.model.model_type == "naive": + # bc = True + dmc_data = DMCBCDataModule(**cfg.data) else: - dmc_data = DMCDataModule.from_argparse_args(args) + dmc_data = DMCDataModule(**cfg.data) - # init training module - dict_args = vars(args) - - model = CTLitModule(**dict_args) - if args.load_model_from: + model = pl_utils.instantiate_class(cfg["model"]) + if cfg.load_model_from: model.load_my_checkpoint( - args.load_model_from, - no_action=args.no_load_action, - strict=not args.no_strict, - no_action_head=args.no_action_head, + cfg.load_model_from, + no_action=cfg.no_load_action, + strict=not cfg.no_strict, + no_action_head=cfg.no_action_head, ) - print("loaded model from", args.load_model_from) + print("loaded model from", cfg.load_model_from) # init root dir - os.makedirs(args.output_dir, exist_ok=True) - print("output dir", args.output_dir) - - # checkpoint saving metrics - checkpoint_callback = ModelCheckpoint( - dirpath=os.path.join(args.output_dir, "checkpoints_bc"), - filename="best_reward_model", - mode="max", - save_top_k=args.save_k, - monitor="val/interactive_reward", - save_last=True, - ) + os.makedirs(cfg.output_dir, exist_ok=True) + print("output dir", cfg.output_dir) - logger = TensorBoardLogger(os.path.join(args.output_dir, "tb_logs"), name="train") - - # init trainer - trainer = pl.Trainer( - accelerator=args.accelerator, - devices=args.devices, - num_nodes=args.nodes, - default_root_dir=args.output_dir, - min_epochs=1, - max_epochs=args.epochs, - callbacks=[checkpoint_callback], - strategy="ddp", - # strategy='ddp_find_unused_parameters_false', - fast_dev_run=False, - logger=logger, - ) + trainer = pl_utils.instantiate_trainer(cfg_dict) trainer.fit(model, datamodule=dmc_data) @@ -76,9 +43,15 @@ def main(args): if __name__ == "__main__": - parser = DMCDataModule.add_argparse_args(parser) - parser = CTLitModule.add_model_specific_args(parser) - - args = parser.parse_args() + cfg = OmegaConf.from_cli() + + if "base" in cfg: + basecfg = OmegaConf.load(cfg.base) + del cfg.base + cfg = OmegaConf.merge(basecfg, cfg) + print(OmegaConf.to_yaml(cfg)) + main(cfg) + else: + raise SystemExit("Base configuration file not specified! Exiting.") - main(args) + main(cfg) diff --git a/src/dmc_pretrain.py b/src/dmc_pretrain.py index afec325..03b4dae 100644 --- a/src/dmc_pretrain.py +++ b/src/dmc_pretrain.py @@ -1,77 +1,31 @@ -import json +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import os -import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger +from omegaconf import OmegaConf -from arguments import parser -from datamodules.dmc_datamodule import DMCMultiDomainDataModule from datamodules.dummy_datamodule import RandomDataset from models.multitask_ct_module import MultiTaskCTLitModule +from utils import pl_utils -def main(args): - - # set seed for reproducibility, although the trainer does not allow deterministic for this implementation - pl.seed_everything(args.seed, workers=True) - - with open(args.multi_config) as jsonfile: - args.source_envs = json.load(jsonfile) +def main(cfg): + trainer = pl_utils.instantiate_trainer(cfg) + datamodule = pl_utils.instantiate_class(cfg["data"]) - if args.source_data_type == "full": - args.dataset_types = ["fullcollect"] - elif args.source_data_type == "rand": - args.dataset_types = ["randcollect"] - elif args.source_data_type == "mix": - args.dataset_types = ["fullcollect", "randcollect"] - args.num_steps = args.num_steps // 2 - args.biased_multi = True - - # init data module - dmc_data = DMCMultiDomainDataModule.from_argparse_args(args) - - # init training module - dict_args = vars(args) - if args.load_model_from: - model = MultiTaskCTLitModule.load_from_checkpoint(args.load_model_from, **dict_args) - print("loaded model from", args.load_model_from) + if cfg["load_model_from"] is not None: + model = MultiTaskCTLitModule.load_from_checkpoint(cfg.load_model_from, **cfg.model) + print("loaded model from", cfg["load_model_from"]) else: - model = MultiTaskCTLitModule(**dict_args) + model = pl_utils.instantiate_class(cfg["model"]) # init root dir - os.makedirs(args.output_dir, exist_ok=True) - print("output dir", args.output_dir) - - # checkpoint saving metrics - checkpoint_callback = ModelCheckpoint( - dirpath=os.path.join(args.output_dir, "checkpoints"), - # filename="best_model", - filename="checkpoint_{epoch:02d}", - mode="min", - save_top_k=args.save_k, - monitor="val/avg_loss", - save_last=True, - ) - - logger = TensorBoardLogger(os.path.join(args.output_dir, "tb_logs"), name="train") - - # init trainer - trainer = pl.Trainer( - accelerator=args.accelerator, - devices=args.devices, - num_nodes=args.nodes, - default_root_dir=args.output_dir, - min_epochs=1, - max_epochs=args.epochs, - callbacks=[checkpoint_callback], - strategy="ddp", - fast_dev_run=False, - logger=logger, - ) + os.makedirs(cfg["output_dir"], exist_ok=True) + print("output dir", cfg["output_dir"]) # start training - trainer.fit(model, datamodule=dmc_data) + trainer.fit(model, datamodule) # testing dummy = RandomDataset() @@ -79,10 +33,16 @@ def main(args): if __name__ == "__main__": + cfg = OmegaConf.from_cli() + + if "base" in cfg: + basecfg = OmegaConf.load(cfg.base) + del cfg.base + cfg = OmegaConf.merge(basecfg, cfg) + cfg = OmegaConf.to_container(cfg, resolve=True) + print(OmegaConf.to_yaml(cfg)) + main(cfg) + else: + raise SystemExit("Base configuration file not specified! Exiting.") - parser = DMCMultiDomainDataModule.add_argparse_args(parser) - parser = MultiTaskCTLitModule.add_model_specific_args(parser) - - args = parser.parse_args() - - main(args) + main(cfg) diff --git a/src/envs/README.md b/src/envs/README.md new file mode 100644 index 0000000..c64b91b --- /dev/null +++ b/src/envs/README.md @@ -0,0 +1,3 @@ +## Credits: + +This folder contains a modified version of [dmc2gym](https://github.com/denisyarats/dmc2gym) by [Denis Yarats](https://github.com/denisyarats). diff --git a/src/models/__init__.py b/src/models/__init__.py index e69de29..9a04545 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/src/models/components/__init__.py b/src/models/components/__init__.py index e69de29..9a04545 100644 --- a/src/models/components/__init__.py +++ b/src/models/components/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/src/models/components/mingpt.py b/src/models/components/mingpt.py index 94b086d..cac9ff3 100644 --- a/src/models/components/mingpt.py +++ b/src/models/components/mingpt.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + """The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: @@ -78,7 +81,6 @@ def build_mlp( init_method=None, bias=True, ): - layers = [] in_size = input_size for _ in range(n_layers - 1): @@ -448,7 +450,7 @@ def forward( if pred_rand_inverse: # randomly mask past actions and predict them rand_mask_idx = np.random.choice(actions.shape[1], rand_mask_size, replace=False) - masked_token = token_embeddings.clone() + masked_token = token_embeddings.clone() for j in range(rand_mask_size): masked_token[:, 1 + 2 * rand_mask_idx[j], :] = -1 @@ -488,7 +490,6 @@ def forward( return logits, losses def get_embeddings(self, states, actions, timesteps): - if actions is not None and actions.shape[1] == 0: actions = None is_testing = (actions is None) or (actions.shape[1] != states.shape[1]) @@ -541,8 +542,8 @@ def configure_optimizers(self, hparams): """Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. - See examples here: - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers + See examples here: https://pytorch- + lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() diff --git a/src/models/ct_module.py b/src/models/ct_module.py index a8d136d..c0c410b 100644 --- a/src/models/ct_module.py +++ b/src/models/ct_module.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from collections import OrderedDict from typing import Any, List @@ -7,24 +10,42 @@ from torch.nn import functional as F from models.components.mingpt import GPT, GPTConfig -from models.utils import ( - get_exp_return_dmc, - get_min_action_dmc, - top_k_logits, -) +from models.utils import get_exp_return_dmc, get_min_action_dmc, top_k_logits class CTLitModule(LightningModule): - """LightningModule for Control Transformer. - """ + """LightningModule for Control Transformer.""" def __init__( self, - domain, - task, - betas=(0.9, 0.95), - weight_decay=0.1, - **kwargs, + agent_type: str, + model_type: str, + domain: str, + task: str, + n_embd: int, + lr: float, + unsupervise: bool, + forward: bool, + inverse: bool, + reward: bool, + rand_inverse: bool, + freeze_encoder: bool, + rand_attn_only: bool, + rand_mask_size: bool, + mask_obs_size: bool, + forward_weight: float, + n_layer: int, + n_head: int, + rtg_layers: int, + bc_layers: int, + pred_layers: int, + context_length: int, + epochs: int, + timestep: int, + weight_decay: float, + betas: List[float], + eval_epochs: int, + seed: int, ): super().__init__() @@ -55,41 +76,6 @@ def __init__( else: assert "agent type not supported" - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("CTModel") - - parser.add_argument("--agent_type", type=str, default="gpt") - parser.add_argument( - "--model_type", type=str, default="reward_conditioned", choices=["reward_conditioned", "naive"], - help="the reward_conditioned option learns the RTG-conditioned policy, while the naive option learns the behavior cloning policy" - ) - parser.add_argument("--n_embd", type=int, default=256) - parser.add_argument("--lr", type=float, default=6e-4) - - ## set false in downstream learning, unless using as auxiliary tasks - parser.add_argument("--unsupervise", default=False, action="store_true") - parser.add_argument("--forward", default=False, action="store_true") - parser.add_argument("--inverse", default=False, action="store_true") - parser.add_argument("--reward", default=False, action="store_true") - parser.add_argument("--rand_inverse", default=False, action="store_true") - parser.add_argument("--freeze_encoder", default=False, action="store_true") - parser.add_argument("--rand_attn_only", default=False, action="store_true") - parser.add_argument("--rand_mask_size", type=int, default=-1) - parser.add_argument("--mask_obs_size", type=int, default=-1) - - # weights - parser.add_argument("--forward_weight", type=float, default=1.0) - - # layers and network configs - parser.add_argument("--n_layer", type=int, default=8) - parser.add_argument("--n_head", type=int, default=8) - parser.add_argument("--rtg_layers", type=int, default=1) - parser.add_argument("--bc_layers", type=int, default=1) - parser.add_argument("--pred_layers", type=int, default=1) - - return parent_parser - def load_my_checkpoint(self, path, no_action=False, strict=True, no_action_head=False): m = torch.load(path)["state_dict"] model_dict = self.state_dict() @@ -299,7 +285,18 @@ def get_return_dmc(self, epochs): return eval_return, std_return @torch.no_grad() - def sample(self, x, steps, cont_action=True, temperature=1.0, sample=False, top_k=None, actions=None, rtgs=None, timesteps=None): + def sample( + self, + x, + steps, + cont_action=True, + temperature=1.0, + sample=False, + top_k=None, + actions=None, + rtgs=None, + timesteps=None, + ): """take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in the sequence, feeding the predictions back into the model each time. diff --git a/src/models/multitask_ct_module.py b/src/models/multitask_ct_module.py index 8cb69d9..aa991c0 100644 --- a/src/models/multitask_ct_module.py +++ b/src/models/multitask_ct_module.py @@ -1,30 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from collections import OrderedDict -from typing import Any, List +from typing import Any, Dict, List -import numpy as np import torch from pytorch_lightning import LightningModule -from torch.nn import functional as F from models.components.mingpt import GPT, GPTConfig -from models.utils import ( - get_exp_return, - get_exp_return_dmc, - get_min_action, - get_min_action_dmc, - top_k_logits, -) +from models.utils import get_min_action_dmc class MultiTaskCTLitModule(LightningModule): - """LightningModule for multi-task control transformer. - """ + """LightningModule for multi-task control transformer.""" def __init__( - self, - betas=(0.9, 0.95), - weight_decay=0.1, - **kwargs + self, + epochs: int, + agent_type: str, + model_type: str, + timestep: int, + n_embd: int, + lr: float, + forward: bool, + inverse: bool, + reward: bool, + rand_inverse: bool, + unsupervise: bool, + rand_mask_size: int, + freeze_encoder: bool, + context_length: int, + betas: List[float], + weight_decay: float, + n_layer: int, + n_head: int, + mask_obs_size: int, + pred_layers: int, + bc_layers: int, + rtg_layers: int, + forward_weight: float, + source_envs: Dict[str, List[str]], ): super().__init__() @@ -39,7 +54,7 @@ def __init__( for domain in self.domains: vocab_size += get_min_action_dmc(domain) else: - vocab_size = get_min_action_dmc(self.hparams.domain) + vocab_size = get_min_action_dmc(self.hparams.domain) if self.hparams.agent_type == "gpt": block_size = self.hparams.context_length * 2 @@ -63,55 +78,6 @@ def __init__( else: assert "agent type not supported" - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("CTModel") - - parser.add_argument("--agent_type", type=str, default="gpt") - parser.add_argument( - "--model_type", type=str, default="reward_conditioned", choices=["reward_conditioned", "naive"] - ) - parser.add_argument("--n_embd", type=int, default=256) - parser.add_argument("--lr", type=float, default=6e-4) - - ## whether to use supervision - parser.add_argument('--unsupervise', action='store_true') - parser.add_argument('--no-unsupervise', dest='unsupervise', action='store_false') - parser.set_defaults(unsupervise=True) - - ## whether to use whether to use forward prediction - parser.add_argument('--forward', action='store_true') - parser.add_argument('--no-forward', dest='forward', action='store_false') - parser.set_defaults(forward=True) - - ## whether to use whether to use inverse prediction - parser.add_argument('--inverse', action='store_true') - parser.add_argument('--no-inverse', dest='inverse', action='store_false') - parser.set_defaults(inverse=True) - - ## whether to use whether to use random mask hindsight control - parser.add_argument('--rand_inverse', action='store_true') - parser.add_argument('--no-rand_inverse', dest='rand_inverse', action='store_false') - parser.set_defaults(rand_inverse=True) - - parser.add_argument("--rand_mask_size", type=int, default=-1, help="mask size for action, -1 is to set masks by curriculum") - parser.add_argument("--mask_obs_size", type=int, default=-1, help="mask size for observations, -1 is to set masks by curriculum") - - # weights - parser.add_argument("--forward_weight", type=float, default=1.0) - - # layers and network configs - parser.add_argument("--n_layer", type=int, default=8) - parser.add_argument("--n_head", type=int, default=8) - parser.add_argument("--rtg_layers", type=int, default=1) - parser.add_argument("--bc_layers", type=int, default=1) - parser.add_argument("--pred_layers", type=int, default=1) - - ## additional options that are not used in the original method - parser.add_argument("--reward", default=False, action="store_true", help="whether to predict reward") - - return parent_parser - def training_step(self, batch: Any, batch_idx: int): obs, actions, rtg, ts, rewards, task_ids = batch targets = None if self.hparams.unsupervise else actions diff --git a/src/models/utils.py b/src/models/utils.py index a2157d7..7c3f821 100644 --- a/src/models/utils.py +++ b/src/models/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + """The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: @@ -8,9 +11,7 @@ """ import random -from collections import deque -import gym import numpy as np import torch @@ -69,4 +70,3 @@ def get_exp_return_dmc(domain, task): return game_to_returns[domain + "_" + task] else: raise NotImplementedError() - diff --git a/src/utils/pl_utils.py b/src/utils/pl_utils.py new file mode 100644 index 0000000..0b46b82 --- /dev/null +++ b/src/utils/pl_utils.py @@ -0,0 +1,67 @@ +from importlib import import_module +from typing import Any, Dict, List + +import pytorch_lightning as pl +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.loggers import LightningLoggerBase + + +def instantiate_class(init: Dict[str, Any]) -> Any: + """Instantiates a class with the given args and init. + + Args: + todo + Returns: + The instantiated class object. + """ + kwargs = {k: init[k] for k in set(list(init.keys())) - {"_target_"}} + + class_module, class_name = init["_target_"].rsplit(".", 1) + module = import_module(class_module, package=class_name) + args_class = getattr(module, class_name) + return args_class(**kwargs) + + +def instantiate_callbacks(callbacks_cfg: dict) -> List[Callback]: + """Instantiates callbacks from config.""" + callbacks: List[Callback] = [] + + if not callbacks_cfg: + return callbacks + + if not isinstance(callbacks_cfg, dict): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, dict) and "_target_" in cb_conf: + callbacks.append(instantiate_class(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: dict) -> List[LightningLoggerBase]: + """Instantiates loggers from config.""" + logger: List[LightningLoggerBase] = [] + + if not logger_cfg: + return logger + + if not isinstance(logger_cfg, dict): + raise TypeError("Logger config must be a Dict!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, dict) and "_target_" in lg_conf: + logger.append(instantiate_class(lg_conf)) + + return logger + + +def instantiate_trainer(cfg: dict): + if cfg.get("seed", None): + pl.seed_everything(cfg["seed"], workers=True) + + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + logger: List[LightningLoggerBase] = instantiate_loggers(cfg.get("logger")) + trainer: Trainer = Trainer(**cfg["trainer"], callbacks=callbacks, logger=logger) + + return trainer