From d45d5195474c59ea903d58ad0207e228e3f3b757 Mon Sep 17 00:00:00 2001 From: coincheung <867153576@qq.com> Date: Thu, 29 Nov 2018 13:26:55 +0800 Subject: [PATCH] model --- .gitignore | 7 +- README | 13 +++ adj.md | 1 + license.txt | 23 +++++ model.py | 199 +++++++++++++++++++++++++++++++++++++++++++ xception.py | 240 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 482 insertions(+), 1 deletion(-) create mode 100644 README create mode 100644 adj.md create mode 100644 license.txt create mode 100644 model.py create mode 100644 xception.py diff --git a/.gitignore b/.gitignore index 894a44c..b38d235 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ dist/ downloads/ eggs/ .eggs/ -lib/ lib64/ parts/ sdist/ @@ -102,3 +101,9 @@ venv.bak/ # mypy .mypy_cache/ + + +## Coin: +gtFine_trainvaltest.zip +gtFine/ + diff --git a/README b/README new file mode 100644 index 0000000..a6495c2 --- /dev/null +++ b/README @@ -0,0 +1,13 @@ + +README and scripts +------------------ + +The README and various scripts for inspection, preparation, and evaluation can be found in our git repository: +https://github.com/mcordts/cityscapesScripts + +Contact +------- + +Marius Cordts, Mohamed Omran +www.cityscapes-dataset.net +mail@cityscapes-dataset.net diff --git a/adj.md b/adj.md new file mode 100644 index 0000000..03393e6 --- /dev/null +++ b/adj.md @@ -0,0 +1 @@ +1. see if xception should end with bn or relu or conv diff --git a/license.txt b/license.txt new file mode 100644 index 0000000..20f10ca --- /dev/null +++ b/license.txt @@ -0,0 +1,23 @@ +---------------------- +The Cityscapes Dataset +---------------------- + + +License agreement +----------------- + +This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree: + +1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions. +2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website. +3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character. +4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain. +5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt). + + +Contact +------- + +Marius Cordts, Mohamed Omran +www.cityscapes-dataset.net +mail@cityscapes-dataset.net diff --git a/model.py b/model.py new file mode 100644 index 0000000..96bcaf5 --- /dev/null +++ b/model.py @@ -0,0 +1,199 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as modelzoo +import torch.nn.functional as F +import torchvision + + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks = 3, stride=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride=stride, + padding = 1, + bias=True) + self.bn = nn.BatchNorm2d(out_chan) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = F.relu(x, inplace = True) + return x + + +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, stride = 2) + self.conv2 = ConvBNReLU(64, 128, stride = 2) + self.conv3 = ConvBNReLU(128, 256, stride = 2) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.in_chan = in_chan + self.conv = nn.Conv2d(in_chan, + in_chan, + kernel_size = 1, + bias=True) + self.bn = nn.BatchNorm2d(in_chan) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + assert self.in_chan == x.size()[1] + in_ten = x + x = F.avg_pool2d(x, x.size()[2:]) + x = self.conv(x) + x = self.bn(x) + x = self.sigmoid(x) + x = torch.mul(in_ten, x) + return x + + + +class ContextPath(nn.Module): + def __init__(self, n_classes = 10, *args, **kwargs): + super(ContextPath, self).__init__() + resnet = torchvision.models.resnet18() + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu + self.maxpool = resnet.maxpool + self.layer1 = resnet.layer1 + self.layer2 = resnet.layer2 + self.layer3 = resnet.layer3 + self.layer4 = resnet.layer4 + + self.arm16 = AttentionRefinementModule(256) + self.arm32 = AttentionRefinementModule(512) + + self.conv_feat16 = nn.Conv2d(256, + n_classes, + kernel_size = 3, + bias=True) + self.conv_feat32 = nn.Conv2d(512, + n_classes, + kernel_size = 3, + bias=True) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + + feat16 = self.layer3(x) + feat32 = self.layer4(feat16) + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + feat16_arm = self.arm16(feat16) + feat32_arm = self.arm32(feat32) + + feat32_with_avg = torch.mul(feat32_arm, avg) + feat32_up = F.interpolate(feat32_with_avg, scale_factor = 4) + feat16_up = F.interpolate(feat16_arm, scale_factor = 2) + + feat_out = torch.cat((feat32_up, feat16_up), dim = 1) + feat_out16 = self.conv_feat16(feat16) + feat_out32 = self.conv_feat32(feat32) + + return feat_out, feat_out16, feat_out32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + + + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, n_classes, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, n_classes, ks = 3) + self.conv1 = nn.Conv2d(n_classes, n_classes, 1) + self.conv2 = nn.Conv2d(n_classes, n_classes, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = torch.cat((fsp, fcp), dim = 1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = F.relu(atten, inplace = True) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.sp = SpatialPath() + self.cp = ContextPath(n_classes) + self.ffm = FeatureFusionModule(1024, n_classes) + + + def forward(self, x): + feat_sp = self.sp(x) + feat_cp, feat16, feat32 = self.cp(x) + feat_out = self.ffm(feat_sp, feat_cp) + return feat_out, feat16, feat32 + + + + +if __name__ == "__main__": + net = BiSeNet(21) + in_ten = torch.randn(10, 3, 224, 224) + out, out16, out32 = net(in_ten) + print(out.shape) + print(out16.shape) + print(out32.shape) + + convbnrelu = ConvBNReLU(3, 10) + print(convbnrelu(in_ten).shape) + sp = SpatialPath() + out = sp(in_ten) + print(out.shape) + cp = ContextPath(10) + out, out16, out32 = cp(in_ten) + print(out.shape) + print(out16.shape) + print(out32.shape) + # arm = AttentionRefinementModule(3, 10) + # out = arm(in_ten) + # print(out.shape) + # # out_x, out_aux = net(in_ten) + # # print(out_x.shape) + # # print(out_aux.shape) + # in_ten = torch.randn(1, 2, 3,3) + # print(in_ten) + # import numpy as np + # sig = np.arange(2).reshape(1,2,1,1).astype(np.float32) + # sig = torch.tensor(sig) + # print(torch.mul(in_ten, sig)) + + ffm = FeatureFusionModule(in_chan = 1024, n_classes = 21) + feat1 = torch.randn(10, 768, 32, 32) + feat2 = torch.randn(10, 256, 32, 32) + feat_out = ffm(feat1, feat2) + print(feat_out.shape) diff --git a/xception.py b/xception.py new file mode 100644 index 0000000..ca56300 --- /dev/null +++ b/xception.py @@ -0,0 +1,240 @@ +""" +Coin: I must give thanks to [Cadene](https://github.com/Cadene) for his xception models. +""" +""" +Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) + +@author: tstandley +Adapted by cadene + +Creates an Xception Model as defined in: + +Francois Chollet +Xception: Deep Learning with Depthwise Separable Convolutions +https://arxiv.org/pdf/1610.02357.pdf + +This weights ported from the Keras implementation. Achieves the following performance on the validation set: + +Loss:0.9173 Prec@1:78.892 Prec@5:94.292 + +REMEMBER to set your image size to 3x299x299 for both test and validation + +normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + +The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 +""" +# from __future__ import print_function, division, absolute_import +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.nn import init + +__all__ = ['xception'] + +pretrained_settings = { + 'xception': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000, + 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } + } +} + + +class SeparableConv2d(nn.Module): + def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): + super(SeparableConv2d,self).__init__() + + self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) + self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) + + def forward(self,x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): + super(Block, self).__init__() + + if out_filters != in_filters or strides!=1: + self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip=None + + self.relu = nn.ReLU(inplace=True) + rep=[] + + filters=in_filters + if grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps-1): + rep.append(self.relu) + rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3,strides,1)) + self.rep = nn.Sequential(*rep) + + def forward(self,inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x+=skip + return x + + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + def __init__(self, num_classes=1000): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + self.num_classes = num_classes + + self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32,64,3,bias=False) + self.bn2 = nn.BatchNorm2d(64) + #do relu here + + self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) + self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) + self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) + + self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) + + self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) + + self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) + + self.conv3 = SeparableConv2d(1024,1536,3,1,1) + self.bn3 = nn.BatchNorm2d(1536) + + #do relu here + self.conv4 = SeparableConv2d(1536,2048,3,1,1) + self.bn4 = nn.BatchNorm2d(2048) + + # self.fc = nn.Linear(2048, num_classes) + + # #------- init weights -------- + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, math.sqrt(2. / n)) + # elif isinstance(m, nn.BatchNorm2d): + # m.weight.data.fill_(1) + # m.bias.data.zero_() + # #----------------------------- + + def features(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + return x + + def logits(self, features): + ## TODO + x = self.relu(features) + + # x = F.adaptive_avg_pool2d(x, (1, 1)) + # x = x.view(x.size(0), -1) + # x = self.last_linear(x) + return x + + def forward(self, x): + x = self.features(x) + x = self.logits(x) + return x + + +def xception(num_classes=1000, pretrained='imagenet'): + model = Xception(num_classes=num_classes) + if pretrained: + settings = pretrained_settings['xception'][pretrained] + # assert num_classes == settings['num_classes'], \ + # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + + model = Xception(num_classes=num_classes) + state_dict = model_zoo.load_url(settings['url']) + new_state = {k:v for k,v in state_dict.items() if not 'fc' in k} + model.load_state_dict(new_state) + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + model.mean = settings['mean'] + model.std = settings['std'] + + # TODO: ugly + # model.last_linear = model.fc + # del model.fc + return model