|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# Buffer_priority.py |
| 3 | +# author: yangrui |
| 4 | +# description: |
| 5 | +# created: 2019-11-01T11:52:44.496Z+08:00 |
| 6 | +# last-modified: 2019-11-01T11:52:44.496Z+08:00 |
| 7 | + |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +class SumTree(object): |
| 12 | + data_pointer = 0 |
| 13 | + |
| 14 | + def __init__(self, capacity): |
| 15 | + self.capacity = capacity # for all priority values |
| 16 | + self.tree = np.zeros(2 * capacity - 1) |
| 17 | + # [--------------Parent nodes-------------][-------leaves to recode priority-------] |
| 18 | + # size: capacity - 1 size: capacity |
| 19 | + self.data = np.zeros(capacity, dtype=object) # for all transitions,格式是对象,相当于指针 |
| 20 | + # [--------------data frame-------------] |
| 21 | + # size: capacity |
| 22 | + |
| 23 | + def add(self, p, data): |
| 24 | + tree_idx = self.data_pointer + self.capacity - 1 # 在树的叶子节点的位置 |
| 25 | + self.data[self.data_pointer] = data # update data_frame |
| 26 | + self.update(tree_idx, p) # update tree_frame |
| 27 | + |
| 28 | + self.data_pointer += 1 |
| 29 | + if self.data_pointer >= self.capacity: # replace when exceed the capacity |
| 30 | + self.data_pointer = 0 |
| 31 | + |
| 32 | + def update(self, tree_idx, p): |
| 33 | + change = p - self.tree[tree_idx] |
| 34 | + self.tree[tree_idx] = p |
| 35 | + # then propagate the change through tree |
| 36 | + while tree_idx != 0: # this method is faster than the recursive loop in the reference code |
| 37 | + tree_idx = (tree_idx - 1) // 2 |
| 38 | + self.tree[tree_idx] += change |
| 39 | + |
| 40 | + def get_leaf(self, v): |
| 41 | + """ |
| 42 | + Tree structure and array storage: |
| 43 | +
|
| 44 | + Tree index: |
| 45 | + 0 -> storing priority sum |
| 46 | + / \ |
| 47 | + 1 2 |
| 48 | + / \ / \ |
| 49 | + 3 4 5 6 -> storing priority for transitions |
| 50 | +
|
| 51 | + Array type for storing: |
| 52 | + [0,1,2,3,4,5,6] |
| 53 | + """ |
| 54 | + parent_idx = 0 |
| 55 | + while True: # the while loop is faster than the method in the reference code |
| 56 | + cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids |
| 57 | + cr_idx = cl_idx + 1 |
| 58 | + if cl_idx >= len(self.tree): # reach bottom, end search |
| 59 | + leaf_idx = parent_idx |
| 60 | + break |
| 61 | + else: # downward search, always search for a higher priority node |
| 62 | + if v <= self.tree[cl_idx]: |
| 63 | + parent_idx = cl_idx |
| 64 | + else: |
| 65 | + v -= self.tree[cl_idx] |
| 66 | + parent_idx = cr_idx |
| 67 | + |
| 68 | + data_idx = leaf_idx - self.capacity + 1 |
| 69 | + return leaf_idx, self.tree[leaf_idx], self.data[data_idx] |
| 70 | + |
| 71 | + @property |
| 72 | + def total_p(self): |
| 73 | + return self.tree[0] # the root |
| 74 | + |
| 75 | + |
| 76 | +class Buffer_PER(object): # stored as ( s, a, r, s_ ) in SumTree |
| 77 | + epsilon = 0.01 # small amount to avoid zero priority |
| 78 | + alpha = 0.6 # [0~1] convert the importance of TD error to priority |
| 79 | + beta = 0.4 # importance-sampling, from initial value increasing to 1 |
| 80 | + beta_increment_per_sampling = 0.001 |
| 81 | + abs_err_upper = 1. # clipped abs error |
| 82 | + |
| 83 | + def __init__(self, capacity): |
| 84 | + self.tree = SumTree(capacity) |
| 85 | + |
| 86 | + def store(self, transition): |
| 87 | + max_p = np.max(self.tree.tree[-self.tree.capacity:]) |
| 88 | + if max_p == 0: |
| 89 | + max_p = self.abs_err_upper |
| 90 | + self.tree.add(max_p, transition) # set the max p for new p |
| 91 | + |
| 92 | + def sample(self, n): |
| 93 | + b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1)) |
| 94 | + pri_seg = self.tree.total_p / n # priority segment |
| 95 | + self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) # max = 1 |
| 96 | + |
| 97 | + min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight |
| 98 | + for i in range(n): |
| 99 | + a, b = pri_seg * i, pri_seg * (i + 1) |
| 100 | + v = np.random.uniform(a, b) |
| 101 | + idx, p, data = self.tree.get_leaf(v) |
| 102 | + prob = p / self.tree.total_p |
| 103 | + ISWeights[i, 0] = np.power(prob/min_prob, -self.beta) |
| 104 | + b_idx[i], b_memory[i, :] = idx, data |
| 105 | + |
| 106 | + return b_idx, b_memory, ISWeights |
| 107 | + |
| 108 | + def batch_update(self, tree_idx, abs_errors): |
| 109 | + abs_errors += self.epsilon # convert to abs and avoid 0 |
| 110 | + clipped_errors = np.minimum(abs_errors, self.abs_err_upper) |
| 111 | + ps = np.power(clipped_errors, self.alpha) |
| 112 | + for ti, p in zip(tree_idx, ps): |
| 113 | + self.tree.update(ti, p) |
| 114 | + |
| 115 | +# 总的buffer类 |
| 116 | +class Buffer(): |
| 117 | + def __init__(self,n_features, buffer_type='', capacity=1e4): |
| 118 | + self.memory_size = capacity |
| 119 | + self.n_features = n_features |
| 120 | + self.type = buffer_type |
| 121 | + self.memory_counter = 0 |
| 122 | + |
| 123 | + if self.type == 'priority': |
| 124 | + self.memory = Buffer_PER(capacity=capacity) |
| 125 | + else: |
| 126 | + self.memory = np.zeros((self.memory_size, n_features*2+2)) |
| 127 | + |
| 128 | + def store(self, transition): |
| 129 | + self.memory_counter += 1 |
| 130 | + |
| 131 | + if self.type == 'priority': |
| 132 | + self.memory.store(transition) |
| 133 | + else: |
| 134 | + index = self.memory_counter % self.memory_size |
| 135 | + self.memory[index, :] = transition |
| 136 | + |
| 137 | + |
| 138 | + def sample(self, batch_size): |
| 139 | + info = None |
| 140 | + if self.type == 'priority': |
| 141 | + tree_idx, batch_memory, ISWeights = self.memory.sample(batch_size) |
| 142 | + info = (tree_idx, ISWeights) |
| 143 | + else: |
| 144 | + sample_index = np.random.choice(self.memory_size, size=batch_size) # 考虑buffer已先填满 |
| 145 | + batch_memory = self.memory[sample_index, :] |
| 146 | + |
| 147 | + return batch_memory, info |
| 148 | + |
| 149 | + def update(self, tree_idx, td_errors): |
| 150 | + assert self.type == 'priority' |
| 151 | + self.memory.batch_update(tree_idx, td_errors) |
| 152 | + |
| 153 | + |
| 154 | + |
| 155 | + |
| 156 | + |
| 157 | + |
| 158 | + |
| 159 | + |
| 160 | + |
| 161 | + |
| 162 | + |
| 163 | + |
0 commit comments