-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_depth_multi_processes.py
101 lines (85 loc) · 2.96 KB
/
gen_depth_multi_processes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import argparse
from PIL import Image
import os
import huggingface_hub
import pandas as pd
import argparse
from glob import glob
from multiprocessing import Pool, current_process
from tqdm import tqdm
import json
from diffusers import MarigoldDepthPipeline
import torch
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--num_processes", type=int, default=1)
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--rel_path", type=str, default=None)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--num_gpus", type=int, default=1)
parser.add_argument(
"--model_path", type=str, default="prs-eth/marigold-depth-lcm-v1-0"
)
parser.add_argument("--check_images", default=False, action="store_true")
args = parser.parse_args()
if args.rel_path is None:
args.rel_path = args.dataset_path
return args
def is_image(image_path):
image_types = ["png", "jpg", ".peg", "gif", "webp", "bmp", "jpeg"]
if image_path.split(".")[-1] not in image_types:
return False
# try:
# Image.open(image_path).convert("RGBA")
# except Exception:
# print(f"Error opening {image_path}")
# return False
else:
return True
def is_valid_image(image_path):
try:
Image.open(image_path).convert("RGBA")
except Exception:
print(f"Error opening {image_path}")
return False
else:
return True
def init_subprocess(model_path, num_gpus):
global pipe
pipe = MarigoldDepthPipeline.from_pretrained(
model_path,
torch_dtyoe=torch.float16,
).to(f"cuda:{(current_process()._identity[0] - 1)%num_gpus}")
pipe.set_progress_bar_config(disable=True)
def process(image_path):
global pipe
save_path = image_path.replace("imgs_unzip", "depths")
if os.path.exists(save_path):
return None
try:
image = Image.open(image_path).convert("RGB")
except Exception as e:
return None
result = pipe(image, num_inference_steps=4).prediction
result = pipe.image_processor.visualize_depth(result)[0]
os.makedirs(os.path.dirname(save_path), exist_ok=True)
result.save(save_path)
if __name__ == "__main__":
args = parse_args()
image_paths = glob(f"{args.dataset_path}/**", recursive=True)
image_paths = [image_path for image_path in image_paths if is_image(image_path)]
image_paths = [
image_path
for image_path in image_paths
if not os.path.exists(image_path.replace("imgs_unzip", "depths"))
]
print(f"num images:{len(image_paths)}")
print("gen tags")
with Pool(
processes=args.num_processes,
initializer=init_subprocess,
initargs=(args.model_path, args.num_gpus),
) as p:
results = list(tqdm(p.imap(process, image_paths), total=len(image_paths)))