Skip to content

Commit

Permalink
Training bug in 01.DQN.ipynb fixed (mention in issue #7). Updated Plo…
Browse files Browse the repository at this point in the history
…tting code. Added MSE as an option for the loss function (now default for DQN). New results for 01.DQN.ipynb. Retesting other notebooks coming soon
  • Loading branch information
qfettes committed Feb 11, 2019
1 parent 1996cca commit b3d65a8
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 144 deletions.
169 changes: 81 additions & 88 deletions 01.DQN.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions agents/BaseAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def huber(self, x):
cond = (x.abs() < 1.0).float().detach()
return 0.5 * x.pow(2) * cond + (x.abs() - 0.5) * (1.0 - cond)

def MSE(self, x):
return 0.5 * x.pow(2)

def save_w(self):
torch.save(self.model.state_dict(), './saved_agents/model.dump')
torch.save(self.optimizer.state_dict(), './saved_agents/optim.dump')
Expand Down
11 changes: 6 additions & 5 deletions agents/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, static_policy=False, env=None, config=None):
self.experience_replay_size = config.EXP_REPLAY_SIZE
self.batch_size = config.BATCH_SIZE
self.learn_start = config.LEARN_START
self.update_freq = config.UPDATE_FREQ
self.sigma_init= config.SIGMA_INIT
self.priority_beta_start = config.PRIORITY_BETA_START
self.priority_beta_frames = config.PRIORITY_BETA_FRAMES
Expand Down Expand Up @@ -58,8 +59,8 @@ def __init__(self, static_policy=False, env=None, config=None):
self.nstep_buffer = []

def declare_networks(self):
self.model = DQN(self.num_feats, self.num_actions, noisy=self.noisy, sigma_init=self.sigma_init, body=SimpleBody)
self.target_model = DQN(self.num_feats, self.num_actions, noisy=self.noisy, sigma_init=self.sigma_init, body=SimpleBody)
self.model = DQN(self.num_feats, self.num_actions, noisy=self.noisy, sigma_init=self.sigma_init, body=AtariBody)
self.target_model = DQN(self.num_feats, self.num_actions, noisy=self.noisy, sigma_init=self.sigma_init, body=AtariBody)

def declare_memory(self):
self.memory = ExperienceReplayMemory(self.experience_replay_size) if not self.priority_replay else PrioritizedReplayMemory(self.experience_replay_size, self.priority_alpha, self.priority_beta_start, self.priority_beta_frames)
Expand Down Expand Up @@ -116,9 +117,9 @@ def compute_loss(self, batch_vars):
diff = (expected_q_values - current_q_values)
if self.priority_replay:
self.memory.update_priorities(indices, diff.detach().squeeze().abs().cpu().numpy().tolist())
loss = self.huber(diff).squeeze() * weights
loss = self.MSE(diff).squeeze() * weights
else:
loss = self.huber(diff)
loss = self.MSE(diff)
loss = loss.mean()

return loss
Expand All @@ -129,7 +130,7 @@ def update(self, s, a, r, s_, frame=0):

self.append_to_replay(s, a, r, s_)

if frame < self.learn_start:
if frame < self.learn_start or frame % self.update_freq != 0:
return None

batch_vars = self.prep_minibatch()
Expand Down
89 changes: 39 additions & 50 deletions dqn_devel.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,80 @@
import gym
import numpy as np

from IPython.display import clear_output
import matplotlib
#matplotlib.use("agg")
from matplotlib import pyplot as plt
#%matplotlib inline

from timeit import default_timer as timer
from datetime import timedelta
import math
import glob

from utils.wrappers import *
from utils.hyperparameters import Config
from agents.DQN import Model
from utils.plot import plot_reward

config = Config()

#algorithm control
config.USE_NOISY_NETS=False
config.USE_PRIORITY_REPLAY=False
config.USE_NOISY_NETS = False
config.USE_PRIORITY_REPLAY = False

#Multi-step returns
config.N_STEPS = 1

#epsilon variables
config.epsilon_start = 1.0
config.epsilon_final = 0.01
config.epsilon_decay = 500
config.epsilon_start = 1.0
config.epsilon_final = 0.01
config.epsilon_decay = 30000
config.epsilon_by_frame = lambda frame_idx: config.epsilon_final + (config.epsilon_start - config.epsilon_final) * math.exp(-1. * frame_idx / config.epsilon_decay)

#misc agent variables
config.GAMMA=0.99
config.LR=1e-4
config.GAMMA = 0.99
config.LR = 1e-4

