Skip to content

Commit dd728f4

Browse files
committed
add base agent
1 parent 0556392 commit dd728f4

File tree

2 files changed

+75
-41
lines changed

2 files changed

+75
-41
lines changed

base_agent.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# -*- coding: utf-8 -*-
2+
# base_agent.py
3+
# author: yangrui
4+
# description:
5+
# created: 2019-09-29T15:01:38.383Z+08:00
6+
# last-modified: 2019-09-29T15:01:38.383Z+08:00
7+
8+
9+
from gym_2048 import Game2048Env
10+
import random
11+
12+
13+
class BaseAgent():
14+
def act(self, state):
15+
raise NotImplementedError
16+
17+
18+
class RandomAgent(BaseAgent):
19+
def act(self, state):
20+
return random.randint(0, 3)
21+
22+
23+
if __name__ == "__main__":
24+
import time
25+
import numpy as np
26+
27+
def run(ifrender=False):
28+
agent = RandomAgent()
29+
env = Game2048Env()
30+
state, reward, done, info = env.reset()
31+
if ifrender:
32+
env.render()
33+
34+
start = time.time()
35+
while True:
36+
action = agent.act(state)
37+
# print('action: {}'.format(action))
38+
state, reward, done, info = env.step(action)
39+
if ifrender:
40+
env.render()
41+
if done:
42+
print('\nfinished, info:{}'.format(info))
43+
break
44+
45+
end = time.time()
46+
print('episode time:{} s\n'.format(end - start))
47+
return end - start, info['highest'], info['score'], info['steps']
48+
49+
time_lis, highest_lis, score_lis, steps_lis = [], [], [], []
50+
for i in range(1000):
51+
t, highest, score, steps = run()
52+
time_lis.append(t)
53+
highest_lis.append(highest)
54+
score_lis.append(score)
55+
steps_lis.append(steps)
56+
57+
print('eval result:\naverage episode time:{} s, average highest score:{}, average total score:{}, average steps:{}'.format(np.mean(time_lis), np.mean(highest_lis), np.mean(score_lis), np.mean(steps_lis)))
58+
59+
60+

gym_2048.py

+15-41
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,19 @@ def __init__(self):
5959
# Initialise seed
6060
self.seed()
6161

62-
# Reset ready for a game
63-
self.reset()
62+
# # Reset ready for a game
63+
# self.reset()
64+
65+
def _get_info(self, info=None):
66+
if not info:
67+
info = {}
68+
else:
69+
assert type(info) == dict, 'info should be of type dict!'
70+
71+
info['highest'] = self.highest()
72+
info['score'] = self.score
73+
info['steps'] = self.steps
74+
return info
6475

6576
def seed(self, seed=None):
6677
self.np_random, seed = seeding.np_random(seed)
@@ -103,10 +114,7 @@ def step(self, action):
103114
done = False
104115
reward = self.illegal_move_reward
105116

106-
#print("Am I done? {}".format(done))
107-
info['highest'] = self.highest()
108-
info['score'] = self.score
109-
info['steps'] = self.steps
117+
info = self._get_info(info)
110118

111119
# Return observation (board state), reward, done and info dict
112120
return self.Matrix, reward, done, info
@@ -120,7 +128,7 @@ def reset(self):
120128
self.add_tile()
121129
self.add_tile()
122130

123-
return self.Matrix
131+
return self.Matrix, 0, False, self._get_info()
124132

125133
def render(self, mode='human'):
126134
outfile = StringIO() if mode == 'ansi' else sys.stdout
@@ -282,40 +290,6 @@ def set_board(self, new_board):
282290
"""Retrieve the whole board, useful for testing."""
283291
self.Matrix = new_board
284292

285-
if __name__ == "__main__":
286-
import random
287-
import time
288-
import numpy as np
289-
290-
291-
def run():
292-
env = Game2048Env()
293-
env.render()
294-
start = time.time()
295-
while True:
296-
action = random.randint(0, 3)
297-
print('action: {}'.format(action))
298-
state, reward, done, info = env.step(action)
299-
env.render()
300-
if done:
301-
print('\nfinished, info:{}'.format(info))
302-
break
303-
304-
end = time.time()
305-
print('episode time:{} s\n'.format(end - start))
306-
return end - start, info['highest'], info['score'], info['steps']
307-
308-
time_lis, highest_lis, score_lis, steps_lis = [], [], [], []
309-
for i in range(100):
310-
t, highest, score, steps = run()
311-
time_lis.append(t)
312-
highest_lis.append(highest)
313-
score_lis.append(score)
314-
steps_lis.append(steps)
315-
316-
print('eval result:\naverage episode time:{} s, average highest score:{}, average total score:{}, average steps:{}'.format(np.mean(time_lis), np.mean(highest_lis), np.mean(score_lis), np.mean(steps_lis)))
317-
318-
319293

320294

321295

0 commit comments

Comments
 (0)