Skip to content

Commit

Permalink
model
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Nov 29, 2018
1 parent af28f4a commit d45d519
Show file tree
Hide file tree
Showing 6 changed files with 482 additions and 1 deletion.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
Expand Down Expand Up @@ -102,3 +101,9 @@ venv.bak/

# mypy
.mypy_cache/


## Coin:
gtFine_trainvaltest.zip
gtFine/

13 changes: 13 additions & 0 deletions README
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

README and scripts
------------------

The README and various scripts for inspection, preparation, and evaluation can be found in our git repository:
https://github.com/mcordts/cityscapesScripts

Contact
-------

Marius Cordts, Mohamed Omran
www.cityscapes-dataset.net
[email protected]
1 change: 1 addition & 0 deletions adj.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1. see if xception should end with bn or relu or conv
23 changes: 23 additions & 0 deletions license.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
----------------------
The Cityscapes Dataset
----------------------


License agreement
-----------------

This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:

1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions.
2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website.
3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.
4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.
5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt).


Contact
-------

Marius Cordts, Mohamed Omran
www.cityscapes-dataset.net
[email protected]
199 changes: 199 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-


import torch
import torch.nn as nn
import torch.utils.model_zoo as modelzoo
import torch.nn.functional as F
import torchvision


resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'

class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks = 3, stride=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size = ks,
stride=stride,
padding = 1,
bias=True)
self.bn = nn.BatchNorm2d(out_chan)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = F.relu(x, inplace = True)
return x


class SpatialPath(nn.Module):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, stride = 2)
self.conv2 = ConvBNReLU(64, 128, stride = 2)
self.conv3 = ConvBNReLU(128, 256, stride = 2)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x


class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.in_chan = in_chan
self.conv = nn.Conv2d(in_chan,
in_chan,
kernel_size = 1,
bias=True)
self.bn = nn.BatchNorm2d(in_chan)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
assert self.in_chan == x.size()[1]
in_ten = x
x = F.avg_pool2d(x, x.size()[2:])
x = self.conv(x)
x = self.bn(x)
x = self.sigmoid(x)
x = torch.mul(in_ten, x)
return x



class ContextPath(nn.Module):
def __init__(self, n_classes = 10, *args, **kwargs):
super(ContextPath, self).__init__()
resnet = torchvision.models.resnet18()
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4

self.arm16 = AttentionRefinementModule(256)
self.arm32 = AttentionRefinementModule(512)

self.conv_feat16 = nn.Conv2d(256,
n_classes,
kernel_size = 3,
bias=True)
self.conv_feat32 = nn.Conv2d(512,
n_classes,
kernel_size = 3,
bias=True)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)

feat16 = self.layer3(x)
feat32 = self.layer4(feat16)
avg = F.avg_pool2d(feat32, feat32.size()[2:])
feat16_arm = self.arm16(feat16)
feat32_arm = self.arm32(feat32)

feat32_with_avg = torch.mul(feat32_arm, avg)
feat32_up = F.interpolate(feat32_with_avg, scale_factor = 4)
feat16_up = F.interpolate(feat16_arm, scale_factor = 2)

feat_out = torch.cat((feat32_up, feat16_up), dim = 1)
feat_out16 = self.conv_feat16(feat16)
feat_out32 = self.conv_feat32(feat32)

return feat_out, feat_out16, feat_out32

def init_weight(self):
state_dict = modelzoo.load_url(resnet18_url)
self_state_dict = self.state_dict()




class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, n_classes, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, n_classes, ks = 3)
self.conv1 = nn.Conv2d(n_classes, n_classes, 1)
self.conv2 = nn.Conv2d(n_classes, n_classes, 1)
self.sigmoid = nn.Sigmoid()

def forward(self, fsp, fcp):
fcat = torch.cat((fsp, fcp), dim = 1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = F.relu(atten, inplace = True)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out



class BiSeNet(nn.Module):
def __init__(self, n_classes, *args, **kwargs):
super(BiSeNet, self).__init__()
self.sp = SpatialPath()
self.cp = ContextPath(n_classes)
self.ffm = FeatureFusionModule(1024, n_classes)


def forward(self, x):
feat_sp = self.sp(x)
feat_cp, feat16, feat32 = self.cp(x)
feat_out = self.ffm(feat_sp, feat_cp)
return feat_out, feat16, feat32




if __name__ == "__main__":
net = BiSeNet(21)
in_ten = torch.randn(10, 3, 224, 224)
out, out16, out32 = net(in_ten)
print(out.shape)
print(out16.shape)
print(out32.shape)

convbnrelu = ConvBNReLU(3, 10)
print(convbnrelu(in_ten).shape)
sp = SpatialPath()
out = sp(in_ten)
print(out.shape)
cp = ContextPath(10)
out, out16, out32 = cp(in_ten)
print(out.shape)
print(out16.shape)
print(out32.shape)
# arm = AttentionRefinementModule(3, 10)
# out = arm(in_ten)
# print(out.shape)
# # out_x, out_aux = net(in_ten)
# # print(out_x.shape)
# # print(out_aux.shape)
# in_ten = torch.randn(1, 2, 3,3)
# print(in_ten)
# import numpy as np
# sig = np.arange(2).reshape(1,2,1,1).astype(np.float32)
# sig = torch.tensor(sig)
# print(torch.mul(in_ten, sig))

ffm = FeatureFusionModule(in_chan = 1024, n_classes = 21)
feat1 = torch.randn(10, 768, 32, 32)
feat2 = torch.randn(10, 256, 32, 32)
feat_out = ffm(feat1, feat2)
print(feat_out.shape)
Loading

0 comments on commit d45d519

Please sign in to comment.