Skip to content

Commit

Permalink
Merge pull request dvschultz#9 from JCBrouwer/main
Browse files Browse the repository at this point in the history
Add top-k training and aydao's mega config
  • Loading branch information
dvschultz authored Nov 21, 2020
2 parents 0d89d18 + eb24558 commit 0e47c2f
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
* **Flesh Digressions**: @aydao’s circular constant layer script edited to work with ADA see aydao_flesh_digressions.py
* **Raw dataset creations**: Taken from the @skyflynil repo, reduces the size of datasets dramatically. Use `create_from images_raw` and `create_from image_folders_raw` in dataset creation, and use `--use-raw=True` in training (False by default!)
* **align faces script**: From @pbaylies, this script will align images better for projection.
* **top-k training**: Improve generator training by only propagating gradients from images the discriminator was most unsure of: [Sinha & Zhao](https://arxiv.org/abs/2002.06224).
* **@aydao's config**: Extra large config for huge datasets (>100k img)

## StyleGAN2 with adaptive discriminator augmentation (ADA)<br>&mdash; Official TensorFlow implementation

Expand Down
8 changes: 7 additions & 1 deletion dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,8 @@ def create_from_images(tfrecord_dir, image_dir, shuffle):
error('No input images found')

img = np.asarray(PIL.Image.open(image_filenames[0]))
print(img.shape)
shape = img.shape
resolution = img.shape[0]
channels = img.shape[2] if img.ndim == 3 else 1
if img.shape[1] != resolution:
Expand All @@ -709,9 +711,13 @@ def create_from_images(tfrecord_dir, image_dir, shuffle):
for idx in range(order.size):
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
if channels == 1:
print("Greyscale, adding dimension:", image_filenames[order[idx]], img.shape)
img = img[np.newaxis, :, :] # HW => CHW
else:
img = img.transpose([2, 0, 1]) # HWC => CHW
if img.shape != shape:
print("Wrong shape:", image_filenames[order[idx]], img.shape, "should be", shape)
continue
tfr.add_image(img)

#----------------------------------------------------------------------------
Expand Down Expand Up @@ -750,7 +756,7 @@ def create_from_images_raw(tfrecord_dir, image_dir, shuffle, resolution_log2=7,
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
tfr.create_tfr_writer(img.shape)
for idx in range(order.size):
print('loading: ' + image_filenames[order[idx]])
# print('loading: ' + image_filenames[order[idx]])
# img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
# if (img.shape[1] != 1024) or (img.shape[0] != 1024):
# error('Input images must have the same width and height')
Expand Down
7 changes: 7 additions & 0 deletions dnnlib/tflib/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, bui
self._trainables = None
self._var_global_to_local = None
self._run_cache = dict()
self.epochs = tf.Variable(0., dtype=tf.float32, name='epochs')

def _init_graph(self) -> None:
assert self._var_inits is not None
Expand Down Expand Up @@ -537,6 +538,12 @@ def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx =
ops.append(var.assign(new_value))
return tf.group(*ops)

def update_epochs(self, epochs: TfExpressionEx = 0) -> tf.Operation:
"""Construct a TensorFlow op that updates the epoch counter of this network."""
with tfutil.absolute_name_scope(self.scope + "/_Epochs"):
op = self.epochs.assign(epochs)
return op

def run(self,
*in_arrays: Tuple[Union[np.ndarray, None], ...],
input_transform: dict = None,
Expand Down
23 changes: 22 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def setup_training_options(

cfg_specs = {
'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # populated dynamically based on 'gpus' and 'res'
'aydao': dict(ref_gpus=2, kimg=25000, mb=16, mbstd=8, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU
'11gb-gpu': dict(ref_gpus=1, kimg=25000, mb=4, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU
'11gb-gpu-complex': dict(ref_gpus=1, kimg=25000, mb=4, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU
'24gb-gpu': dict(ref_gpus=1, kimg=25000, mb=8, mbstd=8, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 24GB GPU
Expand Down Expand Up @@ -211,6 +212,26 @@ def setup_training_options(
args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4 # enable mixed-precision training
args.G_args.conv_clamp = args.D_args.conv_clamp = 256 # clamp activations to avoid float16 overflow

if cfg == 'aydao':
# disable path length and style mixing regularization
args.loss_args.pl_weight = 0
args.G_args.style_mixing_prob = None

# double generator capacity
args.G_args.fmap_base = 32 << 10
args.G_args.fmap_max = 1024

# enable top k training
args.loss_args.G_top_k = True
# args.loss_args.G_top_k_gamma = 0.99 # takes ~70% of full training from scratch to decay to 0.5
# args.loss_args.G_top_k_gamma = 0.9862 # takes 12500 kimg to decay to 0.5 (~1/2 of total_kimg when training from scratch)
args.loss_args.G_top_k_gamma = 0.9726 # takes 6250 kimg to decay to 0.5 (~1/4 of total_kimg when training from scratch)
args.loss_args.G_top_k_frac = 0.5

# reduce in-memory size, you need a BIG GPU for this model
args.minibatch_gpu = 4 # probably will need to set this pretty low with such a large G, higher values work better for top-k training though
args.G_args.num_fp16_res = 6 # making more layers fp16 can help as well

if cfg == 'cifar' or cfg.split('-')[-1] == 'complex':
args.loss_args.pl_weight = 0 # disable path length regularization
args.G_args.style_mixing_prob = None # disable style mixing
Expand Down Expand Up @@ -560,7 +581,7 @@ def main():
group.add_argument('--metricdata', help='Dataset to evaluate metrics against (optional)', metavar='PATH')

group = parser.add_argument_group('base config')
group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', '11gb-gpu','11gb-gpu-complex', '24gb-gpu','24gb-gpu-complex', '48gb-gpu', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline'])
group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', '11gb-gpu','11gb-gpu-complex', '24gb-gpu','24gb-gpu-complex', '48gb-gpu', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline', 'aydao'])
group.add_argument('--gamma', help='Override R1 gamma', type=float, metavar='FLOAT')
group.add_argument('--kimg', help='Override training duration', type=int, metavar='INT')

Expand Down
11 changes: 9 additions & 2 deletions training/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def eval_D(D, aug, images, labels, report=None, augment_inputs=True, return_aux=
# Non-saturating logistic loss with R1 and path length regularizers, used
# in the paper "Analyzing and Improving the Image Quality of StyleGAN".

def stylegan2(G, D, aug, fake_labels, real_images, real_labels, r1_gamma=10, pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2, **_kwargs):
def stylegan2(G, D, aug, fake_labels, real_images, real_labels, r1_gamma=10, pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2, G_top_k = False, G_top_k_gamma = 0.9, G_top_k_frac = 0.5, **_kwargs):
# Evaluate networks for the main loss.
minibatch_size = tf.shape(fake_labels)[0]
fake_latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
Expand All @@ -98,7 +98,14 @@ def stylegan2(G, D, aug, fake_labels, real_images, real_labels, r1_gamma=10, pl_

# Non-saturating logistic loss from "Generative Adversarial Nets".
with tf.name_scope('Loss_main'):
G_loss = tf.nn.softplus(-D_fake.scores) # -log(sigmoid(D_fake.scores)), pylint: disable=invalid-unary-operand-type
D_fake_scores = D_fake.scores
if G_top_k:
k_frac = tf.maximum(G_top_k_gamma ** G.epochs, G_top_k_frac)
k = tf.cast(tf.ceil(tf.cast(minibatch_size, tf.float32) * k_frac), tf.int32)
lowest_k_scores, _ = tf.nn.top_k(-tf.squeeze(D_fake_scores), k=k) # want smallest probabilities not largest
D_fake_scores = tf.expand_dims(-lowest_k_scores, axis=1)
G_loss = tf.nn.softplus(-D_fake_scores) # -log(sigmoid(D_fake_scores)), pylint: disable=invalid-unary-operand-type

D_loss = tf.nn.softplus(D_fake.scores) # -log(1 - sigmoid(D_fake.scores))
D_loss += tf.nn.softplus(-D_real.scores) # -log(sigmoid(D_real.scores)), pylint: disable=invalid-unary-operand-type
G_reg = 0
Expand Down
8 changes: 6 additions & 2 deletions training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def training_loop(
D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[])
Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in)
Gs_epochs = tf.placeholder(tf.float32, name='Gs_epochs', shape=[])
Gs_epochs_op = Gs.update_epochs(Gs_epochs)
tflib.init_uninitialized_vars()
with tf.device('/gpu:0'):
peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
Expand Down Expand Up @@ -234,6 +236,8 @@ def training_loop(
Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup)
Gs_beta = 0.5 ** (minibatch_size / max(Gs_nimg, 1e-8))

epochs = float(100 * cur_nimg / (total_kimg * 1000)) # 100 total top k "epochs" in total_kimg

# Run training ops.
for _repeat_idx in range(minibatch_repeats):
rounds = range(0, minibatch_size, minibatch_gpu * num_gpus)
Expand All @@ -247,7 +251,7 @@ def training_loop(
tflib.run([G_train_op, data_fetch_op])
if run_G_reg:
tflib.run(G_reg_op)
tflib.run([D_train_op, Gs_update_op], {Gs_beta_in: Gs_beta})
tflib.run([D_train_op, Gs_update_op, Gs_epochs_op], {Gs_beta_in: Gs_beta, Gs_epochs: epochs})
if run_D_reg:
tflib.run(D_reg_op)

Expand All @@ -257,7 +261,7 @@ def training_loop(
tflib.run(G_train_op)
if run_G_reg:
tflib.run(G_reg_op)
tflib.run(Gs_update_op, {Gs_beta_in: Gs_beta})
tflib.run([Gs_update_op, Gs_epochs_op], {Gs_beta_in: Gs_beta, Gs_epochs: epochs})
for _round in rounds:
tflib.run(data_fetch_op)
tflib.run(D_train_op)
Expand Down

0 comments on commit 0e47c2f

Please sign in to comment.