-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathdqn_agent.py
185 lines (151 loc) · 6.72 KB
/
dqn_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
DQN Agent for Vector Observation Learning
Example Developed By:
Michael Richardson, 2018
Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL)
Code expanded and adapted from code examples provided by Udacity DRL Team, 2018.
"""
# Import Required Packages
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from collections import namedtuple, deque
from model import QNetwork
from replay_memory import ReplayBuffer
# Determine if CPU or GPU computation should be used
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
"""
##################################################
Agent Class
Defines DQN Agent Methods
Agent interacts with and learns from an environment.
"""
class Agent():
"""
Initialize Agent, inclduing:
DQN Hyperparameters
Local and Targat State-Action Policy Networks
Replay Memory Buffer from Replay Buffer Class (define below)
"""
def __init__(self, state_size, action_size, dqn_type='DQN', replay_memory_size=1e5, batch_size=64, gamma=0.99,
learning_rate=1e-3, target_tau=2e-3, update_rate=4, seed=0):
"""
DQN Agent Parameters
======
state_size (int): dimension of each state
action_size (int): dimension of each action
dqn_type (string): can be either 'DQN' for vanillia dqn learning (default) or 'DDQN' for double-DQN.
replay_memory size (int): size of the replay memory buffer (typically 5e4 to 5e6)
batch_size (int): size of the memory batch used for model updates (typically 32, 64 or 128)
gamma (float): paramete for setting the discoun ted value of future rewards (typically .95 to .995)
learning_rate (float): specifies the rate of model learing (typically 1e-4 to 1e-3))
seed (int): random seed for initializing training point.
"""
self.dqn_type = dqn_type
self.state_size = state_size
self.action_size = action_size
self.buffer_size = int(replay_memory_size)
self.batch_size = batch_size
self.gamma = gamma
self.learn_rate = learning_rate
self.tau = target_tau
self.update_rate = update_rate
self.seed = random.seed(seed)
"""
# DQN Agent Q-Network
# For DQN training, two nerual network models are employed;
# (a) A network that is updated every (step % update_rate == 0)
# (b) A target network, with weights updated to equal the network at a slower (target_tau) rate.
# The slower modulation of the target network weights operates to stablize learning.
"""
self.network = QNetwork(state_size, action_size, seed).to(device)
self.target_network = QNetwork(state_size, action_size, seed).to(device)
self.optimizer = optim.Adam(self.network.parameters(), lr=self.learn_rate)
# Replay memory
self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, seed)
# Initialize time step (for updating every UPDATE_EVERY steps)
self.t_step = 0
########################################################
# STEP() method
#
def step(self, state, action, reward, next_state, done):
# Save experience in replay memory
self.memory.add(state, action, reward, next_state, done)
# Learn every UPDATE_EVERY time steps.
self.t_step = (self.t_step + 1) % self.update_rate
if self.t_step == 0:
# If enough samples are available in memory, get random subset and learn
if len(self.memory) > self.batch_size:
experiences = self.memory.sample()
self.learn(experiences, self.gamma)
########################################################
# ACT() method
#
def act(self, state, eps=0.0):
"""Returns actions for given state as per current policy.
Params
======
state (array_like): current state
eps (float): epsilon, for epsilon-greedy action selection
"""
state = torch.from_numpy(state).float().unsqueeze(0).to(device)
self.network.eval()
with torch.no_grad():
action_values = self.network(state)
self.network.train()
# Epsilon-greedy action selection
if random.random() > eps:
return np.argmax(action_values.cpu().data.numpy())
else:
return random.choice(np.arange(self.action_size))
########################################################
# LEARN() method
# Update value parameters using given batch of experience tuples.
def learn(self, experiences, gamma, DQN=True):
"""
Params
======
experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples
gamma (float): discount factor
"""
states, actions, rewards, next_states, dones = experiences
# Get Q values from current observations (s, a) using model nextwork
Qsa = self.network(states).gather(1, actions)
if (self.dqn_type == 'DDQN'):
#Double DQN
#************************
Qsa_prime_actions = self.network(next_states).detach().max(1)[1].unsqueeze(1)
Qsa_prime_targets = self.target_network(next_states)[Qsa_prime_actions].unsqueeze(1)
else:
#Regular (Vanilla) DQN
#************************
# Get max Q values for (s',a') from target model
Qsa_prime_target_values = self.target_network(next_states).detach()
Qsa_prime_targets = Qsa_prime_target_values.max(1)[0].unsqueeze(1)
# Compute Q targets for current states
Qsa_targets = rewards + (gamma * Qsa_prime_targets * (1 - dones))
# Compute loss (error)
loss = F.mse_loss(Qsa, Qsa_targets)
# Minimize the loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# ------------------- update target network ------------------- #
self.soft_update(self.network, self.target_network, self.tau)
########################################################
"""
Soft update model parameters.
θ_target = τ*θ_local + (1 - τ)*θ_target
"""
def soft_update(self, local_model, target_model, tau):
"""
Params
======
local_model (PyTorch model): weights will be copied from
target_model (PyTorch model): weights will be copied to
tau (float): interpolation parameter
"""
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)