-
Notifications
You must be signed in to change notification settings - Fork 2
/
yolact_kenning.py
executable file
·217 lines (183 loc) · 6.63 KB
/
yolact_kenning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!/usr/bin/env python3
# Copyright 2022-2024 Antmicro <www.antmicro.com>
#
# SPDX-License-Identifier: Apache-2.0
"""YOLACT ROS2 node implementation."""
import traceback
from gc import collect
from pathlib import Path
import rclpy
from kenning.modelwrappers.instance_segmentation.yolact import YOLACT
from kenning_computer_vision_msgs.msg import BoxMsg, MaskMsg, SegmentationMsg
from cvnode_base.core.cvnode_base import BaseCVNode
from cvnode_base.utils.image import imageToMat
class YOLACTOnnx:
"""
ONNX runtime wrapper for YOLACT model.
"""
def __init__(self, node: BaseCVNode):
self.node = node # ROS2 node
self.node.declare_parameter("device", rclpy.Parameter.Type.STRING)
def prepare(self) -> bool:
"""
Prepare node and model for inference.
Returns
-------
bool
True if preparation was successful, False otherwise.
"""
try:
from kenning.runtimes.onnx import ONNXRuntime
except ImportError:
self.node.get_logger().error("Cannot import ONNXRuntime")
self.node.get_logger().error(str(traceback.format_exc()))
return False
model_path = Path(self.node.get_parameter("model_path").value)
if not model_path.exists():
self.node.get_logger().error(f"File {model_path} does not exist")
return False
self.model = YOLACT(model_path, None, top_k=100, score_threshold=0.3)
device = self.node.get_parameter("device").value
if not device:
self.node.get_logger().error(
"Please specify device for TVM runtime"
)
return False
if device == "cpu":
execution_providers = ["CPUExecutionProvider"]
elif device == "cuda":
execution_providers = ["CUDAExecutionProvider"]
else:
self.node.get_logger().error(
f"Device {device} is not supported by ONNX runtime"
)
return False
self.runtime = ONNXRuntime(
model_path,
execution_providers=execution_providers,
disable_performance_measurements=True,
)
ret = self.runtime.prepare_local()
return ret
class YOLACTTVM:
"""
TVM runtime wrapper for YOLACT model.
"""
def __init__(self, node: BaseCVNode):
self.node = node # ROS2 node
self.node.declare_parameter("device", rclpy.Parameter.Type.STRING)
def prepare(self) -> bool:
"""
Prepare node and model for inference.
Returns
-------
bool
True if preparation was successful, False otherwise.
"""
try:
from kenning.runtimes.tvm import TVMRuntime
except ImportError:
self.node.get_logger().error("Cannot import TVMRuntime")
self.node.get_logger().error(str(traceback.format_exc()))
return False
model_path = Path(self.node.get_parameter("model_path").value)
if not model_path.exists():
self.node.get_logger().error(f"File {model_path} does not exist")
return False
device = self.node.get_parameter("device").value
if not device:
self.node.get_logger().error(
"Please specify device for TVM runtime"
)
return False
self.model = YOLACT(model_path, None, top_k=100, score_threshold=0.3)
self.runtime = TVMRuntime(
model_path,
contextname=device,
disable_performance_measurements=True,
)
ret = self.runtime.prepare_local()
return ret
class YOLACTTFLite:
"""
TFLite runtime wrapper for YOLACT model.
"""
def __init__(self, node: BaseCVNode):
self.node = node # ROS2 node
def prepare(self) -> bool:
"""
Prepare node and model for inference.
Returns
-------
bool
True if preparation was successful, False otherwise.
"""
try:
from kenning.runtimes.tflite import TFLiteRuntime
except ImportError:
self.node.get_logger().error("Cannot import TFLiteRuntime")
self.node.get_logger().error(str(traceback.format_exc()))
return False
model_path = Path(self.node.get_parameter("model_path").value)
if not model_path.exists():
self.node.get_logger().error(f"File {model_path} does not exist")
return False
self.model = YOLACT(model_path, None, top_k=100, score_threshold=0.3)
self.runtime = TFLiteRuntime(
model_path, disable_performance_measurements=True
)
ret = self.runtime.prepare_local()
return ret
class YOLACTNode(BaseCVNode):
"""
ROS2 node for YOLACT model.
"""
backends = {
"tvm": YOLACTTVM,
"tflite": YOLACTTFLite,
"onnxruntime": YOLACTOnnx,
}
def __init__(self, node_name: str):
self.yolact = None # Wrapper for YOLACT model with runtime
super().__init__(node_name=node_name)
self.declare_parameter("backend", rclpy.Parameter.Type.STRING)
self.declare_parameter("model_path", rclpy.Parameter.Type.STRING)
def prepare(self):
backend = self.get_parameter("backend").value
if backend not in self.backends:
self.get_logger().error(f"Backend {backend} is not supported")
return False
self.yolact = self.backends[backend](self)
return self.yolact.prepare()
def run_inference(self, X):
x = imageToMat(X.frame, "rgb8").transpose(2, 0, 1)
x = self.yolact.model.preprocess_input([x])
self.yolact.runtime.load_input([x])
self.yolact.runtime.run()
preds = self.yolact.runtime.extract_output()
preds = self.yolact.model.postprocess_outputs(preds)
msg = SegmentationMsg()
if preds:
for y in preds[0]:
box = BoxMsg()
box._xmin = float(y.xmin)
box._xmax = float(y.xmax)
box._ymin = float(y.ymin)
box._ymax = float(y.ymax)
msg._boxes.append(box)
msg._scores.append(y.score)
mask = MaskMsg()
mask._data = y.mask.flatten()
mask._dimension = [y.mask.shape[0], y.mask.shape[1]]
msg._masks.append(mask)
msg._classes.append(y.clsname)
return True, msg
def cleanup(self):
del self.yolact
collect()
if __name__ == "__main__":
rclpy.init()
node = YOLACTNode("yolact_node")
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()