Skip to content

Commit

Permalink
functioning, efficient development drqn implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
qfettes committed Jun 18, 2018
1 parent 3fcc68e commit d120c1c
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 71 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Relevant Papers:
8. Rainbow: Combining Improvements in Deep Reinforcement Learning [[Publication]](https://arxiv.org/abs/1710.02298)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/8.Rainbow.ipynb)
9. Distributional Reinforcement Learning with Quantile Regression [[Publication]](https://arxiv.org/abs/1710.10044)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/9.QuantileRegression-DQN.ipynb)
10. Rainbow with Quantile Regression [[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/10.Quantile-Rainbow.ipynb)
11. Deep Recurrent Q-Learning for Partially Observable MDPs [[Publication]](https://arxiv.org/abs/1507.06527)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/11.DRQN.ipynb)


Requirements:
Expand Down
5 changes: 4 additions & 1 deletion agents/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,7 @@ def finish_nstep(self):

def huber(self, x):
cond = (x.abs() < 1.0).to(torch.float)
return 0.5 * x.pow(2) * cond + (x.abs() - 0.5) * (1 - cond)
return 0.5 * x.pow(2) * cond + (x.abs() - 0.5) * (1 - cond)

def reset_hx(self):
pass
61 changes: 45 additions & 16 deletions agents/DRQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch

from agents.DQN import Model as DQN_Agent
from networks.network_bodies import RecurrentSimpleBody
from networks.networks import DQN
from networks.networks import DRQN
from utils.ReplayMemory import RecurrentExperienceReplayMemory
from utils.hyperparameters import SEQUENCE_LENGTH, device

Expand All @@ -14,49 +13,79 @@ def __init__(self, static_policy=False, env=None):

super(Model, self).__init__(static_policy, env)

self.seq = [np.zeros(self.num_feats) for j in range(self.sequence_length)]
self.reset_hx()


def declare_networks(self):
self.model = DQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init, body=RecurrentSimpleBody)
self.target_model = DQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init, body=RecurrentSimpleBody)
self.model = DRQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init)
self.target_model = DRQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init)

def declare_memory(self):
self.memory = RecurrentExperienceReplayMemory(self.experience_replay_size, self.sequence_length)

def prep_minibatch(self):
transitions, indices, weights = self.memory.sample(self.batch_size)

transitions = [trans for seq in transitions for trans in seq] #flatten to prepare
batch_state, batch_action, batch_reward, batch_next_state = zip(*transitions)

batch_state = torch.cat(batch_state).view(self.batch_size, self.sequence_length, -1).transpose(0, 1)
batch_action = torch.cat(batch_action).view(self.batch_size, self.sequence_length, -1).transpose(0, 1)[self.sequence_length-1]
batch_reward = torch.cat(batch_reward).view(self.batch_size, self.sequence_length, -1).transpose(0, 1)[self.sequence_length-1]
batch_state = torch.tensor(batch_state, device=device, dtype=torch.float).view(self.batch_size, self.sequence_length, -1)
batch_action = torch.tensor(batch_action, device=device, dtype=torch.long).view(self.batch_size, self.sequence_length, -1)[:,self.sequence_length-1,:]
batch_reward = torch.tensor(batch_reward, device=device, dtype=torch.float).view(self.batch_size, self.sequence_length, -1)[:,self.sequence_length-1,:]
#get set of next states for end of each sequence
batch_next_state = next_states = tuple([batch_next_state[i] for i in range(len(batch_next_state)) if (i+1)%(self.sequence_length)==0])
batch_next_state = tuple([batch_next_state[i] for i in range(len(batch_next_state)) if (i+1)%(self.sequence_length)==0])

non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch_next_state)), device=device, dtype=torch.uint8)
try: #sometimes all next states are false, especially with nstep returns
non_final_next_states = torch.cat([s for s in batch_next_state if s is not None])
non_final_next_states = torch.tensor([s for s in batch_next_state if s is not None], device=device, dtype=torch.float).unsqueeze(dim=1)
empty_next_state_values = False
except:
empty_next_state_values = True

non_final_next_states = torch.cat((batch_state[1:, non_final_mask, :], non_final_next_states.unsqueeze(dim=0)))
non_final_next_states = torch.cat([batch_state[non_final_mask, 1:, :], non_final_next_states], dim=1)

return batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights

def compute_loss(self, batch_vars):
batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars

#estimate
self.model.sample_noise()
current_q_values, hx = self.model(batch_state)
hx = hx[non_final_mask]
current_q_values = current_q_values.gather(1, batch_action)

#target
with torch.no_grad():
max_next_q_values = torch.zeros(self.batch_size, device=device, dtype=torch.float).unsqueeze(dim=1)
if not empty_next_state_values:
max_next_action = self.get_max_next_state_action(non_final_next_states, hx)
self.target_model.sample_noise()
max_next, _ = self.target_model(non_final_next_states, hx)
max_next_q_values[non_final_mask] = max_next.gather(1, max_next_action)
expected_q_values = batch_reward + ((self.gamma**self.nsteps)*max_next_q_values)

diff = (expected_q_values - current_q_values)
loss = self.huber(diff)
loss = loss.mean()

return loss

def get_action(self, s, eps=0.1):
with torch.no_grad():
self.seq.pop(0)
self.seq.append(s)
if np.random.random() >= eps or self.static_policy or self.noisy:
X = torch.tensor(self.seq, device=device, dtype=torch.float)
X = torch.tensor([[s]], device=device, dtype=torch.float)
self.model.sample_noise()
a = self.model(X).max(1)[1].view(1, 1)
a, self.action_hx = self.model(X, self.action_hx)
a = a.max(1)[1]
return a.item()
else:
return np.random.randint(0, self.num_actions)

