-
Notifications
You must be signed in to change notification settings - Fork 16
/
train.py
38 lines (33 loc) · 1.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from docopt import docopt
from trainer import PPOTrainer
from yaml_parser import YamlParser
def main():
# Command line arguments via docopt
_USAGE = """
Usage:
train.py [options]
train.py --help
Options:
--config=<path> Path to the yaml config file [default: ./configs/cartpole.yaml]
--run-id=<path> Specifies the tag for saving the tensorboard summary [default: run].
--cpu Force training on CPU [default: False]
"""
options = docopt(_USAGE)
run_id = options["--run-id"]
cpu = options["--cpu"]
# Parse the yaml config file. The result is a dictionary, which is passed to the trainer.
config = YamlParser(options["--config"]).get_config()
if not cpu:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
device = torch.device("cpu")
torch.set_default_tensor_type("torch.FloatTensor")
# Initialize the PPO trainer and commence training
trainer = PPOTrainer(config, run_id=run_id, device=device)
trainer.run_training()
trainer.close()
if __name__ == "__main__":
main()