-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
143 lines (92 loc) · 4.58 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
from pathlib import Path
from tqdm import tqdm
# torch
import torch
from einops import repeat
# vision imports
from PIL import Image
from torchvision.utils import make_grid, save_image
# dalle related classes and utils
from dalle_pytorch import __version__
from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer
# argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dalle_path', type = str, required = True,
help='path to your trained DALL-E')
parser.add_argument('--vqgan_model_path', type=str, default = None,
help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')
parser.add_argument('--vqgan_config_path', type=str, default = None,
help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')
parser.add_argument('--text', type = str, required = True,
help='your text prompt')
parser.add_argument('--num_images', type = int, default = 128, required = False,
help='number of images')
parser.add_argument('--batch_size', type = int, default = 4, required = False,
help='batch size')
parser.add_argument('--top_k', type = float, default = 0.9, required = False,
help='top k filter threshold')
parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
help='output directory')
parser.add_argument('--bpe_path', type = str,
help='path to your huggingface BPE json file')
parser.add_argument('--hug', dest='hug', action = 'store_true')
parser.add_argument('--chinese', dest='chinese', action = 'store_true')
parser.add_argument('--taming', dest='taming', action='store_true')
parser.add_argument('--gentxt', dest='gentxt', action='store_true')
args = parser.parse_args()
# helper fns
def exists(val):
return val is not None
# tokenizer
if exists(args.bpe_path):
klass = HugTokenizer if args.hug else YttmTokenizer
tokenizer = klass(args.bpe_path)
elif args.chinese:
tokenizer = ChineseTokenizer()
# load DALL-E
dalle_path = Path(args.dalle_path)
assert dalle_path.exists(), 'trained DALL-E must exist'
load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights, vae_class_name, version = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights'), load_obj.pop('vae_class_name', None), load_obj.pop('version', None)
# friendly print
if exists(version):
print(f'Loading a model trained with DALLE-pytorch version {version}')
else:
print('You are loading a model trained on an older version of DALL-E pytorch - it may not be compatible with the most recent version')
# load VAE
if args.taming:
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
elif vae_params is not None:
vae = DiscreteVAE(**vae_params)
else:
vae = OpenAIDiscreteVAE()
assert not (exists(vae_class_name) and vae.__class__.__name__ != vae_class_name), f'you trained DALL-E using {vae_class_name} but are trying to generate with {vae.__class__.__name__} - please make sure you are passing in the correct paths and settings for the VAE to use for generation'
# reconstitute DALL-E
dalle = DALLE(vae = vae, **dalle_params).cuda()
dalle.load_state_dict(weights)
# generate images
image_size = vae.image_size
texts = args.text.split('|')
for j, text in tqdm(enumerate(texts)):
if args.gentxt:
text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k)
text = gen_texts[0]
else:
text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda()
text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images)
outputs = []
for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'):
output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
outputs.append(output)
outputs = torch.cat(outputs)
# save all images
file_name = text
outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)]
outputs_dir.mkdir(parents = True, exist_ok = True)
for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
save_image(image, outputs_dir / f'{i}.png', normalize=True)
with open(outputs_dir / 'caption.txt', 'w') as f:
f.write(file_name)
print(f'created {args.num_images} images at "{str(outputs_dir)}"')