diff --git a/cvm/models/ops/blocks/gather_excite.py b/cvm/models/ops/blocks/gather_excite.py index b2dbeb2..85ca0cc 100644 --- a/cvm/models/ops/blocks/gather_excite.py +++ b/cvm/models/ops/blocks/gather_excite.py @@ -2,45 +2,31 @@ from functools import partial from contextlib import contextmanager from torch import nn -import torch.nn.functional as F -from . import norm_act -from .vanilla_conv2d import Conv2d1x1 -from ..functional import make_divisible -from typing import OrderedDict from .depthwise_separable_conv2d import DepthwiseBlock, DepthwiseConv2dBN from .interpolate import Interpolate _GE_INNER_NONLINEAR: nn.Module = partial(nn.ReLU, inplace=True) _GE_GATING_FN: nn.Module = nn.Sigmoid -_GE_DIVISOR: int = 8 -_GE_USE_NORM: bool = True @contextmanager def ge( inner_nonlinear: nn.Module = _GE_INNER_NONLINEAR, - gating_fn: nn.Module = _GE_GATING_FN, - divisor: int = _GE_DIVISOR, - use_norm: bool = _GE_USE_NORM + gating_fn: nn.Module = _GE_GATING_FN ): global _GE_INNER_NONLINEAR global _GE_GATING_FN - global _GE_DIVISOR - global _GE_USE_NORM _pre_inner_fn = _GE_INNER_NONLINEAR _pre_fn = _GE_GATING_FN - _pre_divisor = _GE_DIVISOR - _pre_use_norm = _GE_USE_NORM + _GE_INNER_NONLINEAR = inner_nonlinear _GE_GATING_FN = gating_fn - _GE_DIVISOR = divisor - _GE_USE_NORM = use_norm + yield + _GE_INNER_NONLINEAR = _pre_inner_fn _GE_GATING_FN = _pre_fn - _GE_DIVISOR = _pre_divisor - _GE_USE_NORM = _pre_use_norm class GatherExciteBlock(nn.Module):