From 56ed7f421cc3349185cec83f6ec0f0cc09e907d5 Mon Sep 17 00:00:00 2001 From: Federico-PizarroBejarano Date: Wed, 27 Nov 2024 15:04:44 -0500 Subject: [PATCH] Laying the groundwork for eventual jacobian calculations and improving saving the random env state variable --- experiments/mpsc/plotting_results.py | 44 +++++++++---------- experiments/mpsc/train_all_models.sh | 2 +- safe_control_gym/controllers/ppo/ppo.py | 18 ++++---- .../experiments/base_experiment.py | 2 +- safe_control_gym/safety_filters/mpsc/mpsc.py | 11 ++--- .../mpsc_cost_function/precomputed_cost.py | 8 ++-- 6 files changed, 41 insertions(+), 44 deletions(-) diff --git a/experiments/mpsc/plotting_results.py b/experiments/mpsc/plotting_results.py index f5f1f6d44..d4be3c3fb 100644 --- a/experiments/mpsc/plotting_results.py +++ b/experiments/mpsc/plotting_results.py @@ -628,9 +628,7 @@ def plot_all_logs(system, task, algo): all_results[model].append(load_from_logs(f'./models/rl_models/{model}/logs/')) for key in all_results[ordered_models[0]][0].keys(): - if key == 'stat_eval/ep_return': - plot_log(key, all_results) - if key == 'stat/constraint_violation': + if key in ['stat/ep_return', 'stat_eval/ep_return', 'stat/constraint_violation']: plot_log(key, all_results) @@ -666,7 +664,7 @@ def plot_log(key, all_results): if __name__ == '__main__': - ordered_models = [model for model in os.listdir('./models/rl_models/') if 'curriculum' in model] + ordered_models = sorted([model for model in os.listdir('./models/rl_models/') if 'curriculum' in model and '20' not in model]) colors = plt.cm.viridis(np.linspace(0, 1, len(ordered_models))) def extract_rate_of_change_of_inputs(results_data, certified=True): @@ -724,24 +722,24 @@ def extract_length_uncert(results_data, certified=False): plot_all_logs(system_name, task_name, algo_name) plot_step_time(system_name, task_name, algo_name) - plot_model_comparisons(system_name, task_name, algo_name, extract_magnitude_of_corrections) - plot_model_comparisons(system_name, task_name, algo_name, extract_percent_magnitude_of_corrections) - plot_model_comparisons(system_name, task_name, algo_name, extract_max_correction) - plot_model_comparisons(system_name, task_name, algo_name, extract_percent_max_correction) + # plot_model_comparisons(system_name, task_name, algo_name, extract_magnitude_of_corrections) + # plot_model_comparisons(system_name, task_name, algo_name, extract_percent_magnitude_of_corrections) + # plot_model_comparisons(system_name, task_name, algo_name, extract_max_correction) + # plot_model_comparisons(system_name, task_name, algo_name, extract_percent_max_correction) plot_model_comparisons(system_name, task_name, algo_name, extract_roc_cert) - plot_model_comparisons(system_name, task_name, algo_name, extract_roc_uncert) - plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_cert) - plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_uncert) - plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_cert) - plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_uncert) - plot_model_comparisons(system_name, task_name, algo_name, extract_number_of_corrections) - plot_model_comparisons(system_name, task_name, algo_name, extract_length_cert) - plot_model_comparisons(system_name, task_name, algo_name, extract_length_uncert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_roc_uncert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_cert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_uncert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_cert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_uncert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_number_of_corrections) + # plot_model_comparisons(system_name, task_name, algo_name, extract_length_cert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_length_uncert) plot_model_comparisons(system_name, task_name, algo_name, extract_reward_cert) - plot_model_comparisons(system_name, task_name, algo_name, extract_reward_uncert) - plot_model_comparisons(system_name, task_name, algo_name, extract_failed_cert) - plot_model_comparisons(system_name, task_name, algo_name, extract_failed_uncert) - plot_model_comparisons(system_name, task_name, algo_name, extract_feasible_iterations) - if task_name == 'stab': - plot_model_comparisons(system_name, task_name, algo_name, extract_final_dist_cert) - plot_model_comparisons(system_name, task_name, algo_name, extract_final_dist_uncert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_reward_uncert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_failed_cert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_failed_uncert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_feasible_iterations) + # if task_name == 'stab': + # plot_model_comparisons(system_name, task_name, algo_name, extract_final_dist_cert) + # plot_model_comparisons(system_name, task_name, algo_name, extract_final_dist_uncert) diff --git a/experiments/mpsc/train_all_models.sh b/experiments/mpsc/train_all_models.sh index f699c7cbf..f5263c538 100755 --- a/experiments/mpsc/train_all_models.sh +++ b/experiments/mpsc/train_all_models.sh @@ -1,5 +1,5 @@ #!/bin/bash -sbatch train_model.sbatch False 1 1 +sbatch train_model.sbatch False 1 1 False for MPSC_COST_HORIZON in 2 5 10 20; do for DECAY_FACTOR in 0.25 0.5 0.75 1; do # Ignore precomputed differences diff --git a/safe_control_gym/controllers/ppo/ppo.py b/safe_control_gym/controllers/ppo/ppo.py index 3ab94448f..c0d356db1 100644 --- a/safe_control_gym/controllers/ppo/ppo.py +++ b/safe_control_gym/controllers/ppo/ppo.py @@ -93,6 +93,7 @@ def __init__(self, def reset(self): '''Do initializations for training or evaluation.''' + self.env_state = {} if self.training: # set up stats tracking self.env.add_tracker('constraint_violation', 0) @@ -124,10 +125,7 @@ def save(self, ): '''Saves model params and experiment state to checkpoint path.''' if save_only_random_seed is True: - exp_state = { - 'env_random_state': self.env.get_env_random_state() - } - torch.save(exp_state, path) + self.env_state[path] = self.env.get_env_random_state() return path_dir = os.path.dirname(path) os.makedirs(path_dir, exist_ok=True) @@ -150,10 +148,10 @@ def load(self, load_only_random_seed=False, ): '''Restores model and experiment given checkpoint path.''' - state = torch.load(path) if load_only_random_seed is True: - self.env.set_env_random_state(state['env_random_state']) + self.env.set_env_random_state(self.env_state[path]) return + state = torch.load(path) # Restore policy. self.agent.load_state_dict(state['agent']) self.obs_normalizer.load_state_dict(state['obs_normalizer']) @@ -257,7 +255,7 @@ def run(self, success = False physical_action = env.denormalize_action(action) unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx] - certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info) + certified_action, success, jacobian = self.safety_filter.certify_action(unextended_obs, physical_action, info) if success: action = env.normalize_action(certified_action) else: @@ -307,7 +305,7 @@ def train_step(self): info = self.info start = time.time() if self.safety_filter is not None and self.preserve_random_state is True: - self.save(f'./temp-data/saved_controller_prev_{self.model_name}.npy', save_only_random_seed=True) + self.save('prev', save_only_random_seed=True) for _ in range(self.rollout_steps): with torch.no_grad(): action, v, logp = self.agent.ac.step(torch.FloatTensor(obs).to(self.device)) @@ -318,7 +316,7 @@ def train_step(self): if self.safety_filter is not None and (self.filter_train_actions is True or self.penalize_sf_diff is True): physical_action = self.env.envs[0].denormalize_action(action) unextended_obs = np.squeeze(true_obs)[:self.env.envs[0].symbolic.nx] - certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info) + certified_action, success, jacobian = self.safety_filter.certify_action(unextended_obs, physical_action, info) if success and self.filter_train_actions is True: action = self.env.envs[0].normalize_action(certified_action) else: @@ -455,7 +453,7 @@ def env_reset(self, env, use_safe_reset): info['current_step'] = 1 unextended_obs = np.squeeze(obs)[:self.env.envs[0].symbolic.nx] self.safety_filter.reset_before_run() - _, success = self.safety_filter.certify_action(unextended_obs, action, info) + _, success, _ = self.safety_filter.certify_action(unextended_obs, action, info) if not success: self.safety_filter.ocp_solver.reset() diff --git a/safe_control_gym/experiments/base_experiment.py b/safe_control_gym/experiments/base_experiment.py index 14ff4e43b..5005e8791 100644 --- a/safe_control_gym/experiments/base_experiment.py +++ b/safe_control_gym/experiments/base_experiment.py @@ -160,7 +160,7 @@ def _select_action(self, obs, info): if self.safety_filter is not None: physical_action = self.env.denormalize_action(action) unextended_obs = obs[:self.env.symbolic.nx] - certified_action, _ = self.safety_filter.certify_action(unextended_obs, physical_action, info) + certified_action, _, _ = self.safety_filter.certify_action(unextended_obs, physical_action, info) action = self.env.normalize_action(certified_action) return action diff --git a/safe_control_gym/safety_filters/mpsc/mpsc.py b/safe_control_gym/safety_filters/mpsc/mpsc.py index ee7c8521b..4fd159267 100644 --- a/safe_control_gym/safety_filters/mpsc/mpsc.py +++ b/safe_control_gym/safety_filters/mpsc/mpsc.py @@ -162,9 +162,10 @@ def solve_optimization(self, if self.use_acados: action, feasible = self.solve_acados_optimization(obs, uncertified_action, iteration) + jacobian = None else: - action, feasible = self.solve_casadi_optimization(obs, uncertified_action, iteration) - return action, feasible + action, feasible, jacobian = self.solve_casadi_optimization(obs, uncertified_action, iteration) + return action, feasible, jacobian def solve_casadi_optimization(self, obs, @@ -227,7 +228,7 @@ def solve_casadi_optimization(self, print(e) feasible = False action = None - return action, feasible + return action, feasible, None def solve_acados_optimization(self, obs, @@ -301,7 +302,7 @@ def certify_action(self, self.before_optimization(current_state) iteration = self.extract_step(info) - action, feasible = self.solve_optimization(current_state, uncertified_action, iteration) + action, feasible, jacobian = self.solve_optimization(current_state, uncertified_action, iteration) self.results_dict['feasible'].append(feasible) if feasible: @@ -334,7 +335,7 @@ def certify_action(self, self.results_dict['certified_action'].append(certified_action) self.results_dict['correction'].append(np.linalg.norm(certified_action - uncertified_action)) - return certified_action, success + return certified_action, success, jacobian def setup_results_dict(self): '''Setup the results dictionary to store run information.''' diff --git a/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py b/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py index e8ff9a3b9..35b8efc27 100644 --- a/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py +++ b/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py @@ -101,8 +101,8 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration): self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy') self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy') elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state: - self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.model_name}.npy', save_only_random_seed=True) - self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.model_name}.npy', load_only_random_seed=True) + self.uncertified_controller.save('curr', save_only_random_seed=True) + self.uncertified_controller.load('prev', load_only_random_seed=True) for h in range(self.mpsc_cost_horizon): next_step = min(iteration + h, self.env.X_GOAL.shape[0] - 1) @@ -136,7 +136,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration): self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy') self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy') elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state is True: - self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.model_name}.npy', load_only_random_seed=True) - self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.model_name}.npy', save_only_random_seed=True) + self.uncertified_controller.load('curr', load_only_random_seed=True) + self.uncertified_controller.save('prev', save_only_random_seed=True) return v_L