Skip to content

Commit

Permalink
add tensorrt
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Sep 29, 2020
1 parent aa3876b commit 12cd095
Show file tree
Hide file tree
Showing 18 changed files with 815 additions and 174 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,5 @@ play.py
preprocess_data.py
res/
adj.md
tensorrt/build/*

25 changes: 20 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV


The mIOU evaluation result of the models trained and evaluated on cityscapes train/val set is:
| none | ss | ssc | msf | mscf | fps | link |
| none | ss | ssc | msf | mscf | fps(fp16/fp32) | link |
|------|:--:|:---:|:---:|:----:|:---:|:----:|
| bisenetv1 | 74.85 | 76.46 | 77.36 | 78.72 | - | [download](https://drive.google.com/file/d/1e1_E7OrpjTaD5Rael7Fus5lg-uGZ5TUZ/view?usp=sharing) |
| bisenetv2 | 74.39 | 74.44 | 76.10 | 75.94 | - | [download](https://drive.google.com/file/d/1r_F-KZg-3s2pPcHRIuHZhZ0DQ0wocudk/view?usp=sharing) |
| bisenetv1 | 75.55 | 76.90 | 77.40 | 78.91 | 60/19 | [download](https://drive.google.com/file/d/140MBBAt49N1z1wsKueoFA6HB_QuYud8i/view?usp=sharing) |
| bisenetv2 | 74.12 | 74.18 | 75.89 | 75.87 | 50/16 | [download](https://drive.google.com/file/d/1qq38u9JT4pp1ubecGLTCHHtqwntH0FCY/view?usp=sharing) |

> Where **ss** means single scale evaluation, **ssc** means single scale crop evaluation, **msf** means multi-scale evaluation with flip augment, and **mscf** means multi-scale crop evaluation with flip evaluation. The eval scales of multi-scales evaluation are `[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]`, and the crop size of crop evaluation is `[1024, 1024]`.
> The fps is tested in different way from the paper. For more information, please see [here](./tensorrt).
Note that the model has a big variance, which means that the results of training for many times would vary within a relatively big margin. For example, if you train bisenetv2 for many times, you will observe that the result of **ss** evaluation of bisenetv2 varies between 72.1-74.4.


## platform
My platform is like this:
* ubuntu 16.04
* cuda 10.1.243
* ubuntu 18.04
* nvidia Tesla T4 gpu, driver 450.51.05
* cuda 10.2
* cudnn 7
* miniconda python 3.6.9
* pytorch 1.6.0
Expand Down Expand Up @@ -59,7 +62,12 @@ Then you need to change the field of `im_root` and `train/val_im_anns` in the co
In order to train the model, you can run command like this:
```
$ export CUDA_VISIBLE_DEVICES=0,1
# if you want to train with apex
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --model bisenetv2 # or bisenetv1
# if you want to train with pytorch fp16 feature from torch 1.6
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --model bisenetv2 # or bisenetv1
```

Note that though `bisenetv2` has fewer flops, it requires much more training iterations. The the training time of `bisenetv1` is shorter.
Expand All @@ -70,6 +78,9 @@ You can also load the trained model weights and finetune from it:
```
$ export CUDA_VISIBLE_DEVICES=0,1
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --finetune-from ./res/model_final.pth --model bisenetv2 # or bisenetv1
# same with pytorch fp16 feature
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --finetune-from ./res/model_final.pth --model bisenetv2 # or bisenetv1
```


Expand All @@ -79,6 +90,10 @@ You can also evaluate a trained model like this:
$ python tools/evaluate.py --model bisenetv1 --weight-path /path/to/your/weight.pth
```

## Infer with tensorrt
You can go to [tensorrt](./tensorrt) For details.


### Be aware that this is the refactored version of the original codebase. You can go to the `old` directory for original implementation.


2 changes: 1 addition & 1 deletion datasets/cityscapes/gtFine
2 changes: 1 addition & 1 deletion datasets/cityscapes/leftImg8bit
2 changes: 1 addition & 1 deletion lib/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, dataroot, annpath, trans_func=None, mode='train'):

def __getitem__(self, idx):
impth, lbpth = self.img_paths[idx], self.lb_paths[idx]
img, label = cv2.imread(impth), cv2.imread(lbpth, 0)
img, label = cv2.imread(impth)[:, :, ::-1], cv2.imread(lbpth, 0)
if not self.lb_map is None:
label = self.lb_map[label]
im_lb = dict(im=img, lb=label)
Expand Down
75 changes: 50 additions & 25 deletions lib/models/bisenetv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


class ConvBNReLU(nn.Module):

def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
Expand All @@ -38,16 +39,39 @@ def init_weight(self):
if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class UpSample(nn.Module):

def __init__(self, n_chan, factor=2):
super(UpSample, self).__init__()
out_chan = n_chan * factor * factor
self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
self.up = nn.PixelShuffle(factor)
self.init_weight()

def forward(self, x):
feat = self.proj(x)
feat = self.up(feat)
return feat

def init_weight(self):
nn.init.xavier_normal_(self.proj.weight, gain=1.)


class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):

def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.up_factor = up_factor
out_chan = n_classes * up_factor * up_factor
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True)
self.up = nn.PixelShuffle(up_factor)
self.init_weight()

def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
x = self.up(x)
return x

def init_weight(self):
Expand Down Expand Up @@ -79,7 +103,7 @@ def __init__(self, in_chan, out_chan, *args, **kwargs):

def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = torch.mean(feat, dim=(2, 3), keepdim=True)
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
Expand All @@ -102,28 +126,25 @@ def __init__(self, *args, **kwargs):
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
self.up32 = nn.Upsample(scale_factor=2.)
self.up16 = nn.Upsample(scale_factor=2.)

self.init_weight()

def forward(self, x):
H0, W0 = x.size()[2:]
feat8, feat16, feat32 = self.resnet(x)
H8, W8 = feat8.size()[2:]
H16, W16 = feat16.size()[2:]
H32, W32 = feat32.size()[2:]

avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = torch.mean(feat32, dim=(2, 3), keepdim=True)
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')

feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
feat32_sum = feat32_arm + avg
feat32_up = self.up32(feat32_sum)
feat32_up = self.conv_head32(feat32_up)

feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
feat16_up = self.up16(feat16_sum)
feat16_up = self.conv_head16(feat16_up)

return feat16_up, feat32_up # x8, x16
Expand Down Expand Up @@ -203,7 +224,7 @@ def __init__(self, in_chan, out_chan, *args, **kwargs):
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = torch.mean(feat, dim=(2, 3), keepdim=True)
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
Expand Down Expand Up @@ -231,14 +252,17 @@ def get_params(self):


class BiSeNetV1(nn.Module):
def __init__(self, n_classes, *args, **kwargs):

def __init__(self, n_classes, output_aux=True, *args, **kwargs):
super(BiSeNetV1, self).__init__()
self.cp = ContextPath()
self.sp = SpatialPath()
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
self.conv_out = BiSeNetOutput(256, 256, n_classes, up_factor=8)
self.output_aux = output_aux
if self.output_aux:
self.conv_out16 = BiSeNetOutput(128, 64, n_classes, up_factor=8)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes, up_factor=16)
self.init_weight()

def forward(self, x):
Expand All @@ -248,13 +272,12 @@ def forward(self, x):
feat_fuse = self.ffm(feat_sp, feat_cp8)

feat_out = self.conv_out(feat_fuse)
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)

feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
return feat_out, feat_out16, feat_out32
if self.output_aux:
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)
return feat_out, feat_out16, feat_out32
feat_out = feat_out.argmax(dim=1)
return feat_out

def init_weight(self):
for ly in self.children():
Expand All @@ -276,11 +299,13 @@ def get_params(self):


if __name__ == "__main__":
net = BiSeNet(19)
net = BiSeNetV1(19)
net.cuda()
net.eval()
in_ten = torch.randn(16, 3, 640, 480).cuda()
out, out16, out32 = net(in_ten)
print(out.shape)
print(out16.shape)
print(out32.shape)

net.get_params()
Loading

0 comments on commit 12cd095

Please sign in to comment.