Skip to content

Commit 546203b

Browse files
committed
super dqn agent and little change on env
1 parent cb83517 commit 546203b

9 files changed

+3276
-1
lines changed

Buffer_module.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# -*- coding: utf-8 -*-
2+
# Buffer_priority.py
3+
# author: yangrui
4+
# description:
5+
# created: 2019-11-01T11:52:44.496Z+08:00
6+
# last-modified: 2019-11-01T11:52:44.496Z+08:00
7+
8+
9+
import numpy as np
10+
11+
class SumTree(object):
12+
data_pointer = 0
13+
14+
def __init__(self, capacity):
15+
self.capacity = capacity # for all priority values
16+
self.tree = np.zeros(2 * capacity - 1)
17+
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
18+
# size: capacity - 1 size: capacity
19+
self.data = np.zeros(capacity, dtype=object) # for all transitions,格式是对象,相当于指针
20+
# [--------------data frame-------------]
21+
# size: capacity
22+
23+
def add(self, p, data):
24+
tree_idx = self.data_pointer + self.capacity - 1 # 在树的叶子节点的位置
25+
self.data[self.data_pointer] = data # update data_frame
26+
self.update(tree_idx, p) # update tree_frame
27+
28+
self.data_pointer += 1
29+
if self.data_pointer >= self.capacity: # replace when exceed the capacity
30+
self.data_pointer = 0
31+
32+
def update(self, tree_idx, p):
33+
change = p - self.tree[tree_idx]
34+
self.tree[tree_idx] = p
35+
# then propagate the change through tree
36+
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
37+
tree_idx = (tree_idx - 1) // 2
38+
self.tree[tree_idx] += change
39+
40+
def get_leaf(self, v):
41+
"""
42+
Tree structure and array storage:
43+
44+
Tree index:
45+
0 -> storing priority sum
46+
/ \
47+
1 2
48+
/ \ / \
49+
3 4 5 6 -> storing priority for transitions
50+
51+
Array type for storing:
52+
[0,1,2,3,4,5,6]
53+
"""
54+
parent_idx = 0
55+
while True: # the while loop is faster than the method in the reference code
56+
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
57+
cr_idx = cl_idx + 1
58+
if cl_idx >= len(self.tree): # reach bottom, end search
59+
leaf_idx = parent_idx
60+
break
61+
else: # downward search, always search for a higher priority node
62+
if v <= self.tree[cl_idx]:
63+
parent_idx = cl_idx
64+
else:
65+
v -= self.tree[cl_idx]
66+
parent_idx = cr_idx
67+
68+
data_idx = leaf_idx - self.capacity + 1
69+
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
70+
71+
@property
72+
def total_p(self):
73+
return self.tree[0] # the root
74+
75+
76+
class Buffer_PER(object): # stored as ( s, a, r, s_ ) in SumTree
77+
epsilon = 0.01 # small amount to avoid zero priority
78+
alpha = 0.6 # [0~1] convert the importance of TD error to priority
79+
beta = 0.4 # importance-sampling, from initial value increasing to 1
80+
beta_increment_per_sampling = 0.001
81+
abs_err_upper = 1. # clipped abs error
82+
83+
def __init__(self, capacity):
84+
self.tree = SumTree(capacity)
85+
86+
def store(self, transition):
87+
max_p = np.max(self.tree.tree[-self.tree.capacity:])
88+
if max_p == 0:
89+
max_p = self.abs_err_upper
90+
self.tree.add(max_p, transition) # set the max p for new p
91+
92+
def sample(self, n):
93+
b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))
94+
pri_seg = self.tree.total_p / n # priority segment
95+
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) # max = 1
96+
97+
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight
98+
for i in range(n):
99+
a, b = pri_seg * i, pri_seg * (i + 1)
100+
v = np.random.uniform(a, b)
101+
idx, p, data = self.tree.get_leaf(v)
102+
prob = p / self.tree.total_p
103+
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
104+
b_idx[i], b_memory[i, :] = idx, data
105+
106+
return b_idx, b_memory, ISWeights
107+
108+
def batch_update(self, tree_idx, abs_errors):
109+
abs_errors += self.epsilon # convert to abs and avoid 0
110+
clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
111+
ps = np.power(clipped_errors, self.alpha)
112+
for ti, p in zip(tree_idx, ps):
113+
self.tree.update(ti, p)
114+
115+
# 总的buffer类
116+
class Buffer():
117+
def __init__(self,n_features, buffer_type='', capacity=1e4):
118+
self.memory_size = capacity
119+
self.n_features = n_features
120+
self.type = buffer_type
121+
self.memory_counter = 0
122+
123+
if self.type == 'priority':
124+
self.memory = Buffer_PER(capacity=capacity)
125+
else:
126+
self.memory = np.zeros((self.memory_size, n_features*2+2))
127+
128+
def store(self, transition):
129+
self.memory_counter += 1
130+
131+
if self.type == 'priority':
132+
self.memory.store(transition)
133+
else:
134+
index = self.memory_counter % self.memory_size
135+
self.memory[index, :] = transition
136+
137+
138+
def sample(self, batch_size):
139+
info = None
140+
if self.type == 'priority':
141+
tree_idx, batch_memory, ISWeights = self.memory.sample(batch_size)
142+
info = (tree_idx, ISWeights)
143+
else:
144+
sample_index = np.random.choice(self.memory_size, size=batch_size) # 考虑buffer已先填满
145+
batch_memory = self.memory[sample_index, :]
146+
147+
return batch_memory, info
148+
149+
def update(self, tree_idx, td_errors):
150+
assert self.type == 'priority'
151+
self.memory.batch_update(tree_idx, td_errors)
152+
153+
154+
155+
156+
157+
158+
159+
160+
161+
162+
163+

