Skip to content

Commit

Permalink
Add more sliders for style mixing.
Browse files Browse the repository at this point in the history
  • Loading branch information
pbaylies committed Dec 22, 2020
1 parent 1a61a69 commit a8f0b1c
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions WikiArt Example Generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pb/anaconda3/envs/tf/lib/python3.6/site-packages/requests/__init__.py:91: RequestsDependencyWarning: urllib3 (1.26.2) or chardet (3.0.4) doesn't match a supported version!\n",
" RequestsDependencyWarning)\n"
]
}
],
"source": [
"import ipywidgets as widgets\n",
"import pickle\n",
Expand Down Expand Up @@ -143,18 +152,18 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7a1eeb8cf0e048258111a8665e355eab",
"model_id": "7c67b87b9a924d8cb1976c3447a2dcd3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HBox(children=(IntSlider(value=0, description='Seed 1: ', max=100000), IntSlider(value=0, descr"
"VBox(children=(HBox(children=(IntSlider(value=0, description='Content Seed: ', max=100000), IntSlider(value=0,…"
]
},
"metadata": {},
Expand All @@ -163,7 +172,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b218e2d962f547af9c1982648354a799",
"model_id": "dab59134d449418a9f63a1a064fdbfe6",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -176,31 +185,38 @@
}
],
"source": [
"seed1 = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed 1: ')\n",
"seed2 = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed 2: ')\n",
"seed1 = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Content Seed: ')\n",
"seed2 = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Content Label: ')\n",
"seed1b = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Style Seed: ')\n",
"seed2b = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Style Label: ')\n",
"scale = widgets.FloatSlider(min=-5, max=5, step=0.05, value=0, description='Scale: ')\n",
"truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ')\n",
"top_box = widgets.HBox([seed1, seed2])\n",
"mid_box = widgets.HBox([seed1b, seed2b])\n",
"bot_box = widgets.HBox([scale, truncation])\n",
"ui = widgets.VBox([top_box, bot_box])\n",
"ui = widgets.VBox([top_box, mid_box, bot_box])\n",
"\n",
"def display_sample(seed1, seed2, scale, truncation):\n",
"def display_sample(seed1, seed2, seed1b, seed2b, scale, truncation):\n",
" batch_size = 1\n",
" all_seeds = [seed1] * batch_size\n",
" all_seedsb = [seed1b] * batch_size\n",
" all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n",
" all_zb = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seedsb]) # [minibatch, component]\n",
" all_l = scale * np.random.RandomState(seed2).randn(167)\n",
" all_lb = scale * np.random.RandomState(seed2b).randn(167)\n",
" all_w = Gs.components.mapping.run(all_z, np.tile(all_l, (batch_size, 1))) # [minibatch, layer, component]\n",
" all_wb = Gs.components.mapping.run(all_zb, np.tile(all_lb, (batch_size, 1))) # [minibatch, layer, component]\n",
" if truncation != 1:\n",
" w_avg = Gs.get_var('dlatent_avg')\n",
" all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n",
" all_w = all_w\n",
" all_wb = w_avg + (all_wb - w_avg) * truncation # [minibatch, layer, component]\n",
" all_w = np.concatenate((all_w[:,0:9,:], all_wb[:,9:18,:]), axis=1)\n",
" all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs)\n",
" display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))\n",
"\n",
"out = widgets.interactive_output(display_sample, {'seed1': seed1, 'seed2': seed2, 'scale': scale, 'truncation': truncation})\n",
"out = widgets.interactive_output(display_sample, {'seed1': seed1, 'seed2': seed2, 'seed1b': seed1b, 'seed2b': seed2b, 'scale': scale, 'truncation': truncation})\n",
"\n",
"display(ui, out)\n",
"# Alternatively, less sliders to play with..."
"display(ui, out)"
]
},
{
Expand Down

0 comments on commit a8f0b1c

Please sign in to comment.