-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRGBState.py
68 lines (59 loc) · 2 KB
/
RGBState.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import time
import copy
import time
import constants
import HyperNode
import globales
import Features
class RGBState():
def __init__(self, ale_state=None, features=None):
self.ale_state = ale_state
self.features = features
self.best_reward_below = float("-inf")
self.reward = 0
self.terminal = False
def __eq__(self, other):
return self.features == other.features
def __hash__(self):
return hash(str(self.process_screen()))
def process_screen(self):
return self.features
def get_basic_features(self, rgb):
f = set()
for x in range(0,len(rgb)):
for y in range(0, len(rgb[0])):
color = (rgb[x][y][0], rgb[x][y][1], rgb[x][y][2])
f.add(( x / constants.SPLIT_X, y / constants.SPLIT_Y, color))
return list(f)
def get_b_pros_features(self, rgb, fs):
tmp = list(fs)
list.sort(tmp)
ls = [ set() for x in range(0,20)]
#print tmp
return []
def get_b_prost_features(self, rgb):
return []
def set_features(self, rgb):
features = []
features += self.get_basic_features(rgb)
features += self.get_b_pros_features(rgb, features)
features += self.get_b_prost_features(rgb)
self.features = Features.Features(features)
def get_successor_states(self, env):
"Return a list of nodes reachable from this node. [Fig. 3.8]"
nexts = []
for act in range(0,len(env._action_set)):
#print act
n = RGBState()
env.ale.restoreState(self.ale_state)
screen, reward, terminal, info = env.step(act)
#env.render()
#time.sleep(2)
#env.render(close=True)
n.set_features(screen)
n.ale_state = env.ale.cloneState()
n.reward = self.reward + reward
n.best_reward_below = n.reward
n.terminal = terminal
nexts.append((act, n))
return nexts