Skip to content

Commit

Permalink
Basic code for swingup on the training_rl_paper branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Jan 14, 2025
1 parent 868e425 commit c5de7f9
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 114 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ examples/mpsc/unsafe_rl_temp_data/
#
examples/pid/*data/
#
examples/rl/test_model*/
#
experiments/mpsc/temp-data/
experiments/mpsc/models/rl_models/
experiments/mpsc/results*/
Expand Down
46 changes: 15 additions & 31 deletions examples/rl/config_overrides/cartpole/cartpole_stab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,34 @@ task_config:

# state initialization
init_state:
init_x: 0.1
init_x_dot: -1.5
init_theta: -0.155
init_theta_dot: 0.75
init_x: 0.0
init_x_dot: 0.0
init_theta: 3.14
init_theta_dot: 0
randomized_init: True
randomized_inertial_prop: False

init_state_randomization_info:
init_x:
distrib: 'uniform'
low: -2
high: 2
low: -0.25
high: 0.25
init_x_dot:
distrib: 'uniform'
low: -2
high: 2
low: -0.25
high: 0.25
init_theta:
distrib: 'uniform'
low: -0.16
high: 0.16
low: 3.0
high: 3.3
init_theta_dot:
distrib: 'uniform'
low: -1
high: 1
low: 0
high: 0

task: stabilization
task_info:
stabilization_goal: [0.7, 0]
stabilization_goal: [0.0, 0]
stabilization_goal_tolerance: 0.0

inertial_prop:
Expand All @@ -48,25 +48,9 @@ task_config:
obs_goal_horizon: 0

# RL Reward
rew_state_weight: [1, 1, 1, 1]
rew_act_weight: 0.1
rew_state_weight: [0.1, 0.1, 1, 0.1]
rew_act_weight: 0.001
rew_exponential: True

# constraints
constraints:
- constraint_form: default_constraint
constrained_variable: state
upper_bounds:
- 2
- 2
- 0.16
- 1
lower_bounds:
- -2
- -2
- -0.16
- -1
- constraint_form: default_constraint
constrained_variable: input
done_on_out_of_bound: True
done_on_violation: False
2 changes: 1 addition & 1 deletion examples/rl/config_overrides/cartpole/ppo_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ algo_config:
max_grad_norm: 0.5

# runner args
max_env_steps: 300000
max_env_steps: 1000000
num_workers: 1
rollout_batch_size: 4
rollout_steps: 150
Expand Down
6 changes: 3 additions & 3 deletions examples/rl/rl_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from safe_control_gym.utils.registration import make


