-
Notifications
You must be signed in to change notification settings - Fork 6
/
run_one_image.py
105 lines (92 loc) · 3.54 KB
/
run_one_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
It is used to run the model on one image (and its corresponding trimap).
For default, run:
python run_one_image.py \
--config-file ViTS_1024 \
--checkpoint-dir path/to/checkpoint
It will be saved in the directory ``./demo``.
If you want to run on your own image, run:
python run_one_image.py \
--config-file ViTS(or ViTB, must be paired with the config file name in ./configs/xxx.py) \
--checkpoint-dir <your checkpoint directory> \
--image-dir <your image directory> \
--trimap-dir <your trimap directory> \
--output-dir <your output directory> \
--device <your device>
"""
import cv2
from PIL import Image
from re import findall
from os.path import join as opj
from torchvision.transforms import functional as F
from detectron2.engine import default_argument_parser
from detectron2.config import LazyConfig, instantiate
from detectron2.checkpoint import DetectionCheckpointer
def infer_one_image(model, input, save_dir=None):
"""
Infer the alpha matte of one image.
Input:
model: the trained model
image: the input image
trimap: the input trimap
"""
output = model(input)
output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
cv2.imwrite(opj(save_dir), output)
return None
def init_model(model, checkpoint, device, sample_strategy):
"""
Initialize the model.
Input:
config: the config file of the model
checkpoint: the checkpoint of the model
"""
cfg = LazyConfig.load(model)
if sample_strategy is not None:
cfg.difmatte.args["use_ddim"] = True if "ddim" in sample_strategy else False
cfg.diffusion.steps = int(findall(r"\d+", sample_strategy)[0])
model = instantiate(cfg.model)
diffusion = instantiate(cfg.diffusion)
cfg.difmatte.model = model
cfg.difmatte.diffusion = diffusion
difmatte = instantiate(cfg.difmatte)
difmatte.to(device)
difmatte.eval()
DetectionCheckpointer(difmatte).load(checkpoint)
return difmatte
def get_data(image_dir, trimap_dir):
"""
Get the data of one image.
Input:
image_dir: the directory of the image
trimap_dir: the directory of the trimap
"""
image = Image.open(image_dir).convert('RGB')
image = F.to_tensor(image).unsqueeze(0)
trimap = Image.open(trimap_dir).convert('L')
trimap = F.to_tensor(trimap).unsqueeze(0)
# force tri-values in trimap
trimap[trimap > 0.9] = 1.00000
trimap[(trimap >= 0.1) & (trimap <= 0.9)] = 0.50000
trimap[trimap < 0.1] = 0.00000
return {
'image': image,
'trimap': trimap
}
if __name__ == '__main__':
#add argument we need:
parser = default_argument_parser()
parser.add_argument('--config-dir', type=str, default='configs/ViTS_1024.py')
parser.add_argument('--checkpoint-dir', type=str, required=True)
parser.add_argument('--image-dir', type=str, default='demo/retriever_rgb.png')
parser.add_argument('--trimap-dir', type=str, default='demo/retriever_trimap.png')
parser.add_argument('--output-dir', type=str, default='demo/result.png')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--sample-strategy', type=str, default="ddim10")
args = parser.parse_args()
input = get_data(args.image_dir, args.trimap_dir)
print('Initializing model...Please wait...')
model = init_model(args.config_dir, args.checkpoint_dir, args.device, args.sample_strategy)
print('Model initialized. Start inferencing...')
alpha = infer_one_image(model, input, args.output_dir)
print('Inferencing finished.')