-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_lstm.py
104 lines (84 loc) · 2.82 KB
/
inference_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import cv2
import mediapipe as mp
import numpy as np
import threading
import tensorflow as tf
label = "Warmup...."
n_time_steps = 10
lm_list = []
mpPose = mp.solutions.pose
pose = mpPose.Pose()
mpDraw = mp.solutions.drawing_utils
model = tf.keras.models.load_model("swing_clap_nothing.h5")
cap = cv2.VideoCapture(0)
def make_landmark_timestep(results):
c_lm = []
for id, lm in enumerate(results.pose_landmarks.landmark):
c_lm.append(lm.x)
c_lm.append(lm.y)
c_lm.append(lm.z)
c_lm.append(lm.visibility)
return c_lm
def draw_landmark_on_image(mpDraw, results, img):
mpDraw.draw_landmarks(img, results.pose_landmarks, mpPose.POSE_CONNECTIONS)
for id, lm in enumerate(results.pose_landmarks.landmark):
h, w, c = img.shape
print(id, lm)
cx, cy = int(lm.x * w), int(lm.y * h)
cv2.circle(img, (cx, cy), 5, (255, 0, 0), cv2.FILLED)
return img
def draw_class_on_image(label, img): #vẽ chữ lên ảnh
font = cv2.FONT_HERSHEY_SIMPLEX
bottomLeftCornerOfText = (10, 30)
fontScale = 1
fontColor = (0, 255, 0)
thickness = 2
lineType = 2
cv2.putText(img, label,
bottomLeftCornerOfText,
font,
fontScale,
fontColor,
thickness,
lineType)
return img
def detect(model, lm_list):
global label
lm_list = np.array(lm_list)
lm_list = np.expand_dims(lm_list, axis=0)
print(lm_list.shape)
results = model.predict(lm_list)
print(results)
if np.argmax(results) == 0: #trả về trị số có gtr lớn nhất
label = "SWING HAND"
elif np.argmax(results) == 1:
label = "CLAPPING HAND"
else:
label = "NOTHING"
return label
i = 0
warmup_frames = 60
while True:
success, img = cap.read()
imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results = pose.process(imgRGB)
i = i + 1
if i > warmup_frames:
print("Start detect....")
if results.pose_landmarks:
c_lm = make_landmark_timestep(results)
lm_list.append(c_lm)
if len(lm_list) == n_time_steps:
# predict
t1 = threading.Thread(target=detect, args=(model, lm_list,)) #đa luồng để giúp tăng tốc độ xử lí (chạy ss)
#và tối ưu hóa hiệu suất của ứng dụng (phù hợp I/O)
#multiprocessing phù hợp tính toán
t1.start()
lm_list = []
img = draw_landmark_on_image(mpDraw, results, img)
img = draw_class_on_image(label, img)
cv2.imshow("Image", img)
if cv2.waitKey(1) == ord('q'):
break
cap.release()
cv2.destroyAllWindows()