diff --git a/mmaction/models/heads/base.py b/mmaction/models/heads/base.py index b905bf376b..a0e72c077a 100644 --- a/mmaction/models/heads/base.py +++ b/mmaction/models/heads/base.py @@ -5,6 +5,22 @@ import torch.nn.functional as F +class AvgConsensus(nn.Module): + """Average consensus module. + + Attributes: + dim (int): Decide which dim consensus function to apply. + Default: 1. + """ + + def __init__(self, dim=1): + super(AvgConsensus, self).__init__() + self.dim = dim + + def forward(self, input): + return input.mean(dim=self.dim, keepdim=True) + + class BaseHead(nn.Module, metaclass=ABCMeta): """Base class for head. diff --git a/mmaction/models/heads/tsn_head.py b/mmaction/models/heads/tsn_head.py index 1c2fe889ab..b323bf1378 100644 --- a/mmaction/models/heads/tsn_head.py +++ b/mmaction/models/heads/tsn_head.py @@ -2,23 +2,7 @@ from mmcv.cnn.weight_init import normal_init from ..registry import HEADS -from .base import BaseHead - - -class AvgConsensus(nn.Module): - """Average consensus module. - - Attributes: - dim (int): Decide which dim consensus function to apply. - Default: 1. - """ - - def __init__(self, dim=1): - super(AvgConsensus, self).__init__() - self.dim = dim - - def forward(self, input): - return input.mean(dim=self.dim, keepdim=True) +from .base import AvgConsensus, BaseHead @HEADS.register_module