forked from julesripoll/dqn-pong-breakout
-
Notifications
You must be signed in to change notification settings - Fork 0
/
play_game.py
43 lines (33 loc) · 1.43 KB
/
play_game.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
import os
import gym
from utils.utils import play_game
from agents.dqn import DQNAgent
from agents.double_dqn import DoubleDQNAgent
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--agent', type=str, default='ddqn', help='Agent choice : DQN/DDQN')
parser.add_argument('--path_weights', type=str, default=None, help='Path to agent weight')
parser.add_argument('--gif_name', type=str, default='game.gif', help='GIF file name')
args = parser.parse_args()
agent_type = args.agent
gif_name = args.gif_name
agent_file = args.path_weights
env = gym.make('MinAtar/Breakout-v1')
env.reset()
saving_dir = './networks_weights/'
if agent_file:
agent_path = os.path.join(saving_dir, agent_file + '.zip')
agent = DQNAgent.load(agent_path)
else:
if agent_type == 'dqn':
print("Load DQN weights")
agent_file = 'DQN_lr=5e-4_g=0.999_bs=256_ed=0.9995_em=0.005_ms=500000_ts=2e6'
agent_path = os.path.join(saving_dir, agent_file + '.zip')
agent = DQNAgent.load(agent_path)
else:
print("Load Double DQN weights")
agent_file = 'DDQN_lr=5e-4_g=0.999_bs=256_ed=0.9995_em=0.005_ms=500000_ts=2e6'
agent_path = os.path.join(saving_dir, agent_file + '.zip')
agent = DoubleDQNAgent.load(agent_path)
play_game(env, agent, path=f"./games/{gif_name}")