diff --git a/.gitignore b/.gitignore index 830a4d9..2ad723b 100644 --- a/.gitignore +++ b/.gitignore @@ -110,4 +110,5 @@ play.py preprocess_data.py res/ adj.md +tensorrt/build/* diff --git a/README.md b/README.md index 933b3b2..86f65c3 100644 --- a/README.md +++ b/README.md @@ -4,20 +4,23 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV The mIOU evaluation result of the models trained and evaluated on cityscapes train/val set is: -| none | ss | ssc | msf | mscf | fps | link | +| none | ss | ssc | msf | mscf | fps(fp16/fp32) | link | |------|:--:|:---:|:---:|:----:|:---:|:----:| -| bisenetv1 | 74.85 | 76.46 | 77.36 | 78.72 | - | [download](https://drive.google.com/file/d/1e1_E7OrpjTaD5Rael7Fus5lg-uGZ5TUZ/view?usp=sharing) | -| bisenetv2 | 74.39 | 74.44 | 76.10 | 75.94 | - | [download](https://drive.google.com/file/d/1r_F-KZg-3s2pPcHRIuHZhZ0DQ0wocudk/view?usp=sharing) | +| bisenetv1 | 75.55 | 76.90 | 77.40 | 78.91 | 60/19 | [download](https://drive.google.com/file/d/140MBBAt49N1z1wsKueoFA6HB_QuYud8i/view?usp=sharing) | +| bisenetv2 | 74.12 | 74.18 | 75.89 | 75.87 | 50/16 | [download](https://drive.google.com/file/d/1qq38u9JT4pp1ubecGLTCHHtqwntH0FCY/view?usp=sharing) | > Where **ss** means single scale evaluation, **ssc** means single scale crop evaluation, **msf** means multi-scale evaluation with flip augment, and **mscf** means multi-scale crop evaluation with flip evaluation. The eval scales of multi-scales evaluation are `[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]`, and the crop size of crop evaluation is `[1024, 1024]`. +> The fps is tested in different way from the paper. For more information, please see [here](./tensorrt). + Note that the model has a big variance, which means that the results of training for many times would vary within a relatively big margin. For example, if you train bisenetv2 for many times, you will observe that the result of **ss** evaluation of bisenetv2 varies between 72.1-74.4. ## platform My platform is like this: -* ubuntu 16.04 -* cuda 10.1.243 +* ubuntu 18.04 +* nvidia Tesla T4 gpu, driver 450.51.05 +* cuda 10.2 * cudnn 7 * miniconda python 3.6.9 * pytorch 1.6.0 @@ -59,7 +62,12 @@ Then you need to change the field of `im_root` and `train/val_im_anns` in the co In order to train the model, you can run command like this: ``` $ export CUDA_VISIBLE_DEVICES=0,1 + +# if you want to train with apex $ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --model bisenetv2 # or bisenetv1 + +# if you want to train with pytorch fp16 feature from torch 1.6 +$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --model bisenetv2 # or bisenetv1 ``` Note that though `bisenetv2` has fewer flops, it requires much more training iterations. The the training time of `bisenetv1` is shorter. @@ -70,6 +78,9 @@ You can also load the trained model weights and finetune from it: ``` $ export CUDA_VISIBLE_DEVICES=0,1 $ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --finetune-from ./res/model_final.pth --model bisenetv2 # or bisenetv1 + +# same with pytorch fp16 feature +$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --finetune-from ./res/model_final.pth --model bisenetv2 # or bisenetv1 ``` @@ -79,6 +90,10 @@ You can also evaluate a trained model like this: $ python tools/evaluate.py --model bisenetv1 --weight-path /path/to/your/weight.pth ``` +## Infer with tensorrt +You can go to [tensorrt](./tensorrt) For details. + + ### Be aware that this is the refactored version of the original codebase. You can go to the `old` directory for original implementation. diff --git a/datasets/cityscapes/gtFine b/datasets/cityscapes/gtFine index ab8357c..ae71826 120000 --- a/datasets/cityscapes/gtFine +++ b/datasets/cityscapes/gtFine @@ -1 +1 @@ -/data2/zzy/.datasets/cityscapes/gtFine/ \ No newline at end of file +/data2/zzy/.datasets/cityscapes//gtFine/ \ No newline at end of file diff --git a/datasets/cityscapes/leftImg8bit b/datasets/cityscapes/leftImg8bit index ddf1ba5..eed1adb 120000 --- a/datasets/cityscapes/leftImg8bit +++ b/datasets/cityscapes/leftImg8bit @@ -1 +1 @@ -/data2/zzy/.datasets/cityscapes/leftImg8bit/ \ No newline at end of file +/data2/zzy/.datasets/cityscapes//leftImg8bit/ \ No newline at end of file diff --git a/lib/base_dataset.py b/lib/base_dataset.py index 89ec240..70d4ebe 100644 --- a/lib/base_dataset.py +++ b/lib/base_dataset.py @@ -40,7 +40,7 @@ def __init__(self, dataroot, annpath, trans_func=None, mode='train'): def __getitem__(self, idx): impth, lbpth = self.img_paths[idx], self.lb_paths[idx] - img, label = cv2.imread(impth), cv2.imread(lbpth, 0) + img, label = cv2.imread(impth)[:, :, ::-1], cv2.imread(lbpth, 0) if not self.lb_map is None: label = self.lb_map[label] im_lb = dict(im=img, lb=label) diff --git a/lib/models/bisenetv1.py b/lib/models/bisenetv1.py index 92620a6..a84fa23 100644 --- a/lib/models/bisenetv1.py +++ b/lib/models/bisenetv1.py @@ -13,6 +13,7 @@ class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): super(ConvBNReLU, self).__init__() self.conv = nn.Conv2d(in_chan, @@ -38,16 +39,39 @@ def init_weight(self): if not ly.bias is None: nn.init.constant_(ly.bias, 0) +class UpSample(nn.Module): + + def __init__(self, n_chan, factor=2): + super(UpSample, self).__init__() + out_chan = n_chan * factor * factor + self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0) + self.up = nn.PixelShuffle(factor) + self.init_weight() + + def forward(self, x): + feat = self.proj(x) + feat = self.up(feat) + return feat + + def init_weight(self): + nn.init.xavier_normal_(self.proj.weight, gain=1.) + + class BiSeNetOutput(nn.Module): - def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + + def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs): super(BiSeNetOutput, self).__init__() + self.up_factor = up_factor + out_chan = n_classes * up_factor * up_factor self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) - self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True) + self.up = nn.PixelShuffle(up_factor) self.init_weight() def forward(self, x): x = self.conv(x) x = self.conv_out(x) + x = self.up(x) return x def init_weight(self): @@ -79,7 +103,7 @@ def __init__(self, in_chan, out_chan, *args, **kwargs): def forward(self, x): feat = self.conv(x) - atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = torch.mean(feat, dim=(2, 3), keepdim=True) atten = self.conv_atten(atten) atten = self.bn_atten(atten) atten = self.sigmoid_atten(atten) @@ -102,28 +126,25 @@ def __init__(self, *args, **kwargs): self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + self.up32 = nn.Upsample(scale_factor=2.) + self.up16 = nn.Upsample(scale_factor=2.) self.init_weight() def forward(self, x): - H0, W0 = x.size()[2:] feat8, feat16, feat32 = self.resnet(x) - H8, W8 = feat8.size()[2:] - H16, W16 = feat16.size()[2:] - H32, W32 = feat32.size()[2:] - avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = torch.mean(feat32, dim=(2, 3), keepdim=True) avg = self.conv_avg(avg) - avg_up = F.interpolate(avg, (H32, W32), mode='nearest') feat32_arm = self.arm32(feat32) - feat32_sum = feat32_arm + avg_up - feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_sum = feat32_arm + avg + feat32_up = self.up32(feat32_sum) feat32_up = self.conv_head32(feat32_up) feat16_arm = self.arm16(feat16) feat16_sum = feat16_arm + feat32_up - feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.up16(feat16_sum) feat16_up = self.conv_head16(feat16_up) return feat16_up, feat32_up # x8, x16 @@ -203,7 +224,7 @@ def __init__(self, in_chan, out_chan, *args, **kwargs): def forward(self, fsp, fcp): fcat = torch.cat([fsp, fcp], dim=1) feat = self.convblk(fcat) - atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = torch.mean(feat, dim=(2, 3), keepdim=True) atten = self.conv1(atten) atten = self.relu(atten) atten = self.conv2(atten) @@ -231,14 +252,17 @@ def get_params(self): class BiSeNetV1(nn.Module): - def __init__(self, n_classes, *args, **kwargs): + + def __init__(self, n_classes, output_aux=True, *args, **kwargs): super(BiSeNetV1, self).__init__() self.cp = ContextPath() self.sp = SpatialPath() self.ffm = FeatureFusionModule(256, 256) - self.conv_out = BiSeNetOutput(256, 256, n_classes) - self.conv_out16 = BiSeNetOutput(128, 64, n_classes) - self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.conv_out = BiSeNetOutput(256, 256, n_classes, up_factor=8) + self.output_aux = output_aux + if self.output_aux: + self.conv_out16 = BiSeNetOutput(128, 64, n_classes, up_factor=8) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes, up_factor=16) self.init_weight() def forward(self, x): @@ -248,13 +272,12 @@ def forward(self, x): feat_fuse = self.ffm(feat_sp, feat_cp8) feat_out = self.conv_out(feat_fuse) - feat_out16 = self.conv_out16(feat_cp8) - feat_out32 = self.conv_out32(feat_cp16) - - feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) - feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) - feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) - return feat_out, feat_out16, feat_out32 + if self.output_aux: + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + return feat_out, feat_out16, feat_out32 + feat_out = feat_out.argmax(dim=1) + return feat_out def init_weight(self): for ly in self.children(): @@ -276,11 +299,13 @@ def get_params(self): if __name__ == "__main__": - net = BiSeNet(19) + net = BiSeNetV1(19) net.cuda() net.eval() in_ten = torch.randn(16, 3, 640, 480).cuda() out, out16, out32 = net(in_ten) print(out.shape) + print(out16.shape) + print(out32.shape) net.get_params() diff --git a/lib/models/bisenetv2.py b/lib/models/bisenetv2.py index 56fc0a5..a06133b 100644 --- a/lib/models/bisenetv2.py +++ b/lib/models/bisenetv2.py @@ -23,6 +23,24 @@ def forward(self, x): return feat +class UpSample(nn.Module): + + def __init__(self, n_chan, factor=2): + super(UpSample, self).__init__() + out_chan = n_chan * factor * factor + self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0) + self.up = nn.PixelShuffle(factor) + self.init_weight() + + def forward(self, x): + feat = self.proj(x) + feat = self.up(feat) + return feat + + def init_weight(self): + nn.init.xavier_normal_(self.proj.weight, gain=1.) + + class DetailBranch(nn.Module): def __init__(self): @@ -234,6 +252,8 @@ def __init__(self): 128, 128, kernel_size=1, stride=1, padding=0, bias=False), ) + self.up1 = nn.Upsample(scale_factor=4) + self.up2 = nn.Upsample(scale_factor=4) ##TODO: does this really has no relu? self.conv = nn.Sequential( nn.Conv2d( @@ -249,12 +269,10 @@ def forward(self, x_d, x_s): left2 = self.left2(x_d) right1 = self.right1(x_s) right2 = self.right2(x_s) - right1 = F.interpolate( - right1, size=dsize, mode='bilinear', align_corners=True) + right1 = self.up1(right1) left = left1 * torch.sigmoid(right1) right = left2 * torch.sigmoid(right2) - right = F.interpolate( - right, size=dsize, mode='bilinear', align_corners=True) + right = self.up2(right) out = self.conv(left + right) return out @@ -262,38 +280,48 @@ def forward(self, x_d, x_s): class SegmentHead(nn.Module): - def __init__(self, in_chan, mid_chan, n_classes): + def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True): super(SegmentHead, self).__init__() self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1) self.drop = nn.Dropout(0.1) - self.conv_out = nn.Conv2d( - mid_chan, n_classes, kernel_size=1, stride=1, - padding=0, bias=True) + self.up_factor = up_factor + + out_chan = n_classes * up_factor * up_factor + if aux: + self.conv_out = nn.Sequential( + ConvBNReLU(mid_chan, up_factor * up_factor, 3, stride=1), + nn.Conv2d(up_factor * up_factor, out_chan, 1, 1, 0), + nn.PixelShuffle(up_factor) + ) + else: + self.conv_out = nn.Sequential( + nn.Conv2d(mid_chan, out_chan, 1, 1, 0), + nn.PixelShuffle(up_factor) + ) - def forward(self, x, size=None): + def forward(self, x): feat = self.conv(x) feat = self.drop(feat) feat = self.conv_out(feat) - if not size is None: - feat = F.interpolate(feat, size=size, - mode='bilinear', align_corners=True) return feat class BiSeNetV2(nn.Module): - def __init__(self, n_classes): + def __init__(self, n_classes, output_aux=True): super(BiSeNetV2, self).__init__() + self.output_aux = output_aux self.detail = DetailBranch() self.segment = SegmentBranch() self.bga = BGALayer() ## TODO: what is the number of mid chan ? - self.head = SegmentHead(128, 1024, n_classes) - self.aux2 = SegmentHead(16, 128, n_classes) - self.aux3 = SegmentHead(32, 128, n_classes) - self.aux4 = SegmentHead(64, 128, n_classes) - self.aux5_4 = SegmentHead(128, 128, n_classes) + self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False) + if self.output_aux: + self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4) + self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8) + self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16) + self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32) self.init_weights() @@ -303,12 +331,15 @@ def forward(self, x): feat2, feat3, feat4, feat5_4, feat_s = self.segment(x) feat_head = self.bga(feat_d, feat_s) - logits = self.head(feat_head, size) - logits_aux2 = self.aux2(feat2, size) - logits_aux3 = self.aux3(feat3, size) - logits_aux4 = self.aux4(feat4, size) - logits_aux5_4 = self.aux5_4(feat5_4, size) - return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4 + logits = self.head(feat_head) + if self.output_aux: + logits_aux2 = self.aux2(feat2) + logits_aux3 = self.aux3(feat3) + logits_aux4 = self.aux4(feat4) + logits_aux5_4 = self.aux5_4(feat5_4) + return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4 + pred = logits.argmax(dim=1) + return pred def init_weights(self): for name, module in self.named_modules(): @@ -365,11 +396,13 @@ def init_weights(self): # feat = segment(x)[0] # print(feat.size()) # - x = torch.randn(16, 3, 512, 1024) + x = torch.randn(16, 3, 1024, 2048) model = BiSeNetV2(n_classes=19) - logits = model(x)[0] - print(logits.size()) + outs = model(x) + for out in outs: + print(out.size()) + # print(logits.size()) - for name, param in model.named_parameters(): - if len(param.size()) == 1: - print(name) + # for name, param in model.named_parameters(): + # if len(param.size()) == 1: + # print(name) diff --git a/lib/transform_cv2.py b/lib/transform_cv2.py index 5dea3e6..02a324d 100644 --- a/lib/transform_cv2.py +++ b/lib/transform_cv2.py @@ -129,13 +129,14 @@ def __init__(self, mean=(0, 0, 0), std=(1., 1., 1.)): def __call__(self, im_lb): im, lb = im_lb['im'], im_lb['lb'] - im = im[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) # to rgb order + im = im.transpose(2, 0, 1).astype(np.float32) im = torch.from_numpy(im).div_(255) dtype, device = im.dtype, im.device mean = torch.as_tensor(self.mean, dtype=dtype, device=device)[:, None, None] std = torch.as_tensor(self.std, dtype=dtype, device=device)[:, None, None] im = im.sub_(mean).div_(std).clone() - lb = torch.from_numpy(lb.astype(np.int64).copy()).clone() + if not lb is None: + lb = torch.from_numpy(lb.astype(np.int64).copy()).clone() return dict(im=im, lb=lb) @@ -153,87 +154,5 @@ def __call__(self, im_lb): if __name__ == '__main__': - # from PIL import Image - # im = Image.open(imgpth) - # lb = Image.open(lbpth) - # print(lb.size) - # im.show() - # lb.show() - import cv2 - im = cv2.imread(imgpth) - lb = cv2.imread(lbpth, 0) - lb = lb * 10 - - trans = Compose([ - RandomHorizontalFlip(), - RandomShear(p=0.5, rate=3), - RandomRotate(p=0.5, degree=5), - RandomScale([0.5, 0.7]), - RandomCrop((768, 768)), - RandomErasing(p=1, size=(36, 36)), - ChannelShuffle(p=1), - ColorJitter( - brightness=0.3, - contrast=0.3, - saturation=0.5 - ), - # RandomEqualize(p=0.1), - ]) - # inten = dict(im=im, lb=lb) - # out = trans(inten) - # im = out['im'] - # lb = out['lb'] - # cv2.imshow('lb', lb) - # cv2.imshow('org', im) - # cv2.waitKey(0) - - - ### try merge rotate and shear here - im = cv2.imread(imgpth) - lb = cv2.imread(lbpth, 0) - im = cv2.resize(im, (1024, 512)) - lb = cv2.resize(lb, (1024, 512), interpolation=cv2.INTER_NEAREST) - lb = lb * 10 - inten = dict(im=im, lb=lb) - trans1 = Compose([ - RandomShear(p=1, rate=0.15), - # RandomRotate(p=1, degree=10), - ]) - trans2 = Compose([ - # RandomShearRotate(p_shear=1, p_rot=0, rate_shear=0.1, rot_degree=9), - RandomHFlipShearRotate(p_flip=0.5, p_shear=1, p_rot=0, rate_shear=0.1, rot_degree=9), - ]) - out1 = trans1(inten) - im1 = out1['im'] - lb1 = out1['lb'] - # cv2.imshow('lb', lb1) - cv2.imshow('org1', im1) - out2 = trans2(inten) - im2 = out2['im'] - lb2 = out2['lb'] - # cv2.imshow('lb', lb1) - # cv2.imshow('org2', im2) - cv2.waitKey(0) - print(np.sum(im1-im2)) - print('====') - #### - - - totensor = ToTensor( - mean=(0.406, 0.456, 0.485), - std=(0.225, 0.224, 0.229) - ) - # print(im[0, :2, :2]) - print(lb[:2, :2]) - out = totensor(out) - im = out['im'] - lb = out['lb'] - print(im.size()) - # print(im[0, :2, :2]) - # print(lb[:2, :2]) - - out = totensor(inten) - im = out['im'] - print(im.size()) - print(im[0, 502:504, 766:768]) + pass diff --git a/tensorrt/CMakeLists.txt b/tensorrt/CMakeLists.txt new file mode 100644 index 0000000..3e9bcbe --- /dev/null +++ b/tensorrt/CMakeLists.txt @@ -0,0 +1,21 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 2.8) + +PROJECT(segment) + +set(CMAKE_CXX_FLAGS "-std=c++14 -O1") + + +link_directories(/usr/local/cuda/lib64) + + +find_package(CUDA REQUIRED) +find_package(OpenCV REQUIRED) + +add_executable(segment segment.cpp trt_dep.cpp) +target_include_directories( + segment PUBLIC ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS}) +target_link_libraries( + segment -lnvinfer -lnvinfer_plugin -lnvparsers -lnvonnxparser + ${CUDA_LIBRARIES} + ${OpenCV_LIBRARIES} + ) diff --git a/tensorrt/README.md b/tensorrt/README.md new file mode 100644 index 0000000..9a1dca0 --- /dev/null +++ b/tensorrt/README.md @@ -0,0 +1,64 @@ + +### My platform + +* ubuntu 18.04 +* nvidia Tesla T4 gpu, driver 450.51.05 +* cuda 10.2, cudnn 7 +* cmake 3.10.2 +* opencv built from source +* tensorrt 7.0.0 + + +### Export model to onnx +I export the model like this: +``` +$ python tools/export_onnx.py --model bisenetv1 --weight-path /path/to/your/model.pth --outpath ./model.onnx +``` + +**NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here. + +### Build with source code +Just use the standard cmake build method: +``` +mkdir -p tensorrt/build +cd tensorrt/build +cmake .. +make +``` +This would generate a `./segment` in the `tensorrt/build` directory. + + +### Convert onnx to tensorrt model +If you can successfully compile the source code, you can parse the onnx model to tensorrt model like this: +``` +$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt +``` +If your gpu support acceleration with fp16 inferenece, you can add a `--fp16` option to in this step: +``` +$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp16 +``` +Note that I use the simplest method to parse the command line args, so please do **Not** change the order of the above command. + + +### Infer with one single image +Run inference like this: +``` +$ ./segment run /path/to/saved_model.trt /path/to/input/image.jpg /path/to/saved_img.jpg +``` + +### Test speed +The speed depends on the specific gpu platform you are working on, you can test the fps on your gpu like this: +``` +$ ./segment test /path/to/saved_model.trt +``` + + +## Tips: +1. Since tensorrt 7.0.0 cannot parse well the `bilinear interpolation` op exported from pytorch, I replace them with pytorch `nn.PixelShuffle`, which would bring some performance overhead(more flops and parameters), and make inference a bit slower. Also due to the `nn.PixelShuffle` op, you **must** export the onnx model with input size to be *n* times of 32. + +2. There would be some problem for tensorrt 7.0.0 to parse the `nn.AvgPool2d` op from pytorch with onnx opset11. So I use opset10 to export the model. + +4. The speed(fps) is tested on a single nvidia Tesla T4 gpu with `batchsize=1` and `cropsize=(1024,2048)`. Please note that T4 gpu is almost 2 times slower than 2080ti, you should evaluate the speed considering your own platform and cropsize. Also note that the performance would be affected if your gpu is concurrently working on other tasks. Please make sure no other program is running on your gpu when you test the speed. + +5. On my platform, after compiling with tensorrt, the model size of bisenetv1 is 33Mb(fp16) and 133Mb(fp32), and the size of bisenetv2 is 29Mb(fp16) and 54Mb(fp32). However, the fps of bisenetv1 is 60(fp16) and 19(fp32), while the fps of bisenetv2 is 50(fp16) and 16(fp32). It is obvious that bisenetv2 has fewer parameters than bisenetv1, but the speed is otherwise. I am not sure whether it is because tensorrt has worse optimization strategy in some ops used in bisenetv2(such as depthwise convolution) or because of the limitation of the gpu on different ops. Please tell me if you have better idea on this. + diff --git a/tensorrt/segment.cpp b/tensorrt/segment.cpp new file mode 100644 index 0000000..47f86d2 --- /dev/null +++ b/tensorrt/segment.cpp @@ -0,0 +1,182 @@ +#include "NvInfer.h" +#include "NvOnnxParser.h" +#include "NvInferPlugin.h" +#include +#include "NvInferRuntimeCommon.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "trt_dep.hpp" + + +using nvinfer1::IHostMemory; +using nvinfer1::IBuilder; +using nvinfer1::INetworkDefinition; +using nvinfer1::ICudaEngine; +using nvinfer1::IInt8Calibrator; +using nvinfer1::IBuilderConfig; +using nvinfer1::IRuntime; +using nvinfer1::IExecutionContext; +using nvinfer1::ILogger; +using nvinfer1::Dims3; +using nvinfer1::Dims2; +using Severity = nvinfer1::ILogger::Severity; + +using std::string; +using std::ios; +using std::ofstream; +using std::ifstream; +using std::vector; +using std::cout; +using std::endl; +using std::array; + +using cv::Mat; + + + + +vector> get_color_map(); + +void compile_onnx(vector args); +void run_with_trt(vector args); +void test_speed(vector args); + + +int main(int argc, char* argv[]) { + if (argc < 3) { + cout << "usage is ./segment compile/run/test\n"; + std::abort(); + } + + vector args; + for (int i{1}; i < argc; ++i) args.emplace_back(argv[i]); + + if (args[0] == "compile") { + if (argc < 4) { + cout << "usage is: ./segment compile input.onnx output.trt [--fp16]\n"; + std::abort(); + } + compile_onnx(args); + } else if (args[0] == "run") { + if (argc < 5) { + cout << "usage is ./segment run ./xxx.trt input.jpg result.jpg\n"; + std::abort(); + } + run_with_trt(args); + } else if (args[0] == "test") { + if (argc < 3) { + cout << "usage is ./segment test ./xxx.trt\n"; + std::abort(); + } + test_speed(args); + } + + return 0; +} + + +void compile_onnx(vector args) { + bool use_fp16{false}; + if ((args.size() >= 4) && args[3] == "--fp16") use_fp16 = true; + + TrtSharedEnginePtr engine = parse_to_engine(args[1], use_fp16); + serialize(engine, args[2]); +} + + +void run_with_trt(vector args) { + + TrtSharedEnginePtr engine = deserialize(args[1]); + + Dims3 i_dims = static_cast( + engine->getBindingDimensions(engine->getBindingIndex("input_image"))); + Dims3 o_dims = static_cast( + engine->getBindingDimensions(engine->getBindingIndex("preds"))); + const int iH{i_dims.d[2]}, iW{i_dims.d[3]}; + const int oH{o_dims.d[1]}, oW{o_dims.d[2]}; + + // prepare image and resize + Mat im = cv::imread(args[2]); + if (im.empty()) { + cout << "cannot read image \n"; + std::abort(); + } + // CHECK (!im.empty()) << "cannot read image \n"; + int orgH{im.rows}, orgW{im.cols}; + if ((orgH != iH) || orgW != iW) { + cout << "resize orignal image of (" << orgH << "," << orgW + << ") to (" << iH << ", " << iW << ") according to model require\n"; + cv::resize(im, im, cv::Size(iW, iH), cv::INTER_CUBIC); + } + + // normalize and convert to rgb + array mean{0.485f, 0.456f, 0.406f}; + array variance{0.229f, 0.224f, 0.225f}; + float scale = 1.f / 255.f; + for (int i{0}; i < 3; ++ i) { + variance[i] = 1.f / variance[i]; + } + vector data(iH * iW * 3); + for (int h{0}; h < iH; ++h) { + cv::Vec3b *p = im.ptr(h); + for (int w{0}; w < iW; ++w) { + for (int c{0}; c < 3; ++c) { + int idx = (2 - c) * iH * iW + h * iW + w; // to rgb order + data[idx] = (p[w][c] * scale - mean[c]) * variance[c]; + } + } + } + + // call engine + vector res = infer_with_engine(engine, data); + + // generate colored out + vector> color_map = get_color_map(); + Mat pred(cv::Size(oW, oH), CV_8UC3); + int idx{0}; + for (int i{0}; i < oH; ++i) { + uint8_t *ptr = pred.ptr(i); + for (int j{0}; j < oW; ++j) { + ptr[0] = color_map[res[idx]][0]; + ptr[1] = color_map[res[idx]][1]; + ptr[2] = color_map[res[idx]][2]; + ptr += 3; + ++ idx; + } + } + + // resize back and save + if ((orgH != oH) || orgW != oW) { + cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_NEAREST); + } + cv::imwrite(args[3], pred); + +} + + +vector> get_color_map() { + vector> color_map(256, vector(3)); + std::minstd_rand rand_eng(123); + std::uniform_int_distribution u(0, 255); + for (int i{0}; i < 256; ++i) { + for (int j{0}; j < 3; ++j) { + color_map[i][j] = u(rand_eng); + } + } + return color_map; +} + + +void test_speed(vector args) { + TrtSharedEnginePtr engine = deserialize(args[1]); + test_fps_with_engine(engine); +} diff --git a/tensorrt/trt_dep.cpp b/tensorrt/trt_dep.cpp new file mode 100644 index 0000000..40741c5 --- /dev/null +++ b/tensorrt/trt_dep.cpp @@ -0,0 +1,249 @@ + +#include +#include +#include +#include +#include +#include +#include + +#include "trt_dep.hpp" + + +using nvinfer1::IHostMemory; +using nvinfer1::IBuilder; +using nvinfer1::INetworkDefinition; +using nvinfer1::ICudaEngine; +using nvinfer1::IInt8Calibrator; +using nvinfer1::IBuilderConfig; +using nvinfer1::IRuntime; +using nvinfer1::IExecutionContext; +using nvinfer1::ILogger; +using nvinfer1::Dims3; +using nvinfer1::Dims2; +using Severity = nvinfer1::ILogger::Severity; + +using std::string; +using std::ios; +using std::ofstream; +using std::ifstream; +using std::vector; +using std::cout; +using std::endl; +using std::array; + + +Logger gLogger; + + +TrtSharedEnginePtr shared_engine_ptr(ICudaEngine* ptr) { + return TrtSharedEnginePtr(ptr, TrtDeleter()); +} + + +TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) { + unsigned int maxBatchSize{1}; + int memory_limit = 1U << 30; // 1G + + auto builder = TrtUniquePtr(nvinfer1::createInferBuilder(gLogger)); + if (!builder) { + cout << "create builder failed\n"; + std::abort(); + } + + const auto explicitBatch = 1U << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto network = TrtUniquePtr( + builder->createNetworkV2(explicitBatch)); + if (!network) { + cout << "create network failed\n"; + std::abort(); + } + + auto config = TrtUniquePtr(builder->createBuilderConfig()); + if (!config) { + cout << "create builder config failed\n"; + std::abort(); + } + + auto parser = TrtUniquePtr(nvonnxparser::createParser(*network, gLogger)); + if (!parser) { + cout << "create parser failed\n"; + std::abort(); + } + + int verbosity = (int)nvinfer1::ILogger::Severity::kWARNING; + bool state = parser->parseFromFile(onnx_pth.c_str(), verbosity); + if (!state) { + cout << "parse model failed\n"; + std::abort(); + } + + + config->setMaxWorkspaceSize(memory_limit); + if (use_fp16 && builder->platformHasFastFp16()) { + config->setFlag(nvinfer1::BuilderFlag::kFP16); // fp16 + } + // TODO: see if use dla or int8 + + auto output = network->getOutput(0); + output->setType(nvinfer1::DataType::kINT32); + + TrtSharedEnginePtr engine = shared_engine_ptr( + builder->buildEngineWithConfig(*network, *config)); + if (!engine) { + cout << "create engine failed\n"; + std::abort(); + } + + return engine; +} + + +void serialize(TrtSharedEnginePtr engine, string save_path) { + + auto trt_stream = TrtUniquePtr(engine->serialize()); + if (!trt_stream) { + cout << "serialize engine failed\n"; + std::abort(); + } + + ofstream ofile(save_path, ios::out | ios::binary); + ofile.write((const char*)trt_stream->data(), trt_stream->size()); + + ofile.close(); +} + + +TrtSharedEnginePtr deserialize(string serpth) { + + ifstream ifile(serpth, ios::in | ios::binary); + if (!ifile) { + cout << "read serialized file failed\n"; + std::abort(); + } + + ifile.seekg(0, ios::end); + const int mdsize = ifile.tellg(); + ifile.clear(); + ifile.seekg(0, ios::beg); + vector buf(mdsize); + ifile.read(&buf[0], mdsize); + ifile.close(); + cout << "model size: " << mdsize << endl; + + auto runtime = TrtUniquePtr(nvinfer1::createInferRuntime(gLogger)); + TrtSharedEnginePtr engine = shared_engine_ptr( + runtime->deserializeCudaEngine((void*)&buf[0], mdsize, nullptr)); + return engine; +} + + +vector infer_with_engine(TrtSharedEnginePtr engine, vector& data) { + Dims3 out_dims = static_cast( + engine->getBindingDimensions(engine->getBindingIndex("preds"))); + + const int batchsize{1}, H{out_dims.d[1]}, W{out_dims.d[2]}; + const int in_size{static_cast(data.size())}; + const int out_size{batchsize * H * W}; + vector buffs(2); + vector res(out_size); + + auto context = TrtUniquePtr(engine->createExecutionContext()); + if (!context) { + cout << "create execution context failed\n"; + std::abort(); + } + + cudaError_t state; + state = cudaMalloc(&buffs[0], in_size * sizeof(float)); + if (state) { + cout << "allocate memory failed\n"; + std::abort(); + } + state = cudaMalloc(&buffs[1], out_size * sizeof(int)); + if (state) { + cout << "allocate memory failed\n"; + std::abort(); + } + cudaStream_t stream; + state = cudaStreamCreate(&stream); + if (state) { + cout << "create stream failed\n"; + std::abort(); + } + + state = cudaMemcpyAsync( + buffs[0], &data[0], in_size * sizeof(float), + cudaMemcpyHostToDevice, stream); + if (state) { + cout << "transmit to device failed\n"; + std::abort(); + } + context->enqueueV2(&buffs[0], stream, nullptr); + // context->enqueue(1, &buffs[0], stream, nullptr); + state = cudaMemcpyAsync( + &res[0], buffs[1], out_size * sizeof(int), + cudaMemcpyDeviceToHost, stream); + if (state) { + cout << "transmit to host failed \n"; + std::abort(); + } + cudaStreamSynchronize(stream); + + cudaFree(buffs[0]); + cudaFree(buffs[1]); + cudaStreamDestroy(stream); + + return res; +} + + +void test_fps_with_engine(TrtSharedEnginePtr engine) { + Dims3 in_dims = static_cast( + engine->getBindingDimensions(engine->getBindingIndex("input_image"))); + Dims3 out_dims = static_cast( + engine->getBindingDimensions(engine->getBindingIndex("preds"))); + const int batchsize{1}; + const int oH{out_dims.d[1]}, oW{out_dims.d[2]}; + const int iH{in_dims.d[2]}, iW{in_dims.d[3]}; + const int in_size{batchsize * 3 * iH * iW}; + const int out_size{batchsize * oH * oW}; + + auto context = TrtUniquePtr(engine->createExecutionContext()); + if (!context) { + cout << "create execution context failed\n"; + std::abort(); + } + + vector buffs(2); + cudaError_t state; + state = cudaMalloc(&buffs[0], in_size * sizeof(float)); + if (state) { + cout << "allocate memory failed\n"; + std::abort(); + } + state = cudaMalloc(&buffs[1], out_size * sizeof(int)); + if (state) { + cout << "allocate memory failed\n"; + std::abort(); + } + + cout << "\ntest with cropsize of (" << iH << ", " << iW << ") ...\n"; + auto start = std::chrono::steady_clock::now(); + const int n_loops{1000}; + for (int i{0}; i < n_loops; ++i) { + // context->execute(1, &buffs[0]); + context->executeV2(&buffs[0]); + } + auto end = std::chrono::steady_clock::now(); + double duration = std::chrono::duration(end - start).count(); + duration /= 1000.; + cout << "running " << n_loops << " times, use time: " + << duration << "s" << endl; + cout << "fps is: " << static_cast(n_loops) / duration << endl; + + cudaFree(buffs[0]); + cudaFree(buffs[1]); +} + diff --git a/tensorrt/trt_dep.hpp b/tensorrt/trt_dep.hpp new file mode 100644 index 0000000..94a61b7 --- /dev/null +++ b/tensorrt/trt_dep.hpp @@ -0,0 +1,57 @@ +#ifndef _TRT_DEP_HPP_ +#define _TRT_DEP_HPP_ + +#include "NvInfer.h" +#include "NvOnnxParser.h" +#include "NvInferPlugin.h" +#include +#include "NvInferRuntimeCommon.h" + +#include +#include +#include +#include + + +using std::string; +using std::vector; +using std::cout; +using std::endl; + +using nvinfer1::ICudaEngine; +using nvinfer1::ILogger; +using Severity = nvinfer1::ILogger::Severity; + + +class Logger: public ILogger { + public: + void log(Severity severity, const char* msg) override { + if (severity != Severity::kINFO) { + std::cout << msg << std::endl; + } + } +}; + +struct TrtDeleter { + template + void operator()(T* obj) const { + if (obj) {obj->destroy();} + } +}; + +template +using TrtUniquePtr = std::unique_ptr; +using TrtSharedEnginePtr = std::shared_ptr; + + +extern Logger gLogger; + + +TrtSharedEnginePtr shared_engine_ptr(ICudaEngine* ptr); +TrtSharedEnginePtr parse_to_engine(string onnx_path, bool use_fp16); +void serialize(TrtSharedEnginePtr engine, string save_path); +TrtSharedEnginePtr deserialize(string serpth); +vector infer_with_engine(TrtSharedEnginePtr engine, vector& data); +void test_fps_with_engine(TrtSharedEnginePtr engine); + +#endif diff --git a/tools/conver_to_trt.py b/tools/conver_to_trt.py new file mode 100644 index 0000000..eb3dcf6 --- /dev/null +++ b/tools/conver_to_trt.py @@ -0,0 +1,54 @@ +import argparse +import os.path as osp +import sys +sys.path.insert(0, '.') + +import torch +from torch2trt import torch2trt + +from lib.models import model_factory +from configs import cfg_factory + +torch.set_grad_enabled(False) + + +parse = argparse.ArgumentParser() +parse.add_argument('--config', dest='config', type=str, default='deeplab_cityscapes',) +parse.add_argument('--weight-path', dest='weight_pth', type=str, + default='model_final.pth') +parse.add_argument('--outpath', dest='out_pth', type=str, + default='model.onnx') +args = parse.parse_args() + + +cfg = cfg_factory[args.config] +if cfg.use_sync_bn: cfg.use_sync_bn = False + +net = model_factory[cfg.model_type](19, output_aux=False).cuda() +# net.load_state_dict(torch.load(args.weight_pth)) +net.eval() + + +# dummy_input = torch.randn(1, 3, *cfg.crop_size) +dummy_input = torch.randn(1, 3, 1024, 2048).cuda() +input_names = ['input_image'] +output_names = ['preds',] + +trt_model = torch2trt(net, [dummy_input, ]) +# torch.onnx.export(net, dummy_input, args.out_pth, +# input_names=input_names, output_names=output_names, +# verbose=False, opset_version=11) +# +# +# import onnx +# import onnxruntime as ort +# +# print('checking {}'.format(args.out_pth)) +# onnx_obj = onnx.load(args.out_pth) +# print('model loaded') +# onnx.checker.check_model(onnx_obj) +# +# sess = ort.InferenceSession(args.out_pth, None) +# print(sess.get_outputs()[0].name) +# print(sess.get_outputs()[0].shape) +# print(len(sess.get_outputs())) diff --git a/tools/demo.py b/tools/demo.py index e37e636..15d94e9 100644 --- a/tools/demo.py +++ b/tools/demo.py @@ -4,11 +4,11 @@ import argparse import torch import torch.nn as nn -import torchvision.transforms as transforms from PIL import Image import numpy as np import cv2 +import lib.transform_cv2 as T from lib.models import model_factory from configs import cfg_factory @@ -25,10 +25,7 @@ cfg = cfg_factory[args.model] -# palette and mean/std palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8) -mean = torch.tensor([0.3257, 0.3690, 0.3223], dtype=torch.float32).view(-1, 1, 1) -std = torch.tensor([0.2112, 0.2148, 0.2115], dtype=torch.float32).view(-1, 1, 1) # define model net = model_factory[cfg.model_type](19) @@ -37,9 +34,12 @@ net.cuda() # prepare data -im = cv2.imread(args.img_path) -im = im[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) -im = torch.from_numpy(im).div_(255).sub_(mean).div_(std).unsqueeze(0).cuda() +to_tensor = T.ToTensor( + mean=(0.3257, 0.3690, 0.3223), # city, rgb + std=(0.2112, 0.2148, 0.2115), +) +im = cv2.imread(args.img_path)[:, :, ::-1] +im = to_tensor(dict(im=im, lb=None))['im'].unsqueeze(0).cuda() # inference out = net(im)[0].argmax(dim=1).squeeze().detach().cpu().numpy() diff --git a/tools/export_onnx.py b/tools/export_onnx.py new file mode 100644 index 0000000..a9b7458 --- /dev/null +++ b/tools/export_onnx.py @@ -0,0 +1,39 @@ +import argparse +import os.path as osp +import sys +sys.path.insert(0, '.') + +import torch + +from lib.models import model_factory +from configs import cfg_factory + +torch.set_grad_enabled(False) + + +parse = argparse.ArgumentParser() +parse.add_argument('--model', dest='model', type=str, default='bisenetv1',) +parse.add_argument('--weight-path', dest='weight_pth', type=str, + default='model_final.pth') +parse.add_argument('--outpath', dest='out_pth', type=str, + default='model.onnx') +args = parse.parse_args() + + +cfg = cfg_factory[args.model] +if cfg.use_sync_bn: cfg.use_sync_bn = False + +net = model_factory[cfg.model_type](19, output_aux=False) +net.load_state_dict(torch.load(args.weight_pth), strict=False) +net.eval() + + +# dummy_input = torch.randn(1, 3, *cfg.crop_size) +dummy_input = torch.randn(1, 3, 1024, 2048) +input_names = ['input_image'] +output_names = ['preds',] + +torch.onnx.export(net, dummy_input, args.out_pth, + input_names=input_names, output_names=output_names, + verbose=False, opset_version=10) + diff --git a/tools/train.py b/tools/train.py index cf42910..32a8625 100644 --- a/tools/train.py +++ b/tools/train.py @@ -128,16 +128,6 @@ def set_meters(): return time_meter, loss_meter, loss_pre_meter, loss_aux_meters -def save_model(states, save_pth): - logger = logging.getLogger() - logger.info('\nsave models to {}'.format(save_pth)) - for name, state in states.items(): - save_name = 'model_final_{}.pth'.format(name) - modelpth = osp.join(save_pth, save_name) - if dist.is_initialized() and dist.get_rank() == 0: - torch.save(state, modelpth) - - def train(): logger = logging.getLogger() is_dist = dist.is_initialized() diff --git a/tools/train_amp.py b/tools/train_amp.py index 284fc23..04e9721 100644 --- a/tools/train_amp.py +++ b/tools/train_amp.py @@ -99,6 +99,7 @@ def set_model_dist(net): net = nn.parallel.DistributedDataParallel( net, device_ids=[local_rank, ], + # find_unused_parameters=True, output_device=local_rank) return net @@ -112,15 +113,6 @@ def set_meters(): return time_meter, loss_meter, loss_pre_meter, loss_aux_meters -def save_model(states, save_pth): - logger = logging.getLogger() - logger.info('\nsave models to {}'.format(save_pth)) - for name, state in states.items(): - save_name = 'model_final_{}.pth'.format(name) - modelpth = osp.join(save_pth, save_name) - if dist.is_initialized() and dist.get_rank() == 0: - torch.save(state, modelpth) - def train(): logger = logging.getLogger()