-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodal_stable_diffusion.py
147 lines (124 loc) · 4.31 KB
/
modal_stable_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# git clone https://github.com/modal-labs/modal-examples
# cd modal-examples
# modal run 06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py --prompt 'An 1600s oil painting of the New York City skyline'
# modal run 06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py --prompt 'An 1600s oil painting of Silicon Valley VC investors'
from __future__ import annotations
import io
import os
import time
from pathlib import Path
from modal import Image, Secret, Stub, method
stub = Stub("stable-diffusion-cli")
model_id = "runwayml/stable-diffusion-v1-5"
cache_path = "/vol/cache"
def download_models():
import diffusers
import torch
hugging_face_token = os.environ["HUGGINGFACE_TOKEN"]
# Download scheduler configuration. Experiment with different schedulers
# to identify one that works best for your use-case.
scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_auth_token=hugging_face_token,
cache_dir=cache_path,
)
scheduler.save_pretrained(cache_path, safe_serialization=True)
# Downloads all other models.
pipe = diffusers.StableDiffusionPipeline.from_pretrained(
model_id,
use_auth_token=hugging_face_token,
revision="fp16",
torch_dtype=torch.float16,
cache_dir=cache_path,
)
pipe.save_pretrained(cache_path, safe_serialization=True)
image = (
Image.debian_slim(python_version="3.10")
.pip_install(
"accelerate",
"diffusers[torch]>=0.15.1",
"ftfy",
"torchvision",
"transformers~=4.25.1",
"triton",
"safetensors",
)
.pip_install(
"torch==2.0.1+cu117",
find_links="https://download.pytorch.org/whl/torch_stable.html",
)
.pip_install("xformers", pre=True)
.run_function(
download_models,
secrets=[Secret.from_name("huggingface-secret")],
)
)
stub.image = image
@stub.cls(gpu="A10G")
class StableDiffusion:
def __enter__(self):
import diffusers
import torch
torch.backends.cuda.matmul.allow_tf32 = True
scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
cache_path,
subfolder="scheduler",
solver_order=2,
prediction_type="epsilon",
thresholding=False,
algorithm_type="dpmsolver++",
solver_type="midpoint",
denoise_final=True, # important if steps are <= 10
low_cpu_mem_usage=True,
device_map="auto",
)
self.pipe = diffusers.StableDiffusionPipeline.from_pretrained(
cache_path,
scheduler=scheduler,
low_cpu_mem_usage=True,
device_map="auto",
)
self.pipe.enable_xformers_memory_efficient_attention()
@method()
def run_inference(
self, prompt: str, steps: int = 20, batch_size: int = 4
) -> list[bytes]:
import torch
with torch.inference_mode():
with torch.autocast("cuda"):
images = self.pipe(
[prompt] * batch_size,
num_inference_steps=steps,
guidance_scale=7.0,
).images
# Convert to PNG bytes
image_output = []
for image in images:
with io.BytesIO() as buf:
image.save(buf, format="PNG")
image_output.append(buf.getvalue())
return image_output
@stub.local_entrypoint()
def entrypoint(
prompt: str, samples: int = 5, steps: int = 10, batch_size: int = 1
):
print(
f"prompt => {prompt}, steps => {steps}, samples => {samples}, batch_size => {batch_size}"
)
dir = Path("/tmp/stable-diffusion")
if not dir.exists():
dir.mkdir(exist_ok=True, parents=True)
sd = StableDiffusion()
for i in range(samples):
t0 = time.time()
images = sd.run_inference.call(prompt, steps, batch_size)
total_time = time.time() - t0
print(
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
)
for j, image_bytes in enumerate(images):
output_path = dir / f"output_{j}_{i}.png"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as f:
f.write(image_bytes)