9
9
import cv2
10
10
import numpy as np
11
11
12
+ WEIGHTS_URL = 'https://github.com/notAi-tech/LogoDet/releases/download/292_classes_v1/weights'
13
+ CLASSES_URL = 'https://github.com/notAi-tech/LogoDet/releases/download/292_classes_v1/classes'
14
+
15
+ home = os .path .expanduser ("~" )
16
+ model_folder = os .path .join (home , '.LogoDet/' )
17
+ if not os .path .exists (model_folder ):
18
+ os .mkdir (model_folder )
19
+
20
+ model_path = os .path .join (model_folder , 'weights' )
21
+
22
+ if not os .path .exists (model_path ):
23
+ print ('Downloading the checkpoint to' , model_path )
24
+ pydload .dload (WEIGHTS_URL , save_to_path = model_path , max_time = None )
25
+
26
+ classes_path = os .path .join (model_folder , 'classes' )
27
+
28
+ if not os .path .exists (classes_path ):
29
+ print ('Downloading the class list to' , classes_path )
30
+ pydload .dload (CLASSES_URL , save_to_path = classes_path , max_time = None )
31
+
32
+ detection_model = models .load_model (model_path , backbone_name = 'resnet50' )
33
+ classes = open (classes_path ).readlines ()
34
+ classes = [i .strip () for i in classes if i .strip ()]
35
+
36
+ def detect_single (img_path , min_prob = 0.4 ):
37
+ image = read_image_bgr (img_path )
38
+ image = preprocess_image (image )
39
+ image , scale = resize_image (image )
40
+ boxes , scores , labels = detection_model .predict_on_batch (np .expand_dims (image , axis = 0 ))
41
+ boxes /= scale
42
+ processed_boxes = []
43
+ for box , score , label in zip (boxes [0 ], scores [0 ], labels [0 ]):
44
+ if score < min_prob :
45
+ continue
46
+ box = box .astype (int ).tolist ()
47
+ label = classes [label ]
48
+ processed_boxes .append ({'box' : box , 'score' : score , 'label' : label })
49
+
50
+ return processed_boxes
51
+
52
+
53
+ def detect_batch ():
54
+ # TODO
55
+ pass
0 commit comments