This repository has been archived by the owner on May 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
/
train_metabo_hm3.py
executable file
·115 lines (106 loc) · 4.02 KB
/
train_metabo_hm3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) 2019 Robert Bosch GmbH
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# ******************************************************************
# train_metabo_hm3.py
# Train MetaBO on Hartmann-3-Function
# The weights, stats, logs, and the learning curve are stored in metabo/log and can
# be evaluated using metabo/eval/evaluate.py
# ******************************************************************
import os
import multiprocessing as mp
from datetime import datetime
from metabo.policies.policies import NeuralAF
from metabo.ppo.ppo import PPO
from metabo.ppo.plot_learning_curve_online import plot_learning_curve_online
from gym.envs.registration import register
rootdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "metabo")
# specifiy environment
env_spec = {
"env_id": "MetaBO-HM3-v0",
"D": 3,
"f_type": "HM3-var",
"f_opts": {"bound_translation": 0.1,
"bound_scaling": 0.1,
"M": 50},
"features": ["posterior_mean", "posterior_std", "timestep", "budget", "x"],
"T": 30,
"n_init_samples": 0,
"pass_X_to_pi": False,
# parameters were determined offline via type-2-ML on a GP with 100 datapoints
"kernel_lengthscale": [0.716, 0.298, 0.186],
"kernel_variance": 0.83,
"noise_variance": 1.688e-11,
"use_prior_mean_function": False,
"local_af_opt": True,
"N_MS": 2000,
"N_LS": 2000,
"k": 5,
"reward_transformation": "neg_log10"
}
# specify PPO parameters
n_iterations = 2000
batch_size = 1200
n_workers = 10
arch_spec = 4 * [200]
ppo_spec = {
"batch_size": batch_size,
"max_steps": n_iterations * batch_size,
"minibatch_size": batch_size // 20,
"n_epochs": 4,
"lr": 1e-4,
"epsilon": 0.15,
"value_coeff": 1.0,
"ent_coeff": 0.01,
"gamma": 0.98,
"lambda": 0.98,
"loss_type": "GAElam",
"normalize_advs": True,
"n_workers": n_workers,
"env_id": env_spec["env_id"],
"seed": 0,
"env_seeds": list(range(n_workers)),
"policy_options": {
"activations": "relu",
"arch_spec": arch_spec,
"use_value_network": True,
"t_idx": -2,
"T_idx": -1,
"arch_spec_value": arch_spec
}
}
# register environment
register(
id=env_spec["env_id"],
entry_point="metabo.environment.metabo_gym:MetaBO",
max_episode_steps=env_spec["T"],
reward_threshold=None,
kwargs=env_spec
)
# log data and weights go here, use this folder for evaluation afterwards
logpath = os.path.join(rootdir, "log", env_spec["env_id"], datetime.strftime(datetime.now(), "%Y-%m-%d-%H-%M-%S"))
# set up policy
policy_fn = lambda observation_space, action_space, deterministic: NeuralAF(observation_space=observation_space,
action_space=action_space,
deterministic=deterministic,
options=ppo_spec["policy_options"])
# do training
print("Training on {}.\nFind logs, weights, and learning curve at {}\n\n".format(env_spec["env_id"], logpath))
ppo = PPO(policy_fn=policy_fn, params=ppo_spec, logpath=logpath, save_interval=1)
# learning curve is plotted online in separate process
p = mp.Process(target=plot_learning_curve_online, kwargs={"logpath": logpath, "reload": True})
p.start()
ppo.train()
p.terminate()
plot_learning_curve_online(logpath=logpath, reload=False)