Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upload files for TensorRT and TF-TRT for yolo LP #508

Open
wants to merge 10 commits into
base: tensorrt
Choose a base branch
from
17 changes: 17 additions & 0 deletions peekingduck/configs/model/license_plate_detector.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
input: ["img"]
output: ["bboxes", "bbox_labels", "bbox_scores"]

model_type: v4tiny # v4 or v4tiny
classes: ../weights/yolo_license_plate/classes.names
weights_dir: ["../weights/yolo_license_plate"]
model_weights_dir: {
v4: ../weights/yolo_license_plate/LPyolov4,
v4tiny: ../weights/yolo_license_plate/LPyolov4tiny,
}
blob_file: "yolo_license_plate.zip"

size: 416
max_output_size_per_class: 50
max_total_size: 50
yolo_score_threshold: 0.1
yolo_iou_threshold: 0.3
21 changes: 21 additions & 0 deletions peekingduck/configs/model/yolo_lp_tensorrt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
input: ["img"]
output: ["bboxes", "bbox_labels", "bbox_scores"]

yolo_plugin_path: "/home/aisg/Cvhub/tensorrt_demo/tensorrt_demos/plugins/libyolo_layer.so"

model_type: v4 # v4 or v4tiny
TensorRT_path: {
v4: "/home/aisg/Cvhub/tensorrt_demo/tensorrt_demos/yolo/yolov4-LP.trt",
v4tiny: "/home/aisg/Cvhub/tensorrt_demo/tensorrt_demos/yolo/yolov4tiny-LP.trt",
}

converted_model_path: {
v4: "../weights/TRTv4",
v4tiny: "../weights/TRTv4tiny",
}

size: 416
max_output_size_per_class: 50
max_total_size: 50
yolo_score_threshold: 0.1
yolo_iou_threshold: 0.3
20 changes: 20 additions & 0 deletions peekingduck/configs/model/yolo_lp_tf_trt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
input: ["img"]
output: ["bboxes", "bbox_labels", "bbox_scores"]


model_type: v4 # v4 or v4tiny
model_weights_dir: {
v4: "../weights/yolo_license_plate/LPyolov4",
v4tiny: "../weights/yolo_license_plate/LPyolov4tiny",
}

converted_model_path: {
v4: "../weights/TRTv4",
v4tiny: "../weights/TRTv4tiny",
}

size: 416
max_output_size_per_class: 50
max_total_size: 50
yolo_score_threshold: 0.1
yolo_iou_threshold: 0.3
17 changes: 17 additions & 0 deletions peekingduck/pipeline/nodes/model/tensorrtv1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2021 AI Singapore
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Yolo-related files for TensorRT yolo node
"""
302 changes: 302 additions & 0 deletions peekingduck/pipeline/nodes/model/tensorrtv1/detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# Copyright 2021 AI Singapore

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Object detection class using TensorRT yolo ,with custom yolo_layer
plugins, single label model to find license plate object bboxes
"""

import ctypes
from typing import Dict, Any, List, Tuple
import numpy as np
import cv2
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit


class HostDeviceMem(object):
"""
Simple helper data class that's a little nicer to use than a 2-tuple.
"""

def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem

def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

def __repr__(self):
return self.__str__()


class Detector:
"""
Object detection class using yolo model to find object bboxes
"""

def __init__(self, config: Dict[str, Any], cuda_ctx=None) -> None:
self.config = config
self.trt_logger = trt.Logger(trt.Logger.INFO)
self.cuda_ctx = cuda_ctx
if self.cuda_ctx:
self.cuda_ctx.push()

try:
ctypes.cdll.LoadLibrary(self.config["yolo_plugin_path"])
except OSError as error:
raise SystemExit(
"ERROR: failed to load ./plugins/libyolo_layer.so. "
'Did you forget to do a "make" in the "./plugins/" '
"subdirectory?"
) from error

