Skip to content

Commit

Permalink
Laying the groundwork for eventual jacobian calculations and improvin…
Browse files Browse the repository at this point in the history
…g saving the random env state variable
  • Loading branch information
Federico-PizarroBejarano committed Nov 27, 2024
1 parent 86e8345 commit 56ed7f4
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 44 deletions.
44 changes: 21 additions & 23 deletions experiments/mpsc/plotting_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,7 @@ def plot_all_logs(system, task, algo):
all_results[model].append(load_from_logs(f'./models/rl_models/{model}/logs/'))

for key in all_results[ordered_models[0]][0].keys():
if key == 'stat_eval/ep_return':
plot_log(key, all_results)
if key == 'stat/constraint_violation':
if key in ['stat/ep_return', 'stat_eval/ep_return', 'stat/constraint_violation']:
plot_log(key, all_results)


Expand Down Expand Up @@ -666,7 +664,7 @@ def plot_log(key, all_results):


if __name__ == '__main__':
ordered_models = [model for model in os.listdir('./models/rl_models/') if 'curriculum' in model]
ordered_models = sorted([model for model in os.listdir('./models/rl_models/') if 'curriculum' in model and '20' not in model])
colors = plt.cm.viridis(np.linspace(0, 1, len(ordered_models)))

def extract_rate_of_change_of_inputs(results_data, certified=True):
Expand Down Expand Up @@ -724,24 +722,24 @@ def extract_length_uncert(results_data, certified=False):

plot_all_logs(system_name, task_name, algo_name)
plot_step_time(system_name, task_name, algo_name)
plot_model_comparisons(system_name, task_name, algo_name, extract_magnitude_of_corrections)
plot_model_comparisons(system_name, task_name, algo_name, extract_percent_magnitude_of_corrections)
plot_model_comparisons(system_name, task_name, algo_name, extract_max_correction)
plot_model_comparisons(system_name, task_name, algo_name, extract_percent_max_correction)
# plot_model_comparisons(system_name, task_name, algo_name, extract_magnitude_of_corrections)
# plot_model_comparisons(system_name, task_name, algo_name, extract_percent_magnitude_of_corrections)
# plot_model_comparisons(system_name, task_name, algo_name, extract_max_correction)
# plot_model_comparisons(system_name, task_name, algo_name, extract_percent_max_correction)
plot_model_comparisons(system_name, task_name, algo_name, extract_roc_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_roc_uncert)
plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_uncert)
plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_uncert)
plot_model_comparisons(system_name, task_name, algo_name, extract_number_of_corrections)
plot_model_comparisons(system_name, task_name, algo_name, extract_length_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_length_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_roc_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_rmse_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_constraint_violations_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_number_of_corrections)
# plot_model_comparisons(system_name, task_name, algo_name, extract_length_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_length_uncert)
plot_model_comparisons(system_name, task_name, algo_name, extract_reward_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_reward_uncert)
plot_model_comparisons(system_name, task_name, algo_name, extract_failed_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_failed_uncert)
plot_model_comparisons(system_name, task_name, algo_name, extract_feasible_iterations)
if task_name == 'stab':
plot_model_comparisons(system_name, task_name, algo_name, extract_final_dist_cert)
plot_model_comparisons(system_name, task_name, algo_name, extract_final_dist_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_reward_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_failed_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_failed_uncert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_feasible_iterations)
# if task_name == 'stab':
# plot_model_comparisons(system_name, task_name, algo_name, extract_final_dist_cert)
# plot_model_comparisons(system_name, task_name, algo_name, extract_final_dist_uncert)
2 changes: 1 addition & 1 deletion experiments/mpsc/train_all_models.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
sbatch train_model.sbatch False 1 1
sbatch train_model.sbatch False 1 1 False
for MPSC_COST_HORIZON in 2 5 10 20; do
for DECAY_FACTOR in 0.25 0.5 0.75 1; do
# Ignore precomputed differences
Expand Down
18 changes: 8 additions & 10 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self,

