Skip to content

Commit

Permalink
removed noisy nets from vanilla quantile regression implementation
Browse files Browse the repository at this point in the history
changed naming convention of notebooks
implemented rainbow with quantile regression
  • Loading branch information
qfettes committed Jun 15, 2018
1 parent 0287420 commit 1e07c90
Show file tree
Hide file tree
Showing 15 changed files with 3,750 additions and 6 deletions.
420 changes: 420 additions & 0 deletions 01.DQN.ipynb

Large diffs are not rendered by default.

372 changes: 372 additions & 0 deletions 02.NStep_DQN.ipynb

Large diffs are not rendered by default.

266 changes: 266 additions & 0 deletions 03.Double_DQN.ipynb

Large diffs are not rendered by default.

317 changes: 317 additions & 0 deletions 04.Dueling_DQN.ipynb

Large diffs are not rendered by default.

386 changes: 386 additions & 0 deletions 05.DQN-NoisyNets.ipynb

Large diffs are not rendered by default.

370 changes: 370 additions & 0 deletions 06.DQN_PriorityReplay.ipynb

Large diffs are not rendered by default.

370 changes: 370 additions & 0 deletions 07.Categorical-DQN.ipynb

Large diffs are not rendered by default.

392 changes: 392 additions & 0 deletions 08.Rainbow.ipynb

Large diffs are not rendered by default.

373 changes: 373 additions & 0 deletions 09.QuantileRegression-DQN.ipynb

Large diffs are not rendered by default.

