-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathreplay_buffer.py
executable file
·123 lines (104 loc) · 4.03 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
File: replay_buffer.py
Author: Nathaniel Hamilton
Description: This class implements a replay buffer where the relevant information of past experiences is stored and can
be sampled from.
Usage: Import the entire class file to use all options.
Remaining Tasks:
* Make it
"""
import numpy as np
import math
class ReplayBuffer(object):
def __init__(self, capacity, prev_buffer=None):
"""
TODO
"""
# Initialize the replay buffer to empty if no previous buffer is supplied
if prev_buffer is None:
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.next_states = []
self.position = 0
self.capacity = float(capacity)
else:
fill = min(capacity, prev_buffer.capacity)
self.states = prev_buffer.states[-fill:]
self.actions = prev_buffer.actions[-fill:]
self.rewards = prev_buffer.rewards[-fill:]
self.dones = prev_buffer.dones[-fill:]
self.next_states = prev_buffer.next_states[-fill:]
self.position = int(prev_buffer.position % capacity)
self.capacity = float(capacity)
def add_memory(self, state, action, reward, done, next_state):
"""
TODO
:param state:
:param action:
:param reward:
:param done:
:param next_state:
:return:
"""
# If the buffer is not full, append it
if len(self.rewards) < self.capacity:
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.dones.append(done)
self.next_states.append(next_state)
# Otherwise, replace old memories
else:
self.states[self.position] = state
self.actions[self.position] = action
self.rewards[self.position] = reward
self.dones[self.position] = done
self.next_states[self.position] = next_state
# Increment the position
self.position = int((self.position + 1) % self.capacity)
def sample_batch(self, batch_length):
"""
TODO
Assumes the buffer is full
:param batch_length:
:return:
"""
# Convert lists to arrays
states = np.asarray(self.states)
actions = np.asarray(self.actions)
rewards = np.asarray(self.rewards)
dones = np.asarray(self.dones)
next_states = np.asarray(self.next_states)
# Randomize the order of the buffer to eliminate TODO
capacity = len(self.rewards)
indices = np.arange(capacity)
np.random.shuffle(indices)
# Collect the sample from the middle of the randomized indices
batch_start = int(max((math.floor(capacity / 2.0) - math.floor(batch_length / 2.0)), 0))
batch_end = int(min((math.floor(capacity / 2.0) + math.ceil(batch_length / 2.0)), capacity))
batch_indeces = np.asarray(indices[batch_start:batch_end:1], dtype=int)
# np.random.shuffle(batch_indeces) # Shuffle the selected indices again to further randomize
batch_states = states[batch_indeces]
batch_actions = actions[batch_indeces]
batch_rewards = rewards[batch_indeces]
batch_dones = dones[batch_indeces]
batch_next_states = next_states[batch_indeces]
return batch_states, batch_actions, batch_rewards, batch_dones, batch_next_states
if __name__ == '__main__':
replay_buffer = ReplayBuffer(100)
s = np.asarray([0, 0, 0, 0])
a = np.asarray([0, 0])
r = 0
d = 0
for i in range(18):
replay_buffer.add_memory(s, a, r, d, s)
r += 1
b_states, b_actions, b_rewards, b_dones, b_next_states = replay_buffer.sample_batch(19)
print(b_rewards)
new_buffer = ReplayBuffer(5, replay_buffer)
new_buffer.add_memory(s, a, r, d, s)
b_states, b_actions, b_rewards, b_dones, b_next_states = new_buffer.sample_batch(5)
print('New Sample')
print(b_rewards)