forked from viktorfa/oot_diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
64 lines (52 loc) · 2.43 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import argparse
from PIL import Image
from pathlib import Path
import time
from oot_diffusion import OOTDiffusionModel
DEFAULT_HG_ROOT = Path(os.getcwd()) / "ootd_models"
example_model_path = Path(__file__).parent / "oot_diffusion/assets/model_1.png"
example_garment_path = Path(__file__).parent / "oot_diffusion/assets/cloth_1.jpg"
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="oms diffusion")
parser.add_argument("--cloth_path", type=str, default=str(example_garment_path))
parser.add_argument("--person_path", type=str, default=str(example_model_path))
parser.add_argument("--hg_root", type=str, default=str(DEFAULT_HG_ROOT))
parser.add_argument("--cache_dir", type=str, required=False)
parser.add_argument("--output_path", type=str, default="./output_img")
args = parser.parse_args()
if args.person_path == str(example_model_path):
print(
f"Using example model image from {example_model_path}. Use --person_path to specify a different image."
)
if args.cloth_path == str(example_garment_path):
print(
f"Using example garment image from {example_garment_path}. Use --cloth_path to specify a different image."
)
if args.hg_root == str(DEFAULT_HG_ROOT):
print(
f"Using default hg_root to store models path {DEFAULT_HG_ROOT}. Use --hg_root to specify a different path."
)
model = OOTDiffusionModel(
hg_root=args.hg_root,
cache_dir=args.cache_dir,
)
start_time = time.perf_counter()
model.load_pipe()
end_time_load_model = time.perf_counter()
print(f"Model loaded in {end_time_load_model - start_time:.2f} seconds.")
model_image = Image.open(args.person_path)
garment_image = Image.open(args.cloth_path)
start_generate_time = time.perf_counter()
result_images, result_mask = model.generate(
model_path=model_image, cloth_path=garment_image
)
end_generate_time = time.perf_counter()
print(f"Generated in {end_generate_time - start_generate_time:.2f} seconds.")
os.makedirs(args.output_path, exist_ok=True)
with open(f"{args.output_path}/result_mask.png", "wb") as f:
result_mask.save(f, "PNG")
for i, result_image in enumerate(result_images):
with open(f"{args.output_path}/result_image_{i}.png", "wb") as f:
result_image.save(f, "PNG")
print(f"See {args.output_path}/ for the result images.")