Skip to content

Commit

Permalink
better saving and plotting code
Browse files Browse the repository at this point in the history
  • Loading branch information
qfettes committed Jun 20, 2018
1 parent a55aada commit 0d8d683
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 17 deletions.
54 changes: 54 additions & 0 deletions agents/BaseAgent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import pickle
import os.path

import torch
import torch.optim as optim


class BaseAgent(object):
def __init__(self):
self.model=None
self.target_model=None
self.optimizer = None
self.losses = []
self.rewards = []
self.sigma_parameter_mag=[]

def save_w(self):
torch.save(self.model.state_dict(), './saved_agents/model.dump')
torch.save(self.optimizer.state_dict(), './saved_agents/optim.dump')

def load_w(self):
fname_model = "./saved_agents/model.dump"
fname_optim = "./saved_agents/optim.dump"

if os.path.isfile(fname_model):
self.model.load_state_dict(torch.load(fname_model))
self.target_model.load_state_dict(self.model.state_dict())

if os.path.isfile(fname_optim):
self.optimizer.load_state_dict(torch.load(fname_optim))

def save_replay(self):
pickle.dump(self.memory, open('./saved_agents/exp_replay_agent.dump', 'wb'))

def load_replay(self):
fname = './saved_agents/exp_replay_agent.dump'
if os.path.isfile(fname):
self.memory = pickle.load(open(fname, 'rb'))

def save_sigma_param_magnitudes(self):
tmp = []
for name, param in self.model.named_parameters():
if param.requires_grad:
if 'sigma' in name:
tmp+=param.data.cpu().numpy().ravel().tolist()
if tmp:
self.sigma_parameter_mag.append(np.mean(np.abs(np.array(tmp))))

def save_loss(self, loss):
self.losses.append(loss)

def save_reward(self, reward):
self.rewards.append(reward)
6 changes: 4 additions & 2 deletions agents/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import torch
import torch.optim as optim

from agents.BaseAgent import BaseAgent
from networks.networks import DQN
from utils.ReplayMemory import ExperienceReplayMemory, PrioritizedReplayMemory

from timeit import default_timer as timer

class Model(object):
class Model(BaseAgent):
def __init__(self, static_policy=False, env=None, config=None):
super(Model, self).__init__()
self.device = config.device
Expand Down Expand Up @@ -142,7 +143,8 @@ def update(self, s, a, r, s_, frame=0):
self.optimizer.step()

self.update_target_model()
return loss.item()
self.save_loss(loss.item())
self.save_sigma_param_magnitudes()

def get_action(self, s, eps=0.1):
with torch.no_grad():
Expand Down
32 changes: 17 additions & 15 deletions devel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,22 @@
config.SEQUENCE_LENGTH=8


def plot(frame_idx, rewards, losses, elapsed_time):
#clear_output(True)
'''plt.figure(figsize=(20,5))
def plot(frame_idx, rewards, losses, sigma, elapsed_time):
clear_output(True)
plt.figure(figsize=(20,5))
plt.subplot(131)
plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))
plt.plot(rewards)
plt.subplot(132)
plt.title('loss')
plt.plot(losses)
plt.show()'''
print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))
if losses:
plt.subplot(132)
plt.title('loss')
plt.plot(losses)
if sigma:
plt.subplot(133)
plt.title('noisy param magnitude')
plt.plot(sigma)
plt.show()
#print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))


if __name__=='__main__':
Expand All @@ -85,8 +90,6 @@ def plot(frame_idx, rewards, losses, elapsed_time):
#env = wrappers.Monitor(env, 'Delete', force=True)
model = Model(env=env, config=config)

losses = []
all_rewards = []
episode_reward = 0

observation = env.reset()
Expand All @@ -98,20 +101,19 @@ def plot(frame_idx, rewards, losses, elapsed_time):
observation, reward, done, _ = env.step(action)
observation = None if done else observation

loss = model.update(prev_observation, action, reward, observation, frame_idx)
model.update(prev_observation, action, reward, observation, frame_idx)
episode_reward += reward

if done:
model.finish_nstep()
model.reset_hx()
observation = env.reset()
all_rewards.append(episode_reward)
model.save_reward(episode_reward)
episode_reward = 0

if loss is not None:
losses.append(loss)

if frame_idx % 10000 == 0:
plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))
plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start)))

model.save_w()
env.close()
Empty file added saved_agents/__init__.py
Empty file.

0 comments on commit 0d8d683

Please sign in to comment.