Skip to content

Commit

Permalink
Add 'CBAM (Convolutional Block Attention Module)'
Browse files Browse the repository at this point in the history
  • Loading branch information
ffiirree committed Dec 2, 2022
1 parent b8ea8c1 commit 6091026
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 7 deletions.
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,12 @@

### Attention Blocks

- [x] `Non-Local` - [Non-local Neural Networks](https://arxiv.org/abs/1711.07971), CVPR, 2017
- [x] `Squeeze-and-Excitation` - [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507), CVPR, 2017
- [x] `Gather-Excite` - [Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks](https://arxiv.org/abs/1810.12348), NeurIPS, 2018
- [x] `CBAM` - [CBAM: Convolutional Block Attention Module](https://arxiv.org/abs/1807.06521), ECCV, 2018
- [x] `SKNets` - [Selective Kernel Networks](https://arxiv.org/abs/1903.06586), CVPR, 2019

### Self-Attention

- [x] `Non-Local` - [Non-local Neural Networks](https://arxiv.org/abs/1711.07971), CVPR, 2017


### Transformer

- [x] `ViT` - [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929), ICLR, 2020
Expand Down
3 changes: 2 additions & 1 deletion cvm/models/ops/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from .non_local import NonLocalBlock
from .interpolate import Interpolate
from .gather_excite import GatherExciteBlock
from .selective_kernel import SelectiveKernelBlock
from .selective_kernel import SelectiveKernelBlock
from .cbam import CBAM
65 changes: 65 additions & 0 deletions cvm/models/ops/blocks/cbam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from torch import nn
from .vanilla_conv2d import Conv2d1x1
from .norm_act import normalizer_fn, activation_fn


class ChannelAttention(nn.Module):
def __init__(
self,
in_channels,
rd_ratio,
gate_fn: nn.Module = nn.Sigmoid
) -> None:
super().__init__()

rd_channels = int(in_channels * rd_ratio)

self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

self.mlp = nn.Sequential(
Conv2d1x1(in_channels, rd_channels, bias=True),
activation_fn(),
Conv2d1x1(rd_channels, in_channels, bias=True)
)
self.gate = gate_fn()

def forward(self, x):
return x * self.gate(self.mlp(self.max_pool(x)) + self.mlp(self.avg_pool(x)))


class SpatialAttention(nn.Module):
def __init__(
self,
kernel_size: int = 7,
gate_fn: nn.Module = nn.Sigmoid
) -> None:
super().__init__()

self.conv = nn.Conv2d(2, 1, kernel_size, padding=(kernel_size - 1) // 2, bias=False)
self.norm = normalizer_fn(1)
self.gate = gate_fn()

def forward(self, x):
s = torch.cat([torch.amax(x, dim=1, keepdim=True), torch.mean(x, dim=1, keepdim=True)], dim=1)
return x * self.gate(self.norm(self.conv(s)))


class CBAM(nn.Sequential):
r"""
Paper: CBAM: Convolutional Block Attention Module, https://arxiv.org/abs/1807.06521
Code: https://github.com/Jongchan/attention-module
"""

def __init__(
self,
in_channels,
rd_ratio,
kernel_size: int = 7,
gate_fn: nn.Module = nn.Sigmoid
) -> None:
super().__init__(
ChannelAttention(in_channels, rd_ratio, gate_fn=gate_fn),
SpatialAttention(kernel_size=kernel_size, gate_fn=gate_fn)
)
2 changes: 1 addition & 1 deletion cvm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.25'
__version__ = '0.0.26'

0 comments on commit 6091026

Please sign in to comment.