NN_module.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -*- coding: utf-8 -*-
2+
# NN.py
3+
# author: yangrui
4+
# description:
5+
# created: 2019-10-30T16:32:31.081Z+08:00
6+
# last-modified: 2019-10-30T16:32:31.081Z+08:00
7+
8+
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
import numpy as np
13+
14+
# CNN网络
15+
class CNN_Net(nn.Module):
16+
def __init__(self, input_len, output_num, conv_size=(32, 64), fc_size=(1024, 128), out_softmax=False):
17+
super(CNN_Net, self).__init__()
18+
self.input_len = input_len
19+
self.output_num = output_num
20+
self.out_softmax = out_softmax
21+
22+
self.conv1 = nn.Sequential(
23+
nn.Conv2d(1, conv_size[0], kernel_size=3, stride=1, padding=1),
24+
# nn.BatchNorm2d(32),
25+
nn.ReLU(inplace=True)
26+
)
27+
self.conv2 = nn.Sequential(
28+
nn.Conv2d(conv_size[0], conv_size[1], kernel_size=3, stride=1, padding=1),
29+
# nn.BatchNorm2d(64),
30+
nn.ReLU(inplace=True),
31+
# nn.MaxPool2d(kernel_size=2, stride=2)
32+
)
33+
34+
self.fc1 = nn.Linear(conv_size[1] * self.input_len * self.input_len, fc_size[0])
35+
self.fc2 = nn.Linear(fc_size[0], fc_size[1])
36+
self.head = nn.Linear(fc_size[1], self.output_num)
37+
38+
def forward(self, x):
39+
x = x.reshape(-1,1,self.input_len, self.input_len)
40+
x = self.conv1(x)
41+
x = self.conv2(x)
42+
x = x.view(x.size(0), -1)
43+
x = F.relu(self.fc1(x))
44+
x = F.relu(self.fc2(x))
45+
46+
output = self.head(x)
47+
if self.out_softmax:
48+
output = F.softmax(output, dim=1) #值函数估计不应该有softmax
49+
return output
50+
51+
52+
# 全连接网络
53+
class FC_Net(nn.Module):
54+
def __init__(self, input_num, output_num, fc_size=(1024, 128), out_softmax=False):
55+
super(FC_Net, self).__init__()
56+
self.input_num = input_num
57+
self.output_num = output_num
58+
self.out_softmax = out_softmax
59+
60+
self.fc1 = nn.Linear(self.input_num, fc_size[0])
61+
self.fc2 = nn.Linear(fc_size[0], fc_size[1])
62+
self.head = nn.Linear(fc_size[1], self.output_num)
63+
64+
def forward(self, x):
65+
x = x.reshape(-1, self.input_num)
66+
x = F.relu(self.fc1(x))
67+
x = F.relu(self.fc2(x))
68+
69+
output = self.head(x)
70+
if self.out_softmax:
71+
output = F.softmax(output, dim=1) #值函数估计不应该有softmax
72+
return output

