-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
96 lines (76 loc) · 2.86 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Opencv UI for selecting a region of interest
Author : HyeonWoo Jeong
References
- https://stackoverflow.com/questions/49799057/how-to-draw-a-point-in-an-image-using-given-co-ordinate-with-python-opencv
"""
import argparse
import cv2
import numpy as np
######################################################################
# Argument parse. Set the image path
######################################################################
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image", required=True, help="Path to the image")
args = parser.parse_args()
select_point = []
############################################
# SAM(Segmetn-Anything) setting
############################################
from segment_anything import SamPredictor, sam_model_registry
image = cv2.imread(args.image)
clone = image.copy()
sam_checkpoint = "./models/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
######################################################################
# load the image, clone it, and setup the mouse callback function
######################################################################
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
def select_point_callback(event, x, y, flags, param):
"""
Select a point in the image
"""
global select_point
if event == cv2.EVENT_LBUTTONDOWN:
select_point = [(x, y)]
input_point = np.array([[x, y]])
input_label = np.array([1])
mask, score, logit = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
h, w = mask.shape[-2:]
mask = mask.reshape(h, w, 1)
# Mask has a 255 or 0 value
mask = (mask * 255).astype(np.uint8)
# Save mask image
cv2.imwrite("mask.png", mask[:, :])
# Point update
cv2.circle(image, select_point[0], 5, (0, 0, 255), 5)
cv2.imshow("image", image)
contours = cv2.findContours(
mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
)[0]
# cv2.drawContours(mask, contours, -1, (60, 200, 200), 5)
cv2.drawContours(image, contours, -1, (60, 200, 200), 5)
# cv2.imwrite("mask2.png", mask)
cv2.imshow("image", image)
cv2.namedWindow("image")
cv2.setMouseCallback("image", select_point_callback)
# keep looping until the 'q' key is pressed
while True:
cv2.imshow("image", image)
key = cv2.waitKey(1) & 0xFF
# if the 'r' key is pressed, reset the cropping region
if key == ord("r"):
image = clone.copy()
# if the 'c' key is pressed, break from the loop
elif key == ord("c"):
break
# close all open windows
cv2.destroyAllWindows()