self.engine = self._load_engine()
self.input_shape = self.get_input_shape(self.engine)

try:
self.context = self.engine.create_execution_context()
# self.inputs, self.outputs, self.bindings, self.stream = \
# self.allocate_buffers()
except Exception as error:
raise RuntimeError("fail to allocate CUDA resources") from error
finally:
if self.cuda_ctx:
self.cuda_ctx.pop()

def _load_engine(self) -> Any:
trtbin = self.config["TensorRT_path"][self.config["model_type"]]
with open(trtbin, "rb") as file, trt.Runtime(self.trt_logger) as runtime:
return runtime.deserialize_cuda_engine(file.read())

def allocate_buffers(self) -> Tuple[List, List, List, Any]:
"""Allocates all host/device in/out buffers required for an engine."""
inputs = []
outputs = []
bindings = []
output_idx = 0
stream = cuda.Stream()
for binding in self.engine:
binding_dims = self.engine.get_binding_shape(binding)
if len(binding_dims) == 4:
# explicit batch case (TensorRT 7+)
size = trt.volume(binding_dims)
elif len(binding_dims) == 3:
# implicit batch case (TensorRT 6 or older)
size = trt.volume(binding_dims) * self.engine.max_batch_size
else:
raise ValueError(
"bad dims of binding %s: %s" % (binding, str(binding_dims))
)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if self.engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
# each grid has 3 anchors, each anchor generates a detection
# output of 7 float32 values
assert size % 7 == 0
outputs.append(HostDeviceMem(host_mem, device_mem))
output_idx += 1
assert len(inputs) == 1
assert len(outputs) == 1
return inputs, outputs, bindings, stream

@staticmethod
def get_input_shape(engine):
"""Get input shape of the TensorRT YOLO engine."""
binding = engine[0]
assert engine.binding_is_input(binding)
binding_dims = engine.get_binding_shape(binding)
if len(binding_dims) == 4:
return tuple(binding_dims[2:])
if len(binding_dims) == 3:
return tuple(binding_dims[1:])
raise ValueError("bad dims of binding %s: %s" % (binding, str(binding_dims)))

def do_inference_v2(self):
"""do_inference_v2 (for TensorRT 7.0+)

This function is generalized for multiple inputs/outputs for full
dimension networks.
Inputs and outputs are expected to be 2 tuple of (host_mem, device_mem)
from allocate_buffers function.
"""
# Transfer input data to the GPU.
[
cuda.memcpy_htod_async(inp.device, inp.host, self.stream)
for inp in self.inputs
]
# Run inference.
self.context.execute_async_v2(
bindings=self.bindings, stream_handle=self.stream.handle
)
# Transfer predictions back from the GPU.
[
cuda.memcpy_dtoh_async(out.host, out.device, self.stream)
for out in self.outputs
]
# Synchronize the stream
self.stream.synchronize()
# Return only the host outputs.
return [out.host for out in self.outputs]

@staticmethod
def _nms_boxes(detections, nms_threshold):
"""
Apply the Non-Maximum Suppression (NMS) algorithm on the bounding
boxes with their confidence scores and return an array with the
indexes of the bounding boxes we want to keep.

# Args
detections: Nx7 numpy arrays of
[[x, y, w, h, box_confidence, class_id, class_prob],
......]
"""
x_coord = detections[:, 0]
y_coord = detections[:, 1]
width = detections[:, 2]
height = detections[:, 3]
box_confidences = detections[:, 4]

areas = width * height
ordered = box_confidences.argsort()[::-1]

keep = list()
while ordered.size > 0:
# Index of the current element:
i = ordered[0]
keep.append(i)
xxmin = np.maximum(x_coord[i], x_coord[ordered[1:]])
yymin = np.maximum(y_coord[i], y_coord[ordered[1:]])
xxmax = np.minimum(
x_coord[i] + width[i], x_coord[ordered[1:]] + width[ordered[1:]]
)
yymax = np.minimum(
y_coord[i] + height[i], y_coord[ordered[1:]] + height[ordered[1:]]
)

