forked from Puzer/stylegan-encoder
-
Notifications
You must be signed in to change notification settings - Fork 0
/
encode_images.py
112 lines (83 loc) · 5.03 KB
/
encode_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import argparse
import pickle
from tqdm import tqdm
import PIL.Image
import numpy as np
import dnnlib
import dnnlib.tflib as tflib
import config
from encoder.generator_model import Generator
from encoder.perceptual_model import PerceptualModel
URL_FFHQ = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
def split_to_batches(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]
def make_checkpoint(counter, generator, names, generated_images_dir, dlatent_dir):
# max iterations is, like, 500,000, so six digits should be enough.
# Just to be safe, 8.
counterstring = format(counter, '08d')
generated_images = generator.generate_images()
generated_dlatents = generator.get_dlatents()
for img_array, dlatent, img_name in zip(generated_images, generated_dlatents, names):
img = PIL.Image.fromarray(img_array, 'RGB')
img.save(os.path.join(generated_images_dir, f'{counterstring}_iterations_{img_name}.png'), 'PNG')
np.save(os.path.join(dlatent_dir, f'{counterstring}_iterations_{img_name}.npy'), dlatent)
return True
def main():
parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual loss')
parser.add_argument('src_dir', help='Directory with images for encoding')
parser.add_argument('generated_images_dir', help='Directory for storing generated images')
parser.add_argument('dlatent_dir', help='Directory for storing dlatent representations')
# for now it's unclear if larger batch leads to better performance/quality
parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
# Perceptual model params
parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
parser.add_argument('--lr', default=1., help='Learning rate for perceptual model', type=float)
parser.add_argument('--iterations', default=1000, help='Number of optimization steps for each batch', type=int)
# Generator params
parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=bool)
args, other_args = parser.parse_known_args()
ref_images = [os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)]
ref_images = list(filter(os.path.isfile, ref_images))
if len(ref_images) == 0:
raise Exception('%s is empty' % args.src_dir)
os.makedirs(args.generated_images_dir, exist_ok=True)
os.makedirs(args.dlatent_dir, exist_ok=True)
# Initialize generator and perceptual model
tflib.init_tf()
with dnnlib.util.open_url(URL_FFHQ, cache_dir=config.cache_dir) as f:
generator_network, discriminator_network, Gs_network = pickle.load(f)
generator = Generator(Gs_network, args.batch_size, randomize_noise=args.randomize_noise)
#TODO: load dlatents here to pick up training if interrupted.
# latent = np.load('filename.npy')
# generator.set_dlatents()
perceptual_model = PerceptualModel(args.image_size, layer=9, batch_size=args.batch_size)
perceptual_model.build_perceptual_model(generator.generated_image)
# Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
counter = 0
for images_batch in tqdm(split_to_batches(ref_images, args.batch_size), total=len(ref_images)//args.batch_size):
names = [os.path.splitext(os.path.basename(x))[0] for x in images_batch]
perceptual_model.set_reference_images(images_batch)
op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations, learning_rate=args.lr)
pbar = tqdm(op, leave=False, total=args.iterations)
for loss in pbar:
counter = counter+1
checkpointed = False
if counter % 100 == 0 or counter < 100:
checkpointed = make_checkpoint(counter, generator, names, args.generated_images_dir, args.dlatent_dir)
print("****************************")
print(f"*counter: {counter} *")
print("****************************")
pbar.set_description(' '.join(names) + f" counter: {counter}, checkpointed: {checkpointed}" +' Last Loss: %.2f' % loss) # This is the output
print(' '.join(names), ' loss:', loss)
# Generate images from found dlatents and save them
generated_images = generator.generate_images()
generated_dlatents = generator.get_dlatents()
for img_array, dlatent, img_name in zip(generated_images, generated_dlatents, names):
img = PIL.Image.fromarray(img_array, 'RGB')
img.save(os.path.join(args.generated_images_dir, f'{args.iterations}_iters_{img_name}.png'), 'PNG')
np.save(os.path.join(args.dlatent_dir, f'{args.iterations}_{img_name}.npy'), dlatent)
generator.reset_dlatents()
if __name__ == "__main__":
main()