Skip to content

Commit

Permalink
Getting NL_MPSF to work
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Nov 4, 2024
1 parent 36afdbb commit 6294b73
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 218 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ algo_config:
gae_lambda: 0.92
clip_param: 0.2
target_kl: 1.0e-2
entropy_coef: 0.01
entropy_coef: 0.005

# optim args
opt_epochs: 20
Expand All @@ -24,43 +24,9 @@ algo_config:
eval_batch_size: 10

# misc
log_interval: 6600
save_interval: 660000
log_interval: 39600
save_interval: 396000
num_checkpoints: 0
eval_interval: 6600
eval_interval: 39600
eval_save_best: True
tensorboard: False

#algo_config:
# # model args
# hidden_dim: 128
# activation: "relu"
#
# # loss args
# gamma: 0.98
# use_gae: True # or False
# gae_lambda: 0.92
# use_clipped_value: False # or True
# clip_param: 0.1
# target_kl: 1.0e-5
# entropy_coef: 0.003
#
# # optim args
# opt_epochs: 25
# mini_batch_size: 256
# actor_lr: 7.2e-5
# critic_lr: 0.0266
#
# # runner args
# max_env_steps: 216000
# rollout_batch_size: 5
# rollout_steps: 660
# eval_batch_size: 10
#
# # misc
# log_interval: 6600
# save_interval: 660000
# num_checkpoints: 0
# eval_interval: 6600
# eval_save_best: True
# tensorboard: False
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ task_config:
obs_goal_horizon: 1

# RL Reward
rew_state_weight: [10., .0, 10., .0, .0, 0.0]
rew_act_weight: [.0, .0]
rew_state_weight: [10, 0.1, 10, 0.1, 0.1, 0.001]
rew_act_weight: [0.1, 0.1]
rew_exponential: True

disturbances:
Expand Down
16 changes: 8 additions & 8 deletions experiments/mpsc/config_overrides/ppo_quadrotor_2D_attitude.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ algo_config:
gae_lambda: 0.92
clip_param: 0.2
target_kl: 1.0e-2
entropy_coef: 0.01
entropy_coef: 0.005

# optim args
opt_epochs: 20
Expand All @@ -25,15 +25,15 @@ algo_config:
eval_batch_size: 10

# misc
log_interval: 6600
save_interval: 0
log_interval: 39600
save_interval: 396000
num_checkpoints: 0
eval_interval: 6600
eval_interval: 39600
eval_save_best: True
tensorboard: False

# safety filter
filter_train_actions: True
penalize_sf_diff: True
sf_penalty: 75
use_safe_reset: True
filter_train_actions: False
penalize_sf_diff: False
sf_penalty: 1
use_safe_reset: False
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ task_config:
low: -0.05
high: 0.05


task: traj_tracking
task_info:
trajectory_type: figure8
Expand All @@ -53,21 +52,14 @@ task_config:
inertial_prop:
M: 0.033
Iyy: 1.4e-05
beta_1: 18.11
beta_2: 3.68
beta_3: 0.0
alpha_1: -140.8
alpha_2: -13.4
alpha_3: 124.8
pitch_bias: 0.0 # in radian

episode_len_sec: 11
cost: rl_reward
obs_goal_horizon: 1

# RL Reward
rew_state_weight: [10., .0, 10., .0, .0, 0.0]
rew_act_weight: [.0, .0]
rew_state_weight: [10, 0.1, 10, 0.1, 0.1, 0.001]
rew_act_weight: [0.1, 0.1]
rew_exponential: True

# disturbances:
Expand Down
Binary file not shown.
11 changes: 2 additions & 9 deletions experiments/mpsc/mpsc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import matplotlib.pyplot as plt
import numpy as np

from safe_control_gym.envs.benchmark_env import Cost
from safe_control_gym.experiments.base_experiment import BaseExperiment, MetricExtractor
from safe_control_gym.safety_filters.mpsc.mpsc_utils import Cost_Function
from safe_control_gym.utils.configuration import ConfigFactory
Expand All @@ -31,12 +30,6 @@ def run(plot=False, training=False, model='ppo'):
config.algo_config['training'] = False
config.task_config['done_on_violation'] = False
config.task_config['randomized_init'] = False
if config.algo in ['ppo', 'sac']:
config.task_config['cost'] = Cost.RL_REWARD
config.task_config['normalized_rl_action_space'] = True
else:
config.task_config['cost'] = Cost.QUADRATIC
config.task_config['normalized_rl_action_space'] = False

system = 'quadrotor_2D_attitude'

Expand Down Expand Up @@ -168,5 +161,5 @@ def run_multiple_models(plot, all_models):


if __name__ == '__main__':
# run(training=False)
run_multiple_models(plot=False, all_models=['True', 'False'])
# run(plot=True, training=False, model='ppo')
run_multiple_models(plot=False, all_models=['none', 'mpsf'])
4 changes: 3 additions & 1 deletion experiments/mpsc/mpsc_experiment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ TASK='tracking'
ALGO='ppo'

SAFETY_FILTER='nl_mpsc'
MPSC_COST='precomputed_cost'
# SAFETY_FILTER='mpsc_acados'
# MPSC_COST='precomputed_cost'
MPSC_COST='one_step_cost'
MPSC_COST_HORIZON=25
DECAY_FACTOR=0.9

Expand Down
19 changes: 8 additions & 11 deletions experiments/mpsc/plotting_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ def benchmark_plot(system, task, algo):
X_GOAL = all_results['none']['X_GOAL']

uncert = all_results['none']['uncert_results']
mpsf = all_results['True']['cert_results']
none = all_results['False']['cert_results']
mpsf = all_results['mpsf']['cert_results']
none = all_results['none']['cert_results']
for i in [1, 5]:
print('Uncert')
met.data = uncert
Expand Down Expand Up @@ -424,14 +424,11 @@ def calculate_state_violations(data, i):


if __name__ == '__main__':
ordered_models = ['mpsf', 'True', 'ppo', 'none', 'False']
ordered_models = ['mpsf', 'none']

colors = {
'mpsf': 'royalblue',
'True': 'cornflowerblue',
'ppo': 'forestgreen',
'none': 'plum',
'False': 'violet',
}

def extract_rate_of_change_of_inputs(results_data, certified=True):
Expand Down Expand Up @@ -481,20 +478,20 @@ def extract_length_uncert(results_data, certified=False):
task_name = sys.argv[2]
algo_name = sys.argv[3]

benchmark_plot(system_name, task_name, algo_name)
# plot_all_logs(system_name, task_name, algo_name)
# benchmark_plot(system_name, task_name, algo_name)
plot_all_logs(system_name, task_name, algo_name)
# plot_model_comparisons(system_name, task_name, algo_name, extract_magnitude_of_corrections)
# plot_model_comparisons(system_name, task_name, algo_name, extract_max_correction)
# plot_model_comparisons(system_name, task_name, algo_name, extract_roc_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_roc_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_roc_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_number_of_corrections)
# plot_model_comparisons(system_name, task_name, algo_name, extract_length_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_length_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_reward_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_reward_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_reward_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_failed_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_failed_uncert)
Expand Down
Loading

0 comments on commit 6294b73

Please sign in to comment.