Skip to content

Commit 382b22c

Browse files
mbaek01wooyeolbaek
authored andcommitted
Add sd3.5 and flux
1 parent 043aadc commit 382b22c

19 files changed

+638
-10
lines changed

assets/0-<|startoftext|>.png

-2.83 KB
Binary file not shown.

assets/1-<a>.png

-1.2 KB
Binary file not shown.

assets/10-<hello>.png

-956 Bytes
Binary file not shown.

assets/11-<world>.png

-834 Bytes
Binary file not shown.

assets/12-<.>.png

-803 Bytes
Binary file not shown.

assets/13-<|endoftext|>.png

-1.4 KB
Binary file not shown.

assets/2-<cap-.png

-921 Bytes
Binary file not shown.

assets/4--bara>.png

-1.99 KB
Binary file not shown.

assets/5-<holding>.png

-1.2 KB
Binary file not shown.

assets/6-<a>.png

-1.25 KB
Binary file not shown.

assets/7-<sign>.png

-1.26 KB
Binary file not shown.

assets/8-<that>.png

-954 Bytes
Binary file not shown.

assets/9-<reads>.png

-982 Bytes
Binary file not shown.

attention_map_diffusers/modules.py

+463-8
Large diffs are not rendered by default.

attention_map_diffusers/utils.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
from diffusers.models import Transformer2DModel
88
from diffusers.models.unets import UNet2DConditionModel
9-
from diffusers.models.transformers import SD3Transformer2DModel, FluxTransformer2DModel
9+
from diffusers.models.transformers import SD3Transformer2DModel, FluxTransformer2DModel, SanaTransformer2DModel
10+
from diffusers.models.transformers.sana_transformer import SanaTransformerBlock
1011
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock
1112
from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock
12-
from diffusers import FluxPipeline
13+
from diffusers import FluxPipeline, SanaPipeline
1314
from diffusers.models.attention_processor import (
1415
AttnProcessor,
1516
AttnProcessor2_0,
@@ -46,6 +47,7 @@ def register_cross_attention_hook(model, hook_function, target_name):
4647
module.processor.store_attn_map = True
4748
elif isinstance(module.processor, AttnProcessor2_0):
4849
module.processor.store_attn_map = True
50+
print('registered at {name}')
4951
elif isinstance(module.processor, LoRAAttnProcessor):
5052
module.processor.store_attn_map = True
5153
elif isinstance(module.processor, LoRAAttnProcessor2_0):
@@ -77,6 +79,20 @@ def replace_call_method_for_unet(model):
7779
return model
7880

7981

82+
def replace_call_method_for_sana(model):
83+
if model.__class__.__name__ == 'SanaTransformer2DModel':
84+
model.forward = SanaTransformer2DModelForward.__get__(model, SanaTransformer2DModel)
85+
86+
for name, layer in model.named_children():
87+
88+
if layer.__class__.__name__ == 'SanaTransformerBlock':
89+
layer.forward = SanaTransformerBlockForward.__get__(layer, SanaTransformerBlock)
90+
91+
replace_call_method_for_sana(layer)
92+
93+
return model
94+
95+
8096
def replace_call_method_for_sd3(model):
8197
if model.__class__.__name__ == 'SD3Transformer2DModel':
8298
model.forward = SD3Transformer2DModelForward.__get__(model, SD3Transformer2DModel)
@@ -122,6 +138,11 @@ def init_pipeline(pipeline):
122138
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')
123139
pipeline.transformer = replace_call_method_for_flux(pipeline.transformer)
124140

141+
elif pipeline.transformer.__class__.__name__ == 'SanaTransformer2DModel':
142+
SanaPipeline.__call__ == SanaPipeline_call
143+
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn2')
144+
pipeline.transformer = replace_call_method_for_sana(pipeline.transformer)
145+
125146
else:
126147
if pipeline.unet.__class__.__name__ == 'UNet2DConditionModel':
127148
pipeline.unet = register_cross_attention_hook(pipeline.unet, hook_function, 'attn2')

demo/demo-flux-dev.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from diffusers import FluxPipeline
3+
from attention_map_diffusers import (
4+
attn_maps,
5+
init_pipeline,
6+
save_attention_maps
7+
)
8+
9+
pipe = FluxPipeline.from_pretrained(
10+
"black-forest-labs/FLUX.1-dev",
11+
torch_dtype=torch.bfloat16
12+
)
13+
# pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
14+
pipe.to('cuda')
15+
16+
##### 1. Replace modules and Register hook #####
17+
pipe = init_pipeline(pipe)
18+
################################################
19+
20+
# recommend not using batch operations for sd3, as cpu memory could be exceeded.
21+
prompts = [
22+
# "A photo of a puppy wearing a hat.",
23+
"A capybara holding a sign that reads Hello World.",
24+
]
25+
26+
images = pipe(
27+
prompts,
28+
num_inference_steps=15,
29+
guidance_scale=4.5,
30+
).images
31+
32+
for batch, image in enumerate(images):
33+
image.save(f'{batch}-flux-dev.png')
34+
35+
##### 2. Process and Save attention map #####
36+
save_attention_maps(attn_maps, pipe.tokenizer, prompts, base_dir='attn_maps', unconditional=False)
37+
#############################################

demo/demo-flux-schnell.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from diffusers import FluxPipeline
3+
from attention_map_diffusers import (
4+
attn_maps,
5+
init_pipeline,
6+
save_attention_maps
7+
)
8+
9+
pipe = FluxPipeline.from_pretrained(
10+
"black-forest-labs/FLUX.1-schnell",
11+
torch_dtype=torch.bfloat16
12+
)
13+
# pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
14+
pipe.to('cuda')
15+
16+
##### 1. Replace modules and Register hook #####
17+
pipe = init_pipeline(pipe)
18+
################################################
19+
20+
# recommend not using batch operations for sd3, as cpu memory could be exceeded.
21+
prompts = [
22+
# "A photo of a puppy wearing a hat.",
23+
"A capybara holding a sign that reads Hello World.",
24+
]
25+
26+
images = pipe(
27+
prompts,
28+
num_inference_steps=15,
29+
guidance_scale=4.5,
30+
).images
31+
32+
for batch, image in enumerate(images):
33+
image.save(f'{batch}-flux-schnell.png')
34+
35+
##### 2. Process and Save attention map #####
36+
save_attention_maps(attn_maps, pipe.tokenizer, prompts, base_dir='attn_maps', unconditional=False)
37+
#############################################

demo/demo-sana.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch
2+
from diffusers import SanaPipeline
3+
from attention_map_diffusers import (
4+
attn_maps,
5+
init_pipeline,
6+
save_attention_maps
7+
)
8+
9+
pipe = SanaPipeline.from_pretrained(
10+
"Efficient-Large-Model/Sana_1600M_1024px_diffusers",
11+
variant="fp16",
12+
torch_dtype=torch.float16,
13+
)
14+
pipe.to("cuda")
15+
16+
pipe.vae.to(torch.bfloat16)
17+
pipe.text_encoder.to(torch.bfloat16)
18+
19+
##### 1. Replace modules and Register hook #####
20+
# TODO: not implemented yet.
21+
pipe = init_pipeline(pipe)
22+
################################################
23+
24+
prompts = [
25+
"a cyberpunk cat with a neon sign that says 'Sana'",
26+
# "A capybara holding a sign that reads Hello World.",
27+
]
28+
images = pipe(
29+
prompt=prompts,
30+
height=1024,
31+
width=1024,
32+
guidance_scale=5.0,
33+
num_inference_steps=20,
34+
generator=torch.Generator(device="cuda").manual_seed(42),
35+
).images
36+
37+
for batch, image in enumerate(images):
38+
image.save(f'{batch}-sana.png')
39+
40+
##### 2. Process and Save attention map #####
41+
save_attention_maps(attn_maps, pipe.tokenizer, prompts, base_dir='attn_maps', unconditional=True)
42+
#############################################

demo/demo-sd3-5.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from diffusers import StableDiffusion3Pipeline
3+
from attention_map_diffusers import (
4+
attn_maps,
5+
init_pipeline,
6+
save_attention_maps
7+
)
8+
9+
pipe = StableDiffusion3Pipeline.from_pretrained(
10+
"stabilityai/stable-diffusion-3.5-medium",
11+
torch_dtype=torch.bfloat16
12+
)
13+
pipe = pipe.to("cuda")
14+
15+
##### 1. Replace modules and Register hook #####
16+
pipe = init_pipeline(pipe)
17+
################################################
18+
19+
# recommend not using batch operations for sd3, as cpu memory could be exceeded.
20+
prompts = [
21+
# "A photo of a puppy wearing a hat.",
22+
"A capybara holding a sign that reads Hello World.",
23+
]
24+
25+
images = pipe(
26+
prompts,
27+
num_inference_steps=15,
28+
guidance_scale=4.5,
29+
).images
30+
31+
for batch, image in enumerate(images):
32+
image.save(f'{batch}-sd3-5.png')
33+
34+
##### 2. Process and Save attention map #####
35+
save_attention_maps(attn_maps, pipe.tokenizer, prompts, base_dir='attn_maps', unconditional=True)
36+
#############################################

0 commit comments

Comments
 (0)