-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsqueeze_excite.py
75 lines (63 loc) · 2.09 KB
/
squeeze_excite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from functools import partial
from contextlib import contextmanager
from torch import nn
from . import factory
from .vanilla_conv2d import Conv2d1x1
from ..functional import make_divisible
from typing import OrderedDict
_SE_INNER_NONLINEAR: nn.Module = partial(nn.ReLU, inplace=True)
_SE_GATING_FN: nn.Module = nn.Sigmoid
_SE_DIVISOR: int = 8
_SE_USE_NORM: bool = False
@contextmanager
def se(
inner_nonlinear: nn.Module = _SE_INNER_NONLINEAR,
gating_fn: nn.Module = _SE_GATING_FN,
divisor: int = _SE_DIVISOR,
use_norm: bool = _SE_USE_NORM
):
global _SE_INNER_NONLINEAR
global _SE_GATING_FN
global _SE_DIVISOR
global _SE_USE_NORM
_pre_inner_fn = _SE_INNER_NONLINEAR
_pre_fn = _SE_GATING_FN
_pre_divisor = _SE_DIVISOR
_pre_use_norm = _SE_USE_NORM
_SE_INNER_NONLINEAR = inner_nonlinear
_SE_GATING_FN = gating_fn
_SE_DIVISOR = divisor
_SE_USE_NORM = use_norm
yield
_SE_INNER_NONLINEAR = _pre_inner_fn
_SE_GATING_FN = _pre_fn
_SE_DIVISOR = _pre_divisor
_SE_USE_NORM = _pre_use_norm
class SEBlock(nn.Sequential):
"""Squeeze-and-Excitation Block
"""
def __init__(
self,
channels,
rd_ratio,
inner_activation_fn: nn.Module = None,
gating_fn: nn.Module = None
):
squeezed_channels = make_divisible(int(channels * rd_ratio), _SE_DIVISOR)
inner_activation_fn = inner_activation_fn or _SE_INNER_NONLINEAR
gating_fn = gating_fn or _SE_GATING_FN
layers = OrderedDict([])
layers['pool'] = nn.AdaptiveAvgPool2d((1, 1))
layers['reduce'] = Conv2d1x1(channels, squeezed_channels, bias=True)
if _SE_USE_NORM:
layers['norm'] = factory.normalizer_fn(squeezed_channels)
layers['act'] = inner_activation_fn()
layers['expand'] = Conv2d1x1(squeezed_channels, channels, bias=True)
layers['gate'] = gating_fn()
super().__init__(layers)
def _forward(self, input):
for module in self:
input = module(input)
return input
def forward(self, x):
return x * self._forward(x)