diff --git a/pytorch_toolbelt/__init__.py b/pytorch_toolbelt/__init__.py index 3ac3d77b9..8c271f4d2 100644 --- a/pytorch_toolbelt/__init__.py +++ b/pytorch_toolbelt/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import -__version__ = '0.0.8' +__version__ = '0.0.9' diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py new file mode 100644 index 000000000..d89b5f9ab --- /dev/null +++ b/pytorch_toolbelt/modules/activations.py @@ -0,0 +1,76 @@ +from functools import partial + +from torch import nn +from torch.nn import functional as F + + +def swish(x): + return x * x.sigmoid() + + +def hard_sigmoid(x, inplace=False): + return F.relu6(x + 3, inplace) / 6 + + +def hard_swish(x, inplace=False): + return x * hard_sigmoid(x, inplace) + + +class HardSigmoid(nn.Module): + def __init__(self, inplace=False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, inplace=self.inplace) + + +class Swish(nn.Module): + def __init__(self, inplace=False): + super(Swish, self).__init__() + + def forward(self, x): + return swish(x) + + +class HardSwish(nn.Module): + def __init__(self, inplace=False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, inplace=self.inplace) + + +def get_activation_module(activation_name: str, **kwargs) -> nn.Module: + if activation_name.lower() == 'relu': + return partial(nn.ReLU, **kwargs) + + if activation_name.lower() == 'relu6': + return partial(nn.ReLU6, **kwargs) + + if activation_name.lower() == 'leaky_relu': + return partial(nn.LeakyReLU, **kwargs) + + if activation_name.lower() == 'elu': + return partial(nn.ELU, **kwargs) + + if activation_name.lower() == 'selu': + return partial(nn.SELU, **kwargs) + + if activation_name.lower() == 'celu': + return partial(nn.CELU, **kwargs) + + if activation_name.lower() == 'glu': + return partial(nn.GLU, **kwargs) + + if activation_name.lower() == 'prelu': + return partial(nn.PReLU, **kwargs) + + if activation_name.lower() == 'hard_sigmoid': + return partial(HardSigmoid, **kwargs) + + if activation_name.lower() == 'hard_swish': + return partial(HardSwish, **kwargs) + + raise ValueError(f'Activation \'{activation_name}\' is not supported') diff --git a/pytorch_toolbelt/modules/backbone/mobilenet.py b/pytorch_toolbelt/modules/backbone/mobilenet.py index 0cd3b74f3..4b9df71c8 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenet.py +++ b/pytorch_toolbelt/modules/backbone/mobilenet.py @@ -1,25 +1,29 @@ +from __future__ import absolute_import + import torch.nn as nn import math +from ..activations import get_activation_module + -def conv_bn(inp, oup, stride): +def conv_bn(inp, oup, stride, activation: nn.Module): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True) + activation(inplace=True) ) -def conv_1x1_bn(inp, oup): +def conv_1x1_bn(inp, oup, activation: nn.Module): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True) + activation(inplace=True) ) class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride, expand_ratio): + def __init__(self, inp, oup, stride, expand_ratio, activation: nn.Module): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] @@ -32,7 +36,7 @@ def __init__(self, inp, oup, stride, expand_ratio): # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), - nn.ReLU6(inplace=True), + activation(inplace=True), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), @@ -42,11 +46,11 @@ def __init__(self, inp, oup, stride, expand_ratio): # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), - nn.ReLU6(inplace=True), + activation(inplace=True), # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), - nn.ReLU6(inplace=True), + activation(inplace=True), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), @@ -60,8 +64,11 @@ def forward(self, x): class MobileNetV2(nn.Module): - def __init__(self, n_class=1000, input_size=224, width_mult=1.): + def __init__(self, n_class=1000, input_size=224, width_mult=1., activation='relu6'): super(MobileNetV2, self).__init__() + + act = get_activation_module(activation) + block = InvertedResidual input_channel = 32 last_channel = 1280 @@ -80,7 +87,7 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.): assert input_size % 32 == 0 input_channel = int(input_channel * width_mult) self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel - self.layer0 = conv_bn(3, input_channel, 2) + self.layer0 = conv_bn(3, input_channel, 2, act) # building inverted residual blocks for layer_index, (t, c, n, s) in enumerate(interverted_residual_setting): @@ -89,16 +96,16 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.): blocks = [] for i in range(n): if i == 0: - blocks.append(block(input_channel, output_channel, s, expand_ratio=t)) + blocks.append(block(input_channel, output_channel, s, expand_ratio=t, activation=act)) else: - blocks.append(block(input_channel, output_channel, 1, expand_ratio=t)) + blocks.append(block(input_channel, output_channel, 1, expand_ratio=t, activation=act)) input_channel = output_channel self.add_module(f'layer{layer_index + 1}', nn.Sequential(*blocks)) # building last several layers - self.final_layer = conv_1x1_bn(input_channel, self.last_channel) + self.final_layer = conv_1x1_bn(input_channel, self.last_channel, activation=act) # building classifier self.classifier = nn.Sequential( diff --git a/pytorch_toolbelt/modules/backbone/mobilenetv3.py b/pytorch_toolbelt/modules/backbone/mobilenetv3.py index 53ee52485..ccecacc26 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenetv3.py +++ b/pytorch_toolbelt/modules/backbone/mobilenetv3.py @@ -6,43 +6,15 @@ import torch.nn as nn import torch.nn.functional as F -from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D -from pytorch_toolbelt.modules import Identity - - -def swish(x): - return x * x.sigmoid() - - -def hard_sigmoid(x, inplace=False): - return F.relu6(x + 3, inplace) / 6 - - -def hard_swish(x, inplace=False): - return x * hard_sigmoid(x, inplace) - - -class HardSigmoid(nn.Module): - def __init__(self, inplace=False): - super(HardSigmoid, self).__init__() - self.inplace = inplace - - def forward(self, x): - return hard_sigmoid(x, inplace=self.inplace) - - -class HardSwish(nn.Module): - def __init__(self, inplace=False): - super(HardSwish, self).__init__() - self.inplace = inplace - - def forward(self, x): - return hard_swish(x, inplace=self.inplace) +# from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D +from pytorch_toolbelt.modules.activations import HardSwish, HardSigmoid +from pytorch_toolbelt.modules.identity import Identity def _make_divisible(v, divisor, min_value=None): """ Ensure that all layers have a channel number that is divisible by 8 + It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py :param v: @@ -59,9 +31,9 @@ def _make_divisible(v, divisor, min_value=None): return new_v -# https://github.com/jonnedtc/Squeeze-Excitation-PyTorch/blob/master/networks.py class SqEx(nn.Module): - """Squeeze-Excitation block, implemented in ONNX & CoreML friendly way + """Squeeze-Excitation block. Implemented in ONNX & CoreML friendly way. + Original implementation: https://github.com/jonnedtc/Squeeze-Excitation-PyTorch/blob/master/networks.py """ def __init__(self, n_features, reduction=4): @@ -89,24 +61,26 @@ def __init__(self, inplanes, outplanes, expplanes, k=3, stride=1, drop_prob=0, n super(LinearBottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, expplanes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(expplanes) - self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0., - stop_value=drop_prob, nr_steps=num_steps, start_step=start_step) - # TODO: first doesn't have act? + self.db1 = nn.Dropout2d(drop_prob) + # self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0., + # stop_value=drop_prob, nr_steps=num_steps, start_step=start_step) + self.act1 = activation(**act_params) # first does have act according to MobileNetV2 self.conv2 = nn.Conv2d(expplanes, expplanes, kernel_size=k, stride=stride, padding=k // 2, bias=False, groups=expplanes) self.bn2 = nn.BatchNorm2d(expplanes) - self.db2 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0., - stop_value=drop_prob, nr_steps=num_steps, start_step=start_step) + self.db2 = nn.Dropout2d(drop_prob) + # self.db2 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0., + # stop_value=drop_prob, nr_steps=num_steps, start_step=start_step) self.act2 = activation(**act_params) self.se = SqEx(expplanes) if SE else Identity() self.conv3 = nn.Conv2d(expplanes, outplanes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(outplanes) - self.db3 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0., - stop_value=drop_prob, nr_steps=num_steps, start_step=start_step) - self.act3 = activation(**act_params) + self.db3 = nn.Dropout2d(drop_prob) + # self.db3 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0., + # stop_value=drop_prob, nr_steps=num_steps, start_step=start_step) self.stride = stride self.expplanes = expplanes @@ -119,6 +93,7 @@ def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.db1(out) + out = self.act1(out) out = self.conv2(out) out = self.bn2(out) @@ -130,10 +105,9 @@ def forward(self, x): out = self.conv3(out) out = self.bn3(out) out = self.db3(out) - out = self.act3(out) if self.stride == 1 and self.inplanes == self.outplanes: # TODO: or add 1x1? - out = out + residual # No inplace if there is in-place activation before + out += residual # No inplace if there is in-place activation before return out @@ -187,7 +161,6 @@ def __init__(self, inplanes, num_classes, expplanes1, expplanes2): self.avgpool = nn.AdaptiveAvgPool2d(1) self.conv2 = nn.Conv2d(expplanes1, expplanes2, kernel_size=1, stride=1, bias=False) - self.bn2 = nn.BatchNorm2d(expplanes2) self.act2 = HardSwish(inplace=True) self.dropout = nn.Dropout(p=0.2, inplace=True) @@ -207,7 +180,6 @@ def forward(self, x): out = self.avgpool(out) out = self.conv2(out) - out = self.bn2(out) out = self.act2(out) # flatten for input to fully-connected layer @@ -246,16 +218,16 @@ def __init__(self, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num [80, 184, 80, 1, 3, drop_prob, False, HardSwish], # -> 14x14 [80, 480, 112, 1, 3, drop_prob, True, HardSwish], # -> 14x14 [112, 672, 112, 1, 3, drop_prob, True, HardSwish], # -> 14x14 - [112, 672, 160, 1, 5, drop_prob, True, HardSwish], # -> 14x14 - [160, 672, 160, 2, 5, drop_prob, True, HardSwish], # -> 7x7 #TODO + [112, 672, 160, 2, 5, drop_prob, True, HardSwish], # -> 7x7 + [160, 672, 160, 1, 5, drop_prob, True, HardSwish], # -> 7x7 [160, 960, 160, 1, 5, drop_prob, True, HardSwish], # -> 7x7 ] self.bottlenecks_setting_small = [ # in, exp, out, s, k, dp, se, act - [16, 64, 24, 2, 3, 0, True, nn.ReLU], # -> 56x56 #TODO - [24, 72, 24, 2, 3, 0, False, nn.ReLU], # -> 28x28 - [24, 88, 40, 1, 3, 0, False, nn.ReLU], # -> 28x28 - [40, 96, 40, 2, 5, 0, True, HardSwish], # -> 14x14 #TODO + [16, 64, 16, 2, 3, 0, True, nn.ReLU], # -> 56x56 + [16, 72, 24, 2, 3, 0, False, nn.ReLU], # -> 28x28 + [24, 88, 24, 1, 3, 0, False, nn.ReLU], # -> 28x28 + [24, 96, 40, 2, 5, 0, True, HardSwish], # -> 14x14 [40, 240, 40, 1, 5, drop_prob, True, HardSwish], # -> 14x14 [40, 240, 40, 1, 5, drop_prob, True, HardSwish], # -> 14x14 [40, 120, 48, 1, 5, drop_prob, True, HardSwish], # -> 14x14 @@ -290,7 +262,6 @@ def __init__(self, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num def _make_bottlenecks(self): layers = [] - modules = OrderedDict() stage_name = "Bottleneck" diff --git a/pytorch_toolbelt/modules/dropblock.py b/pytorch_toolbelt/modules/dropblock.py index a3805269f..679503514 100644 --- a/pytorch_toolbelt/modules/dropblock.py +++ b/pytorch_toolbelt/modules/dropblock.py @@ -41,14 +41,13 @@ def forward(self, x): mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).to(x) # compute block mask - block_mask = self._compute_block_mask(mask) + block_mask, keeped = self._compute_block_mask(mask) # apply block mask out = x * block_mask[:, None, :, :] # scale output - out = out * block_mask.numel() / block_mask.sum() - + out = out * (block_mask.numel() / keeped).to(out) return out def _compute_block_mask(self, mask): @@ -60,9 +59,10 @@ def _compute_block_mask(self, mask): if self.block_size % 2 == 0: block_mask = block_mask[:, :, :-1, :-1] + keeped = block_mask.numel() - block_mask.sum().to(torch.float32) # prevent overflow in float16 block_mask = 1 - block_mask.squeeze(1) - return block_mask + return block_mask, keeped def _compute_gamma(self, x): return self.drop_prob / (self.block_size ** 2) @@ -146,7 +146,7 @@ def forward(self, x): def step(self): idx = self.i.item() - if idx > self.start_step and idx < self.start_step + self.nr_steps: + if self.start_step < idx < self.start_step + self.nr_steps: self.dropblock.drop_prob += self.step_size self.i += 1 diff --git a/pytorch_toolbelt/modules/encoders.py b/pytorch_toolbelt/modules/encoders.py index b0afd5256..fca538fab 100644 --- a/pytorch_toolbelt/modules/encoders.py +++ b/pytorch_toolbelt/modules/encoders.py @@ -280,9 +280,9 @@ def encoder_layers(self): class MobilenetV2Encoder(EncoderModule): - def __init__(self, layers=[2, 3, 5, 7]): + def __init__(self, layers=[2, 3, 5, 7], activation='relu6'): super().__init__([32, 16, 24, 32, 64, 96, 160, 320], [2, 2, 4, 8, 16, 16, 32, 32], layers) - encoder = MobileNetV2() + encoder = MobileNetV2(activation=activation) self.layer0 = encoder.layer0 self.layer1 = encoder.layer1 diff --git a/pytorch_toolbelt/modules/srm.py b/pytorch_toolbelt/modules/srm.py new file mode 100644 index 000000000..77635047c --- /dev/null +++ b/pytorch_toolbelt/modules/srm.py @@ -0,0 +1,35 @@ +import torch +from torch import nn + + +class SRMLayer(nn.Module): + """An implementation of SRM block, proposed in + "SRM : A Style-based Recalibration Module for Convolutional Neural Networks". + + """ + + def __init__(self, channels: int): + super(SRMLayer, self).__init__() + + # Equal to torch.einsum('bck,ck->bc', A, B) + self.cfc = nn.Conv1d(channels, channels, + kernel_size=2, + bias=False, + groups=channels) + self.bn = nn.BatchNorm1d(channels) + + def forward(self, x): + b, c, _, _ = x.size() + + # Style pooling + mean = x.view(b, c, -1).mean(-1).unsqueeze(-1) + std = x.view(b, c, -1).std(-1).unsqueeze(-1) + u = torch.cat((mean, std), -1) # (b, c, 2) + + # Style integration + z = self.cfc(u) # (b, c, 1) + z = self.bn(z) + g = torch.sigmoid(z) + g = g.view(b, c, 1, 1) + + return x * g.expand_as(x) diff --git a/pytorch_toolbelt/utils/catalyst_utils.py b/pytorch_toolbelt/utils/catalyst_utils.py index ada06c30d..3bffc4e00 100644 --- a/pytorch_toolbelt/utils/catalyst_utils.py +++ b/pytorch_toolbelt/utils/catalyst_utils.py @@ -48,11 +48,6 @@ def to_cpu(self, data): return data raise ValueError("Unsupported type", type(data)) - def _log_image(self, loggers, mode: str, image, name, step: int, suffix=""): - for logger in loggers: - if isinstance(logger, TensorboardLogger): - logger.loggers[mode].add_image(f"{name}{suffix}", tensor_from_rgb_image(image), step) - def on_loader_start(self, state): self.best_score = None self.best_input = None @@ -83,12 +78,12 @@ def on_loader_end(self, state: RunnerState) -> None: if self.best_score is not None: best_samples = self.visualize_batch(self.best_input, self.best_output) for i, image in enumerate(best_samples): - logger.add_image(f"Best Batch/{i}/epoch", tensor_from_rgb_image(image), state.step) + logger.add_image(f"{self.target_metric}/best/{i}", tensor_from_rgb_image(image), state.step) if self.worst_score is not None: worst_samples = self.visualize_batch(self.worst_input, self.worst_output) for i, image in enumerate(worst_samples): - logger.add_image(f"Worst Batch/{i}/epoch", tensor_from_rgb_image(image), state.step) + logger.add_image(f"{self.target_metric}/worst/{i}", tensor_from_rgb_image(image), state.step) class EpochJaccardMetric(Callback): @@ -265,7 +260,7 @@ def on_loader_end(self, state): num_classes = len(class_names) cm = confusion_matrix(outputs, targets, labels=range(num_classes)) - fig = plot_confusion_matrix(cm, figsize=(6 + num_classes // 4, 6 + num_classes // 4), class_names=class_names, normalize=True, noshow=True) + fig = plot_confusion_matrix(cm, figsize=(6 + num_classes // 3, 6 + num_classes // 3), class_names=class_names, normalize=True, noshow=True) fig = render_figure_to_tensor(fig) logger = _get_tensorboard_logger(state) diff --git a/pytorch_toolbelt/utils/fs.py b/pytorch_toolbelt/utils/fs.py index 0c990774a..c7a8155b3 100644 --- a/pytorch_toolbelt/utils/fs.py +++ b/pytorch_toolbelt/utils/fs.py @@ -53,7 +53,7 @@ def auto_file(filename: str, where: str = '.') -> str: raise FileNotFoundError('Given file could not be found with recursive search:' + filename) if len(files) > 1: - raise FileNotFoundError('More than one file matches given filename. Please specify it explicitly:' + filename) + raise FileNotFoundError('More than one file matches given filename. Please specify it explicitly:\n' + '\n'.join(files)) return files[0] diff --git a/pytorch_toolbelt/utils/visualization.py b/pytorch_toolbelt/utils/visualization.py index 1f8a50428..2e3d1df12 100644 --- a/pytorch_toolbelt/utils/visualization.py +++ b/pytorch_toolbelt/utils/visualization.py @@ -19,7 +19,7 @@ def plot_confusion_matrix(cm, class_names, matplotlib.use('Agg') import matplotlib.pyplot as plt - cmap = plt.cm.Blues + cmap = plt.cm.Oranges if normalize: cm = cm.astype(np.float32) / cm.sum(axis=1)[:, np.newaxis] diff --git a/tests/test_modules.py b/tests/test_modules.py index cda2fc4d9..e2263dac1 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -17,6 +17,7 @@ def test_resnet18_encoder(): @pytest.mark.parametrize(['encoder', 'encoder_params'], [ [E.SqueezenetEncoder, {'layers': [0, 1, 2, 3]}], [E.MobilenetV2Encoder, {'layers': [0, 1, 2, 3, 4, 5, 6, 7]}], + [E.MobilenetV2Encoder, {'layers': [3, 5, 7], 'activation': 'elu'}], [E.MobilenetV3Encoder, {'small': False}], [E.MobilenetV3Encoder, {'small': True}], [E.Resnet18Encoder, {'layers': [0, 1, 2, 3, 4]}],