-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmodels.py
506 lines (436 loc) · 20.7 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import OrderedDict
from copy import deepcopy
from diffusers.models import AutoencoderKL, UNet2DConditionModel
import numpy as np
from onnx import shape_inference
import onnx_graphsurgeon as gs
from polygraphy.backend.onnx.loader import fold_constants
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from cuda import cudart
import onnx
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel
class Optimizer():
def __init__(
self,
onnx_graph,
verbose=False
):
self.graph = gs.import_onnx(onnx_graph)
self.verbose = verbose
def info(self, prefix):
if self.verbose:
print(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs")
def cleanup(self, return_onnx=False):
self.graph.cleanup().toposort()
if return_onnx:
return gs.export_onnx(self.graph)
def select_outputs(self, keep, names=None):
self.graph.outputs = [self.graph.outputs[o] for o in keep]
if names:
for i, name in enumerate(names):
self.graph.outputs[i].name = name
def fold_constants(self, return_onnx=False):
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
self.graph = gs.import_onnx(onnx_graph)
if return_onnx:
return onnx_graph
def infer_shapes(self, return_onnx=False):
onnx_graph = gs.export_onnx(self.graph)
if onnx_graph.ByteSize() > 2147483648:
raise TypeError("ERROR: model size exceeds supported 2GB limit")
else:
onnx_graph = shape_inference.infer_shapes(onnx_graph)
self.graph = gs.import_onnx(onnx_graph)
if return_onnx:
return onnx_graph
def get_path(version, inpaint=False):
if version == "1.4":
if inpaint:
return "runwayml/stable-diffusion-inpainting"
else:
return "CompVis/stable-diffusion-v1-4"
elif version == "1.5":
if inpaint:
return "runwayml/stable-diffusion-inpainting"
else:
return "runwayml/stable-diffusion-v1-5"
elif version == "2.0-base":
if inpaint:
return "stabilityai/stable-diffusion-2-inpainting"
else:
return "stabilityai/stable-diffusion-2-base"
elif version == "2.0":
if inpaint:
return "stabilityai/stable-diffusion-2-inpainting"
else:
return "stabilityai/stable-diffusion-2"
elif version == "2.1":
return "stabilityai/stable-diffusion-2-1"
elif version == "2.1-base":
return "stabilityai/stable-diffusion-2-1-base"
else:
raise ValueError(f"Incorrect version {version}")
def get_embedding_dim(version):
if version in ("1.4", "1.5"):
return 768
elif version in ("2.0", "2.0-base", "2.1", "2.1-base"):
return 1024
else:
raise ValueError(f"Incorrect version {version}")
class BaseModel():
def __init__(
self,
hf_token,
fp16=False,
device='cuda',
verbose=True,
path="",
max_batch_size=16,
embedding_dim=768,
text_maxlen=77,
):
self.name = "SD Model"
self.hf_token = hf_token
self.fp16 = fp16
self.device = device
self.verbose = verbose
self.path = path
self.min_batch = 1
self.max_batch = max_batch_size
self.min_image_shape = 256 # min image resolution: 256x256
self.max_image_shape = 1024 # max image resolution: 1024x1024
self.min_latent_shape = self.min_image_shape // 8
self.max_latent_shape = self.max_image_shape // 8
self.embedding_dim = embedding_dim
self.text_maxlen = text_maxlen
def get_model(self):
pass
def get_input_names(self):
pass
def get_output_names(self):
pass
def get_dynamic_axes(self):
return None
def get_sample_input(self, batch_size, image_height, image_width):
pass
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
return None
def get_shape_dict(self, batch_size, image_height, image_width):
return None
def optimize(self, onnx_graph):
opt = Optimizer(onnx_graph, verbose=self.verbose)
opt.info(self.name + ': original')
opt.cleanup()
opt.info(self.name + ': cleanup')
opt.fold_constants()
opt.info(self.name + ': fold constants')
opt.infer_shapes()
opt.info(self.name + ': shape inference')
onnx_opt_graph = opt.cleanup(return_onnx=True)
opt.info(self.name + ': finished')
return onnx_opt_graph
def check_dims(self, batch_size, image_height, image_width):
print(batch_size, self.min_batch, self.max_batch)
assert batch_size >= self.min_batch and batch_size <= self.max_batch
assert image_height % 8 == 0 or image_width % 8 == 0
latent_height = image_height // 8
latent_width = image_width // 8
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
return (latent_height, latent_width)
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
min_batch = batch_size if static_batch else self.min_batch
max_batch = batch_size if static_batch else self.max_batch
latent_height = image_height // 8
latent_width = image_width // 8
min_image_height = image_height if static_shape else self.min_image_shape
max_image_height = image_height if static_shape else self.max_image_shape
min_image_width = image_width if static_shape else self.min_image_shape
max_image_width = image_width if static_shape else self.max_image_shape
min_latent_height = latent_height if static_shape else self.min_latent_shape
max_latent_height = latent_height if static_shape else self.max_latent_shape
min_latent_width = latent_width if static_shape else self.min_latent_shape
max_latent_width = latent_width if static_shape else self.max_latent_shape
return (min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width)
class CLIP(BaseModel):
def __init__(self,
hf_token,
device,
verbose,
path,
max_batch_size,
embedding_dim
):
super(CLIP, self).__init__(hf_token, device=device, verbose=verbose, path=path, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
self.name = "CLIP"
def get_model(self):
return CLIPTextModel.from_pretrained(self.path,
subfolder="text_encoder",
use_auth_token=self.hf_token).to(self.device)
def get_input_names(self):
return ['input_ids']
def get_output_names(self):
return ['text_embeddings', 'pooler_output']
def get_dynamic_axes(self):
return {
'input_ids': {0: 'B'},
'text_embeddings': {0: 'B'}
}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
return {
'input_ids': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
}
def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return {
'input_ids': (batch_size, self.text_maxlen),
'text_embeddings': (batch_size, self.text_maxlen, self.embedding_dim)
}
def get_sample_input(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
def optimize(self, onnx_graph):
opt = Optimizer(onnx_graph, verbose=self.verbose)
opt.info(self.name + ': original')
opt.select_outputs([0]) # delete graph output#1
opt.cleanup()
opt.info(self.name + ': remove output[1]')
opt.fold_constants()
opt.info(self.name + ': fold constants')
opt.infer_shapes()
opt.info(self.name + ': shape inference')
opt.select_outputs([0], names=['text_embeddings']) # rename network output
opt.info(self.name + ': remove output[0]')
opt_onnx_graph = opt.cleanup(return_onnx=True)
opt.info(self.name + ': finished')
return opt_onnx_graph
def make_CLIP(version, hf_token, device, verbose, max_batch_size, inpaint=False):
return CLIP(hf_token=hf_token, device=device, verbose=verbose, path=get_path(version, inpaint=inpaint),
max_batch_size=max_batch_size, embedding_dim=get_embedding_dim(version))
class UNet2DConditionModel_Cnet(torch.nn.Module):
def __init__(
self,
unet,
controlnet,
):
super().__init__()
self.unet = unet
self.controlnet = controlnet
def forward(
self, sample, encoder_hidden_states, controlnet_cond,timestep
):
down_block_res_samples, mid_block_res_sample = self.controlnet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_cond,
# controlnet_conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)
down_block_res_samples = [
down_block_res_sample * 1.0
for down_block_res_sample in down_block_res_samples
]
print("len"*20,len(down_block_res_samples))
mid_block_res_sample *= 1.0
noise_pred = self.unet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)
return noise_pred
class UNet(BaseModel):
def __init__(self,
hf_token,
fp16=False,
device='cuda',
verbose=True,
path="",
max_batch_size=16,
embedding_dim=768,
text_maxlen=77,
unet_dim=4
):
super(UNet, self).__init__(hf_token, fp16=fp16, device=device, verbose=verbose, path=path, max_batch_size=max_batch_size, embedding_dim=embedding_dim, text_maxlen=text_maxlen)
self.unet_dim = unet_dim
self.name = "UNet"
def get_model(self):
#model_opts = {'revision': 'fp16', 'torch_dtype': torch.float16} if self.fp16 else {}
# return UNet2DConditionModel.from_pretrained(self.path,
# subfolder="unet",
# use_auth_token=self.hf_token,
# **model_opts).to(self.device)
unet_tmp = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5",
subfolder="unet", torch_dtype=torch.float16).to('cuda')
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to(
'cuda')
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", unet=unet_tmp, controlnet=controlnet, torch_dtype=torch.float16).to(
'cuda')
new_unet = UNet2DConditionModel_Cnet(unet=pipe.unet, controlnet=pipe.controlnet).to('cuda')
return new_unet
def get_input_names(self):
return ['sample', 'encoder_hidden_states', 'controlnet_cond','timestep']
def get_output_names(self):
return ['latent']
def get_dynamic_axes(self):
return {
'sample': {0: '2B', 2: 'H', 3: 'W'},
'encoder_hidden_states': {0: '2B'},
'latent': {0: '2B', 2: 'H', 3: 'W'},
'controlnet_cond': {0: '2B'}
}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \
self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
print(min_image_height,min_image_width,max_image_height,max_image_width)
return {
'sample': [(2*min_batch, self.unet_dim, min_latent_height, min_latent_width), (2*batch_size, self.unet_dim, latent_height, latent_width), (2*max_batch, self.unet_dim, max_latent_height, max_latent_width)],
'encoder_hidden_states': [(2*min_batch, self.text_maxlen, self.embedding_dim), (2*batch_size, self.text_maxlen, self.embedding_dim), (2*max_batch, self.text_maxlen, self.embedding_dim)],
'controlnet_cond': [(2 * min_batch, 3, min_image_height, min_image_width),
(2 * min_batch, 3, min_image_height, min_image_width),
(2 * max_batch, 3, max_image_height, max_image_width)]
}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
'sample': (2*batch_size, self.unet_dim, latent_height, latent_width), # 2,4,64,64
'encoder_hidden_states': (2*batch_size, self.text_maxlen, self.embedding_dim), #2,77,768
'controlnet_cond': (2 * batch_size, 3, 512, 512),
'latent': (2*batch_size, 4, latent_height, latent_width) # 2,6,64,64
}
def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
dtype = torch.float16 if self.fp16 else torch.float32
return (
torch.randn(2*batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device), # sample
torch.randn(2*batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), # encoder_hidden_states
torch.randn(2 * batch_size, 3, image_height, image_width,dtype=torch.float32, device=self.device), # controlnet
torch.tensor([1.], dtype=torch.float32, device=self.device) # timestep
)
def make_UNet(version, hf_token, device, verbose, max_batch_size, inpaint=False):
return UNet(hf_token=hf_token, fp16=True, device=device, verbose=verbose, path=get_path(version, inpaint=inpaint),
max_batch_size=max_batch_size, embedding_dim=get_embedding_dim(version), unet_dim=(9 if inpaint else 4))
class VAE(BaseModel):
def __init__(self,
hf_token,
device,
verbose,
path,
max_batch_size,
embedding_dim
):
super(VAE, self).__init__(hf_token, device=device, verbose=verbose, path=path, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
self.name = "VAE decoder"
def get_model(self):
vae = AutoencoderKL.from_pretrained(self.path,
subfolder="vae",
use_auth_token=self.hf_token).to(self.device)
vae.forward = vae.decode
return vae
def get_input_names(self):
return ['latent']
def get_output_names(self):
return ['images']
def get_dynamic_axes(self):
return {
'latent': {0: 'B', 2: 'H', 3: 'W'},
'images': {0: 'B', 2: '8H', 3: '8W'}
}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \
self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
return {
'latent': [(min_batch, 4, min_latent_height, min_latent_width), (batch_size, 4, latent_height, latent_width), (max_batch, 4, max_latent_height, max_latent_width)]
}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
'latent': (batch_size, 4, latent_height, latent_width),
'images': (batch_size, 3, image_height, image_width)
}
def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
def make_VAE(version, hf_token, device, verbose, max_batch_size, inpaint=False):
return VAE(hf_token=hf_token, device=device, verbose=verbose, path=get_path(version, inpaint=inpaint),
max_batch_size=max_batch_size, embedding_dim=get_embedding_dim(version))
class TorchVAEEncoder(torch.nn.Module):
def __init__(self, token, device, path):
super().__init__()
self.path = path
self.vae_encoder = AutoencoderKL.from_pretrained(self.path, subfolder="vae", use_auth_token=token).to(device)
def forward(self, x):
return self.vae_encoder.encode(x).latent_dist.sample()
class VAEEncoder(BaseModel):
def __init__(self,
hf_token,
device,
verbose,
path,
max_batch_size,
embedding_dim
):
super(VAEEncoder, self).__init__(hf_token, device=device, verbose=verbose, path=path, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
self.name = "VAE encoder"
def get_model(self):
vae_encoder = TorchVAEEncoder(self.hf_token, self.device, self.path)
return vae_encoder
def get_input_names(self):
return ['images']
def get_output_names(self):
return ['latent']
def get_dynamic_axes(self):
return {
'images': {0: 'B', 2: '8H', 3: '8W'},
'latent': {0: 'B', 2: 'H', 3: 'W'}
}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
assert batch_size >= self.min_batch and batch_size <= self.max_batch
min_batch = batch_size if static_batch else self.min_batch
max_batch = batch_size if static_batch else self.max_batch
self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, _, _, _, _ = \
self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
return {
'images': [(min_batch, 3, min_image_height, min_image_width), (batch_size, 3, image_height, image_width), (max_batch, 3, max_image_height, max_image_width)],
}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
'images': (batch_size, 3, image_height, image_width),
'latent': (batch_size, 4, latent_height, latent_width)
}
def get_sample_input(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)
def make_VAEEncoder(version, hf_token, device, verbose, max_batch_size, inpaint=False):
return VAEEncoder(hf_token=hf_token, device=device, verbose=verbose, path=get_path(version, inpaint=inpaint),
max_batch_size=max_batch_size, embedding_dim=get_embedding_dim(version))
def make_tokenizer(version, hf_token):
return CLIPTokenizer.from_pretrained(get_path(version),
subfolder="tokenizer",
use_auth_token=hf_token)