forked from NVlabs/stylegan2
-
Notifications
You must be signed in to change notification settings - Fork 33
/
run_training.py
executable file
·221 lines (177 loc) · 10.2 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
import argparse
import copy
import os
import sys
import dnnlib
from dnnlib import EasyDict
from metrics.metric_defaults import metric_defaults
#----------------------------------------------------------------------------
_valid_configs = [
# Table 1
'config-a', # Baseline StyleGAN
'config-b', # + Weight demodulation
'config-c', # + Lazy regularization
'config-d', # + Path length regularization
'config-e', # + No growing, new G & D arch.
'config-f', # + Large networks (default)
# Table 2
'config-e-Gorig-Dorig', 'config-e-Gorig-Dresnet', 'config-e-Gorig-Dskip',
'config-e-Gresnet-Dorig', 'config-e-Gresnet-Dresnet', 'config-e-Gresnet-Dskip',
'config-e-Gskip-Dorig', 'config-e-Gskip-Dresnet', 'config-e-Gskip-Dskip',
]
#----------------------------------------------------------------------------
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, mirror_augment_v, metrics, min_h, min_w, res_log2, lr, use_attention, resume_with_new_nets, glr, dlr, use_raw):
train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.
G = EasyDict(func_name='training.networks_stylegan2.G_main') # Options for generator network.
D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') # Options for discriminator network.
G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg') # Options for generator loss.
D_loss = EasyDict(func_name='training.loss.D_logistic_r1') # Options for discriminator loss.
sched = EasyDict() # Options for TrainingSchedule.
grid = EasyDict(size='8k', layout='random') # Options for setup_snapshot_image_grid().
sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run().
tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf().
train.data_dir = data_dir
train.total_kimg = total_kimg
train.mirror_augment = mirror_augment
train.mirror_augment_v = mirror_augment_v
train.resume_with_new_nets = resume_with_new_nets
train.image_snapshot_ticks = 1
train.network_snapshot_ticks = 4
sched.G_lrate_base = sched.D_lrate_base = lr
if glr:
sched.G_lrate_base = glr
if dlr:
sched.D_lrate_base = dlr
sched.minibatch_size_base = 32
sched.minibatch_gpu_base = 4
D_loss.gamma = 10
metrics = [metric_defaults[x] for x in metrics]
desc = 'stylegan2'
desc += '-' + dataset
dataset_args = EasyDict(tfrecord_dir=dataset)
dataset_args.use_raw = use_raw
G.min_h = D.min_h = dataset_args.min_h = min_h
G.min_w = D.min_w = dataset_args.min_w = min_w
G.res_log2 = D.res_log2 = dataset_args.res_log2 = res_log2
if use_attention:
desc+= '-attention'; G.use_attention=True; D.use_attention=True
assert num_gpus in [1, 2, 4, 8]
sc.num_gpus = num_gpus
desc += '-%dgpu' % num_gpus
assert config_id in _valid_configs
desc += '-' + config_id
# Configs A-E: Shrink networks to match original StyleGAN.
if config_id != 'config-f':
G.fmap_base = D.fmap_base = 8 << 10
# Config E: Set gamma to 100 and override G & D architecture.
if config_id.startswith('config-e'):
D_loss.gamma = 100
if 'Gorig' in config_id: G.architecture = 'orig'
if 'Gskip' in config_id: G.architecture = 'skip' # (default)
if 'Gresnet' in config_id: G.architecture = 'resnet'
if 'Dorig' in config_id: D.architecture = 'orig'
if 'Dskip' in config_id: D.architecture = 'skip'
if 'Dresnet' in config_id: D.architecture = 'resnet' # (default)
# Configs A-D: Enable progressive growing and switch to networks that support it.
if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
sched.lod_initial_resolution = 8
sched.G_lrate_base = sched.D_lrate_base = 0.001
sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
sched.minibatch_size_base = 32 # (default)
sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
sched.minibatch_gpu_base = 4 # (default)
sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
G.synthesis_func = 'G_synthesis_stylegan_revised'
D.func_name = 'training.networks_stylegan2.D_stylegan'
# Configs A-C: Disable path length regularization.
if config_id in ['config-a', 'config-b', 'config-c']:
G_loss = EasyDict(func_name='training.loss.G_logistic_ns')
# Configs A-B: Disable lazy regularization.
if config_id in ['config-a', 'config-b']:
train.lazy_regularization = False
# Config A: Switch to original StyleGAN networks.
if config_id == 'config-a':
G = EasyDict(func_name='training.networks_stylegan.G_style')
D = EasyDict(func_name='training.networks_stylegan.D_basic')
if gamma is not None:
D_loss.gamma = gamma
sc.submit_target = dnnlib.SubmitTarget.LOCAL
sc.local.do_not_copy_source_files = True
kwargs = EasyDict(train)
kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)
kwargs.submit_config = copy.deepcopy(sc)
kwargs.submit_config.run_dir_root = result_dir
kwargs.submit_config.run_desc = desc
dnnlib.submit_run(**kwargs)
#----------------------------------------------------------------------------
def _str_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def _parse_comma_sep(s):
if s is None or s.lower() == 'none' or s == '':
return []
return s.split(',')
#----------------------------------------------------------------------------
_examples = '''examples:
# Train StyleGAN2 using the FFHQ dataset
python %(prog)s --num-gpus=8 --data-dir=~/datasets --config=config-f --dataset=ffhq --mirror-augment=true
valid configs:
''' + ', '.join(_valid_configs) + '''
valid metrics:
''' + ', '.join(sorted([x for x in metric_defaults.keys()])) + '''
'''
def main():
parser = argparse.ArgumentParser(
description='Train StyleGAN2.',
epilog=_examples,
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
parser.add_argument('--data-dir', help='Dataset root directory', required=True)
parser.add_argument('--dataset', help='Training dataset', required=True)
parser.add_argument('--config', help='Training config (default: %(default)s)', default='config-f', required=True, dest='config_id', metavar='CONFIG')
parser.add_argument('--num-gpus', help='Number of GPUs (default: %(default)s)', default=1, type=int, metavar='N')
parser.add_argument('--total-kimg', help='Training length in thousands of images (default: %(default)s)', metavar='KIMG', default=25000, type=int)
parser.add_argument('--gamma', help='R1 regularization weight (default is config dependent)', default=None, type=float)
parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool)
parser.add_argument('--mirror-augment-v', help='Mirror augment vertically (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool)
parser.add_argument('--metrics', help='Comma-separated list of metrics or "none" (default: %(default)s)', default='fid50k', type=_parse_comma_sep)
parser.add_argument('--min-h', help='lowest dim of height', default=4, type=int)
parser.add_argument('--min-w', help='lowest dim of width', default=4, type=int)
parser.add_argument('--res-log2', help='multiplier for image size, the training image size (height, width) should be (min_h * 2**res_log2, min_w * 2**res_log2)', default=7, type=int)
parser.add_argument('--lr', help='base learning rate', default=0.003, type=float)
parser.add_argument('--glr',help='overwrite base learning rate for G', default=None, type=float)
parser.add_argument('--dlr',help='overwrite base learning rate for D', default=None, type=float)
parser.add_argument('--use-raw', help='Use raw image dataset, i.e. created from create_from_images_raw (default: %(default)s)', default=True, metavar='BOOL', type=_str_to_bool)
parser.add_argument('--use-attention', help='Experimental: Use google attention (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool)
parser.add_argument('--resume_with_new_nets', help='Experimental: Copy from checkpoint instead of direct load, useful for network structure modification (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool)
args = parser.parse_args()
if not os.path.exists(args.data_dir):
print ('Error: dataset root directory does not exist.')
sys.exit(1)
if args.config_id not in _valid_configs:
print ('Error: --config value must be one of: ', ', '.join(_valid_configs))
sys.exit(1)
for metric in args.metrics:
if metric not in metric_defaults:
print ('Error: unknown metric \'%s\'' % metric)
sys.exit(1)
run(**vars(args))
#----------------------------------------------------------------------------
if __name__ == "__main__":
main()
#----------------------------------------------------------------------------