-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_ckpt_pytorch_to_tf2.py
441 lines (368 loc) · 24.9 KB
/
convert_ckpt_pytorch_to_tf2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import torch
import tensorflow as tf
import numpy as np
from absl import app
from absl import flags
from unet import UNet
from transformer import TransformerModel
from autoencoder import AutoencoderKL
flags.DEFINE_string("pytorch_ckpt_path", None, "Path to pytorch ckpt path.")
FLAGS = flags.FLAGS
def get_state_dict(filename):
sd = torch.load(filename)["state_dict"]
for k in sd.keys():
sd[k] = sd[k].numpy()
return sd
def get_transformer_weights(sd):
weights = []
for i in range(32):
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2) + ".1.to_q.weight"]
w = w.T.reshape(-1, 8, 64)
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2) + ".1.to_k.weight"]
w = w.T.reshape(-1, 8, 64)
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2) + ".1.to_v.weight"]
w = w.T.reshape(-1, 8, 64)
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2) + ".1.to_out.weight"]
w = w.T.reshape(8, 64, -1)
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2) + ".1.to_out.bias"]
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2) + ".0.weight"]
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2) + ".0.bias"]
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2+1) + ".1.net.0.0.weight"]
weights.append(w.T)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2+1) + ".1.net.0.0.bias"]
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2+1) + ".1.net.2.weight"]
weights.append(w.T)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2+1) + ".1.net.2.bias"]
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2+1) + ".0.weight"]
weights.append(w)
w = sd["cond_stage_model.transformer.attn_layers.layers." + str(i*2+1) + ".0.bias"]
weights.append(w)
weights.append(sd["cond_stage_model.transformer.norm.weight"])
weights.append(sd["cond_stage_model.transformer.norm.bias"])
weights.append(sd["cond_stage_model.transformer.token_emb.weight"])
weights.append(sd["cond_stage_model.transformer.pos_emb.emb.weight"])
return weights
def get_unet_weights(sd):
shape_dict = {1: 1, 2: 1, 4: 2, 5: 2, 7: 4, 8: 4}
shape_dict1 = {3: 4, 4: 4, 5: 4, 6: 2, 7: 2, 8: 2, 9: 1, 10: 1, 11: 1}
weights = []
weights.append(sd["model.diffusion_model.input_blocks.0.0.weight"].transpose(2, 3, 1, 0))
weights.append(sd["model.diffusion_model.input_blocks.0.0.bias"])
weights.append(sd["model.diffusion_model.time_embed.0.weight"].T)
weights.append(sd["model.diffusion_model.time_embed.0.bias"])
weights.append(sd["model.diffusion_model.time_embed.2.weight"].T)
weights.append(sd["model.diffusion_model.time_embed.2.bias"])
weights1 = []
for i in range(1, 12):
if i in (3, 6, 9):
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.op.weight"].transpose(2, 3, 1, 0))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.op.bias"])
continue
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.in_layers.0.weight"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.in_layers.0.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.in_layers.2.weight"].transpose(2, 3, 1, 0))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.in_layers.2.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.emb_layers.1.weight"].T)
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.emb_layers.1.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.out_layers.0.weight"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.out_layers.0.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.out_layers.3.weight"].transpose(2, 3, 1, 0))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.out_layers.3.bias"])
if i in (4, 7):
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.skip_connection.weight"].squeeze().T)
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.0.skip_connection.bias"])
if i in (1, 2, 4, 5, 7, 8):
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.proj_in.weight"].squeeze().T)
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.proj_in.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn1.to_q.weight"].T.reshape(320 * shape_dict[i], 8, 40 * shape_dict[i]))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn1.to_k.weight"].T.reshape(320 * shape_dict[i], 8, 40 * shape_dict[i]))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn1.to_v.weight"].T.reshape(320 * shape_dict[i], 8, 40 * shape_dict[i]))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn1.to_out.0.weight"].T.reshape(8, 40 * shape_dict[i], 320 * shape_dict[i]))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn1.to_out.0.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn2.to_q.weight"].T.reshape(320 * shape_dict[i], 8, 40 * shape_dict[i]))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn2.to_k.weight"].T.reshape(1280, 8, 40 * shape_dict[i]))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn2.to_v.weight"].T.reshape(1280, 8, 40 * shape_dict[i]))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn2.to_out.0.weight"].T.reshape(8, 40 * shape_dict[i], 320 * shape_dict[i]))
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.attn2.to_out.0.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.ff.net.0.proj.weight"].T)
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.ff.net.0.proj.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.ff.net.2.weight"].T)
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.ff.net.2.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.norm1.weight"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.norm1.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.norm2.weight"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.norm2.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.norm3.weight"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.transformer_blocks.0.norm3.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.proj_out.weight"].squeeze().T)
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.proj_out.bias"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.norm.weight"])
weights1.append(sd[f"model.diffusion_model.input_blocks.{i}.1.norm.bias"])
weights2 = []
for i in range(3):
if i in (0, 2):
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.in_layers.0.weight"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.in_layers.0.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.in_layers.2.weight"].transpose(2, 3, 1, 0))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.in_layers.2.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.emb_layers.1.weight"].T)
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.emb_layers.1.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.out_layers.0.weight"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.out_layers.0.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.out_layers.3.weight"].transpose(2, 3, 1, 0))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.out_layers.3.bias"])
else:
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.proj_in.weight"].squeeze().T)
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.proj_in.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_q.weight"].T.reshape(320 * 4, 8, 40 * 4))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_k.weight"].T.reshape(320 * 4, 8, 40 * 4))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_v.weight"].T.reshape(320 * 4, 8, 40 * 4))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_out.0.weight"].T.reshape(8, 40 * 4, 320 * 4))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_out.0.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_q.weight"].T.reshape(320 * 4, 8, 40 * 4))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_k.weight"].T.reshape(1280, 8, 40 * 4))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_v.weight"].T.reshape(1280, 8, 40 * 4))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_out.0.weight"].T.reshape(8, 40 * 4, 320 * 4))
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_out.0.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.0.proj.weight"].T)
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.0.proj.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.2.weight"].T)
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.2.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm1.weight"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm1.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm2.weight"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm2.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm3.weight"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm3.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.proj_out.weight"].squeeze().T)
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.proj_out.bias"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.norm.weight"])
weights2.append(sd[f"model.diffusion_model.middle_block.{i}.norm.bias"])
weights3 = []
for i in range(12):
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.in_layers.0.weight"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.in_layers.0.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.in_layers.2.weight"].transpose(2, 3, 1, 0))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.in_layers.2.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.emb_layers.1.weight"].T)
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.emb_layers.1.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.out_layers.0.weight"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.out_layers.0.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.out_layers.3.weight"].transpose(2, 3, 1, 0))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.out_layers.3.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.skip_connection.weight"].squeeze().T)
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.0.skip_connection.bias"])
if i in (3, 4, 5, 6, 7, 8, 9, 10, 11):
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.proj_in.weight"].squeeze().T)
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.proj_in.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn1.to_q.weight"].T.reshape(320 * shape_dict1[i], 8, 40 * shape_dict1[i]))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn1.to_k.weight"].T.reshape(320 * shape_dict1[i], 8, 40 * shape_dict1[i]))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn1.to_v.weight"].T.reshape(320 * shape_dict1[i], 8, 40 * shape_dict1[i]))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn1.to_out.0.weight"].T.reshape(8, 40 * shape_dict1[i], 320 * shape_dict1[i]))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn1.to_out.0.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn2.to_q.weight"].T.reshape(320 * shape_dict1[i], 8, 40 * shape_dict1[i]))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn2.to_k.weight"].T.reshape(1280, 8, 40 * shape_dict1[i]))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn2.to_v.weight"].T.reshape(1280, 8, 40 * shape_dict1[i]))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn2.to_out.0.weight"].T.reshape(8, 40 * shape_dict1[i], 320 * shape_dict1[i]))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.attn2.to_out.0.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.ff.net.0.proj.weight"].T)
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.ff.net.0.proj.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.ff.net.2.weight"].T)
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.ff.net.2.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.norm1.weight"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.norm1.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.norm2.weight"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.norm2.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.norm3.weight"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.transformer_blocks.0.norm3.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.proj_out.weight"].squeeze().T)
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.proj_out.bias"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.norm.weight"])
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.1.norm.bias"])
if i in (2, 5, 8):
if i == 2:
j = 1
else:
j = 2
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.{j}.conv.weight"].transpose(2, 3, 1, 0))
weights3.append(sd[f"model.diffusion_model.output_blocks.{i}.{j}.conv.bias"])
weights = weights + weights1 + weights2 + weights3
weights.append(sd[f"model.diffusion_model.out.0.weight"])
weights.append(sd[f"model.diffusion_model.out.0.bias"])
weights.append(sd[f"model.diffusion_model.out.2.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"model.diffusion_model.out.2.bias"])
return weights
def get_decoder_weights(sd):
weights = []
def get_block(weights, which):
weights.append(sd[f"first_stage_model.decoder.{which}.norm1.weight"])
weights.append(sd[f"first_stage_model.decoder.{which}.norm1.bias"])
weights.append(sd[f"first_stage_model.decoder.{which}.conv1.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.decoder.{which}.conv1.bias"])
weights.append(sd[f"first_stage_model.decoder.{which}.norm2.weight"])
weights.append(sd[f"first_stage_model.decoder.{which}.norm2.bias"])
weights.append(sd[f"first_stage_model.decoder.{which}.conv2.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.decoder.{which}.conv2.bias"])
if which in ("up.0.block.0", "up.1.block.0"):
weights.append(sd[f"first_stage_model.decoder.{which}.nin_shortcut.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.decoder.{which}.nin_shortcut.bias"])
return weights
def get_attn(weights, which):
weights.append(sd[f"first_stage_model.decoder.{which}.norm.weight"])
weights.append(sd[f"first_stage_model.decoder.{which}.norm.bias"])
weights.append(sd[f"first_stage_model.decoder.{which}.q.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.decoder.{which}.q.bias"])
weights.append(sd[f"first_stage_model.decoder.{which}.k.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.decoder.{which}.k.bias"])
weights.append(sd[f"first_stage_model.decoder.{which}.v.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.decoder.{which}.v.bias"])
weights.append(sd[f"first_stage_model.decoder.{which}.proj_out.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.decoder.{which}.proj_out.bias"])
return weights
def get_upsample(weights, i):
weights.append(sd[f"first_stage_model.decoder.up.{i}.upsample.conv.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.decoder.up.{i}.upsample.conv.bias"])
return weights
weights.append(sd[f"first_stage_model.decoder.conv_in.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.decoder.conv_in.bias"])
weights = get_block(weights, "mid.block_1")
weights = get_attn(weights, "mid.attn_1")
weights = get_block(weights, "mid.block_2")
weights = get_block(weights, "up.3.block.0")
weights = get_block(weights, "up.3.block.1")
weights = get_block(weights, "up.3.block.2")
weights = get_upsample(weights, 3)
weights = get_block(weights, "up.2.block.0")
weights = get_block(weights, "up.2.block.1")
weights = get_block(weights, "up.2.block.2")
weights = get_upsample(weights, 2)
weights = get_block(weights, "up.1.block.0")
weights = get_block(weights, "up.1.block.1")
weights = get_block(weights, "up.1.block.2")
weights = get_upsample(weights, 1)
weights = get_block(weights, "up.0.block.0")
weights = get_block(weights, "up.0.block.1")
weights = get_block(weights, "up.0.block.2")
weights.append(sd[f"first_stage_model.decoder.norm_out.weight"])
weights.append(sd[f"first_stage_model.decoder.norm_out.bias"])
weights.append(sd[f"first_stage_model.decoder.conv_out.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.decoder.conv_out.bias"])
return weights
def get_encoder_weights(sd):
weights = []
weights.append(sd[f"first_stage_model.encoder.conv_in.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.encoder.conv_in.bias"])
def get_block(weights, which):
weights.append(sd[f"first_stage_model.encoder.{which}.norm1.weight"])
weights.append(sd[f"first_stage_model.encoder.{which}.norm1.bias"])
weights.append(sd[f"first_stage_model.encoder.{which}.conv1.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.encoder.{which}.conv1.bias"])
weights.append(sd[f"first_stage_model.encoder.{which}.norm2.weight"])
weights.append(sd[f"first_stage_model.encoder.{which}.norm2.bias"])
weights.append(sd[f"first_stage_model.encoder.{which}.conv2.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.encoder.{which}.conv2.bias"])
if which in ("down.1.block.0", "down.2.block.0"):
weights.append(sd[f"first_stage_model.encoder.{which}.nin_shortcut.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.encoder.{which}.nin_shortcut.bias"])
return weights
def get_attn(weights, which):
weights.append(sd[f"first_stage_model.encoder.{which}.norm.weight"])
weights.append(sd[f"first_stage_model.encoder.{which}.norm.bias"])
weights.append(sd[f"first_stage_model.encoder.{which}.q.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.encoder.{which}.q.bias"])
weights.append(sd[f"first_stage_model.encoder.{which}.k.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.encoder.{which}.k.bias"])
weights.append(sd[f"first_stage_model.encoder.{which}.v.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.encoder.{which}.v.bias"])
weights.append(sd[f"first_stage_model.encoder.{which}.proj_out.weight"].squeeze().T)
weights.append(sd[f"first_stage_model.encoder.{which}.proj_out.bias"])
return weights
def get_downsample(weights, i):
weights.append(sd[f"first_stage_model.encoder.down.{i}.downsample.conv.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.encoder.down.{i}.downsample.conv.bias"])
return weights
weights = get_block(weights, "down.0.block.0")
weights = get_block(weights, "down.0.block.1")
weights = get_downsample(weights, 0)
weights = get_block(weights, "down.1.block.0")
weights = get_block(weights, "down.1.block.1")
weights = get_downsample(weights, 1)
weights = get_block(weights, "down.2.block.0")
weights = get_block(weights, "down.2.block.1")
weights = get_downsample(weights, 2)
weights = get_block(weights, "down.3.block.0")
weights = get_block(weights, "down.3.block.1")
weights = get_block(weights, "mid.block_1")
weights = get_attn(weights, "mid.attn_1")
weights = get_block(weights, "mid.block_2")
weights.append(sd[f"first_stage_model.encoder.norm_out.weight"])
weights.append(sd[f"first_stage_model.encoder.norm_out.bias"])
weights.append(sd[f"first_stage_model.encoder.conv_out.weight"].transpose(2, 3, 1, 0))
weights.append(sd[f"first_stage_model.encoder.conv_out.bias"])
return weights
def save_checkpoint(sd):
vocab_size = 30522
transformer = TransformerModel(vocab_size,
encoder_stack_size=32,
hidden_size=1280,
num_heads=8,
filter_size=1280*4,
dropout_rate=0.1,)
token_ids = np.asarray([[ 101, 1037, 7865, 6071, 2003, 2652, 2858, 1010, 3514, 2006,
10683, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]])
ti = np.asarray([[101, 102] + [0] * 75])
token_ids = tf.constant(np.vstack([np.tile(ti, [4, 1]), np.tile(token_ids, [4, 1])]))
transformer(token_ids)
weights = get_transformer_weights(sd)
transformer.set_weights(weights)
batch_size = 8
unet = UNet()
x = np.random.uniform(-1, 1, [batch_size // 2, 32, 32, 4]).astype("float32")
x = np.concatenate([x, x], axis=0)
t_emb = tf.constant([981] * batch_size)
context = tf.constant(np.random.uniform(-1, 1, (batch_size, 77, 1280)).astype("float32"))
unet(x, t_emb, context)
weights = get_unet_weights(sd)
unet.set_weights(weights)
autoencoder = AutoencoderKL(latent_channels=4)
images = tf.constant(np.random.uniform(-1, 1, (4, 256, 256, 3)).astype("float32"))
recon, _ = autoencoder(images)
autoencoder._encoder.set_weights(get_encoder_weights(sd))
autoencoder._quant_conv.set_weights([
sd["first_stage_model.quant_conv.weight"].squeeze().T,
sd["first_stage_model.quant_conv.bias"],]
)
autoencoder._post_quant_conv.set_weights([
sd["first_stage_model.post_quant_conv.weight"].squeeze().T,
sd["first_stage_model.post_quant_conv.bias"],]
)
autoencoder._decoder.set_weights(
get_decoder_weights(sd)
)
ckpt_transformer = tf.train.Checkpoint(transformer=transformer)
ckpt_transformer.save("transformer")
ckpt_unet = tf.train.Checkpoint(unet=unet)
ckpt_unet.save("unet")
ckpt_autoencoder = tf.train.Checkpoint(autoencoder=autoencoder)
ckpt_autoencoder.save("autoencoder")
def main(_):
sd = get_state_dict(FLAGS.pytorch_ckpt_path)
save_checkpoint(sd)
if __name__ == "__main__":
flags.mark_flag_as_required("pytorch_ckpt_path")
app.run(main)