Skip to content

Commit

Permalink
Add 'Efficient Channel Attention'
Browse files Browse the repository at this point in the history
  • Loading branch information
ffiirree committed Dec 4, 2022
1 parent 6091026 commit 1adb783
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
- [x] `Gather-Excite` - [Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks](https://arxiv.org/abs/1810.12348), NeurIPS, 2018
- [x] `CBAM` - [CBAM: Convolutional Block Attention Module](https://arxiv.org/abs/1807.06521), ECCV, 2018
- [x] `SKNets` - [Selective Kernel Networks](https://arxiv.org/abs/1903.06586), CVPR, 2019
- [x] `ECA` - [ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks](https://arxiv.org/abs/1910.03151), CVPR, 2019

### Transformer

Expand Down
3 changes: 2 additions & 1 deletion cvm/models/ops/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
from .interpolate import Interpolate
from .gather_excite import GatherExciteBlock
from .selective_kernel import SelectiveKernelBlock
from .cbam import CBAM
from .cbam import CBAM
from .efficient_channel_attention import EfficientChannelAttention
31 changes: 31 additions & 0 deletions cvm/models/ops/blocks/efficient_channel_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import math
import torch
from torch import nn


class EfficientChannelAttention(nn.Module):
r"""
Paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks, https://arxiv.org/abs/1910.03151
"""
def __init__(
self,
in_channels,
gamma=2,
beta=2
) -> None:
super().__init__()

t = int(abs((math.log(in_channels, 2) + beta) / gamma))
k = max(t if t % 2 else t + 1, 3)

self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2)
self.gate = nn.Sigmoid()

def forward(self, x: torch.Tensor):
y = self.pool(x)
y = self.conv(y.view(y.shape[0], 1, -1))
y = y.view(y.shape[0], -1, 1, 1)
y = self.gate(y)

return x * y.expand_as(x)

0 comments on commit 1adb783

Please sign in to comment.