Skip to content

Commit

Permalink
Add example WikiArt notebook for conditional generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
pbaylies committed Nov 8, 2020
1 parent 5ed6189 commit d815767
Showing 1 changed file with 149 additions and 0 deletions.
149 changes: 149 additions & 0 deletions WikiArt4 Example Generation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Download the conditional WikiArt model from here: https://archive.org/download/wikiart-stylegan2-conditional-model/WikiArt4.pkl"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ipywidgets as widgets\n",
"import pickle\n",
"import math\n",
"import random\n",
"import PIL.Image\n",
"import numpy as np\n",
"import pickle\n",
"import dnnlib\n",
"import dnnlib.tflib as tflib\n",
"network_pkl = 'WikiArt4.pkl'\n",
"\n",
"dnnlib.tflib.init_tf()\n",
"with dnnlib.util.open_url(network_pkl) as f:\n",
" _G, _D, Gs = pickle.load(f)\n",
"\n",
"Gs_syn_kwargs = dnnlib.EasyDict()\n",
"batch_size = 8\n",
"Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)\n",
"Gs_syn_kwargs.randomize_noise = True\n",
"Gs_syn_kwargs.minibatch_size = batch_size"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"artist = widgets.Dropdown(\n",
" options=[('Unknown Artist', 0), ('Boris Kustodiev', 1), ('Camille Pissarro', 2), ('Childe Hassam', 3), ('Claude Monet', 4), ('Edgar Degas', 5), ('Eugene Boudin', 6), ('Gustave Dore', 7), ('Ilya Repin', 8), ('Ivan Aivazovsky', 9), ('Ivan Shishkin', 10), ('John Singer Sargent', 11), ('Marc Chagall', 12), ('Martiros Saryan', 13), ('Nicholas Roerich', 14), ('Pablo Picasso', 15), ('Paul Cezanne', 16), ('Pierre Auguste Renoir', 17), ('Pyotr Konchalovsky', 18), ('Raphael Kirchner', 19), ('Rembrandt', 20), ('Salvador Dali', 21), ('Vincent Van Gogh', 22), ('Hieronymus Bosch', 23), ('Leonardo Da Vinci', 24), ('Albrecht Durer', 25), ('Edouard Cortes', 26), ('Sam Francis', 27), ('Juan Gris', 28), ('Lucas Cranach The Elder', 29), ('Paul Gauguin', 30), ('Konstantin Makovsky', 31), ('Egon Schiele', 32), ('Thomas Eakins', 33), ('Gustave Moreau', 34), ('Francisco Goya', 35), ('Edvard Munch', 36), ('Henri Matisse', 37), ('Fra Angelico', 38), ('Maxime Maufra', 39), ('Jan Matejko', 40), ('Mstislav Dobuzhinsky', 41), ('Alfred Sisley', 42), ('Mary Cassatt', 43), ('Gustave Loiseau', 44), ('Fernando Botero', 45), ('Zinaida Serebriakova', 46), ('Georges Seurat', 47), ('Isaac Levitan', 48), ('Joaquã­n Sorolla', 49), ('Jacek Malczewski', 50), ('Berthe Morisot', 51), ('Andy Warhol', 52), ('Arkhip Kuindzhi', 53), ('Niko Pirosmani', 54), ('James Tissot', 55), ('Vasily Polenov', 56), ('Valentin Serov', 57), ('Pietro Perugino', 58), ('Pierre Bonnard', 59), ('Ferdinand Hodler', 60), ('Bartolome Esteban Murillo', 61), ('Giovanni Boldini', 62), ('Henri Martin', 63), ('Gustav Klimt', 64), ('Vasily Perov', 65), ('Odilon Redon', 66), ('Tintoretto', 67), ('Gene Davis', 68), ('Raphael', 69), ('John Henry Twachtman', 70), ('Henri De Toulouse Lautrec', 71), ('Antoine Blanchard', 72), ('David Burliuk', 73), ('Camille Corot', 74), ('Konstantin Korovin', 75), ('Ivan Bilibin', 76), ('Titian', 77), ('Maurice Prendergast', 78), ('Edouard Manet', 79), ('Peter Paul Rubens', 80), ('Aubrey Beardsley', 81), ('Paolo Veronese', 82), ('Joshua Reynolds', 83), ('Kuzma Petrov Vodkin', 84), ('Gustave Caillebotte', 85), ('Lucian Freud', 86), ('Michelangelo', 87), ('Dante Gabriel Rossetti', 88), ('Felix Vallotton', 89), ('Nikolay Bogdanov Belsky', 90), ('Georges Braque', 91), ('Vasily Surikov', 92), ('Fernand Leger', 93), ('Konstantin Somov', 94), ('Katsushika Hokusai', 95), ('Sir Lawrence Alma Tadema', 96), ('Vasily Vereshchagin', 97), ('Ernst Ludwig Kirchner', 98), ('Mikhail Vrubel', 99), ('Orest Kiprensky', 100), ('William Merritt Chase', 101), ('Aleksey Savrasov', 102), ('Hans Memling', 103), ('Amedeo Modigliani', 104), ('Ivan Kramskoy', 105), ('Utagawa Kuniyoshi', 106), ('Gustave Courbet', 107), ('William Turner', 108), ('Theo Van Rysselberghe', 109), ('Joseph Wright', 110), ('Edward Burne Jones', 111), ('Koloman Moser', 112), ('Viktor Vasnetsov', 113), ('Anthony Van Dyck', 114), ('Raoul Dufy', 115), ('Frans Hals', 116), ('Hans Holbein The Younger', 117), ('Ilya Mashkov', 118), ('Henri Fantin Latour', 119), ('M.C. Escher', 120), ('El Greco', 121), ('Mikalojus Ciurlionis', 122), ('James Mcneill Whistler', 123), ('Karl Bryullov', 124), ('Jacob Jordaens', 125), ('Thomas Gainsborough', 126), ('Eugene Delacroix', 127), ('Canaletto', 128)],\n",
" value=0,\n",
" description='Artist: '\n",
")\n",
"genre = widgets.Dropdown(\n",
" options=[('Abstract Painting', 129), ('Cityscape', 130), ('Genre Painting', 131), ('Illustration', 132), ('Landscape', 133), ('Nude Painting', 134), ('Portrait', 135), ('Religious Painting', 136), ('Sketch And Study', 137), ('Still Life', 138), ('Unknown Genre', 139)],\n",
" value=139,\n",
" description='Genre: '\n",
")\n",
"style = widgets.Dropdown(\n",
" options=[('Abstract Expressionism', 140), ('Action Painting', 141), ('Analytical Cubism', 142), ('Art Nouveau', 143), ('Baroque', 144), ('Color Field Painting', 145), ('Contemporary Realism', 146), ('Cubism', 147), ('Early Renaissance', 148), ('Expressionism', 149), ('Fauvism', 150), ('High Renaissance', 151), ('Impressionism', 152), ('Mannerism Late Renaissance', 153), ('Minimalism', 154), ('Naive Art Primitivism', 155), ('New Realism', 156), ('Northern Renaissance', 157), ('Pointillism', 158), ('Pop Art', 159), ('Post Impressionism', 160), ('Realism', 161), ('Rococo', 162), ('Romanticism', 163), ('Symbolism', 164), ('Synthetic Cubism', 165), ('Ukiyo-e', 166)],\n",
" value=140,\n",
" description='Style: '\n",
")\n",
"seed = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed: ')\n",
"scale = widgets.FloatSlider(min=0, max=25, step=0.1, value=5, description='Global Scale: ')\n",
"truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ')\n",
"scale1 = widgets.FloatSlider(min=0, max=5, step=0.01, value=1, description='Scale 1: ')\n",
"scale2 = widgets.FloatSlider(min=0, max=5, step=0.01, value=1, description='Scale 2: ')\n",
"scale3 = widgets.FloatSlider(min=0, max=5, step=0.01, value=1, description='Scale 3: ')\n",
"top_box = widgets.HBox([artist, genre, style])\n",
"mid_box = widgets.HBox([scale1, scale2, scale3])\n",
"bot_box = widgets.HBox([seed, scale, truncation])\n",
"ui = widgets.VBox([top_box, mid_box, bot_box])\n",
"\n",
"def display_sample(artist, genre, style, scale1, scale2, scale3, seed, scale, truncation):\n",
" batch_size = 1\n",
" l1 = np.zeros((1,167))\n",
" l1[0][artist] = scale1\n",
" l1[0][genre] = scale2\n",
" l1[0][style] = scale3\n",
" l1 = scale * l1\n",
" all_seeds = [seed] * batch_size\n",
" all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n",
" all_w = Gs.components.mapping.run(all_z, np.tile(l1, (batch_size, 1))) # [minibatch, layer, component]\n",
" if truncation != 1:\n",
" w_avg = Gs.get_var('dlatent_avg')\n",
" all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n",
" all_w = all_w\n",
" all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs)\n",
" display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))\n",
"\n",
"out = widgets.interactive_output(display_sample, {'artist': artist, 'genre': genre, 'style': style, 'seed': seed, 'scale1': scale1, 'scale2': scale2, 'scale3': scale3, 'scale': scale, 'truncation': truncation})\n",
"\n",
"display(ui, out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"seed1 = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed 1: ')\n",
"seed2 = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed 2: ')\n",
"scale = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Scale: ')\n",
"truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ')\n",
"top_box = widgets.HBox([seed1, seed2])\n",
"bot_box = widgets.HBox([scale, truncation])\n",
"ui = widgets.VBox([top_box, bot_box])\n",
"\n",
"def display_sample(seed1, seed2, scale, truncation):\n",
" batch_size = 1\n",
" all_seeds = [seed1] * batch_size\n",
" all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n",
" all_l = scale * np.random.RandomState(seed2).randn(167)\n",
" #all_w = Gs.components.mapping.run(all_z, np.tile(all_l, (batch_size, 1))) # [minibatch, layer, component]\n",
" all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component]\n",
" if truncation != 1:\n",
" w_avg = Gs.get_var('dlatent_avg')\n",
" all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n",
" all_w = all_w\n",
" all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs)\n",
" display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))\n",
"\n",
"out = widgets.interactive_output(display_sample, {'seed1': seed1, 'seed2': seed2, 'scale': scale, 'truncation': truncation})\n",
"\n",
"display(ui, out)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit d815767

Please sign in to comment.