365 changes: 365 additions & 0 deletions 10.Rainbow-QuantileRegression.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Rainbow with Quantile Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.optim as optim\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"from IPython.display import clear_output\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"from timeit import default_timer as timer\n",
"from datetime import timedelta\n",
"import math\n",
"\n",
"from utils.wrappers import *\n",
"from agents.DQN import Model as DQN_Agent\n",
"from utils.ReplayMemory import PrioritizedReplayMemory\n",
"from networks.layers import NoisyLinear"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"#Nstep controls\n",
"N_STEPS=3\n",
"\n",
"#epsilon variables\n",
"SIGMA_INIT=0.5\n",
"\n",
"#misc agent variables\n",
"GAMMA=0.99\n",
"LR=1e-4\n",
"\n",
"#memory\n",
"TARGET_NET_UPDATE_FREQ = 1000\n",
"EXP_REPLAY_SIZE = 100000\n",
"BATCH_SIZE = 32\n",
"PRIORITY_ALPHA=0.6\n",
"PRIORITY_BETA_START=0.4\n",
"PRIORITY_BETA_FRAMES = 100000\n",
"\n",
"#Learning control variables\n",
"LEARN_START = 10000\n",
"MAX_FRAMES=700000\n",
"\n",
"#Quantile Regression Parameters\n",
"QUANTILES=51"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class DuelingQRDQN(nn.Module):\n",
" def __init__(self, input_shape, num_outputs, sigma_init=0.5, quantiles=51):\n",
" super(DuelingQRDQN, self).__init__()\n",
" \n",
" self.input_shape = input_shape\n",
" self.num_actions = num_outputs\n",
" self.quantiles=quantiles\n",
"\n",
" self.conv1 = nn.Conv2d(self.input_shape[0], 32, kernel_size=8, stride=4)\n",
" self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)\n",
" self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)\n",
"\n",
" self.adv1 = NoisyLinear(self.feature_size(), 512, sigma_init)\n",
" self.adv2 = NoisyLinear(512, self.num_actions*self.quantiles, sigma_init)\n",
"\n",
" self.val1 = NoisyLinear(self.feature_size(), 512, sigma_init)\n",
" self.val2 = NoisyLinear(512, 1*self.quantiles, sigma_init)\n",
"\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.conv1(x))\n",
" x = F.relu(self.conv2(x))\n",
" x = F.relu(self.conv3(x))\n",
" x = x.view(x.size(0), -1)\n",
"\n",
" adv = F.relu(self.adv1(x))\n",
" adv = self.adv2(adv).view(-1, self.num_actions, self.quantiles)\n",
"\n",
" val = F.relu(self.val1(x))\n",
" val = self.val2(val).view(-1, 1, self.quantiles)\n",
"\n",
" final = val + adv - adv.mean(dim=1).view(-1, 1, self.quantiles)\n",
"\n",
" return final\n",
" \n",
" def sample_noise(self):\n",
" self.adv1.sample_noise()\n",
" self.adv2.sample_noise()\n",
" self.val1.sample_noise()\n",
" self.val2.sample_noise()\n",
" \n",
" def feature_size(self):\n",
" return self.conv3(self.conv2(self.conv1(torch.zeros(1, *self.input_shape)))).view(1, -1).size(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class Model(DQN_Agent):\n",
" def __init__(self, static_policy=False, env=None):\n",
" self.gamma=GAMMA\n",
" self.lr = LR\n",
" self.target_net_update_freq = TARGET_NET_UPDATE_FREQ\n",
" self.experience_replay_size = EXP_REPLAY_SIZE\n",
" self.batch_size = BATCH_SIZE\n",
" self.learn_start = LEARN_START\n",
" self.sigma_init=SIGMA_INIT\n",
" self.priority_beta_start = PRIORITY_BETA_START\n",
" self.priority_beta_frames = PRIORITY_BETA_FRAMES\n",
" self.priority_alpha = PRIORITY_ALPHA\n",
" self.num_quantiles = QUANTILES\n",
" self.cumulative_density = torch.tensor((2 * np.arange(self.num_quantiles) + 1) / (2.0 * self.num_quantiles), device=device, dtype=torch.float) \n",
" self.quantile_weight = 1.0 / self.num_quantiles\n",
"\n",
" self.static_policy=static_policy\n",
" self.num_feats = env.observation_space.shape\n",
" self.num_actions = env.action_space.n\n",
" self.env = env\n",
"\n",
" self.declare_networks()\n",
" \n",
" self.target_model.load_state_dict(self.model.state_dict())\n",
" self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)\n",
" \n",
" self.model = self.model.to(device)\n",
" self.target_model.to(device)\n",
"\n",
" if self.static_policy:\n",
" self.model.eval()\n",
" self.target_model.eval()\n",
" else:\n",
" self.model.train()\n",
" self.target_model.train()\n",
"\n",
" self.update_count = 0\n",
"\n",
" self.declare_memory()\n",
"\n",
" self.nsteps = N_STEPS\n",
" self.nstep_buffer = []\n",
" \n",
" def declare_networks(self):\n",
" self.model = DuelingQRDQN(self.num_feats, self.num_actions, sigma_init=self.sigma_init, quantiles=self.num_quantiles)\n",
" self.target_model = DuelingQRDQN(self.num_feats, self.num_actions, sigma_init=self.sigma_init, quantiles=self.num_quantiles)\n",
" \n",
" def declare_memory(self):\n",
" self.memory = PrioritizedReplayMemory(self.experience_replay_size, self.priority_alpha, self.priority_beta_start, self.priority_beta_frames)\n",
" \n",
" def next_distribution(self, batch_vars):\n",
" batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars\n",
"\n",
" with torch.no_grad():\n",
" quantiles_next = torch.zeros((self.batch_size, self.num_quantiles), device=device, dtype=torch.float)\n",
" if not empty_next_state_values:\n",
" self.target_model.sample_noise()\n",
" max_next_action = self.get_max_next_state_action(non_final_next_states)\n",
" quantiles_next[non_final_mask] = self.target_model(non_final_next_states).gather(1, max_next_action).squeeze(dim=1)\n",
"\n",
" quantiles_next = batch_reward + (self.gamma * quantiles_next)\n",
"\n",
" return quantiles_next\n",
" \n",
" def compute_loss(self, batch_vars):\n",
" batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars\n",
"\n",
" batch_action = batch_action.unsqueeze(dim=-1).expand(-1, -1, self.num_quantiles)\n",
"\n",
" self.model.sample_noise()\n",
" quantiles = self.model(batch_state)\n",
" quantiles = quantiles.gather(1, batch_action).squeeze(1)\n",
"\n",
" quantiles_next = self.next_distribution(batch_vars)\n",
" \n",
" diff = quantiles_next.t().unsqueeze(-1) - quantiles.unsqueeze(0)\n",
"\n",
" loss = self.huber(diff) * torch.abs(self.cumulative_density.view(1, -1) - (diff < 0).to(torch.float))\n",
" loss = loss.transpose(0, 1)\n",
" self.memory.update_priorities(indices, loss.detach().mean(1).sum(-1).abs().cpu().numpy().tolist())\n",
" loss = loss * weights.view(self.batch_size, 1, 1)\n",
" loss = loss.mean(1).sum(-1).mean()\n",
"\n",
" return loss\n",
" \n",
" def get_action(self, s):\n",
" with torch.no_grad():\n",
" X = torch.tensor([s], device=device, dtype=torch.float) \n",
" self.model.sample_noise()\n",
" a = (self.model(X) * self.quantile_weight).sum(dim=2).max(dim=1)[1]\n",
" return a.item()\n",
" \n",
" def get_max_next_state_action(self, next_states):\n",
" next_dist = self.model(next_states) * self.quantile_weight\n",
" return next_dist.sum(dim=2).max(1)[1].view(next_states.size(0), 1, 1).expand(-1, -1, self.num_quantiles)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot Results"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def plot(frame_idx, rewards, losses, elapsed_time):\n",
" clear_output(True)\n",
" plt.figure(figsize=(20,5))\n",
" plt.subplot(131)\n",
" plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))\n",
" plt.plot(rewards)\n",
" plt.subplot(132)\n",
" plt.title('loss')\n",
" plt.plot(losses)\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"start=timer()\n",
"\n",
"env_id = \"PongNoFrameskip-v4\"\n",
"env = make_atari(env_id)\n",
"env = wrap_deepmind(env, frame_stack=False)\n",
"env = wrap_pytorch(env)\n",
"model = Model(env=env)\n",
"\n",
"losses = []\n",
"all_rewards = []\n",
"episode_reward = 0\n",
"\n",
"observation = env.reset()\n",
"for frame_idx in range(1, MAX_FRAMES + 1):\n",
" action = model.get_action(observation)\n",
" prev_observation=observation\n",
" observation, reward, done, _ = env.step(action)\n",
" observation = None if done else observation\n",
"\n",
" loss = model.update(prev_observation, action, reward, observation, frame_idx)\n",
" episode_reward += reward\n",
"\n",
" if done:\n",
" model.finish_nstep()\n",
" observation = env.reset()\n",
" all_rewards.append(episode_reward)\n",
" episode_reward = 0\n",
" \n",
" if np.mean(all_rewards[-10:]) > 20:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
" break\n",
"\n",
" if loss is not None:\n",
" losses.append(loss)\n",
"\n",
" if frame_idx % 10000 == 0:\n",
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n",
"\n",
"\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 1e07c90

Please sign in to comment.