-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
44 lines (37 loc) · 1.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print
# Custom function to get the configuration
def get_config(custom_config):
config = (
PPOConfig()
.environment(custom_config["env_name"])
.framework(custom_config["framework"])
.training(train_batch_size=custom_config["train_batch_size"])
.multi_agent(
policies=custom_config["policies"],
policy_mapping_fn=custom_config["policy_mapping_fn"],
policies_to_train=custom_config["policies_to_train"],
)
.resources(num_gpus=custom_config["num_gpus"])
.learners(num_learners=custom_config["num_learners"])
.env_runners(
num_env_runners=custom_config["num_env_runners"],
num_envs_per_env_runner=custom_config["num_envs_per_env_runner"],
batch_mode=custom_config["batch_mode"],
rollout_fragment_length=custom_config["rollout_fragment_length"],
)
)
return config
# Training the model
def train_model(config, num_iterations=5, result_path='results.txt', checkpoint_dir='checkpoints'):
tune.register_env(config["env_name"], config['env_creator'])
algo_config = get_config(config)
algo = algo_config.build()
for i in range(num_iterations):
result = algo.train()
print(pretty_print(result))
with open(result_path, 'w') as file:
file.write(pretty_print(result))
checkpoint_dir = algo.save(checkpoint_dir).checkpoint.path
print(f"Checkpoint saved in directory {checkpoint_dir}")