This repository is an extended and modified version of the original Hybrid Inverse Reinforcement Learning implementation.
- Extended to support all RL environments, including Gym, MuJoCo, and Atari.
- Added KL divergence and reward-based evaluation to improve benchmarking across different settings.
- Refactored and optimized the codebase to enhance generalization and ease of use.
This repository is structured as follows:
βββ garage # Package folder, where we park all our hybrids
β βββ algorithms # Model-free and model-based inverse RL implementations
β β βββ model_based_irl.py # To run HyPER
β β βββ model_free_irl.py # To run HyPE, FILTER, MM, BC-Reg
β βββ models # Learner and discriminator model architectures
β β βββ discriminator.py # Single and ensemble implementation
β β βββ sac.py # Used in Mujoco locomotion experiments
β β βββ td3_bc.py # Used in D4RL antmaze experiments
β βββ mbrl # Fork of mbrl-lib used in model-based algorithms
β βββ utils # Buffers, wrappers, optimizers, and logging
β
βββ config # Hydra config yamls
β βββ algorithm # Algorithm-specific configs
β βββ overrides # Environment-specific configs
β
βββ experts # Training and collecting expert demonstrations
β
βββ figures # Comparison plots
β
βββ tests # Unit tests and utility scripts
β βββ utils # Utilities for testing data processing
β β βββ convert_demos.py # Converts expert demos into the required format
β β βββ test_fetch_transitions.py # Unit tests for fetching transitions from datasets
Please run
export PS1='[\u@\h \W]\$ '
conda create -n hyirl python=3.8.18 -y
conda activate hyirl
pip install -r requirements.txt
then, assuming you are at the top-level of this repo, run
export PYTHONPATH=${PYTHONPATH}:${PWD}
export D4RL_SUPPRESS_IMPORT_ERROR=1
or run
$env:PYTHONPATH = $env:PYTHONPATH + ";" + (Get-Location)
$env:D4RL_SUPPRESS_IMPORT_ERROR = "1"
in windows!
For all experiments, please activate the conda environment created in Installation.
conda activate hyirl
choose 1 2 or 3 below to get demonstration data
We provide the original datasets used for the mujoco locomotion environments in the paper. These can be acquired by running
python experts/download_datasets.py
which will download the corresponding datasets for all of Ant-v3
, Hopper-v3
, Humanoid-v3
, and Walker2d-v3
.
Since antmaze
demonstrations are downloaded directly from D4RL, there is no need to train an expert beforehand. Please directly run collect_demos.py
, which will download the dataset, run some additional processing, and save all relevant keys to the .npz
file.
To train your own experts, please run the following script
python experts/train.py --env <env_name>
An expert will be trained for the desired environment, and a checkpoint will be saved in experts/<env_name>
. Then, to use this checkpoint to collect new trajectories, please run the following
python experts/collect_demos.py --env <env_name>
Demonstrations will be saved as an .npz
file containing the following entries: observations
, actions
, next_observations
, rewards
, terminals
, timeouts
, seed
, qpos
, and qvel
.
To extend to more experiments, simply add the new environment to the list of arguments allowed in experts/train.py
, then run the two scripts above. Policy optimizers for the expert can also be switched out easily, provided the same is done in experts/collect_demos.py
when loading the checkpoint.
If you already have your own demonstration dataset (e.g., collected from a different framework or custom environment), you can convert it into the required .npz
format using our conversion script:
python tests/utils/convert_demos.py --input <your_demo_file> --output <converted_demo_file>
The input file should contain your demonstration data in a structured format (e.g., pickle, JSON, or numpy).
The conversion script will process the data and output an .npz
file compatible with the training and evaluation pipelines in this repository.
The resulting .npz
file will include the following keys:
observations
, actions
, next_observations
, rewards
, terminals
, timeouts
(optional), seed
(optional), qpos
(optional), and qvel
(optional).
Tip
Make sure your input data includes at least observations
, actions
, and rewards
. Additional information such as next_observations
, terminals
, and timeouts
is recommended for compatibility with model-based algorithms.
Note
This repository currently does not support gymnasium
versions of environments. We are working on updating our files to support newer versions of gym
and gymnasium
.
We found that pretraining the model for model-based antmaze experiments and decreasing model update frequency helped improve stability. Thus, we also provide the script to pretrain the antmaze models, which can be modified to other environments as well. The pretrained model checkpoints used in the paper can be found under garage\pretrained_models\<env_name>
. To create your own pretrained model, run
python main.py algorithm=pretrain_antmaze overrides=model_based_antmaze_diverse
or
python main.py algorithm=pretrain_antmaze overrides=model_based_antmaze_play
To find more details on the specific pretraining process, please reference here.
To recreate the main plots seen in the paper, first run
cd garage
This repository organizes configuration files for various experiments using hydra, which can be found in garage/config
, and one just needs to specify the algorithm and environment they wish to run it on. For example, to run HyPE
on Walker2d-v3
:
maybe for windows should run
export MUJOCO_PY_MUJOCO_PATH=~/.mujoco/mujoco210
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco210/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco210/lib
python main.py algorithm=hype overrides=model_free_walker
or to run HyPER
on Ant-v3
:
python main.py algorithm=hyper overrides=model_based_ant
This package supports training of the following algorithms:
- MM: A baseline moment-matching algorithm that uses an integral probability metric instead of Jensen Shannon divergence as implemented by FastIRL.
- BC-Reg: MM with an added Mean-Square Error Loss on the actor update.
- FILTER: IRL with resets to expert states.
- HyPE: Model-free IRL with policy updates on both learner and expert state-action pairs.
- HyPE+FILTER: HyPE with resets to expert states.
- HyPER: Model-based IRL variant of HyPE, building off LAMPS-MBRL.
on the following environments:
- Ant-v3 (ant)
- Hopper-v3 (hopper)
- Humanoid-v3 (humanoid)
- Walker2d-v3 (walker)
- antmaze-large-diverse-v2 (maze-diverse)
- antmaze-large-play-v2 (maze-play)
Tip
For a more detailed breakdown of the garage
repository, please see here. For a specific breakdown of our implementations of model-free and model-based inverse reinforcement learning, please see here.
Tip
All configs are filled with the exact hyperparameters used in the paper. If one wishes to adapt these algorithms to different environments or datasets, we provide a detailed list of recommendations on which parameters we recommend tuning first here.
Results are saved in two locations. For all config files, model checkpoints, and all other detailed logs of the run, they are saved under garage\experiment_logs\<algorithm>\
. A copy of the final evaluation results are saved under garage\experiment_results\
. To generate graphs for one environment, run
python plot.py --env <env_name>
You can also generate graphs for all environments that have saved results by running
python plot.py --all
This project is based on the original Hybrid Inverse Reinforcement Learning (HYPE) implementation. We have significantly modified and extended the codebase to support a broader range of RL environments (Gym, MuJoCo, Atari) and introduced KL divergence & reward-based evaluation as additional benchmarking metrics.
The original HyPER algorithm was built on LAMPS-MBRL, and HyPE was built on FastIRL. Our modifications focus on extending their applicability and improving performance analysis.
Additional components were borrowed from:
If you found this repository useful in your research, plase consider citing paper.
@misc{ren2024hybrid,
title={Hybrid Inverse Reinforcement Learning},
author={Juntao Ren and Gokul Swamy and Zhiwei Steven Wu and J. Andrew Bagnell and Sanjiban Choudhury},
year={2024},
eprint={2402.08848},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
For any questions regarding this repository or paper, please feel free to contact jlr429 [at] cornell [dot] edu or gswamy [at] cmu [dot] edu.
Below is a brief summary of the extra steps we took to get this project running on a server without sudo privileges. These steps can be appended to the README:
-
Installed required libraries via Conda
We used conda-forge to install packages normally installed via system-level commands. For instance:conda install -c conda-forge glew mesa xorg-libx11 xorg-libxi xorg-libxext patchelf
This resolves errors like missing
<GL/glew.h>
and makes thepatchelf
tool available (needed bymujoco-py
). -
Set compiler/linker environment variables
Because the compiler may not look into Condaβsinclude
/lib
paths by default, we added them:export CPATH="$CONDA_PREFIX/include:$CPATH" export LIBRARY_PATH="$CONDA_PREFIX/lib:$LIBRARY_PATH" export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH"
-
Installed MuJoCo 2.1.0 manually
Sincemujoco-py
only provides Python bindings, we downloadedmujoco210-linux-x86_64.tar.gz
from MuJoCo.org, extracted it into~/.mujoco/
, and named the directorymujoco210
. This is wheremujoco-py
expects to find the native libraries:~/.mujoco/mujoco210/bin/libmujoco210.so
-
Installed
mujoco-py
With all dependencies in place, we simply ran:pip install mujoco-py
It compiled successfully within our Conda environment.
-
Ran the script headlessly
- We occasionally saw a
GLFWError: X11: The DISPLAY environment variable is missing
. This can be ignored when running headless, or you can set:to switch to a headless EGL backend.export MUJOCO_GL=egl
- We also got a Numpy C-API warning (
module compiled against API version 0x10 but this version is 0xe
). It did not break anything, so we ignored it. If it causes issues, you can reinstall Numpy and rebuild related packages to match the correct C-API version.
- We occasionally saw a
After these adjustments, the program started successfully with:
python main.py algorithm=hype overrides=model_free_walker
and began training without further errors.οΌhttps://github.com/openai/mujoco-py/issues/627οΌ
[Top]