diff --git a/README.md b/README.md index db1a1aa..9fcb044 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Relevant Papers: 8. Rainbow: Combining Improvements in Deep Reinforcement Learning [[Publication]](https://arxiv.org/abs/1710.02298)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/8.Rainbow.ipynb) 9. Distributional Reinforcement Learning with Quantile Regression [[Publication]](https://arxiv.org/abs/1710.10044)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/9.QuantileRegression-DQN.ipynb) 10. Rainbow with Quantile Regression [[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/10.Quantile-Rainbow.ipynb) +11. Deep Recurrent Q-Learning for Partially Observable MDPs [[Publication]](https://arxiv.org/abs/1507.06527)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/11.DRQN.ipynb) Requirements: diff --git a/agents/DQN.py b/agents/DQN.py index 20c4a51..c9223d3 100644 --- a/agents/DQN.py +++ b/agents/DQN.py @@ -173,4 +173,7 @@ def finish_nstep(self): def huber(self, x): cond = (x.abs() < 1.0).to(torch.float) - return 0.5 * x.pow(2) * cond + (x.abs() - 0.5) * (1 - cond) \ No newline at end of file + return 0.5 * x.pow(2) * cond + (x.abs() - 0.5) * (1 - cond) + + def reset_hx(self): + pass \ No newline at end of file diff --git a/agents/DRQN.py b/agents/DRQN.py index 064c3e7..1703f59 100644 --- a/agents/DRQN.py +++ b/agents/DRQN.py @@ -3,8 +3,7 @@ import torch from agents.DQN import Model as DQN_Agent -from networks.network_bodies import RecurrentSimpleBody -from networks.networks import DQN +from networks.networks import DRQN from utils.ReplayMemory import RecurrentExperienceReplayMemory from utils.hyperparameters import SEQUENCE_LENGTH, device @@ -14,12 +13,12 @@ def __init__(self, static_policy=False, env=None): super(Model, self).__init__(static_policy, env) - self.seq = [np.zeros(self.num_feats) for j in range(self.sequence_length)] + self.reset_hx() def declare_networks(self): - self.model = DQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init, body=RecurrentSimpleBody) - self.target_model = DQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init, body=RecurrentSimpleBody) + self.model = DRQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init) + self.target_model = DRQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init) def declare_memory(self): self.memory = RecurrentExperienceReplayMemory(self.experience_replay_size, self.sequence_length) @@ -27,36 +26,66 @@ def declare_memory(self): def prep_minibatch(self): transitions, indices, weights = self.memory.sample(self.batch_size) - transitions = [trans for seq in transitions for trans in seq] #flatten to prepare batch_state, batch_action, batch_reward, batch_next_state = zip(*transitions) - batch_state = torch.cat(batch_state).view(self.batch_size, self.sequence_length, -1).transpose(0, 1) - batch_action = torch.cat(batch_action).view(self.batch_size, self.sequence_length, -1).transpose(0, 1)[self.sequence_length-1] - batch_reward = torch.cat(batch_reward).view(self.batch_size, self.sequence_length, -1).transpose(0, 1)[self.sequence_length-1] + batch_state = torch.tensor(batch_state, device=device, dtype=torch.float).view(self.batch_size, self.sequence_length, -1) + batch_action = torch.tensor(batch_action, device=device, dtype=torch.long).view(self.batch_size, self.sequence_length, -1)[:,self.sequence_length-1,:] + batch_reward = torch.tensor(batch_reward, device=device, dtype=torch.float).view(self.batch_size, self.sequence_length, -1)[:,self.sequence_length-1,:] #get set of next states for end of each sequence - batch_next_state = next_states = tuple([batch_next_state[i] for i in range(len(batch_next_state)) if (i+1)%(self.sequence_length)==0]) + batch_next_state = tuple([batch_next_state[i] for i in range(len(batch_next_state)) if (i+1)%(self.sequence_length)==0]) non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch_next_state)), device=device, dtype=torch.uint8) try: #sometimes all next states are false, especially with nstep returns - non_final_next_states = torch.cat([s for s in batch_next_state if s is not None]) + non_final_next_states = torch.tensor([s for s in batch_next_state if s is not None], device=device, dtype=torch.float).unsqueeze(dim=1) empty_next_state_values = False except: empty_next_state_values = True - non_final_next_states = torch.cat((batch_state[1:, non_final_mask, :], non_final_next_states.unsqueeze(dim=0))) + non_final_next_states = torch.cat([batch_state[non_final_mask, 1:, :], non_final_next_states], dim=1) return batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights + def compute_loss(self, batch_vars): + batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars + + #estimate + self.model.sample_noise() + current_q_values, hx = self.model(batch_state) + hx = hx[non_final_mask] + current_q_values = current_q_values.gather(1, batch_action) + + #target + with torch.no_grad(): + max_next_q_values = torch.zeros(self.batch_size, device=device, dtype=torch.float).unsqueeze(dim=1) + if not empty_next_state_values: + max_next_action = self.get_max_next_state_action(non_final_next_states, hx) + self.target_model.sample_noise() + max_next, _ = self.target_model(non_final_next_states, hx) + max_next_q_values[non_final_mask] = max_next.gather(1, max_next_action) + expected_q_values = batch_reward + ((self.gamma**self.nsteps)*max_next_q_values) + + diff = (expected_q_values - current_q_values) + loss = self.huber(diff) + loss = loss.mean() + + return loss + def get_action(self, s, eps=0.1): with torch.no_grad(): - self.seq.pop(0) - self.seq.append(s) if np.random.random() >= eps or self.static_policy or self.noisy: - X = torch.tensor(self.seq, device=device, dtype=torch.float) + X = torch.tensor([[s]], device=device, dtype=torch.float) self.model.sample_noise() - a = self.model(X).max(1)[1].view(1, 1) + a, self.action_hx = self.model(X, self.action_hx) + a = a.max(1)[1] return a.item() else: return np.random.randint(0, self.num_actions) + def get_max_next_state_action(self, next_states, hx): + max_next, _ = self.target_model(next_states, hx) + return max_next.max(dim=1)[1].view(-1, 1) + + def reset_hx(self): + self.action_hx = self.model.init_hidden(1) + \ No newline at end of file diff --git a/devel.py b/devel.py index 2917549..dddc553 100644 --- a/devel.py +++ b/devel.py @@ -10,20 +10,20 @@ from utils.wrappers import * from utils.hyperparameters import * -from agents.Categorical_DQN import Model +from agents.DRQN import Model def plot(frame_idx, rewards, losses, elapsed_time): #clear_output(True) - plt.figure(figsize=(20,5)) + '''plt.figure(figsize=(20,5)) plt.subplot(131) plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time)) plt.plot(rewards) plt.subplot(132) plt.title('loss') plt.plot(losses) - plt.show() - #print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time)) + plt.show()''' + print('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time)) if __name__=='__main__': @@ -55,6 +55,7 @@ def plot(frame_idx, rewards, losses, elapsed_time): if done: model.finish_nstep() + model.reset_hx() observation = env.reset() all_rewards.append(episode_reward) episode_reward = 0 diff --git a/networks/network_bodies.py b/networks/network_bodies.py index bbae156..fcf421a 100644 --- a/networks/network_bodies.py +++ b/networks/network_bodies.py @@ -52,44 +52,4 @@ def feature_size(self): def sample_noise(self): if self.noisy: - self.fc1.sample_noise() - -class RecurrentSimpleBody(nn.Module): - def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5, lstm_size=128, bidirectional=False): - super(RecurrentSimpleBody, self).__init__() - - self.input_shape = input_shape - self.num_actions = num_actions - self.noisy=noisy - self.lstm_size = lstm_size - self.bidirectional = bidirectional - - self.num_directions = 2 if self.bidirectional else 1 - - self.fc1 = nn.Linear(input_shape[0], 128) if not self.noisy else NoisyLinear(input_shape[0], 128, sigma_init) - self.gru = nn.GRUCell(128, self.lstm_size) - - def forward(self, x): - batch_size = x.size(0) - - hidden = self.init_hidden(batch_size) - feats = self.fc1(x[0]) - hidden = self.gru(feats) - print(hidden.shape) - for i in range(1, batch_size): - feats = self.fc1(x[i]) - hidden = self.gru(feats, hidden) - - return hidden - - def feature_size(self): - #return self.fc1(torch.zeros(1, *self.input_shape)).view(1, -1).size(1) - return self.lstm_size - - def sample_noise(self): - if self.noisy: - self.fc1.sample_noise() - - def init_hidden(self, batch_size): - #return (torch.zeros(self.lstm_layers*self.num_directions, batch_size, self.lstm_size, device=device, dtype=torch.float), torch.zeros(self.lstm_layers*self.num_directions, batch_size, self.lstm_size, device=device, dtype=torch.float)) - return torch.zeros(4, self.lstm_size, device=device, dtype=torch.float) \ No newline at end of file + self.fc1.sample_noise() \ No newline at end of file diff --git a/networks/networks.py b/networks/networks.py index 07a8300..39a3156 100644 --- a/networks/networks.py +++ b/networks/networks.py @@ -4,6 +4,7 @@ from networks.layers import NoisyLinear from networks.network_bodies import SimpleBody, AtariBody +from utils.hyperparameters import device class DQN(nn.Module): def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5, body=SimpleBody): @@ -203,4 +204,43 @@ def sample_noise(self): self.adv1.sample_noise() self.adv2.sample_noise() self.val1.sample_noise() - self.val2.sample_noise() \ No newline at end of file + self.val2.sample_noise() + + +########Recurrent Architectures######### + +class DRQN(nn.Module): + def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5, body=SimpleBody, gru_size=512, bidirectional=False): + super(DRQN, self).__init__() + + self.input_shape = input_shape + self.num_actions = num_actions + self.noisy=noisy + self.gru_size = gru_size + self.bidirectional = bidirectional + self.num_directions = 2 if self.bidirectional else 1 + + self.body = body(input_shape, num_actions, noisy, sigma_init) + self.gru = nn.GRUCell(self.body.feature_size(), self.gru_size) + self.fc2 = nn.Linear(self.gru_size, self.num_actions) if not self.noisy else NoisyLinear(self.gru_size, self.num_actions, sigma_init) + + def forward(self, x, hx=None): + batch_size = x.size(0) + x = x.transpose(0, 1) + + hidden = self.init_hidden(batch_size) if hx is None else hx + for i in range(x.size(0)): + feats = self.body(x[i]).view(batch_size, -1) + hidden = self.gru(feats, hidden) + + x = self.fc2(hidden) + + return x, hidden + + def sample_noise(self): + if self.noisy: + self.body.sample_noise() + self.fc2.sample_noise() + + def init_hidden(self, batch_size): + return torch.zeros(batch_size, self.gru_size, device=device, dtype=torch.float) \ No newline at end of file diff --git a/utils/ReplayMemory.py b/utils/ReplayMemory.py index a7c7ff4..32e8c46 100644 --- a/utils/ReplayMemory.py +++ b/utils/ReplayMemory.py @@ -163,6 +163,7 @@ def update_priorities(self, idxes, priorities): self._max_priority = max(self._max_priority, (priority+1e-5)) +#TODO: Needs fix for sampling across episodes class RecurrentExperienceReplayMemory: def __init__(self, capacity, sequence_length=10): self.capacity = capacity @@ -179,14 +180,13 @@ def sample(self, batch_size): begin = [x-self.seq_length for x in finish] samp = [] for start, end in zip(begin, finish): - filler = [] - if start+1 < 0: # sampling near beginning of buffer - for i in range(-1*(start+1)): - filler += (np.zeros_like(self.memory[0][0]), 0, 0, np.zeros_like(self.memory[0][3])) - final = filler+self.memory[max(start+1,0):end+1] - samp += final - - return samp + final = self.memory[max(start+1,0):end+1] + while(len(final)