def reset(self):
'''Do initializations for training or evaluation.'''
self.env_state = {}
if self.training:
# set up stats tracking
self.env.add_tracker('constraint_violation', 0)
Expand Down Expand Up @@ -124,10 +125,7 @@ def save(self,
):
'''Saves model params and experiment state to checkpoint path.'''
if save_only_random_seed is True:
exp_state = {
'env_random_state': self.env.get_env_random_state()
}
torch.save(exp_state, path)
self.env_state[path] = self.env.get_env_random_state()
return
path_dir = os.path.dirname(path)
os.makedirs(path_dir, exist_ok=True)
Expand All @@ -150,10 +148,10 @@ def load(self,
load_only_random_seed=False,
):
'''Restores model and experiment given checkpoint path.'''
state = torch.load(path)
if load_only_random_seed is True:
self.env.set_env_random_state(state['env_random_state'])
self.env.set_env_random_state(self.env_state[path])
return
state = torch.load(path)
# Restore policy.
self.agent.load_state_dict(state['agent'])
self.obs_normalizer.load_state_dict(state['obs_normalizer'])
Expand Down Expand Up @@ -257,7 +255,7 @@ def run(self,
success = False
physical_action = env.denormalize_action(action)
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
certified_action, success, jacobian = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success:
action = env.normalize_action(certified_action)
else:
Expand Down Expand Up @@ -307,7 +305,7 @@ def train_step(self):
info = self.info
start = time.time()
if self.safety_filter is not None and self.preserve_random_state is True:
self.save(f'./temp-data/saved_controller_prev_{self.model_name}.npy', save_only_random_seed=True)
self.save('prev', save_only_random_seed=True)
for _ in range(self.rollout_steps):
with torch.no_grad():
action, v, logp = self.agent.ac.step(torch.FloatTensor(obs).to(self.device))
Expand All @@ -318,7 +316,7 @@ def train_step(self):
if self.safety_filter is not None and (self.filter_train_actions is True or self.penalize_sf_diff is True):
physical_action = self.env.envs[0].denormalize_action(action)
unextended_obs = np.squeeze(true_obs)[:self.env.envs[0].symbolic.nx]
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
certified_action, success, jacobian = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success and self.filter_train_actions is True:
action = self.env.envs[0].normalize_action(certified_action)
else:
Expand Down Expand Up @@ -455,7 +453,7 @@ def env_reset(self, env, use_safe_reset):
info['current_step'] = 1
unextended_obs = np.squeeze(obs)[:self.env.envs[0].symbolic.nx]
self.safety_filter.reset_before_run()
_, success = self.safety_filter.certify_action(unextended_obs, action, info)
_, success, _ = self.safety_filter.certify_action(unextended_obs, action, info)
if not success:
self.safety_filter.ocp_solver.reset()

Expand Down
2 changes: 1 addition & 1 deletion safe_control_gym/experiments/base_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _select_action(self, obs, info):
if self.safety_filter is not None:
physical_action = self.env.denormalize_action(action)
unextended_obs = obs[:self.env.symbolic.nx]
certified_action, _ = self.safety_filter.certify_action(unextended_obs, physical_action, info)
certified_action, _, _ = self.safety_filter.certify_action(unextended_obs, physical_action, info)
action = self.env.normalize_action(certified_action)

return action
Expand Down
11 changes: 6 additions & 5 deletions safe_control_gym/safety_filters/mpsc/mpsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,10 @@ def solve_optimization(self,

if self.use_acados:
action, feasible = self.solve_acados_optimization(obs, uncertified_action, iteration)
jacobian = None
else:
action, feasible = self.solve_casadi_optimization(obs, uncertified_action, iteration)
return action, feasible
action, feasible, jacobian = self.solve_casadi_optimization(obs, uncertified_action, iteration)
return action, feasible, jacobian

def solve_casadi_optimization(self,
obs,
Expand Down Expand Up @@ -227,7 +228,7 @@ def solve_casadi_optimization(self,
print(e)
feasible = False
action = None
return action, feasible
return action, feasible, None

def solve_acados_optimization(self,
obs,
Expand Down Expand Up @@ -301,7 +302,7 @@ def certify_action(self,

self.before_optimization(current_state)
iteration = self.extract_step(info)
action, feasible = self.solve_optimization(current_state, uncertified_action, iteration)
action, feasible, jacobian = self.solve_optimization(current_state, uncertified_action, iteration)
self.results_dict['feasible'].append(feasible)

if feasible:
Expand Down Expand Up @@ -334,7 +335,7 @@ def certify_action(self,
self.results_dict['certified_action'].append(certified_action)
self.results_dict['correction'].append(np.linalg.norm(certified_action - uncertified_action))

return certified_action, success
return certified_action, success, jacobian

def setup_results_dict(self):
'''Setup the results dictionary to store run information.'''
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy')
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state:
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.model_name}.npy', save_only_random_seed=True)
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.model_name}.npy', load_only_random_seed=True)
self.uncertified_controller.save('curr', save_only_random_seed=True)
self.uncertified_controller.load('prev', load_only_random_seed=True)

for h in range(self.mpsc_cost_horizon):
next_step = min(iteration + h, self.env.X_GOAL.shape[0] - 1)
Expand Down Expand Up @@ -136,7 +136,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy')
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state is True:
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.model_name}.npy', load_only_random_seed=True)
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.model_name}.npy', save_only_random_seed=True)
self.uncertified_controller.load('curr', load_only_random_seed=True)
self.uncertified_controller.save('prev', save_only_random_seed=True)

return v_L

0 comments on commit 56ed7f4

Please sign in to comment.