Skip to content

Commit

Permalink
Updated plotting code and added model saving code
Browse files Browse the repository at this point in the history
  • Loading branch information
qfettes committed Jun 20, 2018
1 parent 0d8d683 commit 618d8de
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 219 deletions.
34 changes: 20 additions & 14 deletions 01.DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
"import math\n",
"from utils.wrappers import make_atari, wrap_deepmind, wrap_pytorch\n",
"\n",
"from utils.hyperparameters import Config"
"from utils.hyperparameters import Config\n",
"from agents import BaseAgent"
]
},
{
Expand Down Expand Up @@ -164,7 +165,7 @@
"metadata": {},
"outputs": [],
"source": [
"class Model(object):\n",
"class Model(BaseAgent):\n",
" def __init__(self, static_policy=False, env=None, config=None):\n",
" super(Model, self).__init__()\n",
" self.device = config.device\n",
Expand Down Expand Up @@ -276,7 +277,8 @@
" self.optimizer.step()\n",
"\n",
" self.update_target_model()\n",
" return loss.item()\n",
" self.save_loss(loss.item())\n",
" self.save_sigma_param_magnitudes()\n",
"\n",
"\n",
" def get_action(self, s, eps=0.1):\n",
Expand Down Expand Up @@ -315,15 +317,20 @@
"metadata": {},
"outputs": [],
"source": [
"def plot(frame_idx, rewards, losses, elapsed_time):\n",
"def plot(frame_idx, rewards, losses, sigma, elapsed_time):\n",
" clear_output(True)\n",
" plt.figure(figsize=(20,5))\n",
" plt.subplot(131)\n",
" plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))\n",
" plt.plot(rewards)\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if losses:\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if sigma:\n",
" plt.subplot(133)\n",
" plt.title('noisy param magnitude')\n",
" plt.plot(sigma)\n",
" plt.show()"
]
},
Expand Down Expand Up @@ -372,25 +379,24 @@
" observation, reward, done, _ = env.step(action)\n",
" observation = None if done else observation\n",
"\n",
" loss = model.update(prev_observation, action, reward, observation, frame_idx)\n",
" model.update(prev_observation, action, reward, observation, frame_idx)\n",
" episode_reward += reward\n",
"\n",
" if done:\n",
" model.finish_nstep()\n",
" model.reset_hx()\n",
" observation = env.reset()\n",
" all_rewards.append(episode_reward)\n",
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
" if loss is not None:\n",
" losses.append(loss)\n",
"\n",
" if frame_idx % 10000 == 0:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
"\n",
" plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start)))\n",
"\n",
"model.save_w()\n",
"env.close()"
]
},
Expand Down
33 changes: 19 additions & 14 deletions 02.NStep_DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"from networks.network_bodies import AtariBody\n",
"from utils.ReplayMemory import ExperienceReplayMemory\n",
"\n",
"from utils.hyperparameters import Config"
"from utils.hyperparameters import Config\n",
"from agents import BaseAgent"
]
},
{
Expand Down Expand Up @@ -94,7 +95,7 @@
"metadata": {},
"outputs": [],
"source": [
"class Model(object):\n",
"class Model(BaseAgent):\n",
" def __init__(self, static_policy=False, env=None, config=None):\n",
" super(Model, self).__init__()\n",
" self.device = config.device\n",
Expand Down Expand Up @@ -221,7 +222,8 @@
" self.optimizer.step()\n",
"\n",
" self.update_target_model()\n",
" return loss.item()\n",
" self.save_loss(loss.item())\n",
" self.save_sigma_param_magnitudes()\n",
"\n",
" def get_action(self, s, eps=0.1):\n",
" with torch.no_grad():\n",
Expand Down Expand Up @@ -266,15 +268,20 @@
"metadata": {},
"outputs": [],
"source": [
"def plot(frame_idx, rewards, losses, elapsed_time):\n",
"def plot(frame_idx, rewards, losses, sigma, elapsed_time):\n",
" clear_output(True)\n",
" plt.figure(figsize=(20,5))\n",
" plt.subplot(131)\n",
" plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))\n",
" plt.plot(rewards)\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if losses:\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if sigma:\n",
" plt.subplot(133)\n",
" plt.title('noisy param magnitude')\n",
" plt.plot(sigma)\n",
" plt.show()"
]
},
Expand Down Expand Up @@ -327,26 +334,24 @@
" observation, reward, done, _ = env.step(action)\n",
" observation = None if done else observation\n",
"\n",
" loss = model.update(prev_observation, action, reward, observation, frame_idx)\n",
" model.update(prev_observation, action, reward, observation, frame_idx)\n",
" episode_reward += reward\n",
"\n",
" if done:\n",
" model.finish_nstep()\n",
" model.reset_hx()\n",
" observation = env.reset()\n",
" all_rewards.append(episode_reward)\n",
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
" if loss is not None:\n",
" losses.append(loss)\n",
"\n",
" if frame_idx % 10000 == 0:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
"\n",
" plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start)))\n",
"\n",
"model.save_w()\n",
"env.close()"
]
},
Expand Down
26 changes: 15 additions & 11 deletions 03.Double_DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,20 @@
"metadata": {},
"outputs": [],
"source": [
"def plot(frame_idx, rewards, losses, elapsed_time):\n",
"def plot(frame_idx, rewards, losses, sigma, elapsed_time):\n",
" clear_output(True)\n",
" plt.figure(figsize=(20,5))\n",
" plt.subplot(131)\n",
" plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))\n",
" plt.plot(rewards)\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if losses:\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if sigma:\n",
" plt.subplot(133)\n",
" plt.title('noisy param magnitude')\n",
" plt.plot(sigma)\n",
" plt.show()"
]
},
Expand Down Expand Up @@ -179,25 +184,24 @@
" observation, reward, done, _ = env.step(action)\n",
" observation = None if done else observation\n",
"\n",
" loss = model.update(prev_observation, action, reward, observation, frame_idx)\n",
" model.update(prev_observation, action, reward, observation, frame_idx)\n",
" episode_reward += reward\n",
"\n",
" if done:\n",
" model.finish_nstep()\n",
" model.reset_hx()\n",
" observation = env.reset()\n",
" all_rewards.append(episode_reward)\n",
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
" if loss is not None:\n",
" losses.append(loss)\n",
"\n",
" if frame_idx % 10000 == 0:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
"\n",
" plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start)))\n",
"\n",
"model.save_w()\n",
"env.close()"
]
},
Expand Down
25 changes: 14 additions & 11 deletions 04.Dueling_DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,20 @@
"metadata": {},
"outputs": [],
"source": [
"def plot(frame_idx, rewards, losses, elapsed_time):\n",
"def plot(frame_idx, rewards, losses, sigma, elapsed_time):\n",
" clear_output(True)\n",
" plt.figure(figsize=(20,5))\n",
" plt.subplot(131)\n",
" plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))\n",
" plt.plot(rewards)\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if losses:\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if sigma:\n",
" plt.subplot(133)\n",
" plt.title('noisy param magnitude')\n",
" plt.plot(sigma)\n",
" plt.show()"
]
},
Expand Down Expand Up @@ -229,26 +234,24 @@
" observation, reward, done, _ = env.step(action)\n",
" observation = None if done else observation\n",
"\n",
" loss = model.update(prev_observation, action, reward, observation, frame_idx)\n",
" model.update(prev_observation, action, reward, observation, frame_idx)\n",
" episode_reward += reward\n",
"\n",
" if done:\n",
" model.finish_nstep()\n",
" model.reset_hx()\n",
" observation = env.reset()\n",
" all_rewards.append(episode_reward)\n",
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
" if loss is not None:\n",
" losses.append(loss)\n",
"\n",
" if frame_idx % 10000 == 0:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
"\n",
" plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start)))\n",
"\n",
"model.save_w()\n",
"env.close()"
]
},
Expand Down
25 changes: 14 additions & 11 deletions 05.DQN-NoisyNets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,20 @@
"metadata": {},
"outputs": [],
"source": [
"def plot(frame_idx, rewards, losses, elapsed_time):\n",
"def plot(frame_idx, rewards, losses, sigma, elapsed_time):\n",
" clear_output(True)\n",
" plt.figure(figsize=(20,5))\n",
" plt.subplot(131)\n",
" plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))\n",
" plt.plot(rewards)\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if losses:\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if sigma:\n",
" plt.subplot(133)\n",
" plt.title('noisy param magnitude')\n",
" plt.plot(sigma)\n",
" plt.show()"
]
},
Expand Down Expand Up @@ -300,26 +305,24 @@
" observation, reward, done, _ = env.step(action)\n",
" observation = None if done else observation\n",
"\n",
" loss = model.update(prev_observation, action, reward, observation, frame_idx)\n",
" model.update(prev_observation, action, reward, observation, frame_idx)\n",
" episode_reward += reward\n",
"\n",
" if done:\n",
" model.finish_nstep()\n",
" model.reset_hx()\n",
" observation = env.reset()\n",
" all_rewards.append(episode_reward)\n",
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
" if loss is not None:\n",
" losses.append(loss)\n",
"\n",
" if frame_idx % 10000 == 0:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
"\n",
" plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start)))\n",
"\n",
"model.save_w()\n",
"env.close()"
]
},
Expand Down
25 changes: 14 additions & 11 deletions 06.DQN_PriorityReplay.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,20 @@
"metadata": {},
"outputs": [],
"source": [
"def plot(frame_idx, rewards, losses, elapsed_time):\n",
"def plot(frame_idx, rewards, losses, sigma, elapsed_time):\n",
" clear_output(True)\n",
" plt.figure(figsize=(20,5))\n",
" plt.subplot(131)\n",
" plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))\n",
" plt.plot(rewards)\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if losses:\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" if sigma:\n",
" plt.subplot(133)\n",
" plt.title('noisy param magnitude')\n",
" plt.plot(sigma)\n",
" plt.show()"
]
},
Expand Down Expand Up @@ -280,26 +285,24 @@
" observation, reward, done, _ = env.step(action)\n",
" observation = None if done else observation\n",
"\n",
" loss = model.update(prev_observation, action, reward, observation, frame_idx)\n",
" model.update(prev_observation, action, reward, observation, frame_idx)\n",
" episode_reward += reward\n",
"\n",
" if done:\n",
" model.finish_nstep()\n",
" model.reset_hx()\n",
" observation = env.reset()\n",
" all_rewards.append(episode_reward)\n",
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
" if loss is not None:\n",
" losses.append(loss)\n",
"\n",
" if frame_idx % 10000 == 0:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
"\n",
" plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer()-start)))\n",
"\n",
"model.save_w()\n",
"env.close()"
]
},
Expand Down
Loading

0 comments on commit 618d8de

Please sign in to comment.