-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
removed noisy nets from vanilla quantile regression implementation
changed naming convention of notebooks implemented rainbow with quantile regression
- Loading branch information
Showing
15 changed files
with
3,750 additions
and
6 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,365 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Rainbow with Quantile Regression" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import gym\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"import torch\n", | ||
"import torch.optim as optim\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"\n", | ||
"from IPython.display import clear_output\n", | ||
"from matplotlib import pyplot as plt\n", | ||
"%matplotlib inline\n", | ||
"\n", | ||
"from timeit import default_timer as timer\n", | ||
"from datetime import timedelta\n", | ||
"import math\n", | ||
"\n", | ||
"from utils.wrappers import *\n", | ||
"from agents.DQN import Model as DQN_Agent\n", | ||
"from utils.ReplayMemory import PrioritizedReplayMemory\n", | ||
"from networks.layers import NoisyLinear" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Hyperparameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
"\n", | ||
"#Nstep controls\n", | ||
"N_STEPS=3\n", | ||
"\n", | ||
"#epsilon variables\n", | ||
"SIGMA_INIT=0.5\n", | ||
"\n", | ||
"#misc agent variables\n", | ||
"GAMMA=0.99\n", | ||
"LR=1e-4\n", | ||
"\n", | ||
"#memory\n", | ||
"TARGET_NET_UPDATE_FREQ = 1000\n", | ||
"EXP_REPLAY_SIZE = 100000\n", | ||
"BATCH_SIZE = 32\n", | ||
"PRIORITY_ALPHA=0.6\n", | ||
"PRIORITY_BETA_START=0.4\n", | ||
"PRIORITY_BETA_FRAMES = 100000\n", | ||
"\n", | ||
"#Learning control variables\n", | ||
"LEARN_START = 10000\n", | ||
"MAX_FRAMES=700000\n", | ||
"\n", | ||
"#Quantile Regression Parameters\n", | ||
"QUANTILES=51" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Network" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class DuelingQRDQN(nn.Module):\n", | ||
" def __init__(self, input_shape, num_outputs, sigma_init=0.5, quantiles=51):\n", | ||
" super(DuelingQRDQN, self).__init__()\n", | ||
" \n", | ||
" self.input_shape = input_shape\n", | ||
" self.num_actions = num_outputs\n", | ||
" self.quantiles=quantiles\n", | ||
"\n", | ||
" self.conv1 = nn.Conv2d(self.input_shape[0], 32, kernel_size=8, stride=4)\n", | ||
" self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)\n", | ||
" self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)\n", | ||
"\n", | ||
" self.adv1 = NoisyLinear(self.feature_size(), 512, sigma_init)\n", | ||
" self.adv2 = NoisyLinear(512, self.num_actions*self.quantiles, sigma_init)\n", | ||
"\n", | ||
" self.val1 = NoisyLinear(self.feature_size(), 512, sigma_init)\n", | ||
" self.val2 = NoisyLinear(512, 1*self.quantiles, sigma_init)\n", | ||
"\n", | ||
" \n", | ||
" def forward(self, x):\n", | ||
" x = F.relu(self.conv1(x))\n", | ||
" x = F.relu(self.conv2(x))\n", | ||
" x = F.relu(self.conv3(x))\n", | ||
" x = x.view(x.size(0), -1)\n", | ||
"\n", | ||
" adv = F.relu(self.adv1(x))\n", | ||
" adv = self.adv2(adv).view(-1, self.num_actions, self.quantiles)\n", | ||
"\n", | ||
" val = F.relu(self.val1(x))\n", | ||
" val = self.val2(val).view(-1, 1, self.quantiles)\n", | ||
"\n", | ||
" final = val + adv - adv.mean(dim=1).view(-1, 1, self.quantiles)\n", | ||
"\n", | ||
" return final\n", | ||
" \n", | ||
" def sample_noise(self):\n", | ||
" self.adv1.sample_noise()\n", | ||
" self.adv2.sample_noise()\n", | ||
" self.val1.sample_noise()\n", | ||
" self.val2.sample_noise()\n", | ||
" \n", | ||
" def feature_size(self):\n", | ||
" return self.conv3(self.conv2(self.conv1(torch.zeros(1, *self.input_shape)))).view(1, -1).size(1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Agent" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class Model(DQN_Agent):\n", | ||
" def __init__(self, static_policy=False, env=None):\n", | ||
" self.gamma=GAMMA\n", | ||
" self.lr = LR\n", | ||
" self.target_net_update_freq = TARGET_NET_UPDATE_FREQ\n", | ||
" self.experience_replay_size = EXP_REPLAY_SIZE\n", | ||
" self.batch_size = BATCH_SIZE\n", | ||
" self.learn_start = LEARN_START\n", | ||
" self.sigma_init=SIGMA_INIT\n", | ||
" self.priority_beta_start = PRIORITY_BETA_START\n", | ||
" self.priority_beta_frames = PRIORITY_BETA_FRAMES\n", | ||
" self.priority_alpha = PRIORITY_ALPHA\n", | ||
" self.num_quantiles = QUANTILES\n", | ||
" self.cumulative_density = torch.tensor((2 * np.arange(self.num_quantiles) + 1) / (2.0 * self.num_quantiles), device=device, dtype=torch.float) \n", | ||
" self.quantile_weight = 1.0 / self.num_quantiles\n", | ||
"\n", | ||
" self.static_policy=static_policy\n", | ||
" self.num_feats = env.observation_space.shape\n", | ||
" self.num_actions = env.action_space.n\n", | ||
" self.env = env\n", | ||
"\n", | ||
" self.declare_networks()\n", | ||
" \n", | ||
" self.target_model.load_state_dict(self.model.state_dict())\n", | ||
" self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)\n", | ||
" \n", | ||
" self.model = self.model.to(device)\n", | ||
" self.target_model.to(device)\n", | ||
"\n", | ||
" if self.static_policy:\n", | ||
" self.model.eval()\n", | ||
" self.target_model.eval()\n", | ||
" else:\n", | ||
" self.model.train()\n", | ||
" self.target_model.train()\n", | ||
"\n", | ||
" self.update_count = 0\n", | ||
"\n", | ||
" self.declare_memory()\n", | ||
"\n", | ||
" self.nsteps = N_STEPS\n", | ||
" self.nstep_buffer = []\n", | ||
" \n", | ||
" def declare_networks(self):\n", | ||
" self.model = DuelingQRDQN(self.num_feats, self.num_actions, sigma_init=self.sigma_init, quantiles=self.num_quantiles)\n", | ||
" self.target_model = DuelingQRDQN(self.num_feats, self.num_actions, sigma_init=self.sigma_init, quantiles=self.num_quantiles)\n", | ||
" \n", | ||
" def declare_memory(self):\n", | ||
" self.memory = PrioritizedReplayMemory(self.experience_replay_size, self.priority_alpha, self.priority_beta_start, self.priority_beta_frames)\n", | ||
" \n", | ||
" def next_distribution(self, batch_vars):\n", | ||
" batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars\n", | ||
"\n", | ||
" with torch.no_grad():\n", | ||
" quantiles_next = torch.zeros((self.batch_size, self.num_quantiles), device=device, dtype=torch.float)\n", | ||
" if not empty_next_state_values:\n", | ||
" self.target_model.sample_noise()\n", | ||
" max_next_action = self.get_max_next_state_action(non_final_next_states)\n", | ||
" quantiles_next[non_final_mask] = self.target_model(non_final_next_states).gather(1, max_next_action).squeeze(dim=1)\n", | ||
"\n", | ||
" quantiles_next = batch_reward + (self.gamma * quantiles_next)\n", | ||
"\n", | ||
" return quantiles_next\n", | ||
" \n", | ||
" def compute_loss(self, batch_vars):\n", | ||
" batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars\n", | ||
"\n", | ||
" batch_action = batch_action.unsqueeze(dim=-1).expand(-1, -1, self.num_quantiles)\n", | ||
"\n", | ||
" self.model.sample_noise()\n", | ||
" quantiles = self.model(batch_state)\n", | ||
" quantiles = quantiles.gather(1, batch_action).squeeze(1)\n", | ||
"\n", | ||
" quantiles_next = self.next_distribution(batch_vars)\n", | ||
" \n", | ||
" diff = quantiles_next.t().unsqueeze(-1) - quantiles.unsqueeze(0)\n", | ||
"\n", | ||
" loss = self.huber(diff) * torch.abs(self.cumulative_density.view(1, -1) - (diff < 0).to(torch.float))\n", | ||
" loss = loss.transpose(0, 1)\n", | ||
" self.memory.update_priorities(indices, loss.detach().mean(1).sum(-1).abs().cpu().numpy().tolist())\n", | ||
" loss = loss * weights.view(self.batch_size, 1, 1)\n", | ||
" loss = loss.mean(1).sum(-1).mean()\n", | ||
"\n", | ||
" return loss\n", | ||
" \n", | ||
" def get_action(self, s):\n", | ||
" with torch.no_grad():\n", | ||
" X = torch.tensor([s], device=device, dtype=torch.float) \n", | ||
" self.model.sample_noise()\n", | ||
" a = (self.model(X) * self.quantile_weight).sum(dim=2).max(dim=1)[1]\n", | ||
" return a.item()\n", | ||
" \n", | ||
" def get_max_next_state_action(self, next_states):\n", | ||
" next_dist = self.model(next_states) * self.quantile_weight\n", | ||
" return next_dist.sum(dim=2).max(1)[1].view(next_states.size(0), 1, 1).expand(-1, -1, self.num_quantiles)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Plot Results" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def plot(frame_idx, rewards, losses, elapsed_time):\n", | ||
" clear_output(True)\n", | ||
" plt.figure(figsize=(20,5))\n", | ||
" plt.subplot(131)\n", | ||
" plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))\n", | ||
" plt.plot(rewards)\n", | ||
" plt.subplot(132)\n", | ||
" plt.title('loss')\n", | ||
" plt.plot(losses)\n", | ||
" plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Training Loop" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"start=timer()\n", | ||
"\n", | ||
"env_id = \"PongNoFrameskip-v4\"\n", | ||
"env = make_atari(env_id)\n", | ||
"env = wrap_deepmind(env, frame_stack=False)\n", | ||
"env = wrap_pytorch(env)\n", | ||
"model = Model(env=env)\n", | ||
"\n", | ||
"losses = []\n", | ||
"all_rewards = []\n", | ||
"episode_reward = 0\n", | ||
"\n", | ||
"observation = env.reset()\n", | ||
"for frame_idx in range(1, MAX_FRAMES + 1):\n", | ||
" action = model.get_action(observation)\n", | ||
" prev_observation=observation\n", | ||
" observation, reward, done, _ = env.step(action)\n", | ||
" observation = None if done else observation\n", | ||
"\n", | ||
" loss = model.update(prev_observation, action, reward, observation, frame_idx)\n", | ||
" episode_reward += reward\n", | ||
"\n", | ||
" if done:\n", | ||
" model.finish_nstep()\n", | ||
" observation = env.reset()\n", | ||
" all_rewards.append(episode_reward)\n", | ||
" episode_reward = 0\n", | ||
" \n", | ||
" if np.mean(all_rewards[-10:]) > 20:\n", | ||
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n", | ||
" break\n", | ||
"\n", | ||
" if loss is not None:\n", | ||
" losses.append(loss)\n", | ||
"\n", | ||
" if frame_idx % 10000 == 0:\n", | ||
" plot(frame_idx, all_rewards, losses, timedelta(seconds=int(timer()-start)))\n", | ||
"\n", | ||
"\n", | ||
"env.close()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.