From a7500bb1a05c135d5253107b8504e86b0489b8f7 Mon Sep 17 00:00:00 2001 From: Federico-PizarroBejarano Date: Wed, 10 May 2023 11:27:40 -0400 Subject: [PATCH] Fixing plotting --- experiments/mpsc/plotting_results.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/experiments/mpsc/plotting_results.py b/experiments/mpsc/plotting_results.py index efc2225e5..e20c6f578 100644 --- a/experiments/mpsc/plotting_results.py +++ b/experiments/mpsc/plotting_results.py @@ -27,6 +27,9 @@ 'quadrotor_3D': 0.06615 } +met = MetricExtractor() +met.verbose = False + def load_one_experiment(system, task, algo, mpsc_cost_horizon): '''Loads the results of every MPSC cost function for a specific experiment. @@ -261,10 +264,10 @@ def extract_rmse(results_data, certified=True): rmse (list): The list of RMSEs for all experiments. ''' if certified: - met = MetricExtractor(results_data['cert_results']) + met.data = results_data['cert_results'] rmse = np.asarray(met.get_episode_rmse()) else: - met = MetricExtractor(results_data['uncert_results']) + met.data = results_data['uncert_results'] rmse = np.asarray(met.get_episode_rmse()) return rmse @@ -298,10 +301,10 @@ def extract_constraint_violations(results_data, certified=True): num_violations (list): The list of number of constraint violations for all experiments. ''' if certified: - met = MetricExtractor(results_data['cert_results']) + met.data = results_data['cert_results'] num_violations = np.asarray(met.get_episode_constraint_violation_steps()) else: - met = MetricExtractor(results_data['uncert_results']) + met.data = results_data['uncert_results'] num_violations = np.asarray(met.get_episode_constraint_violation_steps()) return num_violations @@ -400,7 +403,7 @@ def plot_trajectories(config, X_GOAL, uncert_results, cert_results): uncert_results (dict): The results of the uncertified experiment. cert_results (dict): The results of the certified experiment. ''' - met = MetricExtractor(cert_results) + met.data = cert_results print('Total Certified Violations:', np.asarray(met.get_episode_constraint_violation_steps()).sum()) if config.task == Environment.QUADROTOR: @@ -410,7 +413,7 @@ def plot_trajectories(config, X_GOAL, uncert_results, cert_results): for exp in range(len(uncert_results['obs'])): specific_results = {key: [cert_results[key][exp]] for key in cert_results.keys()} - met = MetricExtractor(specific_results) + met.data = specific_results print(f'Total Certified Violations ({exp}):', np.asarray(met.get_episode_constraint_violation_steps()).sum()) mpsc_results = cert_results['safety_filter_data'][exp] corrections = mpsc_results['correction'][0] * 10.0 > np.linalg.norm(cert_results['current_clipped_action'][exp] - U_EQs[system], axis=1) @@ -447,7 +450,7 @@ def plot_trajectories(config, X_GOAL, uncert_results, cert_results): if config.task_config.task == Task.TRAJ_TRACKING and config.task == Environment.CARTPOLE: _, ax2 = plt.subplots() - ax2.plot(np.linspace(0, 20, cert_results['obs'][exp].shape[0])[1:], X_GOAL[:, 0], 'g--', label='Reference') + ax2.plot(np.linspace(0, 20, cert_results['obs'][exp].shape[0]), X_GOAL[:, 0], 'g--', label='Reference') ax2.plot(np.linspace(0, 20, uncert_results['obs'][exp].shape[0]), uncert_results['obs'][exp][:, 0], 'r--', label='Uncertified') ax2.plot(np.linspace(0, 20, cert_results['obs'][exp].shape[0]), cert_results['obs'][exp][:, 0], '.-', label='Certified') ax2.plot(np.linspace(0, 20, cert_results['obs'][exp].shape[0])[corrections], cert_results['obs'][exp][corrections, 0], 'r.', label='Modified')