diff --git a/README.md b/README.md index 3180eaf2..38b6a279 100755 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ * **Flesh Digressions**: @aydao’s circular constant layer script edited to work with ADA see aydao_flesh_digressions.py * **Raw dataset creations**: Taken from the @skyflynil repo, reduces the size of datasets dramatically. Use `create_from images_raw` and `create_from image_folders_raw` in dataset creation, and use `--use-raw=True` in training (False by default!) * **align faces script**: From @pbaylies, this script will align images better for projection. +* **top-k training**: Improve generator training by only propagating gradients from images the discriminator was most unsure of: [Sinha & Zhao](https://arxiv.org/abs/2002.06224). +* **@aydao's config**: Extra large config for huge datasets (>100k img) ## StyleGAN2 with adaptive discriminator augmentation (ADA)
— Official TensorFlow implementation diff --git a/dataset_tool.py b/dataset_tool.py index ab224a17..8a22a720 100755 --- a/dataset_tool.py +++ b/dataset_tool.py @@ -756,7 +756,7 @@ def create_from_images_raw(tfrecord_dir, image_dir, shuffle, resolution_log2=7, order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames)) tfr.create_tfr_writer(img.shape) for idx in range(order.size): - print('loading: ' + image_filenames[order[idx]]) + # print('loading: ' + image_filenames[order[idx]]) # img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) # if (img.shape[1] != 1024) or (img.shape[0] != 1024): # error('Input images must have the same width and height') diff --git a/dnnlib/tflib/network.py b/dnnlib/tflib/network.py index ff0c169e..45bb5f9c 100755 --- a/dnnlib/tflib/network.py +++ b/dnnlib/tflib/network.py @@ -120,6 +120,7 @@ def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, bui self._trainables = None self._var_global_to_local = None self._run_cache = dict() + self.epochs = tf.Variable(0., dtype=tf.float32, name='epochs') def _init_graph(self) -> None: assert self._var_inits is not None @@ -537,6 +538,12 @@ def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = ops.append(var.assign(new_value)) return tf.group(*ops) + def update_epochs(self, epochs: TfExpressionEx = 0) -> tf.Operation: + """Construct a TensorFlow op that updates the epoch counter of this network.""" + with tfutil.absolute_name_scope(self.scope + "/_Epochs"): + op = self.epochs.assign(epochs) + return op + def run(self, *in_arrays: Tuple[Union[np.ndarray, None], ...], input_transform: dict = None, diff --git a/train.py b/train.py index ba12c4e2..bdd907df 100755 --- a/train.py +++ b/train.py @@ -172,7 +172,7 @@ def setup_training_options( cfg_specs = { 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # populated dynamically based on 'gpus' and 'res' - 'wav': dict(ref_gpus=2, kimg=25000, mb=8, mbstd=8, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU + 'aydao': dict(ref_gpus=2, kimg=25000, mb=16, mbstd=8, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU '11gb-gpu': dict(ref_gpus=1, kimg=25000, mb=4, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU '11gb-gpu-complex': dict(ref_gpus=1, kimg=25000, mb=4, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU '24gb-gpu': dict(ref_gpus=1, kimg=25000, mb=8, mbstd=8, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 24GB GPU @@ -212,14 +212,17 @@ def setup_training_options( args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4 # enable mixed-precision training args.G_args.conv_clamp = args.D_args.conv_clamp = 256 # clamp activations to avoid float16 overflow - if cfg == 'wav': + if cfg == 'aydao': args.loss_args.pl_weight = 0 args.G_args.style_mixing_prob = None - - fmap_base = 32 << 10 - fmap_max = 1024 - args.G_args.fmap_base = args.D_args.fmap_base = fmap_base - args.G_args.fmap_max = args.D_args.fmap_max = fmap_max + args.G_args.fmap_base = 32 << 10 + args.G_args.fmap_max = 1024 + args.loss_args.G_top_k = True + # args.loss_args.G_top_k_gamma = 0.9862 # takes 12500 kimg to decay to 0.5 + args.loss_args.G_top_k_gamma = 0.9726 # takes 6250 kimg to decay to 0.5 + args.loss_args.G_top_k_frac = 0.5 + args.minibatch_gpu = 2 # probably will need to set this pretty low with such a large G + # args.G_args.num_fp16_res = 6 # making more layers fp16 can help as well if cfg == 'cifar' or cfg.split('-')[-1] == 'complex': args.loss_args.pl_weight = 0 # disable path length regularization @@ -570,7 +573,7 @@ def main(): group.add_argument('--metricdata', help='Dataset to evaluate metrics against (optional)', metavar='PATH') group = parser.add_argument_group('base config') - group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', '11gb-gpu','11gb-gpu-complex', '24gb-gpu','24gb-gpu-complex', '48gb-gpu', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline', 'wav']) + group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', '11gb-gpu','11gb-gpu-complex', '24gb-gpu','24gb-gpu-complex', '48gb-gpu', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline', 'aydao']) group.add_argument('--gamma', help='Override R1 gamma', type=float, metavar='FLOAT') group.add_argument('--kimg', help='Override training duration', type=int, metavar='INT') diff --git a/training/loss.py b/training/loss.py index 9d819d29..12e0660f 100755 --- a/training/loss.py +++ b/training/loss.py @@ -88,7 +88,7 @@ def eval_D(D, aug, images, labels, report=None, augment_inputs=True, return_aux= # Non-saturating logistic loss with R1 and path length regularizers, used # in the paper "Analyzing and Improving the Image Quality of StyleGAN". -def stylegan2(G, D, aug, fake_labels, real_images, real_labels, r1_gamma=10, pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2, **_kwargs): +def stylegan2(G, D, aug, fake_labels, real_images, real_labels, r1_gamma=10, pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2, G_top_k = False, G_top_k_gamma = 0.9, G_top_k_frac = 0.5, **_kwargs): # Evaluate networks for the main loss. minibatch_size = tf.shape(fake_labels)[0] fake_latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) @@ -98,7 +98,14 @@ def stylegan2(G, D, aug, fake_labels, real_images, real_labels, r1_gamma=10, pl_ # Non-saturating logistic loss from "Generative Adversarial Nets". with tf.name_scope('Loss_main'): - G_loss = tf.nn.softplus(-D_fake.scores) # -log(sigmoid(D_fake.scores)), pylint: disable=invalid-unary-operand-type + D_fake_scores = D_fake.scores + if G_top_k: + k_frac = tf.maximum(G_top_k_gamma ** G.epochs, G_top_k_frac) + k = tf.cast(tf.ceil(tf.cast(minibatch_size, tf.float32) * k_frac), tf.int32) + lowest_k_scores, _ = tf.nn.top_k(-tf.squeeze(D_fake_scores), k=k) # want smallest probabilities not largest + D_fake_scores = tf.expand_dims(-lowest_k_scores, axis=1) + G_loss = tf.nn.softplus(-D_fake_scores) # -log(sigmoid(D_fake_scores)), pylint: disable=invalid-unary-operand-type + D_loss = tf.nn.softplus(D_fake.scores) # -log(1 - sigmoid(D_fake.scores)) D_loss += tf.nn.softplus(-D_real.scores) # -log(sigmoid(D_real.scores)), pylint: disable=invalid-unary-operand-type G_reg = 0 diff --git a/training/training_loop.py b/training/training_loop.py index 7b5e2b3b..7e709776 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -202,6 +202,8 @@ def training_loop( D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in) + Gs_epochs = tf.placeholder(tf.float32, name='Gs_epochs', shape=[]) + Gs_epochs_op = Gs.update_epochs(Gs_epochs) tflib.init_uninitialized_vars() with tf.device('/gpu:0'): peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() @@ -234,6 +236,8 @@ def training_loop( Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup) Gs_beta = 0.5 ** (minibatch_size / max(Gs_nimg, 1e-8)) + epochs = float(100 * cur_nimg / (total_kimg * 1000)) # 100 total top k "epochs" in total_kimg + # Run training ops. for _repeat_idx in range(minibatch_repeats): rounds = range(0, minibatch_size, minibatch_gpu * num_gpus) @@ -247,7 +251,7 @@ def training_loop( tflib.run([G_train_op, data_fetch_op]) if run_G_reg: tflib.run(G_reg_op) - tflib.run([D_train_op, Gs_update_op], {Gs_beta_in: Gs_beta}) + tflib.run([D_train_op, Gs_update_op, Gs_epochs_op], {Gs_beta_in: Gs_beta, Gs_epochs: epochs}) if run_D_reg: tflib.run(D_reg_op) @@ -257,7 +261,7 @@ def training_loop( tflib.run(G_train_op) if run_G_reg: tflib.run(G_reg_op) - tflib.run(Gs_update_op, {Gs_beta_in: Gs_beta}) + tflib.run([Gs_update_op, Gs_epochs_op], {Gs_beta_in: Gs_beta, Gs_epochs: epochs}) for _round in rounds: tflib.run(data_fetch_op) tflib.run(D_train_op)