NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
+
+
+
+
1. Definitions
+
+
“Licensor” means any person or entity that distributes its Work.
+
+
“Software” means the original work of authorship made available under
+this License.
+
+
“Work” means the Software and any additions to or derivative works of
+the Software that are made available under this License.
+
+
The terms “reproduce,” “reproduction,” “derivative works,” and
+“distribution” have the meaning as provided under U.S. copyright law;
+provided, however, that for the purposes of this License, derivative
+works shall not include works that remain separable from, or merely
+link (or bind by name) to the interfaces of, the Work.
+
+
Works, including the Software, are “made available” under this License
+by including in or with the Work either (a) a copyright notice
+referencing the applicability of this License to the Work, or (b) a
+copy of this License.
+
+
2. License Grants
+
+
2.1 Copyright Grant. Subject to the terms and conditions of this
+License, each Licensor grants to you a perpetual, worldwide,
+non-exclusive, royalty-free, copyright license to reproduce,
+prepare derivative works of, publicly display, publicly perform,
+sublicense and distribute its Work and any resulting derivative
+works in any form.
+
+
3. Limitations
+
+
3.1 Redistribution. You may reproduce or distribute the Work only
+if (a) you do so under this License, (b) you include a complete
+copy of this License with your distribution, and (c) you retain
+without modification any copyright, patent, trademark, or
+attribution notices that are present in the Work.
+
+
3.2 Derivative Works. You may specify that additional or different
+terms apply to the use, reproduction, and distribution of your
+derivative works of the Work (“Your Terms”) only if (a) Your Terms
+provide that the use limitation in Section 3.3 applies to your
+derivative works, and (b) you identify the specific derivative
+works that are subject to Your Terms. Notwithstanding Your Terms,
+this License (including the redistribution requirements in Section
+3.1) will continue to apply to the Work itself.
+
+
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for
+use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work
+and any derivative works commercially. As used herein, “non-commercially” means for research or
+evaluation purposes only.
+
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
+against any Licensor (including any claim, cross-claim or
+counterclaim in a lawsuit) to enforce any patents that you allege
+are infringed by any Work, then your rights under this License from
+such Licensor (including the grant in Section 2.1) will terminate immediately.
+
+
3.5 Trademarks. This License does not grant any rights to use any
+Licensor’s or its affiliates’ names, logos, or trademarks, except
+as necessary to reproduce the notices described in this License.
+
+
3.6 Termination. If you violate any term of this License, then your
+rights under this License (including the grant in Section 2.1)
+will terminate immediately.
+
+
4. Disclaimer of Warranty.
+
+
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+THIS LICENSE.
+
+
5. Limitation of Liability.
+
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+THE POSSIBILITY OF SUCH DAMAGES.
+
+
+
+
+
+
diff --git a/docs/stylegan2-ada-teaser-1024x252.png b/docs/stylegan2-ada-teaser-1024x252.png
new file mode 100755
index 00000000..14eb641b
Binary files /dev/null and b/docs/stylegan2-ada-teaser-1024x252.png differ
diff --git a/docs/stylegan2-ada-training-curves.png b/docs/stylegan2-ada-training-curves.png
new file mode 100755
index 00000000..94cbb5e5
Binary files /dev/null and b/docs/stylegan2-ada-training-curves.png differ
diff --git a/docs/train-help.txt b/docs/train-help.txt
new file mode 100755
index 00000000..0f1a9b2f
--- /dev/null
+++ b/docs/train-help.txt
@@ -0,0 +1,89 @@
+usage: train.py [-h] --outdir DIR [--gpus INT] [--snap INT] [--seed INT] [-n]
+ --data PATH [--res INT] [--mirror BOOL] [--metrics LIST]
+ [--metricdata PATH]
+ [--cfg {auto,stylegan2,paper256,paper512,paper1024,cifar,cifarbaseline}]
+ [--gamma FLOAT] [--kimg INT] [--aug {noaug,ada,fixed,adarv}]
+ [--p FLOAT] [--target TARGET]
+ [--augpipe {blit,geom,color,filter,noise,cutout,bg,bgc,bgcf,bgcfn,bgcfnc}]
+ [--cmethod {nocmethod,bcr,zcr,pagan,wgangp,auxrot,spectralnorm,shallowmap,adropout}]
+ [--dcap FLOAT] [--resume RESUME] [--freezed INT]
+
+Train a GAN using the techniques described in the paper
+"Training Generative Adversarial Networks with Limited Data".
+
+optional arguments:
+ -h, --help show this help message and exit
+
+general options:
+ --outdir DIR Where to save the results (required)
+ --gpus INT Number of GPUs to use (default: 1 gpu)
+ --snap INT Snapshot interval (default: 50 ticks)
+ --seed INT Random seed (default: 1000)
+ -n, --dry-run Print training options and exit
+
+training dataset:
+ --data PATH Training dataset path (required)
+ --res INT Dataset resolution (default: highest available)
+ --mirror BOOL Augment dataset with x-flips (default: false)
+
+metrics:
+ --metrics LIST Comma-separated list or "none" (default: fid50k_full)
+ --metricdata PATH Dataset to evaluate metrics against (optional)
+
+base config:
+ --cfg {auto,stylegan2,paper256,paper512,paper1024,cifar,cifarbaseline}
+ Base config (default: auto)
+ --gamma FLOAT Override R1 gamma
+ --kimg INT Override training duration
+
+discriminator augmentation:
+ --aug {noaug,ada,fixed,adarv}
+ Augmentation mode (default: ada)
+ --p FLOAT Specify augmentation probability for --aug=fixed
+ --target TARGET Override ADA target for --aug=ada and --aug=adarv
+ --augpipe {blit,geom,color,filter,noise,cutout,bg,bgc,bgcf,bgcfn,bgcfnc}
+ Augmentation pipeline (default: bgc)
+
+comparison methods:
+ --cmethod {nocmethod,bcr,zcr,pagan,wgangp,auxrot,spectralnorm,shallowmap,adropout}
+ Comparison method (default: nocmethod)
+ --dcap FLOAT Multiplier for discriminator capacity
+
+transfer learning:
+ --resume RESUME Resume from network pickle (default: noresume)
+ --freezed INT Freeze-D (default: 0 discriminator layers)
+
+examples:
+
+ # Train custom dataset using 1 GPU.
+ python train.py --outdir=~/training-runs --gpus=1 --data=~/datasets/custom
+
+ # Train class-conditional CIFAR-10 using 2 GPUs.
+ python train.py --outdir=~/training-runs --gpus=2 --data=~/datasets/cifar10c \
+ --cfg=cifar
+
+ # Transfer learn MetFaces from FFHQ using 4 GPUs.
+ python train.py --outdir=~/training-runs --gpus=4 --data=~/datasets/metfaces \
+ --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
+
+ # Reproduce original StyleGAN2 config F.
+ python train.py --outdir=~/training-runs --gpus=8 --data=~/datasets/ffhq \
+ --cfg=stylegan2 --res=1024 --mirror=1 --aug=noaug
+
+available base configs (--cfg):
+ auto Automatically select reasonable defaults based on resolution
+ and GPU count. Good starting point for new datasets.
+ stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
+ paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
+ paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
+ paper1024 Reproduce results for MetFaces at 1024x1024.
+ cifar Reproduce results for CIFAR-10 (tuned configuration).
+ cifarbaseline Reproduce results for CIFAR-10 (baseline configuration).
+
+transfer learning source networks (--resume):
+ ffhq256 FFHQ trained at 256x256 resolution.
+ ffhq512 FFHQ trained at 512x512 resolution.
+ ffhq1024 FFHQ trained at 1024x1024 resolution.
+ celebahq256 CelebA-HQ trained at 256x256 resolution.
+ lsundog256 LSUN Dog trained at 256x256 resolution.
+ Custom network pickle.
diff --git a/generate.py b/generate.py
new file mode 100755
index 00000000..42210a5a
--- /dev/null
+++ b/generate.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Generate images using pretrained network pickle."""
+
+import argparse
+import os
+import pickle
+import re
+
+import numpy as np
+import PIL.Image
+
+import dnnlib
+import dnnlib.tflib as tflib
+
+#----------------------------------------------------------------------------
+
+def generate_images(network_pkl, seeds, truncation_psi, outdir, class_idx, dlatents_npz):
+ tflib.init_tf()
+ print('Loading networks from "%s"...' % network_pkl)
+ with dnnlib.util.open_url(network_pkl) as fp:
+ _G, _D, Gs = pickle.load(fp)
+
+ os.makedirs(outdir, exist_ok=True)
+
+ # Render images for a given dlatent vector.
+ if dlatents_npz is not None:
+ print(f'Generating images from dlatents file "{dlatents_npz}"')
+ dlatents = np.load(dlatents_npz)['dlatents']
+ assert dlatents.shape[1:] == (18, 512) # [N, 18, 512]
+ imgs = Gs.components.synthesis.run(dlatents, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))
+ for i, img in enumerate(imgs):
+ fname = f'{outdir}/dlatent{i:02d}.png'
+ print (f'Saved {fname}')
+ PIL.Image.fromarray(img, 'RGB').save(fname)
+ return
+
+ # Render images for dlatents initialized from random seeds.
+ Gs_kwargs = {
+ 'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
+ 'randomize_noise': False
+ }
+ if truncation_psi is not None:
+ Gs_kwargs['truncation_psi'] = truncation_psi
+
+ noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
+ label = np.zeros([1] + Gs.input_shapes[1][1:])
+ if class_idx is not None:
+ label[:, class_idx] = 1
+
+ for seed_idx, seed in enumerate(seeds):
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
+ rnd = np.random.RandomState(seed)
+ z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
+ tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
+ images = Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel]
+ PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/seed{seed:04d}.png')
+
+#----------------------------------------------------------------------------
+
+def _parse_num_range(s):
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
+
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ m = range_re.match(s)
+ if m:
+ return list(range(int(m.group(1)), int(m.group(2))+1))
+ vals = s.split(',')
+ return [int(x) for x in vals]
+
+#----------------------------------------------------------------------------
+
+_examples = '''examples:
+
+ # Generate curated MetFaces images without truncation (Fig.10 left)
+ python %(prog)s --outdir=out --trunc=1 --seeds=85,265,297,849 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
+
+ # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
+ python %(prog)s --outdir=out --trunc=0.7 --seeds=600-605 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
+
+ # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
+ python %(prog)s --outdir=out --trunc=1 --seeds=0-35 --class=1 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl
+
+ # Render image from projected latent vector
+ python %(prog)s --outdir=out --dlatents=out/dlatents.npz \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl
+'''
+
+#----------------------------------------------------------------------------
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Generate images using pretrained network pickle.',
+ epilog=_examples,
+ formatter_class=argparse.RawDescriptionHelpFormatter
+ )
+
+ parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
+ g = parser.add_mutually_exclusive_group(required=True)
+ g.add_argument('--seeds', type=_parse_num_range, help='List of random seeds')
+ g.add_argument('--dlatents', dest='dlatents_npz', help='Generate images for saved dlatents')
+ parser.add_argument('--trunc', dest='truncation_psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5)
+ parser.add_argument('--class', dest='class_idx', type=int, help='Class label (default: unconditional)')
+ parser.add_argument('--outdir', help='Where to save the output images', required=True, metavar='DIR')
+
+ args = parser.parse_args()
+ generate_images(**vars(args))
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ main()
+
+#----------------------------------------------------------------------------
diff --git a/metrics/__init__.py b/metrics/__init__.py
new file mode 100755
index 00000000..2c61c745
--- /dev/null
+++ b/metrics/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/metrics/frechet_inception_distance.py b/metrics/frechet_inception_distance.py
new file mode 100755
index 00000000..1f6be674
--- /dev/null
+++ b/metrics/frechet_inception_distance.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Frechet Inception Distance (FID) from the paper
+"GANs trained by a two time-scale update rule converge to a local Nash equilibrium"."""
+
+import os
+import pickle
+import numpy as np
+import scipy
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+
+from metrics import metric_base
+
+#----------------------------------------------------------------------------
+
+class FID(metric_base.MetricBase):
+ def __init__(self, max_reals, num_fakes, minibatch_per_gpu, use_cached_real_stats=True, **kwargs):
+ super().__init__(**kwargs)
+ self.max_reals = max_reals
+ self.num_fakes = num_fakes
+ self.minibatch_per_gpu = minibatch_per_gpu
+ self.use_cached_real_stats = use_cached_real_stats
+
+ def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ
+ minibatch_size = num_gpus * self.minibatch_per_gpu
+ with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_features.pkl') as f: # identical to http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ feature_net = pickle.load(f)
+
+ # Calculate statistics for reals.
+ cache_file = self._get_cache_file_for_reals(max_reals=self.max_reals)
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ if self.use_cached_real_stats and os.path.isfile(cache_file):
+ with open(cache_file, 'rb') as f:
+ mu_real, sigma_real = pickle.load(f)
+ else:
+ nfeat = feature_net.output_shape[1]
+ mu_real = np.zeros(nfeat)
+ sigma_real = np.zeros([nfeat, nfeat])
+ num_real = 0
+ for images, _labels, num in self._iterate_reals(minibatch_size):
+ if self.max_reals is not None:
+ num = min(num, self.max_reals - num_real)
+ if images.shape[1] == 1:
+ images = np.tile(images, [1, 3, 1, 1])
+ for feat in list(feature_net.run(images, num_gpus=num_gpus, assume_frozen=True))[:num]:
+ mu_real += feat
+ sigma_real += np.outer(feat, feat)
+ num_real += 1
+ if self.max_reals is not None and num_real >= self.max_reals:
+ break
+ mu_real /= num_real
+ sigma_real /= num_real
+ sigma_real -= np.outer(mu_real, mu_real)
+ with open(cache_file, 'wb') as f:
+ pickle.dump((mu_real, sigma_real), f)
+
+ # Construct TensorFlow graph.
+ result_expr = []
+ for gpu_idx in range(num_gpus):
+ with tf.device('/gpu:%d' % gpu_idx):
+ Gs_clone = Gs.clone()
+ feature_net_clone = feature_net.clone()
+ latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
+ labels = self._get_random_labels_tf(self.minibatch_per_gpu)
+ images = Gs_clone.get_output_for(latents, labels, **G_kwargs)
+ if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
+ images = tflib.convert_images_to_uint8(images)
+ result_expr.append(feature_net_clone.get_output_for(images))
+
+ # Calculate statistics for fakes.
+ feat_fake = []
+ for begin in range(0, self.num_fakes, minibatch_size):
+ self._report_progress(begin, self.num_fakes)
+ feat_fake += list(np.concatenate(tflib.run(result_expr), axis=0))
+ feat_fake = np.stack(feat_fake[:self.num_fakes])
+ mu_fake = np.mean(feat_fake, axis=0)
+ sigma_fake = np.cov(feat_fake, rowvar=False)
+
+ # Calculate FID.
+ m = np.square(mu_fake - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
+ dist = m + np.trace(sigma_fake + sigma_real - 2*s)
+ self._report_result(np.real(dist))
+
+#----------------------------------------------------------------------------
diff --git a/metrics/inception_score.py b/metrics/inception_score.py
new file mode 100755
index 00000000..c33f0893
--- /dev/null
+++ b/metrics/inception_score.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Inception Score (IS) from the paper
+"Improved techniques for training GANs"."""
+
+import pickle
+import numpy as np
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+
+from metrics import metric_base
+
+#----------------------------------------------------------------------------
+
+class IS(metric_base.MetricBase):
+ def __init__(self, num_images, num_splits, minibatch_per_gpu, **kwargs):
+ super().__init__(**kwargs)
+ self.num_images = num_images
+ self.num_splits = num_splits
+ self.minibatch_per_gpu = minibatch_per_gpu
+
+ def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ
+ minibatch_size = num_gpus * self.minibatch_per_gpu
+ with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_softmax.pkl') as f:
+ inception = pickle.load(f)
+ activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)
+
+ # Construct TensorFlow graph.
+ result_expr = []
+ for gpu_idx in range(num_gpus):
+ with tf.device(f'/gpu:{gpu_idx}'):
+ Gs_clone = Gs.clone()
+ inception_clone = inception.clone()
+ latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
+ labels = self._get_random_labels_tf(self.minibatch_per_gpu)
+ images = Gs_clone.get_output_for(latents, labels, **G_kwargs)
+ if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
+ images = tflib.convert_images_to_uint8(images)
+ result_expr.append(inception_clone.get_output_for(images))
+
+ # Calculate activations for fakes.
+ for begin in range(0, self.num_images, minibatch_size):
+ self._report_progress(begin, self.num_images)
+ end = min(begin + minibatch_size, self.num_images)
+ activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]
+
+ # Calculate IS.
+ scores = []
+ for i in range(self.num_splits):
+ part = activations[i * self.num_images // self.num_splits : (i + 1) * self.num_images // self.num_splits]
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
+ kl = np.mean(np.sum(kl, 1))
+ scores.append(np.exp(kl))
+ self._report_result(np.mean(scores), suffix='_mean')
+ self._report_result(np.std(scores), suffix='_std')
+
+#----------------------------------------------------------------------------
diff --git a/metrics/kernel_inception_distance.py b/metrics/kernel_inception_distance.py
new file mode 100755
index 00000000..20fa8db5
--- /dev/null
+++ b/metrics/kernel_inception_distance.py
@@ -0,0 +1,94 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Kernel Inception Distance (KID) from the paper
+"Demystifying MMD GANs"."""
+
+import os
+import pickle
+import numpy as np
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+
+from metrics import metric_base
+
+#----------------------------------------------------------------------------
+
+def compute_kid(feat_real, feat_fake, num_subsets=100, max_subset_size=1000):
+ n = feat_real.shape[1]
+ m = min(min(feat_real.shape[0], feat_fake.shape[0]), max_subset_size)
+ t = 0
+ for _subset_idx in range(num_subsets):
+ x = feat_fake[np.random.choice(feat_fake.shape[0], m, replace=False)]
+ y = feat_real[np.random.choice(feat_real.shape[0], m, replace=False)]
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
+ b = (x @ y.T / n + 1) ** 3
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
+ return t / num_subsets / m
+
+#----------------------------------------------------------------------------
+
+class KID(metric_base.MetricBase):
+ def __init__(self, max_reals, num_fakes, minibatch_per_gpu, use_cached_real_stats=True, **kwargs):
+ super().__init__(**kwargs)
+ self.max_reals = max_reals
+ self.num_fakes = num_fakes
+ self.minibatch_per_gpu = minibatch_per_gpu
+ self.use_cached_real_stats = use_cached_real_stats
+
+ def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ
+ minibatch_size = num_gpus * self.minibatch_per_gpu
+ with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_features.pkl') as f: # identical to http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ feature_net = pickle.load(f)
+
+ # Calculate statistics for reals.
+ cache_file = self._get_cache_file_for_reals(max_reals=self.max_reals)
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ if self.use_cached_real_stats and os.path.isfile(cache_file):
+ with open(cache_file, 'rb') as f:
+ feat_real = pickle.load(f)
+ else:
+ feat_real = []
+ for images, _labels, num in self._iterate_reals(minibatch_size):
+ if self.max_reals is not None:
+ num = min(num, self.max_reals - len(feat_real))
+ if images.shape[1] == 1:
+ images = np.tile(images, [1, 3, 1, 1])
+ feat_real += list(feature_net.run(images, num_gpus=num_gpus, assume_frozen=True))[:num]
+ if self.max_reals is not None and len(feat_real) >= self.max_reals:
+ break
+ feat_real = np.stack(feat_real)
+ with open(cache_file, 'wb') as f:
+ pickle.dump(feat_real, f)
+
+ # Construct TensorFlow graph.
+ result_expr = []
+ for gpu_idx in range(num_gpus):
+ with tf.device('/gpu:%d' % gpu_idx):
+ Gs_clone = Gs.clone()
+ feature_net_clone = feature_net.clone()
+ latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
+ labels = self._get_random_labels_tf(self.minibatch_per_gpu)
+ images = Gs_clone.get_output_for(latents, labels, **G_kwargs)
+ if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
+ images = tflib.convert_images_to_uint8(images)
+ result_expr.append(feature_net_clone.get_output_for(images))
+
+ # Calculate statistics for fakes.
+ feat_fake = []
+ for begin in range(0, self.num_fakes, minibatch_size):
+ self._report_progress(begin, self.num_fakes)
+ feat_fake += list(np.concatenate(tflib.run(result_expr), axis=0))
+ feat_fake = np.stack(feat_fake[:self.num_fakes])
+
+ # Calculate KID.
+ kid = compute_kid(feat_real, feat_fake)
+ self._report_result(np.real(kid), fmt='%-12.8f')
+
+#----------------------------------------------------------------------------
diff --git a/metrics/linear_separability.py b/metrics/linear_separability.py
new file mode 100755
index 00000000..d95e12b8
--- /dev/null
+++ b/metrics/linear_separability.py
@@ -0,0 +1,184 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Linear Separability (LS) from the paper
+"A Style-Based Generator Architecture for Generative Adversarial Networks"."""
+
+import pickle
+from collections import defaultdict
+import numpy as np
+import sklearn.svm
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+
+from metrics import metric_base
+
+#----------------------------------------------------------------------------
+
+classifier_urls = [
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-00-male.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-01-smiling.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-02-attractive.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-03-wavy-hair.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-04-young.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-05-5-o-clock-shadow.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-06-arched-eyebrows.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-07-bags-under-eyes.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-08-bald.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-09-bangs.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-10-big-lips.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-11-big-nose.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-12-black-hair.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-13-blond-hair.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-14-blurry.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-15-brown-hair.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-16-bushy-eyebrows.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-17-chubby.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-18-double-chin.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-19-eyeglasses.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-20-goatee.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-21-gray-hair.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-22-heavy-makeup.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-23-high-cheekbones.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-24-mouth-slightly-open.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-25-mustache.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-26-narrow-eyes.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-27-no-beard.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-28-oval-face.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-29-pale-skin.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-30-pointy-nose.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-31-receding-hairline.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-32-rosy-cheeks.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-33-sideburns.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-34-straight-hair.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-35-wearing-earrings.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-36-wearing-hat.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-37-wearing-lipstick.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-38-wearing-necklace.pkl',
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-39-wearing-necktie.pkl',
+]
+
+#----------------------------------------------------------------------------
+
+def prob_normalize(p):
+ p = np.asarray(p).astype(np.float32)
+ assert len(p.shape) == 2
+ return p / np.sum(p)
+
+def mutual_information(p):
+ p = prob_normalize(p)
+ px = np.sum(p, axis=1)
+ py = np.sum(p, axis=0)
+ result = 0.0
+ for x in range(p.shape[0]):
+ p_x = px[x]
+ for y in range(p.shape[1]):
+ p_xy = p[x][y]
+ p_y = py[y]
+ if p_xy > 0.0:
+ result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output
+ return result
+
+def entropy(p):
+ p = prob_normalize(p)
+ result = 0.0
+ for x in range(p.shape[0]):
+ for y in range(p.shape[1]):
+ p_xy = p[x][y]
+ if p_xy > 0.0:
+ result -= p_xy * np.log2(p_xy)
+ return result
+
+def conditional_entropy(p):
+ # H(Y|X) where X corresponds to axis 0, Y to axis 1
+ # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0?
+ p = prob_normalize(p)
+ y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y)
+ return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up.
+
+#----------------------------------------------------------------------------
+
+class LS(metric_base.MetricBase):
+ def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs):
+ assert num_keep <= num_samples
+ super().__init__(**kwargs)
+ self.num_samples = num_samples
+ self.num_keep = num_keep
+ self.attrib_indices = attrib_indices
+ self.minibatch_per_gpu = minibatch_per_gpu
+
+ def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ
+ minibatch_size = num_gpus * self.minibatch_per_gpu
+
+ # Construct TensorFlow graph for each GPU.
+ result_expr = []
+ for gpu_idx in range(num_gpus):
+ with tf.device(f'/gpu:{gpu_idx}'):
+ Gs_clone = Gs.clone()
+
+ # Generate images.
+ latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
+ labels = self._get_random_labels_tf(self.minibatch_per_gpu)
+ dlatents = Gs_clone.components.mapping.get_output_for(latents, labels, **G_kwargs)
+ images = Gs_clone.get_output_for(latents, None, **G_kwargs)
+ if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
+
+ # Downsample to 256x256. The attribute classifiers were built for 256x256.
+ if images.shape[2] > 256:
+ factor = images.shape[2] // 256
+ images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
+ images = tf.reduce_mean(images, axis=[3, 5])
+
+ # Run classifier for each attribute.
+ result_dict = dict(latents=latents, dlatents=dlatents[:,-1])
+ for attrib_idx in self.attrib_indices:
+ with dnnlib.util.open_url(classifier_urls[attrib_idx]) as f:
+ classifier = pickle.load(f)
+ logits = classifier.get_output_for(images, None)
+ predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1))
+ result_dict[attrib_idx] = predictions
+ result_expr.append(result_dict)
+
+ # Sampling loop.
+ results = []
+ for begin in range(0, self.num_samples, minibatch_size):
+ self._report_progress(begin, self.num_samples)
+ results += tflib.run(result_expr)
+ results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()}
+
+ # Calculate conditional entropy for each attribute.
+ conditional_entropies = defaultdict(list)
+ for attrib_idx in self.attrib_indices:
+ # Prune the least confident samples.
+ pruned_indices = list(range(self.num_samples))
+ pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i]))
+ pruned_indices = pruned_indices[:self.num_keep]
+
+ # Fit SVM to the remaining samples.
+ svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1)
+ for space in ['latents', 'dlatents']:
+ svm_inputs = results[space][pruned_indices]
+ try:
+ svm = sklearn.svm.LinearSVC()
+ svm.fit(svm_inputs, svm_targets)
+ svm.score(svm_inputs, svm_targets)
+ svm_outputs = svm.predict(svm_inputs)
+ except:
+ svm_outputs = svm_targets # assume perfect prediction
+
+ # Calculate conditional entropy.
+ p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)]
+ conditional_entropies[space].append(conditional_entropy(p))
+
+ # Calculate separability scores.
+ scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()}
+ self._report_result(scores['latents'], suffix='_z')
+ self._report_result(scores['dlatents'], suffix='_w')
+
+#----------------------------------------------------------------------------
diff --git a/metrics/metric_base.py b/metrics/metric_base.py
new file mode 100755
index 00000000..84fab746
--- /dev/null
+++ b/metrics/metric_base.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Common definitions for quality metrics."""
+
+import os
+import time
+import hashlib
+import pickle
+import numpy as np
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+
+from training import dataset
+
+#----------------------------------------------------------------------------
+# Base class for metrics.
+
+class MetricBase:
+ def __init__(self, name, force_dataset_args={}, force_G_kwargs={}):
+ # Constructor args.
+ self.name = name
+ self.force_dataset_args = force_dataset_args
+ self.force_G_kwargs = force_G_kwargs
+
+ # Configuration.
+ self._dataset_args = dnnlib.EasyDict()
+ self._run_dir = None
+ self._progress_fn = None
+
+ # Internal state.
+ self._results = []
+ self._network_name = ''
+ self._eval_time = 0
+ self._dataset = None
+
+ def configure(self, dataset_args={}, run_dir=None, progress_fn=None):
+ self._dataset_args = dnnlib.EasyDict(dataset_args)
+ self._dataset_args.update(self.force_dataset_args)
+ self._run_dir = run_dir
+ self._progress_fn = progress_fn
+
+ def run(self, network_pkl, num_gpus=1, G_kwargs=dict(is_validation=True)):
+ self._results = []
+ self._network_name = os.path.splitext(os.path.basename(network_pkl))[0]
+ self._eval_time = 0
+ self._dataset = None
+
+ with tf.Graph().as_default(), tflib.create_session().as_default(): # pylint: disable=not-context-manager
+ self._report_progress(0, 1)
+ time_begin = time.time()
+ with dnnlib.util.open_url(network_pkl) as f:
+ G, D, Gs = pickle.load(f)
+
+ G_kwargs = dnnlib.EasyDict(G_kwargs)
+ G_kwargs.update(self.force_G_kwargs)
+ self._evaluate(G=G, D=D, Gs=Gs, G_kwargs=G_kwargs, num_gpus=num_gpus)
+
+ self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init
+ self._report_progress(1, 1)
+ if self._dataset is not None:
+ self._dataset.close()
+ self._dataset = None
+
+ result_str = self.get_result_str()
+ print(result_str)
+ if self._run_dir is not None and os.path.isdir(self._run_dir):
+ with open(os.path.join(self._run_dir, f'metric-{self.name}.txt'), 'at') as f:
+ f.write(result_str + '\n')
+
+ def get_result_str(self):
+ title = self._network_name
+ if len(title) > 29:
+ title = '...' + title[-26:]
+ result_str = f'{title:<30s} time {dnnlib.util.format_time(self._eval_time):<12s}'
+ for res in self._results:
+ result_str += f' {self.name}{res.suffix} {res.fmt % res.value}'
+ return result_str.strip()
+
+ def update_autosummaries(self):
+ for res in self._results:
+ tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value)
+
+ def _evaluate(self, **_kwargs):
+ raise NotImplementedError # to be overridden by subclasses
+
+ def _report_result(self, value, suffix='', fmt='%-10.4f'):
+ self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
+
+ def _report_progress(self, cur, total):
+ if self._progress_fn is not None:
+ self._progress_fn(cur, total)
+
+ def _get_cache_file_for_reals(self, extension='pkl', **kwargs):
+ all_args = dnnlib.EasyDict(metric_name=self.name)
+ all_args.update(self._dataset_args)
+ all_args.update(kwargs)
+ md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8'))
+ dataset_name = os.path.splitext(os.path.basename(self._dataset_args.path))[0]
+ return dnnlib.make_cache_dir_path('metrics', f'{md5.hexdigest()}-{self.name}-{dataset_name}.{extension}')
+
+ def _get_dataset_obj(self):
+ if self._dataset is None:
+ self._dataset = dataset.load_dataset(**self._dataset_args)
+ return self._dataset
+
+ def _iterate_reals(self, minibatch_size):
+ print(f'Calculating real image statistics for {self.name}...')
+ dataset_obj = self._get_dataset_obj()
+ while True:
+ images = []
+ labels = []
+ for _ in range(minibatch_size):
+ image, label = dataset_obj.get_minibatch_np(1)
+ if image is None:
+ break
+ images.append(image)
+ labels.append(label)
+ num = len(images)
+ if num == 0:
+ break
+ images = np.concatenate(images + [images[-1]] * (minibatch_size - num), axis=0)
+ labels = np.concatenate(labels + [labels[-1]] * (minibatch_size - num), axis=0)
+ yield images, labels, num
+ if num < minibatch_size:
+ break
+
+ def _get_random_labels_tf(self, minibatch_size):
+ return self._get_dataset_obj().get_random_labels_tf(minibatch_size)
+
+#----------------------------------------------------------------------------
diff --git a/metrics/metric_defaults.py b/metrics/metric_defaults.py
new file mode 100755
index 00000000..b456e9c6
--- /dev/null
+++ b/metrics/metric_defaults.py
@@ -0,0 +1,36 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Default metric definitions."""
+
+from dnnlib import EasyDict
+
+#----------------------------------------------------------------------------
+
+metric_defaults = EasyDict([(args.name, args) for args in [
+ # ADA paper.
+ EasyDict(name='fid50k_full', class_name='metrics.frechet_inception_distance.FID', max_reals=None, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)),
+ EasyDict(name='kid50k_full', class_name='metrics.kernel_inception_distance.KID', max_reals=1000000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)),
+ EasyDict(name='pr50k3_full', class_name='metrics.precision_recall.PR', max_reals=200000, num_fakes=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)),
+ EasyDict(name='is50k', class_name='metrics.inception_score.IS', num_images=50000, num_splits=10, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)),
+
+ # Legacy: StyleGAN2.
+ EasyDict(name='fid50k', class_name='metrics.frechet_inception_distance.FID', max_reals=50000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)),
+ EasyDict(name='kid50k', class_name='metrics.kernel_inception_distance.KID', max_reals=50000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)),
+ EasyDict(name='pr50k3', class_name='metrics.precision_recall.PR', max_reals=50000, num_fakes=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000, force_dataset_args=dict(shuffle=False, max_images=None)),
+ EasyDict(name='ppl2_wend', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)),
+
+ # Legacy: StyleGAN.
+ EasyDict(name='ppl_zfull', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)),
+ EasyDict(name='ppl_wfull', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)),
+ EasyDict(name='ppl_zend', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)),
+ EasyDict(name='ppl_wend', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)),
+ EasyDict(name='ls', class_name='metrics.linear_separability.LS', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4, force_dataset_args=dict(shuffle=False, max_images=None)),
+]])
+
+#----------------------------------------------------------------------------
diff --git a/metrics/perceptual_path_length.py b/metrics/perceptual_path_length.py
new file mode 100755
index 00000000..15a327ba
--- /dev/null
+++ b/metrics/perceptual_path_length.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Perceptual Path Length (PPL) from the paper
+"A Style-Based Generator Architecture for Generative Adversarial Networks"."""
+
+import pickle
+import numpy as np
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+
+from metrics import metric_base
+
+#----------------------------------------------------------------------------
+
+# Normalize batch of vectors.
+def normalize(v):
+ return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
+
+# Spherical interpolation of a batch of vectors.
+def slerp(a, b, t):
+ a = normalize(a)
+ b = normalize(b)
+ d = tf.reduce_sum(a * b, axis=-1, keepdims=True)
+ p = t * tf.math.acos(d)
+ c = normalize(b - d * a)
+ d = a * tf.math.cos(p) + c * tf.math.sin(p)
+ return normalize(d)
+
+#----------------------------------------------------------------------------
+
+class PPL(metric_base.MetricBase):
+ def __init__(self, num_samples, epsilon, space, sampling, crop, minibatch_per_gpu, **kwargs):
+ assert space in ['z', 'w']
+ assert sampling in ['full', 'end']
+ super().__init__(**kwargs)
+ self.num_samples = num_samples
+ self.epsilon = epsilon
+ self.space = space
+ self.sampling = sampling
+ self.crop = crop
+ self.minibatch_per_gpu = minibatch_per_gpu
+
+ def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ
+ minibatch_size = num_gpus * self.minibatch_per_gpu
+
+ # Construct TensorFlow graph.
+ distance_expr = []
+ for gpu_idx in range(num_gpus):
+ with tf.device(f'/gpu:{gpu_idx}'):
+ Gs_clone = Gs.clone()
+ noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')]
+
+ # Generate random latents and interpolation t-values.
+ lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:])
+ lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0)
+ labels = tf.reshape(tf.tile(self._get_random_labels_tf(self.minibatch_per_gpu), [1, 2]), [self.minibatch_per_gpu * 2, -1])
+
+ # Interpolate in W or Z.
+ if self.space == 'w':
+ dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, labels, **G_kwargs)
+ dlat_t01 = tf.cast(dlat_t01, tf.float32)
+ dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2]
+ dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis])
+ dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon)
+ dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape)
+ else: # space == 'z'
+ lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2]
+ lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis])
+ lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon)
+ lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape)
+ dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, labels, **G_kwargs)
+
+ # Synthesize images.
+ with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch
+ images = Gs_clone.components.synthesis.get_output_for(dlat_e01, randomize_noise=False, **G_kwargs)
+ images = tf.cast(images, tf.float32)
+
+ # Crop only the face region.
+ if self.crop:
+ c = int(images.shape[2] // 8)
+ images = images[:, :, c*3 : c*7, c*2 : c*6]
+
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
+ factor = images.shape[2] // 256
+ if factor > 1:
+ images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
+ images = tf.reduce_mean(images, axis=[3,5])
+
+ # Scale dynamic range from [-1,1] to [0,255] for VGG.
+ images = (images + 1) * (255 / 2)
+ if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
+
+ # Evaluate perceptual distance.
+ img_e0, img_e1 = images[0::2], images[1::2]
+ with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16_zhang_perceptual.pkl') as f:
+ distance_measure = pickle.load(f)
+ distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2))
+
+ # Sampling loop.
+ all_distances = []
+ for begin in range(0, self.num_samples, minibatch_size):
+ self._report_progress(begin, self.num_samples)
+ all_distances += tflib.run(distance_expr)
+ all_distances = np.concatenate(all_distances, axis=0)
+
+ # Reject outliers.
+ lo = np.percentile(all_distances, 1, interpolation='lower')
+ hi = np.percentile(all_distances, 99, interpolation='higher')
+ filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances)
+ self._report_result(np.mean(filtered_distances))
+
+#----------------------------------------------------------------------------
diff --git a/metrics/precision_recall.py b/metrics/precision_recall.py
new file mode 100755
index 00000000..dab3fecc
--- /dev/null
+++ b/metrics/precision_recall.py
@@ -0,0 +1,234 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Precision/Recall (PR) from the paper
+"Improved Precision and Recall Metric for Assessing Generative Models"."""
+
+import os
+import pickle
+import numpy as np
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+
+from metrics import metric_base
+
+#----------------------------------------------------------------------------
+
+def batch_pairwise_distances(U, V):
+ """ Compute pairwise distances between two batches of feature vectors."""
+ with tf.variable_scope('pairwise_dist_block'):
+ # Squared norms of each row in U and V.
+ norm_u = tf.reduce_sum(tf.square(U), 1)
+ norm_v = tf.reduce_sum(tf.square(V), 1)
+
+ # norm_u as a row and norm_v as a column vectors.
+ norm_u = tf.reshape(norm_u, [-1, 1])
+ norm_v = tf.reshape(norm_v, [1, -1])
+
+ # Pairwise squared Euclidean distances.
+ D = tf.maximum(norm_u - 2*tf.matmul(U, V, False, True) + norm_v, 0.0)
+
+ return D
+
+#----------------------------------------------------------------------------
+
+class DistanceBlock():
+ """Distance block."""
+ def __init__(self, num_features, num_gpus):
+ self.num_features = num_features
+ self.num_gpus = num_gpus
+
+ # Initialize TF graph to calculate pairwise distances.
+ with tf.device('/cpu:0'):
+ self._features_batch1 = tf.placeholder(tf.float16, shape=[None, self.num_features])
+ self._features_batch2 = tf.placeholder(tf.float16, shape=[None, self.num_features])
+ features_split2 = tf.split(self._features_batch2, self.num_gpus, axis=0)
+ distances_split = []
+ for gpu_idx in range(self.num_gpus):
+ with tf.device(f'/gpu:{gpu_idx}'):
+ distances_split.append(batch_pairwise_distances(self._features_batch1, features_split2[gpu_idx]))
+ self._distance_block = tf.concat(distances_split, axis=1)
+
+ def pairwise_distances(self, U, V):
+ """Evaluate pairwise distances between two batches of feature vectors."""
+ return self._distance_block.eval(feed_dict={self._features_batch1: U, self._features_batch2: V})
+
+#----------------------------------------------------------------------------
+
+class ManifoldEstimator():
+ """Finds an estimate for the manifold of given feature vectors."""
+ def __init__(self, distance_block, features, row_batch_size, col_batch_size, nhood_sizes, clamp_to_percentile=None):
+ """Find an estimate of the manifold of given feature vectors."""
+ num_images = features.shape[0]
+ self.nhood_sizes = nhood_sizes
+ self.num_nhoods = len(nhood_sizes)
+ self.row_batch_size = row_batch_size
+ self.col_batch_size = col_batch_size
+ self._ref_features = features
+ self._distance_block = distance_block
+
+ # Estimate manifold of features by calculating distances to kth nearest neighbor of each sample.
+ self.D = np.zeros([num_images, self.num_nhoods], dtype=np.float16)
+ distance_batch = np.zeros([row_batch_size, num_images], dtype=np.float16)
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
+
+ for begin1 in range(0, num_images, row_batch_size):
+ end1 = min(begin1 + row_batch_size, num_images)
+ row_batch = features[begin1:end1]
+
+ for begin2 in range(0, num_images, col_batch_size):
+ end2 = min(begin2 + col_batch_size, num_images)
+ col_batch = features[begin2:end2]
+
+ # Compute distances between batches.
+ distance_batch[0:end1-begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch, col_batch)
+
+ # Find the kth nearest neighbor from the current batch.
+ self.D[begin1:end1, :] = np.partition(distance_batch[0:end1-begin1, :], seq, axis=1)[:, self.nhood_sizes]
+
+ if clamp_to_percentile is not None:
+ max_distances = np.percentile(self.D, clamp_to_percentile, axis=0)
+ self.D[self.D > max_distances] = 0 #max_distances # 0
+
+ def evaluate(self, eval_features, return_realism=False, return_neighbors=False):
+ """Evaluate if new feature vectors are in the estimated manifold."""
+ num_eval_images = eval_features.shape[0]
+ num_ref_images = self.D.shape[0]
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float16)
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
+ #max_realism_score = np.zeros([num_eval_images,], dtype=np.float32)
+ realism_score = np.zeros([num_eval_images,], dtype=np.float32)
+ nearest_indices = np.zeros([num_eval_images,], dtype=np.int32)
+
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
+ feature_batch = eval_features[begin1:end1]
+
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
+ ref_batch = self._ref_features[begin2:end2]
+
+ distance_batch[0:end1-begin1, begin2:end2] = self._distance_block.pairwise_distances(feature_batch, ref_batch)
+
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
+ # If a feature vector is inside a hypersphere of some reference sample, then the new sample lies on the estimated manifold.
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
+ samples_in_manifold = distance_batch[0:end1-begin1, :, None] <= self.D
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
+
+ #max_realism_score[begin1:end1] = np.max(self.D[:, 0] / (distance_batch[0:end1-begin1, :] + 1e-18), axis=1)
+ #nearest_indices[begin1:end1] = np.argmax(self.D[:, 0] / (distance_batch[0:end1-begin1, :] + 1e-18), axis=1)
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0:end1-begin1, :], axis=1)
+ realism_score[begin1:end1] = self.D[nearest_indices[begin1:end1], 0] / np.min(distance_batch[0:end1-begin1, :], axis=1)
+
+ if return_realism and return_neighbors:
+ return batch_predictions, realism_score, nearest_indices
+ elif return_realism:
+ return batch_predictions, realism_score
+ elif return_neighbors:
+ return batch_predictions, nearest_indices
+
+ return batch_predictions
+
+#----------------------------------------------------------------------------
+
+def knn_precision_recall_features(ref_features, eval_features, feature_net, nhood_sizes,
+ row_batch_size, col_batch_size, num_gpus):
+ """Calculates k-NN precision and recall for two sets of feature vectors."""
+ state = dnnlib.EasyDict()
+ #num_images = ref_features.shape[0]
+ num_features = feature_net.output_shape[1]
+ state.ref_features = ref_features
+ state.eval_features = eval_features
+
+ # Initialize DistanceBlock and ManifoldEstimators.
+ distance_block = DistanceBlock(num_features, num_gpus)
+ state.ref_manifold = ManifoldEstimator(distance_block, state.ref_features, row_batch_size, col_batch_size, nhood_sizes)
+ state.eval_manifold = ManifoldEstimator(distance_block, state.eval_features, row_batch_size, col_batch_size, nhood_sizes)
+
+ # Evaluate precision and recall using k-nearest neighbors.
+ #print(f'Evaluating k-NN precision and recall with {num_images} samples...')
+ #start = time.time()
+
+ # Precision: How many points from eval_features are in ref_features manifold.
+ state.precision, state.realism_scores, state.nearest_neighbors = state.ref_manifold.evaluate(state.eval_features, return_realism=True, return_neighbors=True)
+ state.knn_precision = state.precision.mean(axis=0)
+
+ # Recall: How many points from ref_features are in eval_features manifold.
+ state.recall = state.eval_manifold.evaluate(state.ref_features)
+ state.knn_recall = state.recall.mean(axis=0)
+
+ #elapsed_time = time.time() - start
+ #print(f'Done evaluation in: {elapsed_time:g}s')
+
+ return state
+
+#----------------------------------------------------------------------------
+
+class PR(metric_base.MetricBase):
+ def __init__(self, max_reals, num_fakes, nhood_size, minibatch_per_gpu, row_batch_size, col_batch_size, **kwargs):
+ super().__init__(**kwargs)
+ self.max_reals = max_reals
+ self.num_fakes = num_fakes
+ self.nhood_size = nhood_size
+ self.minibatch_per_gpu = minibatch_per_gpu
+ self.row_batch_size = row_batch_size
+ self.col_batch_size = col_batch_size
+
+ def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ
+ minibatch_size = num_gpus * self.minibatch_per_gpu
+ with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16.pkl') as f:
+ feature_net = pickle.load(f)
+
+ # Calculate features for reals.
+ cache_file = self._get_cache_file_for_reals(max_reals=self.max_reals)
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ if os.path.isfile(cache_file):
+ with open(cache_file, 'rb') as f:
+ feat_real = pickle.load(f)
+ else:
+ feat_real = []
+ for images, _labels, num in self._iterate_reals(minibatch_size):
+ if images.shape[1] == 1: images = np.tile(images, [1, 3, 1, 1])
+ feat_real += list(feature_net.run(images, num_gpus=num_gpus, assume_frozen=True))[:num]
+ if self.max_reals is not None and len(feat_real) >= self.max_reals:
+ break
+ if self.max_reals is not None and len(feat_real) > self.max_reals:
+ feat_real = feat_real[:self.max_reals]
+ feat_real = np.stack(feat_real)
+ with open(cache_file, 'wb') as f:
+ pickle.dump(feat_real, f)
+
+ # Construct TensorFlow graph.
+ result_expr = []
+ for gpu_idx in range(num_gpus):
+ with tf.device(f'/gpu:{gpu_idx}'):
+ Gs_clone = Gs.clone()
+ feature_net_clone = feature_net.clone()
+ latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
+ labels = self._get_random_labels_tf(self.minibatch_per_gpu)
+ images = Gs_clone.get_output_for(latents, labels, **G_kwargs)
+ if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
+ images = tflib.convert_images_to_uint8(images)
+ result_expr.append(feature_net_clone.get_output_for(images))
+
+ # Calculate features for fakes.
+ feat_fake = []
+ for begin in range(0, self.num_fakes, minibatch_size):
+ self._report_progress(begin, self.num_fakes)
+ feat_fake += list(np.concatenate(tflib.run(result_expr), axis=0))
+ feat_fake = np.stack(feat_fake[:self.num_fakes])
+
+ # Calculate precision and recall.
+ state = knn_precision_recall_features(ref_features=feat_real, eval_features=feat_fake, feature_net=feature_net,
+ nhood_sizes=[self.nhood_size], row_batch_size=self.row_batch_size, col_batch_size=self.row_batch_size, num_gpus=num_gpus)
+ self._report_result(state.knn_precision[0], suffix='_precision')
+ self._report_result(state.knn_recall[0], suffix='_recall')
+
+#----------------------------------------------------------------------------
diff --git a/projector.py b/projector.py
new file mode 100755
index 00000000..8f6be7e7
--- /dev/null
+++ b/projector.py
@@ -0,0 +1,289 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Project given image to the latent space of pretrained network pickle."""
+
+import argparse
+import os
+import pickle
+import imageio
+
+import numpy as np
+import PIL.Image
+import tensorflow as tf
+import tqdm
+
+import dnnlib
+import dnnlib.tflib as tflib
+
+class Projector:
+ def __init__(self):
+ self.num_steps = 1000
+ self.dlatent_avg_samples = 10000
+ self.initial_learning_rate = 0.1
+ self.initial_noise_factor = 0.05
+ self.lr_rampdown_length = 0.25
+ self.lr_rampup_length = 0.05
+ self.noise_ramp_length = 0.75
+ self.regularize_noise_weight = 1e5
+ self.verbose = True
+
+ self._Gs = None
+ self._minibatch_size = None
+ self._dlatent_avg = None
+ self._dlatent_std = None
+ self._noise_vars = None
+ self._noise_init_op = None
+ self._noise_normalize_op = None
+ self._dlatents_var = None
+ self._dlatent_noise_in = None
+ self._dlatents_expr = None
+ self._images_float_expr = None
+ self._images_uint8_expr = None
+ self._target_images_var = None
+ self._lpips = None
+ self._dist = None
+ self._loss = None
+ self._reg_sizes = None
+ self._lrate_in = None
+ self._opt = None
+ self._opt_step = None
+ self._cur_step = None
+
+ def _info(self, *args):
+ if self.verbose:
+ print('Projector:', *args)
+
+ def set_network(self, Gs, dtype='float16'):
+ if Gs is None:
+ self._Gs = None
+ return
+ self._Gs = Gs.clone(randomize_noise=False, dtype=dtype, num_fp16_res=0, fused_modconv=True)
+
+ # Compute dlatent stats.
+ self._info(f'Computing W midpoint and stddev using {self.dlatent_avg_samples} samples...')
+ latent_samples = np.random.RandomState(123).randn(self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:])
+ dlatent_samples = self._Gs.components.mapping.run(latent_samples, None) # [N, L, C]
+ dlatent_samples = dlatent_samples[:, :1, :].astype(np.float32) # [N, 1, C]
+ self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True) # [1, 1, C]
+ self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg) ** 2) / self.dlatent_avg_samples) ** 0.5
+ self._info(f'std = {self._dlatent_std:g}')
+
+ # Setup noise inputs.
+ self._info('Setting up noise inputs...')
+ self._noise_vars = []
+ noise_init_ops = []
+ noise_normalize_ops = []
+ while True:
+ n = f'G_synthesis/noise{len(self._noise_vars)}'
+ if not n in self._Gs.vars:
+ break
+ v = self._Gs.vars[n]
+ self._noise_vars.append(v)
+ noise_init_ops.append(tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32)))
+ noise_mean = tf.reduce_mean(v)
+ noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5
+ noise_normalize_ops.append(tf.assign(v, (v - noise_mean) / noise_std))
+ self._noise_init_op = tf.group(*noise_init_ops)
+ self._noise_normalize_op = tf.group(*noise_normalize_ops)
+
+ # Build image output graph.
+ self._info('Building image output graph...')
+ self._minibatch_size = 1
+ self._dlatents_var = tf.Variable(tf.zeros([self._minibatch_size] + list(self._dlatent_avg.shape[1:])), name='dlatents_var')
+ self._dlatent_noise_in = tf.placeholder(tf.float32, [], name='noise_in')
+ dlatents_noise = tf.random.normal(shape=self._dlatents_var.shape) * self._dlatent_noise_in
+ self._dlatents_expr = tf.tile(self._dlatents_var + dlatents_noise, [1, self._Gs.components.synthesis.input_shape[1], 1])
+ self._images_float_expr = tf.cast(self._Gs.components.synthesis.get_output_for(self._dlatents_expr), tf.float32)
+ self._images_uint8_expr = tflib.convert_images_to_uint8(self._images_float_expr, nchw_to_nhwc=True)
+
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
+ proc_images_expr = (self._images_float_expr + 1) * (255 / 2)
+ sh = proc_images_expr.shape.as_list()
+ if sh[2] > 256:
+ factor = sh[2] // 256
+ proc_images_expr = tf.reduce_mean(tf.reshape(proc_images_expr, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3,5])
+
+ # Build loss graph.
+ self._info('Building loss graph...')
+ self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape), name='target_images_var')
+ if self._lpips is None:
+ with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16_zhang_perceptual.pkl') as f:
+ self._lpips = pickle.load(f)
+ self._dist = self._lpips.get_output_for(proc_images_expr, self._target_images_var)
+ self._loss = tf.reduce_sum(self._dist)
+
+ # Build noise regularization graph.
+ self._info('Building noise regularization graph...')
+ reg_loss = 0.0
+ for v in self._noise_vars:
+ sz = v.shape[2]
+ while True:
+ reg_loss += tf.reduce_mean(v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(v * tf.roll(v, shift=1, axis=2))**2
+ if sz <= 8:
+ break # Small enough already
+ v = tf.reshape(v, [1, 1, sz//2, 2, sz//2, 2]) # Downscale
+ v = tf.reduce_mean(v, axis=[3, 5])
+ sz = sz // 2
+ self._loss += reg_loss * self.regularize_noise_weight
+
+ # Setup optimizer.
+ self._info('Setting up optimizer...')
+ self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in')
+ self._opt = tflib.Optimizer(learning_rate=self._lrate_in)
+ self._opt.register_gradients(self._loss, [self._dlatents_var] + self._noise_vars)
+ self._opt_step = self._opt.apply_updates()
+
+ def start(self, target_images):
+ assert self._Gs is not None
+
+ # Prepare target images.
+ self._info('Preparing target images...')
+ target_images = np.asarray(target_images, dtype='float32')
+ target_images = (target_images + 1) * (255 / 2)
+ sh = target_images.shape
+ assert sh[0] == self._minibatch_size
+ if sh[2] > self._target_images_var.shape[2]:
+ factor = sh[2] // self._target_images_var.shape[2]
+ target_images = np.reshape(target_images, [-1, sh[1], sh[2] // factor, factor, sh[3] // factor, factor]).mean((3, 5))
+
+ # Initialize optimization state.
+ self._info('Initializing optimization state...')
+ dlatents = np.tile(self._dlatent_avg, [self._minibatch_size, 1, 1])
+ tflib.set_vars({self._target_images_var: target_images, self._dlatents_var: dlatents})
+ tflib.run(self._noise_init_op)
+ self._opt.reset_optimizer_state()
+ self._cur_step = 0
+
+ def step(self):
+ assert self._cur_step is not None
+ if self._cur_step >= self.num_steps:
+ return 0, 0
+
+ # Choose hyperparameters.
+ t = self._cur_step / self.num_steps
+ dlatent_noise = self._dlatent_std * self.initial_noise_factor * max(0.0, 1.0 - t / self.noise_ramp_length) ** 2
+ lr_ramp = min(1.0, (1.0 - t) / self.lr_rampdown_length)
+ lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
+ lr_ramp = lr_ramp * min(1.0, t / self.lr_rampup_length)
+ learning_rate = self.initial_learning_rate * lr_ramp
+
+ # Execute optimization step.
+ feed_dict = {self._dlatent_noise_in: dlatent_noise, self._lrate_in: learning_rate}
+ _, dist_value, loss_value = tflib.run([self._opt_step, self._dist, self._loss], feed_dict)
+ tflib.run(self._noise_normalize_op)
+ self._cur_step += 1
+ return dist_value, loss_value
+
+ @property
+ def cur_step(self):
+ return self._cur_step
+
+ @property
+ def dlatents(self):
+ return tflib.run(self._dlatents_expr, {self._dlatent_noise_in: 0})
+
+ @property
+ def noises(self):
+ return tflib.run(self._noise_vars)
+
+ @property
+ def images_float(self):
+ return tflib.run(self._images_float_expr, {self._dlatent_noise_in: 0})
+
+ @property
+ def images_uint8(self):
+ return tflib.run(self._images_uint8_expr, {self._dlatent_noise_in: 0})
+
+#----------------------------------------------------------------------------
+
+def project(network_pkl: str, target_fname: str, outdir: str, save_video: bool, seed: int):
+ # Load networks.
+ tflib.init_tf({'rnd.np_random_seed': seed})
+ print('Loading networks from "%s"...' % network_pkl)
+ with dnnlib.util.open_url(network_pkl) as fp:
+ _G, _D, Gs = pickle.load(fp)
+
+ # Load target image.
+ target_pil = PIL.Image.open(target_fname)
+ w, h = target_pil.size
+ s = min(w, h)
+ target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
+ target_pil= target_pil.convert('RGB')
+ target_pil = target_pil.resize((Gs.output_shape[3], Gs.output_shape[2]), PIL.Image.ANTIALIAS)
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
+ target_float = target_uint8.astype(np.float32).transpose([2, 0, 1]) * (2 / 255) - 1
+
+ # Initialize projector.
+ proj = Projector()
+ proj.set_network(Gs)
+ proj.start([target_float])
+
+ # Setup output directory.
+ os.makedirs(outdir, exist_ok=True)
+ target_pil.save(f'{outdir}/target.png')
+ writer = None
+ if save_video:
+ writer = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=60, codec='libx264', bitrate='16M')
+
+ # Run projector.
+ with tqdm.trange(proj.num_steps) as t:
+ for step in t:
+ assert step == proj.cur_step
+ if writer is not None:
+ writer.append_data(np.concatenate([target_uint8, proj.images_uint8[0]], axis=1))
+ dist, loss = proj.step()
+ t.set_postfix(dist=f'{dist[0]:.4f}', loss=f'{loss:.2f}')
+
+ # Save results.
+ PIL.Image.fromarray(proj.images_uint8[0], 'RGB').save(f'{outdir}/proj.png')
+ np.savez(f'{outdir}/dlatents.npz', dlatents=proj.dlatents)
+ if writer is not None:
+ writer.close()
+
+#----------------------------------------------------------------------------
+
+def _str_to_bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ if v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+#----------------------------------------------------------------------------
+
+_examples = '''examples:
+
+ python %(prog)s --outdir=out --target=targetimg.png \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl
+'''
+
+#----------------------------------------------------------------------------
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Project given image to the latent space of pretrained network pickle.',
+ epilog=_examples,
+ formatter_class=argparse.RawDescriptionHelpFormatter
+ )
+
+ parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
+ parser.add_argument('--target', help='Target image file to project to', dest='target_fname', required=True)
+ parser.add_argument('--save-video', help='Save an mp4 video of optimization progress (default: true)', type=_str_to_bool, default=True)
+ parser.add_argument('--seed', help='Random seed', type=int, default=303)
+ parser.add_argument('--outdir', help='Where to save the output images', required=True, metavar='DIR')
+ project(**vars(parser.parse_args()))
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ main()
+
+#----------------------------------------------------------------------------
diff --git a/style_mixing.py b/style_mixing.py
new file mode 100755
index 00000000..7d183f85
--- /dev/null
+++ b/style_mixing.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Generate style mixing image matrix using pretrained network pickle."""
+
+import argparse
+import os
+import pickle
+import re
+
+import numpy as np
+import PIL.Image
+
+import dnnlib
+import dnnlib.tflib as tflib
+
+#----------------------------------------------------------------------------
+
+def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_styles, outdir, minibatch_size=4):
+ tflib.init_tf()
+ print('Loading networks from "%s"...' % network_pkl)
+ with dnnlib.util.open_url(network_pkl) as fp:
+ _G, _D, Gs = pickle.load(fp)
+
+ w_avg = Gs.get_var('dlatent_avg') # [component]
+ Gs_syn_kwargs = {
+ 'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
+ 'randomize_noise': False,
+ 'minibatch_size': minibatch_size
+ }
+
+ print('Generating W vectors...')
+ all_seeds = list(set(row_seeds + col_seeds))
+ all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]
+ all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component]
+ all_w = w_avg + (all_w - w_avg) * truncation_psi # [minibatch, layer, component]
+ w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} # [layer, component]
+
+ print('Generating images...')
+ all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs) # [minibatch, height, width, channel]
+ image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
+
+ print('Generating style-mixed images...')
+ for row_seed in row_seeds:
+ for col_seed in col_seeds:
+ w = w_dict[row_seed].copy()
+ w[col_styles] = w_dict[col_seed][col_styles]
+ image = Gs.components.synthesis.run(w[np.newaxis], **Gs_syn_kwargs)[0]
+ image_dict[(row_seed, col_seed)] = image
+
+ print('Saving images...')
+ os.makedirs(outdir, exist_ok=True)
+ for (row_seed, col_seed), image in image_dict.items():
+ PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
+
+ print('Saving image grid...')
+ _N, _C, H, W = Gs.output_shape
+ canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
+ for row_idx, row_seed in enumerate([None] + row_seeds):
+ for col_idx, col_seed in enumerate([None] + col_seeds):
+ if row_seed is None and col_seed is None:
+ continue
+ key = (row_seed, col_seed)
+ if row_seed is None:
+ key = (col_seed, col_seed)
+ if col_seed is None:
+ key = (row_seed, row_seed)
+ canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
+ canvas.save(f'{outdir}/grid.png')
+
+#----------------------------------------------------------------------------
+
+def _parse_num_range(s):
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
+
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ m = range_re.match(s)
+ if m:
+ return list(range(int(m.group(1)), int(m.group(2))+1))
+ vals = s.split(',')
+ return [int(x) for x in vals]
+
+#----------------------------------------------------------------------------
+
+_examples = '''examples:
+
+ python %(prog)s --outdir=out --trunc=1 --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
+'''
+
+#----------------------------------------------------------------------------
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Generate style mixing image matrix using pretrained network pickle.',
+ epilog=_examples,
+ formatter_class=argparse.RawDescriptionHelpFormatter
+ )
+
+ parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
+ parser.add_argument('--rows', dest='row_seeds', type=_parse_num_range, help='Random seeds to use for image rows', required=True)
+ parser.add_argument('--cols', dest='col_seeds', type=_parse_num_range, help='Random seeds to use for image columns', required=True)
+ parser.add_argument('--styles', dest='col_styles', type=_parse_num_range, help='Style layer range (default: %(default)s)', default='0-6')
+ parser.add_argument('--trunc', dest='truncation_psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5)
+ parser.add_argument('--outdir', help='Where to save the output images', required=True, metavar='DIR')
+
+ args = parser.parse_args()
+ style_mixing_example(**vars(args))
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ main()
+
+#----------------------------------------------------------------------------
diff --git a/train.py b/train.py
new file mode 100755
index 00000000..5b36d792
--- /dev/null
+++ b/train.py
@@ -0,0 +1,563 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Train a GAN using the techniques described in the paper
+"Training Generative Adversarial Networks with Limited Data"."""
+
+import os
+import argparse
+import json
+import re
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+
+from training import training_loop
+from training import dataset
+from metrics import metric_defaults
+
+#----------------------------------------------------------------------------
+
+class UserError(Exception):
+ pass
+
+#----------------------------------------------------------------------------
+
+def setup_training_options(
+ # General options (not included in desc).
+ gpus = None, # Number of GPUs: , default = 1 gpu
+ snap = None, # Snapshot interval: , default = 50 ticks
+
+ # Training dataset.
+ data = None, # Training dataset (required):
+ res = None, # Override dataset resolution: , default = highest available
+ mirror = None, # Augment dataset with x-flips: , default = False
+
+ # Metrics (not included in desc).
+ metrics = None, # List of metric names: [], ['fid50k_full'] (default), ...
+ metricdata = None, # Metric dataset (optional):
+
+ # Base config.
+ cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline'
+ gamma = None, # Override R1 gamma: , default = depends on cfg
+ kimg = None, # Override training duration: , default = depends on cfg
+
+ # Discriminator augmentation.
+ aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed', 'adarv'
+ p = None, # Specify p for 'fixed' (required):
+ target = None, # Override ADA target for 'ada' and 'adarv': , default = depends on aug
+ augpipe = None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'
+
+ # Comparison methods.
+ cmethod = None, # Comparison method: 'nocmethod' (default), 'bcr', 'zcr', 'pagan', 'wgangp', 'auxrot', 'spectralnorm', 'shallowmap', 'adropout'
+ dcap = None, # Multiplier for discriminator capacity: , default = 1
+
+ # Transfer learning.
+ resume = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', ,
+ freezed = None, # Freeze-D: , default = 0 discriminator layers
+):
+ # Initialize dicts.
+ args = dnnlib.EasyDict()
+ args.G_args = dnnlib.EasyDict(func_name='training.networks.G_main')
+ args.D_args = dnnlib.EasyDict(func_name='training.networks.D_main')
+ args.G_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99)
+ args.D_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99)
+ args.loss_args = dnnlib.EasyDict(func_name='training.loss.stylegan2')
+ args.augment_args = dnnlib.EasyDict(class_name='training.augment.AdaptiveAugment')
+
+ # ---------------------------
+ # General options: gpus, snap
+ # ---------------------------
+
+ if gpus is None:
+ gpus = 1
+ assert isinstance(gpus, int)
+ if not (gpus >= 1 and gpus & (gpus - 1) == 0):
+ raise UserError('--gpus must be a power of two')
+ args.num_gpus = gpus
+
+ if snap is None:
+ snap = 50
+ assert isinstance(snap, int)
+ if snap < 1:
+ raise UserError('--snap must be at least 1')
+ args.image_snapshot_ticks = snap
+ args.network_snapshot_ticks = snap
+
+ # -----------------------------------
+ # Training dataset: data, res, mirror
+ # -----------------------------------
+
+ assert data is not None
+ assert isinstance(data, str)
+ data_name = os.path.basename(os.path.abspath(data))
+ if not os.path.isdir(data) or len(data_name) == 0:
+ raise UserError('--data must point to a directory containing *.tfrecords')
+ desc = data_name
+
+ with tf.Graph().as_default(), tflib.create_session().as_default(): # pylint: disable=not-context-manager
+ args.train_dataset_args = dnnlib.EasyDict(path=data, max_label_size='full')
+ dataset_obj = dataset.load_dataset(**args.train_dataset_args) # try to load the data and see what comes out
+ args.train_dataset_args.resolution = dataset_obj.shape[-1] # be explicit about resolution
+ args.train_dataset_args.max_label_size = dataset_obj.label_size # be explicit about label size
+ validation_set_available = dataset_obj.has_validation_set
+ dataset_obj.close()
+ dataset_obj = None
+
+ if res is None:
+ res = args.train_dataset_args.resolution
+ else:
+ assert isinstance(res, int)
+ if not (res >= 4 and res & (res - 1) == 0):
+ raise UserError('--res must be a power of two and at least 4')
+ if res > args.train_dataset_args.resolution:
+ raise UserError(f'--res cannot exceed maximum available resolution in the dataset ({args.train_dataset_args.resolution})')
+ desc += f'-res{res:d}'
+ args.train_dataset_args.resolution = res
+
+ if mirror is None:
+ mirror = False
+ else:
+ assert isinstance(mirror, bool)
+ if mirror:
+ desc += '-mirror'
+ args.train_dataset_args.mirror_augment = mirror
+
+ # ----------------------------
+ # Metrics: metrics, metricdata
+ # ----------------------------
+
+ if metrics is None:
+ metrics = ['fid50k_full']
+ assert isinstance(metrics, list)
+ assert all(isinstance(metric, str) for metric in metrics)
+
+ args.metric_arg_list = []
+ for metric in metrics:
+ if metric not in metric_defaults.metric_defaults:
+ raise UserError('\n'.join(['--metrics can only contain the following values:', 'none'] + list(metric_defaults.metric_defaults.keys())))
+ args.metric_arg_list.append(metric_defaults.metric_defaults[metric])
+
+ args.metric_dataset_args = dnnlib.EasyDict(args.train_dataset_args)
+ if metricdata is not None:
+ assert isinstance(metricdata, str)
+ if not os.path.isdir(metricdata):
+ raise UserError('--metricdata must point to a directory containing *.tfrecords')
+ args.metric_dataset_args.path = metricdata
+
+ # -----------------------------
+ # Base config: cfg, gamma, kimg
+ # -----------------------------
+
+ if cfg is None:
+ cfg = 'auto'
+ assert isinstance(cfg, str)
+ desc += f'-{cfg}'
+
+ cfg_specs = {
+ 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # populated dynamically based on 'gpus' and 'res'
+ 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, unlike original StyleGAN2
+ 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
+ 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
+ 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8),
+ 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=0.5, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
+ 'cifarbaseline': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=0.5, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=8),
+ }
+
+ assert cfg in cfg_specs
+ spec = dnnlib.EasyDict(cfg_specs[cfg])
+ if cfg == 'auto':
+ desc += f'{gpus:d}'
+ spec.ref_gpus = gpus
+ spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
+ spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed
+ spec.fmaps = 1 if res >= 512 else 0.5
+ spec.lrate = 0.002 if res >= 1024 else 0.0025
+ spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula
+ spec.ema = spec.mb * 10 / 32
+
+ args.total_kimg = spec.kimg
+ args.minibatch_size = spec.mb
+ args.minibatch_gpu = spec.mb // spec.ref_gpus
+ args.D_args.mbstd_group_size = spec.mbstd
+ args.G_args.fmap_base = args.D_args.fmap_base = int(spec.fmaps * 16384)
+ args.G_args.fmap_max = args.D_args.fmap_max = 512
+ args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = spec.lrate
+ args.loss_args.r1_gamma = spec.gamma
+ args.G_smoothing_kimg = spec.ema
+ args.G_smoothing_rampup = spec.ramp
+ args.G_args.mapping_layers = spec.map
+ args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4 # enable mixed-precision training
+ args.G_args.conv_clamp = args.D_args.conv_clamp = 256 # clamp activations to avoid float16 overflow
+
+ if cfg == 'cifar':
+ args.loss_args.pl_weight = 0 # disable path length regularization
+ args.G_args.style_mixing_prob = None # disable style mixing
+ args.D_args.architecture = 'orig' # disable residual skip connections
+
+ if gamma is not None:
+ assert isinstance(gamma, float)
+ if not gamma >= 0:
+ raise UserError('--gamma must be non-negative')
+ desc += f'-gamma{gamma:g}'
+ args.loss_args.r1_gamma = gamma
+
+ if kimg is not None:
+ assert isinstance(kimg, int)
+ if not kimg >= 1:
+ raise UserError('--kimg must be at least 1')
+ desc += f'-kimg{kimg:d}'
+ args.total_kimg = kimg
+
+ # ---------------------------------------------------
+ # Discriminator augmentation: aug, p, target, augpipe
+ # ---------------------------------------------------
+
+ if aug is None:
+ aug = 'ada'
+ else:
+ assert isinstance(aug, str)
+ desc += f'-{aug}'
+
+ if aug == 'ada':
+ args.augment_args.tune_heuristic = 'rt'
+ args.augment_args.tune_target = 0.6
+
+ elif aug == 'noaug':
+ pass
+
+ elif aug == 'fixed':
+ if p is None:
+ raise UserError(f'--aug={aug} requires specifying --p')
+
+ elif aug == 'adarv':
+ if not validation_set_available:
+ raise UserError(f'--aug={aug} requires separate validation set; please see "python dataset_tool.py pack -h"')
+ args.augment_args.tune_heuristic = 'rv'
+ args.augment_args.tune_target = 0.5
+
+ else:
+ raise UserError(f'--aug={aug} not supported')
+
+ if p is not None:
+ assert isinstance(p, float)
+ if aug != 'fixed':
+ raise UserError('--p can only be specified with --aug=fixed')
+ if not 0 <= p <= 1:
+ raise UserError('--p must be between 0 and 1')
+ desc += f'-p{p:g}'
+ args.augment_args.initial_strength = p
+
+ if target is not None:
+ assert isinstance(target, float)
+ if aug not in ['ada', 'adarv']:
+ raise UserError('--target can only be specified with --aug=ada or --aug=adarv')
+ if not 0 <= target <= 1:
+ raise UserError('--target must be between 0 and 1')
+ desc += f'-target{target:g}'
+ args.augment_args.tune_target = target
+
+ assert augpipe is None or isinstance(augpipe, str)
+ if augpipe is None:
+ augpipe = 'bgc'
+ else:
+ if aug == 'noaug':
+ raise UserError('--augpipe cannot be specified with --aug=noaug')
+ desc += f'-{augpipe}'
+
+ augpipe_specs = {
+ 'blit': dict(xflip=1, rotate90=1, xint=1),
+ 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1),
+ 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
+ 'filter': dict(imgfilter=1),
+ 'noise': dict(noise=1),
+ 'cutout': dict(cutout=1),
+ 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
+ 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
+ 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
+ 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
+ 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
+ }
+
+ assert augpipe in augpipe_specs
+ if aug != 'noaug':
+ args.augment_args.apply_func = 'training.augment.augment_pipeline'
+ args.augment_args.apply_args = augpipe_specs[augpipe]
+
+ # ---------------------------------
+ # Comparison methods: cmethod, dcap
+ # ---------------------------------
+
+ assert cmethod is None or isinstance(cmethod, str)
+ if cmethod is None:
+ cmethod = 'nocmethod'
+ else:
+ desc += f'-{cmethod}'
+
+ if cmethod == 'nocmethod':
+ pass
+
+ elif cmethod == 'bcr':
+ args.loss_args.func_name = 'training.loss.cmethods'
+ args.loss_args.bcr_real_weight = 10
+ args.loss_args.bcr_fake_weight = 10
+ args.loss_args.bcr_augment = dnnlib.EasyDict(func_name='training.augment.augment_pipeline', xint=1, xint_max=1/32)
+
+ elif cmethod == 'zcr':
+ args.loss_args.func_name = 'training.loss.cmethods'
+ args.loss_args.zcr_gen_weight = 0.02
+ args.loss_args.zcr_dis_weight = 0.2
+ args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0 # disable mixed-precision training
+ args.G_args.conv_clamp = args.D_args.conv_clamp = None
+
+ elif cmethod == 'pagan':
+ if aug != 'noaug':
+ raise UserError(f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug')
+ args.D_args.use_pagan = True
+ args.augment_args.tune_heuristic = 'rt' # enable ada heuristic
+ args.augment_args.pop('apply_func', None) # disable discriminator augmentation
+ args.augment_args.pop('apply_args', None)
+ args.augment_args.tune_target = 0.95
+
+ elif cmethod == 'wgangp':
+ if aug != 'noaug':
+ raise UserError(f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug')
+ if gamma is not None:
+ raise UserError(f'--cmethod={cmethod} is not compatible with --gamma')
+ args.loss_args = dnnlib.EasyDict(func_name='training.loss.wgangp')
+ args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = 0.001
+ args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0 # disable mixed-precision training
+ args.G_args.conv_clamp = args.D_args.conv_clamp = None
+ args.lazy_regularization = False
+
+ elif cmethod == 'auxrot':
+ if args.train_dataset_args.max_label_size > 0:
+ raise UserError(f'--cmethod={cmethod} is not compatible with label conditioning; please specify a dataset without labels')
+ args.loss_args.func_name = 'training.loss.cmethods'
+ args.loss_args.auxrot_alpha = 10
+ args.loss_args.auxrot_beta = 5
+ args.D_args.score_max = 5 # prepare D to output 5 scalars per image instead of just 1
+
+ elif cmethod == 'spectralnorm':
+ args.D_args.use_spectral_norm = True
+
+ elif cmethod == 'shallowmap':
+ if args.G_args.mapping_layers == 2:
+ raise UserError(f'--cmethod={cmethod} is a no-op for --cfg={cfg}')
+ args.G_args.mapping_layers = 2
+
+ elif cmethod == 'adropout':
+ if aug != 'noaug':
+ raise UserError(f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug')
+ args.D_args.adaptive_dropout = 1
+ args.augment_args.tune_heuristic = 'rt' # enable ada heuristic
+ args.augment_args.pop('apply_func', None) # disable discriminator augmentation
+ args.augment_args.pop('apply_args', None)
+ args.augment_args.tune_target = 0.6
+
+ else:
+ raise UserError(f'--cmethod={cmethod} not supported')
+
+ if dcap is not None:
+ assert isinstance(dcap, float)
+ if not dcap > 0:
+ raise UserError('--dcap must be positive')
+ desc += f'-dcap{dcap:g}'
+ args.D_args.fmap_base = max(int(args.D_args.fmap_base * dcap), 1)
+ args.D_args.fmap_max = max(int(args.D_args.fmap_max * dcap), 1)
+
+ # ----------------------------------
+ # Transfer learning: resume, freezed
+ # ----------------------------------
+
+ resume_specs = {
+ 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
+ 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
+ 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
+ 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
+ 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
+ }
+
+ assert resume is None or isinstance(resume, str)
+ if resume is None:
+ resume = 'noresume'
+ elif resume == 'noresume':
+ desc += '-noresume'
+ elif resume in resume_specs:
+ desc += f'-resume{resume}'
+ args.resume_pkl = resume_specs[resume] # predefined url
+ else:
+ desc += '-resumecustom'
+ args.resume_pkl = resume # custom path or url
+
+ if resume != 'noresume':
+ args.augment_args.tune_kimg = 100 # make ADA react faster at the beginning
+ args.G_smoothing_rampup = None # disable EMA rampup
+
+ if freezed is not None:
+ assert isinstance(freezed, int)
+ if not freezed >= 0:
+ raise UserError('--freezed must be non-negative')
+ desc += f'-freezed{freezed:d}'
+ args.D_args.freeze_layers = freezed
+
+ return desc, args
+
+#----------------------------------------------------------------------------
+
+def run_training(outdir, seed, dry_run, **hyperparam_options):
+ # Setup training options.
+ tflib.init_tf({'rnd.np_random_seed': seed})
+ run_desc, training_options = setup_training_options(**hyperparam_options)
+
+ # Pick output directory.
+ prev_run_dirs = []
+ if os.path.isdir(outdir):
+ prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
+ prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
+ prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
+ cur_run_id = max(prev_run_ids, default=-1) + 1
+ training_options.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')
+ assert not os.path.exists(training_options.run_dir)
+
+ # Print options.
+ print()
+ print('Training options:')
+ print(json.dumps(training_options, indent=2))
+ print()
+ print(f'Output directory: {training_options.run_dir}')
+ print(f'Training data: {training_options.train_dataset_args.path}')
+ print(f'Training length: {training_options.total_kimg} kimg')
+ print(f'Resolution: {training_options.train_dataset_args.resolution}')
+ print(f'Number of GPUs: {training_options.num_gpus}')
+ print()
+
+ # Dry run?
+ if dry_run:
+ print('Dry run; exiting.')
+ return
+
+ # Kick off training.
+ print('Creating output directory...')
+ os.makedirs(training_options.run_dir)
+ with open(os.path.join(training_options.run_dir, 'training_options.json'), 'wt') as f:
+ json.dump(training_options, f, indent=2)
+ with dnnlib.util.Logger(os.path.join(training_options.run_dir, 'log.txt')):
+ training_loop.training_loop(**training_options)
+
+#----------------------------------------------------------------------------
+
+def _str_to_bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ if v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+def _parse_comma_sep(s):
+ if s is None or s.lower() == 'none' or s == '':
+ return []
+ return s.split(',')
+
+#----------------------------------------------------------------------------
+
+_cmdline_help_epilog = '''examples:
+
+ # Train custom dataset using 1 GPU.
+ python %(prog)s --outdir=~/training-runs --gpus=1 --data=~/datasets/custom
+
+ # Train class-conditional CIFAR-10 using 2 GPUs.
+ python %(prog)s --outdir=~/training-runs --gpus=2 --data=~/datasets/cifar10c \\
+ --cfg=cifar
+
+ # Transfer learn MetFaces from FFHQ using 4 GPUs.
+ python %(prog)s --outdir=~/training-runs --gpus=4 --data=~/datasets/metfaces \\
+ --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
+
+ # Reproduce original StyleGAN2 config F.
+ python %(prog)s --outdir=~/training-runs --gpus=8 --data=~/datasets/ffhq \\
+ --cfg=stylegan2 --res=1024 --mirror=1 --aug=noaug
+
+available base configs (--cfg):
+ auto Automatically select reasonable defaults based on resolution
+ and GPU count. Good starting point for new datasets.
+ stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
+ paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
+ paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
+ paper1024 Reproduce results for MetFaces at 1024x1024.
+ cifar Reproduce results for CIFAR-10 (tuned configuration).
+ cifarbaseline Reproduce results for CIFAR-10 (baseline configuration).
+
+transfer learning source networks (--resume):
+ ffhq256 FFHQ trained at 256x256 resolution.
+ ffhq512 FFHQ trained at 512x512 resolution.
+ ffhq1024 FFHQ trained at 1024x1024 resolution.
+ celebahq256 CelebA-HQ trained at 256x256 resolution.
+ lsundog256 LSUN Dog trained at 256x256 resolution.
+ Custom network pickle.
+'''
+
+#----------------------------------------------------------------------------
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Train a GAN using the techniques described in the paper\n"Training Generative Adversarial Networks with Limited Data".',
+ epilog=_cmdline_help_epilog,
+ formatter_class=argparse.RawDescriptionHelpFormatter
+ )
+
+ group = parser.add_argument_group('general options')
+ group.add_argument('--outdir', help='Where to save the results (required)', required=True, metavar='DIR')
+ group.add_argument('--gpus', help='Number of GPUs to use (default: 1 gpu)', type=int, metavar='INT')
+ group.add_argument('--snap', help='Snapshot interval (default: 50 ticks)', type=int, metavar='INT')
+ group.add_argument('--seed', help='Random seed (default: %(default)s)', type=int, default=1000, metavar='INT')
+ group.add_argument('-n', '--dry-run', help='Print training options and exit', action='store_true', default=False)
+
+ group = parser.add_argument_group('training dataset')
+ group.add_argument('--data', help='Training dataset path (required)', metavar='PATH', required=True)
+ group.add_argument('--res', help='Dataset resolution (default: highest available)', type=int, metavar='INT')
+ group.add_argument('--mirror', help='Augment dataset with x-flips (default: false)', type=_str_to_bool, metavar='BOOL')
+
+ group = parser.add_argument_group('metrics')
+ group.add_argument('--metrics', help='Comma-separated list or "none" (default: fid50k_full)', type=_parse_comma_sep, metavar='LIST')
+ group.add_argument('--metricdata', help='Dataset to evaluate metrics against (optional)', metavar='PATH')
+
+ group = parser.add_argument_group('base config')
+ group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline'])
+ group.add_argument('--gamma', help='Override R1 gamma', type=float, metavar='FLOAT')
+ group.add_argument('--kimg', help='Override training duration', type=int, metavar='INT')
+
+ group = parser.add_argument_group('discriminator augmentation')
+ group.add_argument('--aug', help='Augmentation mode (default: ada)', choices=['noaug', 'ada', 'fixed', 'adarv'])
+ group.add_argument('--p', help='Specify augmentation probability for --aug=fixed', type=float, metavar='FLOAT')
+ group.add_argument('--target', help='Override ADA target for --aug=ada and --aug=adarv', type=float)
+ group.add_argument('--augpipe', help='Augmentation pipeline (default: bgc)', choices=['blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc', 'bgcf', 'bgcfn', 'bgcfnc'])
+
+ group = parser.add_argument_group('comparison methods')
+ group.add_argument('--cmethod', help='Comparison method (default: nocmethod)', choices=['nocmethod', 'bcr', 'zcr', 'pagan', 'wgangp', 'auxrot', 'spectralnorm', 'shallowmap', 'adropout'])
+ group.add_argument('--dcap', help='Multiplier for discriminator capacity', type=float, metavar='FLOAT')
+
+ group = parser.add_argument_group('transfer learning')
+ group.add_argument('--resume', help='Resume from network pickle (default: noresume)')
+ group.add_argument('--freezed', help='Freeze-D (default: 0 discriminator layers)', type=int, metavar='INT')
+
+ args = parser.parse_args()
+ try:
+ run_training(**vars(args))
+ except UserError as err:
+ print(f'Error: {err}')
+ exit(1)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ main()
+
+#----------------------------------------------------------------------------
diff --git a/training/__init__.py b/training/__init__.py
new file mode 100755
index 00000000..2c61c745
--- /dev/null
+++ b/training/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/training/augment.py b/training/augment.py
new file mode 100755
index 00000000..17296fc3
--- /dev/null
+++ b/training/augment.py
@@ -0,0 +1,587 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Adaptive discriminator augmentation (ADA) from the paper
+"Training Generative Adversarial Networks with Limited Data"."""
+
+import numpy as np
+import tensorflow as tf
+import scipy.signal
+import dnnlib
+import dnnlib.tflib as tflib
+
+from training import loss
+
+#----------------------------------------------------------------------------
+# Main class for adaptive discriminator augmentation (ADA).
+# - Performs adaptive tuning of augmentation strength during training.
+# - Acts as a wrapper for the augmentation pipeline.
+# - Keeps track of the necessary training statistics.
+# - Calculates statistics for the validation set, if available.
+
+class AdaptiveAugment:
+ def __init__(self,
+ apply_func = None, # Function representing the augmentation pipeline. Can be a fully-qualified name, a function object, or None.
+ apply_args = {}, # Keyword arguments for the augmentation pipeline.
+ initial_strength = 0, # Augmentation strength (p) to use initially.
+ tune_heuristic = None, # Heuristic for tuning the augmentation strength dynamically: 'rt', 'rv', None.
+ tune_target = None, # Target value for the selected heuristic.
+ tune_kimg = 500, # Adjustment speed, measured in how many kimg it takes for the strength to increase/decrease by one unit.
+ stat_decay_kimg = 0, # Exponential moving average to use for training statistics, measured as the half-life in kimg. 0 = disable EMA.
+ ):
+ tune_stats = {
+ 'rt': {'Loss/signs/real'},
+ 'rv': {'Loss/scores/fake', 'Loss/scores/real', 'Loss/scores/valid'},
+ None: {},
+ }
+ assert tune_heuristic in tune_stats
+ assert apply_func is None or isinstance(apply_func, str) or dnnlib.util.is_top_level_function(apply_func)
+
+ # Configuration.
+ self.apply_func = dnnlib.util.get_obj_by_name(apply_func) if isinstance(apply_func, str) else apply_func
+ self.apply_args = apply_args
+ self.strength = initial_strength
+ self.tune_heuristic = tune_heuristic
+ self.tune_target = tune_target
+ self.tune_kimg = tune_kimg
+ self.stat_decay_kimg = stat_decay_kimg
+
+ # Runtime state.
+ self._tune_stats = tune_stats[tune_heuristic]
+ self._strength_var = None
+ self._acc_vars = dict() # {name: [var, ...], ...}
+ self._acc_decay_in = None
+ self._acc_decay_ops = dict() # {name: op, ...}
+ self._valid_images = None
+ self._valid_labels = None
+ self._valid_images_in = None
+ self._valid_labels_in = None
+ self._valid_op = None
+ self._valid_ofs = 0
+
+ def init_validation_set(self, D_gpus, training_set):
+ assert self._valid_images is None
+ images, labels = training_set.load_validation_set_np()
+ if images.shape[0] == 0:
+ return
+ self._valid_images = images
+ self._valid_labels = labels
+
+ # Build validation graph.
+ with tflib.absolute_name_scope('Validation'), tf.control_dependencies(None):
+ with tf.device('/cpu:0'):
+ self._valid_images_in = tf.placeholder(training_set.dtype, name='valid_images_in', shape=[None]+training_set.shape)
+ self._valid_labels_in = tf.placeholder(training_set.label_dtype, name='valid_labels_in', shape=[None,training_set.label_size])
+ images_in_gpus = tf.split(self._valid_images_in, len(D_gpus))
+ labels_in_gpus = tf.split(self._valid_labels_in, len(D_gpus))
+ ops = []
+ for gpu, (D_gpu, images_in_gpu, labels_in_gpu) in enumerate(zip(D_gpus, images_in_gpus, labels_in_gpus)):
+ with tf.device(f'/gpu:{gpu}'):
+ images_expr = tf.cast(images_in_gpu, tf.float32) * (2 / 255) - 1
+ D_valid = loss.eval_D(D_gpu, self, images_expr, labels_in_gpu, report='valid')
+ ops += [D_valid.scores]
+ self._valid_op = tf.group(*ops)
+
+ def apply(self, images, labels, enable=True):
+ if not enable or self.apply_func is None or (self.strength == 0 and self.tune_heuristic is None):
+ return images, labels
+ with tf.name_scope('Augment'):
+ images, labels = self.apply_func(images, labels, strength=self.get_strength_var(), **self.apply_args)
+ return images, labels
+
+ def get_strength_var(self):
+ if self._strength_var is None:
+ with tflib.absolute_name_scope('Augment'), tf.control_dependencies(None):
+ self._strength_var = tf.Variable(np.float32(self.strength), name='strength', trainable=False)
+ return self._strength_var
+
+ def report_stat(self, name, expr):
+ if name in self._tune_stats:
+ expr = self._increment_acc(name, expr)
+ return expr
+
+ def tune(self, nimg_delta):
+ acc = {name: self._read_and_decay_acc(name, nimg_delta) for name in self._tune_stats}
+ nimg_ratio = nimg_delta / (self.tune_kimg * 1000)
+ strength = self.strength
+
+ if self.tune_heuristic == 'rt':
+ assert self.tune_target is not None
+ rt = acc['Loss/signs/real']
+ strength += nimg_ratio * np.sign(rt - self.tune_target)
+
+ if self.tune_heuristic == 'rv':
+ assert self.tune_target is not None
+ assert self._valid_images is not None
+ rv = (acc['Loss/scores/real'] - acc['Loss/scores/valid']) / max(acc['Loss/scores/real'] - acc['Loss/scores/fake'], 1e-8)
+ strength += nimg_ratio * np.sign(rv - self.tune_target)
+
+ self._set_strength(strength)
+
+ def run_validation(self, minibatch_size):
+ if self._valid_images is not None:
+ indices = [(self._valid_ofs + i) % self._valid_images.shape[0] for i in range(minibatch_size)]
+ tflib.run(self._valid_op, {self._valid_images_in: self._valid_images[indices], self._valid_labels_in: self._valid_labels[indices]})
+ self._valid_ofs += len(indices)
+
+ def _set_strength(self, strength):
+ strength = max(strength, 0)
+ if self._strength_var is not None and strength != self.strength:
+ tflib.set_vars({self._strength_var: strength})
+ self.strength = strength
+
+ def _increment_acc(self, name, expr):
+ with tf.name_scope('acc_' + name):
+ with tf.control_dependencies(None):
+ acc_var = tf.Variable(tf.zeros(2), name=name, trainable=False) # [acc_num, acc_sum]
+ if name not in self._acc_vars:
+ self._acc_vars[name] = []
+ self._acc_vars[name].append(acc_var)
+ expr_num = tf.shape(tf.reshape(expr, [-1]))[0]
+ expr_sum = tf.reduce_sum(expr)
+ acc_op = tf.assign_add(acc_var, [expr_num, expr_sum])
+ with tf.control_dependencies([acc_op]):
+ return tf.identity(expr)
+
+ def _read_and_decay_acc(self, name, nimg_delta):
+ acc_vars = self._acc_vars[name]
+ acc_num, acc_sum = tuple(np.sum(tflib.run(acc_vars), axis=0))
+ if nimg_delta > 0:
+ with tflib.absolute_name_scope('Augment'), tf.control_dependencies(None):
+ if self._acc_decay_in is None:
+ self._acc_decay_in = tf.placeholder(tf.float32, name='acc_decay_in', shape=[])
+ if name not in self._acc_decay_ops:
+ with tf.name_scope('acc_' + name):
+ ops = [tf.assign(var, var * self._acc_decay_in) for var in acc_vars]
+ self._acc_decay_ops[name] = tf.group(*ops)
+ acc_decay = 0.5 ** (nimg_delta / (self.stat_decay_kimg * 1000)) if self.stat_decay_kimg > 0 else 0
+ tflib.run(self._acc_decay_ops[name], {self._acc_decay_in: acc_decay})
+ return acc_sum / acc_num if acc_num > 0 else 0
+
+#----------------------------------------------------------------------------
+# Helper for randomly gating augmentation parameters based on the given probability.
+
+def gate_augment_params(probability, params, disabled_val):
+ shape = tf.shape(params)
+ cond = (tf.random_uniform(shape[:1], 0, 1) < probability)
+ disabled_val = tf.broadcast_to(tf.convert_to_tensor(disabled_val, dtype=params.dtype), shape)
+ return tf.where(cond, params, disabled_val)
+
+#----------------------------------------------------------------------------
+# Helpers for constructing batched transformation matrices.
+
+def construct_batch_of_matrices(*rows):
+ rows = [[tf.convert_to_tensor(x, dtype=tf.float32) for x in r] for r in rows]
+ batch_elems = [x for r in rows for x in r if x.shape.rank != 0]
+ assert all(x.shape.rank == 1 for x in batch_elems)
+ batch_size = tf.shape(batch_elems[0])[0] if len(batch_elems) else 1
+ rows = [[tf.broadcast_to(x, [batch_size]) for x in r] for r in rows]
+ return tf.transpose(rows, [2, 0, 1])
+
+def translate_2d(tx, ty):
+ return construct_batch_of_matrices(
+ [1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1])
+
+def translate_3d(tx, ty, tz):
+ return construct_batch_of_matrices(
+ [1, 0, 0, tx],
+ [0, 1, 0, ty],
+ [0, 0, 1, tz],
+ [0, 0, 0, 1])
+
+def scale_2d(sx, sy):
+ return construct_batch_of_matrices(
+ [sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1])
+
+def scale_3d(sx, sy, sz):
+ return construct_batch_of_matrices(
+ [sx, 0, 0, 0],
+ [0, sy, 0, 0],
+ [0, 0, sz, 0],
+ [0, 0, 0, 1])
+
+def rotate_2d(theta):
+ return construct_batch_of_matrices(
+ [tf.cos(theta), tf.sin(-theta), 0],
+ [tf.sin(theta), tf.cos(theta), 0],
+ [0, 0, 1])
+
+def rotate_3d(v, theta):
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
+ s = tf.sin(theta); c = tf.cos(theta); cc = 1 - c
+ return construct_batch_of_matrices(
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
+ [0, 0, 0, 1])
+
+def translate_2d_inv(tx, ty):
+ return translate_2d(-tx, -ty)
+
+def scale_2d_inv(sx, sy):
+ return scale_2d(1/sx, 1/sy)
+
+def rotate_2d_inv(theta):
+ return rotate_2d(-theta)
+
+#----------------------------------------------------------------------------
+# Coefficients of various wavelet decomposition low-pass filters.
+
+wavelets = {
+ 'haar': [0.7071067811865476, 0.7071067811865476],
+ 'db1': [0.7071067811865476, 0.7071067811865476],
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
+}
+
+#----------------------------------------------------------------------------
+# Versatile image augmentation pipeline from the paper
+# "Training Generative Adversarial Networks with Limited Data".
+#
+# All augmentations are disabled by default; individual augmentations can
+# be enabled by setting their probability multipliers to 1.
+
+def augment_pipeline(
+ images, # Input images: NCHW, float32, dynamic range [-1,+1].
+ labels, # Input labels.
+ strength = 1, # Overall multiplier for augmentation probability; can be a Tensor.
+ debug_percentile = None, # Percentile value for visualizing parameter ranges; None = normal operation.
+
+ # Pixel blitting.
+ xflip = 0, # Probability multiplier for x-flip.
+ rotate90 = 0, # Probability multiplier for 90 degree rotations.
+ xint = 0, # Probability multiplier for integer translation.
+ xint_max = 0.125, # Range of integer translation, relative to image dimensions.
+
+ # General geometric transformations.
+ scale = 0, # Probability multiplier for isotropic scaling.
+ rotate = 0, # Probability multiplier for arbitrary rotation.
+ aniso = 0, # Probability multiplier for anisotropic scaling.
+ xfrac = 0, # Probability multiplier for fractional translation.
+ scale_std = 0.2, # Log2 standard deviation of isotropic scaling.
+ rotate_max = 1, # Range of arbitrary rotation, 1 = full circle.
+ aniso_std = 0.2, # Log2 standard deviation of anisotropic scaling.
+ xfrac_std = 0.125, # Standard deviation of frational translation, relative to image dimensions.
+
+ # Color transformations.
+ brightness = 0, # Probability multiplier for brightness.
+ contrast = 0, # Probability multiplier for contrast.
+ lumaflip = 0, # Probability multiplier for luma flip.
+ hue = 0, # Probability multiplier for hue rotation.
+ saturation = 0, # Probability multiplier for saturation.
+ brightness_std = 0.2, # Standard deviation of brightness.
+ contrast_std = 0.5, # Log2 standard deviation of contrast.
+ hue_max = 1, # Range of hue rotation, 1 = full circle.
+ saturation_std = 1, # Log2 standard deviation of saturation.
+
+ # Image-space filtering.
+ imgfilter = 0, # Probability multiplier for image-space filtering.
+ imgfilter_bands = [1,1,1,1], # Probability multipliers for individual frequency bands.
+ imgfilter_std = 1, # Log2 standard deviation of image-space filter amplification.
+
+ # Image-space corruptions.
+ noise = 0, # Probability multiplier for additive RGB noise.
+ cutout = 0, # Probability multiplier for cutout.
+ noise_std = 0.1, # Standard deviation of additive RGB noise.
+ cutout_size = 0.5, # Size of the cutout rectangle, relative to image dimensions.
+):
+ # Determine input shape.
+ batch, channels, height, width = images.shape.as_list()
+ if batch is None:
+ batch = tf.shape(images)[0]
+
+ # -------------------------------------
+ # Select parameters for pixel blitting.
+ # -------------------------------------
+
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
+ I_3 = tf.eye(3, batch_shape=[batch])
+ G_inv = I_3
+
+ # Apply x-flip with probability (xflip * strength).
+ if xflip > 0:
+ i = tf.floor(tf.random_uniform([batch], 0, 2))
+ i = gate_augment_params(xflip * strength, i, 0)
+ if debug_percentile is not None:
+ i = tf.floor(tf.broadcast_to(debug_percentile, [batch]) * 2)
+ G_inv @= scale_2d_inv(1 - 2 * i, 1)
+
+ # Apply 90 degree rotations with probability (rotate90 * strength).
+ if rotate90 > 0:
+ i = tf.floor(tf.random_uniform([batch], 0, 4))
+ i = gate_augment_params(rotate90 * strength, i, 0)
+ if debug_percentile is not None:
+ i = tf.floor(tf.broadcast_to(debug_percentile, [batch]) * 4)
+ G_inv @= rotate_2d_inv(-np.pi / 2 * i)
+
+ # Apply integer translation with probability (xint * strength).
+ if xint > 0:
+ t = tf.random_uniform([batch, 2], -xint_max, xint_max)
+ t = gate_augment_params(xint * strength, t, 0)
+ if debug_percentile is not None:
+ t = (tf.broadcast_to(debug_percentile, [batch, 2]) * 2 - 1) * xint_max
+ G_inv @= translate_2d_inv(tf.rint(t[:,0] * width), tf.rint(t[:,1] * height))
+
+ # --------------------------------------------------------
+ # Select parameters for general geometric transformations.
+ # --------------------------------------------------------
+
+ # Apply isotropic scaling with probability (scale * strength).
+ if scale > 0:
+ s = 2 ** tf.random_normal([batch], 0, scale_std)
+ s = gate_augment_params(scale * strength, s, 1)
+ if debug_percentile is not None:
+ s = 2 ** (tflib.erfinv(tf.broadcast_to(debug_percentile, [batch]) * 2 - 1) * scale_std)
+ G_inv @= scale_2d_inv(s, s)
+
+ # Apply pre-rotation with probability p_rot.
+ p_rot = 1 - tf.sqrt(tf.cast(tf.maximum(1 - rotate * strength, 0), tf.float32)) # P(pre OR post) = p
+ if rotate > 0:
+ theta = tf.random_uniform([batch], -np.pi * rotate_max, np.pi * rotate_max)
+ theta = gate_augment_params(p_rot, theta, 0)
+ if debug_percentile is not None:
+ theta = (tf.broadcast_to(debug_percentile, [batch]) * 2 - 1) * np.pi * rotate_max
+ G_inv @= rotate_2d_inv(-theta) # Before anisotropic scaling.
+
+ # Apply anisotropic scaling with probability (aniso * strength).
+ if aniso > 0:
+ s = 2 ** tf.random_normal([batch], 0, aniso_std)
+ s = gate_augment_params(aniso * strength, s, 1)
+ if debug_percentile is not None:
+ s = 2 ** (tflib.erfinv(tf.broadcast_to(debug_percentile, [batch]) * 2 - 1) * aniso_std)
+ G_inv @= scale_2d_inv(s, 1 / s)
+
+ # Apply post-rotation with probability p_rot.
+ if rotate > 0:
+ theta = tf.random_uniform([batch], -np.pi * rotate_max, np.pi * rotate_max)
+ theta = gate_augment_params(p_rot, theta, 0)
+ if debug_percentile is not None:
+ theta = tf.zeros([batch])
+ G_inv @= rotate_2d_inv(-theta) # After anisotropic scaling.
+
+ # Apply fractional translation with probability (xfrac * strength).
+ if xfrac > 0:
+ t = tf.random_normal([batch, 2], 0, xfrac_std)
+ t = gate_augment_params(xfrac * strength, t, 0)
+ if debug_percentile is not None:
+ t = tflib.erfinv(tf.broadcast_to(debug_percentile, [batch, 2]) * 2 - 1) * xfrac_std
+ G_inv @= translate_2d_inv(t[:,0] * width, t[:,1] * height)
+
+ # ----------------------------------
+ # Execute geometric transformations.
+ # ----------------------------------
+
+ # Execute if the transform is not identity.
+ if G_inv is not I_3:
+
+ # Setup orthogonal lowpass filter.
+ Hz = wavelets['sym6']
+ Hz = np.asarray(Hz, dtype=np.float32)
+ Hz = np.reshape(Hz, [-1, 1, 1]).repeat(channels, axis=1) # [tap, channel, 1]
+ Hz_pad = Hz.shape[0] // 4
+
+ # Calculate padding.
+ cx = (width - 1) / 2
+ cy = (height - 1) / 2
+ cp = np.transpose([[-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1]]) # [xyz, idx]
+ cp = G_inv @ cp[np.newaxis] # [batch, xyz, idx]
+ cp = cp[:, :2, :] # [batch, xy, idx]
+ m_lo = tf.ceil(tf.reduce_max(-cp, axis=[0,2]) - [cx, cy] + Hz_pad * 2)
+ m_hi = tf.ceil(tf.reduce_max( cp, axis=[0,2]) - [cx, cy] + Hz_pad * 2)
+ m_lo = tf.clip_by_value(m_lo, [0, 0], [width-1, height-1])
+ m_hi = tf.clip_by_value(m_hi, [0, 0], [width-1, height-1])
+
+ # Pad image and adjust origin.
+ images = tf.transpose(images, [0, 2, 3, 1]) # NCHW => NHWC
+ pad = [[0, 0], [m_lo[1], m_hi[1]], [m_lo[0], m_hi[0]], [0, 0]]
+ images = tf.pad(tensor=images, paddings=pad, mode='REFLECT')
+ T_in = translate_2d(cx + m_lo[0], cy + m_lo[1])
+ T_out = translate_2d_inv(cx + Hz_pad, cy + Hz_pad)
+ G_inv = T_in @ G_inv @ T_out
+
+ # Upsample.
+ shape = [batch, tf.shape(images)[1] * 2, tf.shape(images)[2] * 2, channels]
+ images = tf.nn.depthwise_conv2d_backprop_input(input_sizes=shape, filter=Hz[np.newaxis, :], out_backprop=images, strides=[1,2,2,1], padding='SAME', data_format='NHWC')
+ images = tf.nn.depthwise_conv2d_backprop_input(input_sizes=shape, filter=Hz[:, np.newaxis], out_backprop=images, strides=[1,1,1,1], padding='SAME', data_format='NHWC')
+ G_inv = scale_2d(2, 2) @ G_inv @ scale_2d_inv(2, 2) # Account for the increased resolution.
+
+ # Execute transformation.
+ transforms = tf.reshape(G_inv, [-1, 9])[:, :8]
+ shape = [(height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
+ images = tf.contrib.image.transform(images=images, transforms=transforms, output_shape=shape, interpolation='BILINEAR')
+
+ # Downsample and crop.
+ images = tf.nn.depthwise_conv2d(input=images, filter=Hz[np.newaxis,:], strides=[1,1,1,1], padding='SAME', data_format='NHWC')
+ images = tf.nn.depthwise_conv2d(input=images, filter=Hz[:,np.newaxis], strides=[1,2,2,1], padding='SAME', data_format='NHWC')
+ images = images[:, Hz_pad : height + Hz_pad, Hz_pad : width + Hz_pad, :]
+ images = tf.transpose(images, [0, 3, 1, 2]) # NHWC => NCHW
+
+ # --------------------------------------------
+ # Select parameters for color transformations.
+ # --------------------------------------------
+
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
+ I_4 = tf.eye(4, batch_shape=[batch])
+ C = I_4
+
+ # Apply brightness with probability (brightness * strength).
+ if brightness > 0:
+ b = tf.random_normal([batch], 0, brightness_std)
+ b = gate_augment_params(brightness * strength, b, 0)
+ if debug_percentile is not None:
+ b = tflib.erfinv(tf.broadcast_to(debug_percentile, [batch]) * 2 - 1) * brightness_std
+ C = translate_3d(b, b, b) @ C
+
+ # Apply contrast with probability (contrast * strength).
+ if contrast > 0:
+ c = 2 ** tf.random_normal([batch], 0, contrast_std)
+ c = gate_augment_params(contrast * strength, c, 1)
+ if debug_percentile is not None:
+ c = 2 ** (tflib.erfinv(tf.broadcast_to(debug_percentile, [batch]) * 2 - 1) * contrast_std)
+ C = scale_3d(c, c, c) @ C
+
+ # Apply luma flip with probability (lumaflip * strength).
+ v = np.array([1, 1, 1, 0]) / np.sqrt(3) # Luma axis.
+ if lumaflip > 0:
+ i = tf.floor(tf.random_uniform([batch], 0, 2))
+ i = gate_augment_params(lumaflip * strength, i, 0)
+ if debug_percentile is not None:
+ i = tf.floor(tf.broadcast_to(debug_percentile, [batch]) * 2)
+ i = tf.reshape(i, [batch, 1, 1])
+ C = (I_4 - 2 * np.outer(v, v) * i) @ C # Householder reflection.
+
+ # Apply hue rotation with probability (hue * strength).
+ if hue > 0 and channels > 1:
+ theta = tf.random_uniform([batch], -np.pi * hue_max, np.pi * hue_max)
+ theta = gate_augment_params(hue * strength, theta, 0)
+ if debug_percentile is not None:
+ theta = (tf.broadcast_to(debug_percentile, [batch]) * 2 - 1) * np.pi * hue_max
+ C = rotate_3d(v, theta) @ C # Rotate around v.
+
+ # Apply saturation with probability (saturation * strength).
+ if saturation > 0 and channels > 1:
+ s = 2 ** tf.random_normal([batch], 0, saturation_std)
+ s = gate_augment_params(saturation * strength, s, 1)
+ if debug_percentile is not None:
+ s = 2 ** (tflib.erfinv(tf.broadcast_to(debug_percentile, [batch]) * 2 - 1) * saturation_std)
+ s = tf.reshape(s, [batch, 1, 1])
+ C = (np.outer(v, v) + (I_4 - np.outer(v, v)) * s) @ C
+
+ # ------------------------------
+ # Execute color transformations.
+ # ------------------------------
+
+ # Execute if the transform is not identity.
+ if C is not I_4:
+ images = tf.reshape(images, [batch, channels, height * width])
+ if channels == 3:
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
+ elif channels == 1:
+ C = tf.reduce_mean(C[:, :3, :], axis=1, keepdims=True)
+ images = images * tf.reduce_sum(C[:, :, :3], axis=2, keepdims=True) + C[:, :, 3:]
+ else:
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
+ images = tf.reshape(images, [batch, channels, height, width])
+
+ # ----------------------
+ # Image-space filtering.
+ # ----------------------
+
+ if imgfilter > 0:
+ num_bands = 4
+ assert len(imgfilter_bands) == num_bands
+ expected_power = np.array([10, 1, 1, 1]) / 13 # Expected power spectrum (1/f).
+
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
+ g = tf.ones([batch, num_bands]) # Global gain vector (identity).
+ for i, band_strength in enumerate(imgfilter_bands):
+ t_i = 2 ** tf.random_normal([batch], 0, imgfilter_std)
+ t_i = gate_augment_params(imgfilter * strength * band_strength, t_i, 1)
+ if debug_percentile is not None:
+ t_i = 2 ** (tflib.erfinv(tf.broadcast_to(debug_percentile, [batch]) * 2 - 1) * imgfilter_std) if band_strength > 0 else tf.ones([batch])
+ t = tf.ones([batch, num_bands]) # Temporary gain vector.
+ t = tf.concat([t[:, :i], t_i[:, np.newaxis], t[:, i+1:]], axis=-1) # Replace i'th element.
+ t /= tf.sqrt(tf.reduce_sum(expected_power * tf.square(t), axis=-1, keepdims=True)) # Normalize power.
+ g *= t # Accumulate into global gain.
+
+ # Construct filter bank.
+ Hz_lo = wavelets['sym2']
+ Hz_lo = np.asarray(Hz_lo, dtype=np.float32) # H(z)
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
+ Hz_bands = np.eye(num_bands, 1) # Bandpass(H(z), b_i)
+ for i in range(1, num_bands):
+ Hz_bands = np.dstack([Hz_bands, np.zeros_like(Hz_bands)]).reshape(num_bands, -1)[:, :-1]
+ Hz_bands = scipy.signal.convolve(Hz_bands, [Hz_lo2])
+ Hz_bands[i, (Hz_bands.shape[1] - Hz_hi2.size) // 2 : (Hz_bands.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
+
+ # Construct combined amplification filter.
+ Hz_prime = g @ Hz_bands # [batch, tap]
+ Hz_prime = tf.transpose(Hz_prime) # [tap, batch]
+ Hz_prime = tf.tile(Hz_prime[:, :, np.newaxis], [1, 1, channels]) # [tap, batch, channels]
+ Hz_prime = tf.reshape(Hz_prime, [-1, batch * channels, 1]) # [tap, batch * channels, 1]
+
+ # Apply filter.
+ images = tf.reshape(images, [1, -1, height, width])
+ pad = Hz_bands.shape[1] // 2
+ pad = [[0,0], [0,0], [pad, pad], [pad, pad]]
+ images = tf.pad(tensor=images, paddings=pad, mode='REFLECT')
+ images = tf.nn.depthwise_conv2d(input=images, filter=Hz_prime[np.newaxis,:], strides=[1,1,1,1], padding='VALID', data_format='NCHW')
+ images = tf.nn.depthwise_conv2d(input=images, filter=Hz_prime[:,np.newaxis], strides=[1,1,1,1], padding='VALID', data_format='NCHW')
+ images = tf.reshape(images, [-1, channels, height, width])
+
+ # ------------------------
+ # Image-space corruptions.
+ # ------------------------
+
+ # Apply additive RGB noise with probability (noise * strength).
+ if noise > 0:
+ sigma = tf.abs(tf.random_normal([batch], 0, noise_std))
+ sigma = gate_augment_params(noise * strength, sigma, 0)
+ if debug_percentile is not None:
+ sigma = tflib.erfinv(tf.broadcast_to(debug_percentile, [batch])) * noise_std
+ sigma = tf.reshape(sigma, [-1, 1, 1, 1])
+ images += tf.random_normal([batch, channels, height, width]) * sigma
+
+ # Apply cutout with probability (cutout * strength).
+ if cutout > 0:
+ size = tf.fill([batch, 2], cutout_size)
+ size = gate_augment_params(cutout * strength, size, 0)
+ center = tf.random_uniform([batch, 2], 0, 1)
+ if debug_percentile is not None:
+ size = tf.fill([batch, 2], cutout_size)
+ center = tf.broadcast_to(debug_percentile, [batch, 2])
+ size = tf.reshape(size, [batch, 2, 1, 1, 1])
+ center = tf.reshape(center, [batch, 2, 1, 1, 1])
+ coord_x = tf.reshape(tf.range(width, dtype=tf.float32), [1, 1, 1, width])
+ coord_y = tf.reshape(tf.range(height, dtype=tf.float32), [1, 1, height, 1])
+ mask_x = (tf.abs((coord_x + 0.5) / width - center[:, 0]) >= size[:, 0] / 2)
+ mask_y = (tf.abs((coord_y + 0.5) / height - center[:, 1]) >= size[:, 1] / 2)
+ mask = tf.cast(tf.logical_or(mask_x, mask_y), tf.float32)
+ images *= mask
+
+ return images, labels
+
+#----------------------------------------------------------------------------
diff --git a/training/dataset.py b/training/dataset.py
new file mode 100755
index 00000000..b96876ed
--- /dev/null
+++ b/training/dataset.py
@@ -0,0 +1,233 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Streaming images and labels from dataset created with dataset_tool.py."""
+
+import os
+import glob
+import numpy as np
+import tensorflow as tf
+import dnnlib.tflib as tflib
+
+#----------------------------------------------------------------------------
+# Dataset class that loads images from tfrecords files.
+
+class TFRecordDataset:
+ def __init__(self,
+ tfrecord_dir, # Directory containing a collection of tfrecords files.
+ resolution = None, # Dataset resolution, None = autodetect.
+ label_file = None, # Relative path of the labels file, None = autodetect.
+ max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components.
+ max_images = None, # Maximum number of images to use, None = use all images.
+ max_validation = 10000, # Maximum size of the validation set, None = use all available images.
+ mirror_augment = False, # Apply mirror augment?
+ repeat = True, # Repeat dataset indefinitely?
+ shuffle = True, # Shuffle images?
+ shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling.
+ prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching.
+ buffer_mb = 256, # Read buffer size (megabytes).
+ num_threads = 2, # Number of concurrent threads.
+ _is_validation = False,
+):
+ self.tfrecord_dir = tfrecord_dir
+ self.resolution = None
+ self.resolution_log2 = None
+ self.shape = [] # [channels, height, width]
+ self.dtype = 'uint8'
+ self.label_file = label_file
+ self.label_size = None # components
+ self.label_dtype = None
+ self.has_validation_set = None
+ self.mirror_augment = mirror_augment
+ self.repeat = repeat
+ self.shuffle = shuffle
+ self._max_validation = max_validation
+ self._np_labels = None
+ self._tf_minibatch_in = None
+ self._tf_labels_var = None
+ self._tf_labels_dataset = None
+ self._tf_datasets = dict()
+ self._tf_iterator = None
+ self._tf_init_ops = dict()
+ self._tf_minibatch_np = None
+ self._cur_minibatch = -1
+ self._cur_lod = -1
+
+ # List files in the dataset directory.
+ assert os.path.isdir(self.tfrecord_dir)
+ all_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*')))
+ self.has_validation_set = (self._max_validation > 0) and any(os.path.basename(f).startswith('validation-') for f in all_files)
+ all_files = [f for f in all_files if os.path.basename(f).startswith('validation-') == _is_validation]
+
+ # Inspect tfrecords files.
+ tfr_files = [f for f in all_files if f.endswith('.tfrecords')]
+ assert len(tfr_files) >= 1
+ tfr_shapes = []
+ for tfr_file in tfr_files:
+ tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
+ for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
+ tfr_shapes.append(self.parse_tfrecord_np(record).shape)
+ break
+
+ # Autodetect label filename.
+ if self.label_file is None:
+ guess = [f for f in all_files if f.endswith('.labels')]
+ if len(guess):
+ self.label_file = guess[0]
+ elif not os.path.isfile(self.label_file):
+ guess = os.path.join(self.tfrecord_dir, self.label_file)
+ if os.path.isfile(guess):
+ self.label_file = guess
+
+ # Determine shape and resolution.
+ max_shape = max(tfr_shapes, key=np.prod)
+ self.resolution = resolution if resolution is not None else max_shape[1]
+ self.resolution_log2 = int(np.log2(self.resolution))
+ self.shape = [max_shape[0], self.resolution, self.resolution]
+ tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]
+ assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
+ assert all(shape[1] == shape[2] for shape in tfr_shapes)
+ assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))
+ assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))
+
+ # Load labels.
+ assert max_label_size == 'full' or max_label_size >= 0
+ self._np_labels = np.zeros([1<<30, 0], dtype=np.float32)
+ if self.label_file is not None and max_label_size != 0:
+ self._np_labels = np.load(self.label_file)
+ assert self._np_labels.ndim == 2
+ if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:
+ self._np_labels = self._np_labels[:, :max_label_size]
+ if max_images is not None and self._np_labels.shape[0] > max_images:
+ self._np_labels = self._np_labels[:max_images]
+ self.label_size = self._np_labels.shape[1]
+ self.label_dtype = self._np_labels.dtype.name
+
+ # Build TF expressions.
+ with tf.name_scope('Dataset'), tf.device('/cpu:0'), tf.control_dependencies(None):
+ self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])
+ self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var')
+ self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)
+ for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):
+ if tfr_lod < 0:
+ continue
+ dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
+ if max_images is not None:
+ dset = dset.take(max_images)
+ dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads)
+ dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
+ bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
+ if self.shuffle and shuffle_mb > 0:
+ dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
+ if self.repeat:
+ dset = dset.repeat()
+ if prefetch_mb > 0:
+ dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
+ dset = dset.batch(self._tf_minibatch_in)
+ self._tf_datasets[tfr_lod] = dset
+ self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)
+ self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}
+
+ def close(self):
+ pass
+
+ # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf().
+ def configure(self, minibatch_size, lod=0):
+ lod = int(np.floor(lod))
+ assert minibatch_size >= 1 and lod in self._tf_datasets
+ if self._cur_minibatch != minibatch_size or self._cur_lod != lod:
+ self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size})
+ self._cur_minibatch = minibatch_size
+ self._cur_lod = lod
+
+ # Get next minibatch as TensorFlow expressions.
+ def get_minibatch_tf(self):
+ images, labels = self._tf_iterator.get_next()
+ if self.mirror_augment:
+ images = tf.cast(images, tf.float32)
+ images = tf.where(tf.random_uniform([tf.shape(images)[0]]) < 0.5, images, tf.reverse(images, [3]))
+ images = tf.cast(images, self.dtype)
+ return images, labels
+
+ # Get next minibatch as NumPy arrays.
+ def get_minibatch_np(self, minibatch_size, lod=0): # => (images, labels) or (None, None)
+ self.configure(minibatch_size, lod)
+ if self._tf_minibatch_np is None:
+ with tf.name_scope('Dataset'):
+ self._tf_minibatch_np = self.get_minibatch_tf()
+ try:
+ return tflib.run(self._tf_minibatch_np)
+ except tf.errors.OutOfRangeError:
+ return None, None
+
+ # Get random labels as TensorFlow expression.
+ def get_random_labels_tf(self, minibatch_size): # => labels
+ with tf.name_scope('Dataset'):
+ if self.label_size > 0:
+ with tf.device('/cpu:0'):
+ return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32))
+ return tf.zeros([minibatch_size, 0], self.label_dtype)
+
+ # Get random labels as NumPy array.
+ def get_random_labels_np(self, minibatch_size): # => labels
+ if self.label_size > 0:
+ return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])]
+ return np.zeros([minibatch_size, 0], self.label_dtype)
+
+ # Load validation set as NumPy array.
+ def load_validation_set_np(self):
+ images = []
+ labels = []
+ if self.has_validation_set:
+ validation_set = TFRecordDataset(
+ tfrecord_dir=self.tfrecord_dir, resolution=self.shape[2], max_label_size=self.label_size,
+ max_images=self._max_validation, repeat=False, shuffle=False, prefetch_mb=0, _is_validation=True)
+ validation_set.configure(1)
+ while True:
+ image, label = validation_set.get_minibatch_np(1)
+ if image is None:
+ break
+ images.append(image)
+ labels.append(label)
+ images = np.concatenate(images, axis=0) if len(images) else np.zeros([0] + self.shape, dtype=self.dtype)
+ labels = np.concatenate(labels, axis=0) if len(labels) else np.zeros([0, self.label_size], self.label_dtype)
+ assert list(images.shape[1:]) == self.shape
+ assert labels.shape[1] == self.label_size
+ assert images.shape[0] <= self._max_validation
+ return images, labels
+
+ # Parse individual image from a tfrecords file into TensorFlow expression.
+ @staticmethod
+ def parse_tfrecord_tf(record):
+ features = tf.parse_single_example(record, features={
+ 'shape': tf.FixedLenFeature([3], tf.int64),
+ 'data': tf.FixedLenFeature([], tf.string)})
+ data = tf.decode_raw(features['data'], tf.uint8)
+ return tf.reshape(data, features['shape'])
+
+ # Parse individual image from a tfrecords file into NumPy array.
+ @staticmethod
+ def parse_tfrecord_np(record):
+ ex = tf.train.Example()
+ ex.ParseFromString(record)
+ shape = ex.features.feature['shape'].int64_list.value # pylint: disable=no-member
+ data = ex.features.feature['data'].bytes_list.value[0] # pylint: disable=no-member
+ return np.fromstring(data, np.uint8).reshape(shape)
+
+#----------------------------------------------------------------------------
+# Construct a dataset object using the given options.
+
+def load_dataset(path=None, resolution=None, max_images=None, max_label_size=0, mirror_augment=False, repeat=True, shuffle=True, seed=None):
+ _ = seed
+ assert os.path.isdir(path)
+ return TFRecordDataset(
+ tfrecord_dir=path,
+ resolution=resolution, max_images=max_images, max_label_size=max_label_size,
+ mirror_augment=mirror_augment, repeat=repeat, shuffle=shuffle)
+
+#----------------------------------------------------------------------------
diff --git a/training/loss.py b/training/loss.py
new file mode 100755
index 00000000..9d819d29
--- /dev/null
+++ b/training/loss.py
@@ -0,0 +1,307 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Loss functions."""
+
+import numpy as np
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+from dnnlib.tflib.autosummary import autosummary
+
+#----------------------------------------------------------------------------
+# Report statistic for all interested parties (AdaptiveAugment and tfevents).
+
+def report_stat(aug, name, value):
+ if aug is not None:
+ value = aug.report_stat(name, value)
+ value = autosummary(name, value)
+ return value
+
+#----------------------------------------------------------------------------
+# Report loss terms and collect them into EasyDict.
+
+def report_loss(aug, G_loss, D_loss, G_reg=None, D_reg=None):
+ assert G_loss is not None and D_loss is not None
+ terms = dnnlib.EasyDict(G_reg=None, D_reg=None)
+ terms.G_loss = report_stat(aug, 'Loss/G/loss', G_loss)
+ terms.D_loss = report_stat(aug, 'Loss/D/loss', D_loss)
+ if G_reg is not None: terms.G_reg = report_stat(aug, 'Loss/G/reg', G_reg)
+ if D_reg is not None: terms.D_reg = report_stat(aug, 'Loss/D/reg', D_reg)
+ return terms
+
+#----------------------------------------------------------------------------
+# Evaluate G and return results as EasyDict.
+
+def eval_G(G, latents, labels, return_dlatents=False):
+ r = dnnlib.EasyDict()
+ r.args = dnnlib.EasyDict()
+ r.args.is_training = True
+ if return_dlatents:
+ r.args.return_dlatents = True
+ r.images = G.get_output_for(latents, labels, **r.args)
+
+ r.dlatents = None
+ if return_dlatents:
+ r.images, r.dlatents = r.images
+ return r
+
+#----------------------------------------------------------------------------
+# Evaluate D and return results as EasyDict.
+
+def eval_D(D, aug, images, labels, report=None, augment_inputs=True, return_aux=0):
+ r = dnnlib.EasyDict()
+ r.images_aug = images
+ r.labels_aug = labels
+ if augment_inputs and aug is not None:
+ r.images_aug, r.labels_aug = aug.apply(r.images_aug, r.labels_aug)
+
+ r.args = dnnlib.EasyDict()
+ r.args.is_training = True
+ if aug is not None:
+ r.args.augment_strength = aug.get_strength_var()
+ if return_aux > 0:
+ r.args.score_size = return_aux + 1
+ r.scores = D.get_output_for(r.images_aug, r.labels_aug, **r.args)
+
+ r.aux = None
+ if return_aux:
+ r.aux = r.scores[:, 1:]
+ r.scores = r.scores[:, :1]
+
+ if report is not None:
+ report_ops = [
+ report_stat(aug, 'Loss/scores/' + report, r.scores),
+ report_stat(aug, 'Loss/signs/' + report, tf.sign(r.scores)),
+ report_stat(aug, 'Loss/squares/' + report, tf.square(r.scores)),
+ ]
+ with tf.control_dependencies(report_ops):
+ r.scores = tf.identity(r.scores)
+ return r
+
+#----------------------------------------------------------------------------
+# Non-saturating logistic loss with R1 and path length regularizers, used
+# in the paper "Analyzing and Improving the Image Quality of StyleGAN".
+
+def stylegan2(G, D, aug, fake_labels, real_images, real_labels, r1_gamma=10, pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2, **_kwargs):
+ # Evaluate networks for the main loss.
+ minibatch_size = tf.shape(fake_labels)[0]
+ fake_latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
+ G_fake = eval_G(G, fake_latents, fake_labels, return_dlatents=True)
+ D_fake = eval_D(D, aug, G_fake.images, fake_labels, report='fake')
+ D_real = eval_D(D, aug, real_images, real_labels, report='real')
+
+ # Non-saturating logistic loss from "Generative Adversarial Nets".
+ with tf.name_scope('Loss_main'):
+ G_loss = tf.nn.softplus(-D_fake.scores) # -log(sigmoid(D_fake.scores)), pylint: disable=invalid-unary-operand-type
+ D_loss = tf.nn.softplus(D_fake.scores) # -log(1 - sigmoid(D_fake.scores))
+ D_loss += tf.nn.softplus(-D_real.scores) # -log(sigmoid(D_real.scores)), pylint: disable=invalid-unary-operand-type
+ G_reg = 0
+ D_reg = 0
+
+ # R1 regularizer from "Which Training Methods for GANs do actually Converge?".
+ if r1_gamma != 0:
+ with tf.name_scope('Loss_R1'):
+ r1_grads = tf.gradients(tf.reduce_sum(D_real.scores), [real_images])[0]
+ r1_penalty = tf.reduce_sum(tf.square(r1_grads), axis=[1,2,3])
+ r1_penalty = report_stat(aug, 'Loss/r1_penalty', r1_penalty)
+ D_reg += r1_penalty * (r1_gamma * 0.5)
+
+ # Path length regularizer from "Analyzing and Improving the Image Quality of StyleGAN".
+ if pl_weight != 0:
+ with tf.name_scope('Loss_PL'):
+
+ # Evaluate the regularization term using a smaller minibatch to conserve memory.
+ G_pl = G_fake
+ if pl_minibatch_shrink > 1:
+ pl_minibatch_size = minibatch_size // pl_minibatch_shrink
+ pl_latents = fake_latents[:pl_minibatch_size]
+ pl_labels = fake_labels[:pl_minibatch_size]
+ G_pl = eval_G(G, pl_latents, pl_labels, return_dlatents=True)
+
+ # Compute |J*y|.
+ pl_noise = tf.random_normal(tf.shape(G_pl.images)) / np.sqrt(np.prod(G.output_shape[2:]))
+ pl_grads = tf.gradients(tf.reduce_sum(G_pl.images * pl_noise), [G_pl.dlatents])[0]
+ pl_lengths = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(pl_grads), axis=2), axis=1))
+
+ # Track exponential moving average of |J*y|.
+ with tf.control_dependencies(None):
+ pl_mean_var = tf.Variable(name='pl_mean', trainable=False, initial_value=0, dtype=tf.float32)
+ pl_mean = pl_mean_var + pl_decay * (tf.reduce_mean(pl_lengths) - pl_mean_var)
+ pl_update = tf.assign(pl_mean_var, pl_mean)
+
+ # Calculate (|J*y|-a)^2.
+ with tf.control_dependencies([pl_update]):
+ pl_penalty = tf.square(pl_lengths - pl_mean)
+ pl_penalty = report_stat(aug, 'Loss/pl_penalty', pl_penalty)
+
+ # Apply weight.
+ #
+ # Note: The division in pl_noise decreases the weight by num_pixels, and the reduce_mean
+ # in pl_lengths decreases it by num_affine_layers. The effective weight then becomes:
+ #
+ # gamma_pl = pl_weight / num_pixels / num_affine_layers
+ # = 2 / (r^2) / (log2(r) * 2 - 2)
+ # = 1 / (r^2 * (log2(r) - 1))
+ # = ln(2) / (r^2 * (ln(r) - ln(2))
+ #
+ G_reg += tf.tile(pl_penalty, [pl_minibatch_shrink]) * pl_weight
+
+ return report_loss(aug, G_loss, D_loss, G_reg, D_reg)
+
+#----------------------------------------------------------------------------
+# Hybrid loss used for comparison methods used in the paper
+# "Training Generative Adversarial Networks with Limited Data".
+
+def cmethods(G, D, aug, fake_labels, real_images, real_labels,
+ r1_gamma=10, r2_gamma=0,
+ pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2,
+ bcr_real_weight=0, bcr_fake_weight=0, bcr_augment=None,
+ zcr_gen_weight=0, zcr_dis_weight=0, zcr_noise_std=0.1,
+ auxrot_alpha=0, auxrot_beta=0,
+ **_kwargs,
+):
+ # Evaluate networks for the main loss.
+ minibatch_size = tf.shape(fake_labels)[0]
+ fake_latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
+ G_fake = eval_G(G, fake_latents, fake_labels)
+ D_fake = eval_D(D, aug, G_fake.images, fake_labels, report='fake')
+ D_real = eval_D(D, aug, real_images, real_labels, report='real')
+
+ # Non-saturating logistic loss from "Generative Adversarial Nets".
+ with tf.name_scope('Loss_main'):
+ G_loss = tf.nn.softplus(-D_fake.scores) # -log(sigmoid(D_fake.scores)), pylint: disable=invalid-unary-operand-type
+ D_loss = tf.nn.softplus(D_fake.scores) # -log(1 - sigmoid(D_fake.scores))
+ D_loss += tf.nn.softplus(-D_real.scores) # -log(sigmoid(D_real.scores)), pylint: disable=invalid-unary-operand-type
+ G_reg = 0
+ D_reg = 0
+
+ # R1 and R2 regularizers from "Which Training Methods for GANs do actually Converge?".
+ if r1_gamma != 0 or r2_gamma != 0:
+ with tf.name_scope('Loss_R1R2'):
+ if r1_gamma != 0:
+ r1_grads = tf.gradients(tf.reduce_sum(D_real.scores), [real_images])[0]
+ r1_penalty = tf.reduce_sum(tf.square(r1_grads), axis=[1,2,3])
+ r1_penalty = report_stat(aug, 'Loss/r1_penalty', r1_penalty)
+ D_reg += r1_penalty * (r1_gamma * 0.5)
+ if r2_gamma != 0:
+ r2_grads = tf.gradients(tf.reduce_sum(D_fake.scores), [G_fake.images])[0]
+ r2_penalty = tf.reduce_sum(tf.square(r2_grads), axis=[1,2,3])
+ r2_penalty = report_stat(aug, 'Loss/r2_penalty', r2_penalty)
+ D_reg += r2_penalty * (r2_gamma * 0.5)
+
+ # Path length regularizer from "Analyzing and Improving the Image Quality of StyleGAN".
+ if pl_weight != 0:
+ with tf.name_scope('Loss_PL'):
+ pl_minibatch_size = minibatch_size // pl_minibatch_shrink
+ pl_latents = fake_latents[:pl_minibatch_size]
+ pl_labels = fake_labels[:pl_minibatch_size]
+ G_pl = eval_G(G, pl_latents, pl_labels, return_dlatents=True)
+ pl_noise = tf.random_normal(tf.shape(G_pl.images)) / np.sqrt(np.prod(G.output_shape[2:]))
+ pl_grads = tf.gradients(tf.reduce_sum(G_pl.images * pl_noise), [G_pl.dlatents])[0]
+ pl_lengths = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(pl_grads), axis=2), axis=1))
+ with tf.control_dependencies(None):
+ pl_mean_var = tf.Variable(name='pl_mean', trainable=False, initial_value=0, dtype=tf.float32)
+ pl_mean = pl_mean_var + pl_decay * (tf.reduce_mean(pl_lengths) - pl_mean_var)
+ pl_update = tf.assign(pl_mean_var, pl_mean)
+ with tf.control_dependencies([pl_update]):
+ pl_penalty = tf.square(pl_lengths - pl_mean)
+ pl_penalty = report_stat(aug, 'Loss/pl_penalty', pl_penalty)
+ G_reg += tf.tile(pl_penalty, [pl_minibatch_shrink]) * pl_weight
+
+ # bCR regularizer from "Improved consistency regularization for GANs".
+ if (bcr_real_weight != 0 or bcr_fake_weight != 0) and bcr_augment is not None:
+ with tf.name_scope('Loss_bCR'):
+ if bcr_real_weight != 0:
+ bcr_real_images, bcr_real_labels = dnnlib.util.call_func_by_name(D_real.images_aug, D_real.labels_aug, **bcr_augment)
+ D_bcr_real = eval_D(D, aug, bcr_real_images, bcr_real_labels, report='real_bcr', augment_inputs=False)
+ bcr_real_penalty = tf.square(D_bcr_real.scores - D_real.scores)
+ bcr_real_penalty = report_stat(aug, 'Loss/bcr_penalty/real', bcr_real_penalty)
+ D_loss += bcr_real_penalty * bcr_real_weight # NOTE: Must not use lazy regularization for this term.
+ if bcr_fake_weight != 0:
+ bcr_fake_images, bcr_fake_labels = dnnlib.util.call_func_by_name(D_fake.images_aug, D_fake.labels_aug, **bcr_augment)
+ D_bcr_fake = eval_D(D, aug, bcr_fake_images, bcr_fake_labels, report='fake_bcr', augment_inputs=False)
+ bcr_fake_penalty = tf.square(D_bcr_fake.scores - D_fake.scores)
+ bcr_fake_penalty = report_stat(aug, 'Loss/bcr_penalty/fake', bcr_fake_penalty)
+ D_loss += bcr_fake_penalty * bcr_fake_weight # NOTE: Must not use lazy regularization for this term.
+
+ # zCR regularizer from "Improved consistency regularization for GANs".
+ if zcr_gen_weight != 0 or zcr_dis_weight != 0:
+ with tf.name_scope('Loss_zCR'):
+ zcr_fake_latents = fake_latents + tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) * zcr_noise_std
+ G_zcr = eval_G(G, zcr_fake_latents, fake_labels)
+ if zcr_gen_weight > 0:
+ zcr_gen_penalty = -tf.reduce_mean(tf.square(G_fake.images - G_zcr.images), axis=[1,2,3])
+ zcr_gen_penalty = report_stat(aug, 'Loss/zcr_gen_penalty', zcr_gen_penalty)
+ G_loss += zcr_gen_penalty * zcr_gen_weight
+ if zcr_dis_weight > 0:
+ D_zcr = eval_D(D, aug, G_zcr.images, fake_labels, report='fake_zcr', augment_inputs=False)
+ zcr_dis_penalty = tf.square(D_fake.scores - D_zcr.scores)
+ zcr_dis_penalty = report_stat(aug, 'Loss/zcr_dis_penalty', zcr_dis_penalty)
+ D_loss += zcr_dis_penalty * zcr_dis_weight
+
+ # Auxiliary rotation loss from "Self-supervised GANs via auxiliary rotation loss".
+ if auxrot_alpha != 0 or auxrot_beta != 0:
+ with tf.name_scope('Loss_AuxRot'):
+ idx = tf.range(minibatch_size * 4, dtype=tf.int32) // minibatch_size
+ b0 = tf.logical_or(tf.equal(idx, 0), tf.equal(idx, 1))
+ b1 = tf.logical_or(tf.equal(idx, 0), tf.equal(idx, 3))
+ b2 = tf.logical_or(tf.equal(idx, 0), tf.equal(idx, 2))
+ if auxrot_alpha != 0:
+ auxrot_fake = tf.tile(G_fake.images, [4, 1, 1, 1])
+ auxrot_fake = tf.where(b0, auxrot_fake, tf.reverse(auxrot_fake, [2]))
+ auxrot_fake = tf.where(b1, auxrot_fake, tf.reverse(auxrot_fake, [3]))
+ auxrot_fake = tf.where(b2, auxrot_fake, tf.transpose(auxrot_fake, [0, 1, 3, 2]))
+ D_auxrot_fake = eval_D(D, aug, auxrot_fake, fake_labels, return_aux=4)
+ G_loss += tf.nn.sparse_softmax_cross_entropy_with_logits(labels=idx, logits=D_auxrot_fake.aux) * auxrot_alpha
+ if auxrot_beta != 0:
+ auxrot_real = tf.tile(real_images, [4, 1, 1, 1])
+ auxrot_real = tf.where(b0, auxrot_real, tf.reverse(auxrot_real, [2]))
+ auxrot_real = tf.where(b1, auxrot_real, tf.reverse(auxrot_real, [3]))
+ auxrot_real = tf.where(b2, auxrot_real, tf.transpose(auxrot_real, [0, 1, 3, 2]))
+ D_auxrot_real = eval_D(D, aug, auxrot_real, real_labels, return_aux=4)
+ D_loss += tf.nn.sparse_softmax_cross_entropy_with_logits(labels=idx, logits=D_auxrot_real.aux) * auxrot_beta
+
+ return report_loss(aug, G_loss, D_loss, G_reg, D_reg)
+
+#----------------------------------------------------------------------------
+# WGAN-GP loss with epsilon penalty, used in the paper
+# "Progressive Growing of GANs for Improved Quality, Stability, and Variation".
+
+def wgangp(G, D, aug, fake_labels, real_images, real_labels, wgan_epsilon=0.001, wgan_lambda=10, wgan_target=1, **_kwargs):
+ minibatch_size = tf.shape(fake_labels)[0]
+ fake_latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
+ G_fake = eval_G(G, fake_latents, fake_labels)
+ D_fake = eval_D(D, aug, G_fake.images, fake_labels, report='fake')
+ D_real = eval_D(D, aug, real_images, real_labels, report='real')
+
+ # WGAN loss from "Wasserstein Generative Adversarial Networks".
+ with tf.name_scope('Loss_main'):
+ G_loss = -D_fake.scores # pylint: disable=invalid-unary-operand-type
+ D_loss = D_fake.scores - D_real.scores
+
+ # Epsilon penalty from "Progressive Growing of GANs for Improved Quality, Stability, and Variation"
+ with tf.name_scope('Loss_epsilon'):
+ epsilon_penalty = report_stat(aug, 'Loss/epsilon_penalty', tf.square(D_real.scores))
+ D_loss += epsilon_penalty * wgan_epsilon
+
+ # Gradient penalty from "Improved Training of Wasserstein GANs".
+ with tf.name_scope('Loss_GP'):
+ mix_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0, 1, dtype=G_fake.images.dtype)
+ mix_images = tflib.lerp(tf.cast(real_images, G_fake.images.dtype), G_fake.images, mix_factors)
+ mix_labels = real_labels # NOTE: Mixing is performed without respect to fake_labels.
+ D_mix = eval_D(D, aug, mix_images, mix_labels, report='mix')
+ mix_grads = tf.gradients(tf.reduce_sum(D_mix.scores), [mix_images])[0]
+ mix_norms = tf.sqrt(tf.reduce_sum(tf.square(mix_grads), axis=[1,2,3]))
+ mix_norms = report_stat(aug, 'Loss/mix_norms', mix_norms)
+ gradient_penalty = tf.square(mix_norms - wgan_target)
+ D_reg = gradient_penalty * (wgan_lambda / (wgan_target**2))
+
+ return report_loss(aug, G_loss, D_loss, None, D_reg)
+
+#----------------------------------------------------------------------------
diff --git a/training/networks.py b/training/networks.py
new file mode 100755
index 00000000..f9bd0614
--- /dev/null
+++ b/training/networks.py
@@ -0,0 +1,632 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Network architectures from the paper
+"Training Generative Adversarial Networks with Limited Data"."""
+
+import numpy as np
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+from dnnlib.tflib.ops.upfirdn_2d import upsample_2d, downsample_2d, upsample_conv_2d, conv_downsample_2d
+from dnnlib.tflib.ops.fused_bias_act import fused_bias_act
+
+# NOTE: Do not import any application-specific modules here!
+# Specify all network parameters as kwargs.
+
+#----------------------------------------------------------------------------
+# Get/create weight tensor for convolution or fully-connected layer.
+
+def get_weight(shape, gain=1, equalized_lr=True, lrmul=1, weight_var='weight', trainable=True, use_spectral_norm=False):
+ fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] for conv2d, [in, out] for fully-connected.
+ he_std = gain / np.sqrt(fan_in) # He init.
+
+ # Apply equalized learning rate from the paper
+ # "Progressive Growing of GANs for Improved Quality, Stability, and Variation".
+ if equalized_lr:
+ init_std = 1.0 / lrmul
+ runtime_coef = he_std * lrmul
+ else:
+ init_std = he_std / lrmul
+ runtime_coef = lrmul
+
+ # Create variable.
+ init = tf.initializers.random_normal(0, init_std)
+ w = tf.get_variable(weight_var, shape=shape, initializer=init, trainable=trainable) * runtime_coef
+ if use_spectral_norm:
+ w = apply_spectral_norm(w, state_var=weight_var+'_sn')
+ return w
+
+#----------------------------------------------------------------------------
+# Bias and activation function.
+
+def apply_bias_act(x, act='linear', gain=None, lrmul=1, clamp=None, bias_var='bias', trainable=True):
+ b = tf.get_variable(bias_var, shape=[x.shape[1]], initializer=tf.initializers.zeros(), trainable=trainable) * lrmul
+ return fused_bias_act(x, b=tf.cast(b, x.dtype), act=act, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+# Fully-connected layer.
+
+def dense_layer(x, fmaps, lrmul=1, weight_var='weight', trainable=True, use_spectral_norm=False):
+ if len(x.shape) > 2:
+ x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])])
+ w = get_weight([x.shape[1].value, fmaps], lrmul=lrmul, weight_var=weight_var, trainable=trainable, use_spectral_norm=use_spectral_norm)
+ w = tf.cast(w, x.dtype)
+ return tf.matmul(x, w)
+
+#----------------------------------------------------------------------------
+# 2D convolution op with optional upsampling, downsampling, and padding.
+
+def conv2d(x, w, up=False, down=False, resample_kernel=None, padding=0):
+ assert not (up and down)
+ kernel = w.shape[0].value
+ assert w.shape[1].value == kernel
+ assert kernel >= 1 and kernel % 2 == 1
+
+ w = tf.cast(w, x.dtype)
+ if up:
+ x = upsample_conv_2d(x, w, data_format='NCHW', k=resample_kernel, padding=padding)
+ elif down:
+ x = conv_downsample_2d(x, w, data_format='NCHW', k=resample_kernel, padding=padding)
+ else:
+ padding_mode = {0: 'SAME', -(kernel // 2): 'VALID'}[padding]
+ x = tf.nn.conv2d(x, w, data_format='NCHW', strides=[1,1,1,1], padding=padding_mode)
+ return x
+
+#----------------------------------------------------------------------------
+# 2D convolution layer.
+
+def conv2d_layer(x, fmaps, kernel, up=False, down=False, resample_kernel=None, lrmul=1, trainable=True, use_spectral_norm=False):
+ w = get_weight([kernel, kernel, x.shape[1].value, fmaps], lrmul=lrmul, trainable=trainable, use_spectral_norm=use_spectral_norm)
+ return conv2d(x, tf.cast(w, x.dtype), up=up, down=down, resample_kernel=resample_kernel)
+
+#----------------------------------------------------------------------------
+# Modulated 2D convolution layer from the paper
+# "Analyzing and Improving Image Quality of StyleGAN".
+
+def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, lrmul=1, fused_modconv=False, trainable=True, use_spectral_norm=False):
+ assert not (up and down)
+ assert kernel >= 1 and kernel % 2 == 1
+
+ # Get weight.
+ wshape = [kernel, kernel, x.shape[1].value, fmaps]
+ w = get_weight(wshape, lrmul=lrmul, trainable=trainable, use_spectral_norm=use_spectral_norm)
+ if x.dtype.name == 'float16' and not fused_modconv and demodulate:
+ w *= np.sqrt(1 / np.prod(wshape[:-1])) / tf.reduce_max(tf.abs(w), axis=[0,1,2]) # Pre-normalize to avoid float16 overflow.
+ ww = w[np.newaxis] # [BkkIO] Introduce minibatch dimension.
+
+ # Modulate.
+ s = dense_layer(y, fmaps=x.shape[1].value, weight_var='mod_weight', trainable=trainable, use_spectral_norm=use_spectral_norm) # [BI] Transform incoming W to style.
+ s = apply_bias_act(s, bias_var='mod_bias', trainable=trainable) + 1 # [BI] Add bias (initially 1).
+ if x.dtype.name == 'float16' and not fused_modconv and demodulate:
+ s *= 1 / tf.reduce_max(tf.abs(s)) # Pre-normalize to avoid float16 overflow.
+ ww *= tf.cast(s[:, np.newaxis, np.newaxis, :, np.newaxis], w.dtype) # [BkkIO] Scale input feature maps.
+
+ # Demodulate.
+ if demodulate:
+ d = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[1,2,3]) + 1e-8) # [BO] Scaling factor.
+ ww *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO] Scale output feature maps.
+
+ # Reshape/scale input.
+ if fused_modconv:
+ x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups.
+ w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1])
+ else:
+ x *= tf.cast(s[:, :, np.newaxis, np.newaxis], x.dtype) # [BIhw] Not fused => scale input activations.
+
+ # 2D convolution.
+ x = conv2d(x, tf.cast(w, x.dtype), up=up, down=down, resample_kernel=resample_kernel)
+
+ # Reshape/scale output.
+ if fused_modconv:
+ x = tf.reshape(x, [-1, fmaps, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch.
+ elif demodulate:
+ x *= tf.cast(d[:, :, np.newaxis, np.newaxis], x.dtype) # [BOhw] Not fused => scale output activations.
+ return x
+
+#----------------------------------------------------------------------------
+# Normalize 2nd raw moment of the given activation tensor along specified axes.
+
+def normalize_2nd_moment(x, axis=1, eps=1e-8):
+ return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axis, keepdims=True) + eps)
+
+#----------------------------------------------------------------------------
+# Minibatch standard deviation layer from the paper
+# "Progressive Growing of GANs for Improved Quality, Stability, and Variation".
+
+def minibatch_stddev_layer(x, group_size=None, num_new_features=1):
+ if group_size is None:
+ group_size = tf.shape(x)[0]
+ else:
+ group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size.
+
+ G = group_size
+ F = num_new_features
+ _N, C, H, W = x.shape.as_list()
+ c = C // F
+
+ y = tf.cast(x, tf.float32) # [NCHW] Cast to FP32.
+ y = tf.reshape(y, [G, -1, F, c, H, W]) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
+ y -= tf.reduce_mean(y, axis=0) # [GnFcHW] Subtract mean over group.
+ y = tf.reduce_mean(tf.square(y), axis=0) # [nFcHW] Calc variance over group.
+ y = tf.sqrt(y + 1e-8) # [nFcHW] Calc stddev over group.
+ y = tf.reduce_mean(y, axis=[2,3,4]) # [nF] Take average over channels and pixels.
+ y = tf.cast(y, x.dtype) # [nF] Cast back to original data type.
+ y = tf.reshape(y, [-1, F, 1, 1]) # [nF11] Add missing dimensions.
+ y = tf.tile(y, [G, 1, H, W]) # [NFHW] Replicate over group and pixels.
+ return tf.concat([x, y], axis=1) # [NCHW] Append to input as new channels.
+
+#----------------------------------------------------------------------------
+# Spectral normalization from the paper
+# "Spectral Normalization for Generative Adversarial Networks".
+
+def apply_spectral_norm(w, state_var='sn', iterations=1, eps=1e-8):
+ fmaps = w.shape[-1].value
+ w_mat = tf.reshape(w, [-1, fmaps])
+ u_var = tf.get_variable(state_var, shape=[1,fmaps], initializer=tf.initializers.random_normal(), trainable=False)
+
+ u = u_var
+ for _ in range(iterations):
+ v = tf.matmul(u, w_mat, transpose_b=True)
+ v *= tf.rsqrt(tf.reduce_sum(tf.square(v)) + eps)
+ u = tf.matmul(v, w_mat)
+ sigma_inv = tf.rsqrt(tf.reduce_sum(tf.square(u)) + eps)
+ u *= sigma_inv
+
+ with tf.control_dependencies([tf.assign(u_var, u)]):
+ return w * sigma_inv
+
+#----------------------------------------------------------------------------
+# Main generator network.
+# Composed of two sub-networks (mapping and synthesis) that are defined below.
+
+def G_main(
+ latents_in, # First input: Latent vectors (Z) [minibatch, latent_size].
+ labels_in, # Second input: Conditioning labels [minibatch, label_size].
+
+ # Evaluation mode.
+ is_training = False, # Network is under training? Enables and disables specific features.
+ is_validation = False, # Network is under validation? Chooses which value to use for truncation_psi.
+ return_dlatents = False, # Return dlatents (W) in addition to the images?
+
+ # Truncation & style mixing.
+ truncation_psi = 0.5, # Style strength multiplier for the truncation trick. None = disable.
+ truncation_cutoff = None, # Number of layers for which to apply the truncation trick. None = disable.
+ truncation_psi_val = None, # Value for truncation_psi to use during validation.
+ truncation_cutoff_val = None, # Value for truncation_cutoff to use during validation.
+ dlatent_avg_beta = 0.995, # Decay for tracking the moving average of W during training. None = disable.
+ style_mixing_prob = 0.9, # Probability of mixing styles during training. None = disable.
+
+ # Sub-networks.
+ components = dnnlib.EasyDict(), # Container for sub-networks. Retained between calls.
+ mapping_func = 'G_mapping', # Build func name for the mapping network.
+ synthesis_func = 'G_synthesis', # Build func name for the synthesis network.
+ is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
+
+ **kwargs, # Arguments for sub-networks (mapping and synthesis).
+):
+ # Validate arguments.
+ assert not is_training or not is_validation
+ assert isinstance(components, dnnlib.EasyDict)
+ if is_validation:
+ truncation_psi = truncation_psi_val
+ truncation_cutoff = truncation_cutoff_val
+ if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):
+ truncation_psi = None
+ if is_training:
+ truncation_cutoff = None
+ if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):
+ dlatent_avg_beta = None
+ if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):
+ style_mixing_prob = None
+
+ # Setup components.
+ if 'synthesis' not in components:
+ components.synthesis = tflib.Network('G_synthesis', func_name=globals()[synthesis_func], **kwargs)
+ num_layers = components.synthesis.input_shape[1]
+ dlatent_size = components.synthesis.input_shape[2]
+ if 'mapping' not in components:
+ components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs)
+
+ # Evaluate mapping network.
+ dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs)
+ dlatents = tf.cast(dlatents, tf.float32)
+
+ # Update moving average of W.
+ dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)
+ if dlatent_avg_beta is not None:
+ with tf.variable_scope('DlatentAvg'):
+ batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
+ update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
+ with tf.control_dependencies([update_op]):
+ dlatents = tf.identity(dlatents)
+
+ # Perform style mixing regularization.
+ if style_mixing_prob is not None:
+ with tf.variable_scope('StyleMix'):
+ latents2 = tf.random_normal(tf.shape(latents_in))
+ dlatents2 = components.mapping.get_output_for(latents2, labels_in, is_training=is_training, **kwargs)
+ dlatents2 = tf.cast(dlatents2, tf.float32)
+ layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
+ mixing_cutoff = tf.cond(
+ tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
+ lambda: tf.random_uniform([], 1, num_layers, dtype=tf.int32),
+ lambda: num_layers)
+ dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)
+
+ # Apply truncation.
+ if truncation_psi is not None:
+ with tf.variable_scope('Truncation'):
+ layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
+ layer_psi = np.ones(layer_idx.shape, dtype=np.float32)
+ if truncation_cutoff is None:
+ layer_psi *= truncation_psi
+ else:
+ layer_psi = tf.where(layer_idx < truncation_cutoff, layer_psi * truncation_psi, layer_psi)
+ dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi)
+
+ # Evaluate synthesis network.
+ images_out = components.synthesis.get_output_for(dlatents, is_training=is_training, force_clean_graph=is_template_graph, **kwargs)
+ images_out = tf.identity(images_out, name='images_out')
+ if return_dlatents:
+ return images_out, dlatents
+ return images_out
+
+#----------------------------------------------------------------------------
+# Generator mapping network.
+
+def G_mapping(
+ latents_in, # First input: Latent vectors (Z) [minibatch, latent_size].
+ labels_in, # Second input: Conditioning labels [minibatch, label_size].
+
+ # Input & output dimensions.
+ latent_size = 512, # Latent vector (Z) dimensionality.
+ label_size = 0, # Label dimensionality, 0 if no labels.
+ dlatent_size = 512, # Disentangled latent (W) dimensionality.
+ dlatent_broadcast = None, # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size].
+
+ # Internal details.
+ mapping_layers = 8, # Number of mapping layers.
+ mapping_fmaps = None, # Number of activations in the mapping layers, None = same as dlatent_size.
+ mapping_lrmul = 0.01, # Learning rate multiplier for the mapping layers.
+ mapping_nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ normalize_latents = True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
+ label_fmaps = None, # Label embedding dimensionality, None = same as latent_size.
+ dtype = 'float32', # Data type to use for intermediate activations and outputs.
+
+ **_kwargs, # Ignore unrecognized keyword args.
+):
+ # Inputs.
+ latents_in.set_shape([None, latent_size])
+ labels_in.set_shape([None, label_size])
+ latents_in = tf.cast(latents_in, dtype)
+ labels_in = tf.cast(labels_in, dtype)
+ x = latents_in
+
+ # Normalize latents.
+ if normalize_latents:
+ with tf.variable_scope('Normalize'):
+ x = normalize_2nd_moment(x)
+
+ # Embed labels, normalize, and concatenate with latents.
+ if label_size > 0:
+ with tf.variable_scope('LabelEmbed'):
+ fmaps = label_fmaps if label_fmaps is not None else latent_size
+ y = labels_in
+ y = apply_bias_act(dense_layer(y, fmaps=fmaps))
+ y = normalize_2nd_moment(y)
+ x = tf.concat([x, y], axis=1)
+
+ # Mapping layers.
+ for layer_idx in range(mapping_layers):
+ with tf.variable_scope(f'Dense{layer_idx}'):
+ fmaps = mapping_fmaps if mapping_fmaps is not None and layer_idx < mapping_layers - 1 else dlatent_size
+ x = apply_bias_act(dense_layer(x, fmaps=fmaps, lrmul=mapping_lrmul), act=mapping_nonlinearity, lrmul=mapping_lrmul)
+
+ # Broadcast.
+ if dlatent_broadcast is not None:
+ with tf.variable_scope('Broadcast'):
+ x = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1])
+
+ # Output.
+ assert x.dtype == tf.as_dtype(dtype)
+ return tf.identity(x, name='dlatents_out')
+
+#----------------------------------------------------------------------------
+# Generator synthesis network.
+
+def G_synthesis(
+ dlatents_in, # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
+
+ # Input & output dimensions.
+ dlatent_size = 512, # Disentangled latent (W) dimensionality.
+ num_channels = 3, # Number of output color channels.
+ resolution = 1024, # Output resolution.
+
+ # Capacity.
+ fmap_base = 16384, # Overall multiplier for the number of feature maps.
+ fmap_decay = 1, # Log2 feature map reduction when doubling the resolution.
+ fmap_min = 1, # Minimum number of feature maps in any layer.
+ fmap_max = 512, # Maximum number of feature maps in any layer.
+ fmap_const = None, # Number of feature maps in the constant input layer. None = default.
+
+ # Internal details.
+ use_noise = True, # Enable noise inputs?
+ randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ dtype = 'float32', # Data type to use for intermediate activations and outputs.
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions, regardless of dtype.
+ conv_clamp = None, # Clamp the output of convolution layers to [-conv_clamp, +conv_clamp], None = disable clamping.
+ resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations, None = box filter.
+ fused_modconv = False, # Implement modulated_conv2d_layer() using grouped convolution?
+
+ **_kwargs, # Ignore unrecognized keyword args.
+):
+ resolution_log2 = int(np.log2(resolution))
+ assert resolution == 2**resolution_log2 and resolution >= 4
+ def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
+ assert architecture in ['orig', 'skip', 'resnet']
+ act = nonlinearity
+ num_layers = resolution_log2 * 2 - 2
+
+ # Disentangled latent (W).
+ dlatents_in.set_shape([None, num_layers, dlatent_size])
+ dlatents_in = tf.cast(dlatents_in, dtype)
+
+ # Noise inputs.
+ noise_inputs = []
+ if use_noise:
+ for layer_idx in range(num_layers - 1):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2**res, 2**res]
+ noise_inputs.append(tf.get_variable(f'noise{layer_idx}', shape=shape, initializer=tf.initializers.random_normal(), trainable=False))
+
+ # Single convolution layer with all the bells and whistles.
+ def layer(x, layer_idx, fmaps, kernel, up=False):
+ x = modulated_conv2d_layer(x, dlatents_in[:, layer_idx], fmaps=fmaps, kernel=kernel, up=up, resample_kernel=resample_kernel, fused_modconv=fused_modconv)
+ if use_noise:
+ if randomize_noise:
+ noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
+ else:
+ noise = tf.cast(noise_inputs[layer_idx], x.dtype)
+ noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros())
+ x += noise * tf.cast(noise_strength, x.dtype)
+ return apply_bias_act(x, act=act, clamp=conv_clamp)
+
+ # Main block for one resolution.
+ def block(x, res): # res = 3..resolution_log2
+ x = tf.cast(x, 'float16' if res > resolution_log2 - num_fp16_res else dtype)
+ t = x
+ with tf.variable_scope('Conv0_up'):
+ x = layer(x, layer_idx=res*2-5, fmaps=nf(res-1), kernel=3, up=True)
+ with tf.variable_scope('Conv1'):
+ x = layer(x, layer_idx=res*2-4, fmaps=nf(res-1), kernel=3)
+ if architecture == 'resnet':
+ with tf.variable_scope('Skip'):
+ t = conv2d_layer(t, fmaps=nf(res-1), kernel=1, up=True, resample_kernel=resample_kernel)
+ x = (x + t) * (1 / np.sqrt(2))
+ return x
+
+ # Upsampling block.
+ def upsample(y):
+ with tf.variable_scope('Upsample'):
+ return upsample_2d(y, k=resample_kernel)
+
+ # ToRGB block.
+ def torgb(x, y, res): # res = 2..resolution_log2
+ with tf.variable_scope('ToRGB'):
+ t = modulated_conv2d_layer(x, dlatents_in[:, res*2-3], fmaps=num_channels, kernel=1, demodulate=False, fused_modconv=fused_modconv)
+ t = apply_bias_act(t, clamp=conv_clamp)
+ t = tf.cast(t, dtype)
+ if y is not None:
+ t += tf.cast(y, t.dtype)
+ return t
+
+ # Layers for 4x4 resolution.
+ y = None
+ with tf.variable_scope('4x4'):
+ with tf.variable_scope('Const'):
+ fmaps = fmap_const if fmap_const is not None else nf(1)
+ x = tf.get_variable('const', shape=[1, fmaps, 4, 4], initializer=tf.initializers.random_normal())
+ x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
+ with tf.variable_scope('Conv'):
+ x = layer(x, layer_idx=0, fmaps=nf(1), kernel=3)
+ if architecture == 'skip':
+ y = torgb(x, y, 2)
+
+ # Layers for >=8x8 resolutions.
+ for res in range(3, resolution_log2 + 1):
+ with tf.variable_scope(f'{2**res}x{2**res}'):
+ x = block(x, res)
+ if architecture == 'skip':
+ y = upsample(y)
+ if architecture == 'skip' or res == resolution_log2:
+ y = torgb(x, y, res)
+
+ images_out = y
+ assert images_out.dtype == tf.as_dtype(dtype)
+ return tf.identity(images_out, name='images_out')
+
+#----------------------------------------------------------------------------
+# Discriminator.
+
+def D_main(
+ images_in, # First input: Images [minibatch, channel, height, width].
+ labels_in, # Second input: Conditioning labels [minibatch, label_size].
+
+ # Input dimensions.
+ num_channels = 3, # Number of input color channels. Overridden based on dataset.
+ resolution = 1024, # Input resolution. Overridden based on dataset.
+ label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
+
+ # Capacity.
+ fmap_base = 16384, # Overall multiplier for the number of feature maps.
+ fmap_decay = 1, # Log2 feature map reduction when doubling the resolution.
+ fmap_min = 1, # Minimum number of feature maps in any layer.
+ fmap_max = 512, # Maximum number of feature maps in any layer.
+
+ # Internal details.
+ mapping_layers = 0, # Number of additional mapping layers for the conditioning labels.
+ mapping_fmaps = None, # Number of activations in the mapping layers, None = default.
+ mapping_lrmul = 0.1, # Learning rate multiplier for the mapping layers.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ mbstd_group_size = None, # Group size for the minibatch standard deviation layer, None = entire minibatch.
+ mbstd_num_features = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
+ dtype = 'float32', # Data type to use for intermediate activations and outputs.
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions, regardless of dtype.
+ conv_clamp = None, # Clamp the output of convolution layers to [-conv_clamp, +conv_clamp], None = disable clamping.
+ resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations, None = box filter.
+
+ # Comparison methods.
+ augment_strength = 0, # AdaptiveAugment.get_strength_var() for pagan & adropout.
+ use_pagan = False, # pagan: Enable?
+ pagan_num = 16, # pagan: Number of active bits with augment_strength=1.
+ pagan_fade = 0.5, # pagan: Relative duration of fading in new bits.
+ score_size = 1, # auxrot: Number of scalars to output. Can vary between evaluations.
+ score_max = 1, # auxrot: Maximum number of scalars to output. Must be set at construction time.
+ use_spectral_norm = False, # spectralnorm: Enable?
+ adaptive_dropout = 0, # adropout: Standard deviation to use with augment_strength=1, 0 = disable.
+ freeze_layers = 0, # Freeze-D: Number of layers to freeze.
+
+ **_kwargs, # Ignore unrecognized keyword args.
+):
+ resolution_log2 = int(np.log2(resolution))
+ assert resolution == 2**resolution_log2 and resolution >= 4
+ def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
+ assert architecture in ['orig', 'skip', 'resnet']
+ if mapping_fmaps is None:
+ mapping_fmaps = nf(0)
+ act = nonlinearity
+
+ # Inputs.
+ images_in.set_shape([None, num_channels, resolution, resolution])
+ labels_in.set_shape([None, label_size])
+ images_in = tf.cast(images_in, dtype)
+ labels_in = tf.cast(labels_in, dtype)
+
+ # Label embedding and mapping.
+ if label_size > 0:
+ y = labels_in
+ with tf.variable_scope('LabelEmbed'):
+ y = apply_bias_act(dense_layer(y, fmaps=mapping_fmaps))
+ y = normalize_2nd_moment(y)
+ for idx in range(mapping_layers):
+ with tf.variable_scope(f'Mapping{idx}'):
+ y = apply_bias_act(dense_layer(y, fmaps=mapping_fmaps, lrmul=mapping_lrmul), act=act, lrmul=mapping_lrmul)
+ labels_in = y
+
+ # Adaptive multiplicative dropout.
+ def adrop(x):
+ if adaptive_dropout != 0:
+ s = [tf.shape(x)[0], x.shape[1]] + [1] * (x.shape.rank - 2)
+ x *= tf.cast(tf.exp(tf.random_normal(s) * (augment_strength * adaptive_dropout)), x.dtype)
+ return x
+
+ # Freeze-D.
+ cur_layer_idx = 0
+ def is_next_layer_trainable():
+ nonlocal cur_layer_idx
+ trainable = (cur_layer_idx >= freeze_layers)
+ cur_layer_idx += 1
+ return trainable
+
+ # Construct PA-GAN bit vector.
+ pagan_bits = None
+ pagan_signs = None
+ if use_pagan:
+ with tf.variable_scope('PAGAN'):
+ idx = tf.range(pagan_num, dtype=tf.float32)
+ active = (augment_strength * pagan_num - idx - 1) / max(pagan_fade, 1e-8) + 1
+ prob = tf.clip_by_value(active[np.newaxis, :], 0, 1) * 0.5
+ rnd = tf.random_uniform([tf.shape(images_in)[0], pagan_num])
+ pagan_bits = tf.cast(rnd < prob, dtype=tf.float32)
+ pagan_signs = tf.reduce_prod(1 - pagan_bits * 2, axis=1, keepdims=True)
+
+ # FromRGB block.
+ def fromrgb(x, y, res): # res = 2..resolution_log2
+ with tf.variable_scope('FromRGB'):
+ trainable = is_next_layer_trainable()
+ t = tf.cast(y, 'float16' if res > resolution_log2 - num_fp16_res else dtype)
+ t = adrop(conv2d_layer(t, fmaps=nf(res-1), kernel=1, trainable=trainable))
+ if pagan_bits is not None:
+ with tf.variable_scope('PAGAN'):
+ t += dense_layer(tf.cast(pagan_bits, t.dtype), fmaps=nf(res-1), trainable=trainable)[:, :, np.newaxis, np.newaxis]
+ t = apply_bias_act(t, act=act, clamp=conv_clamp, trainable=trainable)
+ if x is not None:
+ t += tf.cast(x, t.dtype)
+ return t
+
+ # Main block for one resolution.
+ def block(x, res): # res = 2..resolution_log2
+ x = tf.cast(x, 'float16' if res > resolution_log2 - num_fp16_res else dtype)
+ t = x
+ with tf.variable_scope('Conv0'):
+ trainable = is_next_layer_trainable()
+ x = apply_bias_act(adrop(conv2d_layer(x, fmaps=nf(res-1), kernel=3, trainable=trainable, use_spectral_norm=use_spectral_norm)), act=act, clamp=conv_clamp, trainable=trainable)
+ with tf.variable_scope('Conv1_down'):
+ trainable = is_next_layer_trainable()
+ x = apply_bias_act(adrop(conv2d_layer(x, fmaps=nf(res-2), kernel=3, down=True, resample_kernel=resample_kernel, trainable=trainable, use_spectral_norm=use_spectral_norm)), act=act, clamp=conv_clamp, trainable=trainable)
+ if architecture == 'resnet':
+ with tf.variable_scope('Skip'):
+ trainable = is_next_layer_trainable()
+ t = adrop(conv2d_layer(t, fmaps=nf(res-2), kernel=1, down=True, resample_kernel=resample_kernel, trainable=trainable))
+ x = (x + t) * (1 / np.sqrt(2))
+ return x
+
+ # Downsampling block.
+ def downsample(y):
+ with tf.variable_scope('Downsample'):
+ return downsample_2d(y, k=resample_kernel)
+
+ # Layers for >=8x8 resolutions.
+ x = None
+ y = images_in
+ for res in range(resolution_log2, 2, -1):
+ with tf.variable_scope(f'{2**res}x{2**res}'):
+ if architecture == 'skip' or res == resolution_log2:
+ x = fromrgb(x, y, res)
+ x = block(x, res)
+ if architecture == 'skip':
+ y = downsample(y)
+
+ # Layers for 4x4 resolution.
+ with tf.variable_scope('4x4'):
+ if architecture == 'skip':
+ x = fromrgb(x, y, 2)
+ x = tf.cast(x, dtype)
+ if mbstd_num_features > 0:
+ with tf.variable_scope('MinibatchStddev'):
+ x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features)
+ with tf.variable_scope('Conv'):
+ trainable = is_next_layer_trainable()
+ x = apply_bias_act(adrop(conv2d_layer(x, fmaps=nf(1), kernel=3, trainable=trainable, use_spectral_norm=use_spectral_norm)), act=act, clamp=conv_clamp, trainable=trainable)
+ with tf.variable_scope('Dense0'):
+ trainable = is_next_layer_trainable()
+ x = apply_bias_act(adrop(dense_layer(x, fmaps=nf(0), trainable=trainable)), act=act, trainable=trainable)
+
+ # Output layer (always trainable).
+ with tf.variable_scope('Output'):
+ if label_size > 0:
+ assert score_max == 1
+ x = apply_bias_act(dense_layer(x, fmaps=mapping_fmaps))
+ x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True) / np.sqrt(mapping_fmaps)
+ else:
+ x = apply_bias_act(dense_layer(x, fmaps=score_max))
+ if pagan_signs is not None:
+ assert score_max == 1
+ x *= pagan_signs
+ scores_out = x[:, :score_size]
+
+ # Output.
+ assert scores_out.dtype == tf.as_dtype(dtype)
+ scores_out = tf.identity(scores_out, name='scores_out')
+ return scores_out
+
+#----------------------------------------------------------------------------
diff --git a/training/training_loop.py b/training/training_loop.py
new file mode 100755
index 00000000..f70c11f8
--- /dev/null
+++ b/training/training_loop.py
@@ -0,0 +1,326 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Main training loop."""
+
+import os
+import pickle
+import time
+import PIL.Image
+import numpy as np
+import tensorflow as tf
+import dnnlib
+import dnnlib.tflib as tflib
+from dnnlib.tflib.autosummary import autosummary
+
+from training import dataset
+
+#----------------------------------------------------------------------------
+# Select size and contents of the image snapshot grids that are exported
+# periodically during training.
+
+def setup_snapshot_image_grid(training_set):
+ gw = np.clip(7680 // training_set.shape[2], 7, 32)
+ gh = np.clip(4320 // training_set.shape[1], 4, 32)
+
+ # Unconditional.
+ if training_set.label_size == 0:
+ reals, labels = training_set.get_minibatch_np(gw * gh)
+ return (gw, gh), reals, labels
+
+ # Row per class.
+ cw, ch = (gw, 1)
+ nw = (gw - 1) // cw + 1
+ nh = (gh - 1) // ch + 1
+
+ # Collect images.
+ blocks = [[] for _i in range(nw * nh)]
+ for _iter in range(1000000):
+ real, label = training_set.get_minibatch_np(1)
+ idx = np.argmax(label[0])
+ while idx < len(blocks) and len(blocks[idx]) >= cw * ch:
+ idx += training_set.label_size
+ if idx < len(blocks):
+ blocks[idx].append((real, label))
+ if all(len(block) >= cw * ch for block in blocks):
+ break
+
+ # Layout grid.
+ reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
+ labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype)
+ for i, block in enumerate(blocks):
+ for j, (real, label) in enumerate(block):
+ x = (i % nw) * cw + j % cw
+ y = (i // nw) * ch + j // cw
+ if x < gw and y < gh:
+ reals[x + y * gw] = real[0]
+ labels[x + y * gw] = label[0]
+ return (gw, gh), reals, labels
+
+#----------------------------------------------------------------------------
+
+def save_image_grid(images, filename, drange, grid_size):
+ lo, hi = drange
+ gw, gh = grid_size
+ images = np.asarray(images, dtype=np.float32)
+ images = (images - lo) * (255 / (hi - lo))
+ images = np.rint(images).clip(0, 255).astype(np.uint8)
+ _N, C, H, W = images.shape
+ images = images.reshape(gh, gw, C, H, W)
+ images = images.transpose(0, 3, 1, 4, 2)
+ images = images.reshape(gh * H, gw * W, C)
+ PIL.Image.fromarray(images, {3: 'RGB', 1: 'L'}[C]).save(filename)
+
+#----------------------------------------------------------------------------
+# Main training script.
+
+def training_loop(
+ run_dir = '.', # Output directory.
+ G_args = {}, # Options for generator network.
+ D_args = {}, # Options for discriminator network.
+ G_opt_args = {}, # Options for generator optimizer.
+ D_opt_args = {}, # Options for discriminator optimizer.
+ loss_args = {}, # Options for loss function.
+ train_dataset_args = {}, # Options for dataset to train with.
+ metric_dataset_args = {}, # Options for dataset to evaluate metrics against.
+ augment_args = {}, # Options for adaptive augmentations.
+ metric_arg_list = [], # Metrics to evaluate during training.
+ num_gpus = 1, # Number of GPUs to use.
+ minibatch_size = 32, # Global minibatch size.
+ minibatch_gpu = 4, # Number of samples processed at a time by one GPU.
+ G_smoothing_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
+ G_smoothing_rampup = None, # EMA ramp-up coefficient.
+ minibatch_repeats = 4, # Number of minibatches to run in the inner loop.
+ lazy_regularization = True, # Perform regularization as a separate training step?
+ G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False.
+ D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False.
+ total_kimg = 25000, # Total length of the training, measured in thousands of real images.
+ kimg_per_tick = 4, # Progress snapshot interval.
+ image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
+ network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'.
+ resume_pkl = None, # Network pickle to resume training from.
+ abort_fn = None, # Callback function for determining whether to abort training.
+ progress_fn = None, # Callback function for updating training progress.
+):
+ assert minibatch_size % (num_gpus * minibatch_gpu) == 0
+ start_time = time.time()
+
+ print('Loading training set...')
+ training_set = dataset.load_dataset(**train_dataset_args)
+ print('Image shape:', np.int32(training_set.shape).tolist())
+ print('Label shape:', [training_set.label_size])
+ print()
+
+ print('Constructing networks...')
+ with tf.device('/gpu:0'):
+ G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
+ D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
+ Gs = G.clone('Gs')
+ if resume_pkl is not None:
+ print(f'Resuming from "{resume_pkl}"')
+ with dnnlib.util.open_url(resume_pkl) as f:
+ rG, rD, rGs = pickle.load(f)
+ G.copy_vars_from(rG)
+ D.copy_vars_from(rD)
+ Gs.copy_vars_from(rGs)
+ G.print_layers()
+ D.print_layers()
+
+ print('Exporting sample images...')
+ grid_size, grid_reals, grid_labels = setup_snapshot_image_grid(training_set)
+ save_image_grid(grid_reals, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
+ grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
+ grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu)
+ save_image_grid(grid_fakes, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
+
+ print(f'Replicating networks across {num_gpus} GPUs...')
+ G_gpus = [G]
+ D_gpus = [D]
+ for gpu in range(1, num_gpus):
+ with tf.device(f'/gpu:{gpu}'):
+ G_gpus.append(G.clone(f'{G.name}_gpu{gpu}'))
+ D_gpus.append(D.clone(f'{D.name}_gpu{gpu}'))
+
+ print('Initializing augmentations...')
+ aug = None
+ if augment_args.get('class_name', None) is not None:
+ aug = dnnlib.util.construct_class_by_name(**augment_args)
+ aug.init_validation_set(D_gpus=D_gpus, training_set=training_set)
+
+ print('Setting up optimizers...')
+ G_opt_args = dict(G_opt_args)
+ D_opt_args = dict(D_opt_args)
+ for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]:
+ args['minibatch_multiplier'] = minibatch_size // num_gpus // minibatch_gpu
+ if lazy_regularization:
+ mb_ratio = reg_interval / (reg_interval + 1)
+ args['learning_rate'] *= mb_ratio
+ if 'beta1' in args: args['beta1'] **= mb_ratio
+ if 'beta2' in args: args['beta2'] **= mb_ratio
+ G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
+ D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
+ G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
+ D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)
+
+ print('Constructing training graph...')
+ data_fetch_ops = []
+ training_set.configure(minibatch_gpu)
+ for gpu, (G_gpu, D_gpu) in enumerate(zip(G_gpus, D_gpus)):
+ with tf.name_scope(f'Train_gpu{gpu}'), tf.device(f'/gpu:{gpu}'):
+
+ # Fetch training data via temporary variables.
+ with tf.name_scope('DataFetch'):
+ real_images_var = tf.Variable(name='images', trainable=False, initial_value=tf.zeros([minibatch_gpu] + training_set.shape))
+ real_labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([minibatch_gpu, training_set.label_size]))
+ real_images_write, real_labels_write = training_set.get_minibatch_tf()
+ real_images_write = tflib.convert_images_from_uint8(real_images_write)
+ data_fetch_ops += [tf.assign(real_images_var, real_images_write)]
+ data_fetch_ops += [tf.assign(real_labels_var, real_labels_write)]
+
+ # Evaluate loss function and register gradients.
+ fake_labels = training_set.get_random_labels_tf(minibatch_gpu)
+ terms = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, aug=aug, fake_labels=fake_labels, real_images=real_images_var, real_labels=real_labels_var, **loss_args)
+ if lazy_regularization:
+ if terms.G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(terms.G_reg * G_reg_interval), G_gpu.trainables)
+ if terms.D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(terms.D_reg * D_reg_interval), D_gpu.trainables)
+ else:
+ if terms.G_reg is not None: terms.G_loss += terms.G_reg
+ if terms.D_reg is not None: terms.D_loss += terms.D_reg
+ G_opt.register_gradients(tf.reduce_mean(terms.G_loss), G_gpu.trainables)
+ D_opt.register_gradients(tf.reduce_mean(terms.D_loss), D_gpu.trainables)
+
+ print('Finalizing training ops...')
+ data_fetch_op = tf.group(*data_fetch_ops)
+ G_train_op = G_opt.apply_updates()
+ D_train_op = D_opt.apply_updates()
+ G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
+ D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
+ Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[])
+ Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in)
+ tflib.init_uninitialized_vars()
+ with tf.device('/gpu:0'):
+ peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
+
+ print('Initializing metrics...')
+ summary_log = tf.summary.FileWriter(run_dir)
+ metrics = []
+ for args in metric_arg_list:
+ metric = dnnlib.util.construct_class_by_name(**args)
+ metric.configure(dataset_args=metric_dataset_args, run_dir=run_dir)
+ metrics.append(metric)
+
+ print(f'Training for {total_kimg} kimg...')
+ print()
+ if progress_fn is not None:
+ progress_fn(0, total_kimg)
+ tick_start_time = time.time()
+ maintenance_time = tick_start_time - start_time
+ cur_nimg = 0
+ cur_tick = -1
+ tick_start_nimg = cur_nimg
+ running_mb_counter = 0
+
+ done = False
+ while not done:
+
+ # Compute EMA decay parameter.
+ Gs_nimg = G_smoothing_kimg * 1000.0
+ if G_smoothing_rampup is not None:
+ Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup)
+ Gs_beta = 0.5 ** (minibatch_size / max(Gs_nimg, 1e-8))
+
+ # Run training ops.
+ for _repeat_idx in range(minibatch_repeats):
+ rounds = range(0, minibatch_size, minibatch_gpu * num_gpus)
+ run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0)
+ run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0)
+ cur_nimg += minibatch_size
+ running_mb_counter += 1
+
+ # Fast path without gradient accumulation.
+ if len(rounds) == 1:
+ tflib.run([G_train_op, data_fetch_op])
+ if run_G_reg:
+ tflib.run(G_reg_op)
+ tflib.run([D_train_op, Gs_update_op], {Gs_beta_in: Gs_beta})
+ if run_D_reg:
+ tflib.run(D_reg_op)
+
+ # Slow path with gradient accumulation.
+ else:
+ for _round in rounds:
+ tflib.run(G_train_op)
+ if run_G_reg:
+ tflib.run(G_reg_op)
+ tflib.run(Gs_update_op, {Gs_beta_in: Gs_beta})
+ for _round in rounds:
+ tflib.run(data_fetch_op)
+ tflib.run(D_train_op)
+ if run_D_reg:
+ tflib.run(D_reg_op)
+
+ # Run validation.
+ if aug is not None:
+ aug.run_validation(minibatch_size=minibatch_size)
+
+ # Tune augmentation parameters.
+ if aug is not None:
+ aug.tune(minibatch_size * minibatch_repeats)
+
+ # Perform maintenance tasks once per tick.
+ done = (cur_nimg >= total_kimg * 1000) or (abort_fn is not None and abort_fn())
+ if done or cur_tick < 0 or cur_nimg >= tick_start_nimg + kimg_per_tick * 1000:
+ cur_tick += 1
+ tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
+ tick_start_nimg = cur_nimg
+ tick_end_time = time.time()
+ total_time = tick_end_time - start_time
+ tick_time = tick_end_time - tick_start_time
+
+ # Report progress.
+ print(' '.join([
+ f"tick {autosummary('Progress/tick', cur_tick):<5d}",
+ f"kimg {autosummary('Progress/kimg', cur_nimg / 1000.0):<8.1f}",
+ f"time {dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)):<12s}",
+ f"sec/tick {autosummary('Timing/sec_per_tick', tick_time):<7.1f}",
+ f"sec/kimg {autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):<7.2f}",
+ f"maintenance {autosummary('Timing/maintenance_sec', maintenance_time):<6.1f}",
+ f"gpumem {autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):<5.1f}",
+ f"augment {autosummary('Progress/augment', aug.strength if aug is not None else 0):.3f}",
+ ]))
+ autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
+ autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
+ if progress_fn is not None:
+ progress_fn(cur_nimg // 1000, total_kimg)
+
+ # Save snapshots.
+ if image_snapshot_ticks is not None and (done or cur_tick % image_snapshot_ticks == 0):
+ grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu)
+ save_image_grid(grid_fakes, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
+ if network_snapshot_ticks is not None and (done or cur_tick % network_snapshot_ticks == 0):
+ pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl')
+ with open(pkl, 'wb') as f:
+ pickle.dump((G, D, Gs), f)
+ if len(metrics):
+ print('Evaluating metrics...')
+ for metric in metrics:
+ metric.run(pkl, num_gpus=num_gpus)
+
+ # Update summaries.
+ for metric in metrics:
+ metric.update_autosummaries()
+ tflib.autosummary.save_summaries(summary_log, cur_nimg)
+ tick_start_time = time.time()
+ maintenance_time = tick_start_time - tick_end_time
+
+ print()
+ print('Exiting...')
+ summary_log.close()
+ training_set.close()
+
+#----------------------------------------------------------------------------