Skip to content
This repository was archived by the owner on Jun 24, 2021. It is now read-only.

Commit 3a3f227

Browse files
committed
add prediction boilerplate
Signed-off-by: Praneeth <[email protected]>
1 parent 008b8c9 commit 3a3f227

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

predictor.py

+44
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,47 @@
99
import cv2
1010
import numpy as np
1111

12+
WEIGHTS_URL = 'https://github.com/notAi-tech/LogoDet/releases/download/292_classes_v1/weights'
13+
CLASSES_URL = 'https://github.com/notAi-tech/LogoDet/releases/download/292_classes_v1/classes'
14+
15+
home = os.path.expanduser("~")
16+
model_folder = os.path.join(home, '.LogoDet/')
17+
if not os.path.exists(model_folder):
18+
os.mkdir(model_folder)
19+
20+
model_path = os.path.join(model_folder, 'weights')
21+
22+
if not os.path.exists(model_path):
23+
print('Downloading the checkpoint to', model_path)
24+
pydload.dload(WEIGHTS_URL, save_to_path=model_path, max_time=None)
25+
26+
classes_path = os.path.join(model_folder, 'classes')
27+
28+
if not os.path.exists(classes_path):
29+
print('Downloading the class list to', classes_path)
30+
pydload.dload(CLASSES_URL, save_to_path=classes_path, max_time=None)
31+
32+
detection_model = models.load_model(model_path, backbone_name='resnet50')
33+
classes = open(classes_path).readlines()
34+
classes = [i.strip() for i in classes if i.strip()]
35+
36+
def detect_single(img_path, min_prob=0.4):
37+
image = read_image_bgr(img_path)
38+
image = preprocess_image(image)
39+
image, scale = resize_image(image)
40+
boxes, scores, labels = detection_model.predict_on_batch(np.expand_dims(image, axis=0))
41+
boxes /= scale
42+
processed_boxes = []
43+
for box, score, label in zip(boxes[0], scores[0], labels[0]):
44+
if score < min_prob:
45+
continue
46+
box = box.astype(int).tolist()
47+
label = classes[label]
48+
processed_boxes.append({'box': box, 'score': score, 'label': label})
49+
50+
return processed_boxes
51+
52+
53+
def detect_batch():
54+
# TODO
55+
pass

0 commit comments

Comments
 (0)