width1 = np.maximum(0.0, xxmax - xxmin)
height1 = np.maximum(0.0, yymax - yymin)
intersection = width1 * height1
union = areas[i] + areas[ordered[1:]] - intersection
iou = intersection / union
indexes = np.where(iou <= nms_threshold)[0]
ordered = ordered[indexes + 1]

keep = np.array(keep)
return keep

def _post_process(self, trt_outputs: List[np.ndarray]):
"""
Post process TRT output

args:
trt_ouputs: a list of tensors where each tensor contans a
multiple of 7 float32 numbers in the order of
[x, y, w, h, box_conf_score, class_id, class_prob]
return:
bbox: list of bbox coordinate in [xmin,ymin,xmax,ymax]
scores: list of bbox conf score
classes: list of class id
"""
detections = []
for output in trt_outputs:
detection = output.reshape((-1, 7))
detection = detection[detection[:, 4] >= 0.1]
detections.append(detection)
detections = np.concatenate(detections, axis=0)

# NMS
nms_detections = np.zeros((0, 7), dtype=detections.dtype)
for class_id in set(detections[:, 5]):
idxs = np.where(detections[:, 5] == class_id)
cls_detections = detections[idxs]
keep = self._nms_boxes(cls_detections, 0.3)
nms_detections = np.concatenate(
[nms_detections, cls_detections[keep]], axis=0
)

xmin = nms_detections[:, 0].reshape(-1, 1)
ymin = nms_detections[:, 1].reshape(-1, 1)
width = nms_detections[:, 2].reshape(-1, 1)
height = nms_detections[:, 3].reshape(-1, 1)
boxes = np.concatenate([xmin, ymin, xmin + width, ymin + height], axis=1)
scores = nms_detections[:, 4]
classes = nms_detections[:, 5]

# update the labels names of the object detected
# labels = np.asarray([self.class_labels[int(i)] for i in classes])

return boxes, classes, scores

@staticmethod
def bbox_scaling(bboxes: List[list], scale_factor: float) -> List[list]:
"""
To scale the width and height of bboxes from v4tiny
After the conversion of the model in .cfg and .weight file format, from
Alexey's Darknet repo, to tf model, bboxes are bigger.
So downscaling is required for a better fit
"""
for idx, box in enumerate(bboxes):
xmin, ymin, xmax, ymax = tuple(box)
center_x = (xmin + xmax) / 2
center_y = (ymin + ymax) / 2
scaled_xmin = center_x - ((xmax - xmin) / 2 * scale_factor)
scaled_xmax = center_x + ((xmax - xmin) / 2 * scale_factor)
scaled_ymin = center_y - ((ymax - ymin) / 2 * scale_factor)
scaled_ymax = center_y + ((ymax - ymin) / 2 * scale_factor)
bboxes[idx] = [scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax]

return bboxes

def predict(self, frame: np.array) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Detect all license plate objects' bounding box from one image

args:
image: (Numpy Array) input image

return:
boxes: (Numpy Array) an array of bounding box with
definition like (xmin, ymin, xmax, ymax), in a
coordinate system with origin point in
the left top corner
labels: (Numpy Array) an array of class labels
scores: (Numpy Array) an array of confidence scores
"""
self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()

image_data = cv2.resize(frame, (self.input_shape[1], self.input_shape[0]))

image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
image_data = image_data.transpose((2, 0, 1)).astype(np.float32)
image_data = image_data / 255.0

self.inputs[0].host = np.ascontiguousarray(image_data)
if self.cuda_ctx:
self.cuda_ctx.push()
trt_outputs = self.do_inference_v2()
if self.cuda_ctx:
self.cuda_ctx.pop()

bboxes, labels, scores = self._post_process(trt_outputs)

return bboxes, labels, scores
Loading