Skip to content

Commit

Permalink
Merging in latest changes from benchmark_mpsf which has nl_mpsc, chan…
Browse files Browse the repository at this point in the history
…ges to mpsc_acados, and some config changes
  • Loading branch information
Federico-PizarroBejarano committed Dec 6, 2024
1 parent ebf05ce commit 6e06675
Show file tree
Hide file tree
Showing 13 changed files with 1,188 additions and 106 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
algo: ppo
algo_config:
# model args
hidden_dim: 128
activation: tanh
hidden_dim: 64
activation: relu

# loss args
gamma: 0.98
Expand All @@ -19,7 +19,7 @@ algo_config:
critic_lr: 0.001

# runner args
max_env_steps: 2640000
max_env_steps: 660000
rollout_batch_size: 1
rollout_steps: 660
eval_batch_size: 10
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
task_config:
seed: 1337
info_in_reset: True
ctrl_freq: 60
pyb_freq: 60
Expand Down Expand Up @@ -36,13 +35,10 @@ task_config:
# RL Reward
rew_state_weight: [10, 0.1, 10, 0.1, 0.1, 0.001]
rew_act_weight: [0.1, 0.1]
info_mse_metric_state_weight: [1, 0, 1, 0, 0, 0]
rew_exponential: True

constraints:
- constraint_form: default_constraint
constrained_variable: state
upper_bounds: [ 0.9, 2, 1.45, 2, 0.75, 3]
lower_bounds: [-0.9, -2, 0.55, -2, -0.75, -3]
- constraint_form: default_constraint
constrained_variable: input

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ sf_config:
warmstart: True
integration_algo: rk4
use_terminal_set: False
max_w: 0.002

# Cost function
cost_function: one_step_cost
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
safety_filter: nl_mpsc
sf_config:
# LQR controller parameters
q_mpc: [18, 0.1, 18, 0.5, 0.5, 0.0001]
r_mpc: [3., 3.]

# MPC Parameters
use_acados: True
horizon: 25
warmstart: True
integration_algo: rk4
use_terminal_set: False

# Prior info
prior_info:
prior_prop: null
randomize_prior_prop: False
prior_prop_rand_info: null

# Learning disturbance bounds
n_samples: 6000

# Cost function
cost_function: one_step_cost
mpsc_cost_horizon: 5
decay_factor: 0.85

# Softening
soften_constraints: True
slack_cost: 250
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
algo: ppo
algo_config:
# model args
hidden_dim: 128
activation: tanh
hidden_dim: 64
activation: relu

# loss args
gamma: 0.98
Expand All @@ -19,7 +19,7 @@ algo_config:
critic_lr: 0.001

# runner args
max_env_steps: 2640000
max_env_steps: 660000
rollout_batch_size: 1
rollout_steps: 660
eval_batch_size: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ task_config:
# RL Reward
rew_state_weight: [10, 0.1, 10, 0.1, 0.1, 0.001]
rew_act_weight: [0.1, 0.1]
info_mse_metric_state_weight: [1, 0, 1, 0, 0, 0]
rew_exponential: True

constraints:
Expand Down
Binary file not shown.
6 changes: 3 additions & 3 deletions experiments/mpsc/mpsc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def run_multiple_models(plot, all_models):
all_cert_results[key].append(cert_results[key][0])

met = MetricExtractor()
uncert_metrics = met.compute_metrics(data=all_uncert_results)
cert_metrics = met.compute_metrics(data=all_cert_results)
uncert_metrics = met.compute_metrics(data=all_uncert_results, max_steps=660)
cert_metrics = met.compute_metrics(data=all_cert_results, max_steps=66)

all_results = {'uncert_results': all_uncert_results,
'uncert_metrics': uncert_metrics,
Expand All @@ -170,4 +170,4 @@ def run_multiple_models(plot, all_models):

if __name__ == '__main__':
# run(plot=True, training=False, model='none')
run_multiple_models(plot=True, all_models=['mpsf'])
run_multiple_models(plot=True, all_models=['mpsf7'])
1 change: 1 addition & 0 deletions experiments/mpsc/mpsc_experiment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ SYS='quadrotor_2D_attitude'
TASK='tracking'
ALGO='ppo'

# SAFETY_FILTER='nl_mpsc'
SAFETY_FILTER='mpsc_acados'
# MPSC_COST='one_step_cost'
MPSC_COST='precomputed_cost'
Expand Down
17 changes: 9 additions & 8 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,15 @@ def run(self,
action = self.select_action(obs=obs, info=info)

# Adding safety filter
success = False
physical_action = env.denormalize_action(action)
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success:
action = env.normalize_action(certified_action)
else:
self.safety_filter.ocp_solver.reset()
if self.safety_filter is not None:
success = False
physical_action = env.denormalize_action(action)
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success:
action = env.normalize_action(certified_action)
else:
self.safety_filter.ocp_solver.reset()

action = np.atleast_2d(np.squeeze([action]))
obs, rew, done, info = env.step(action)
Expand Down
4 changes: 4 additions & 0 deletions safe_control_gym/safety_filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
entry_point='safe_control_gym.safety_filters.mpsc.linear_mpsc:LINEAR_MPSC',
config_entry_point='safe_control_gym.safety_filters.mpsc:mpsc.yaml')

register(idx='nl_mpsc',
entry_point='safe_control_gym.safety_filters.mpsc.nl_mpsc:NL_MPSC',
config_entry_point='safe_control_gym.safety_filters.mpsc:mpsc.yaml')

register(idx='mpsc_acados',
entry_point='safe_control_gym.safety_filters.mpsc.mpsc_acados:MPSC_ACADOS',
config_entry_point='safe_control_gym.safety_filters.mpsc:mpsc.yaml')
Expand Down
Loading

0 comments on commit 6e06675

Please sign in to comment.