Skip to content

Commit

Permalink
Adding basic id to PPO instances to allow for parallel training
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Nov 19, 2024
1 parent 43f0f44 commit 64c5337
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self,

# Adding safety filter
self.safety_filter = None
self.instance_idx = int(np.random.rand() * 1000000)

def reset(self):
'''Do initializations for training or evaluation.'''
Expand Down Expand Up @@ -307,7 +308,7 @@ def train_step(self):
info = self.info
start = time.time()
if self.safety_filter is not None and self.preserve_random_state is True:
self.save('./temp-data/saved_controller_prev.npy', save_only_random_seed=True)
self.save(f'./temp-data/saved_controller_prev_{self.instance_idx}.npy', save_only_random_seed=True)
for _ in range(self.rollout_steps):
with torch.no_grad():
action, v, logp = self.agent.ac.step(torch.FloatTensor(obs).to(self.device))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy')
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state:
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy', save_only_random_seed=True)
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy', load_only_random_seed=True)
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.instance_idx}.npy', save_only_random_seed=True)
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.instance_idx}.npy', load_only_random_seed=True)

for h in range(self.mpsc_cost_horizon):
next_step = min(iteration + h, self.env.X_GOAL.shape[0] - 1)
Expand Down Expand Up @@ -136,7 +136,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy')
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state is True:
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy', load_only_random_seed=True)
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy', save_only_random_seed=True)
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.instance_idx}.npy', load_only_random_seed=True)
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.instance_idx}.npy', save_only_random_seed=True)

return v_L

0 comments on commit 64c5337

Please sign in to comment.