forked from kiri-art/docker-diffusers-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert-to-diffusers.py
92 lines (80 loc) · 2.98 KB
/
convert-to-diffusers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import requests
import subprocess
from utils import Storage
CHECKPOINT_CONFIG_URL = os.environ.get("CHECKPOINT_CONFIG_URL", None)
CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL", None)
MODEL_ID = os.environ.get("MODEL_ID", None)
_CONVERT_SPECIAL = os.environ.get("_CONVERT_SPECIAL", None)
if __name__ == "__main__":
# response = requests.get(
# "https://github.com/huggingface/diffusers/raw/main/scripts/convert_original_stable_diffusion_to_diffusers.py"
# )
# open("convert_original_stable_diffusion_to_diffusers.py", "wb").write(
# response.content
# )
if not CHECKPOINT_URL or CHECKPOINT_URL == "":
quit()
CHECKPOINT_DIR = "/root/.cache/checkpoints"
fname = CHECKPOINT_DIR + "/" + CHECKPOINT_URL.split("/").pop()
if CHECKPOINT_CONFIG_URL != "":
storage = Storage(CHECKPOINT_CONFIG_URL)
configPath = (
CHECKPOINT_DIR + "/" + CHECKPOINT_URL.split("/").pop() + "_config.yaml"
)
print(f"Downloading {CHECKPOINT_CONFIG_URL} to {configPath}...")
storage.download_file(configPath)
specialSrc = "https://raw.githubusercontent.com/hafriedlander/diffusers/stable_diffusion_2/scripts/convert_original_stable_diffusion_to_diffusers.py"
specialPath = CHECKPOINT_DIR + "/" + "convert_special.py"
if _CONVERT_SPECIAL:
storage = Storage(specialSrc)
print(f"Downloading {specialSrc} to {specialPath}")
storage.download_file(specialPath)
scriptPath = (
specialPath
if _CONVERT_SPECIAL
else "./diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py"
)
print("Converting " + fname + " to diffusers model " + MODEL_ID + "...", flush=True)
subprocess.run(
["pip", "install", "omegaconf", "pytorch_lightning", "tensorboard"], check=True
)
subprocess.run(["apt-get", "install", "-y", "wget"], check=True)
subprocess.run(
[
"sed",
"-i",
# Force loading into CPU
"s/torch.load(args.checkpoint_path)/torch.load(args.checkpoint_path, map_location=torch.device('cpu'))/",
scriptPath,
]
)
# Nice to check but also there seems to be a race condition here which
# needs further investigation. Python docs are clear that subprocess.run()
# will "Wait for command to complete, then return a CompletedProcess instance."
# But it really seems as though without the grep in the middle, the script is
# run before sed completes, or maybe there's some FS level caching gotchas.
subprocess.run(
[
"grep",
"torch.load",
scriptPath,
],
check=True,
)
args = [
"python3",
scriptPath,
"--extract_ema",
"--checkpoint_path",
fname,
"--dump_path",
MODEL_ID,
]
if CHECKPOINT_CONFIG_URL:
args.append("--original_config_file")
args.append(configPath)
subprocess.run(
args,
check=True,
)