Skip to content

Commit

Permalink
fixed plotting bug
Browse files Browse the repository at this point in the history
  • Loading branch information
qfettes committed Jun 20, 2018
1 parent 618d8de commit 9e9a89b
Show file tree
Hide file tree
Showing 11 changed files with 13 additions and 35 deletions.
6 changes: 2 additions & 4 deletions 01.DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"from utils.wrappers import make_atari, wrap_deepmind, wrap_pytorch\n",
"\n",
"from utils.hyperparameters import Config\n",
"from agents import BaseAgent"
"from agents.BaseAgent import BaseAgent"
]
},
{
Expand Down Expand Up @@ -366,8 +366,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -389,7 +387,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
6 changes: 2 additions & 4 deletions 02.NStep_DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"from utils.ReplayMemory import ExperienceReplayMemory\n",
"\n",
"from utils.hyperparameters import Config\n",
"from agents import BaseAgent"
"from agents.BaseAgent import BaseAgent"
]
},
{
Expand Down Expand Up @@ -321,8 +321,6 @@
"#env = wrappers.Monitor(env, 'Delete', force=True)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -344,7 +342,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 03.Double_DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -194,7 +192,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 04.Dueling_DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -244,7 +242,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 05.DQN-NoisyNets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -315,7 +313,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 06.DQN_PriorityReplay.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -295,7 +293,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 07.Categorical-DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -299,7 +297,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 08.Rainbow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -322,7 +320,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 09.QuantileRegression-DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -286,7 +284,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 10.Quantile-Rainbow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@
"env = wrap_pytorch(env)\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -303,7 +301,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions 11.DRQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,6 @@
"#env = gym.make('CartPole-v1')\n",
"model = Model(env=env, config=config)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
Expand All @@ -376,7 +374,7 @@
" model.save_reward(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 19:\n",
" if np.mean(model.rewards[-10:]) > 19:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
Expand Down

0 comments on commit 9e9a89b

Please sign in to comment.