forked from thainv0212/my_game_bot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
memory.py
142 lines (111 loc) · 5.58 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy as np
from tree import SumTree
class NormalMemory():
def __init__(self, buffer_size=100000, observation_space_shape=(286), multi_reward=False):
self.buffer_size = buffer_size
total_shape = np.prod(observation_space_shape)
self.state_mem = np.zeros((self.buffer_size, *(total_shape,)), dtype=np.float32)
self.action_mem = np.zeros((self.buffer_size), dtype=np.int32)
if not multi_reward:
self.reward_mem = np.zeros((self.buffer_size), dtype=np.float32)
else:
self.reward_mem = np.zeros((self.buffer_size, 2), dtype=np.float32)
self.next_state_mem = np.zeros((self.buffer_size, *(total_shape,)),
dtype=np.float32)
self.done_mem = np.zeros((self.buffer_size), dtype=np.bool)
self.pointer = 0
def add_exp(self, state, action, reward, next_state, done):
idx = self.pointer % self.buffer_size
self.state_mem[idx] = np.squeeze(state).transpose().flatten()
self.action_mem[idx] = action
self.reward_mem[idx] = reward
self.next_state_mem[idx] = np.squeeze(next_state).transpose().flatten()
self.done_mem[idx] = 1 - int(done)
self.pointer += 1
def sample_exp(self, batch_size=128):
max_mem = min(self.pointer, self.buffer_size)
batch = np.random.choice(max_mem, batch_size, replace=False)
states = self.state_mem[batch]
actions = self.action_mem[batch]
rewards = self.reward_mem[batch]
next_states = self.next_state_mem[batch]
dones = self.done_mem[batch]
return states, actions, rewards, next_states, dones
class PERMemory(object): # stored as ( s, a, r, s_ ) in SumTree
"""
This SumTree code is modified version and the original code is from:
https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
"""
PER_e = 0.01 # Hyperparameter that we use to avoid some experiences to have 0 probability of being taken
PER_a = 0.6 # Hyperparameter that we use to make a tradeoff between taking only exp with high priority and sampling randomly
PER_b = 0.4 # importance-sampling, from initial value increasing to 1
PER_b_increment_per_sampling = 0.001
absolute_error_upper = 1. # clipped abs error
def __init__(self, capacity):
# Making the tree
"""
Remember that our tree is composed of a sum tree that contains the priority scores at his leaf
And also a data array
We don't use deque because it means that at each timestep our experiences change index by one.
We prefer to use a simple array and to overwrite when the memory is full.
"""
self.tree = SumTree(capacity)
"""
Store a new experience in our tree
Each new experience have a score of max_prority (it will be then improved when we use this exp to train our DDQN)
"""
def store(self, experience):
# Find the max priority
max_priority = np.max(self.tree.tree[-self.tree.capacity:])
# If the max priority = 0 we can't put priority = 0 since this exp will never have a chance to be selected
# So we use a minimum priority
if max_priority == 0:
max_priority = self.absolute_error_upper
self.tree.add(max_priority, experience) # set the max p for new p
"""
- First, to sample a minibatch of k size, the range [0, priority_total] is / into k ranges.
- Then a value is uniformly sampled from each range
- We search in the sumtree, the experience where priority score correspond to sample values are retrieved from.
- Then, we calculate IS weights for each minibatch element
"""
def sample(self, n):
# Create a sample array that will contains the minibatch
memory_b = []
b_idx, b_ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, 1), dtype=np.float32)
# Calculate the priority segment
# Here, as explained in the paper, we divide the Range[0, ptotal] into n ranges
priority_segment = self.tree.total_priority / n # priority segment
# Here we increasing the PER_b each time we sample a new minibatch
self.PER_b = np.min([1., self.PER_b + self.PER_b_increment_per_sampling]) # max = 1
# Calculating the max_weight
p_min = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_priority
max_weight = (p_min * n) ** (-self.PER_b)
if np.isinf(max_weight):
max_weight = 1
for i in range(n):
"""
A value is uniformly sample from each range
"""
a, b = priority_segment * i, priority_segment * (i + 1)
value = np.random.uniform(a, b)
"""
Experience that correspond to each value is retrieved
"""
index, priority, data = self.tree.get_leaf(value)
# P(j)
sampling_probabilities = priority / self.tree.total_priority
# IS = (1/N * 1/P(i))**b /max wi == (N*P(i))**-b /max wi
b_ISWeights[i, 0] = np.power(n * sampling_probabilities, -self.PER_b) / max_weight
b_idx[i] = index
experience = [data]
memory_b.append(experience)
return b_idx, memory_b, b_ISWeights
"""
Update the priorities on the tree
"""
def batch_update(self, tree_idx, abs_errors):
abs_errors += self.PER_e # convert to abs and avoid 0
clipped_errors = np.minimum(abs_errors, self.absolute_error_upper)
ps = np.power(clipped_errors, self.PER_a)
for ti, p in zip(tree_idx, ps):
self.tree.update(ti, p)