From 0d8d6832ba9cbc6d2e017597ffbee93309d0f904 Mon Sep 17 00:00:00 2001 From: quintin Date: Wed, 20 Jun 2018 13:45:55 -0400 Subject: [PATCH] better saving and plotting code --- agents/BaseAgent.py | 54 ++++++++++++++++++++++++++++++++++++++++ agents/DQN.py | 6 +++-- devel.py | 32 +++++++++++++----------- saved_agents/__init__.py | 0 4 files changed, 75 insertions(+), 17 deletions(-) create mode 100644 agents/BaseAgent.py create mode 100644 saved_agents/__init__.py diff --git a/agents/BaseAgent.py b/agents/BaseAgent.py new file mode 100644 index 0000000..06d6392 --- /dev/null +++ b/agents/BaseAgent.py @@ -0,0 +1,54 @@ +import numpy as np +import pickle +import os.path + +import torch +import torch.optim as optim + + +class BaseAgent(object): + def __init__(self): + self.model=None + self.target_model=None + self.optimizer = None + self.losses = [] + self.rewards = [] + self.sigma_parameter_mag=[] + + def save_w(self): + torch.save(self.model.state_dict(), './saved_agents/model.dump') + torch.save(self.optimizer.state_dict(), './saved_agents/optim.dump') + + def load_w(self): + fname_model = "./saved_agents/model.dump" + fname_optim = "./saved_agents/optim.dump" + + if os.path.isfile(fname_model): + self.model.load_state_dict(torch.load(fname_model)) + self.target_model.load_state_dict(self.model.state_dict()) + + if os.path.isfile(fname_optim): + self.optimizer.load_state_dict(torch.load(fname_optim)) + + def save_replay(self): + pickle.dump(self.memory, open('./saved_agents/exp_replay_agent.dump', 'wb')) + + def load_replay(self): + fname = './saved_agents/exp_replay_agent.dump' + if os.path.isfile(fname): + self.memory = pickle.load(open(fname, 'rb')) + + def save_sigma_param_magnitudes(self): + tmp = [] + for name, param in self.model.named_parameters(): + if param.requires_grad: + if 'sigma' in name: + tmp+=param.data.cpu().numpy().ravel().tolist() + if tmp: + self.sigma_parameter_mag.append(np.mean(np.abs(np.array(tmp)))) + + def save_loss(self, loss): + self.losses.append(loss) + + def save_reward(self, reward): + self.rewards.append(reward) \ No newline at end of file diff --git a/agents/DQN.py b/agents/DQN.py index 14ca79b..08a5596 100644 --- a/agents/DQN.py +++ b/agents/DQN.py @@ -3,12 +3,13 @@ import torch import torch.optim as optim +from agents.BaseAgent import BaseAgent from networks.networks import DQN from utils.ReplayMemory import ExperienceReplayMemory, PrioritizedReplayMemory from timeit import default_timer as timer -class Model(object): +class Model(BaseAgent): def __init__(self, static_policy=False, env=None, config=None): super(Model, self).__init__() self.device = config.device @@ -142,7 +143,8 @@ def update(self, s, a, r, s_, frame=0): self.optimizer.step() self.update_target_model() - return loss.item() + self.save_loss(loss.item()) + self.save_sigma_param_magnitudes() def get_action(self, s, eps=0.1): with torch.no_grad(): diff --git a/devel.py b/devel.py index 31ffbc1..013a7a3 100644 --- a/devel.py +++ b/devel.py @@ -61,17 +61,22 @@ config.SEQUENCE_LENGTH=8 -def plot(frame_idx, rewards, losses, elapsed_time): - #clear_output(True) - '''plt.figure(figsize=(20,5)) +def plot(frame_idx, rewards, losses, sigma, elapsed_time): + clear_output(True) + plt.figure(figsize=(20,5)) plt.subplot(131) plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time)) plt.plot(rewards) - plt.subplot(132) - plt.title('loss') - plt.plot(losses) - plt.show()''' - print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time)) + if losses: + plt.subplot(132) + plt.title('loss') + plt.plot(losses) + if sigma: + plt.subplot(133) + plt.title('noisy param magnitude') + plt.plot(sigma) + plt.show() + #print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time)) if __name__=='__main__': @@ -85,8 +90,6 @@ def plot(frame_idx, rewards, losses, elapsed_time): #env = wrappers.Monitor(env, 'Delete', force=True) model = Model(env=env, config=config) - losses = [] - all_rewards = [] episode_reward = 0 observation = env.reset() @@ -98,20 +101,19 @@ def plot(frame_idx, rewards, losses, elapsed_time): observation, reward, done, _ = env.step(action) observation = None if done else observation - loss = model.update(prev_observation, action, reward, observation, frame_idx) + model.update(prev_observation, action, reward, observation, frame_idx) episode_reward += reward if done: model.finish_nstep() model.reset_hx() observation = env.reset() - all_rewards.append(episode_reward) + model.save_reward(episode_reward) episode_reward = 0 - if loss is not None: - losses.append(loss) if frame_idx % 10000 == 0: - plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start))) + plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start))) + model.save_w() env.close() \ No newline at end of file diff --git a/saved_agents/__init__.py b/saved_agents/__init__.py new file mode 100644 index 0000000..e69de29