From 076393fb2968374b862e087955600a38912624b1 Mon Sep 17 00:00:00 2001 From: Federico-PizarroBejarano Date: Fri, 6 Dec 2024 11:33:17 -0500 Subject: [PATCH] Minor cleanup --- .../mpsc_acados_quadrotor_2D_attitude.yaml | 2 +- safe_control_gym/controllers/ppo/ppo.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/experiments/mpsc/config_overrides/mpsc_acados_quadrotor_2D_attitude.yaml b/experiments/mpsc/config_overrides/mpsc_acados_quadrotor_2D_attitude.yaml index b1be0a338..21146f069 100644 --- a/experiments/mpsc/config_overrides/mpsc_acados_quadrotor_2D_attitude.yaml +++ b/experiments/mpsc/config_overrides/mpsc_acados_quadrotor_2D_attitude.yaml @@ -1,4 +1,4 @@ -safety_filter: nl_mpsc +safety_filter: mpsc_acados sf_config: # LQR controller parameters q_mpc: [18, 0.1, 18, 0.5, 0.5, 0.0001] diff --git a/safe_control_gym/controllers/ppo/ppo.py b/safe_control_gym/controllers/ppo/ppo.py index d19e16ae5..4174a3f1c 100644 --- a/safe_control_gym/controllers/ppo/ppo.py +++ b/safe_control_gym/controllers/ppo/ppo.py @@ -45,6 +45,7 @@ def __init__(self, self.sf_penalty = 1 self.use_safe_reset = False super().__init__(env_func, training, checkpoint_path, output_dir, use_gpu, seed, **kwargs) + # Task. if self.training: # Training and testing. @@ -237,14 +238,15 @@ def run(self, action = self.select_action(obs=obs, info=info) # Adding safety filter - success = False - physical_action = env.denormalize_action(action) - unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx] - certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info) - if success: - action = env.normalize_action(certified_action) - else: - self.safety_filter.ocp_solver.reset() + if self.safety_filter is not None: + success = False + physical_action = env.denormalize_action(action) + unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx] + certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info) + if success: + action = env.normalize_action(certified_action) + else: + self.safety_filter.ocp_solver.reset() action = np.atleast_2d(np.squeeze([action])) obs, rew, done, info = env.step(action)