-
Notifications
You must be signed in to change notification settings - Fork 13
/
q_learn_utils.py
173 lines (146 loc) · 6.42 KB
/
q_learn_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import torch
import torch.nn as nn
import numpy as np
import cv2
import gym
import gym.spaces
import collections
class DQN(nn.Module):
def __init__(self, input_shape, n_actions):
super(DQN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def _get_conv_out(self, shape):
# size calculation
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
conv_out = self.conv(x).view(x.size()[0], -1)
return self.fc(conv_out)
#---------------------------------------------------------
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
"""Return only every `skip`-th frame"""
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
# we will run across "skip" interactions, and take a max pool across
# the last two interactions
# we will take the same action four times in a row
self._obs_buffer = collections.deque(maxlen=2)
self._skip = skip
def step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
# max pool across the stacked observation frames
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
#---------------------------------------------------------
class FireResetEnv(gym.Wrapper):
def __init__(self, env=None):
"""For environments where the user need to press FIRE for the game to start."""
super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def step(self, action):
return self.env.step(action)
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
#---------------------------------------------------------
class ProcessFrame84(gym.ObservationWrapper):
def __init__(self, env=None):
super(ProcessFrame84, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def observation(self, obs):
# return resized and gray image
return ProcessFrame84.process(obs)
@staticmethod
def process(frame):
# resize the vector given to be an image
if frame.size == 210 * 160 * 3:
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
elif frame.size == 250 * 160 * 3:
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
else:
assert False, "Unknown resolution."
# make this gray scale
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
# now resize it to be smaller.
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
# take off the top and bottom (scores and boundary, not needed here)
x_t = resized_screen[18:102, :]
# this last step, this was an artifact o the 2013 GPU limitations
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
#---------------------------------------------------------
class ImageToPyTorch(gym.ObservationWrapper):
def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
old_shape = self.observation_space.shape
# remember that a box is just a tensor
# convert the timage to pytorch from numpy
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
dtype=np.float32)
def observation(self, observation):
# make observation channels first (axis 2 gets moved to axis 0
return np.moveaxis(observation, 2, 0)
#---------------------------------------------------------
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps, dtype=np.float32):
super(BufferWrapper, self).__init__(env)
self.dtype = dtype
old_space = env.observation_space
self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
old_space.high.repeat(n_steps, axis=0), dtype=dtype)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
return self.observation(self.env.reset())
def observation(self, observation):
# keep queue of observation FIFO
self.buffer[:-1] = self.buffer[1:] # out with the old
self.buffer[-1] = observation # in with the new
return self.buffer
#---------------------------------------------------------
class ScaledFloatFrame(gym.ObservationWrapper):
def observation(self, obs):
# observation becomes normalized 0-1
return np.array(obs).astype(np.float32) / 255.0
#=========================================================
def make_env(env_name):
env = gym.make(env_name) # get pong environment
env = MaxAndSkipEnv(env) # max pool across two adjacent frames
env = FireResetEnv(env) # if the game has a 'fire' and 'reset' button
env = ProcessFrame84(env) # resize and convert to gray scale
env = ImageToPyTorch(env) # make torch tensor instead of numpy
env = BufferWrapper(env, 4) # use last four consecutive frames for obs space
return ScaledFloatFrame(env) # normalize frames 0-1 rather than 0-255