def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'):
def run(gui=True, plot=True, n_episodes=1, n_steps=None, curr_path='.'):
'''Main function to run RL experiments.
Args:
Expand All @@ -32,7 +32,7 @@ def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'):
fac = ConfigFactory()
config = fac.merge()

task = 'stab' if config.task_config.task == Task.STABILIZATION else 'track'
# task = 'stab' if config.task_config.task == Task.STABILIZATION else 'track'
if config.task == Environment.QUADROTOR:
system = f'quadrotor_{str(config.task_config.quad_type)}D'
else:
Expand All @@ -50,7 +50,7 @@ def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'):
output_dir=curr_path + '/temp')

# Load state_dict from trained.
ctrl.load(f'{curr_path}/models/{config.algo}/{config.algo}_model_{system}_{task}.pt')
ctrl.load(f'{curr_path}/test_model/model_best.pt')

# Remove temporary files and directories
shutil.rmtree(f'{curr_path}/temp', ignore_errors=True)
Expand Down
20 changes: 3 additions & 17 deletions examples/rl/rl_experiment.sh
Original file line number Diff line number Diff line change
@@ -1,25 +1,11 @@
#!/bin/bash

# SYS='cartpole'
# SYS='quadrotor_2D'
SYS='quadrotor_3D'

# TASK='stab'
TASK='track'

SYS='cartpole'
TASK='stab'
ALGO='ppo'
# ALGO='sac'
# ALGO='safe_explorer_ppo'

if [ "$SYS" == 'cartpole' ]; then
SYS_NAME=$SYS
else
SYS_NAME='quadrotor'
fi

# RL Experiment
python3 ./rl_experiment.py \
--task ${SYS_NAME} \
--task ${SYS} \
--algo ${ALGO} \
--overrides \
./config_overrides/${SYS}/${SYS}_${TASK}.yaml \
Expand Down
50 changes: 3 additions & 47 deletions examples/rl/train_rl_model.sh
Original file line number Diff line number Diff line change
@@ -1,61 +1,17 @@
#!/bin/bash

SYS='cartpole'
# SYS='quadrotor_2D'
# SYS='quadrotor_3D'

TASK='stab'
# TASK='track'

ALGO='ppo'
# ALGO='sac'
# ALGO='safe_explorer_ppo'

if [ "$SYS" == 'cartpole' ]; then
SYS_NAME=$SYS
else
SYS_NAME='quadrotor'
fi

# Removed the temporary data used to train the new unsafe model.
rm -r -f ./unsafe_rl_temp_data/

if [ "$ALGO" == 'safe_explorer_ppo' ]; then
# Pretrain the unsafe controller/agent.
python3 ../../safe_control_gym/experiments/train_rl_controller.py \
--algo ${ALGO} \
--task ${SYS_NAME} \
--overrides \
./config_overrides/${SYS}/${ALGO}_${SYS}_pretrain.yaml \
./config_overrides/${SYS}/${SYS}_${TASK}.yaml \
--output_dir ./unsafe_rl_temp_data/ \
--seed 2 \
--kv_overrides \
task_config.init_state=None

# Move the newly trained unsafe model.
mv ./unsafe_rl_temp_data/model_latest.pt ./models/${ALGO}/${ALGO}_pretrain_${SYS}_${TASK}.pt

# Removed the temporary data used to train the new unsafe model.
rm -r -f ./unsafe_rl_temp_data/
fi

# Train the unsafe controller/agent.
python3 ../../safe_control_gym/experiments/train_rl_controller.py \
--algo ${ALGO} \
--task ${SYS_NAME} \
--task ${SYS} \
--overrides \
./config_overrides/${SYS}/${ALGO}_${SYS}.yaml \
./config_overrides/${SYS}/${SYS}_${TASK}.yaml \
--output_dir ./unsafe_rl_temp_data/ \
--output_dir ./test_model/ \
--seed 2 \
--kv_overrides \
task_config.init_state=None \
task_config.randomized_init=True \
algo_config.pretrained=./models/${ALGO}/${ALGO}_pretrain_${SYS}_${TASK}.pt

# Move the newly trained unsafe model.
mv ./unsafe_rl_temp_data/model_best.pt ./models/${ALGO}/${ALGO}_model_${SYS}_${TASK}.pt

# Removed the temporary data used to train the new unsafe model.
rm -r -f ./unsafe_rl_temp_data/
task_config.randomized_init=True
23 changes: 12 additions & 11 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,15 @@ def run(self,
action = self.select_action(obs=obs, info=info)

# Adding safety filter
success = False
physical_action = env.denormalize_action(action)
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success:
action = env.normalize_action(certified_action)
else:
self.safety_filter.ocp_solver.reset()
if self.safety_filter is not None:
success = False
physical_action = env.denormalize_action(action)
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success:
action = env.normalize_action(certified_action)
elif self.safety_filter.use_acados:
self.safety_filter.ocp_solver.reset()

action = np.atleast_2d(np.squeeze([action]))
obs, rew, done, info = env.step(action)
Expand Down Expand Up @@ -301,10 +302,10 @@ def train_step(self):
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success and self.filter_train_actions is True:
action = self.env.envs[0].normalize_action(certified_action)
else:
elif self.safety_filter.use_acados:
self.safety_filter.ocp_solver.reset()

action = np.atleast_2d(np.squeeze([action]))
action = np.atleast_2d(np.squeeze([action])).reshape((self.rollout_batch_size, -1))
next_obs, rew, done, info = self.env.step(action)
if done[0] and self.use_safe_reset:
next_obs, info = self.env_reset(self.env, self.use_safe_reset)
Expand Down Expand Up @@ -436,7 +437,7 @@ def env_reset(self, env, use_safe_reset):
unextended_obs = np.squeeze(obs)[:self.env.envs[0].symbolic.nx]
self.safety_filter.reset_before_run()
_, success = self.safety_filter.certify_action(unextended_obs, action, info)
if not success:
if not success and self.safety_filter.use_acados:
self.safety_filter.ocp_solver.reset()

return obs, info
12 changes: 8 additions & 4 deletions safe_control_gym/envs/gym_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def _setup_symbolic(self, prior_prop={}, **kwargs):

def _set_action_space(self):
'''Sets the action space of the environment.'''
self.action_scale = 10
self.action_scale = 20
self.physical_action_bounds = (-1 * np.atleast_1d(self.action_scale), np.atleast_1d(self.action_scale))
self.action_threshold = 1 if self.NORMALIZED_RL_ACTION_SPACE else self.action_scale
self.action_space = spaces.Box(low=-self.action_threshold, high=self.action_threshold, shape=(1,))
Expand All @@ -438,12 +438,16 @@ def _set_action_space(self):

def _set_observation_space(self):
'''Sets the observation space of the environment.'''
# Angle at which to fail the episode.
self.theta_threshold_radians = 90 * math.pi / 180
# NOTE: different value in PyBullet gym (0.4) and OpenAI gym (2.4).
self.x_threshold = 2.4
self.x_dot_threshold = 20
self.theta_threshold_radians = 2 * math.pi # Angle at which to fail the episode.
self.theta_dot_threshold = 20
# Limit set to 2x: i.e. a failing observation is still within bounds.
obs_bound = np.array([self.x_threshold * 2, np.finfo(np.float32).max, self.theta_threshold_radians * 2, np.finfo(np.float32).max])
obs_bound = np.array([self.x_threshold * 2,
self.x_dot_threshold,
self.theta_threshold_radians * 2,
self.theta_dot_threshold])
self.state_space = spaces.Box(low=-obs_bound, high=obs_bound, dtype=np.float32)

# Concatenate goal info for RL
Expand Down

0 comments on commit c5de7f9

Please sign in to comment.