-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgenerate_prior.py
107 lines (86 loc) · 3.6 KB
/
generate_prior.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from diffusers import (
DiffusionPipeline,
AutoencoderKL,
StableDiffusionXLPipeline
)
import torch
import numpy as np
from PIL import Image
import os
import glob
import argparse
from tqdm import tqdm, trange
from utils.models import load_sam, load_owl
model_id = '/data/model/stable-diffusion-xl-base-1.0'
vae_id = '/data/model/sdxl-vae-fp16-fix'
prompt = "a photo of a {}, simple background, full body view, award winning photography, highly detailed"
negative_prompt = "anime, cartoon, graphic, text, painting, crayon, graphite, abstract glitch, blurry"
owl_threshold = 0.4
def main(args):
num_images = args.num_images
vae = vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16)
pipe = StableDiffusionXLPipeline.from_pretrained(
model_id,
vae=vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
)
pipe = pipe.to("cuda")
gen_prompt = prompt.format(args.gen_class)
if args.out_dir is None:
args.out_dir = f"dataset/prior/{args.gen_class.replace(' ', '_')}"
os.makedirs(args.out_dir, exist_ok=True)
if args.gen_mask:
predictor = load_sam("vit_h", 'cuda')
processor, model = load_owl('cuda')
if args.owl_query is None:
args.owl_query = args.gen_class
if args.mask_dir is None:
args.mask_dir = args.out_dir
os.makedirs(args.mask_dir, exist_ok=True)
for i in trange(num_images):
if args.gen_mask:
for TRY in range(10):
img = pipe(prompt=gen_prompt, negative_prompt=negative_prompt).images[0]
inputs = processor(text=[args.owl_query], images=img, return_tensors="pt").to('cuda')
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.Tensor([img.size[::-1]])
results = processor.post_process_object_detection(
outputs=outputs,
target_sizes=target_sizes,
threshold=owl_threshold,
)
boxes = results[0]["boxes"]
boxes = np.array(boxes.cpu().detach())
if len(boxes) != 1:
print(f"not single object error")
continue
else:
input_box = boxes[0]
predictor.set_image(np.array(img))
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
mask = Image.fromarray(masks[0].astype(np.uint8) * 255)
mask.save(f"{args.mask_dir}/mask_{i}.png")
break
if TRY == 9:
raise ValueError("Failed to generate mask, try another class.")
else:
img = pipe(prompt=gen_prompt, negative_prompt=negative_prompt).images[0]
img.save(f"{args.out_dir}/img_{i}.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Eval.")
parser.add_argument("--gen_class", type=str, required=True)
parser.add_argument("--out_dir", type=str, default=None)
parser.add_argument("--num_images", type=int, default=50)
parser.add_argument("--gen_mask", action="store_true")
parser.add_argument("--owl_query", type=str, default=None)
parser.add_argument("--mask_dir", type=str, default=None)
args = parser.parse_known_args()[0]
main(args)