From 64c5337b955475fea525e0bfa070765f780455f0 Mon Sep 17 00:00:00 2001 From: Federico-PizarroBejarano Date: Tue, 19 Nov 2024 10:28:28 -0500 Subject: [PATCH] Adding basic id to PPO instances to allow for parallel training --- safe_control_gym/controllers/ppo/ppo.py | 3 ++- .../mpsc/mpsc_cost_function/precomputed_cost.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/safe_control_gym/controllers/ppo/ppo.py b/safe_control_gym/controllers/ppo/ppo.py index 7095f3429..b6b8e6108 100644 --- a/safe_control_gym/controllers/ppo/ppo.py +++ b/safe_control_gym/controllers/ppo/ppo.py @@ -90,6 +90,7 @@ def __init__(self, # Adding safety filter self.safety_filter = None + self.instance_idx = int(np.random.rand() * 1000000) def reset(self): '''Do initializations for training or evaluation.''' @@ -307,7 +308,7 @@ def train_step(self): info = self.info start = time.time() if self.safety_filter is not None and self.preserve_random_state is True: - self.save('./temp-data/saved_controller_prev.npy', save_only_random_seed=True) + self.save(f'./temp-data/saved_controller_prev_{self.instance_idx}.npy', save_only_random_seed=True) for _ in range(self.rollout_steps): with torch.no_grad(): action, v, logp = self.agent.ac.step(torch.FloatTensor(obs).to(self.device)) diff --git a/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py b/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py index 6a7c2dce2..22cb4c362 100644 --- a/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py +++ b/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py @@ -101,8 +101,8 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration): self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy') self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy') elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state: - self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy', save_only_random_seed=True) - self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy', load_only_random_seed=True) + self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.instance_idx}.npy', save_only_random_seed=True) + self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.instance_idx}.npy', load_only_random_seed=True) for h in range(self.mpsc_cost_horizon): next_step = min(iteration + h, self.env.X_GOAL.shape[0] - 1) @@ -136,7 +136,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration): self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy') self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy') elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state is True: - self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy', load_only_random_seed=True) - self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy', save_only_random_seed=True) + self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.instance_idx}.npy', load_only_random_seed=True) + self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.instance_idx}.npy', save_only_random_seed=True) return v_L