-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train.py
53 lines (39 loc) · 1.99 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
from GeneralDQN import DQN
from TennisEnv import TennisEnv
from TennisQNet import TennisLeaderQNetwork
from TennisQNet import TennisDealerQNetwork
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--leader", default=None, help="The leader policy network (.pt) (optional)")
parser.add_argument("--leader_target", default=None, help="The leader target network (.pt) (optional)")
parser.add_argument("--dealer", default=None, help="The dealer policy network (.pt) (optional)")
parser.add_argument("--dealer_target", default=None, help="The dealer target network(.pt) (optional)")
parser.add_argument("--trainee", default="leader", help="\"dealer\" or \"leader\"")
parser.add_argument("--save_name", default="best", help="Path to save checkpoints to")
args = parser.parse_args()
assert(args.trainee in ["leader", "dealer"])
learning_rates = [1e-3, 5e-4, 1e-4, 5e-5, 1e-5]
epsilon_decays = [1e5, 5e5, 1e6, 5e6]
for lr in learning_rates:
for eps_decay in epsilon_decays:
hyperparameters = {"LR": lr, "EPS_DECAY": eps_decay}
# Leader
leader = DQN(TennisLeaderQNetwork, hyperparameters)
leader.load_model(args.leader, args.leader_target)
# Dealer
dealer = DQN(TennisDealerQNetwork, hyperparameters)
dealer.load_model(args.dealer, args.dealer_target)
# Environment
environment = TennisEnv(leader, dealer, rewarded_player=args.trainee)
# Train just one player at a time
if args.trainee == "leader":
trainee_obj = leader
else:
trainee_obj = dealer
# Train the model
trainee_obj.train(environment, save_name=args.save_name)
# Validate
avg_reward = trainee_obj.validate(environment)
# Print results
print("Learning rate: {lr}, Epsilon decay: {eps_decay}, Average reward: {avg_reward}")