From d8157671dd0d4fbd013c0ed5792ca7bea5d30cdf Mon Sep 17 00:00:00 2001 From: Peter Baylies Date: Sun, 8 Nov 2020 15:36:25 -0500 Subject: [PATCH] Add example WikiArt notebook for conditional generation. --- WikiArt4 Example Generation.ipynb | 149 ++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 WikiArt4 Example Generation.ipynb diff --git a/WikiArt4 Example Generation.ipynb b/WikiArt4 Example Generation.ipynb new file mode 100644 index 00000000..18c86413 --- /dev/null +++ b/WikiArt4 Example Generation.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Download the conditional WikiArt model from here: https://archive.org/download/wikiart-stylegan2-conditional-model/WikiArt4.pkl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ipywidgets as widgets\n", + "import pickle\n", + "import math\n", + "import random\n", + "import PIL.Image\n", + "import numpy as np\n", + "import pickle\n", + "import dnnlib\n", + "import dnnlib.tflib as tflib\n", + "network_pkl = 'WikiArt4.pkl'\n", + "\n", + "dnnlib.tflib.init_tf()\n", + "with dnnlib.util.open_url(network_pkl) as f:\n", + " _G, _D, Gs = pickle.load(f)\n", + "\n", + "Gs_syn_kwargs = dnnlib.EasyDict()\n", + "batch_size = 8\n", + "Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)\n", + "Gs_syn_kwargs.randomize_noise = True\n", + "Gs_syn_kwargs.minibatch_size = batch_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "artist = widgets.Dropdown(\n", + " options=[('Unknown Artist', 0), ('Boris Kustodiev', 1), ('Camille Pissarro', 2), ('Childe Hassam', 3), ('Claude Monet', 4), ('Edgar Degas', 5), ('Eugene Boudin', 6), ('Gustave Dore', 7), ('Ilya Repin', 8), ('Ivan Aivazovsky', 9), ('Ivan Shishkin', 10), ('John Singer Sargent', 11), ('Marc Chagall', 12), ('Martiros Saryan', 13), ('Nicholas Roerich', 14), ('Pablo Picasso', 15), ('Paul Cezanne', 16), ('Pierre Auguste Renoir', 17), ('Pyotr Konchalovsky', 18), ('Raphael Kirchner', 19), ('Rembrandt', 20), ('Salvador Dali', 21), ('Vincent Van Gogh', 22), ('Hieronymus Bosch', 23), ('Leonardo Da Vinci', 24), ('Albrecht Durer', 25), ('Edouard Cortes', 26), ('Sam Francis', 27), ('Juan Gris', 28), ('Lucas Cranach The Elder', 29), ('Paul Gauguin', 30), ('Konstantin Makovsky', 31), ('Egon Schiele', 32), ('Thomas Eakins', 33), ('Gustave Moreau', 34), ('Francisco Goya', 35), ('Edvard Munch', 36), ('Henri Matisse', 37), ('Fra Angelico', 38), ('Maxime Maufra', 39), ('Jan Matejko', 40), ('Mstislav Dobuzhinsky', 41), ('Alfred Sisley', 42), ('Mary Cassatt', 43), ('Gustave Loiseau', 44), ('Fernando Botero', 45), ('Zinaida Serebriakova', 46), ('Georges Seurat', 47), ('Isaac Levitan', 48), ('Joaquã­n Sorolla', 49), ('Jacek Malczewski', 50), ('Berthe Morisot', 51), ('Andy Warhol', 52), ('Arkhip Kuindzhi', 53), ('Niko Pirosmani', 54), ('James Tissot', 55), ('Vasily Polenov', 56), ('Valentin Serov', 57), ('Pietro Perugino', 58), ('Pierre Bonnard', 59), ('Ferdinand Hodler', 60), ('Bartolome Esteban Murillo', 61), ('Giovanni Boldini', 62), ('Henri Martin', 63), ('Gustav Klimt', 64), ('Vasily Perov', 65), ('Odilon Redon', 66), ('Tintoretto', 67), ('Gene Davis', 68), ('Raphael', 69), ('John Henry Twachtman', 70), ('Henri De Toulouse Lautrec', 71), ('Antoine Blanchard', 72), ('David Burliuk', 73), ('Camille Corot', 74), ('Konstantin Korovin', 75), ('Ivan Bilibin', 76), ('Titian', 77), ('Maurice Prendergast', 78), ('Edouard Manet', 79), ('Peter Paul Rubens', 80), ('Aubrey Beardsley', 81), ('Paolo Veronese', 82), ('Joshua Reynolds', 83), ('Kuzma Petrov Vodkin', 84), ('Gustave Caillebotte', 85), ('Lucian Freud', 86), ('Michelangelo', 87), ('Dante Gabriel Rossetti', 88), ('Felix Vallotton', 89), ('Nikolay Bogdanov Belsky', 90), ('Georges Braque', 91), ('Vasily Surikov', 92), ('Fernand Leger', 93), ('Konstantin Somov', 94), ('Katsushika Hokusai', 95), ('Sir Lawrence Alma Tadema', 96), ('Vasily Vereshchagin', 97), ('Ernst Ludwig Kirchner', 98), ('Mikhail Vrubel', 99), ('Orest Kiprensky', 100), ('William Merritt Chase', 101), ('Aleksey Savrasov', 102), ('Hans Memling', 103), ('Amedeo Modigliani', 104), ('Ivan Kramskoy', 105), ('Utagawa Kuniyoshi', 106), ('Gustave Courbet', 107), ('William Turner', 108), ('Theo Van Rysselberghe', 109), ('Joseph Wright', 110), ('Edward Burne Jones', 111), ('Koloman Moser', 112), ('Viktor Vasnetsov', 113), ('Anthony Van Dyck', 114), ('Raoul Dufy', 115), ('Frans Hals', 116), ('Hans Holbein The Younger', 117), ('Ilya Mashkov', 118), ('Henri Fantin Latour', 119), ('M.C. Escher', 120), ('El Greco', 121), ('Mikalojus Ciurlionis', 122), ('James Mcneill Whistler', 123), ('Karl Bryullov', 124), ('Jacob Jordaens', 125), ('Thomas Gainsborough', 126), ('Eugene Delacroix', 127), ('Canaletto', 128)],\n", + " value=0,\n", + " description='Artist: '\n", + ")\n", + "genre = widgets.Dropdown(\n", + " options=[('Abstract Painting', 129), ('Cityscape', 130), ('Genre Painting', 131), ('Illustration', 132), ('Landscape', 133), ('Nude Painting', 134), ('Portrait', 135), ('Religious Painting', 136), ('Sketch And Study', 137), ('Still Life', 138), ('Unknown Genre', 139)],\n", + " value=139,\n", + " description='Genre: '\n", + ")\n", + "style = widgets.Dropdown(\n", + " options=[('Abstract Expressionism', 140), ('Action Painting', 141), ('Analytical Cubism', 142), ('Art Nouveau', 143), ('Baroque', 144), ('Color Field Painting', 145), ('Contemporary Realism', 146), ('Cubism', 147), ('Early Renaissance', 148), ('Expressionism', 149), ('Fauvism', 150), ('High Renaissance', 151), ('Impressionism', 152), ('Mannerism Late Renaissance', 153), ('Minimalism', 154), ('Naive Art Primitivism', 155), ('New Realism', 156), ('Northern Renaissance', 157), ('Pointillism', 158), ('Pop Art', 159), ('Post Impressionism', 160), ('Realism', 161), ('Rococo', 162), ('Romanticism', 163), ('Symbolism', 164), ('Synthetic Cubism', 165), ('Ukiyo-e', 166)],\n", + " value=140,\n", + " description='Style: '\n", + ")\n", + "seed = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed: ')\n", + "scale = widgets.FloatSlider(min=0, max=25, step=0.1, value=5, description='Global Scale: ')\n", + "truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ')\n", + "scale1 = widgets.FloatSlider(min=0, max=5, step=0.01, value=1, description='Scale 1: ')\n", + "scale2 = widgets.FloatSlider(min=0, max=5, step=0.01, value=1, description='Scale 2: ')\n", + "scale3 = widgets.FloatSlider(min=0, max=5, step=0.01, value=1, description='Scale 3: ')\n", + "top_box = widgets.HBox([artist, genre, style])\n", + "mid_box = widgets.HBox([scale1, scale2, scale3])\n", + "bot_box = widgets.HBox([seed, scale, truncation])\n", + "ui = widgets.VBox([top_box, mid_box, bot_box])\n", + "\n", + "def display_sample(artist, genre, style, scale1, scale2, scale3, seed, scale, truncation):\n", + " batch_size = 1\n", + " l1 = np.zeros((1,167))\n", + " l1[0][artist] = scale1\n", + " l1[0][genre] = scale2\n", + " l1[0][style] = scale3\n", + " l1 = scale * l1\n", + " all_seeds = [seed] * batch_size\n", + " all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n", + " all_w = Gs.components.mapping.run(all_z, np.tile(l1, (batch_size, 1))) # [minibatch, layer, component]\n", + " if truncation != 1:\n", + " w_avg = Gs.get_var('dlatent_avg')\n", + " all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n", + " all_w = all_w\n", + " all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs)\n", + " display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))\n", + "\n", + "out = widgets.interactive_output(display_sample, {'artist': artist, 'genre': genre, 'style': style, 'seed': seed, 'scale1': scale1, 'scale2': scale2, 'scale3': scale3, 'scale': scale, 'truncation': truncation})\n", + "\n", + "display(ui, out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seed1 = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed 1: ')\n", + "seed2 = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed 2: ')\n", + "scale = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Scale: ')\n", + "truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ')\n", + "top_box = widgets.HBox([seed1, seed2])\n", + "bot_box = widgets.HBox([scale, truncation])\n", + "ui = widgets.VBox([top_box, bot_box])\n", + "\n", + "def display_sample(seed1, seed2, scale, truncation):\n", + " batch_size = 1\n", + " all_seeds = [seed1] * batch_size\n", + " all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]\n", + " all_l = scale * np.random.RandomState(seed2).randn(167)\n", + " #all_w = Gs.components.mapping.run(all_z, np.tile(all_l, (batch_size, 1))) # [minibatch, layer, component]\n", + " all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component]\n", + " if truncation != 1:\n", + " w_avg = Gs.get_var('dlatent_avg')\n", + " all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]\n", + " all_w = all_w\n", + " all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs)\n", + " display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))\n", + "\n", + "out = widgets.interactive_output(display_sample, {'seed1': seed1, 'seed2': seed2, 'scale': scale, 'truncation': truncation})\n", + "\n", + "display(ui, out)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}