Skip to content

Commit

Permalink
Add fid10k metrics, updates for variance and truncation with labels i…
Browse files Browse the repository at this point in the history
…n the notebook.
  • Loading branch information
pbaylies committed Dec 19, 2020
1 parent 6004052 commit 1a61a69
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
38 changes: 21 additions & 17 deletions WikiArt Example Generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"import dnnlib\n",
"import dnnlib.tflib as tflib\n",
"network_pkl = 'WikiArt5.pkl'\n",
"#network_pkl = 'WikiArt_Uncond2.pkl'\n",
"\n",
"dnnlib.tflib.init_tf()\n",
"with dnnlib.util.open_url(network_pkl) as f:\n",
Expand All @@ -39,15 +40,15 @@
},
{
"cell_type": "code",
"execution_count": 184,
"execution_count": 2,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c77fb8afc211410b9743a12bf97938ae",
"model_id": "1782f66682b3476689255fbd89fb1bfc",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -61,7 +62,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e5eea77c8b1c4ba8a1de4e3df1f096aa",
"model_id": "f3041137c22d40c998d31193bfa89708",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -91,9 +92,9 @@
")\n",
"seed = widgets.IntSlider(min=0, max=100000, step=1, value=9, description='Seed: ')\n",
"scale = widgets.FloatSlider(min=0, max=25, step=0.1, value=2, description='Global Scale: ')\n",
"truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ')\n",
"variance = widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.4, description='Variance: ')\n",
"iterations = widgets.IntSlider(min=0, max=100, step=1, value=64, description='Iterations: ')\n",
"truncation = widgets.FloatSlider(min=-5, max=10, step=0.1, value=1, description='Truncation: ')\n",
"variance = widgets.FloatSlider(min=0, max=10, step=0.1, value=0.4, description='Variance: ')\n",
"iterations = widgets.IntSlider(min=0, max=1000, step=1, value=64, description='Iterations: ')\n",
"top_box = widgets.HBox([artist, genre, style])\n",
"mid_box = widgets.HBox([variance, iterations])\n",
"bot_box = widgets.HBox([seed, scale, truncation])\n",
Expand All @@ -109,19 +110,23 @@
" all_seeds = [seed] * batch_size\n",
" all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n",
" all_w = Gs.components.mapping.run(all_z, l1) # [minibatch, layer, component]\n",
" #all_w = Gs.components.mapping.run(all_z, np.tile(l1, (batch_size, 1))) # [minibatch, layer, component]\n",
" if truncation != 1:\n",
" w_avg = Gs.get_var('dlatent_avg')\n",
" all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n",
" total = 0.0\n",
" acc_w = np.zeros((batch_size,18,512))\n",
" for i in range(400): # calculate approximate center\n",
" acc_w += Gs.components.mapping.run(0*all_z+np.random.RandomState(i).randn(512), np.tile(l1, (batch_size, 1))) # [minibatch, layer, component]\n",
" total+=1.0\n",
" acc_w /= total\n",
" w_avg = acc_w\n",
" if variance == 0 or iterations < 1:\n",
" if truncation != 1:\n",
" all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n",
" all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs)\n",
" else:\n",
" acc_w = np.zeros((batch_size,18,512))\n",
" total = 0.0\n",
" for i in range(iterations):\n",
" all_w = Gs.components.mapping.run(all_z + variance*np.random.RandomState(i).randn(512), np.tile(l1 + variance*np.random.RandomState(i).randn(167), (batch_size, 1))) # [minibatch, layer, component]\n",
" all_w = Gs.components.mapping.run(all_z + variance*np.random.RandomState(i).randn(512), np.tile(l1 + 0.1*variance*np.random.RandomState(i).randn(167), (batch_size, 1))) # [minibatch, layer, component]\n",
" if truncation != 1:\n",
" w_avg = Gs.get_var('dlatent_avg')\n",
" all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n",
" acc_w += all_w\n",
" total+=1.0\n",
Expand All @@ -138,13 +143,13 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4d7791b74ef0415dad6651e0f6cf0e80",
"model_id": "7a1eeb8cf0e048258111a8665e355eab",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -158,7 +163,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "588de2e8dcf44ca78ef94461efd1725c",
"model_id": "b218e2d962f547af9c1982648354a799",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -184,8 +189,7 @@
" all_seeds = [seed1] * batch_size\n",
" all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n",
" all_l = scale * np.random.RandomState(seed2).randn(167)\n",
" #all_w = Gs.components.mapping.run(all_z, np.tile(all_l, (batch_size, 1))) # [minibatch, layer, component]\n",
" all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component]\n",
" all_w = Gs.components.mapping.run(all_z, np.tile(all_l, (batch_size, 1))) # [minibatch, layer, component]\n",
" if truncation != 1:\n",
" w_avg = Gs.get_var('dlatent_avg')\n",
" all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n",
Expand Down
2 changes: 2 additions & 0 deletions metrics/metric_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

metric_defaults = EasyDict([(args.name, args) for args in [
# ADA paper.
EasyDict(name='fid10k_full', class_name='metrics.frechet_inception_distance.FID', max_reals=None, num_fakes=10000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)),
EasyDict(name='fid50k_full', class_name='metrics.frechet_inception_distance.FID', max_reals=None, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)),
EasyDict(name='kid50k_full', class_name='metrics.kernel_inception_distance.KID', max_reals=1000000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)),
EasyDict(name='pr50k3_full', class_name='metrics.precision_recall.PR', max_reals=200000, num_fakes=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)),
EasyDict(name='is50k', class_name='metrics.inception_score.IS', num_images=50000, num_splits=10, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)),

# Legacy: StyleGAN2.
EasyDict(name='fid10k', class_name='metrics.frechet_inception_distance.FID', max_reals=10000, num_fakes=10000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)),
EasyDict(name='fid50k', class_name='metrics.frechet_inception_distance.FID', max_reals=50000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)),
EasyDict(name='kid50k', class_name='metrics.kernel_inception_distance.KID', max_reals=50000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)),
EasyDict(name='pr50k3', class_name='metrics.precision_recall.PR', max_reals=50000, num_fakes=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000, force_dataset_args=dict(shuffle=False, max_images=None)),
Expand Down

0 comments on commit 1a61a69

Please sign in to comment.