diff --git a/README.md b/README.md index 4121b51..974b308 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ - [x] `CBAM` - [CBAM: Convolutional Block Attention Module](https://arxiv.org/abs/1807.06521), ECCV, 2018 - [x] `SKNets` - [Selective Kernel Networks](https://arxiv.org/abs/1903.06586), CVPR, 2019 - [x] `ECA` - [ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks](https://arxiv.org/abs/1910.03151), CVPR, 2019 +- [x] `GlobalContextBlick` - [GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond](https://arxiv.org/abs/1904.11492), 2019 ### Transformer diff --git a/cvm/models/ops/blocks/__init__.py b/cvm/models/ops/blocks/__init__.py index a808c72..3a9fc66 100644 --- a/cvm/models/ops/blocks/__init__.py +++ b/cvm/models/ops/blocks/__init__.py @@ -19,4 +19,6 @@ from .gather_excite import GatherExciteBlock from .selective_kernel import SelectiveKernelBlock from .cbam import CBAM -from .efficient_channel_attention import EfficientChannelAttention \ No newline at end of file +from .efficient_channel_attention import EfficientChannelAttention +from .norm import LayerNorm2d +from .global_context import GlobalContextBlock \ No newline at end of file diff --git a/cvm/models/ops/blocks/global_context.py b/cvm/models/ops/blocks/global_context.py new file mode 100644 index 0000000..cb4842e --- /dev/null +++ b/cvm/models/ops/blocks/global_context.py @@ -0,0 +1,41 @@ +import torch +from torch import nn +from .vanilla_conv2d import Conv2d1x1 +from .norm import LayerNorm2d + + +class GlobalContextBlock(nn.Module): + r""" + Paper: GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond, https://arxiv.org/abs/1904.11492 + """ + + def __init__( + self, + in_channels, + rd_ratio + ) -> None: + super().__init__() + + channels = int(in_channels * rd_ratio) + + self.conv1x1 = Conv2d1x1(in_channels, 1, bias=True) + self.softmax = nn.Softmax(dim=1) + + self.transform = nn.Sequential( + Conv2d1x1(in_channels, channels), + LayerNorm2d(channels), + nn.ReLU(inplace=True), + Conv2d1x1(channels, in_channels) + ) + + def forward(self, x): + # context modeling + c = torch.einsum( + "ncx, nxo -> nco", + x.view(x.shape[0], x.shape[1], -1), + self.softmax(self.conv1x1(x).view(x.shape[0], -1, 1)) + ) + c = x * c.unsqueeze(-1) + + # transform + return x + self.transform(c) diff --git a/cvm/models/ops/blocks/norm.py b/cvm/models/ops/blocks/norm.py new file mode 100644 index 0000000..ab9dbd9 --- /dev/null +++ b/cvm/models/ops/blocks/norm.py @@ -0,0 +1,18 @@ +from torch import nn +import torch.nn.functional as F + + +class LayerNorm2d(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial BCHW tensors """ + + def __init__( + self, + channels + ): + super().__init__(channels) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x diff --git a/cvm/version.py b/cvm/version.py index 2068bc3..dc982b5 100644 --- a/cvm/version.py +++ b/cvm/version.py @@ -1 +1 @@ -__version__ = '0.0.26' \ No newline at end of file +__version__ = '0.0.27' \ No newline at end of file