-
Notifications
You must be signed in to change notification settings - Fork 577
/
Copy pathstable_diffusion_inpaint.py
117 lines (107 loc) · 3.72 KB
/
stable_diffusion_inpaint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
import glob
import argparse
import torch
import numpy as np
import PIL.Image as Image
from pathlib import Path
from diffusers import StableDiffusionInpaintPipeline
from utils.mask_processing import crop_for_filling_pre, crop_for_filling_post
from utils.crop_for_replacing import recover_size, resize_and_pad
from utils import load_img_to_array, save_array_to_img
def fill_img_with_sd(
img: np.ndarray,
mask: np.ndarray,
text_prompt: str,
device="cuda"
):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32,
).to(device)
img_crop, mask_crop = crop_for_filling_pre(img, mask)
img_crop_filled = pipe(
prompt=text_prompt,
image=Image.fromarray(img_crop),
mask_image=Image.fromarray(mask_crop)
).images[0]
img_filled = crop_for_filling_post(img, mask, np.array(img_crop_filled))
return img_filled
def replace_img_with_sd(
img: np.ndarray,
mask: np.ndarray,
text_prompt: str,
step: int = 50,
device="cuda"
):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32,
).to(device)
img_padded, mask_padded, padding_factors = resize_and_pad(img, mask)
img_padded = pipe(
prompt=text_prompt,
image=Image.fromarray(img_padded),
mask_image=Image.fromarray(255 - mask_padded),
num_inference_steps=step,
).images[0]
height, width, _ = img.shape
img_resized, mask_resized = recover_size(
np.array(img_padded), mask_padded, (height, width), padding_factors)
mask_resized = np.expand_dims(mask_resized, -1) / 255
img_resized = img_resized * (1-mask_resized) + img * mask_resized
return img_resized
def setup_args(parser):
parser.add_argument(
"--input_img", type=str, required=True,
help="Path to a single input img",
)
parser.add_argument(
"--text_prompt", type=str, required=True,
help="Text prompt",
)
parser.add_argument(
"--input_mask_glob", type=str, required=True,
help="Glob to input masks",
)
parser.add_argument(
"--output_dir", type=str, required=True,
help="Output path to the directory with results.",
)
parser.add_argument(
"--seed", type=int,
help="Specify seed for reproducibility.",
)
parser.add_argument(
"--deterministic", action="store_true",
help="Use deterministic algorithms for reproducibility.",
)
if __name__ == "__main__":
"""Example usage:
python lama_inpaint.py \
--input_img FA_demo/FA1_dog.png \
--input_mask_glob "results/FA1_dog/mask*.png" \
--text_prompt "a teddy bear on a bench" \
--output_dir results
"""
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.deterministic:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)
img_stem = Path(args.input_img).stem
mask_ps = sorted(glob.glob(args.input_mask_glob))
out_dir = Path(args.output_dir) / img_stem
out_dir.mkdir(parents=True, exist_ok=True)
img = load_img_to_array(args.input_img)
for mask_p in mask_ps:
if args.seed is not None:
torch.manual_seed(args.seed)
mask = load_img_to_array(mask_p)
img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"
img_filled = fill_img_with_sd(
img, mask, args.text_prompt, device=device)
save_array_to_img(img_filled, img_filled_p)