dqn_agent.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# -*- coding: utf-8 -*-
2+
# dqn_agent.py
3+
# author: yangrui
4+
# description:
5+
# created: 2019-10-12T11:07:45.524Z+08:00
6+
# last-modified: 2019-10-12T11:07:45.524Z+08:00
7+
8+
import torch
9+
import torch.nn as nn
10+
import torch.nn.functional as F
11+
import numpy as np
12+
import copy
13+
import utils
14+
from NN_module import CNN_Net, FC_Net
15+
from Buffer_module import Buffer
16+
17+
18+
class DQN():
19+
batch_size = 128
20+
lr = 1e-4
21+
epsilon = 0.15
22+
memory_capacity = int(1e4)
23+
gamma = 0.99
24+
q_network_iteration = 200
25+
save_path = "./save/"
26+
soft_update_theta = 0.1
27+
clip_norm_max = 1
28+
train_interval = 5
29+
conv_size = (32, 64) # num filters
30+
fc_size = (512, 128)
31+
32+
def __init__(self, num_state, num_action, enable_double=False, enable_priority=True):
33+
super(DQN, self).__init__()
34+
self.num_state = num_state
35+
self.num_action = num_action
36+
self.state_len = int(np.sqrt(self.num_state))
37+
self.enable_double = enable_double
38+
self.enable_priority = enable_priority
39+
40+
self.eval_net, self.target_net = CNN_Net(self.state_len, num_action,self.conv_size, self.fc_size), CNN_Net(self.state_len, num_action, self.conv_size, self.fc_size)
41+
# self.eval_net, self.target_net = FC_Net(self.num_state, self.num_action), FC_Net(self.num_state, self.num_action)
42+
43+
self.learn_step_counter = 0
44+
self.buffer = Buffer(self.num_state, 'priority', self.memory_capacity)
45+
# self.memory = np.zeros((self.memory_capacity, num_state * 2 + 2))
46+
self.initial_epsilon = self.epsilon
47+
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.lr)
48+
49+
50+
def select_action(self, state, random=False, deterministic=False):
51+
state = torch.unsqueeze(torch.FloatTensor(state), 0)
52+
if not random and np.random.random() > self.epsilon or deterministic: # greedy policy
53+
action_value = self.eval_net.forward(state)
54+
action = torch.max(action_value.reshape(-1,4), 1)[1].data.numpy()
55+
else: # random policy
56+
action = np.random.randint(0,self.num_action)
57+
return action
58+
59+
60+
def store_transition(self, state, action, reward, next_state):
61+
state = state.reshape(-1)
62+
next_state = next_state.reshape(-1)
63+
64+
transition = np.hstack((state, [action, reward], next_state))
65+
self.buffer.store(transition)
66+
# index = self.memory_counter % self.memory_capacity
67+
# self.memory[index, :] = transition
68+
# self.memory_counter += 1
69+
70+
71+
def update(self):
72+
#soft update the parameters
73+
if self.learn_step_counter % self.q_network_iteration ==0 and self.learn_step_counter:
74+
for p_e, p_t in zip(self.eval_net.parameters(), self.target_net.parameters()):
75+
p_t.data = self.soft_update_theta * p_e.data + (1 - self.soft_update_theta) * p_t.data
76+
77+
self.learn_step_counter+=1
78+
79+
#sample batch from memory
80+
if self.enable_priority:
81+
batch_memory, (tree_idx, ISWeights) = self.buffer.sample(self.batch_size)
82+
else:
83+
batch_memory, _ = self.buffer.sample(self.batch_size)
84+
85+
batch_state = torch.FloatTensor(batch_memory[:, :self.num_state])
86+
batch_action = torch.LongTensor(batch_memory[:, self.num_state: self.num_state+1].astype(int))
87+
batch_reward = torch.FloatTensor(batch_memory[:, self.num_state+1: self.num_state+2])
88+
batch_next_state = torch.FloatTensor(batch_memory[:,-self.num_state:])
89+
90+
#q_eval
91+
q_eval_total = self.eval_net(batch_state)
92+
q_eval = q_eval_total.gather(1, batch_action)
93+
q_next = self.target_net(batch_next_state).detach()
94+
95+
if self.enable_double:
96+
q_eval_argmax = q_eval_total.max(1)[1].view(self.batch_size, 1)
97+
q_max = q_next.gather(1, q_eval_argmax).view(self.batch_size, 1)
98+
else:
99+
q_max = q_next.max(1)[0].view(self.batch_size, 1)
100+
101+
q_target = batch_reward + self.gamma * q_max
102+
103+
if self.enable_priority:
104+
abs_errors = (q_target - q_eval.data).abs()
105+
self.buffer.update(tree_idx, abs_errors)
106+
# loss = (torch.FloatTensor(ISWeights) * (q_target - q_eval).pow(2)).mean()
107+
loss = (q_target - q_eval).pow(2).mean() # 可能去掉ISweight更好??
108+
109+
110+
# print(ISWeights)
111+
# print(loss)
112+
113+
# import pdb; pdb.set_trace()
114+
else:
115+
loss = F.mse_loss(q_eval, q_target)
116+
117+
118+
self.optimizer.zero_grad()
119+
loss.backward()
120+
nn.utils.clip_grad_norm_(self.eval_net.parameters(), self.clip_norm_max)
121+
self.optimizer.step()
122+
123+
return loss
124+
125+
126+
def save(self, path=None, name='dqn_net.pkl'):
127+
path = self.save_path if not path else path
128+
utils.check_path_exist(path)
129+
torch.save(self.eval_net.state_dict(), path + name)
130+
131+
def load(self, path=None, name='dqn_net.pkl'):
132+
path = self.save_path if not path else path
133+
self.eval_net.load_state_dict(torch.load(path + name))
134+
135+
136+
def epsilon_decay(self, episode, total_episode):
137+
self.epsilon = self.initial_epsilon * (1 - episode / total_episode)

gym_2048.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self):
5757
self.set_illegal_move_reward(0.)
5858
self.set_max_tile(None)
5959

60-
self.max_illegal = 50 # max number of illegal actions
60+
self.max_illegal = 10 # max number of illegal actions
6161
self.num_illegal = 0
6262

6363
# Initialise seed

0 commit comments

Comments
 (0)