forked from supervisely-ecosystem/RT-DETR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
83 lines (64 loc) · 2.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
'''by lyuwenyu
'''
import time
import contextlib
import numpy as np
from PIL import Image
from collections import OrderedDict
import onnx
import torch
import onnx_graphsurgeon
def to_binary_data(path, size=(640, 640), output_name='input_tensor.bin'):
'''--loadInputs='image:input_tensor.bin'
'''
im = Image.open(path).resize(size)
data = np.asarray(im, dtype=np.float32).transpose(2, 0, 1)[None] / 255.
data.tofile(output_name)
def yolo_insert_nms(path, score_threshold=0.01, iou_threshold=0.7, max_output_boxes=300, simplify=False):
'''
http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/api/onnxops/onnx__EfficientNMS_TRT.html
https://huggingface.co/spaces/muttalib1326/Punjabi_Character_Detection/blob/3dd1e17054c64e5f6b2254278f96cfa2bf418cd4/utils/add_nms.py
'''
onnx_model = onnx.load(path)
if simplify:
from onnxsim import simplify
onnx_model, _ = simplify(onnx_model, overwrite_input_shapes={'image': [1, 3, 640, 640]})
graph = onnx_graphsurgeon.import_onnx(onnx_model)
graph.toposort()
graph.fold_constants()
graph.cleanup()
topk = max_output_boxes
attrs = OrderedDict(plugin_version='1',
background_class=-1,
max_output_boxes=topk,
score_threshold=score_threshold,
iou_threshold=iou_threshold,
score_activation=False,
box_coding=0, )
outputs = [onnx_graphsurgeon.Variable('num_dets', np.int32, [-1, 1]),
onnx_graphsurgeon.Variable('det_boxes', np.float32, [-1, topk, 4]),
onnx_graphsurgeon.Variable('det_scores', np.float32, [-1, topk]),
onnx_graphsurgeon.Variable('det_classes', np.int32, [-1, topk])]
graph.layer(op='EfficientNMS_TRT',
name="batched_nms",
inputs=[graph.outputs[0],
graph.outputs[1]],
outputs=outputs,
attrs=attrs, )
graph.outputs = outputs
graph.cleanup().toposort()
onnx.save(onnx_graphsurgeon.export_onnx(graph), f'yolo_w_nms.onnx')
class TimeProfiler(contextlib.ContextDecorator):
def __init__(self, ):
self.total = 0
def __enter__(self, ):
self.start = self.time()
return self
def __exit__(self, type, value, traceback):
self.total += self.time() - self.start
def reset(self, ):
self.total = 0
def time(self, ):
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()