-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_controlnet_image_annotation.py
89 lines (73 loc) · 2.68 KB
/
gen_controlnet_image_annotation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from controlnet_aux.processor import Processor
from tqdm import tqdm
from multiprocessing import Pool
from glob import glob
from PIL import Image
import argparse
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default=".")
parser.add_argument("--num_processes", type=int, default=1)
parser.add_argument("--check_images", default=False, action="store_true")
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--processor_id", type=str, default=None)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
if args.save_path is None:
args.save_path = os.path.join(args.dataset_path, "..", args.processor_id)
os.makedirs(args.save_path, exist_ok=True)
return args
def is_image(image_path):
image_types = ["png", "jpg", ".peg", "gif", "webp", "bmp", "jpeg"]
if image_path.split(".")[-1] not in image_types:
return False
# try:
# Image.open(image_path).convert("RGBA")
# except Exception:
# print(f"Error opening {image_path}")
# return False
else:
return True
def init_subprocess(processor_id):
global processor
processor = Processor(processor_id)
def get_annotation(parameters):
image_path, save_path = parameters
global processor
os.makedirs(os.path.dirname(save_path), exist_ok=True)
image = Image.open(image_path)
result = processor(image, to_pil=True)
result.save(save_path)
if __name__ == "__main__":
args = parse_args()
image_paths = glob(f"{args.dataset_path}/**", recursive=True)
image_paths = [image_path for image_path in image_paths if is_image(image_path)]
save_paths = [
os.path.join(args.save_path, os.path.relpath(image_path, args.dataset_path))
for image_path in image_paths
]
input_parameters = [
(image_path, save_path)
for image_path, save_path in zip(image_paths, save_paths)
]
if args.check_images:
print("check images")
with Pool() as p:
results = list(
tqdm(
p.imap(is_valid_image, image_paths),
total=len(image_paths),
)
)
image_paths = [image_paths[i] for i in range(len(image_paths)) if results[i]]
print(f"num images: {len(image_paths)}")
print("gen image_annotation:")
with Pool(
processes=args.num_processes,
initializer=init_subprocess,
initargs=(args.processor_id,),
) as p:
results = list(
tqdm(p.imap(get_annotation, input_parameters), total=len(image_paths))
)