-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
50 lines (43 loc) · 1.44 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
'''''''''
@file: main.py
@author: MRL Liu
@time: 2021/4/25 14:17
@env: Python,Numpy
@desc: Maze项目的启动器,负责切换不同的方格环境
@ref:
@blog: https://blog.csdn.net/qq_41959920
'''''''''
import time
from datetime import timedelta
from env.line_env import Line_Env
from env.maze_env import Maze_Env
from brain.brain import Brain
from trainer import Trainer
def run_line():
env = Line_Env() # 创建环境
agent = Brain(env.n_features,env.n_actions) # 创建agent
trainer = Trainer(env, agent) # 创建训练器
# 训练agent模型
trainer.train_q_learning(max_episodes=10)
#trainer.train_sarsa(max_episodes=10)
# 绘制检测数据
trainer.draw_plot()
def run_maze():
env = Maze_Env() # 创建环境
agent = Brain(env.n_features,env.n_actions)
trainer = Trainer(env, agent) # 创建训练器
#trainer = Maze_Trainer_Sarsa(env, agent)
# 训练agent模型
# env.after(100, trainer.train(max_episodes=10)) # 在窗口主循环中添加方法
start_time = time.time()# 记录时间
trainer.train_q_learning(max_episodes=15)
#trainer.train_dqn(max_episodes=15)
end_time = time.time()
time_dif = end_time - start_time
print("本次训练总共花费的Time: " + str(timedelta(seconds=int(round(time_dif)))))
# 绘制检测数据
trainer.draw_plot()
env.mainloop() # 调用主循环显示窗口
if __name__ == '__main__':
#run_line()
run_maze()