From 1a61a691e0357f245bee7a67589beb7f083936aa Mon Sep 17 00:00:00 2001 From: Peter Baylies Date: Sat, 19 Dec 2020 12:34:01 -0500 Subject: [PATCH] Add fid10k metrics, updates for variance and truncation with labels in the notebook. --- WikiArt Example Generation.ipynb | 38 ++++++++++++++++++-------------- metrics/metric_defaults.py | 2 ++ 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/WikiArt Example Generation.ipynb b/WikiArt Example Generation.ipynb index 9d7c5117..e0db055d 100644 --- a/WikiArt Example Generation.ipynb +++ b/WikiArt Example Generation.ipynb @@ -25,6 +25,7 @@ "import dnnlib\n", "import dnnlib.tflib as tflib\n", "network_pkl = 'WikiArt5.pkl'\n", + "#network_pkl = 'WikiArt_Uncond2.pkl'\n", "\n", "dnnlib.tflib.init_tf()\n", "with dnnlib.util.open_url(network_pkl) as f:\n", @@ -39,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 184, + "execution_count": 2, "metadata": { "scrolled": false }, @@ -47,7 +48,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c77fb8afc211410b9743a12bf97938ae", + "model_id": "1782f66682b3476689255fbd89fb1bfc", "version_major": 2, "version_minor": 0 }, @@ -61,7 +62,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e5eea77c8b1c4ba8a1de4e3df1f096aa", + "model_id": "f3041137c22d40c998d31193bfa89708", "version_major": 2, "version_minor": 0 }, @@ -91,9 +92,9 @@ ")\n", "seed = widgets.IntSlider(min=0, max=100000, step=1, value=9, description='Seed: ')\n", "scale = widgets.FloatSlider(min=0, max=25, step=0.1, value=2, description='Global Scale: ')\n", - "truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ')\n", - "variance = widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.4, description='Variance: ')\n", - "iterations = widgets.IntSlider(min=0, max=100, step=1, value=64, description='Iterations: ')\n", + "truncation = widgets.FloatSlider(min=-5, max=10, step=0.1, value=1, description='Truncation: ')\n", + "variance = widgets.FloatSlider(min=0, max=10, step=0.1, value=0.4, description='Variance: ')\n", + "iterations = widgets.IntSlider(min=0, max=1000, step=1, value=64, description='Iterations: ')\n", "top_box = widgets.HBox([artist, genre, style])\n", "mid_box = widgets.HBox([variance, iterations])\n", "bot_box = widgets.HBox([seed, scale, truncation])\n", @@ -109,19 +110,23 @@ " all_seeds = [seed] * batch_size\n", " all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n", " all_w = Gs.components.mapping.run(all_z, l1) # [minibatch, layer, component]\n", - " #all_w = Gs.components.mapping.run(all_z, np.tile(l1, (batch_size, 1))) # [minibatch, layer, component]\n", - " if truncation != 1:\n", - " w_avg = Gs.get_var('dlatent_avg')\n", - " all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n", + " total = 0.0\n", + " acc_w = np.zeros((batch_size,18,512))\n", + " for i in range(400): # calculate approximate center\n", + " acc_w += Gs.components.mapping.run(0*all_z+np.random.RandomState(i).randn(512), np.tile(l1, (batch_size, 1))) # [minibatch, layer, component]\n", + " total+=1.0\n", + " acc_w /= total\n", + " w_avg = acc_w\n", " if variance == 0 or iterations < 1:\n", + " if truncation != 1:\n", + " all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n", " all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs)\n", " else:\n", " acc_w = np.zeros((batch_size,18,512))\n", " total = 0.0\n", " for i in range(iterations):\n", - " all_w = Gs.components.mapping.run(all_z + variance*np.random.RandomState(i).randn(512), np.tile(l1 + variance*np.random.RandomState(i).randn(167), (batch_size, 1))) # [minibatch, layer, component]\n", + " all_w = Gs.components.mapping.run(all_z + variance*np.random.RandomState(i).randn(512), np.tile(l1 + 0.1*variance*np.random.RandomState(i).randn(167), (batch_size, 1))) # [minibatch, layer, component]\n", " if truncation != 1:\n", - " w_avg = Gs.get_var('dlatent_avg')\n", " all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n", " acc_w += all_w\n", " total+=1.0\n", @@ -138,13 +143,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4d7791b74ef0415dad6651e0f6cf0e80", + "model_id": "7a1eeb8cf0e048258111a8665e355eab", "version_major": 2, "version_minor": 0 }, @@ -158,7 +163,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "588de2e8dcf44ca78ef94461efd1725c", + "model_id": "b218e2d962f547af9c1982648354a799", "version_major": 2, "version_minor": 0 }, @@ -184,8 +189,7 @@ " all_seeds = [seed1] * batch_size\n", " all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n", " all_l = scale * np.random.RandomState(seed2).randn(167)\n", - " #all_w = Gs.components.mapping.run(all_z, np.tile(all_l, (batch_size, 1))) # [minibatch, layer, component]\n", - " all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component]\n", + " all_w = Gs.components.mapping.run(all_z, np.tile(all_l, (batch_size, 1))) # [minibatch, layer, component]\n", " if truncation != 1:\n", " w_avg = Gs.get_var('dlatent_avg')\n", " all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n", diff --git a/metrics/metric_defaults.py b/metrics/metric_defaults.py index b456e9c6..040e527c 100755 --- a/metrics/metric_defaults.py +++ b/metrics/metric_defaults.py @@ -14,12 +14,14 @@ metric_defaults = EasyDict([(args.name, args) for args in [ # ADA paper. + EasyDict(name='fid10k_full', class_name='metrics.frechet_inception_distance.FID', max_reals=None, num_fakes=10000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)), EasyDict(name='fid50k_full', class_name='metrics.frechet_inception_distance.FID', max_reals=None, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)), EasyDict(name='kid50k_full', class_name='metrics.kernel_inception_distance.KID', max_reals=1000000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)), EasyDict(name='pr50k3_full', class_name='metrics.precision_recall.PR', max_reals=200000, num_fakes=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)), EasyDict(name='is50k', class_name='metrics.inception_score.IS', num_images=50000, num_splits=10, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)), # Legacy: StyleGAN2. + EasyDict(name='fid10k', class_name='metrics.frechet_inception_distance.FID', max_reals=10000, num_fakes=10000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)), EasyDict(name='fid50k', class_name='metrics.frechet_inception_distance.FID', max_reals=50000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)), EasyDict(name='kid50k', class_name='metrics.kernel_inception_distance.KID', max_reals=50000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)), EasyDict(name='pr50k3', class_name='metrics.precision_recall.PR', max_reals=50000, num_fakes=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000, force_dataset_args=dict(shuffle=False, max_images=None)),