Skip to content

Commit

Permalink
Fixing plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed May 10, 2023
1 parent ffa5c87 commit a7500bb
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions experiments/mpsc/plotting_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
'quadrotor_3D': 0.06615
}

met = MetricExtractor()
met.verbose = False


def load_one_experiment(system, task, algo, mpsc_cost_horizon):
'''Loads the results of every MPSC cost function for a specific experiment.
Expand Down Expand Up @@ -261,10 +264,10 @@ def extract_rmse(results_data, certified=True):
rmse (list): The list of RMSEs for all experiments.
'''
if certified:
met = MetricExtractor(results_data['cert_results'])
met.data = results_data['cert_results']
rmse = np.asarray(met.get_episode_rmse())
else:
met = MetricExtractor(results_data['uncert_results'])
met.data = results_data['uncert_results']
rmse = np.asarray(met.get_episode_rmse())
return rmse

Expand Down Expand Up @@ -298,10 +301,10 @@ def extract_constraint_violations(results_data, certified=True):
num_violations (list): The list of number of constraint violations for all experiments.
'''
if certified:
met = MetricExtractor(results_data['cert_results'])
met.data = results_data['cert_results']
num_violations = np.asarray(met.get_episode_constraint_violation_steps())
else:
met = MetricExtractor(results_data['uncert_results'])
met.data = results_data['uncert_results']
num_violations = np.asarray(met.get_episode_constraint_violation_steps())

return num_violations
Expand Down Expand Up @@ -400,7 +403,7 @@ def plot_trajectories(config, X_GOAL, uncert_results, cert_results):
uncert_results (dict): The results of the uncertified experiment.
cert_results (dict): The results of the certified experiment.
'''
met = MetricExtractor(cert_results)
met.data = cert_results
print('Total Certified Violations:', np.asarray(met.get_episode_constraint_violation_steps()).sum())

if config.task == Environment.QUADROTOR:
Expand All @@ -410,7 +413,7 @@ def plot_trajectories(config, X_GOAL, uncert_results, cert_results):

for exp in range(len(uncert_results['obs'])):
specific_results = {key: [cert_results[key][exp]] for key in cert_results.keys()}
met = MetricExtractor(specific_results)
met.data = specific_results
print(f'Total Certified Violations ({exp}):', np.asarray(met.get_episode_constraint_violation_steps()).sum())
mpsc_results = cert_results['safety_filter_data'][exp]
corrections = mpsc_results['correction'][0] * 10.0 > np.linalg.norm(cert_results['current_clipped_action'][exp] - U_EQs[system], axis=1)
Expand Down Expand Up @@ -447,7 +450,7 @@ def plot_trajectories(config, X_GOAL, uncert_results, cert_results):

if config.task_config.task == Task.TRAJ_TRACKING and config.task == Environment.CARTPOLE:
_, ax2 = plt.subplots()
ax2.plot(np.linspace(0, 20, cert_results['obs'][exp].shape[0])[1:], X_GOAL[:, 0], 'g--', label='Reference')
ax2.plot(np.linspace(0, 20, cert_results['obs'][exp].shape[0]), X_GOAL[:, 0], 'g--', label='Reference')
ax2.plot(np.linspace(0, 20, uncert_results['obs'][exp].shape[0]), uncert_results['obs'][exp][:, 0], 'r--', label='Uncertified')
ax2.plot(np.linspace(0, 20, cert_results['obs'][exp].shape[0]), cert_results['obs'][exp][:, 0], '.-', label='Certified')
ax2.plot(np.linspace(0, 20, cert_results['obs'][exp].shape[0])[corrections], cert_results['obs'][exp][corrections, 0], 'r.', label='Modified')
Expand Down

0 comments on commit a7500bb

Please sign in to comment.