-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgradio_controlnet_openpose.py
137 lines (124 loc) · 4.44 KB
/
gradio_controlnet_openpose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel
from diffusers.pipelines import StableDiffusionControlNetPipeline
import gradio as gr
import argparse
from controlnet_aux import OpenposeDetector
from oms_diffusion.garment_adapter.garment_diffusion import ClothAdapter
parser = argparse.ArgumentParser(description="oms diffusion")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument(
"--pipe_path", type=str, default="SG161222/Realistic_Vision_V4.0_noVAE"
)
args = parser.parse_args()
device = "cuda"
openpose_model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet").to(device)
control_net_openpose = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
args.pipe_path, vae=vae, controlnet=control_net_openpose, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
full_net = ClothAdapter(pipe, args.model_path, device)
def get_pose(image):
openpose_image = openpose_model(image)
return openpose_image
def process(
cloth_image,
cloth_mask_image,
prompt,
a_prompt,
n_prompt,
num_samples,
width,
height,
sample_steps,
scale,
seed,
pose_image,
):
images, cloth_mask_image = full_net.generate(
cloth_image,
cloth_mask_image,
prompt,
a_prompt,
num_samples,
n_prompt,
seed,
scale,
sample_steps,
height,
width,
image=pose_image,
)
return images, cloth_mask_image
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown(
"##You can enlarge image resolution to get better face, but the cloth maybe lose control, we will release high-resolution checkpoint soon##"
)
with gr.Row():
with gr.Column():
cloth_image = gr.Image(label="cloth Image", type="pil")
cloth_mask_image = gr.Image(
label="cloth mask Image, if not support, will be produced by inner segment algorithm",
type="pil",
)
prompt = gr.Textbox(label="Prompt", value="a photography of a model")
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(
label="Images", minimum=1, maximum=12, value=4, step=1
)
height = gr.Slider(
label="Height", minimum=256, maximum=768, value=512, step=64
)
width = gr.Slider(
label="Width", minimum=192, maximum=576, value=384, step=64
)
sample_steps = gr.Slider(
label="Steps", minimum=1, maximum=100, value=20, step=1
)
scale = gr.Slider(
label="Guidance Scale", minimum=1, maximum=10.0, value=2.5, step=0.1
)
seed = gr.Slider(
label="Seed", minimum=-1, maximum=2147483647, step=1, value=1234
)
a_prompt = gr.Textbox(
label="Added Prompt", value="best quality, high quality"
)
n_prompt = gr.Textbox(
label="Negative Prompt",
value="bare, monochrome, lowres, bad anatomy, worst quality, low quality",
)
with gr.Column():
pose_image = gr.Image(label="pose Image", type="pil")
pose_button = gr.Button(value="get pose")
with gr.Column():
result_gallery = gr.Gallery(
label="Output", show_label=False, elem_id="gallery", min_width=384
)
cloth_seg_image = gr.Image(
label="cloth mask", type="pil", width=192, height=256
)
ips = [
cloth_image,
cloth_mask_image,
prompt,
a_prompt,
n_prompt,
num_samples,
width,
height,
sample_steps,
scale,
seed,
pose_image,
]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, cloth_seg_image])
pose_button.click(fn=get_pose, inputs=pose_image, outputs=pose_image)
block.launch(server_name="0.0.0.0", server_port=7860)