Skip to content

Commit

Permalink
Merge pull request #12 from BloodAxe/develop
Browse files Browse the repository at this point in the history
0.0.9
  • Loading branch information
BloodAxe authored Jun 3, 2019
2 parents 41ea208 + 3c179e6 commit 3ef7f37
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 84 deletions.
2 changes: 1 addition & 1 deletion pytorch_toolbelt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import

__version__ = '0.0.8'
__version__ = '0.0.9'
76 changes: 76 additions & 0 deletions pytorch_toolbelt/modules/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from functools import partial

from torch import nn
from torch.nn import functional as F


def swish(x):
return x * x.sigmoid()


def hard_sigmoid(x, inplace=False):
return F.relu6(x + 3, inplace) / 6


def hard_swish(x, inplace=False):
return x * hard_sigmoid(x, inplace)


class HardSigmoid(nn.Module):
def __init__(self, inplace=False):
super(HardSigmoid, self).__init__()
self.inplace = inplace

def forward(self, x):
return hard_sigmoid(x, inplace=self.inplace)


class Swish(nn.Module):
def __init__(self, inplace=False):
super(Swish, self).__init__()

def forward(self, x):
return swish(x)


class HardSwish(nn.Module):
def __init__(self, inplace=False):
super(HardSwish, self).__init__()
self.inplace = inplace

def forward(self, x):
return hard_swish(x, inplace=self.inplace)


def get_activation_module(activation_name: str, **kwargs) -> nn.Module:
if activation_name.lower() == 'relu':
return partial(nn.ReLU, **kwargs)

if activation_name.lower() == 'relu6':
return partial(nn.ReLU6, **kwargs)

if activation_name.lower() == 'leaky_relu':
return partial(nn.LeakyReLU, **kwargs)

if activation_name.lower() == 'elu':
return partial(nn.ELU, **kwargs)

if activation_name.lower() == 'selu':
return partial(nn.SELU, **kwargs)

if activation_name.lower() == 'celu':
return partial(nn.CELU, **kwargs)

if activation_name.lower() == 'glu':
return partial(nn.GLU, **kwargs)

if activation_name.lower() == 'prelu':
return partial(nn.PReLU, **kwargs)

if activation_name.lower() == 'hard_sigmoid':
return partial(HardSigmoid, **kwargs)

if activation_name.lower() == 'hard_swish':
return partial(HardSwish, **kwargs)

raise ValueError(f'Activation \'{activation_name}\' is not supported')
33 changes: 20 additions & 13 deletions pytorch_toolbelt/modules/backbone/mobilenet.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from __future__ import absolute_import

import torch.nn as nn
import math

from ..activations import get_activation_module


def conv_bn(inp, oup, stride):
def conv_bn(inp, oup, stride, activation: nn.Module):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
activation(inplace=True)
)


def conv_1x1_bn(inp, oup):
def conv_1x1_bn(inp, oup, activation: nn.Module):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
activation(inplace=True)
)


