-
Notifications
You must be signed in to change notification settings - Fork 881
/
sdxl_gen_img.py
executable file
·3210 lines (2778 loc) · 137 KB
/
sdxl_gen_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import itertools
import json
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
import inspect
import time
import zipfile
from diffusers.utils import deprecate
from diffusers.configuration_utils import FrozenDict
import argparse
import math
import os
import random
import re
import diffusers
import numpy as np
import torch
from library.device_utils import init_ipex, clean_memory, get_preferred_device
init_ipex()
import torchvision
from diffusers import (
AutoencoderKL,
DDPMScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
# UNet2DConditionModel,
StableDiffusionPipeline,
)
from einops import rearrange
from tqdm import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
import PIL
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util
import library.train_util as train_util
import library.sdxl_model_util as sdxl_model_util
import library.sdxl_train_util as sdxl_train_util
from networks.lora import LoRANetwork
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
# その他の設定
LATENT_CHANNELS = 4
DOWNSAMPLING_FACTOR = 8
CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True)
elif xformers:
logger.info("Enable xformers for U-Net")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
logger.info("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True)
# TODO common train_util.py
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
replace_vae_attn_to_memory_efficient()
elif xformers:
# replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す?
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
elif sdpa:
replace_vae_attn_to_sdpa()
def replace_vae_attn_to_memory_efficient():
logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states, **kwargs):
q_bucket_size = 512
k_bucket_size = 1024
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj)
)
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_flash_attn(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_flash_attn
def replace_vae_attn_to_xformers():
logger.info("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj)
)
query_proj = query_proj.contiguous()
key_proj = key_proj.contiguous()
value_proj = value_proj.contiguous()
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_xformers_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_xformers(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_xformers
def replace_vae_attn_to_sdpa():
logger.info("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj)
)
out = torch.nn.functional.scaled_dot_product_attention(
query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False
)
out = rearrange(out, "b n h d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_sdpa_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_sdpa(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_sdpa
# endregion
# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正
class PipelineLike:
def __init__(
self,
device,
vae: AutoencoderKL,
text_encoders: List[CLIPTextModel],
tokenizers: List[CLIPTokenizer],
unet: InferSdxlUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int,
):
super().__init__()
self.device = device
self.clip_skip = clip_skip
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
self.vae = vae
self.text_encoders = text_encoders
self.tokenizers = tokenizers
self.unet: InferSdxlUNet2DConditionModel = unet
self.scheduler = scheduler
self.safety_checker = None
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
self.token_replacements_list.append({})
# ControlNet # not supported yet
self.control_nets: List[ControlNetLLLite] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
self.gradual_latent: GradualLatent = None
# Textual Inversion
def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids):
self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids
def set_enable_control_net(self, en: bool):
self.control_net_enabled = en
def get_token_replacer(self, tokenizer):
tokenizer_index = self.tokenizers.index(tokenizer)
token_replacements = self.token_replacements_list[tokenizer_index]
def replace_tokens(tokens):
# logger.info("replace_tokens", tokens, "=>", token_replacements)
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
new_tokens = []
for token in tokens:
if token in token_replacements:
replacement = token_replacements[token]
new_tokens.extend(replacement)
else:
new_tokens.append(token)
return new_tokens
return replace_tokens
def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
logger.info("gradual_latent is disabled")
self.gradual_latent = None
else:
logger.info(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
height: int = 1024,
width: int = 1024,
original_height: int = None,
original_width: int = None,
original_height_negative: int = None,
original_width_negative: int = None,
crop_top: int = 0,
crop_left: int = 0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_scale: float = None,
strength: float = 0.8,
# num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
vae_batch_size: float = None,
return_latents: bool = False,
# return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
img2img_noise=None,
clip_guide_images=None,
**kwargs,
):
# TODO support secondary prompt
num_images_per_prompt = 1 # fixed because already prompt is repeated
if isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
reginonal_network = " AND " in prompt[0]
vae_batch_size = (
batch_size
if vae_batch_size is None
else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
)
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
)
# get prompt text embeddings
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
logger.info(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
tes_text_embs = []
tes_uncond_embs = []
tes_real_uncond_embs = []
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
token_replacer = self.get_token_replacer(tokenizer)
# use last text_pool, because it is from text encoder 2
text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings(
tokenizer,
text_encoder,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
**kwargs,
)
tes_text_embs.append(text_embeddings)
tes_uncond_embs.append(uncond_embeddings)
if negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
token_replacer,
prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
uncond_prompt=[""] * batch_size,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
**kwargs,
)
tes_real_uncond_embs.append(real_uncond_embeddings)
# concat text encoder outputs
text_embeddings = tes_text_embs[0]
uncond_embeddings = tes_uncond_embs[0]
for i in range(1, len(tes_text_embs)):
text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048
if do_classifier_free_guidance:
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
if self.control_nets:
# ControlNetのhintにguide imageを流用する
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
if isinstance(clip_guide_images[0], PIL.Image.Image):
clip_guide_images = [preprocess_image(im) for im in clip_guide_images]
clip_guide_images = torch.cat(clip_guide_images)
if isinstance(clip_guide_images, list):
clip_guide_images = torch.stack(clip_guide_images)
clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype)
# create size embs
if original_height is None:
original_height = height
if original_width is None:
original_width = width
if original_height_negative is None:
original_height_negative = original_height
if original_width_negative is None:
original_width_negative = original_width
if crop_top is None:
crop_top = 0
if crop_left is None:
crop_left = 0
emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
uc_emb1 = sdxl_train_util.get_timestep_embedding(
torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256
)
emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
if reginonal_network:
# use last pool for conditioning
num_sub_prompts = len(text_pool) // batch_size
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
if init_image is not None and self.clip_vision_model is not None:
logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True)
clip_vision_embeddings = clip_vision_embeddings.image_embeds
if len(clip_vision_embeddings) == 1 and batch_size > 1:
clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1))
clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength
assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}"
text_pool = clip_vision_embeddings # replace: same as ComfyUI (?)
c_vector = torch.cat([text_pool, c_vector], dim=1)
if do_classifier_free_guidance:
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
vector_embeddings = torch.cat([uc_vector, c_vector])
else:
vector_embeddings = c_vector
# set timesteps
self.scheduler.set_timesteps(num_inference_steps, self.device)
latents_dtype = text_embeddings.dtype
init_latents_orig = None
mask = None
if init_image is None:
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (
batch_size * num_images_per_prompt,
self.unet.in_channels,
height // 8,
width // 8,
)
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(
latents_shape,
generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
timesteps = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
else:
# image to tensor
if isinstance(init_image, PIL.Image.Image):
init_image = [init_image]
if isinstance(init_image[0], PIL.Image.Image):
init_image = [preprocess_image(im) for im in init_image]
init_image = torch.cat(init_image)
if isinstance(init_image, list):
init_image = torch.stack(init_image)
# mask image to tensor
if mask_image is not None:
if isinstance(mask_image, PIL.Image.Image):
mask_image = [mask_image]
if isinstance(mask_image[0], PIL.Image.Image):
mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint
# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[-2:] == (height // 8, width // 8):
init_latents = init_image
else:
if vae_batch_size >= batch_size:
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
(init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to(
self.vae.dtype
)
).latent_dist
init_latents.append(init_latent_dist.sample(generator=generator))
init_latents = torch.cat(init_latents)
init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
if len(init_latents) == 1:
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
init_latents_orig = init_latents
# preprocess mask
if mask_image is not None:
mask = mask_image.to(device=self.device, dtype=latents_dtype)
if len(mask) == 1:
mask = mask.repeat((batch_size, 1, 1, 1))
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets:
# guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
if self.control_net_enabled:
for control_net, _ in self.control_nets:
with torch.no_grad():
control_net.set_cond_image(clip_guide_images)
else:
for control_net, _ in self.control_nets:
control_net.set_cond_image(None)
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
# # first, we downscale the latents to the half of the size
# # 最初に1/2に縮小する
# height, width = latents.shape[-2:]
# # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to(
# # latents.dtype
# # )
# latents = latents[:, :, ::2, ::2]
# current_scale = 0.5
# # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?)
# # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので)
# scale_step = 0.125
# # timesteps at which to start increasing the scale: 1000 seems to be enough
# # 拡大を開始するtimesteps: 1000で十分そうである
# start_timesteps = 1000
# # how many steps to wait before increasing the scale again
# # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed)
# # large values leads to flat images
# # 何ステップごとに拡大するか
# # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる)
# # 大きすぎると細部が書き込まれずのっぺりした感じになる
# every_n_steps = 5
# scale_step = input("scale step:")
# scale_step = float(scale_step)
# start_timesteps = input("start timesteps:")
# start_timesteps = int(start_timesteps)
# every_n_steps = input("every n steps:")
# every_n_steps = int(every_n_steps)
# # for i, t in enumerate(tqdm(timesteps)):
# i = 0
# last_step = 0
# while i < len(timesteps):
# t = timesteps[i]
# print(f"[{i}] t={t}")
# print(i, t, current_scale, latents.shape)
# if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0:
# if i == last_step:
# pass
# else:
# print("upscale")
# current_scale = min(current_scale + scale_step, 1.0)
# h = int(height * current_scale) // 8 * 8
# w = int(width * current_scale) // 8 * 8
# latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to(
# latents.dtype
# )
# last_step = i
# i = max(0, i - every_n_steps + 1)
# diff = timesteps[i] - timesteps[last_step]
# # resized_init_noise = torch.nn.functional.interpolate(
# # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False
# # ).to(latents.dtype)
# # latents = self.scheduler.add_noise(latents, resized_init_noise, diff)
# latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4)
# # latents += torch.randn_like(latents) / 100 * diff
# continue
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
logger.info(f'{self.scheduler.__class__.__name__}')
else:
enable_gradual_latent = True
step_elapsed = 1000
current_ratio = self.gradual_latent.ratio
# first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする
height, width = latents.shape[-2:]
org_dtype = latents.dtype
if org_dtype == torch.bfloat16:
latents = latents.float()
latents = torch.nn.functional.interpolate(
latents, scale_factor=current_ratio, mode="bicubic", align_corners=False
).to(org_dtype)
# apply unsharp mask / アンシャープマスクを適用する
if self.gradual_latent.gaussian_blur_ksize:
latents = self.gradual_latent.apply_unshark_mask(latents)
for i, t in enumerate(tqdm(timesteps)):
resized_size = None
if enable_gradual_latent:
# gradually upscale the latents / latentsを徐々にアップスケールする
if (
t < self.gradual_latent.start_timesteps
and current_ratio < 1.0
and step_elapsed >= self.gradual_latent.every_n_steps
):
current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0)
# make divisible by 8 because size of latents must be divisible at bottom of UNet
h = int(height * current_ratio) // 8 * 8
w = int(width * current_ratio) // 8 * 8
resized_size = (h, w)
self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent)
step_elapsed = 0
else:
self.scheduler.set_gradual_latent_params(None, None)
step_elapsed += 1
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# disable control net if ratio is set
if self.control_nets and self.control_net_enabled:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
# predict the noise residual
# TODO Diffusers' ControlNet
# if self.control_nets and self.control_net_enabled:
# if reginonal_network:
# num_sub_and_neg_prompts = len(text_embeddings) // batch_size
# text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
# else:
# text_emb_last = text_embeddings
# # not working yet
# noise_pred = original_control_net.call_unet_and_control_net(
# i,
# num_latent_input,
# self.unet,
# self.control_nets,
# guided_hints,
# i / len(timesteps),
# latent_model_input,
# t,
# text_emb_last,
# ).sample
# else:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
# perform guidance
if do_classifier_free_guidance:
if negative_scale is None:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(
num_latent_input
) # uncond is real uncond
noise_pred = (
noise_pred_uncond
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
- negative_scale * (noise_pred_negative - noise_pred_uncond)
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
i += 1
if return_latents:
return latents
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
if vae_batch_size >= batch_size:
image = self.vae.decode(latents.to(self.vae.dtype)).sample
else:
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
self.vae.decode(
(latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype)
).sample
)
image = torch.cat(images)
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
clean_memory()
if output_type == "pil":
# image = self.numpy_to_pil(image)
image = (image * 255).round().astype("uint8")
image = [Image.fromarray(im) for im in image]
return image
# return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier