Skip to content

Commit

Permalink
initial commit for rlmpc
Browse files Browse the repository at this point in the history
  • Loading branch information
svsawant committed Nov 3, 2023
1 parent 83fae93 commit 8841562
Show file tree
Hide file tree
Showing 8 changed files with 621 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/rl/rl_experiment.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash

# SYS='cartpole'
SYS='cartpole'
# SYS='quadrotor_2D'
SYS='quadrotor_3D'
# SYS='quadrotor_3D'

# TASK='stab'
TASK='track'
Expand Down
72 changes: 72 additions & 0 deletions examples/rlmpc/config_overrides/cartpole/cartpole_stab.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
task_config:
seed: 42
info_in_reset: True
ctrl_freq: 15
pyb_freq: 750
physics: pyb

# state initialization
init_state:
init_x: 0.1
init_x_dot: -1.5
init_theta: -0.175
init_theta_dot: 0.5
randomized_init: True
randomized_inertial_prop: False
normalized_rl_action_space: True

init_state_randomization_info:
init_x:
distrib: "uniform"
low: -1
high: 1
init_x_dot:
distrib: "uniform"
low: -0.1
high: 0.1
init_theta:
distrib: "uniform"
low: -0.2
high: 0.2
init_theta_dot:
distrib: "uniform"
low: -0.1
high: 0.1

task: stabilization
task_info:
stabilization_goal: [0.7, 0]
stabilization_goal_tolerance: 0.05

inertial_prop:
pole_length: 0.5
cart_mass: 1
pole_mass: 0.1

episode_len_sec: 10
cost: quadratic
obs_goal_horizon: 1

# RL Reward
rew_state_weight: [1, 1, 1, 1]
rew_act_weight: 0.1
rew_exponential: True

# constraints
constraints:
- constraint_form: default_constraint
constrained_variable: state
upper_bounds:
- 2
- 2
- 0.18
- 2
lower_bounds:
- -2
- -2
- -0.18
- -2
- constraint_form: default_constraint
constrained_variable: input
done_on_out_of_bound: True
done_on_violation: False
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

52 changes: 52 additions & 0 deletions examples/rlmpc/rlmpc_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
'''This script tests the RL-MPC implementation.'''

import os
import shutil
from functools import partial

import matplotlib.pyplot as plt
import numpy as np

from safe_control_gym.envs.benchmark_env import Cost, Environment, Task
from safe_control_gym.experiments.base_experiment import BaseExperiment
from safe_control_gym.utils.configuration import ConfigFactory
from safe_control_gym.utils.registration import make


def run(plot=True, training=False, n_episodes=1, n_steps=None, curr_path='.'):
'''Main function to run RL-MPC experiments.
Args:
plot (bool): Whether to plot the results.
training (bool): Whether to train the MPSC or load pre-trained values.
n_episodes (int): The number of episodes to execute.
n_steps (int): How many steps to run the experiment.
curr_path (str): The current relative path to the experiment folder.
'''

# Create the configuration dictionary.
fac = ConfigFactory()
config = fac.merge()
system = config.task

# Create an environment
env_func = partial(make,
config.task,
**config.task_config)
env = env_func()

# Setup controller.
ctrl = make(config.algo,
env_func,
**config.algo_config,
output_dir=curr_path + '/temp')

# Run without safety filter
experiment = BaseExperiment(env, ctrl)
results, uncert_metrics = experiment.run_evaluation(n_episodes=n_episodes, n_steps=n_steps)
elapsed_time_uncert = results['timestamp'][0][-1] - results['timestamp'][0][0]



if __name__ == '__main__':
run()
22 changes: 22 additions & 0 deletions examples/rlmpc/rlmpc_experiment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash

SYS='cartpole'
# SYS='quadrotor_2D'

TASK='stab'
# TASK='track'

ALGO='qlearning_mpc'

if [ "$SYS" == 'cartpole' ]; then
SYS_NAME=$SYS
else
SYS_NAME='quadrotor'
fi

# Model-predictive safety certification of an unsafe controller.
python3 ./rlmpc_experiment.py \
--task ${SYS_NAME} \
--algo ${ALGO} \
--overrides \
./config_overrides/${SYS}/${SYS}_${TASK}.yaml
4 changes: 4 additions & 0 deletions safe_control_gym/controllers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,7 @@
register(idx='rap',
entry_point='safe_control_gym.controllers.rarl.rap:RAP',
config_entry_point='safe_control_gym.controllers.rarl:rap.yaml')

register(idx='qlearning_mpc',
entry_point='safe_control_gym.controllers.mpc.qlearning_mpc:Qlearning_MPC',
config_entry_point='safe_control_gym.controllers.mpc:qlearning_mpc.yaml')
Loading

0 comments on commit 8841562

Please sign in to comment.