#memory
config.TARGET_NET_UPDATE_FREQ = 128
config.EXP_REPLAY_SIZE = 10000
config.BATCH_SIZE = 32
config.PRIORITY_ALPHA=0.6
config.PRIORITY_BETA_START=0.4
config.TARGET_NET_UPDATE_FREQ = 1000
config.EXP_REPLAY_SIZE = 100000
config.BATCH_SIZE = 32

config.PRIORITY_ALPHA = 0.6
config.PRIORITY_BETA_START = 0.4
config.PRIORITY_BETA_FRAMES = 100000

#Noisy Nets
config.SIGMA_INIT=0.5
config.SIGMA_INIT = 0.5

#Learning control variables
config.LEARN_START = config.BATCH_SIZE*2
config.MAX_FRAMES=100000
config.LEARN_START = 10000
config.MAX_FRAMES = 1000000
config.UPDATE_FREQ = 1

#Categorical Params
config.ATOMS = 51
config.V_MAX = 50
config.V_MIN = 0

#Quantile Regression Parameters
config.QUANTILES=21
config.QUANTILES = 21

#DRQN Parameters
config.SEQUENCE_LENGTH=8


def plot(frame_idx, rewards, losses, sigma, elapsed_time):
clear_output(True)
plt.figure(figsize=(20,5))
plt.subplot(131)
plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))
plt.plot(rewards)
if losses:
plt.subplot(132)
plt.title('loss')
plt.plot(losses)
if sigma:
plt.subplot(133)
plt.title('noisy param magnitude')
plt.plot(sigma)
plt.show()
print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))

config.SEQUENCE_LENGTH = 8

if __name__=='__main__':
start=timer()

'''env_id = "PongNoFrameskip-v4"
log_dir = "/tmp/gym/"
try:
os.makedirs(log_dir)
except OSError:
files = glob.glob(os.path.join(log_dir, '*.monitor.csv'))
for f in files:
os.remove(f)

env_id = "PongNoFrameskip-v4"
env = make_atari(env_id)
env = wrap_deepmind(env, frame_stack=False)
env = wrap_pytorch(env)'''
env = gym.make('CartPole-v0')
#env = wrappers.Monitor(env, 'Delete', force=True)
model = Model(env=env, config=config)
env = bench.Monitor(env, os.path.join(log_dir, env_id))
env = wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=True)
env = WrapPyTorch(env)
model = Model(env=env, config=config)

episode_reward = 0

Expand All @@ -111,9 +97,12 @@ def plot(frame_idx, rewards, losses, sigma, elapsed_time):
model.save_reward(episode_reward)
episode_reward = 0


if frame_idx % 10000 == 0:
plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start)))
try:
print('frame %s. time: %s' % (frame_idx, timedelta(seconds=int(timer()-start))))
plot_reward(log_dir, env_id, 'DRQN', config.MAX_FRAMES, bin_size=10, smooth=1, time=timedelta(seconds=int(timer()-start)), ipynb=False)
except IOError:
pass

model.save_w()
env.close()
Binary file added results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified saved_agents/model.dump
Binary file not shown.
Binary file modified saved_agents/optim.dump
Binary file not shown.
1 change: 1 addition & 0 deletions utils/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self):
#Learning control variables
self.LEARN_START = 10000
self.MAX_FRAMES=100000
self.UPDATE_FREQ = 1

#Categorical Params
self.ATOMS = 51
Expand Down
35 changes: 34 additions & 1 deletion utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,37 @@ def plot(folder, game, name, num_steps, bin_size=100, smooth=1):

plt.title(game)
plt.legend(loc=4)
plt.show()
plt.show()

def plot_reward(folder, game, name, num_steps, bin_size=10, smooth=1, time=None, save_filename='results.png', ipynb=False):
matplotlib.rcParams.update({'font.size': 20})
tx, ty = load_data(folder, smooth, bin_size)

if tx is None or ty is None:
return

fig = plt.figure(figsize=(20,5))
plt.plot(tx, ty, label="{}".format(name))

tick_fractions = np.array([0.1, 0.2, 0.4, 0.6, 0.8, 1.0])
ticks = tick_fractions * num_steps
tick_names = ["{:.0e}".format(tick) for tick in ticks]
plt.xticks(ticks, tick_names)
plt.xlim(0, num_steps * 1.01)

plt.xlabel('Number of Timesteps')
plt.ylabel('Rewards')

if time is not None:
plt.title(game + ' || Last 10: ' + str(np.round(np.mean(ty[-10]))) + ' || Elapsed Time: ' + str(time))
else:
plt.title(game + ' || Last 10: ' + str(np.round(np.mean(ty[-10]))))
plt.legend(loc=4)
if ipynb:
plt.show()
else:
plt.savefig(save_filename)
plt.clf()
plt.close()

return np.round(np.mean(ty[-10]))

0 comments on commit b3d65a8

Please sign in to comment.