forked from vuoristo/dqn-agent
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathDQNModel.py
197 lines (164 loc) · 6.97 KB
/
DQNModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
""" Parent class for DQNModels. Defines the interface for communication
with DQNAgent.
"""
import tensorflow as tf
import numpy as np
import cv2
class DQNModel(object):
def __init__(
self, env, learning_rate=2.5e-4, momentum=0.95, gamma=0.99, tau=0.01,
huber_loss=True, soft_updates=True, steps_to_hard_update=10000,
weights_to_load=None, train_dir='dqnrgb6', collect_summaries=True):
"""
arguments:
env -- OpenAI gym environment
keyword arguments:
learning_rate -- RMSProp learning rate
momentum -- RMSProp momentum
gamma -- Q-learning gamma. default 0.99
tau -- Soft target update rate. default 0.01
huber_loss -- Use Huber loss. default True
soft_updates -- soft target updates. default True
steps_to_hard_update -- number of steps between hard updates.
default 10000
weights_to_load -- Path to pretrained weights. default None
train_dir -- directory where summaries and variables are saved.
default train
collect_summaries -- toggle summary collection. default True
"""
self.num_actions = env.action_space.n
self.gamma = gamma
self.tau = tau
self.soft_updates = soft_updates
self.steps_to_hard_update = steps_to_hard_update
self.total_steps = 0
self.collect_summaries = collect_summaries
self.huber_loss = huber_loss
with tf.variable_scope('inputs'):
self.first_observation = tf.placeholder(
tf.float32, shape=[None] + list(self.input_shape),
name='first_observation')
self.second_observation = tf.placeholder(
tf.float32, shape=[None] + list(self.input_shape),
name='second_observation')
self.actions = tf.placeholder(
tf.int32, shape=[None, 1], name='actions')
self.rewards = tf.placeholder(
tf.float32, shape=[None, 1], name='rewards')
self.terminals_mask = tf.placeholder(
tf.float32, shape=[None, 1], name='terminals_mask')
with tf.variable_scope('online_model'):
self.online_model = self.build_net(self.first_observation, trainable=True)
with tf.variable_scope('target_model'):
self.target_model = self.build_net(self.second_observation, trainable=False)
online_qs = self.online_model['outputs']
target_qs = self.target_model['outputs']
with tf.variable_scope('train_targets'):
actions_mask = tf.one_hot(self.actions, self.num_actions,
name='actions_mask')
actions_mask = tf.reshape(actions_mask, (-1, self.num_actions))
masked_online_qs = tf.reduce_sum(actions_mask*online_qs,
reduction_indices=1)
masked_online_qs = tf.reshape(masked_online_qs, (-1, 1))
max_q_t_1 = tf.reduce_max(target_qs, reduction_indices=1)
max_q_t_1 = tf.reshape(max_q_t_1, (-1, 1))
self.delta = self.terminals_mask * (
gamma * max_q_t_1 + self.rewards
) - masked_online_qs
if self.huber_loss:
self.delta = tf.where(tf.abs(self.delta) < 1.0, 0.5 *
tf.square(self.delta), tf.abs(self.delta) - 0.5)
with tf.variable_scope('main_loss'):
self.loss = tf.reduce_mean(tf.square(self.delta), name='loss')
self.optimizer = tf.train.RMSPropOptimizer(learning_rate,
momentum=momentum,
epsilon=0.01)
self.train = self.optimizer.minimize(self.loss)
# The target model is updated towards the online model in either
# soft steps every iteration or by copying the weights all at once
# every steps_to_hard_update steps
with tf.variable_scope('target_updates'):
if self.soft_updates:
self.target_updates = self.get_soft_updates()
else:
self.target_updates = self.get_hard_updates()
self.saver = tf.train.Saver(tf.all_variables(), max_to_keep = 0)
loss_summary = tf.summary.scalar('loss', self.loss)
q_summary = tf.summary.scalar('max_q', tf.reduce_max(online_qs))
num_imgs = self.window_size
temp_img = tf.split(self.online_model['x'],num_imgs,axis=3)
img = []
img.append( tf.concat(temp_img[:] ,2))
online_input_summary = tf.summary.image("online", img[:])
self.summary_op = tf.summary.merge_all()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.666)
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
self.summary_writer = tf.summary.FileWriter(train_dir, self.sess.graph)
if weights_to_load is None:
init = tf.initialize_all_variables()
self.sess.run(init)
else:
restore = self.saver.restore(self.sess, weights_to_load)
def get_hard_updates(self):
"""Builds operations for assigning online model weights to
target model.
"""
updates = []
for key, value in self.online_model['shared_vars'].items():
W = self.target_model['shared_vars'].get(key)
update = W.assign(value)
updates.append(update)
return updates
def get_soft_updates(self):
"""Builds operations for assigning online model weights
incrementally to target model.
"""
updates = []
for key, value in self.online_model['shared_vars'].items():
W = self.target_model['shared_vars'].get(key)
update = W.assign(self.tau * value + (1. - self.tau) * W)
updates.append(update)
return updates
def do_target_updates(self):
self.sess.run(self.target_updates)
def train_net(self, ob0, actions, rewards, ob1, terminals):
"""Perform one step of gradient descent training on batch.
Also update target model if soft updates are enabled.
"""
ob0s = self.reshape_input(ob0)
ob1s = self.reshape_input(ob1)
acs = np.reshape(actions, (-1, 1))
res = np.reshape(rewards, (-1, 1))
terms = np.reshape(terminals, (-1, 1))
# Clip rewards to -1,0,1
#out_res = np.zeros_like(res)
#out_res[np.nonzero(res)] = 1. * np.sign(res[np.nonzero(res)])
out_res = res
terminals_mask = np.invert(terms) * 1
loss, _ = self.sess.run([self.loss, self.train], feed_dict={
self.first_observation:ob0s,
self.second_observation:ob1s,
self.actions:acs,
self.rewards:out_res,
self.terminals_mask:terminals_mask})
self.total_steps += 1
if self.soft_updates:
self.do_target_updates()
else:
if self.total_steps % self.steps_to_hard_update == 0:
self.do_target_updates()
if self.collect_summaries and self.total_steps % 100 == 0:
summary_str = self.sess.run(self.summary_op, feed_dict={
self.first_observation:ob0s,
self.second_observation:ob1s,
self.actions:acs,
self.rewards:out_res,
self.terminals_mask:terminals_mask})
self.summary_writer.add_summary(summary_str, self.total_steps)
if self.total_steps % 3000 == 0:
self.saver.save(self.sess,
'dqn4/model.ckpt',
global_step=self.total_steps)
def get_q_value(self, observation):
ob1s = self.reshape_input(observation)
return self.infer_online_q(ob1s)