class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
def __init__(self, inp, oup, stride, expand_ratio, activation: nn.Module):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
Expand All @@ -32,7 +36,7 @@ def __init__(self, inp, oup, stride, expand_ratio):
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
activation(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
Expand All @@ -42,11 +46,11 @@ def __init__(self, inp, oup, stride, expand_ratio):
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
activation(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
activation(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
Expand All @@ -60,8 +64,11 @@ def forward(self, x):


class MobileNetV2(nn.Module):
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
def __init__(self, n_class=1000, input_size=224, width_mult=1., activation='relu6'):
super(MobileNetV2, self).__init__()

act = get_activation_module(activation)

block = InvertedResidual
input_channel = 32
last_channel = 1280
Expand All @@ -80,7 +87,7 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
assert input_size % 32 == 0
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
self.layer0 = conv_bn(3, input_channel, 2)
self.layer0 = conv_bn(3, input_channel, 2, act)

# building inverted residual blocks
for layer_index, (t, c, n, s) in enumerate(interverted_residual_setting):
Expand All @@ -89,16 +96,16 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
blocks = []
for i in range(n):
if i == 0:
blocks.append(block(input_channel, output_channel, s, expand_ratio=t))
blocks.append(block(input_channel, output_channel, s, expand_ratio=t, activation=act))
else:
blocks.append(block(input_channel, output_channel, 1, expand_ratio=t))
blocks.append(block(input_channel, output_channel, 1, expand_ratio=t, activation=act))

input_channel = output_channel

self.add_module(f'layer{layer_index + 1}', nn.Sequential(*blocks))

# building last several layers
self.final_layer = conv_1x1_bn(input_channel, self.last_channel)
self.final_layer = conv_1x1_bn(input_channel, self.last_channel, activation=act)

# building classifier
self.classifier = nn.Sequential(
Expand Down
77 changes: 24 additions & 53 deletions pytorch_toolbelt/modules/backbone/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,15 @@
import torch.nn as nn
import torch.nn.functional as F

from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D
from pytorch_toolbelt.modules import Identity


def swish(x):
return x * x.sigmoid()


def hard_sigmoid(x, inplace=False):
return F.relu6(x + 3, inplace) / 6


def hard_swish(x, inplace=False):
return x * hard_sigmoid(x, inplace)


class HardSigmoid(nn.Module):
def __init__(self, inplace=False):
super(HardSigmoid, self).__init__()
self.inplace = inplace

def forward(self, x):
return hard_sigmoid(x, inplace=self.inplace)


class HardSwish(nn.Module):
def __init__(self, inplace=False):
super(HardSwish, self).__init__()
self.inplace = inplace

def forward(self, x):
return hard_swish(x, inplace=self.inplace)
# from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D
from pytorch_toolbelt.modules.activations import HardSwish, HardSigmoid
from pytorch_toolbelt.modules.identity import Identity


def _make_divisible(v, divisor, min_value=None):
"""
Ensure that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
Expand All @@ -59,9 +31,9 @@ def _make_divisible(v, divisor, min_value=None):
return new_v


# https://github.com/jonnedtc/Squeeze-Excitation-PyTorch/blob/master/networks.py
class SqEx(nn.Module):
"""Squeeze-Excitation block, implemented in ONNX & CoreML friendly way
"""Squeeze-Excitation block. Implemented in ONNX & CoreML friendly way.
Original implementation: https://github.com/jonnedtc/Squeeze-Excitation-PyTorch/blob/master/networks.py
"""

def __init__(self, n_features, reduction=4):
Expand Down Expand Up @@ -89,24 +61,26 @@ def __init__(self, inplanes, outplanes, expplanes, k=3, stride=1, drop_prob=0, n
super(LinearBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, expplanes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(expplanes)
self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
# TODO: first doesn't have act?
self.db1 = nn.Dropout2d(drop_prob)
# self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
# stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
self.act1 = activation(**act_params) # first does have act according to MobileNetV2

self.conv2 = nn.Conv2d(expplanes, expplanes, kernel_size=k, stride=stride, padding=k // 2, bias=False,
groups=expplanes)
self.bn2 = nn.BatchNorm2d(expplanes)
self.db2 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
self.db2 = nn.Dropout2d(drop_prob)
# self.db2 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
# stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
self.act2 = activation(**act_params)

self.se = SqEx(expplanes) if SE else Identity()

self.conv3 = nn.Conv2d(expplanes, outplanes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(outplanes)
self.db3 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
self.act3 = activation(**act_params)
self.db3 = nn.Dropout2d(drop_prob)
# self.db3 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
# stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)

self.stride = stride
self.expplanes = expplanes
Expand All @@ -119,6 +93,7 @@ def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.db1(out)
out = self.act1(out)

out = self.conv2(out)
out = self.bn2(out)
Expand All @@ -130,10 +105,9 @@ def forward(self, x):
out = self.conv3(out)
out = self.bn3(out)
out = self.db3(out)
out = self.act3(out)

if self.stride == 1 and self.inplanes == self.outplanes: # TODO: or add 1x1?
out = out + residual # No inplace if there is in-place activation before
out += residual # No inplace if there is in-place activation before

return out

Expand Down Expand Up @@ -187,7 +161,6 @@ def __init__(self, inplanes, num_classes, expplanes1, expplanes2):
self.avgpool = nn.AdaptiveAvgPool2d(1)

self.conv2 = nn.Conv2d(expplanes1, expplanes2, kernel_size=1, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(expplanes2)
self.act2 = HardSwish(inplace=True)

self.dropout = nn.Dropout(p=0.2, inplace=True)
Expand All @@ -207,7 +180,6 @@ def forward(self, x):
out = self.avgpool(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.act2(out)

# flatten for input to fully-connected layer
Expand Down Expand Up @@ -246,16 +218,16 @@ def __init__(self, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num
[80, 184, 80, 1, 3, drop_prob, False, HardSwish], # -> 14x14
[80, 480, 112, 1, 3, drop_prob, True, HardSwish], # -> 14x14
[112, 672, 112, 1, 3, drop_prob, True, HardSwish], # -> 14x14
[112, 672, 160, 1, 5, drop_prob, True, HardSwish], # -> 14x14
[160, 672, 160, 2, 5, drop_prob, True, HardSwish], # -> 7x7 #TODO
[112, 672, 160, 2, 5, drop_prob, True, HardSwish], # -> 7x7
[160, 672, 160, 1, 5, drop_prob, True, HardSwish], # -> 7x7
[160, 960, 160, 1, 5, drop_prob, True, HardSwish], # -> 7x7
]
self.bottlenecks_setting_small = [
# in, exp, out, s, k, dp, se, act
[16, 64, 24, 2, 3, 0, True, nn.ReLU], # -> 56x56 #TODO
[24, 72, 24, 2, 3, 0, False, nn.ReLU], # -> 28x28
[24, 88, 40, 1, 3, 0, False, nn.ReLU], # -> 28x28
[40, 96, 40, 2, 5, 0, True, HardSwish], # -> 14x14 #TODO
[16, 64, 16, 2, 3, 0, True, nn.ReLU], # -> 56x56
[16, 72, 24, 2, 3, 0, False, nn.ReLU], # -> 28x28
[24, 88, 24, 1, 3, 0, False, nn.ReLU], # -> 28x28
[24, 96, 40, 2, 5, 0, True, HardSwish], # -> 14x14
[40, 240, 40, 1, 5, drop_prob, True, HardSwish], # -> 14x14
[40, 240, 40, 1, 5, drop_prob, True, HardSwish], # -> 14x14
[40, 120, 48, 1, 5, drop_prob, True, HardSwish], # -> 14x14
Expand Down Expand Up @@ -290,7 +262,6 @@ def __init__(self, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num

def _make_bottlenecks(self):
layers = []

modules = OrderedDict()
stage_name = "Bottleneck"

Expand Down
10 changes: 5 additions & 5 deletions pytorch_toolbelt/modules/dropblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@ def forward(self, x):
mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).to(x)

# compute block mask
block_mask = self._compute_block_mask(mask)
block_mask, keeped = self._compute_block_mask(mask)

# apply block mask
out = x * block_mask[:, None, :, :]

# scale output
out = out * block_mask.numel() / block_mask.sum()

out = out * (block_mask.numel() / keeped).to(out)
return out

def _compute_block_mask(self, mask):
Expand All @@ -60,9 +59,10 @@ def _compute_block_mask(self, mask):
if self.block_size % 2 == 0:
block_mask = block_mask[:, :, :-1, :-1]

keeped = block_mask.numel() - block_mask.sum().to(torch.float32) # prevent overflow in float16
block_mask = 1 - block_mask.squeeze(1)

return block_mask
return block_mask, keeped

def _compute_gamma(self, x):
return self.drop_prob / (self.block_size ** 2)
Expand Down Expand Up @@ -146,7 +146,7 @@ def forward(self, x):

def step(self):
idx = self.i.item()
if idx > self.start_step and idx < self.start_step + self.nr_steps:
if self.start_step < idx < self.start_step + self.nr_steps:
self.dropblock.drop_prob += self.step_size

self.i += 1
4 changes: 2 additions & 2 deletions pytorch_toolbelt/modules/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ def encoder_layers(self):


class MobilenetV2Encoder(EncoderModule):
def __init__(self, layers=[2, 3, 5, 7]):
def __init__(self, layers=[2, 3, 5, 7], activation='relu6'):
super().__init__([32, 16, 24, 32, 64, 96, 160, 320], [2, 2, 4, 8, 16, 16, 32, 32], layers)
encoder = MobileNetV2()
encoder = MobileNetV2(activation=activation)

self.layer0 = encoder.layer0
self.layer1 = encoder.layer1
Expand Down
Loading

0 comments on commit 3ef7f37

Please sign in to comment.