def get_max_next_state_action(self, next_states, hx):
max_next, _ = self.target_model(next_states, hx)
return max_next.max(dim=1)[1].view(-1, 1)

def reset_hx(self):
self.action_hx = self.model.init_hidden(1)


9 changes: 5 additions & 4 deletions devel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@

from utils.wrappers import *
from utils.hyperparameters import *
from agents.Categorical_DQN import Model
from agents.DRQN import Model


def plot(frame_idx, rewards, losses, elapsed_time):
#clear_output(True)
plt.figure(figsize=(20,5))
'''plt.figure(figsize=(20,5))
plt.subplot(131)
plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))
plt.plot(rewards)
plt.subplot(132)
plt.title('loss')
plt.plot(losses)
plt.show()
#print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))
plt.show()'''
print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))


if __name__=='__main__':
Expand Down Expand Up @@ -55,6 +55,7 @@ def plot(frame_idx, rewards, losses, elapsed_time):

if done:
model.finish_nstep()
model.reset_hx()
observation = env.reset()
all_rewards.append(episode_reward)
episode_reward = 0
Expand Down
42 changes: 1 addition & 41 deletions networks/network_bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,44 +52,4 @@ def feature_size(self):

def sample_noise(self):
if self.noisy:
self.fc1.sample_noise()

class RecurrentSimpleBody(nn.Module):
def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5, lstm_size=128, bidirectional=False):
super(RecurrentSimpleBody, self).__init__()

self.input_shape = input_shape
self.num_actions = num_actions
self.noisy=noisy
self.lstm_size = lstm_size
self.bidirectional = bidirectional

self.num_directions = 2 if self.bidirectional else 1

self.fc1 = nn.Linear(input_shape[0], 128) if not self.noisy else NoisyLinear(input_shape[0], 128, sigma_init)
self.gru = nn.GRUCell(128, self.lstm_size)

def forward(self, x):
batch_size = x.size(0)

hidden = self.init_hidden(batch_size)
feats = self.fc1(x[0])
hidden = self.gru(feats)
print(hidden.shape)
for i in range(1, batch_size):
feats = self.fc1(x[i])
hidden = self.gru(feats, hidden)

return hidden

def feature_size(self):
#return self.fc1(torch.zeros(1, *self.input_shape)).view(1, -1).size(1)
return self.lstm_size

def sample_noise(self):
if self.noisy:
self.fc1.sample_noise()

def init_hidden(self, batch_size):
#return (torch.zeros(self.lstm_layers*self.num_directions, batch_size, self.lstm_size, device=device, dtype=torch.float), torch.zeros(self.lstm_layers*self.num_directions, batch_size, self.lstm_size, device=device, dtype=torch.float))
return torch.zeros(4, self.lstm_size, device=device, dtype=torch.float)
self.fc1.sample_noise()
42 changes: 41 additions & 1 deletion networks/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from networks.layers import NoisyLinear
from networks.network_bodies import SimpleBody, AtariBody
from utils.hyperparameters import device

class DQN(nn.Module):
def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5, body=SimpleBody):
Expand Down Expand Up @@ -203,4 +204,43 @@ def sample_noise(self):
self.adv1.sample_noise()
self.adv2.sample_noise()
self.val1.sample_noise()
self.val2.sample_noise()
self.val2.sample_noise()


########Recurrent Architectures#########

class DRQN(nn.Module):
def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5, body=SimpleBody, gru_size=512, bidirectional=False):
super(DRQN, self).__init__()

self.input_shape = input_shape
self.num_actions = num_actions
self.noisy=noisy
self.gru_size = gru_size
self.bidirectional = bidirectional
self.num_directions = 2 if self.bidirectional else 1

self.body = body(input_shape, num_actions, noisy, sigma_init)
self.gru = nn.GRUCell(self.body.feature_size(), self.gru_size)
self.fc2 = nn.Linear(self.gru_size, self.num_actions) if not self.noisy else NoisyLinear(self.gru_size, self.num_actions, sigma_init)

def forward(self, x, hx=None):
batch_size = x.size(0)
x = x.transpose(0, 1)

hidden = self.init_hidden(batch_size) if hx is None else hx
for i in range(x.size(0)):
feats = self.body(x[i]).view(batch_size, -1)
hidden = self.gru(feats, hidden)

x = self.fc2(hidden)

return x, hidden

def sample_noise(self):
if self.noisy:
self.body.sample_noise()
self.fc2.sample_noise()

def init_hidden(self, batch_size):
return torch.zeros(batch_size, self.gru_size, device=device, dtype=torch.float)
16 changes: 8 additions & 8 deletions utils/ReplayMemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def update_priorities(self, idxes, priorities):
self._max_priority = max(self._max_priority, (priority+1e-5))


#TODO: Needs fix for sampling across episodes
class RecurrentExperienceReplayMemory:
def __init__(self, capacity, sequence_length=10):
self.capacity = capacity
Expand All @@ -179,14 +180,13 @@ def sample(self, batch_size):
begin = [x-self.seq_length for x in finish]
samp = []
for start, end in zip(begin, finish):
filler = []
if start+1 < 0: # sampling near beginning of buffer
for i in range(-1*(start+1)):
filler += (np.zeros_like(self.memory[0][0]), 0, 0, np.zeros_like(self.memory[0][3]))
final = filler+self.memory[max(start+1,0):end+1]
samp += final

return samp
final = self.memory[max(start+1,0):end+1]
while(len(final)<self.seq_length):
final = [(np.zeros_like(self.memory[0][0]), 0, 0, np.zeros_like(self.memory[0][3]))] + final
samp+=(final)

#returns flattened version
return samp, None, None

def __len__(self):
return len(self.memory)
Expand Down

0 comments on commit d120c1c

Please sign in to comment.