-
Notifications
You must be signed in to change notification settings - Fork 0
/
prediction_s900.py
111 lines (93 loc) · 3.32 KB
/
prediction_s900.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from torchvision import transforms
from PIL import Image
import cv2
import warnings
import argparse
import pyttsx3
import config
from utils import load_model, generate_backbone_name
from cvzone.HandTrackingModule import HandDetector
warnings.filterwarnings(action = "ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--weight_path",
"-w",
type=str,
default=config.S900_WEIGHT,
required=False,
help="path of the weight that is supposed to be loaded for prediction",
)
parser.add_argument(
"--backbone",
"-b",
type=str,
default=config.REAL_TIME_BACKBONE,
required=False,
help="backbone of the model architecture that is to be used for prediction",
)
args = parser.parse_args()
weight_path = args.weight_path
backbone_name = args.backbone
backbone = generate_backbone_name(backbone_name)
MEAN = [0.5172, 0.4853, 0.4789]
STD = [0.2236, 0.2257, 0.2162]
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD),
])
cap = cv2.VideoCapture(0)
imgSize = config.IMG_SIZE
labels = config.ALPHABETS
detector = HandDetector(maxHands=1)
word = ''
index_list = []
offset = 20
model = load_model(backbone, weight_path, device)
model.eval()
while True:
success, img0 = cap.read()
imgOutput = img0.copy()
hands, img1 = detector.findHands(img0)
if hands:
hand = hands[0]
x, y, w, h = hand['bbox']
imgCrop_line = img1[y - round(w/2):y + round(w/2) + h, x - round(h/2):x + w + round(h/2)]
try:
imgResize_line = cv2.resize(imgCrop_line, (imgSize, imgSize))
cv2.imshow("ImageCrop_line", imgCrop_line)
except:
pass
img_np = Image.fromarray(imgResize_line)
img_tensor = img_transform(img_np)
image = img_tensor.unsqueeze(0)
output = model(image.to(device))
_, predicted = torch.max(output.data, 1)
index = predicted.item()
cv2.rectangle(imgOutput, (x - offset, y - offset-50),
(x - offset + 90, y - offset- 50 + 50), (255, 0, 255), cv2.FILLED)
cv2.putText(imgOutput, labels[index], (x, y -26), cv2.FONT_HERSHEY_COMPLEX, 1.7, (255, 255, 255), 2)
cv2.rectangle(imgOutput, (x - offset, y - offset),
(x + w + offset, y + h + offset), (255, 0, 255), 4)
key = cv2.waitKey(1)
if str(labels[index]) == 'Y':
print('letter recorded')
try:
word = word + str(labels[int(index_list[-1])])
except:
pass
index_list = []
if str(labels[index]) == 'O':
print(word)
pyttsx3.speak(word)
word = ''
if index != 24:
index_list.append(index)
cv2.imshow("Image", imgOutput)
k = cv2.waitKey(1) & 0xFF
if k == 27:
cap.release()
break
cv2.destroyAllWindows()