Skip to content

Commit

Permalink
Add 'Global Context Block'
Browse files Browse the repository at this point in the history
  • Loading branch information
ffiirree committed Dec 4, 2022
1 parent b87ba3e commit 851c21d
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
- [x] `CBAM` - [CBAM: Convolutional Block Attention Module](https://arxiv.org/abs/1807.06521), ECCV, 2018
- [x] `SKNets` - [Selective Kernel Networks](https://arxiv.org/abs/1903.06586), CVPR, 2019
- [x] `ECA` - [ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks](https://arxiv.org/abs/1910.03151), CVPR, 2019
- [x] `GlobalContextBlick` - [GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond](https://arxiv.org/abs/1904.11492), 2019

### Transformer

Expand Down
4 changes: 3 additions & 1 deletion cvm/models/ops/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@
from .gather_excite import GatherExciteBlock
from .selective_kernel import SelectiveKernelBlock
from .cbam import CBAM
from .efficient_channel_attention import EfficientChannelAttention
from .efficient_channel_attention import EfficientChannelAttention
from .norm import LayerNorm2d
from .global_context import GlobalContextBlock
41 changes: 41 additions & 0 deletions cvm/models/ops/blocks/global_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from torch import nn
from .vanilla_conv2d import Conv2d1x1
from .norm import LayerNorm2d


class GlobalContextBlock(nn.Module):
r"""
Paper: GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond, https://arxiv.org/abs/1904.11492
"""

def __init__(
self,
in_channels,
rd_ratio
) -> None:
super().__init__()

channels = int(in_channels * rd_ratio)

self.conv1x1 = Conv2d1x1(in_channels, 1, bias=True)
self.softmax = nn.Softmax(dim=1)

self.transform = nn.Sequential(
Conv2d1x1(in_channels, channels),
LayerNorm2d(channels),
nn.ReLU(inplace=True),
Conv2d1x1(channels, in_channels)
)

def forward(self, x):
# context modeling
c = torch.einsum(
"ncx, nxo -> nco",
x.view(x.shape[0], x.shape[1], -1),
self.softmax(self.conv1x1(x).view(x.shape[0], -1, 1))
)
c = x * c.unsqueeze(-1)

# transform
return x + self.transform(c)
18 changes: 18 additions & 0 deletions cvm/models/ops/blocks/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from torch import nn
import torch.nn.functional as F


class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial BCHW tensors """

def __init__(
self,
channels
):
super().__init__(channels)

def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
2 changes: 1 addition & 1 deletion cvm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.26'
__version__ = '0.0.27'

0 comments on commit 851c21d

Please sign in to comment.